diff --git a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts index 7d0d2548fc..9816d4fe47 100644 --- a/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts +++ b/src/api/providers/__tests__/base-openai-compatible-provider.spec.ts @@ -354,7 +354,9 @@ describe("BaseOpenAiCompatibleProvider", () => { stream: true, stream_options: { include_usage: true }, }), - undefined, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) }) diff --git a/src/api/providers/__tests__/openai.spec.ts b/src/api/providers/__tests__/openai.spec.ts index 4469efd4d1..2bf4befb3d 100644 --- a/src/api/providers/__tests__/openai.spec.ts +++ b/src/api/providers/__tests__/openai.spec.ts @@ -549,7 +549,9 @@ describe("OpenAiHandler", () => { model: mockOptions.openAiModelId, messages: [{ role: "user", content: "Test prompt" }], }, - {}, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) }) @@ -634,7 +636,10 @@ describe("OpenAiHandler", () => { stream_options: { include_usage: true }, temperature: 0, }, - { path: "/models/chat/completions" }, + expect.objectContaining({ + path: "/models/chat/completions", + signal: expect.any(AbortSignal), + }), ) // Verify max_tokens is NOT included when includeMaxTokens is not set @@ -680,7 +685,10 @@ describe("OpenAiHandler", () => { { role: "user", content: "Hello!" }, ], }, - { path: "/models/chat/completions" }, + expect.objectContaining({ + path: "/models/chat/completions", + signal: expect.any(AbortSignal), + }), ) // Verify max_tokens is NOT included when includeMaxTokens is not set @@ -697,7 +705,10 @@ describe("OpenAiHandler", () => { model: azureOptions.openAiModelId, messages: [{ role: "user", content: "Test prompt" }], }, - { path: "/models/chat/completions" }, + expect.objectContaining({ + path: "/models/chat/completions", + signal: expect.any(AbortSignal), + }), ) // Verify max_tokens is NOT included when includeMaxTokens is not set @@ -737,7 +748,9 @@ describe("OpenAiHandler", () => { model: grokOptions.openAiModelId, stream: true, }), - {}, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) const mockCalls = mockCreate.mock.calls @@ -796,7 +809,9 @@ describe("OpenAiHandler", () => { // O3 models do not support deprecated max_tokens but do support max_completion_tokens max_completion_tokens: 32000, }), - {}, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) }) @@ -953,7 +968,9 @@ describe("OpenAiHandler", () => { reasoning_effort: "medium", temperature: undefined, }), - {}, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) // Verify max_tokens is NOT included @@ -997,7 +1014,9 @@ describe("OpenAiHandler", () => { // O3 models do not support deprecated max_tokens but do support max_completion_tokens max_completion_tokens: 65536, // Using default maxTokens from o3Options }), - {}, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) // Verify stream is not set @@ -1074,7 +1093,9 @@ describe("OpenAiHandler", () => { expect.objectContaining({ temperature: undefined, // Temperature is not supported for O3 models }), - {}, + expect.objectContaining({ + signal: expect.any(AbortSignal), + }), ) }) @@ -1099,7 +1120,10 @@ describe("OpenAiHandler", () => { expect.objectContaining({ model: "o3-mini", }), - { path: "/models/chat/completions" }, + expect.objectContaining({ + path: "/models/chat/completions", + signal: expect.any(AbortSignal), + }), ) // Verify max_tokens is NOT included when includeMaxTokens is false @@ -1129,7 +1153,10 @@ describe("OpenAiHandler", () => { model: "o3-mini", // O3 models do not support max_tokens }), - { path: "/models/chat/completions" }, + expect.objectContaining({ + path: "/models/chat/completions", + signal: expect.any(AbortSignal), + }), ) }) }) diff --git a/src/api/providers/base-openai-compatible-provider.ts b/src/api/providers/base-openai-compatible-provider.ts index a2a55cdc10..54a4b44d98 100644 --- a/src/api/providers/base-openai-compatible-provider.ts +++ b/src/api/providers/base-openai-compatible-provider.ts @@ -37,6 +37,9 @@ export abstract class BaseOpenAiCompatibleProvider protected client: OpenAI + // Abort controller for cancelling ongoing requests + private abortController?: AbortController + constructor({ providerName, baseURL, @@ -106,7 +109,12 @@ export abstract class BaseOpenAiCompatibleProvider } try { - return this.client.chat.completions.create(params, requestOptions) + // Merge abort signal with any existing request options + const mergedOptions: OpenAI.RequestOptions = { + ...requestOptions, + signal: this.abortController?.signal, + } + return this.client.chat.completions.create(params, mergedOptions) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -117,87 +125,99 @@ export abstract class BaseOpenAiCompatibleProvider messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const stream = await this.createStream(systemPrompt, messages, metadata) + // Create AbortController for cancellation + this.abortController = new AbortController() - const matcher = new XmlMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, - ) - - let lastUsage: OpenAI.CompletionUsage | undefined - const activeToolCallIds = new Set() + try { + const stream = await this.createStream(systemPrompt, messages, metadata) + + const matcher = new XmlMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, + ) + + let lastUsage: OpenAI.CompletionUsage | undefined + const activeToolCallIds = new Set() + + for await (const chunk of stream) { + // Check if request was aborted + if (this.abortController?.signal.aborted) { + break + } - for await (const chunk of stream) { - // Check for provider-specific error responses (e.g., MiniMax base_resp) - const chunkAny = chunk as any - if (chunkAny.base_resp?.status_code && chunkAny.base_resp.status_code !== 0) { - throw new Error( - `${this.providerName} API Error (${chunkAny.base_resp.status_code}): ${chunkAny.base_resp.status_msg || "Unknown error"}`, - ) - } + // Check for provider-specific error responses (e.g., MiniMax base_resp) + const chunkAny = chunk as any + if (chunkAny.base_resp?.status_code && chunkAny.base_resp.status_code !== 0) { + throw new Error( + `${this.providerName} API Error (${chunkAny.base_resp.status_code}): ${chunkAny.base_resp.status_msg || "Unknown error"}`, + ) + } - const delta = chunk.choices?.[0]?.delta - const finishReason = chunk.choices?.[0]?.finish_reason + const delta = chunk.choices?.[0]?.delta + const finishReason = chunk.choices?.[0]?.finish_reason - if (delta?.content) { - for (const processedChunk of matcher.update(delta.content)) { - yield processedChunk + if (delta?.content) { + for (const processedChunk of matcher.update(delta.content)) { + yield processedChunk + } } - } - if (delta) { - for (const key of ["reasoning_content", "reasoning"] as const) { - if (key in delta) { - const reasoning_content = ((delta as any)[key] as string | undefined) || "" - if (reasoning_content?.trim()) { - yield { type: "reasoning", text: reasoning_content } + if (delta) { + for (const key of ["reasoning_content", "reasoning"] as const) { + if (key in delta) { + const reasoning_content = ((delta as any)[key] as string | undefined) || "" + if (reasoning_content?.trim()) { + yield { type: "reasoning", text: reasoning_content } + } + break } - break } } - } - // Emit raw tool call chunks - NativeToolCallParser handles state management - if (delta?.tool_calls) { - for (const toolCall of delta.tool_calls) { - if (toolCall.id) { - activeToolCallIds.add(toolCall.id) + // Emit raw tool call chunks - NativeToolCallParser handles state management + if (delta?.tool_calls) { + for (const toolCall of delta.tool_calls) { + if (toolCall.id) { + activeToolCallIds.add(toolCall.id) + } + yield { + type: "tool_call_partial", + index: toolCall.index, + id: toolCall.id, + name: toolCall.function?.name, + arguments: toolCall.function?.arguments, + } } - yield { - type: "tool_call_partial", - index: toolCall.index, - id: toolCall.id, - name: toolCall.function?.name, - arguments: toolCall.function?.arguments, + } + + // Emit tool_call_end events when finish_reason is "tool_calls" + // This ensures tool calls are finalized even if the stream doesn't properly close + if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { + for (const id of activeToolCallIds) { + yield { type: "tool_call_end", id } } + activeToolCallIds.clear() } - } - // Emit tool_call_end events when finish_reason is "tool_calls" - // This ensures tool calls are finalized even if the stream doesn't properly close - if (finishReason === "tool_calls" && activeToolCallIds.size > 0) { - for (const id of activeToolCallIds) { - yield { type: "tool_call_end", id } + if (chunk.usage) { + lastUsage = chunk.usage } - activeToolCallIds.clear() } - if (chunk.usage) { - lastUsage = chunk.usage + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, this.getModel().info) } - } - - if (lastUsage) { - yield this.processUsageMetrics(lastUsage, this.getModel().info) - } - // Process any remaining content - for (const processedChunk of matcher.final()) { - yield processedChunk + // Process any remaining content + for (const processedChunk of matcher.final()) { + yield processedChunk + } + } finally { + this.abortController = undefined } } @@ -222,20 +242,25 @@ export abstract class BaseOpenAiCompatibleProvider } async completePrompt(prompt: string): Promise { - const { id: modelId, info: modelInfo } = this.getModel() + // Create AbortController for cancellation + this.abortController = new AbortController() - const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = { - model: modelId, - messages: [{ role: "user", content: prompt }], - } + try { + const { id: modelId, info: modelInfo } = this.getModel() - // Add thinking parameter if reasoning is enabled and model supports it - if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) { - ;(params as any).thinking = { type: "enabled" } - } + const params: OpenAI.Chat.Completions.ChatCompletionCreateParams = { + model: modelId, + messages: [{ role: "user", content: prompt }], + } - try { - const response = await this.client.chat.completions.create(params) + // Add thinking parameter if reasoning is enabled and model supports it + if (this.options.enableReasoningEffort && modelInfo.supportsReasoningBinary) { + ;(params as any).thinking = { type: "enabled" } + } + + const response = await this.client.chat.completions.create(params, { + signal: this.abortController.signal, + }) // Check for provider-specific error responses (e.g., MiniMax base_resp) const responseAny = response as any @@ -248,6 +273,8 @@ export abstract class BaseOpenAiCompatibleProvider return response.choices?.[0]?.message.content || "" } catch (error) { throw handleOpenAIError(error, this.providerName) + } finally { + this.abortController = undefined } } diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index 9d632fbdf4..76fe7ba11e 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -33,6 +33,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl protected options: ApiHandlerOptions protected client: OpenAI private readonly providerName = "OpenAI" + // Abort controller for cancelling ongoing requests + private abortController?: AbortController constructor(options: ApiHandlerOptions) { super() @@ -85,193 +87,206 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl messages: Anthropic.Messages.MessageParam[], metadata?: ApiHandlerCreateMessageMetadata, ): ApiStream { - const { info: modelInfo, reasoning } = this.getModel() - const modelUrl = this.options.openAiBaseUrl ?? "" - const modelId = this.options.openAiModelId ?? "" - const enabledR1Format = this.options.openAiR1FormatEnabled ?? false - const isAzureAiInference = this._isAzureAiInference(modelUrl) - const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format - - if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) { - yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, metadata) - return - } + // Create AbortController for cancellation + this.abortController = new AbortController() - let systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { - role: "system", - content: systemPrompt, - } + try { + const { info: modelInfo, reasoning } = this.getModel() + const modelUrl = this.options.openAiBaseUrl ?? "" + const modelId = this.options.openAiModelId ?? "" + const enabledR1Format = this.options.openAiR1FormatEnabled ?? false + const isAzureAiInference = this._isAzureAiInference(modelUrl) + const deepseekReasoner = modelId.includes("deepseek-reasoner") || enabledR1Format + + if (modelId.includes("o1") || modelId.includes("o3") || modelId.includes("o4")) { + yield* this.handleO3FamilyMessage(modelId, systemPrompt, messages, metadata) + return + } - if (this.options.openAiStreamingEnabled ?? true) { - let convertedMessages + let systemMessage: OpenAI.Chat.ChatCompletionSystemMessageParam = { + role: "system", + content: systemPrompt, + } - if (deepseekReasoner) { - convertedMessages = convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) - } else { - if (modelInfo.supportsPromptCache) { - systemMessage = { - role: "system", - content: [ - { - type: "text", - text: systemPrompt, - // @ts-ignore-next-line - cache_control: { type: "ephemeral" }, - }, - ], + if (this.options.openAiStreamingEnabled ?? true) { + let convertedMessages + + if (deepseekReasoner) { + convertedMessages = convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) + } else { + if (modelInfo.supportsPromptCache) { + systemMessage = { + role: "system", + content: [ + { + type: "text", + text: systemPrompt, + // @ts-ignore-next-line + cache_control: { type: "ephemeral" }, + }, + ], + } } - } - convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] + convertedMessages = [systemMessage, ...convertToOpenAiMessages(messages)] - if (modelInfo.supportsPromptCache) { - // Note: the following logic is copied from openrouter: - // Add cache_control to the last two user messages - // (note: this works because we only ever add one user message at a time, but if we added multiple we'd need to mark the user message before the last assistant message) - const lastTwoUserMessages = convertedMessages.filter((msg) => msg.role === "user").slice(-2) + if (modelInfo.supportsPromptCache) { + // Note: the following logic is copied from openrouter: + // Add cache_control to the last two user messages + // (note: this works because we only ever add one user message at a time, but if we added multiple we'd need to mark the user message before the last assistant message) + const lastTwoUserMessages = convertedMessages.filter((msg) => msg.role === "user").slice(-2) - lastTwoUserMessages.forEach((msg) => { - if (typeof msg.content === "string") { - msg.content = [{ type: "text", text: msg.content }] - } + lastTwoUserMessages.forEach((msg) => { + if (typeof msg.content === "string") { + msg.content = [{ type: "text", text: msg.content }] + } + + if (Array.isArray(msg.content)) { + // NOTE: this is fine since env details will always be added at the end. but if it weren't there, and the user added a image_url type message, it would pop a text part before it and then move it after to the end. + let lastTextPart = msg.content.filter((part) => part.type === "text").pop() - if (Array.isArray(msg.content)) { - // NOTE: this is fine since env details will always be added at the end. but if it weren't there, and the user added a image_url type message, it would pop a text part before it and then move it after to the end. - let lastTextPart = msg.content.filter((part) => part.type === "text").pop() + if (!lastTextPart) { + lastTextPart = { type: "text", text: "..." } + msg.content.push(lastTextPart) + } - if (!lastTextPart) { - lastTextPart = { type: "text", text: "..." } - msg.content.push(lastTextPart) + // @ts-ignore-next-line + lastTextPart["cache_control"] = { type: "ephemeral" } } + }) + } + } - // @ts-ignore-next-line - lastTextPart["cache_control"] = { type: "ephemeral" } - } - }) + const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { + model: modelId, + temperature: + this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), + messages: convertedMessages, + stream: true as const, + ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), + ...(reasoning && reasoning), + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), + ...(metadata?.toolProtocol === "native" && + metadata.parallelToolCalls === true && { + parallel_tool_calls: true, + }), } - } - const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + // Add max_tokens if needed + this.addMaxTokensIfNeeded(requestOptions, modelInfo) - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { - model: modelId, - temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), - messages: convertedMessages, - stream: true as const, - ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), - ...(reasoning && reasoning), - ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), - ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), - ...(metadata?.toolProtocol === "native" && - metadata.parallelToolCalls === true && { - parallel_tool_calls: true, - }), - } - - // Add max_tokens if needed - this.addMaxTokensIfNeeded(requestOptions, modelInfo) + let stream + try { + stream = await this.client.chat.completions.create(requestOptions, { + ...(isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: this.abortController.signal, + }) + } catch (error) { + throw handleOpenAIError(error, this.providerName) + } - let stream - try { - stream = await this.client.chat.completions.create( - requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, + const matcher = new XmlMatcher( + "think", + (chunk) => + ({ + type: chunk.matched ? "reasoning" : "text", + text: chunk.data, + }) as const, ) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } - const matcher = new XmlMatcher( - "think", - (chunk) => - ({ - type: chunk.matched ? "reasoning" : "text", - text: chunk.data, - }) as const, - ) + let lastUsage + const activeToolCallIds = new Set() - let lastUsage - const activeToolCallIds = new Set() + for await (const chunk of stream) { + // Check if request was aborted + if (this.abortController?.signal.aborted) { + break + } - for await (const chunk of stream) { - const delta = chunk.choices?.[0]?.delta ?? {} - const finishReason = chunk.choices?.[0]?.finish_reason + const delta = chunk.choices?.[0]?.delta ?? {} + const finishReason = chunk.choices?.[0]?.finish_reason - if (delta.content) { - for (const chunk of matcher.update(delta.content)) { - yield chunk + if (delta.content) { + for (const chunk of matcher.update(delta.content)) { + yield chunk + } } - } - if ("reasoning_content" in delta && delta.reasoning_content) { - yield { - type: "reasoning", - text: (delta.reasoning_content as string | undefined) || "", + if ("reasoning_content" in delta && delta.reasoning_content) { + yield { + type: "reasoning", + text: (delta.reasoning_content as string | undefined) || "", + } } - } - yield* this.processToolCalls(delta, finishReason, activeToolCallIds) + yield* this.processToolCalls(delta, finishReason, activeToolCallIds) - if (chunk.usage) { - lastUsage = chunk.usage + if (chunk.usage) { + lastUsage = chunk.usage + } } - } - for (const chunk of matcher.final()) { - yield chunk - } + for (const chunk of matcher.final()) { + yield chunk + } - if (lastUsage) { - yield this.processUsageMetrics(lastUsage, modelInfo) - } - } else { - const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { - model: modelId, - messages: deepseekReasoner - ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) - : [systemMessage, ...convertToOpenAiMessages(messages)], - ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), - ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), - ...(metadata?.toolProtocol === "native" && - metadata.parallelToolCalls === true && { - parallel_tool_calls: true, - }), - } + if (lastUsage) { + yield this.processUsageMetrics(lastUsage, modelInfo) + } + } else { + const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsNonStreaming = { + model: modelId, + messages: deepseekReasoner + ? convertToR1Format([{ role: "user", content: systemPrompt }, ...messages]) + : [systemMessage, ...convertToOpenAiMessages(messages)], + ...(metadata?.tools && { tools: this.convertToolsForOpenAI(metadata.tools) }), + ...(metadata?.tool_choice && { tool_choice: metadata.tool_choice }), + ...(metadata?.toolProtocol === "native" && + metadata.parallelToolCalls === true && { + parallel_tool_calls: true, + }), + } - // Add max_tokens if needed - this.addMaxTokensIfNeeded(requestOptions, modelInfo) + // Add max_tokens if needed + this.addMaxTokensIfNeeded(requestOptions, modelInfo) - let response - try { - response = await this.client.chat.completions.create( - requestOptions, - this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) - } catch (error) { - throw handleOpenAIError(error, this.providerName) - } + let response + try { + response = await this.client.chat.completions.create(requestOptions, { + ...(this._isAzureAiInference(modelUrl) ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: this.abortController.signal, + }) + } catch (error) { + throw handleOpenAIError(error, this.providerName) + } - const message = response.choices?.[0]?.message + const message = response.choices?.[0]?.message - if (message?.tool_calls) { - for (const toolCall of message.tool_calls) { - if (toolCall.type === "function") { - yield { - type: "tool_call", - id: toolCall.id, - name: toolCall.function.name, - arguments: toolCall.function.arguments, + if (message?.tool_calls) { + for (const toolCall of message.tool_calls) { + if (toolCall.type === "function") { + yield { + type: "tool_call", + id: toolCall.id, + name: toolCall.function.name, + arguments: toolCall.function.arguments, + } } } } - } - yield { - type: "text", - text: message?.content || "", - } + yield { + type: "text", + text: message?.content || "", + } - yield this.processUsageMetrics(response.usage, modelInfo) + yield this.processUsageMetrics(response.usage, modelInfo) + } + } finally { + this.abortController = undefined } } @@ -299,6 +314,9 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } async completePrompt(prompt: string): Promise { + // Create AbortController for cancellation + this.abortController = new AbortController() + try { const isAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) const model = this.getModel() @@ -314,10 +332,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl let response try { - response = await this.client.chat.completions.create( - requestOptions, - isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) + response = await this.client.chat.completions.create(requestOptions, { + ...(isAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: this.abortController.signal, + }) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -329,6 +347,8 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } throw error + } finally { + this.abortController = undefined } } @@ -372,10 +392,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl let stream try { - stream = await this.client.chat.completions.create( - requestOptions, - methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) + stream = await this.client.chat.completions.create(requestOptions, { + ...(methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: this.abortController!.signal, + }) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -408,10 +428,10 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl let response try { - response = await this.client.chat.completions.create( - requestOptions, - methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}, - ) + response = await this.client.chat.completions.create(requestOptions, { + ...(methodIsAzureAiInference ? { path: OPENAI_AZURE_AI_INFERENCE_PATH } : {}), + signal: this.abortController!.signal, + }) } catch (error) { throw handleOpenAIError(error, this.providerName) } @@ -442,6 +462,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const activeToolCallIds = new Set() for await (const chunk of stream) { + // Check if request was aborted + if (this.abortController?.signal.aborted) { + break + } + const delta = chunk.choices?.[0]?.delta const finishReason = chunk.choices?.[0]?.finish_reason