diff --git a/src/agents/agent.py b/src/agents/agent.py index 113de8847..bf9760e79 100644 --- a/src/agents/agent.py +++ b/src/agents/agent.py @@ -485,7 +485,7 @@ def as_tool( parameters: type[Any] | None = None, input_builder: StructuredToolInputBuilder | None = None, include_input_schema: bool = False, - ) -> Tool: + ) -> FunctionTool: """Transform this agent into a tool, callable by other agents. This is different from handoffs in two ways: diff --git a/tests/test_agent_as_tool.py b/tests/test_agent_as_tool.py index 693fc9a74..d2cdafd15 100644 --- a/tests/test_agent_as_tool.py +++ b/tests/test_agent_as_tool.py @@ -421,12 +421,9 @@ async def fake_run( monkeypatch.setattr(Runner, "run", classmethod(fake_run)) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="inherits_config_tool", - tool_description="inherit config", - ), + tool = agent.as_tool( + tool_name="inherits_config_tool", + tool_description="inherit config", ) tool_context = ToolContext( context=None, @@ -475,13 +472,10 @@ async def fake_run( monkeypatch.setattr(Runner, "run", classmethod(fake_run)) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="override_config_tool", - tool_description="override config", - run_config=explicit_run_config, - ), + tool = agent.as_tool( + tool_name="override_config_tool", + tool_description="override config", + run_config=explicit_run_config, ) tool_context = ToolContext( context=None, @@ -529,12 +523,9 @@ async def fake_run( monkeypatch.setattr(Runner, "run", classmethod(fake_run)) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="trace_config_tool", - tool_description="inherits trace config", - ), + tool = agent.as_tool( + tool_name="trace_config_tool", + tool_description="inherits trace config", ) tool_context = ToolContext( context=None, @@ -561,13 +552,10 @@ class TranslationInput(BaseModel): target: str agent = Agent(name="translator") - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="translate", - tool_description="Translate text", - parameters=TranslationInput, - ), + tool = agent.as_tool( + tool_name="translate", + tool_description="Translate text", + parameters=TranslationInput, ) captured: dict[str, Any] = {} @@ -626,12 +614,9 @@ async def test_agent_as_tool_clears_stale_tool_input_for_plain_tools( """Non-structured agent tools should not inherit stale tool input.""" agent = Agent(name="plain_agent") - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="plain_tool", - tool_description="Plain tool", - ), + tool = agent.as_tool( + tool_name="plain_tool", + tool_description="Plain tool", ) run_context = RunContextWrapper({"locale": "en-US"}) @@ -685,13 +670,10 @@ class TranslationInput(BaseModel): target: str = Field(description="Target language") agent = Agent(name="summary_agent") - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="summarize_schema", - tool_description="Summary tool", - parameters=TranslationInput, - ), + tool = agent.as_tool( + tool_name="summarize_schema", + tool_description="Summary tool", + parameters=TranslationInput, ) captured: dict[str, Any] = {} @@ -756,14 +738,11 @@ async def builder(options: StructuredToolInputBuilderOptions): builder_calls.append(options) return custom_items - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="builder_tool", - tool_description="Builder tool", - parameters=TranslationInput, - input_builder=builder, - ), + tool = agent.as_tool( + tool_name="builder_tool", + tool_description="Builder tool", + parameters=TranslationInput, + input_builder=builder, ) class DummyResult: @@ -813,13 +792,10 @@ async def test_agent_as_tool_rejects_invalid_builder_output() -> None: def builder(_options): return 123 - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="invalid_builder_tool", - tool_description="Invalid builder tool", - input_builder=builder, - ), + tool = agent.as_tool( + tool_name="invalid_builder_tool", + tool_description="Invalid builder tool", + input_builder=builder, ) tool_context = ToolContext( @@ -844,14 +820,11 @@ class TranslationInput(BaseModel): target: str = Field(description="Target language") agent = Agent(name="schema_agent") - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="schema_tool", - tool_description="Schema tool", - parameters=TranslationInput, - include_input_schema=True, - ), + tool = agent.as_tool( + tool_name="schema_tool", + tool_description="Schema tool", + parameters=TranslationInput, + include_input_schema=True, ) captured: dict[str, Any] = {} @@ -903,13 +876,10 @@ async def test_agent_as_tool_ignores_input_schema_without_parameters( """include_input_schema should be ignored when no parameters are provided.""" agent = Agent(name="default_schema_agent") - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="default_schema_tool", - tool_description="Default schema tool", - include_input_schema=True, - ), + tool = agent.as_tool( + tool_name="default_schema_tool", + tool_description="Default schema tool", + include_input_schema=True, ) captured: dict[str, Any] = {} @@ -1017,14 +987,11 @@ async def extractor(result: Any) -> str: assert result is resumed_result return "from_resume" - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="outer_tool", - tool_description="Outer agent tool", - custom_output_extractor=extractor, - is_enabled=True, - ), + tool = agent.as_tool( + tool_name="outer_tool", + tool_description="Outer agent tool", + custom_output_extractor=extractor, + is_enabled=True, ) output = await tool.on_invoke_tool(tool_context, tool_call.arguments) @@ -1102,13 +1069,10 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: type="function_call", ) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="stream_tool", - tool_description="Streams events", - on_stream=on_stream, - ), + tool = agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + on_stream=on_stream, ) tool_context = ToolContext( @@ -1179,13 +1143,10 @@ def fake_run_streamed( async def on_stream(payload: AgentToolStreamEvent) -> None: seen_agents.append(payload["agent"]) - tool = cast( - FunctionTool, - first_agent.as_tool( - tool_name="delegate_tool", - tool_description="Streams handoff events", - on_stream=on_stream, - ), + tool = first_agent.as_tool( + tool_name="delegate_tool", + tool_description="Streams handoff events", + on_stream=on_stream, ) tool_call = ResponseFunctionToolCall( @@ -1269,14 +1230,11 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: type="function_call", ) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="stream_tool", - tool_description="Streams events", - custom_output_extractor=extractor, - on_stream=on_stream, - ), + tool = agent.as_tool( + tool_name="stream_tool", + tool_description="Streams events", + custom_output_extractor=extractor, + on_stream=on_stream, ) tool_context = ToolContext( @@ -1329,13 +1287,10 @@ def sync_handler(event: AgentToolStreamEvent) -> None: type="function_call", ) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="sync_tool", - tool_description="Uses sync handler", - on_stream=sync_handler, - ), + tool = agent.as_tool( + tool_name="sync_tool", + tool_description="Uses sync handler", + on_stream=sync_handler, ) tool_context = ToolContext( context=None, @@ -1402,13 +1357,10 @@ async def on_stream(payload: AgentToolStreamEvent) -> None: type="function_call", ) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="nonblocking_tool", - tool_description="Uses non-blocking streaming handler", - on_stream=on_stream, - ), + tool = agent.as_tool( + tool_name="nonblocking_tool", + tool_description="Uses non-blocking streaming handler", + on_stream=on_stream, ) tool_context = ToolContext( context=None, @@ -1468,13 +1420,10 @@ def bad_handler(event: AgentToolStreamEvent) -> None: type="function_call", ) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="error_tool", - tool_description="Handler throws", - on_stream=bad_handler, - ), + tool = agent.as_tool( + tool_name="error_tool", + tool_description="Handler throws", + on_stream=bad_handler, ) tool_context = ToolContext( context=None, @@ -1525,12 +1474,9 @@ async def fake_run( classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))), ) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="nostream_tool", - tool_description="No streaming path", - ), + tool = agent.as_tool( + tool_name="nostream_tool", + tool_description="No streaming path", ) tool_context = ToolContext( context=None, @@ -1581,13 +1527,10 @@ async def on_stream(event: AgentToolStreamEvent) -> None: type="function_call", ) - tool = cast( - FunctionTool, - agent.as_tool( - tool_name="direct_stream_tool", - tool_description="Direct invocation", - on_stream=on_stream, - ), + tool = agent.as_tool( + tool_name="direct_stream_tool", + tool_description="Direct invocation", + on_stream=on_stream, ) tool_context = ToolContext( context=None, diff --git a/tests/test_example_workflows.py b/tests/test_example_workflows.py index a3603acc2..5dd0bc4b0 100644 --- a/tests/test_example_workflows.py +++ b/tests/test_example_workflows.py @@ -26,7 +26,7 @@ ) from agents.agent import ToolsToFinalOutputResult from agents.items import TResponseInputItem -from agents.tool import FunctionTool, FunctionToolResult, function_tool +from agents.tool import FunctionToolResult, function_tool from .fake_model import FakeModel from .test_responses import ( @@ -444,13 +444,10 @@ async def test_agent_as_tool_streaming_example_collects_events() -> None: async def on_stream(event: AgentToolStreamEvent) -> None: received.append(event) - billing_tool = cast( - FunctionTool, - billing_agent.as_tool( - tool_name="billing_agent", - tool_description="Answer billing questions", - on_stream=on_stream, - ), + billing_tool = billing_agent.as_tool( + tool_name="billing_agent", + tool_description="Answer billing questions", + on_stream=on_stream, ) async def fake_invoke(ctx, input: str) -> str: