From 7e7348bb9e364d44c1206b399b0acaa66e8d5a7b Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Sun, 21 Jun 2026 03:41:10 +0800 Subject: [PATCH 1/2] feat: add abort singal core plumbing --- src/api/index.ts | 6 ++ src/core/task/Task.ts | 1 + src/core/task/__tests__/Task.spec.ts | 139 +++++++++++++++++++++++++++ 3 files changed, 146 insertions(+) diff --git a/src/api/index.ts b/src/api/index.ts index e52b41200..0c901f8e2 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -90,6 +90,12 @@ export interface ApiHandlerCreateMessageMetadata { * Only applies to providers that support function calling restrictions (e.g., Gemini). */ allowedFunctionNames?: string[] + /** + * Abort signal for cancelling the HTTP request mid-stream. + * Passed through to AI SDK's streamText() so the underlying HTTP request is aborted + * when the user clicks stop, preventing wasted API tokens/compute on the provider side. + */ + abortSignal?: AbortSignal } export interface ApiHandler { diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index 2f1b370b4..53b5768e9 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -4161,6 +4161,7 @@ export class Task extends EventEmitter implements TaskLike { // Create an AbortController to allow cancelling the request mid-stream this.currentRequestAbortController = new AbortController() const abortSignal = this.currentRequestAbortController.signal + metadata.abortSignal = abortSignal // Reset the flag after using it this.skipPrevResponseIdOnce = false diff --git a/src/core/task/__tests__/Task.spec.ts b/src/core/task/__tests__/Task.spec.ts index dd6313580..e84b023de 100644 --- a/src/core/task/__tests__/Task.spec.ts +++ b/src/core/task/__tests__/Task.spec.ts @@ -1795,6 +1795,145 @@ describe("Cline", () => { // Verify cancelCurrentRequest was called expect(cancelSpy).toHaveBeenCalled() }) + describe("abortSignal", () => { + it("should pass AbortController signal to createMessage metadata", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + startTask: false, + }) + + // Mock required methods for attemptApiRequest to work without hanging + vi.spyOn(task as any, "getSystemPrompt").mockResolvedValue("mock system prompt") + + vi.spyOn(task.api, "getModel").mockReturnValue({ + id: mockApiConfig.apiModelId!, + info: { + supportsImages: false, + supportsPromptCache: true, + contextWindow: 200000, + maxTokens: 4096, + inputPrice: 0.3, + outputPrice: 1.5, + } as ModelInfo, + }) + + const providerState = await mockProvider.getState() + vi.spyOn(mockProvider, "getState").mockResolvedValue({ + ...providerState, + apiConfiguration: mockApiConfig, + autoApprovalEnabled: true, + requestDelaySeconds: 0, + }) + + // Mock the API stream response + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { type: "text", text: "response" } + }, + async next() { + return { done: true, value: { type: "text", text: "response" } } + }, + async return() { + return { done: true, value: undefined } + }, + async throw(e: any) { + throw e + }, + [Symbol.asyncDispose]: async () => {}, + } as AsyncGenerator + + const createMessageSpy = vi.spyOn(task.api, "createMessage").mockReturnValue(mockStream) + + task.apiConversationHistory = [ + { + role: "user" as const, + content: [{ type: "text" as const, text: "test message" }], + ts: Date.now(), + }, + ] as any + + const iterator = task.attemptApiRequest(0) + await iterator.next() + + // Verify createMessage was called with metadata containing abortSignal + expect(createMessageSpy).toHaveBeenCalled() + const [, , metadata] = createMessageSpy.mock.calls[0]! + + expect(metadata).toBeDefined() + expect(metadata!.abortSignal).toBeInstanceOf(AbortSignal) + }) + + it("should use the same AbortController signal as currentRequestAbortController", async () => { + const task = new Task({ + provider: mockProvider, + apiConfiguration: mockApiConfig, + task: "test task", + startTask: false, + }) + + // Mock required methods for attemptApiRequest to work without hanging + vi.spyOn(task as any, "getSystemPrompt").mockResolvedValue("mock system prompt") + + vi.spyOn(task.api, "getModel").mockReturnValue({ + id: mockApiConfig.apiModelId!, + info: { + supportsImages: false, + supportsPromptCache: true, + contextWindow: 200000, + maxTokens: 4096, + inputPrice: 0.3, + outputPrice: 1.5, + } as ModelInfo, + }) + + const providerState = await mockProvider.getState() + vi.spyOn(mockProvider, "getState").mockResolvedValue({ + ...providerState, + apiConfiguration: mockApiConfig, + autoApprovalEnabled: true, + requestDelaySeconds: 0, + }) + + // Mock the API stream response + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { type: "text", text: "response" } + }, + async next() { + return { done: true, value: { type: "text", text: "response" } } + }, + async return() { + return { done: true, value: undefined } + }, + async throw(e: any) { + throw e + }, + [Symbol.asyncDispose]: async () => {}, + } as AsyncGenerator + + const createMessageSpy = vi.spyOn(task.api, "createMessage").mockReturnValue(mockStream) + + task.apiConversationHistory = [ + { + role: "user" as const, + content: [{ type: "text" as const, text: "test message" }], + ts: Date.now(), + }, + ] as any + + const iterator = task.attemptApiRequest(0) + await iterator.next() + + // Get the signal from metadata + const [, , metadata] = createMessageSpy.mock.calls[0]! + const metadataSignal = metadata!.abortSignal + + // The signal in metadata should be the same as the one from currentRequestAbortController + expect(metadataSignal).toBe(task.currentRequestAbortController!.signal) + }) + }) }) }) From 589b8872ca10cc5e3a348ec7f500bf8bacbc4ec0 Mon Sep 17 00:00:00 2001 From: Eason Liang Date: Sun, 21 Jun 2026 08:37:42 +0800 Subject: [PATCH 2/2] feat: add abort singal for pass-through providers --- .../__tests__/anthropic-vertex.spec.ts | 176 +++++++++++++++++- src/api/providers/__tests__/anthropic.spec.ts | 74 ++++++++ .../base-openai-compatible-provider.spec.ts | 88 +++++++++ src/api/providers/__tests__/deepseek.spec.ts | 148 +++++++++++++++ src/api/providers/__tests__/gemini.spec.ts | 102 ++++++++++ src/api/providers/__tests__/lite-llm.spec.ts | 160 ++++++++++++++++ src/api/providers/__tests__/lmstudio.spec.ts | 79 ++++++++ src/api/providers/__tests__/mimo.spec.ts | 92 +++++++++ src/api/providers/__tests__/minimax.spec.ts | 116 ++++++++++-- src/api/providers/__tests__/moonshot.spec.ts | 84 +++++++++ src/api/providers/__tests__/openai.spec.ts | 98 ++++++++++ .../providers/__tests__/openrouter.spec.ts | 100 ++++++++++ src/api/providers/__tests__/poe.spec.ts | 79 ++++++++ .../__tests__/qwen-code-native-tools.spec.ts | 93 +++++++++ src/api/providers/__tests__/requesty.spec.ts | 91 +++++++++ src/api/providers/__tests__/unbound.spec.ts | 114 ++++++++++++ src/api/providers/__tests__/xai.spec.ts | 84 +++++++++ src/api/providers/__tests__/zai.spec.ts | 102 ++++++++++ .../providers/__tests__/zoo-gateway.spec.ts | 88 +++++++++ src/api/providers/anthropic-vertex.ts | 9 +- src/api/providers/anthropic.ts | 9 +- .../base-openai-compatible-provider.ts | 7 +- src/api/providers/deepseek.ts | 5 +- src/api/providers/gemini.ts | 1 + src/api/providers/lite-llm.ts | 5 +- src/api/providers/lm-studio.ts | 20 +- src/api/providers/mimo.ts | 5 +- src/api/providers/minimax.ts | 5 +- src/api/providers/openai-compatible.ts | 1 + src/api/providers/openai.ts | 16 +- src/api/providers/openrouter.ts | 13 +- src/api/providers/poe.ts | 1 + src/api/providers/qwen-code.ts | 7 +- src/api/providers/requesty.ts | 6 +- src/api/providers/unbound.ts | 5 +- src/api/providers/xai.ts | 13 +- src/api/providers/zai.ts | 1 + src/api/providers/zoo-gateway.ts | 1 + 38 files changed, 2040 insertions(+), 58 deletions(-) diff --git a/src/api/providers/__tests__/anthropic-vertex.spec.ts b/src/api/providers/__tests__/anthropic-vertex.spec.ts index 2121e8695..55416e1f3 100644 --- a/src/api/providers/__tests__/anthropic-vertex.spec.ts +++ b/src/api/providers/__tests__/anthropic-vertex.spec.ts @@ -263,10 +263,13 @@ describe("VertexHandler", () => { ], stream: true, // Tools are now always present (minimum 6 from ALWAYS_AVAILABLE_TOOLS) - tools: expect.any(Array), - tool_choice: expect.any(Object), + tools: [], + tool_choice: { + disable_parallel_tool_use: false, + type: "auto", + }, }), - undefined, + {}, ) }) @@ -481,7 +484,7 @@ describe("VertexHandler", () => { }), ], }), - undefined, + {}, ) }) @@ -1156,7 +1159,7 @@ describe("VertexHandler", () => { } // Verify the API was called without the beta header - expect(mockCreate).toHaveBeenCalledWith(expect.anything(), undefined) + expect(mockCreate).toHaveBeenCalledWith(expect.anything(), {}) }) }) @@ -1246,7 +1249,7 @@ describe("VertexHandler", () => { thinking: { type: "enabled", budget_tokens: 4096 }, temperature: 1.0, // Thinking requires temperature 1.0 }), - undefined, + {}, ) }) @@ -1273,7 +1276,7 @@ describe("VertexHandler", () => { expect.objectContaining({ thinking: { type: "adaptive" }, }), - undefined, + {}, ) const request = mockCreate.mock.calls[0][0] @@ -1302,7 +1305,7 @@ describe("VertexHandler", () => { expect.objectContaining({ thinking: { type: "adaptive" }, }), - undefined, + {}, ) const request = mockCreate.mock.calls[0][0] @@ -1393,7 +1396,7 @@ describe("VertexHandler", () => { ]), tool_choice: { type: "auto", disable_parallel_tool_use: false }, }), - undefined, + {}, ) }) @@ -1446,7 +1449,7 @@ describe("VertexHandler", () => { }), ]), }), - undefined, + {}, ) }) @@ -1611,4 +1614,157 @@ describe("VertexHandler", () => { }) }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-sonnet", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const mockStream = async function* () { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + type: "message_start", + message: { usage: { input_tokens: 10, output_tokens: 0 } }, + } + } + + ;(handler["client"].messages as any).create = vitest.fn().mockResolvedValue(mockStream()) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }], { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should not pass signal when abortSignal is undefined", async () => { + const handler = new AnthropicVertexHandler({ + apiModelId: "claude-3-sonnet", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + const mockStream = async function* () { + yield { + type: "message_start", + message: { usage: { input_tokens: 10, output_tokens: 5 } }, + } + yield { + type: "content_block_start", + content_block: { type: "text", text: "" }, + } + yield { + type: "content_block_delta", + delta: { type: "text_delta", text: "response" }, + } + } + + ;(handler["client"].messages as any).create = vitest.fn().mockResolvedValue(mockStream()) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }]) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const testHandler = new AnthropicVertexHandler({ + apiModelId: "claude-3-sonnet", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + testHandler["client"].messages.create = vitest.fn().mockImplementation(async (options, requestOptions) => { + // Verify that the signal was passed and is already aborted + expect(requestOptions).toHaveProperty("signal", controller.signal) + expect(controller.signal.aborted).toBe(true) + + return { + [Symbol.asyncIterator]: async function* () { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + type: "message_start", + message: { usage: { input_tokens: 10, output_tokens: 5 } }, + } + }, + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }], { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should pass signal when provided", async () => { + const controller = new AbortController() + let capturedRequestOptions: any + + const testHandler = new AnthropicVertexHandler({ + apiModelId: "claude-3-sonnet", + vertexProjectId: "test-project", + vertexRegion: "us-central1", + }) + + testHandler["client"].messages.create = vitest.fn().mockImplementation(async (options, requestOptions) => { + capturedRequestOptions = requestOptions + return { + [Symbol.asyncIterator]: async function* () { + yield { + type: "message_start", + message: { usage: { input_tokens: 10, output_tokens: 5 } }, + } + yield { + type: "content_block_delta", + delta: { type: "text_delta", text: "response" }, + } + }, + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }], { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(capturedRequestOptions).toHaveProperty("signal", controller.signal) + }) + }) }) diff --git a/src/api/providers/__tests__/anthropic.spec.ts b/src/api/providers/__tests__/anthropic.spec.ts index a2c0cb88e..6a1ca4b19 100644 --- a/src/api/providers/__tests__/anthropic.spec.ts +++ b/src/api/providers/__tests__/anthropic.spec.ts @@ -1057,4 +1057,78 @@ describe("AnthropicHandler", () => { }) }) }) + + describe("abort signal", () => { + it("should pass abortSignal to the SDK options", async () => { + const controller = new AbortController() + + mockCreate.mockImplementation(async (options, requestOptions) => { + // Verify that the signal was passed + expect(requestOptions).toHaveProperty("signal", controller.signal) + return { + async *[Symbol.asyncIterator]() { + yield { + type: "message_start", + message: { usage: { input_tokens: 10, output_tokens: 5 } }, + } + yield { + type: "content_block_delta", + delta: { type: "text_delta", text: "response" }, + } + }, + } + }) + + const handler = new AnthropicHandler(mockOptions) + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }], { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should work normally without abortSignal", async () => { + const handler = new AnthropicHandler(mockOptions) + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }]) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should not pass signal when abortSignal is undefined", async () => { + mockCreate.mockImplementation(async (options, requestOptions) => { + // When no abortSignal is provided, requestOptions should be undefined or not have signal + expect(requestOptions).toBeUndefined() + return { + async *[Symbol.asyncIterator]() { + yield { + type: "message_start", + message: { usage: { input_tokens: 10, output_tokens: 5 } }, + } + }, + } + }) + + const handler = new AnthropicHandler(mockOptions) + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }]) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + }) }) diff --git a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts index 847aa6e4d..e4a98f19e 100644 --- a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts +++ b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts @@ -551,4 +551,92 @@ describe("BaseOpenAiCompatibleProvider", () => { expect(endChunks).toHaveLength(0) }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + handler = new TestOpenAiCompatibleProvider("test-api-key") + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" } }], + usage: null, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + handler = new TestOpenAiCompatibleProvider("test-api-key") + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + handler = new TestOpenAiCompatibleProvider("test-api-key") + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/deepseek.spec.ts b/src/api/providers/__tests__/deepseek.spec.ts index 2f0482eee..a0d230818 100644 --- a/src/api/providers/__tests__/deepseek.spec.ts +++ b/src/api/providers/__tests__/deepseek.spec.ts @@ -710,4 +710,152 @@ describe("DeepSeekHandler", () => { expect(toolCallChunks[0].name).toBe("get_weather") }) }) + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const testHandler = new DeepSeekHandler(mockOptions) + + mockCreate.mockImplementation(async () => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" }, index: 0 }], + usage: null, + } + } + }, + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const testHandler = new DeepSeekHandler(mockOptions) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" }, index: 0 }] } + yield { + choices: [{ delta: {}, index: 0 }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const testHandler = new DeepSeekHandler(mockOptions) + + mockCreate.mockImplementation(async () => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" }, index: 0 }] } + }, + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should pass abort signal to OpenAI client create call", async () => { + const controller = new AbortController() + const testHandler = new DeepSeekHandler(mockOptions) + + let receivedConfig: any + mockCreate.mockImplementation(async (options: unknown, config?: unknown) => { + receivedConfig = config + return { + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "response" }, index: 0 }] } + yield { + choices: [{ delta: {}, index: 0 }], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await stream.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should not pass signal when no abortSignal is provided", async () => { + const testHandler = new DeepSeekHandler(mockOptions) + + let receivedConfig: any + mockCreate.mockImplementation(async (options: unknown, config?: unknown) => { + receivedConfig = config + return { + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "response" }, index: 0 }] } + yield { + choices: [{ delta: {}, index: 0 }], + usage: { prompt_tokens: 1, completion_tokens: 1, total_tokens: 2 }, + } + }, + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + await stream.next() + + expect(mockCreate).toHaveBeenCalledWith(expect.anything(), {}) + }) + }) }) diff --git a/src/api/providers/__tests__/gemini.spec.ts b/src/api/providers/__tests__/gemini.spec.ts index e2633474a..ef0c08e33 100644 --- a/src/api/providers/__tests__/gemini.spec.ts +++ b/src/api/providers/__tests__/gemini.spec.ts @@ -366,4 +366,106 @@ describe("GeminiHandler", () => { expect(mockCaptureException).toHaveBeenCalled() }) }) + + describe("abort signal", () => { + it("should pass abortSignal to the SDK options", async () => { + const controller = new AbortController() + let capturedConfig: any + + handler["client"].models.generateContentStream = vitest.fn().mockImplementation(async (params) => { + capturedConfig = params?.config + return { + [Symbol.asyncIterator]: async function* () { + yield { text: "Hello world!" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(capturedConfig).toHaveProperty("signal", controller.signal) + }) + + it("should work normally without abortSignal", async () => { + handler["client"].models.generateContentStream = vitest.fn().mockResolvedValue({ + [Symbol.asyncIterator]: async function* () { + yield { text: "Hello world!" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should not pass signal when abortSignal is undefined", async () => { + let capturedConfig: any + + handler["client"].models.generateContentStream = vitest.fn().mockImplementation(async (params) => { + capturedConfig = params?.config + return { + [Symbol.asyncIterator]: async function* () { + yield { text: "Hello world!" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(capturedConfig).not.toHaveProperty("signal") + }) + + it("should pass signal when provided", async () => { + const controller = new AbortController() + let capturedConfig: any + + handler["client"].models.generateContentStream = vitest.fn().mockImplementation(async (params) => { + capturedConfig = params?.config + return { + [Symbol.asyncIterator]: async function* () { + yield { text: "Hello world!" } + yield { usageMetadata: { promptTokenCount: 10, candidatesTokenCount: 5 } } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + expect(capturedConfig).toHaveProperty("signal", controller.signal) + }) + }) }) diff --git a/src/api/providers/__tests__/lite-llm.spec.ts b/src/api/providers/__tests__/lite-llm.spec.ts index ab2f26105..4ab966269 100644 --- a/src/api/providers/__tests__/lite-llm.spec.ts +++ b/src/api/providers/__tests__/lite-llm.spec.ts @@ -1180,4 +1180,164 @@ describe("LiteLLMHandler", () => { expect(requestHeaders).not.toHaveProperty("X-Zoo-Session-ID") }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const testHandler = new LiteLLMHandler(mockOptions) + + const mockStream = { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" } }], + usage: null, + } + } + }, + } + + mockCreate.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const testHandler = new LiteLLMHandler(mockOptions) + + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }] } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + } + + mockCreate.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const testHandler = new LiteLLMHandler(mockOptions) + + const mockStream = { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }] } + }, + } + + mockCreate.mockReturnValue({ + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should pass abort signal to OpenAI client create call", async () => { + const controller = new AbortController() + const testHandler = new LiteLLMHandler(mockOptions) + + let receivedConfig: any + mockCreate.mockImplementation((options: unknown, config?: unknown) => { + receivedConfig = config + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "response" } }] } + yield { + choices: [{ delta: {} }], + usage: { prompt_tokens: 1, completion_tokens: 1 }, + } + }, + } + return { + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await stream.next() + + expect(mockCreate).toHaveBeenCalledWith( + expect.anything(), + expect.objectContaining({ + signal: controller.signal, + }), + ) + }) + + it("should not pass signal when no abortSignal is provided", async () => { + const testHandler = new LiteLLMHandler(mockOptions) + + let receivedConfig: any + mockCreate.mockImplementation((options: unknown, config?: unknown) => { + receivedConfig = config + const mockStream = { + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "response" } }] } + yield { + choices: [{ delta: {} }], + usage: { prompt_tokens: 1, completion_tokens: 1 }, + } + }, + } + return { + withResponse: vi.fn().mockResolvedValue({ data: mockStream }), + } + }) + + const stream = testHandler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + await stream.next() + + expect(receivedConfig).not.toHaveProperty("signal") + }) + }) }) diff --git a/src/api/providers/__tests__/lmstudio.spec.ts b/src/api/providers/__tests__/lmstudio.spec.ts index c6ebd8a6e..7fd2125cf 100644 --- a/src/api/providers/__tests__/lmstudio.spec.ts +++ b/src/api/providers/__tests__/lmstudio.spec.ts @@ -157,6 +157,85 @@ describe("LmStudioHandler", () => { }) }) + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new DOMException("The operation was aborted.", "AbortError") + } + yield { + choices: [{ delta: { content: "response" }, index: 0 }], + usage: null, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new DOMException("The operation was aborted.", "AbortError") + } + yield { + choices: [{ delta: { content: "response" }, index: 0 }], + usage: null, + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) + describe("getModel", () => { it("should return model info", () => { const modelInfo = handler.getModel() diff --git a/src/api/providers/__tests__/mimo.spec.ts b/src/api/providers/__tests__/mimo.spec.ts index 7da1c8446..f0d8c97f0 100644 --- a/src/api/providers/__tests__/mimo.spec.ts +++ b/src/api/providers/__tests__/mimo.spec.ts @@ -376,6 +376,7 @@ describe("MimoHandler", () => { expect.objectContaining({ extra_body: { thinking: { type: "enabled" } }, }), + undefined, ) }) @@ -999,4 +1000,95 @@ describe("MimoHandler", () => { expect(params.model).toBe("mimo-v2.5") }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const handler = new MimoHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" }, index: 0 }], + usage: null, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const handler = new MimoHandler(mockOptions) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" }, index: 0 }], usage: null } + yield { + choices: [{ delta: {}, index: 0, finish_reason: "stop" }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, + } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const handler = new MimoHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" }, index: 0 }], usage: null } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/minimax.spec.ts b/src/api/providers/__tests__/minimax.spec.ts index d87ae1190..448a33801 100644 --- a/src/api/providers/__tests__/minimax.spec.ts +++ b/src/api/providers/__tests__/minimax.spec.ts @@ -296,16 +296,17 @@ describe("MiniMaxHandler", () => { const messageGenerator = handlerWithModel.createMessage(systemPrompt, messages) await messageGenerator.next() - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - model: modelId, - max_tokens: Math.min(modelInfo.maxTokens, Math.ceil(modelInfo.contextWindow * 0.2)), - temperature: 1, - system: expect.any(Array), - messages: expect.any(Array), - stream: true, - }), + // Verify the first argument contains the expected request params + const firstCallArgs = mockCreate.mock.calls[0] + const requestParams = firstCallArgs[0] + expect(requestParams.model).toBe(modelId) + expect(requestParams.stream).toBe(true) + expect(requestParams.max_tokens).toBe( + Math.min(modelInfo.maxTokens, Math.ceil(modelInfo.contextWindow * 0.2)), ) + expect(requestParams.temperature).toBe(1) + expect(requestParams.system).toBeInstanceOf(Array) + expect(requestParams.messages).toBeInstanceOf(Array) }) it("should use temperature 1 by default", async () => { @@ -320,11 +321,10 @@ describe("MiniMaxHandler", () => { const messageGenerator = handler.createMessage("test", []) await messageGenerator.next() - expect(mockCreate).toHaveBeenCalledWith( - expect.objectContaining({ - temperature: 1, - }), - ) + // Verify the first argument contains the expected request params + const firstCallArgs = mockCreate.mock.calls[0] + const requestParams = firstCallArgs[0] + expect(requestParams.temperature).toBe(1) }) it("should handle thinking blocks in stream", async () => { @@ -480,4 +480,92 @@ describe("MiniMaxHandler", () => { expect(model.contextWindow).toBe(204_800) }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + handler = new MiniMaxHandler({ minimaxApiKey: "test-key" }) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + type: "content_block_start", + content_block: { type: "text", text: "" }, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + handler = new MiniMaxHandler({ minimaxApiKey: "test-key" }) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { type: "message_start", message: { usage: { input_tokens: 10, output_tokens: 5 } } } + yield { type: "content_block_delta", delta: { type: "text_delta", text: "response" } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + handler = new MiniMaxHandler({ minimaxApiKey: "test-key" }) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { type: "content_block_delta", delta: { type: "text_delta", text: "response" } } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/moonshot.spec.ts b/src/api/providers/__tests__/moonshot.spec.ts index c0fd832a1..c04c061b2 100644 --- a/src/api/providers/__tests__/moonshot.spec.ts +++ b/src/api/providers/__tests__/moonshot.spec.ts @@ -471,4 +471,88 @@ describe("MoonshotHandler", () => { expect(toolCallChunks[0].arguments).toBe('{"path":"test.ts"}') }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + + async function* mockFullStream() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { type: "text-delta", text: "response" } + } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 1, outputTokens: 1, details: {}, raw: {} }), + }) + + const stream = handler.createMessage("system", [{ role: "user" as const, content: "Hello" }], { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + async function* mockFullStream() { + yield { type: "text-delta", text: "Hello" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 10, outputTokens: 5, details: {}, raw: {} }), + }) + + const stream = handler.createMessage("system", [{ role: "user" as const, content: "Hello" }]) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + async function* mockFullStream() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { type: "text-delta", text: "response" } + } + + mockStreamText.mockReturnValue({ + fullStream: mockFullStream(), + usage: Promise.resolve({ inputTokens: 1, outputTokens: 1, details: {}, raw: {} }), + }) + + const stream = handler.createMessage("system", [{ role: "user" as const, content: "Hello" }], { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 3c006f831..b256745b0 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -1279,6 +1279,104 @@ describe("OpenAiHandler", () => { }) }) }) +describe("abort signal", () => { + let mockOptions: ApiHandlerOptions + + beforeEach(() => { + mockOptions = { + openAiApiKey: "test-api-key", + openAiModelId: "gpt-4", + openAiBaseUrl: "https://api.openai.com/v1", + } + vi.clearAllMocks() + }) + + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const handler = new OpenAiHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" } }], + usage: null, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const handler = new OpenAiHandler(mockOptions) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const handler = new OpenAiHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) +}) describe("getOpenAiModels", () => { beforeEach(() => { diff --git a/src/api/providers/__tests__/openrouter.spec.ts b/src/api/providers/__tests__/openrouter.spec.ts index b21d409d0..c29ad7939 100644 --- a/src/api/providers/__tests__/openrouter.spec.ts +++ b/src/api/providers/__tests__/openrouter.spec.ts @@ -715,4 +715,104 @@ describe("OpenRouterHandler", () => { ) }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const handler = new OpenRouterHandler(mockOptions) + + const mockCreate = vitest.fn().mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" } }], + usage: null, + } + } + }, + } + }) + + ;(OpenAI as any).prototype.chat = { + completions: { create: mockCreate }, + } as any + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const handler = new OpenRouterHandler(mockOptions) + + const mockCreate = vitest.fn().mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + ;(OpenAI as any).prototype.chat = { + completions: { create: mockCreate }, + } as any + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const handler = new OpenRouterHandler(mockOptions) + + const mockCreate = vitest.fn().mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + ;(OpenAI as any).prototype.chat = { + completions: { create: mockCreate }, + } as any + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/poe.spec.ts b/src/api/providers/__tests__/poe.spec.ts index b22d42179..9fae57991 100644 --- a/src/api/providers/__tests__/poe.spec.ts +++ b/src/api/providers/__tests__/poe.spec.ts @@ -310,4 +310,83 @@ describe("PoeHandler", () => { ) }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const handler = new PoeHandler({ poeApiKey: "test-key" }) + + mockStreamText.mockReturnValue({ + [Symbol.asyncIterator]: async function* () { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { type: "text", text: "response" } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const handler = new PoeHandler({ poeApiKey: "test-key" }) + + mockStreamText.mockReturnValue({ + [Symbol.asyncIterator]: async function* () { + yield { type: "text", text: "Hello" } + yield { type: "text", text: " world!" } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const handler = new PoeHandler({ poeApiKey: "test-key" }) + + mockStreamText.mockReturnValue({ + [Symbol.asyncIterator]: async function* () { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { type: "text", text: "response" } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts index 3615c0f92..42bef3abf 100644 --- a/src/api/providers/__tests__/qwen-code-native-tools.spec.ts +++ b/src/api/providers/__tests__/qwen-code-native-tools.spec.ts @@ -103,6 +103,7 @@ describe("QwenCodeHandler Native Tools", () => { ]), parallel_tool_calls: true, }), + undefined, ) }) @@ -126,6 +127,7 @@ describe("QwenCodeHandler Native Tools", () => { expect.objectContaining({ tool_choice: "auto", }), + undefined, ) }) @@ -237,6 +239,7 @@ describe("QwenCodeHandler Native Tools", () => { expect.objectContaining({ parallel_tool_calls: true, }), + undefined, ) }) @@ -444,4 +447,94 @@ describe("QwenCodeHandler Native Tools", () => { expect(endChunks).toHaveLength(1) }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + } + }, + } + }) + + const handler = new QwenCodeHandler({ + apiKey: "test-key", + } as any) + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const handler = new QwenCodeHandler({ + apiKey: "test-key", + } as any) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + const handler = new QwenCodeHandler({ + apiKey: "test-key", + } as any) + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/requesty.spec.ts b/src/api/providers/__tests__/requesty.spec.ts index 4dfa2a7c9..d395ace4b 100644 --- a/src/api/providers/__tests__/requesty.spec.ts +++ b/src/api/providers/__tests__/requesty.spec.ts @@ -218,6 +218,7 @@ describe("RequestyHandler", () => { stream_options: { include_usage: true }, temperature: 0, }), + undefined, ) }) @@ -251,6 +252,7 @@ describe("RequestyHandler", () => { thinking: { type: "adaptive" }, temperature: undefined, }), + undefined, ) }) @@ -400,6 +402,7 @@ describe("RequestyHandler", () => { ]), tool_choice: "auto", }), + undefined, ) }) @@ -538,4 +541,92 @@ describe("RequestyHandler", () => { await expect(handler.completePrompt("test prompt")).rejects.toThrow("Unexpected error") }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const handler = new RequestyHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" } }], + usage: null, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const handler = new RequestyHandler(mockOptions) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const handler = new RequestyHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/unbound.spec.ts b/src/api/providers/__tests__/unbound.spec.ts index d8f75fe85..4873d081a 100644 --- a/src/api/providers/__tests__/unbound.spec.ts +++ b/src/api/providers/__tests__/unbound.spec.ts @@ -179,6 +179,7 @@ describe("UnboundHandler", () => { mode: "architect", }, }), + undefined, ) }) @@ -201,4 +202,117 @@ describe("UnboundHandler", () => { }), ) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const mockCreateLocal = vi.fn().mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + yield { choices: [{ delta: { content: "response" } }], usage: null } + await new Promise((resolve) => setTimeout(resolve, 10)) + } + throw new Error("AbortError: The operation was aborted") + }, + } + }) + + vi.mocked(OpenAI).mockImplementation( + () => + ({ + chat: { completions: { create: mockCreateLocal } as any } as any, + }) as any, + ) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const mockCreateLocal = vi.fn().mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + vi.mocked(OpenAI).mockImplementation( + () => + ({ + chat: { completions: { create: mockCreateLocal } as any } as any, + }) as any, + ) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const mockCreateLocal = vi.fn().mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + vi.mocked(OpenAI).mockImplementation( + () => + ({ + chat: { completions: { create: mockCreateLocal } as any } as any, + }) as any, + ) + + const handler = new UnboundHandler({ + unboundApiKey: "test-key", + unboundModelId: "openai/gpt-4o", + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/xai.spec.ts b/src/api/providers/__tests__/xai.spec.ts index 78f091b8e..890f03028 100644 --- a/src/api/providers/__tests__/xai.spec.ts +++ b/src/api/providers/__tests__/xai.spec.ts @@ -106,6 +106,7 @@ describe("XAIHandler", () => { store: false, include: ["reasoning.encrypted_content"], }), + undefined, ) }) @@ -235,6 +236,7 @@ describe("XAIHandler", () => { tool_choice: "auto", parallel_tool_calls: true, }), + undefined, ) }) @@ -272,6 +274,7 @@ describe("XAIHandler", () => { reasoning_effort: "high", }), }), + undefined, ) }) @@ -297,4 +300,85 @@ describe("XAIHandler", () => { const stream = handler.createMessage("test prompt", []) await expect(stream.next()).rejects.toThrow(`xAI completion error: ${errorMessage}`) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + + mockResponsesCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + while (true) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { type: "response.output_text.delta", delta: "chunk" } + } + }, + }) + + handler = new XAIHandler({ xaiApiKey: "test-key" }) + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + handler = new XAIHandler({ xaiApiKey: "test-key" }) + + mockResponsesCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + yield { type: "response.output_text.delta", delta: "Hello" } + yield { type: "response.output_text.delta", delta: " world!" } + yield { type: "response.completed", response: { id: "test" }, usage: {} } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + mockResponsesCreate.mockResolvedValueOnce({ + [Symbol.asyncIterator]: async function* () { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { type: "text", text: "response" } + }, + }) + + handler = new XAIHandler({ xaiApiKey: "test-key" }) + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/zai.spec.ts b/src/api/providers/__tests__/zai.spec.ts index 66266a2fe..c2a3558cd 100644 --- a/src/api/providers/__tests__/zai.spec.ts +++ b/src/api/providers/__tests__/zai.spec.ts @@ -544,6 +544,7 @@ describe("ZAiHandler", () => { model: "glm-5.1", max_tokens: 40_000, }), + undefined, ) }) @@ -584,6 +585,7 @@ describe("ZAiHandler", () => { model: "glm-5.1", max_tokens: 100_000, }), + undefined, ) }) @@ -614,6 +616,7 @@ describe("ZAiHandler", () => { model: "glm-4.7", thinking: { type: "enabled" }, }), + undefined, ) }) @@ -644,6 +647,7 @@ describe("ZAiHandler", () => { thinking: { type: "enabled" }, reasoning_effort: "high", }), + undefined, ) }) @@ -674,6 +678,7 @@ describe("ZAiHandler", () => { thinking: { type: "enabled" }, reasoning_effort: "max", }), + undefined, ) }) @@ -730,6 +735,7 @@ describe("ZAiHandler", () => { thinking: { type: "enabled" }, reasoning_effort: "high", }), + undefined, ) }) @@ -761,6 +767,7 @@ describe("ZAiHandler", () => { model: "glm-4.7", thinking: { type: "disabled" }, }), + undefined, ) }) @@ -792,6 +799,7 @@ describe("ZAiHandler", () => { model: "glm-4.7", thinking: { type: "enabled" }, }), + undefined, ) }) @@ -845,6 +853,7 @@ describe("ZAiHandler", () => { model: "glm-5-turbo", thinking: { type: "enabled" }, }), + undefined, ) }) @@ -875,7 +884,100 @@ describe("ZAiHandler", () => { model: "glm-5-turbo", thinking: { type: "disabled" }, }), + undefined, ) }) }) + + describe("abort signal", () => { + beforeEach(() => { + vitest.clearAllMocks() + }) + + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_coding" }) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" } }], + usage: null, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_coding" }) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + handler = new ZAiHandler({ zaiApiKey: "test-zai-api-key", zaiApiLine: "international_coding" }) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/__tests__/zoo-gateway.spec.ts b/src/api/providers/__tests__/zoo-gateway.spec.ts index e0c060db3..3374c24e4 100644 --- a/src/api/providers/__tests__/zoo-gateway.spec.ts +++ b/src/api/providers/__tests__/zoo-gateway.spec.ts @@ -635,4 +635,92 @@ describe("ZooGatewayHandler", () => { ) }) }) + + describe("abort signal", () => { + it("should handle abort signal triggered during request", async () => { + const controller = new AbortController() + const handler = new ZooGatewayHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + while (!controller.signal.aborted) { + await new Promise((resolve) => setTimeout(resolve, 10)) + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { + choices: [{ delta: { content: "response" } }], + usage: null, + } + } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + setTimeout(() => controller.abort(), 50) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + + it("should work normally without abortSignal", async () => { + const handler = new ZooGatewayHandler(mockOptions) + + mockCreate.mockResolvedValue({ + async *[Symbol.asyncIterator]() { + yield { choices: [{ delta: { content: "Hello" } }], usage: null } + yield { choices: [{ delta: {} }], usage: { prompt_tokens: 10, completion_tokens: 5 } } + }, + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any) + + const chunks: any[] = [] + for await (const chunk of stream) { + chunks.push(chunk) + } + + expect(chunks.length).toBeGreaterThan(0) + }) + + it("should abort immediately if signal is already aborted", async () => { + const controller = new AbortController() + controller.abort() + + const handler = new ZooGatewayHandler(mockOptions) + + mockCreate.mockImplementation(async (options: unknown) => { + return { + async *[Symbol.asyncIterator]() { + if (controller.signal.aborted) { + throw new Error("AbortError: The operation was aborted") + } + yield { choices: [{ delta: { content: "response" } }], usage: null } + }, + } + }) + + const stream = handler.createMessage("system", [{ role: "user", content: "Hello" }] as any, { + taskId: "test", + tools: [], + abortSignal: controller.signal, + }) + + await expect(async () => { + for await (const _chunk of stream) { + // consume stream + } + }).rejects.toThrow(/abort/i) + }) + }) }) diff --git a/src/api/providers/anthropic-vertex.ts b/src/api/providers/anthropic-vertex.ts index 9089562f4..c9be114dd 100644 --- a/src/api/providers/anthropic-vertex.ts +++ b/src/api/providers/anthropic-vertex.ts @@ -112,10 +112,11 @@ export class AnthropicVertexHandler extends BaseProvider implements SingleComple ...nativeToolParams, } as Anthropic.Messages.MessageCreateParamsStreaming - // and prompt caching - const requestOptions = betas?.length ? { headers: { "anthropic-beta": betas.join(",") } } : undefined - - const stream = await this.client.messages.create(params, requestOptions) + const stream = await this.client.messages.create(params, { + // and prompt caching + ...(betas?.length ? { headers: { "anthropic-beta": betas.join(",") } } : {}), + ...(metadata?.abortSignal ? { signal: metadata.abortSignal } : {}), + }) for await (const chunk of stream) { switch (chunk.type) { diff --git a/src/api/providers/anthropic.ts b/src/api/providers/anthropic.ts index 7a4ef30ad..b96b222db 100644 --- a/src/api/providers/anthropic.ts +++ b/src/api/providers/anthropic.ts @@ -150,6 +150,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa stream: true, ...nativeToolParams, } + const abortSignal = metadata?.abortSignal stream = await this.client.messages.create( requestParams as Anthropic.Messages.MessageCreateParamsStreaming, (() => { @@ -176,9 +177,12 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa case "claude-haiku-4-5-20251001": case "claude-3-haiku-20240307": betas.push("prompt-caching-2024-07-31") - return { headers: { "anthropic-beta": betas.join(",") } } + return { + headers: { "anthropic-beta": betas.join(",") }, + ...(metadata?.abortSignal ? { signal: metadata.abortSignal } : {}), + } default: - return undefined + return abortSignal ? { signal: abortSignal } : undefined } })(), ) @@ -209,6 +213,7 @@ export class AnthropicHandler extends BaseProvider implements SingleCompletionHa } stream = (await this.client.messages.create( requestParams as Anthropic.Messages.MessageCreateParamsStreaming, + metadata?.abortSignal ? { signal: metadata?.abortSignal } : undefined, )) as any } catch (error) { TelemetryService.instance.captureException( diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index 28c812660..fd2f6c748 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -115,7 +115,12 @@ export abstract class BaseOpenAiCompatibleProvider messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const stream = await this.createStream(systemPrompt, messages, metadata) + const stream = await this.createStream( + systemPrompt, + messages, + metadata, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) const matcher = new TagMatcher( "think", diff --git a/src/api/providers/deepseek.ts b/src/api/providers/deepseek.ts index 819fe6c7b..292b93cec 100644 --- a/src/api/providers/deepseek.ts +++ b/src/api/providers/deepseek.ts @@ -134,7 +134,10 @@ export class DeepSeekHandler extends OpenAiHandler { try { stream = await this.client.chat.completions.create( requestOptions as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + { + ...(isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + ...(metadata?.abortSignal ? { signal: metadata?.abortSignal } : {}), + }, ) } catch (error) { const { handleOpenAIError } = await import("./utils/openai-error-handler") diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 6c8168cae..d704bb23e 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -302,6 +302,7 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl maxOutputTokens, temperature: temperatureConfig, ...(tools.length > 0 ? { tools } : {}), + ...(metadata?.abortSignal ? { signal: metadata.abortSignal } : {}), } // Do not pass metadata.allowedFunctionNames to Gemini. Live API testing showed diff --git a/src/api/providers/lite-llm.ts b/src/api/providers/lite-llm.ts index 981f984de..6081cc83d 100644 --- a/src/api/providers/lite-llm.ts +++ b/src/api/providers/lite-llm.ts @@ -237,7 +237,10 @@ export class LiteLLMHandler extends RouterProvider implements SingleCompletionHa try { const { data: completion } = await this.client.chat.completions - .create(requestOptions, { headers: requestHeaders }) + .create(requestOptions, { + headers: requestHeaders, + ...(metadata?.abortSignal ? { signal: metadata.abortSignal } : {}), + }) .withResponse() let lastUsage diff --git a/src/api/providers/lm-studio.ts b/src/api/providers/lm-studio.ts index d04bd157c..7b088f0c1 100644 --- a/src/api/providers/lm-studio.ts +++ b/src/api/providers/lm-studio.ts @@ -98,11 +98,19 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan let results try { - results = await this.client.chat.completions.create(params) + results = await this.client.chat.completions.create( + params, + metadata?.abortSignal ? { signal: metadata?.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } + // Check if signal was already aborted before entering the loop + if (metadata?.abortSignal?.aborted) { + throw new DOMException("The operation was aborted.", "AbortError") + } + const matcher = new TagMatcher( "think", (chunk) => @@ -113,6 +121,11 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan ) for await (const chunk of results) { + // Check abort signal during stream iteration + if (metadata?.abortSignal?.aborted) { + throw new DOMException("The operation was aborted.", "AbortError") + } + const delta = chunk.choices[0]?.delta const finishReason = chunk.choices[0]?.finish_reason @@ -163,6 +176,11 @@ export class LmStudioHandler extends BaseProvider implements SingleCompletionHan outputTokens, } as const } catch (error) { + // Preserve AbortError instead of wrapping it + if (error instanceof DOMException && error.name === "AbortError") { + throw error + } + throw new Error( "Please check the LM Studio developer logs to debug what went wrong. You may need to load the model with a larger context length to work with Roo Code's prompts.", ) diff --git a/src/api/providers/mimo.ts b/src/api/providers/mimo.ts index 2901c2e92..4c6a38755 100644 --- a/src/api/providers/mimo.ts +++ b/src/api/providers/mimo.ts @@ -100,7 +100,10 @@ export class MimoHandler extends OpenAiHandler { let stream: AsyncIterable try { - stream = (await this.client.chat.completions.create(params as any)) as any + stream = (await this.client.chat.completions.create( + params as any, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + )) as any } catch (error) { throw handleProviderError(error, "MiMo") } diff --git a/src/api/providers/minimax.ts b/src/api/providers/minimax.ts index 93aa7ea8f..43298f3c0 100644 --- a/src/api/providers/minimax.ts +++ b/src/api/providers/minimax.ts @@ -113,7 +113,10 @@ export class MiniMaxHandler extends BaseProvider implements SingleCompletionHand tool_choice: convertOpenAIToolChoice(metadata?.tool_choice), } - const stream = await this.client.messages.create(requestParams) + const stream = await this.client.messages.create( + requestParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) let inputTokens = 0 let outputTokens = 0 diff --git a/src/api/providers/openai-compatible.ts b/src/api/providers/openai-compatible.ts index d129e7245..14dce3729 100644 --- a/src/api/providers/openai-compatible.ts +++ b/src/api/providers/openai-compatible.ts @@ -174,6 +174,7 @@ export abstract class OpenAICompatibleHandler extends BaseProvider implements Si maxOutputTokens: this.getMaxOutputTokens(), tools: aiSdkTools, toolChoice: this.mapToolChoice(metadata?.tool_choice), + abortSignal: metadata?.abortSignal, } // Use streamText for streaming responses diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index abef612d8..161ee6611 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -175,10 +175,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl let stream try { - stream = await this.client.chat.completions.create( - requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) + stream = await this.client.chat.completions.create(requestOptions, { + ...(isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + ...(metadata?.abortSignal ? { signal: metadata?.abortSignal } : {}), + }) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -241,10 +241,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl let response try { - response = await this.client.chat.completions.create( - requestOptions, - this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) + response = await this.client.chat.completions.create(requestOptions, { + ...(this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + ...(metadata?.abortSignal ? { signal: metadata?.abortSignal } : {}), + }) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/openrouter.ts b/src/api/providers/openrouter.ts index 1ac9c465b..f426e5967 100644 --- a/src/api/providers/openrouter.ts +++ b/src/api/providers/openrouter.ts @@ -331,14 +331,15 @@ export class OpenRouterHandler extends BaseProvider implements SingleCompletionH tool_choice: metadata?.tool_choice, } - // Add Anthropic beta header for fine-grained tool streaming when using Anthropic models - const requestOptions = modelId.startsWith("anthropic/") - ? { headers: { "x-anthropic-beta": "fine-grained-tool-streaming-2025-05-14" } } - : undefined - let stream try { - stream = await this.client.chat.completions.create(completionParams, requestOptions) + stream = await this.client.chat.completions.create(completionParams, { + // Add Anthropic beta header for fine-grained tool streaming when using Anthropic models + ...(modelId.startsWith("anthropic/") + ? { headers: { "x-anthropic-beta": "fine-grained-tool-streaming-2025-05-14" } } + : {}), + ...(metadata?.abortSignal ? { signal: metadata.abortSignal } : {}), + }) } catch (error) { // Try to parse as OpenRouter error structure using Zod const parseResult = OpenRouterErrorResponseSchema.safeParse(error) diff --git a/src/api/providers/poe.ts b/src/api/providers/poe.ts index 536d222ac..92fc1b365 100644 --- a/src/api/providers/poe.ts +++ b/src/api/providers/poe.ts @@ -101,6 +101,7 @@ export class PoeHandler extends BaseProvider implements SingleCompletionHandler tools: aiSdkTools, toolChoice: mapToolChoice(metadata?.tool_choice as any), ...(Object.keys(providerOptions).length > 0 && { providerOptions }), + ...(metadata?.abortSignal ? { signal: metadata.abortSignal } : {}), }) } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) diff --git a/src/api/providers/qwen-code.ts b/src/api/providers/qwen-code.ts index cdf0f88e4..fc0dcd427 100644 --- a/src/api/providers/qwen-code.ts +++ b/src/api/providers/qwen-code.ts @@ -239,7 +239,12 @@ export class QwenCodeHandler extends BaseProvider implements SingleCompletionHan parallel_tool_calls: metadata?.parallelToolCalls ?? true, } - const stream = await this.callApiWithRetry(() => client.chat.completions.create(requestOptions)) + const stream = await this.callApiWithRetry(() => + client.chat.completions.create( + requestOptions, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ), + ) let fullContent = "" diff --git a/src/api/providers/requesty.ts b/src/api/providers/requesty.ts index c490227d4..9fe70909b 100644 --- a/src/api/providers/requesty.ts +++ b/src/api/providers/requesty.ts @@ -162,8 +162,10 @@ export class RequestyHandler extends BaseProvider implements SingleCompletionHan let stream try { - // With streaming params type, SDK returns an async iterable stream - stream = await this.client.chat.completions.create(completionParams) + stream = await this.client.chat.completions.create( + completionParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/unbound.ts b/src/api/providers/unbound.ts index 0ec7a2466..2fb9b89c7 100644 --- a/src/api/providers/unbound.ts +++ b/src/api/providers/unbound.ts @@ -151,7 +151,10 @@ export class UnboundHandler extends BaseProvider implements SingleCompletionHand let stream try { - stream = await this.client.chat.completions.create(completionParams) + stream = await this.client.chat.completions.create( + completionParams, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, + ) } catch (error) { throw handleOpenAIError(error, this.providerName) } diff --git a/src/api/providers/xai.ts b/src/api/providers/xai.ts index e5c0ba0a8..01c1b1855 100644 --- a/src/api/providers/xai.ts +++ b/src/api/providers/xai.ts @@ -127,10 +127,15 @@ export class XAIHandler extends BaseProvider implements SingleCompletionHandler let stream: AsyncIterable try { - stream = (await this.client.responses.create({ - ...requestBody, - stream: true, - } as any)) as unknown as AsyncIterable + // Support metadata.abortSignal for request cancellation (OpenAI SDK Responses API) + const createOptions = metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined + stream = (await this.client.responses.create( + { + ...requestBody, + stream: true, + } as any, + createOptions, + )) as unknown as AsyncIterable } catch (error) { const errorMessage = error instanceof Error ? error.message : String(error) const apiError = new ApiProviderError(errorMessage, this.providerName, model.id, "createMessage") diff --git a/src/api/providers/zai.ts b/src/api/providers/zai.ts index c8f720a97..6485f5143 100644 --- a/src/api/providers/zai.ts +++ b/src/api/providers/zai.ts @@ -125,6 +125,7 @@ export class ZAiHandler extends BaseOpenAiCompatibleProvider { try { return this.client.chat.completions.create( params as OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming, + metadata?.abortSignal ? { signal: metadata.abortSignal } : undefined, ) } catch (error) { throw handleOpenAIError(error, this.providerName) diff --git a/src/api/providers/zoo-gateway.ts b/src/api/providers/zoo-gateway.ts index 4724464ff..d2d862b53 100644 --- a/src/api/providers/zoo-gateway.ts +++ b/src/api/providers/zoo-gateway.ts @@ -220,6 +220,7 @@ export class ZooGatewayHandler extends RouterProvider implements SingleCompletio try { const completion = await this.client.chat.completions.create(body, { headers: requestHeaders, + ...(metadata?.abortSignal ? { signal: metadata.abortSignal } : {}), }) for await (const chunk of completion) {