Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 27 additions & 2 deletions src/agents/_run_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@
ShellResult,
ShellTool,
Tool,
ToolOrigin,
ToolOriginType,
resolve_computer,
)
from .tool_context import ToolContext
Expand Down Expand Up @@ -157,6 +159,20 @@ class ToolRunFunction:
function_tool: FunctionTool


def _get_tool_origin_info(function_tool: FunctionTool) -> ToolOrigin | None:
"""Extract origin information from a FunctionTool.

Returns:
ToolOrigin object if origin is set, otherwise None (defaults to FUNCTION type).
"""
origin = function_tool._tool_origin
if origin is None:
# Default to FUNCTION if not explicitly set
return ToolOrigin(type=ToolOriginType.FUNCTION)

return origin


@dataclass
class ToolRunComputerAction:
tool_call: ResponseComputerToolCall
Expand Down Expand Up @@ -750,11 +766,19 @@ def process_model_response(
error = f"Tool {output.name} not found in agent {agent.name}"
raise ModelBehaviorError(error)

items.append(ToolCallItem(raw_item=output, agent=agent))
function_tool = function_map[output.name]
tool_origin = _get_tool_origin_info(function_tool)
items.append(
ToolCallItem(
raw_item=output,
agent=agent,
tool_origin=tool_origin,
)
)
functions.append(
ToolRunFunction(
tool_call=output,
function_tool=function_map[output.name],
function_tool=function_tool,
)
)

Expand Down Expand Up @@ -1019,6 +1043,7 @@ async def run_single_tool(
output=result,
raw_item=ItemHelpers.tool_call_output_item(tool_run.tool_call, result),
agent=agent,
tool_origin=_get_tool_origin_info(tool_run.function_tool),
),
)
for tool_run, result in zip(tool_runs, results)
Expand Down
7 changes: 7 additions & 0 deletions src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
FunctionToolResult,
Tool,
ToolErrorFunction,
ToolOrigin,
ToolOriginType,
default_tool_error_function,
function_tool,
)
Expand Down Expand Up @@ -535,6 +537,11 @@ async def dispatch_stream_events() -> None:

return run_result.final_output

# Set origin tracking on the FunctionTool created by @function_tool
run_agent._tool_origin = ToolOrigin(
type=ToolOriginType.AGENT_AS_TOOL,
agent_as_tool_name=self.name,
)
return run_agent

async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> str | None:
Expand Down
7 changes: 7 additions & 0 deletions src/agents/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from .exceptions import AgentsException, ModelBehaviorError
from .logger import logger
from .tool import (
ToolOrigin,
ToolOutputFileContent,
ToolOutputImage,
ToolOutputText,
Expand Down Expand Up @@ -244,6 +245,9 @@ class ToolCallItem(RunItemBase[Any]):

type: Literal["tool_call_item"] = "tool_call_item"

tool_origin: ToolOrigin | None = field(default=None, repr=False)
"""Information about the origin/source of the tool call. Only set for FunctionTool calls."""


ToolCallOutputTypes: TypeAlias = Union[
FunctionCallOutput,
Expand All @@ -267,6 +271,9 @@ class ToolCallOutputItem(RunItemBase[Any]):

type: Literal["tool_call_output_item"] = "tool_call_output_item"

tool_origin: ToolOrigin | None = field(default=None, repr=False)
"""Information about the origin/source of the tool call. Only set for FunctionTool calls."""

def to_input_item(self) -> TResponseInputItem:
"""Converts the tool output into an input item for the next model turn.

Expand Down
9 changes: 7 additions & 2 deletions src/agents/mcp/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from ..logger import logger
from ..run_context import RunContextWrapper
from ..strict_schema import ensure_strict_json_schema
from ..tool import FunctionTool, Tool
from ..tool import FunctionTool, Tool, ToolOrigin, ToolOriginType
from ..tracing import FunctionSpanData, get_current_span, mcp_tools_span
from ..util._types import MaybeAwaitable

Expand Down Expand Up @@ -170,13 +170,18 @@ def to_function_tool(
except Exception as e:
logger.info(f"Error converting MCP schema to strict mode: {e}")

return FunctionTool(
function_tool = FunctionTool(
name=tool.name,
description=tool.description or "",
params_json_schema=schema,
on_invoke_tool=invoke_func,
strict_json_schema=is_strict,
)
function_tool._tool_origin = ToolOrigin(
type=ToolOriginType.MCP,
mcp_server_name=server.name,
)
return function_tool

@classmethod
async def invoke_mcp_tool(
Expand Down
15 changes: 14 additions & 1 deletion src/agents/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from openai.types.responses import (
ResponseCompletedEvent,
ResponseFunctionToolCall,
ResponseOutputItemDoneEvent,
)
from openai.types.responses.response_prompt_param import (
Expand All @@ -27,6 +28,7 @@
RunImpl,
SingleStepResult,
TraceCtxManager,
_get_tool_origin_info,
get_model_tracing_impl,
)
from .agent import Agent
Expand Down Expand Up @@ -71,7 +73,7 @@
RunItemStreamEvent,
StreamEvent,
)
from .tool import Tool, dispose_resolved_computers
from .tool import FunctionTool, Tool, dispose_resolved_computers
from .tool_guardrails import ToolInputGuardrailResult, ToolOutputGuardrailResult
from .tracing import Span, SpanError, agent_span, get_current_trace, trace
from .tracing.span_data import AgentSpanData
Expand Down Expand Up @@ -1395,6 +1397,7 @@ async def _run_single_turn_streamed(
model = cls._get_model(agent, run_config)
model_settings = agent.model_settings.resolve(run_config.model_settings)
model_settings = RunImpl.maybe_reset_tool_choice(agent, tool_use_tracker, model_settings)
function_map = {tool.name: tool for tool in all_tools if isinstance(tool, FunctionTool)}

final_response: ModelResponse | None = None

Expand Down Expand Up @@ -1486,9 +1489,19 @@ async def _run_single_turn_streamed(
if call_id and call_id not in emitted_tool_call_ids:
emitted_tool_call_ids.add(call_id)

# Try to get origin info if this is a FunctionTool call
# Use same lookup logic as _run_impl (function_map with last-wins semantics)
tool_origin = None
if isinstance(output_item, ResponseFunctionToolCall):
tool_name = getattr(output_item, "name", None)
if tool_name and tool_name in function_map:
function_tool = function_map[tool_name]
tool_origin = _get_tool_origin_info(function_tool)

tool_item = ToolCallItem(
raw_item=cast(ToolCallItemTypes, output_item),
agent=agent,
tool_origin=tool_origin,
)
streamed_result._event_queue.put_nowait(
RunItemStreamEvent(item=tool_item, name="tool_called")
Expand Down
40 changes: 40 additions & 0 deletions src/agents/tool.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import inspect
import json
import weakref
Expand Down Expand Up @@ -179,6 +180,42 @@ class ComputerProvider(Generic[ComputerT]):
]


class ToolOriginType(str, enum.Enum):
"""The type of tool origin."""

FUNCTION = "function"
"""Regular Python function tool created via @function_tool decorator."""

MCP = "mcp"
"""MCP server tool converted via MCPUtil.to_function_tool()."""

AGENT_AS_TOOL = "agent_as_tool"
"""Agent converted to tool via agent.as_tool()."""


@dataclass
class ToolOrigin:
"""Information about the origin/source of a function tool."""

type: ToolOriginType
"""The type of tool origin."""

mcp_server_name: str | None = None
"""The name of the MCP server. Only set when type is MCP."""

agent_as_tool_name: str | None = None
"""The name of the agent. Only set when type is AGENT_AS_TOOL."""

def __repr__(self) -> str:
"""Custom repr that only includes relevant fields."""
parts = [f"type={self.type.value!r}"]
if self.mcp_server_name is not None:
parts.append(f"mcp_server_name={self.mcp_server_name!r}")
if self.agent_as_tool_name is not None:
parts.append(f"agent_as_tool_name={self.agent_as_tool_name!r}")
return f"ToolOrigin({', '.join(parts)})"


@dataclass
class FunctionToolResult:
tool: FunctionTool
Expand Down Expand Up @@ -235,6 +272,9 @@ class FunctionTool:
tool_output_guardrails: list[ToolOutputGuardrail[Any]] | None = None
"""Optional list of output guardrails to run after invoking this tool."""

_tool_origin: ToolOrigin | None = field(default=None, init=False, repr=False)
"""Private field tracking tool origin. Set by SDK when creating tools."""

def __post_init__(self):
if self.strict_json_schema:
self.params_json_schema = ensure_strict_json_schema(self.params_json_schema)
Expand Down
Loading