Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,9 @@ describe("BaseOpenAiCompatibleProvider", () => {
stream: true,
stream_options: { include_usage: true },
}),
undefined,
expect.objectContaining({
signal: expect.any(AbortSignal),
}),
)
})

Expand Down
49 changes: 38 additions & 11 deletions src/api/providers/__tests__/openai.spec.ts
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,9 @@ describe("OpenAiHandler", () => {
model: mockOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
},
{},
expect.objectContaining({
signal: expect.any(AbortSignal),
}),
)
})

Expand Down Expand Up @@ -634,7 +636,10 @@ describe("OpenAiHandler", () => {
stream_options: { include_usage: true },
temperature: 0,
},
{ path: "/models/chat/completions" },
expect.objectContaining({
path: "/models/chat/completions",
signal: expect.any(AbortSignal),
}),
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
Expand Down Expand Up @@ -680,7 +685,10 @@ describe("OpenAiHandler", () => {
{ role: "user", content: "Hello!" },
],
},
{ path: "/models/chat/completions" },
expect.objectContaining({
path: "/models/chat/completions",
signal: expect.any(AbortSignal),
}),
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
Expand All @@ -697,7 +705,10 @@ describe("OpenAiHandler", () => {
model: azureOptions.openAiModelId,
messages: [{ role: "user", content: "Test prompt" }],
},
{ path: "/models/chat/completions" },
expect.objectContaining({
path: "/models/chat/completions",
signal: expect.any(AbortSignal),
}),
)

// Verify max_tokens is NOT included when includeMaxTokens is not set
Expand Down Expand Up @@ -737,7 +748,9 @@ describe("OpenAiHandler", () => {
model: grokOptions.openAiModelId,
stream: true,
}),
{},
expect.objectContaining({
signal: expect.any(AbortSignal),
}),
)

const mockCalls = mockCreate.mock.calls
Expand Down Expand Up @@ -796,7 +809,9 @@ describe("OpenAiHandler", () => {
// O3 models do not support deprecated max_tokens but do support max_completion_tokens
max_completion_tokens: 32000,
}),
{},
expect.objectContaining({
signal: expect.any(AbortSignal),
}),
)
})

Expand Down Expand Up @@ -953,7 +968,9 @@ describe("OpenAiHandler", () => {
reasoning_effort: "medium",
temperature: undefined,
}),
{},
expect.objectContaining({
signal: expect.any(AbortSignal),
}),
)

// Verify max_tokens is NOT included
Expand Down Expand Up @@ -997,7 +1014,9 @@ describe("OpenAiHandler", () => {
// O3 models do not support deprecated max_tokens but do support max_completion_tokens
max_completion_tokens: 65536, // Using default maxTokens from o3Options
}),
{},
expect.objectContaining({
signal: expect.any(AbortSignal),
}),
)

// Verify stream is not set
Expand Down Expand Up @@ -1074,7 +1093,9 @@ describe("OpenAiHandler", () => {
expect.objectContaining({
temperature: undefined, // Temperature is not supported for O3 models
}),
{},
expect.objectContaining({
signal: expect.any(AbortSignal),
}),
)
})

Expand All @@ -1099,7 +1120,10 @@ describe("OpenAiHandler", () => {
expect.objectContaining({
model: "o3-mini",
}),
{ path: "/models/chat/completions" },
expect.objectContaining({
path: "/models/chat/completions",
signal: expect.any(AbortSignal),
}),
)

// Verify max_tokens is NOT included when includeMaxTokens is false
Expand Down Expand Up @@ -1129,7 +1153,10 @@ describe("OpenAiHandler", () => {
model: "o3-mini",
// O3 models do not support max_tokens
}),
{ path: "/models/chat/completions" },
expect.objectContaining({
path: "/models/chat/completions",
signal: expect.any(AbortSignal),
}),
)
})
})
Expand Down
175 changes: 101 additions & 74 deletions src/api/providers/base-openai-compatible-provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>

protected client: OpenAI

// Abort controller for cancelling ongoing requests
private abortController?: AbortController

constructor({
providerName,
baseURL,
Expand Down Expand Up @@ -106,7 +109,12 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

try {
return this.client.chat.completions.create(params, requestOptions)
// Merge abort signal with any existing request options
const mergedOptions: OpenAI.RequestOptions = {
...requestOptions,
signal: this.abortController?.signal,
}
return this.client.chat.completions.create(params, mergedOptions)
} catch (error) {
throw handleOpenAIError(error, this.providerName)
}
Expand All @@ -117,87 +125,99 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
messages: Anthropic.Messages.MessageParam[],
metadata?: ApiHandlerCreateMessageMetadata,
): ApiStream {
const stream = await this.createStream(systemPrompt, messages, metadata)
// Create AbortController for cancellation
this.abortController = new AbortController()

const matcher = new XmlMatcher(
"think",
(chunk) =>
({
type: chunk.matched ? "reasoning" : "text",
text: chunk.data,
}) as const,
)

let lastUsage: OpenAI.CompletionUsage | undefined
const activeToolCallIds = new Set<string>()
try {
const stream = await this.createStream(systemPrompt, messages, metadata)

const matcher = new XmlMatcher(
"think",
(chunk) =>
({
type: chunk.matched ? "reasoning" : "text",
text: chunk.data,
}) as const,
)

let lastUsage: OpenAI.CompletionUsage | undefined
const activeToolCallIds = new Set<string>()

for await (const chunk of stream) {
// Check if request was aborted
if (this.abortController?.signal.aborted) {
break
}

for await (const chunk of stream) {
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const chunkAny = chunk as any
if (chunkAny.base_resp?.status_code && chunkAny.base_resp.status_code !== 0) {
throw new Error(
`${this.providerName} API Error (${chunkAny.base_resp.status_code}): ${chunkAny.base_resp.status_msg || "Unknown error"}`,
)
}
// Check for provider-specific error responses (e.g., MiniMax base_resp)
const chunkAny = chunk as any
if (chunkAny.base_resp?.status_code && chunkAny.base_resp.status_code !== 0) {
throw new Error(
`${this.providerName} API Error (${chunkAny.base_resp.status_code}): ${chunkAny.base_resp.status_msg || "Unknown error"}`,
)
}

const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason
const delta = chunk.choices?.[0]?.delta
const finishReason = chunk.choices?.[0]?.finish_reason

if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
yield processedChunk
if (delta?.content) {
for (const processedChunk of matcher.update(delta.content)) {
yield processedChunk
}
}
}

if (delta) {
for (const key of ["reasoning_content", "reasoning"] as const) {
if (key in delta) {
const reasoning_content = ((delta as any)[key] as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
if (delta) {
for (const key of ["reasoning_content", "reasoning"] as const) {
if (key in delta) {
const reasoning_content = ((delta as any)[key] as string | undefined) || ""
if (reasoning_content?.trim()) {
yield { type: "reasoning", text: reasoning_content }
}
break
}
break
}
}
}

// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
if (toolCall.id) {
activeToolCallIds.add(toolCall.id)
// Emit raw tool call chunks - NativeToolCallParser handles state management
if (delta?.tool_calls) {
for (const toolCall of delta.tool_calls) {
if (toolCall.id) {
activeToolCallIds.add(toolCall.id)
}
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}
}
yield {
type: "tool_call_partial",
index: toolCall.index,
id: toolCall.id,
name: toolCall.function?.name,
arguments: toolCall.function?.arguments,
}

// Emit tool_call_end events when finish_reason is "tool_calls"
// This ensures tool calls are finalized even if the stream doesn't properly close
if (finishReason === "tool_calls" && activeToolCallIds.size > 0) {
for (const id of activeToolCallIds) {
yield { type: "tool_call_end", id }
}
activeToolCallIds.clear()
}
}

// Emit tool_call_end events when finish_reason is "tool_calls"
// This ensures tool calls are finalized even if the stream doesn't properly close
if (finishReason === "tool_calls" && activeToolCallIds.size > 0) {
for (const id of activeToolCallIds) {
yield { type: "tool_call_end", id }
if (chunk.usage) {
lastUsage = chunk.usage
}
activeToolCallIds.clear()
}

if (chunk.usage) {
lastUsage = chunk.usage
if (lastUsage) {
yield this.processUsageMetrics(lastUsage, this.getModel().info)
}
}

if (lastUsage) {
yield this.processUsageMetrics(lastUsage, this.getModel().info)
}

// Process any remaining content
for (const processedChunk of matcher.final()) {
yield processedChunk
// Process any remaining content
for (const processedChunk of matcher.final()) {
yield processedChunk
}
} finally {
this.abortController = undefined
}
}

Expand All @@ -222,20 +242,25 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
}

async completePrompt(prompt: string): Promise<string> {
const { id: modelId, info: modelInfo } = this.getModel()
// Create AbortController for cancellation
this.abortController = new AbortController()

const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = {
model: modelId,
messages: [{ role: "user", content: prompt }],
}
try {
const { id: modelId, info: modelInfo } = this.getModel()

// Add thinking parameter if reasoning is enabled and model supports it
if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) {
;(params as any).thinking = { type: "enabled" }
}
const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = {
model: modelId,
messages: [{ role: "user", content: prompt }],
}

try {
const response = await this.client.chat.completions.create(params)
// Add thinking parameter if reasoning is enabled and model supports it
if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) {
;(params as any).thinking = { type: "enabled" }
}

const response = await this.client.chat.completions.create(params, {
signal: this.abortController.signal,
})

// Check for provider-specific error responses (e.g., MiniMax base_resp)
const responseAny = response as any
Expand All @@ -248,6 +273,8 @@ export abstract class BaseOpenAiCompatibleProvider<ModelName extends string>
return response.choices?.[0]?.message.content || ""
} catch (error) {
throw handleOpenAIError(error, this.providerName)
} finally {
this.abortController = undefined
}
}

Expand Down
Loading
Loading