diff --git a/src/api/index.ts b/src/api/index.ts index 2ee882ad72c..456a99f74ed 100644 --- a/src/api/index.ts +++ b/src/api/index.ts @@ -94,6 +94,15 @@ export interface ApiHandlerCreateMessageMetadata { * Only applies when toolProtocol is "native". */ parallelToolCalls?: boolean + /** + * Optional array of tool names that the model is allowed to call. + * When provided, all tool definitions are passed to the model (so it can reference + * historical tool calls), but only the specified tools can actually be invoked. + * This is used when switching modes to prevent model errors from missing tool + * definitions while still restricting callable tools to the current mode's permissions. + * Only applies to providers that support function calling restrictions (e.g., Gemini). + */ + allowedFunctionNames?: string[] } export interface ApiHandler { diff --git a/src/api/providers/__tests__/gemini-handler.spec.ts b/src/api/providers/__tests__/gemini-handler.spec.ts index 541ffd5611c..5ddd5a98a9c 100644 --- a/src/api/providers/__tests__/gemini-handler.spec.ts +++ b/src/api/providers/__tests__/gemini-handler.spec.ts @@ -1,4 +1,5 @@ import { t } from "i18next" +import { FunctionCallingConfigMode } from "@google/genai" import { GeminiHandler } from "../gemini" import type { ApiHandlerOptions } from "../../../shared/api" @@ -141,4 +142,152 @@ describe("GeminiHandler backend support", () => { }).rejects.toThrow(t("common:errors.gemini.generate_stream", { error: "API rate limit exceeded" })) }) }) + + describe("allowedFunctionNames support", () => { + const testTools = [ + { + type: "function" as const, + function: { + name: "read_file", + description: "Read a file", + parameters: { type: "object", properties: {} }, + }, + }, + { + type: "function" as const, + function: { + name: "write_to_file", + description: "Write to a file", + parameters: { type: "object", properties: {} }, + }, + }, + { + type: "function" as const, + function: { + name: "execute_command", + description: "Execute a command", + parameters: { type: "object", properties: {} }, + }, + }, + ] + + it("should pass allowedFunctionNames to toolConfig when provided", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: testTools, + allowedFunctionNames: ["read_file", "write_to_file"], + }) + .next() + + const config = stub.mock.calls[0][0].config + expect(config.toolConfig).toEqual({ + functionCallingConfig: { + mode: FunctionCallingConfigMode.ANY, + allowedFunctionNames: ["read_file", "write_to_file"], + }, + }) + }) + + it("should include all tools but restrict callable functions via allowedFunctionNames", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: testTools, + allowedFunctionNames: ["read_file"], + }) + .next() + + const config = stub.mock.calls[0][0].config + // All tools should be passed to the model + expect(config.tools[0].functionDeclarations).toHaveLength(3) + // But only read_file should be allowed to be called + expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"]) + }) + + it("should take precedence over tool_choice when allowedFunctionNames is provided", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: testTools, + tool_choice: "auto", + allowedFunctionNames: ["read_file"], + }) + .next() + + const config = stub.mock.calls[0][0].config + // allowedFunctionNames should take precedence - mode should be ANY, not AUTO + expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.ANY) + expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"]) + }) + + it("should fall back to tool_choice when allowedFunctionNames is empty", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: testTools, + tool_choice: "auto", + allowedFunctionNames: [], + }) + .next() + + const config = stub.mock.calls[0][0].config + // Empty allowedFunctionNames should fall back to tool_choice behavior + expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.AUTO) + expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toBeUndefined() + }) + + it("should not set toolConfig when allowedFunctionNames is undefined and no tool_choice", async () => { + const options = { + apiProvider: "gemini", + } as ApiHandlerOptions + const handler = new GeminiHandler(options) + const stub = vi.fn().mockReturnValue((async function* () {})()) + // @ts-ignore access private client + handler["client"].models.generateContentStream = stub + + await handler + .createMessage("test", [] as any, { + taskId: "test-task", + tools: testTools, + }) + .next() + + const config = stub.mock.calls[0][0].config + // No toolConfig should be set when neither allowedFunctionNames nor tool_choice is provided + expect(config.toolConfig).toBeUndefined() + }) + }) }) diff --git a/src/api/providers/gemini.ts b/src/api/providers/gemini.ts index 1cc9228256c..dada9db14cc 100644 --- a/src/api/providers/gemini.ts +++ b/src/api/providers/gemini.ts @@ -172,7 +172,19 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl ...(tools.length > 0 ? { tools } : {}), } - if (metadata?.tool_choice) { + // Handle allowedFunctionNames for mode-restricted tool access. + // When provided, all tool definitions are passed to the model (so it can reference + // historical tool calls in conversation), but only the specified tools can be invoked. + // This takes precedence over tool_choice to ensure mode restrictions are honored. + if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0) { + config.toolConfig = { + functionCallingConfig: { + // Use ANY mode to allow calling any of the allowed functions + mode: FunctionCallingConfigMode.ANY, + allowedFunctionNames: metadata.allowedFunctionNames, + }, + } + } else if (metadata?.tool_choice) { const choice = metadata.tool_choice let mode: FunctionCallingConfigMode let allowedFunctionNames: string[] | undefined diff --git a/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts b/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts index 50db6984f22..5cdfe2f1e79 100644 --- a/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts +++ b/src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts @@ -1,7 +1,12 @@ import { describe, it, expect, beforeEach, afterEach } from "vitest" import type OpenAI from "openai" import type { ModeConfig, ModelInfo } from "@roo-code/types" -import { filterNativeToolsForMode, filterMcpToolsForMode, applyModelToolCustomization } from "../filter-tools-for-mode" +import { + filterNativeToolsForMode, + filterMcpToolsForMode, + applyModelToolCustomization, + resolveToolAlias, +} from "../filter-tools-for-mode" import * as toolsModule from "../../../../shared/tools" describe("filterNativeToolsForMode", () => { @@ -859,3 +864,49 @@ describe("filterMcpToolsForMode", () => { }) }) }) + +describe("resolveToolAlias", () => { + it("should resolve known alias to canonical name", () => { + // write_file is an alias for write_to_file (defined in TOOL_ALIASES) + expect(resolveToolAlias("write_file")).toBe("write_to_file") + }) + + it("should return canonical name unchanged", () => { + expect(resolveToolAlias("write_to_file")).toBe("write_to_file") + expect(resolveToolAlias("read_file")).toBe("read_file") + expect(resolveToolAlias("apply_diff")).toBe("apply_diff") + }) + + it("should return unknown tool names unchanged", () => { + expect(resolveToolAlias("unknown_tool")).toBe("unknown_tool") + expect(resolveToolAlias("custom_tool_xyz")).toBe("custom_tool_xyz") + }) + + it("should ensure allowedFunctionNames are consistent with functionDeclarations", () => { + // This test documents the fix for the Gemini allowedFunctionNames issue. + // When tools are renamed via aliasRenames, the alias names must be resolved + // back to canonical names for allowedFunctionNames to match functionDeclarations. + // + // Example scenario: + // - Model specifies includedTools: ["write_file"] (an alias) + // - filterNativeToolsForMode returns tool with name "write_file" + // - But allTools (functionDeclarations) contains "write_to_file" (canonical) + // - If allowedFunctionNames contains "write_file", Gemini will error + // - Resolving aliases ensures consistency: resolveToolAlias("write_file") -> "write_to_file" + + const aliasToolName = "write_file" + const canonicalToolName = "write_to_file" + + // Simulate extracting name from a filtered tool that was renamed to alias + const extractedName = aliasToolName + + // Before the fix: allowedFunctionNames would contain alias name + // This would cause Gemini to error because "write_file" doesn't exist in functionDeclarations + + // After the fix: we resolve to canonical name + const resolvedName = resolveToolAlias(extractedName) + + // The resolved name matches what's in functionDeclarations (canonical names) + expect(resolvedName).toBe(canonicalToolName) + }) +}) diff --git a/src/core/task/Task.ts b/src/core/task/Task.ts index fa0a8311b78..b39c2f9b368 100644 --- a/src/core/task/Task.ts +++ b/src/core/task/Task.ts @@ -95,7 +95,7 @@ import { sanitizeToolUseId } from "../../utils/tool-id" // prompts import { formatResponse } from "../prompts/responses" import { SYSTEM_PROMPT } from "../prompts/system" -import { buildNativeToolsArray } from "./build-tools" +import { buildNativeToolsArrayWithRestrictions } from "./build-tools" // core modules import { ToolRepetitionDetector } from "../tools/ToolRepetitionDetector" @@ -4091,15 +4091,27 @@ export class Task extends EventEmitter implements TaskLike { const taskProtocol = this._taskToolProtocol ?? "xml" const shouldIncludeTools = taskProtocol === TOOL_PROTOCOL.NATIVE && (modelInfo.supportsNativeTools ?? false) - // Build complete tools array: native tools + dynamic MCP tools, filtered by mode restrictions + // Build complete tools array: native tools + dynamic MCP tools + // When includeAllToolsWithRestrictions is true, returns all tools but provides + // allowedFunctionNames for providers (like Gemini) that need to see all tool + // definitions in history while restricting callable tools for the current mode. + // Only Gemini currently supports this - other providers filter tools normally. let allTools: OpenAI.Chat.ChatCompletionTool[] = [] + let allowedFunctionNames: string[] | undefined + + // Gemini requires all tool definitions to be present for history compatibility, + // but uses allowedFunctionNames to restrict which tools can be called. + // Other providers (Anthropic, OpenAI, etc.) don't support this feature yet, + // so they continue to receive only the filtered tools for the current mode. + const supportsAllowedFunctionNames = apiConfiguration?.apiProvider === "gemini" + if (shouldIncludeTools) { const provider = this.providerRef.deref() if (!provider) { throw new Error("Provider reference lost during tool building") } - allTools = await buildNativeToolsArray({ + const toolsResult = await buildNativeToolsArrayWithRestrictions({ provider, cwd: this.cwd, mode, @@ -4111,7 +4123,10 @@ export class Task extends EventEmitter implements TaskLike { browserToolEnabled: state?.browserToolEnabled ?? true, modelInfo, diffEnabled: this.diffEnabled, + includeAllToolsWithRestrictions: supportsAllowedFunctionNames, }) + allTools = toolsResult.tools + allowedFunctionNames = toolsResult.allowedFunctionNames } // Parallel tool calls are disabled - feature is on hold @@ -4129,6 +4144,9 @@ export class Task extends EventEmitter implements TaskLike { tool_choice: "auto", toolProtocol: taskProtocol, parallelToolCalls: parallelToolCallsEnabled, + // When mode restricts tools, provide allowedFunctionNames so providers + // like Gemini can see all tools in history but only call allowed ones + ...(allowedFunctionNames ? { allowedFunctionNames } : {}), } : {}), } diff --git a/src/core/task/build-tools.ts b/src/core/task/build-tools.ts index 52a9f2eb82f..fe884314965 100644 --- a/src/core/task/build-tools.ts +++ b/src/core/task/build-tools.ts @@ -9,7 +9,11 @@ import type { ClineProvider } from "../webview/ClineProvider" import { getRooDirectoriesForCwd } from "../../services/roo-config/index.js" import { getNativeTools, getMcpServerTools } from "../prompts/tools/native-tools" -import { filterNativeToolsForMode, filterMcpToolsForMode } from "../prompts/tools/filter-tools-for-mode" +import { + filterNativeToolsForMode, + filterMcpToolsForMode, + resolveToolAlias, +} from "../prompts/tools/filter-tools-for-mode" interface BuildToolsOptions { provider: ClineProvider @@ -23,6 +27,35 @@ interface BuildToolsOptions { browserToolEnabled: boolean modelInfo?: ModelInfo diffEnabled: boolean + /** + * If true, returns all tools without mode filtering, but also includes + * the list of allowed tool names for use with allowedFunctionNames. + * This enables providers that support function call restrictions (e.g., Gemini) + * to pass all tool definitions while restricting callable tools. + */ + includeAllToolsWithRestrictions?: boolean +} + +interface BuildToolsResult { + /** + * The tools to pass to the model. + * If includeAllToolsWithRestrictions is true, this includes ALL tools. + * Otherwise, it includes only mode-filtered tools. + */ + tools: OpenAI.Chat.ChatCompletionTool[] + /** + * The names of tools that are allowed to be called based on mode restrictions. + * Only populated when includeAllToolsWithRestrictions is true. + * Use this with allowedFunctionNames in providers that support it. + */ + allowedFunctionNames?: string[] +} + +/** + * Extracts the function name from a tool definition. + */ +function getToolName(tool: OpenAI.Chat.ChatCompletionTool): string { + return (tool as OpenAI.Chat.ChatCompletionFunctionTool).function.name } /** @@ -33,6 +66,23 @@ interface BuildToolsOptions { * @returns Array of filtered native and MCP tools */ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise { + const result = await buildNativeToolsArrayWithRestrictions(options) + return result.tools +} + +/** + * Builds the complete tools array for native protocol requests with optional mode restrictions. + * When includeAllToolsWithRestrictions is true, returns ALL tools but also provides + * the list of allowed tool names for use with allowedFunctionNames. + * + * This enables providers like Gemini to pass all tool definitions to the model + * (so it can reference historical tool calls) while restricting which tools + * can actually be invoked via allowedFunctionNames in toolConfig. + * + * @param options - Configuration options for building the tools + * @returns BuildToolsResult with tools array and optional allowedFunctionNames + */ +export async function buildNativeToolsArrayWithRestrictions(options: BuildToolsOptions): Promise { const { provider, cwd, @@ -45,6 +95,7 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise browserToolEnabled, modelInfo, diffEnabled, + includeAllToolsWithRestrictions, } = options const mcpHub = provider.getMcpHub() @@ -102,5 +153,29 @@ export async function buildNativeToolsArray(options: BuildToolsOptions): Promise } } - return [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools] + // Combine filtered tools (for backward compatibility and for allowedFunctionNames) + const filteredTools = [...filteredNativeTools, ...filteredMcpTools, ...nativeCustomTools] + + // If includeAllToolsWithRestrictions is true, return ALL tools but provide + // allowed names based on mode filtering + if (includeAllToolsWithRestrictions) { + // Combine ALL tools (unfiltered native + all MCP + custom) + const allTools = [...nativeTools, ...mcpTools, ...nativeCustomTools] + + // Extract names of tools that are allowed based on mode filtering. + // Resolve any alias names to canonical names to ensure consistency with allTools + // (which uses canonical names). This prevents Gemini errors when tools are renamed + // to aliases in filteredTools but allTools contains the original canonical names. + const allowedFunctionNames = filteredTools.map((tool) => resolveToolAlias(getToolName(tool))) + + return { + tools: allTools, + allowedFunctionNames, + } + } + + // Default behavior: return only filtered tools + return { + tools: filteredTools, + } }