Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ import {
} from "./providers"
import { NativeOllamaHandler } from "./providers/native-ollama"

export interface CompletePromptOptions {
/** Abort signal for cancelling the request mid-flight */
signal?: AbortSignal
/** Optional timeout override (ms) — falls back to provider default if omitted */
timeoutMs?: number
}

export interface SingleCompletionHandler {
completePrompt(prompt: string): Promise<string>
completePrompt(prompt: string, options?: CompletePromptOptions): Promise<string>
}

export interface ApiHandlerCreateMessageMetadata {
Expand Down
91 changes: 79 additions & 12 deletions src/api/providers/__tests__/anthropic-vertex.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -834,18 +834,22 @@ describe("VertexHandler", () => {

const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(handler["client"].messages.create).toHaveBeenCalledWith({
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
messages: [
{
role: "user",
content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }],
},
],
stream: false,
})
expect(handler["client"].messages.create).toHaveBeenCalledWith(
{
model: "claude-3-5-sonnet-v2@20241022",
max_tokens: 8192,
temperature: 0,
messages: [
{
role: "user",
content: [{ type: "text", text: "Test prompt", cache_control: { type: "ephemeral" } }],
},
],
stream: false,
thinking: undefined,
},
undefined,
)
})

it("should handle API errors for Claude", async () => {
Expand Down Expand Up @@ -895,6 +899,69 @@ describe("VertexHandler", () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})

it("should pass abort signal through to client", async () => {
handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const controller = new AbortController()
const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "response" }],
})
;(handler["client"].messages as any).create = mockCreate

await handler.completePrompt("test prompt", { signal: controller.signal })
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), {
signal: controller.signal,
})
})

it("should work without options (backward compatible)", async () => {
handler = new AnthropicVertexHandler({
apiModelId: "claude-3-5-sonnet-v2@20241022",
vertexProjectId: "test-project",
vertexRegion: "us-central1",
})

const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "response" }],
})
;(handler["client"].messages as any).create = mockCreate

const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: expect.any(String) }), undefined)
})

it("completePrompt should pass signal through to client", async () => {
const controller = new AbortController()
const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "response" }],
})
;(handler["client"].messages as any).create = mockCreate

await handler.completePrompt("test prompt", { signal: controller.signal, timeoutMs: 5000 })
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: expect.any(String) }),
{ signal: controller.signal }, // only signal is passed, not timeoutMs
)
})

it("completePrompt should not pass timeoutMs when no signal provided", async () => {
const mockCreate = vitest.fn().mockResolvedValue({
content: [{ type: "text", text: "response" }],
})
;(handler["client"].messages as any).create = mockCreate

await handler.completePrompt("test prompt", { timeoutMs: 3000 })
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: expect.any(String) }),
undefined, // anthropic-vertex only passes signal, not timeoutMs
)
})
})

describe("getModel", () => {
Expand Down
118 changes: 110 additions & 8 deletions src/api/providers/__tests__/anthropic.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -434,14 +434,17 @@ describe("AnthropicHandler", () => {
it("should complete prompt successfully", async () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("Test response")
expect(mockCreate).toHaveBeenCalledWith({
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
})
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "Test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
undefined,
)
})

it("should handle API errors", async () => {
Expand All @@ -464,6 +467,105 @@ describe("AnthropicHandler", () => {
const result = await handler.completePrompt("Test prompt")
expect(result).toBe("")
})

it("should pass abort signal through to client", async () => {
const controller = new AbortController()
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
await handler.completePrompt("test prompt", { signal: controller.signal })
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
{ signal: controller.signal },
)
})

it("should work without options (backward compatible)", async () => {
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
undefined,
)
})

it("should merge signal and timeout together", async () => {
const controller = new AbortController()
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
await handler.completePrompt("test prompt", { signal: controller.signal, timeoutMs: 10000 })
expect(mockCreate).toHaveBeenCalledWith(
{
model: mockOptions.apiModelId,
messages: [{ role: "user", content: "test prompt" }],
max_tokens: 8192,
temperature: 0,
thinking: undefined,
stream: false,
},
expect.objectContaining({ signal: controller.signal, timeout: 10000 }),
)
})
Comment thread
coderabbitai[bot] marked this conversation as resolved.

it("should trigger timeout when timeoutMs elapses before request completes", async () => {
const mockCreateWithTimeout = vitest
.fn()
.mockImplementation(
async () =>
new Promise((resolve) =>
setTimeout(() => resolve({ content: [{ type: "text", text: "response" }] }), 500),
),
)

const handlerTimeout = new AnthropicHandler(mockOptions)
// Replace the mock on the existing handler's client
handlerTimeout["client"].messages.create = mockCreateWithTimeout

const controller = new AbortController()
let timeoutTriggered = false
handlerTimeout.completePrompt("test prompt", { signal: controller.signal, timeoutMs: 50 }).catch(() => {
timeoutTriggered = true
})

// Wait for timeout to trigger (50ms timeout + buffer)
await new Promise((resolve) => setTimeout(resolve, 150))

// Verify the API was called with timeout options
expect(mockCreateWithTimeout).toHaveBeenCalled()
// User signal should not be aborted (timeout mechanism aborts its own internal signal)
expect(controller.signal.aborted).toBe(false)
})

it("should pass the same signal instance", async () => {
const controller = new AbortController()
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
await handler.completePrompt("test prompt", { signal: controller.signal })
expect(mockCreate).toHaveBeenCalledWith(
expect.any(Object),
expect.objectContaining({ signal: controller.signal }),
)
// Verify it's the exact same instance, not just equal
const callOptions = mockCreate.mock.calls[0][1]
expect(callOptions?.signal).toBe(controller.signal)
})

it("should not include signal-related options when not provided", async () => {
mockCreate.mockResolvedValueOnce({ content: [{ type: "text", text: "response" }] })
await handler.completePrompt("test prompt")
expect(mockCreate).toHaveBeenCalledWith(expect.any(Object), undefined)
})
})

describe("getModel", () => {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,4 +116,58 @@ describe("BaseOpenAiCompatibleProvider Timeout Configuration", () => {
}),
)
})

describe("completePrompt", () => {
it("should pass timeout through to client when both signal and timeoutMs provided", async () => {
const handler = new TestOpenAiCompatibleProvider("test-api-key")
const controller = new AbortController()
const mockCreate = vitest.fn().mockResolvedValue({
choices: [{ message: { content: "response" } }],
})
handler["client"].chat.completions.create = mockCreate

await handler.completePrompt("test prompt", { signal: controller.signal, timeoutMs: 5000 })
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: "test-model" }),
expect.objectContaining({ signal: expect.any(AbortSignal), timeout: 5000 }),
)
})

it("should pass only timeoutMs when no signal provided", async () => {
const handler = new TestOpenAiCompatibleProvider("test-api-key")
const mockCreate = vitest.fn().mockResolvedValue({
choices: [{ message: { content: "response" } }],
})
handler["client"].chat.completions.create = mockCreate

await handler.completePrompt("test prompt", { timeoutMs: 3000 })
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: "test-model" }), { timeout: 3000 })
})

it("should handle timeoutMs=0 as valid value (!== undefined check)", async () => {
const handler = new TestOpenAiCompatibleProvider("test-api-key")
const mockCreate = vitest.fn().mockResolvedValue({
choices: [{ message: { content: "response" } }],
})
handler["client"].chat.completions.create = mockCreate

await handler.completePrompt("test prompt", { timeoutMs: 0 })
expect(mockCreate).toHaveBeenCalledWith(expect.objectContaining({ model: "test-model" }), { timeout: 0 })
})

it("should work without options (backward compatible)", async () => {
const handler = new TestOpenAiCompatibleProvider("test-api-key")
const mockCreate = vitest.fn().mockResolvedValue({
choices: [{ message: { content: "response" } }],
})
handler["client"].chat.completions.create = mockCreate

const result = await handler.completePrompt("test prompt")
expect(result).toBe("response")
expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({ model: "test-model" }),
{}, // empty object when no options
)
})
})
})
73 changes: 73 additions & 0 deletions src/api/providers/__tests__/bedrock.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1576,6 +1576,79 @@ describe("AwsBedrockHandler", () => {
expect(isAdaptiveThinkingModel("anthropic.claude-3-5-sonnet-20241022-v2:0")).toBe(false)
expect(isAdaptiveThinkingModel("amazon.nova-lite-v1:0")).toBe(false)
})

it("should pass abort signal through to client.send", async () => {
const mockSend = vi.fn()

const handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})

// Set up the mock on the handler's client instance directly
const clientInstance = (handler as any).client
expect(clientInstance).toBeDefined()
clientInstance.send = mockSend

const controller = new AbortController()
mockSend.mockResolvedValueOnce({
output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null },
})

await handler.completePrompt("test prompt", { signal: controller.signal })

expect(mockSend).toHaveBeenCalledWith(expect.any(Object), { abortSignal: controller.signal })
})

it("should work without options (backward compatible)", async () => {
const mockSend = vi.fn()

const handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})

const clientInstance = (handler as any).client
expect(clientInstance).toBeDefined()
clientInstance.send = mockSend

mockSend.mockResolvedValueOnce({
output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null },
})

const result = await handler.completePrompt("test prompt")

expect(result).toBe("response")
expect(mockSend).toHaveBeenCalledWith(expect.any(Object), undefined)
})

it("completePrompt should pass timeoutMs through to client", async () => {
const mockSend = vi.fn()

const handler = new AwsBedrockHandler({
apiModelId: "anthropic.claude-3-5-sonnet-20241022-v2:0",
awsAccessKey: "test-access-key",
awsSecretKey: "test-secret-key",
awsRegion: "us-east-1",
})

const clientInstance = (handler as any).client
expect(clientInstance).toBeDefined()
clientInstance.send = mockSend

mockSend.mockResolvedValueOnce({
output: { message: { content: [{ type: "text", text: "response" }] }, stopReason: null },
})

await handler.completePrompt("test prompt", { timeoutMs: 5000 })

// bedrock.ts uses truthy check for timeoutMs, so it creates AbortSignal.timeout
expect(mockSend).toHaveBeenCalled()
})
})
})
})
Loading
Loading