diff --git a/src/agents/_run_impl.py b/src/agents/_run_impl.py index 54fceef57f..df6d243d04 100644 --- a/src/agents/_run_impl.py +++ b/src/agents/_run_impl.py @@ -95,6 +95,8 @@ ShellResult, ShellTool, Tool, + ToolOrigin, + ToolOriginType, resolve_computer, ) from .tool_context import ToolContext @@ -157,6 +159,20 @@ class ToolRunFunction: function_tool: FunctionTool +def _get_tool_origin_info(function_tool: FunctionTool) -> ToolOrigin | None: + """Extract origin information from a FunctionTool. + + Returns: + ToolOrigin object if origin is set, otherwise None (defaults to FUNCTION type). + """ + origin = function_tool._tool_origin + if origin is None: + # Default to FUNCTION if not explicitly set + return ToolOrigin(type=ToolOriginType.FUNCTION) + + return origin + + @dataclass class ToolRunComputerAction: tool_call: ResponseComputerToolCall @@ -750,11 +766,19 @@ def process_model_response( error = f"Tool {output.name} not found in agent {agent.name}" raise ModelBehaviorError(error) - items.append(ToolCallItem(raw_item=output, agent=agent)) + function_tool = function_map[output.name] + tool_origin = _get_tool_origin_info(function_tool) + items.append( + ToolCallItem( + raw_item=output, + agent=agent, + tool_origin=tool_origin, + ) + ) functions.append( ToolRunFunction( tool_call=output, - function_tool=function_map[output.name], + function_tool=function_tool, ) ) @@ -1019,6 +1043,7 @@ async def run_single_tool( output=result, raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result), agent=agent, + tool_origin=_get_tool_origin_info(tool_run.function_tool), ), ) for tool_run, result in zip(tool_runs, results) diff --git a/src/agents/agent.py b/src/agents/agent.py index d8c7d19e20..d93fcc9723 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -29,6 +29,8 @@ FunctionToolResult, Tool, ToolErrorFunction, + ToolOrigin, + ToolOriginType, default_tool_error_function, function_tool, ) @@ -535,6 +537,11 @@ async def dispatch_stream_events() -> None: return run_result.final_output + # Set origin tracking on the FunctionTool created by @function_tool + run_agent._tool_origin = ToolOrigin( + type=ToolOriginType.AGENT_AS_TOOL, + agent_as_tool_name=self.name, + ) return run_agent async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None: diff --git a/src/agents/items.py b/src/agents/items.py index 991a7f8772..72f057e831 100644 --- a/src/agents/items.py +++ b/src/agents/items.py @@ -48,6 +48,7 @@ from .exceptions import AgentsException, ModelBehaviorError from .logger import logger from .tool import ( + ToolOrigin, ToolOutputFileContent, ToolOutputImage, ToolOutputText, @@ -244,6 +245,9 @@ class ToolCallItem(RunItemBase[Any]): type: Literal["tool_call_item"] = "tool_call_item" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + ToolCallOutputTypes: TypeAlias = Union[ FunctionCallOutput, @@ -267,6 +271,9 @@ class ToolCallOutputItem(RunItemBase[Any]): type: Literal["tool_call_output_item"] = "tool_call_output_item" + tool_origin: ToolOrigin | None = field(default=None, repr=False) + """Information about the origin/source of the tool call. Only set for FunctionTool calls.""" + def to_input_item(self) -> TResponseInputItem: """Converts the tool output into an input item for the next model turn. diff --git a/src/agents/mcp/util.py b/src/agents/mcp/util.py index 6cfe5c96d5..f82a6c732a 100644 --- a/src/agents/mcp/util.py +++ b/src/agents/mcp/util.py @@ -11,7 +11,7 @@ from ..logger import logger from ..run_context import RunContextWrapper from ..strict_schema import ensure_strict_json_schema -from ..tool import FunctionTool, Tool +from ..tool import FunctionTool, Tool, ToolOrigin, ToolOriginType from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span from ..util._types import MaybeAwaitable @@ -170,13 +170,18 @@ def to_function_tool( except Exception as e: logger.info(f"Error converting MCP schema to strict mode: {e}") - return FunctionTool( + function_tool = FunctionTool( name=tool.name, description=tool.description or "", params_json_schema=schema, on_invoke_tool=invoke_func, strict_json_schema=is_strict, ) + function_tool._tool_origin = ToolOrigin( + type=ToolOriginType.MCP, + mcp_server_name=server.name, + ) + return function_tool @classmethod async def invoke_mcp_tool( diff --git a/src/agents/run.py b/src/agents/run.py index 5b5e6fdfae..ab3dd183a4 100644 --- a/src/agents/run.py +++ b/src/agents/run.py @@ -10,6 +10,7 @@ from openai.types.responses import ( ResponseCompletedEvent, + ResponseFunctionToolCall, ResponseOutputItemDoneEvent, ) from openai.types.responses.response_prompt_param import ( @@ -27,6 +28,7 @@ RunImpl, SingleStepResult, TraceCtxManager, + _get_tool_origin_info, get_model_tracing_impl, ) from .agent import Agent @@ -71,7 +73,7 @@ RunItemStreamEvent, StreamEvent, ) -from .tool import Tool, dispose_resolved_computers +from .tool import FunctionTool, Tool, dispose_resolved_computers from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult from .tracing import Span, SpanError, agent_span, get_current_trace, trace from .tracing.span_data import AgentSpanData @@ -1395,6 +1397,7 @@ async def _run_single_turn_streamed( model = cls._get_model(agent, run_config) model_settings = agent.model_settings.resolve(run_config.model_settings) model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings) + function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)} final_response: ModelResponse | None = None @@ -1486,9 +1489,19 @@ async def _run_single_turn_streamed( if call_id and call_id not in emitted_tool_call_ids: emitted_tool_call_ids.add(call_id) + # Try to get origin info if this is a FunctionTool call + # Use same lookup logic as _run_impl (function_map with last-wins semantics) + tool_origin = None + if isinstance(output_item, ResponseFunctionToolCall): + tool_name = getattr(output_item, "name", None) + if tool_name and tool_name in function_map: + function_tool = function_map[tool_name] + tool_origin = _get_tool_origin_info(function_tool) + tool_item = ToolCallItem( raw_item=cast(ToolCallItemTypes, output_item), agent=agent, + tool_origin=tool_origin, ) streamed_result._event_queue.put_nowait( RunItemStreamEvent(item=tool_item, name="tool_called") diff --git a/src/agents/tool.py b/src/agents/tool.py index 8c8d3e9880..93461f60b8 100644 --- a/src/agents/tool.py +++ b/src/agents/tool.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import inspect import json import weakref @@ -179,6 +180,42 @@ class ComputerProvider(Generic[ComputerT]): ] +class ToolOriginType(str, enum.Enum): + """The type of tool origin.""" + + FUNCTION = "function" + """Regular Python function tool created via @function_tool decorator.""" + + MCP = "mcp" + """MCP server tool converted via MCPUtil.to_function_tool().""" + + AGENT_AS_TOOL = "agent_as_tool" + """Agent converted to tool via agent.as_tool().""" + + +@dataclass +class ToolOrigin: + """Information about the origin/source of a function tool.""" + + type: ToolOriginType + """The type of tool origin.""" + + mcp_server_name: str | None = None + """The name of the MCP server. Only set when type is MCP.""" + + agent_as_tool_name: str | None = None + """The name of the agent. Only set when type is AGENT_AS_TOOL.""" + + def __repr__(self) -> str: + """Custom repr that only includes relevant fields.""" + parts = [f"type={self.type.value!r}"] + if self.mcp_server_name is not None: + parts.append(f"mcp_server_name={self.mcp_server_name!r}") + if self.agent_as_tool_name is not None: + parts.append(f"agent_as_tool_name={self.agent_as_tool_name!r}") + return f"ToolOrigin({', '.join(parts)})" + + @dataclass class FunctionToolResult: tool: FunctionTool @@ -235,6 +272,9 @@ class FunctionTool: tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None """Optional list of output guardrails to run after invoking this tool.""" + _tool_origin: ToolOrigin | None = field(default=None, init=False, repr=False) + """Private field tracking tool origin. Set by SDK when creating tools.""" + def __post_init__(self): if self.strict_json_schema: self.params_json_schema = ensure_strict_json_schema(self.params_json_schema) diff --git a/tests/test_tool_origin.py b/tests/test_tool_origin.py new file mode 100644 index 0000000000..b3c528e78d --- /dev/null +++ b/tests/test_tool_origin.py @@ -0,0 +1,319 @@ +"""Tests for tool origin tracking feature.""" + +from __future__ import annotations + +import sys +from typing import cast + +import pytest + +from agents import Agent, FunctionTool, RunContextWrapper, Runner, function_tool +from agents.items import ToolCallItem, ToolCallItemTypes, ToolCallOutputItem +from agents.tool import ToolOrigin, ToolOriginType + +from .fake_model import FakeModel +from .test_responses import get_function_tool_call, get_text_message + +if sys.version_info >= (3, 10): + from .mcp.helpers import FakeMCPServer + + +@pytest.mark.asyncio +async def test_function_tool_origin(): + """Test that regular function tools have FUNCTION origin.""" + model = FakeModel() + + @function_tool + def test_tool(x: int) -> str: + """Test tool.""" + return f"result: {x}" + + agent = Agent(name="test", model=model, tools=[test_tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("test_tool", '{"x": 42}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_call_items[0].tool_origin.mcp_server_name is None + assert tool_call_items[0].tool_origin.agent_as_tool_name is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.FUNCTION + assert tool_output_items[0].tool_origin.mcp_server_name is None + assert tool_output_items[0].tool_origin.agent_as_tool_name is None + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_mcp_tool_origin(): + """Test that MCP tools have MCP origin with server name.""" + model = FakeModel() + server = FakeMCPServer(server_name="test_mcp_server") + server.add_tool("mcp_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("mcp_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server_name == "test_mcp_server" + assert tool_call_items[0].tool_origin.agent_as_tool_name is None + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server_name == "test_mcp_server" + assert tool_output_items[0].tool_origin.agent_as_tool_name is None + + +@pytest.mark.asyncio +async def test_agent_as_tool_origin(): + """Test that agent-as-tool has AGENT_AS_TOOL origin with agent name.""" + model = FakeModel() + nested_model = FakeModel() + + nested_agent = Agent( + name="nested_agent", + model=nested_model, + instructions="You are a nested agent.", + ) + + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + + tool = nested_agent.as_tool( + tool_name="nested_tool", + tool_description="A nested agent tool", + ) + + orchestrator = Agent(name="orchestrator", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("nested_tool", '{"input": "test"}')], + [get_text_message("done")], + ] + ) + + result = await Runner.run(orchestrator, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_call_items[0].tool_origin.mcp_server_name is None + assert tool_call_items[0].tool_origin.agent_as_tool_name == "nested_agent" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert tool_output_items[0].tool_origin.mcp_server_name is None + assert tool_output_items[0].tool_origin.agent_as_tool_name == "nested_agent" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_multiple_tool_origins(): + """Test that multiple tools from different origins work together.""" + model = FakeModel() + nested_model = FakeModel() + + @function_tool + def func_tool(x: int) -> str: + """Function tool.""" + return f"function: {x}" + + mcp_server = FakeMCPServer(server_name="mcp_server") + mcp_server.add_tool("mcp_tool", {}) + + nested_agent = Agent(name="nested", model=nested_model, instructions="Nested agent") + nested_model.add_multiple_turn_outputs([[get_text_message("nested response")]]) + agent_tool = nested_agent.as_tool(tool_name="agent_tool", tool_description="Agent tool") + + agent = Agent( + name="test", + model=model, + tools=[func_tool, agent_tool], + mcp_servers=[mcp_server], + ) + + model.add_multiple_turn_outputs( + [ + [ + get_function_tool_call("func_tool", '{"x": 1}'), + get_function_tool_call("mcp_tool", ""), + get_function_tool_call("agent_tool", '{"input": "test"}'), + ], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + tool_output_items = [item for item in result.new_items if isinstance(item, ToolCallOutputItem)] + + assert len(tool_call_items) == 3 + assert len(tool_output_items) == 3 + + # Find items by tool name + function_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "func_tool" + ) + mcp_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "mcp_tool" + ) + agent_item = next( + item for item in tool_call_items if getattr(item.raw_item, "name", None) == "agent_tool" + ) + + assert function_item.tool_origin is not None + assert function_item.tool_origin.type == ToolOriginType.FUNCTION + assert mcp_item.tool_origin is not None + assert mcp_item.tool_origin.type == ToolOriginType.MCP + assert mcp_item.tool_origin.mcp_server_name == "mcp_server" + assert agent_item.tool_origin is not None + assert agent_item.tool_origin.type == ToolOriginType.AGENT_AS_TOOL + assert agent_item.tool_origin.agent_as_tool_name == "nested" + + +@pytest.mark.asyncio +@pytest.mark.skipif(sys.version_info < (3, 10), reason="MCP tests require Python 3.10+") +async def test_tool_origin_streaming(): + """Test that tool origin is populated correctly in streaming scenarios.""" + model = FakeModel() + server = FakeMCPServer(server_name="streaming_server") + server.add_tool("streaming_tool", {}) + + agent = Agent(name="test", model=model, mcp_servers=[server]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("streaming_tool", "")], + [get_text_message("done")], + ] + ) + + result = Runner.run_streamed(agent, input="test") + tool_call_items = [] + tool_output_items = [] + + async for event in result.stream_events(): + if event.type == "run_item_stream_event": + if isinstance(event.item, ToolCallItem): + tool_call_items.append(event.item) + elif isinstance(event.item, ToolCallOutputItem): + tool_output_items.append(event.item) + + assert len(tool_call_items) == 1 + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_call_items[0].tool_origin.mcp_server_name == "streaming_server" + + assert len(tool_output_items) == 1 + assert tool_output_items[0].tool_origin is not None + assert tool_output_items[0].tool_origin.type == ToolOriginType.MCP + assert tool_output_items[0].tool_origin.mcp_server_name == "streaming_server" + + +@pytest.mark.asyncio +async def test_tool_origin_repr(): + """Test that ToolOrigin repr only shows relevant fields.""" + # FUNCTION origin + function_origin = ToolOrigin(type=ToolOriginType.FUNCTION) + assert "mcp_server_name" not in repr(function_origin) + assert "agent_as_tool_name" not in repr(function_origin) + + # MCP origin + mcp_origin = ToolOrigin(type=ToolOriginType.MCP, mcp_server_name="test_server") + assert "mcp_server_name='test_server'" in repr(mcp_origin) + assert "agent_as_tool_name" not in repr(mcp_origin) + + # AGENT_AS_TOOL origin + agent_origin = ToolOrigin(type=ToolOriginType.AGENT_AS_TOOL, agent_as_tool_name="test_agent") + assert "agent_as_tool_name='test_agent'" in repr(agent_origin) + assert "mcp_server_name" not in repr(agent_origin) + + +@pytest.mark.asyncio +async def test_tool_origin_defaults_to_function(): + """Test that tools without explicit origin default to FUNCTION.""" + model = FakeModel() + + # Create a FunctionTool directly without using @function_tool decorator + async def test_func(ctx: RunContextWrapper, args: str) -> str: + return "result" + + tool = FunctionTool( + name="direct_tool", + description="Direct tool", + params_json_schema={"type": "object", "properties": {}}, + on_invoke_tool=test_func, + ) + + agent = Agent(name="test", model=model, tools=[tool]) + + model.add_multiple_turn_outputs( + [ + [get_function_tool_call("direct_tool", "")], + [get_text_message("done")], + ] + ) + + result = await Runner.run(agent, input="test") + tool_call_items = [item for item in result.new_items if isinstance(item, ToolCallItem)] + + assert len(tool_call_items) == 1 + # Even though _tool_origin is None, _get_tool_origin_info defaults to FUNCTION + assert tool_call_items[0].tool_origin is not None + assert tool_call_items[0].tool_origin.type == ToolOriginType.FUNCTION + + +@pytest.mark.asyncio +async def test_non_function_tool_items_have_no_origin(): + """Test that non-FunctionTool items (computer, shell, etc.) don't have tool_origin.""" + model = FakeModel() + + @function_tool + def func_tool() -> str: + """Function tool.""" + return "result" + + agent = Agent(name="test", model=model, tools=[func_tool]) + + # Create a ToolCallItem for a non-function tool (simulating computer/shell tool) + computer_call = { + "type": "computer_use_preview", + "call_id": "call_123", + "actions": [], + } + + # This simulates what happens for non-FunctionTool items + # They should not have tool_origin set + item = ToolCallItem( + raw_item=cast(ToolCallItemTypes, computer_call), + agent=agent, + ) + + assert item.tool_origin is None