Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/api/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,15 @@ export interface ApiHandlerCreateMessageMetadata {
* Only applies when toolProtocol is "native".
*/
parallelToolCalls?: boolean
/**
* Optional array of tool names that the model is allowed to call.
* When provided, all tool definitions are passed to the model (so it can reference
* historical tool calls), but only the specified tools can actually be invoked.
* This is used when switching modes to prevent model errors from missing tool
* definitions while still restricting callable tools to the current mode's permissions.
* Only applies to providers that support function calling restrictions (e.g., Gemini).
*/
allowedFunctionNames?: string[]
}

export interface ApiHandler {
Expand Down
149 changes: 149 additions & 0 deletions src/api/providers/__tests__/gemini-handler.spec.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { t } from "i18next"
import { FunctionCallingConfigMode } from "@google/genai"

import { GeminiHandler } from "../gemini"
import type { ApiHandlerOptions } from "../../../shared/api"
Expand Down Expand Up @@ -141,4 +142,152 @@ describe("GeminiHandler backend support", () => {
}).rejects.toThrow(t("common:errors.gemini.generate_stream", { error: "API rate limit exceeded" }))
})
})

describe("allowedFunctionNames support", () => {
const testTools = [
{
type: "function" as const,
function: {
name: "read_file",
description: "Read a file",
parameters: { type: "object", properties: {} },
},
},
{
type: "function" as const,
function: {
name: "write_to_file",
description: "Write to a file",
parameters: { type: "object", properties: {} },
},
},
{
type: "function" as const,
function: {
name: "execute_command",
description: "Execute a command",
parameters: { type: "object", properties: {} },
},
},
]

it("should pass allowedFunctionNames to toolConfig when provided", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
allowedFunctionNames: ["read_file", "write_to_file"],
})
.next()

const config = stub.mock.calls[0][0].config
expect(config.toolConfig).toEqual({
functionCallingConfig: {
mode: FunctionCallingConfigMode.ANY,
allowedFunctionNames: ["read_file", "write_to_file"],
},
})
})

it("should include all tools but restrict callable functions via allowedFunctionNames", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
allowedFunctionNames: ["read_file"],
})
.next()

const config = stub.mock.calls[0][0].config
// All tools should be passed to the model
expect(config.tools[0].functionDeclarations).toHaveLength(3)
// But only read_file should be allowed to be called
expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"])
})

it("should take precedence over tool_choice when allowedFunctionNames is provided", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
tool_choice: "auto",
allowedFunctionNames: ["read_file"],
})
.next()

const config = stub.mock.calls[0][0].config
// allowedFunctionNames should take precedence - mode should be ANY, not AUTO
expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.ANY)
expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toEqual(["read_file"])
})

it("should fall back to tool_choice when allowedFunctionNames is empty", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
tool_choice: "auto",
allowedFunctionNames: [],
})
.next()

const config = stub.mock.calls[0][0].config
// Empty allowedFunctionNames should fall back to tool_choice behavior
expect(config.toolConfig.functionCallingConfig.mode).toBe(FunctionCallingConfigMode.AUTO)
expect(config.toolConfig.functionCallingConfig.allowedFunctionNames).toBeUndefined()
})

it("should not set toolConfig when allowedFunctionNames is undefined and no tool_choice", async () => {
const options = {
apiProvider: "gemini",
} as ApiHandlerOptions
const handler = new GeminiHandler(options)
const stub = vi.fn().mockReturnValue((async function* () {})())
// @ts-ignore access private client
handler["client"].models.generateContentStream = stub

await handler
.createMessage("test", [] as any, {
taskId: "test-task",
tools: testTools,
})
.next()

const config = stub.mock.calls[0][0].config
// No toolConfig should be set when neither allowedFunctionNames nor tool_choice is provided
expect(config.toolConfig).toBeUndefined()
})
})
})
14 changes: 13 additions & 1 deletion src/api/providers/gemini.ts
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,19 @@ export class GeminiHandler extends BaseProvider implements SingleCompletionHandl
...(tools.length > 0 ? { tools } : {}),
}

if (metadata?.tool_choice) {
// Handle allowedFunctionNames for mode-restricted tool access.
// When provided, all tool definitions are passed to the model (so it can reference
// historical tool calls in conversation), but only the specified tools can be invoked.
// This takes precedence over tool_choice to ensure mode restrictions are honored.
if (metadata?.allowedFunctionNames && metadata.allowedFunctionNames.length > 0) {
config.toolConfig = {
functionCallingConfig: {
// Use ANY mode to allow calling any of the allowed functions
mode: FunctionCallingConfigMode.ANY,
allowedFunctionNames: metadata.allowedFunctionNames,
},
}
} else if (metadata?.tool_choice) {
const choice = metadata.tool_choice
let mode: FunctionCallingConfigMode
let allowedFunctionNames: string[] | undefined
Expand Down
53 changes: 52 additions & 1 deletion src/core/prompts/tools/__tests__/filter-tools-for-mode.spec.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import { describe, it, expect, beforeEach, afterEach } from "vitest"
import type OpenAI from "openai"
import type { ModeConfig, ModelInfo } from "@roo-code/types"
import { filterNativeToolsForMode, filterMcpToolsForMode, applyModelToolCustomization } from "../filter-tools-for-mode"
import {
filterNativeToolsForMode,
filterMcpToolsForMode,
applyModelToolCustomization,
resolveToolAlias,
} from "../filter-tools-for-mode"
import * as toolsModule from "../../../../shared/tools"

describe("filterNativeToolsForMode", () => {
Expand Down Expand Up @@ -859,3 +864,49 @@ describe("filterMcpToolsForMode", () => {
})
})
})

describe("resolveToolAlias", () => {
it("should resolve known alias to canonical name", () => {
// write_file is an alias for write_to_file (defined in TOOL_ALIASES)
expect(resolveToolAlias("write_file")).toBe("write_to_file")
})

it("should return canonical name unchanged", () => {
expect(resolveToolAlias("write_to_file")).toBe("write_to_file")
expect(resolveToolAlias("read_file")).toBe("read_file")
expect(resolveToolAlias("apply_diff")).toBe("apply_diff")
})

it("should return unknown tool names unchanged", () => {
expect(resolveToolAlias("unknown_tool")).toBe("unknown_tool")
expect(resolveToolAlias("custom_tool_xyz")).toBe("custom_tool_xyz")
})

it("should ensure allowedFunctionNames are consistent with functionDeclarations", () => {
// This test documents the fix for the Gemini allowedFunctionNames issue.
// When tools are renamed via aliasRenames, the alias names must be resolved
// back to canonical names for allowedFunctionNames to match functionDeclarations.
//
// Example scenario:
// - Model specifies includedTools: ["write_file"] (an alias)
// - filterNativeToolsForMode returns tool with name "write_file"
// - But allTools (functionDeclarations) contains "write_to_file" (canonical)
// - If allowedFunctionNames contains "write_file", Gemini will error
// - Resolving aliases ensures consistency: resolveToolAlias("write_file") -> "write_to_file"

const aliasToolName = "write_file"
const canonicalToolName = "write_to_file"

// Simulate extracting name from a filtered tool that was renamed to alias
const extractedName = aliasToolName

// Before the fix: allowedFunctionNames would contain alias name
// This would cause Gemini to error because "write_file" doesn't exist in functionDeclarations

// After the fix: we resolve to canonical name
const resolvedName = resolveToolAlias(extractedName)

// The resolved name matches what's in functionDeclarations (canonical names)
expect(resolvedName).toBe(canonicalToolName)
})
})
24 changes: 21 additions & 3 deletions src/core/task/Task.ts
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ import { sanitizeToolUseId } from "../../utils/tool-id"
// prompts
import { formatResponse } from "../prompts/responses"
import { SYSTEM_PROMPT } from "../prompts/system"
import { buildNativeToolsArray } from "./build-tools"
import { buildNativeToolsArrayWithRestrictions } from "./build-tools"

// core modules
import { ToolRepetitionDetector } from "../tools/ToolRepetitionDetector"
Expand Down Expand Up @@ -4091,15 +4091,27 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
const taskProtocol = this._taskToolProtocol ?? "xml"
const shouldIncludeTools = taskProtocol === TOOL_PROTOCOL.NATIVE && (modelInfo.supportsNativeTools ?? false)

// Build complete tools array: native tools + dynamic MCP tools, filtered by mode restrictions
// Build complete tools array: native tools + dynamic MCP tools
// When includeAllToolsWithRestrictions is true, returns all tools but provides
// allowedFunctionNames for providers (like Gemini) that need to see all tool
// definitions in history while restricting callable tools for the current mode.
// Only Gemini currently supports this - other providers filter tools normally.
let allTools: OpenAI.Chat.ChatCompletionTool[] = []
let allowedFunctionNames: string[] | undefined

// Gemini requires all tool definitions to be present for history compatibility,
// but uses allowedFunctionNames to restrict which tools can be called.
// Other providers (Anthropic, OpenAI, etc.) don't support this feature yet,
// so they continue to receive only the filtered tools for the current mode.
const supportsAllowedFunctionNames = apiConfiguration?.apiProvider === "gemini"

if (shouldIncludeTools) {
const provider = this.providerRef.deref()
if (!provider) {
throw new Error("Provider reference lost during tool building")
}

allTools = await buildNativeToolsArray({
const toolsResult = await buildNativeToolsArrayWithRestrictions({
provider,
cwd: this.cwd,
mode,
Expand All @@ -4111,7 +4123,10 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
browserToolEnabled: state?.browserToolEnabled ?? true,
modelInfo,
diffEnabled: this.diffEnabled,
includeAllToolsWithRestrictions: supportsAllowedFunctionNames,
})
allTools = toolsResult.tools
allowedFunctionNames = toolsResult.allowedFunctionNames
}

// Parallel tool calls are disabled - feature is on hold
Expand All @@ -4129,6 +4144,9 @@ export class Task extends EventEmitter<TaskEvents> implements TaskLike {
tool_choice: "auto",
toolProtocol: taskProtocol,
parallelToolCalls: parallelToolCallsEnabled,
// When mode restricts tools, provide allowedFunctionNames so providers
// like Gemini can see all tools in history but only call allowed ones
...(allowedFunctionNames ? { allowedFunctionNames } : {}),
}
: {}),
}
Expand Down
Loading
Loading