Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/agents/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ def as_tool(
parameters: type[Any] | None = None,
input_builder: StructuredToolInputBuilder | None = None,
include_input_schema: bool = False,
) -> Tool:
) -> FunctionTool:
"""Transform this agent into a tool, callable by other agents.
This is different from handoffs in two ways:
Expand Down
209 changes: 76 additions & 133 deletions tests/test_agent_as_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,9 @@ async def fake_run(

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="inherits_config_tool",
tool_description="inherit config",
),
tool = agent.as_tool(
tool_name="inherits_config_tool",
tool_description="inherit config",
)
tool_context = ToolContext(
context=None,
Expand Down Expand Up @@ -475,13 +472,10 @@ async def fake_run(

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="override_config_tool",
tool_description="override config",
run_config=explicit_run_config,
),
tool = agent.as_tool(
tool_name="override_config_tool",
tool_description="override config",
run_config=explicit_run_config,
)
tool_context = ToolContext(
context=None,
Expand Down Expand Up @@ -529,12 +523,9 @@ async def fake_run(

monkeypatch.setattr(Runner, "run", classmethod(fake_run))

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="trace_config_tool",
tool_description="inherits trace config",
),
tool = agent.as_tool(
tool_name="trace_config_tool",
tool_description="inherits trace config",
)
tool_context = ToolContext(
context=None,
Expand All @@ -561,13 +552,10 @@ class TranslationInput(BaseModel):
target: str

agent = Agent(name="translator")
tool = cast(
FunctionTool,
agent.as_tool(
tool_name="translate",
tool_description="Translate text",
parameters=TranslationInput,
),
tool = agent.as_tool(
tool_name="translate",
tool_description="Translate text",
parameters=TranslationInput,
)

captured: dict[str, Any] = {}
Expand Down Expand Up @@ -626,12 +614,9 @@ async def test_agent_as_tool_clears_stale_tool_input_for_plain_tools(
"""Non-structured agent tools should not inherit stale tool input."""

agent = Agent(name="plain_agent")
tool = cast(
FunctionTool,
agent.as_tool(
tool_name="plain_tool",
tool_description="Plain tool",
),
tool = agent.as_tool(
tool_name="plain_tool",
tool_description="Plain tool",
)

run_context = RunContextWrapper({"locale": "en-US"})
Expand Down Expand Up @@ -685,13 +670,10 @@ class TranslationInput(BaseModel):
target: str = Field(description="Target language")

agent = Agent(name="summary_agent")
tool = cast(
FunctionTool,
agent.as_tool(
tool_name="summarize_schema",
tool_description="Summary tool",
parameters=TranslationInput,
),
tool = agent.as_tool(
tool_name="summarize_schema",
tool_description="Summary tool",
parameters=TranslationInput,
)

captured: dict[str, Any] = {}
Expand Down Expand Up @@ -756,14 +738,11 @@ async def builder(options: StructuredToolInputBuilderOptions):
builder_calls.append(options)
return custom_items

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="builder_tool",
tool_description="Builder tool",
parameters=TranslationInput,
input_builder=builder,
),
tool = agent.as_tool(
tool_name="builder_tool",
tool_description="Builder tool",
parameters=TranslationInput,
input_builder=builder,
)

class DummyResult:
Expand Down Expand Up @@ -813,13 +792,10 @@ async def test_agent_as_tool_rejects_invalid_builder_output() -> None:
def builder(_options):
return 123

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="invalid_builder_tool",
tool_description="Invalid builder tool",
input_builder=builder,
),
tool = agent.as_tool(
tool_name="invalid_builder_tool",
tool_description="Invalid builder tool",
input_builder=builder,
)

tool_context = ToolContext(
Expand All @@ -844,14 +820,11 @@ class TranslationInput(BaseModel):
target: str = Field(description="Target language")

agent = Agent(name="schema_agent")
tool = cast(
FunctionTool,
agent.as_tool(
tool_name="schema_tool",
tool_description="Schema tool",
parameters=TranslationInput,
include_input_schema=True,
),
tool = agent.as_tool(
tool_name="schema_tool",
tool_description="Schema tool",
parameters=TranslationInput,
include_input_schema=True,
)

captured: dict[str, Any] = {}
Expand Down Expand Up @@ -903,13 +876,10 @@ async def test_agent_as_tool_ignores_input_schema_without_parameters(
"""include_input_schema should be ignored when no parameters are provided."""

agent = Agent(name="default_schema_agent")
tool = cast(
FunctionTool,
agent.as_tool(
tool_name="default_schema_tool",
tool_description="Default schema tool",
include_input_schema=True,
),
tool = agent.as_tool(
tool_name="default_schema_tool",
tool_description="Default schema tool",
include_input_schema=True,
)

captured: dict[str, Any] = {}
Expand Down Expand Up @@ -1017,14 +987,11 @@ async def extractor(result: Any) -> str:
assert result is resumed_result
return "from_resume"

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="outer_tool",
tool_description="Outer agent tool",
custom_output_extractor=extractor,
is_enabled=True,
),
tool = agent.as_tool(
tool_name="outer_tool",
tool_description="Outer agent tool",
custom_output_extractor=extractor,
is_enabled=True,
)

output = await tool.on_invoke_tool(tool_context, tool_call.arguments)
Expand Down Expand Up @@ -1102,13 +1069,10 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
type="function_call",
)

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="stream_tool",
tool_description="Streams events",
on_stream=on_stream,
),
tool = agent.as_tool(
tool_name="stream_tool",
tool_description="Streams events",
on_stream=on_stream,
)

tool_context = ToolContext(
Expand Down Expand Up @@ -1179,13 +1143,10 @@ def fake_run_streamed(
async def on_stream(payload: AgentToolStreamEvent) -> None:
seen_agents.append(payload["agent"])

tool = cast(
FunctionTool,
first_agent.as_tool(
tool_name="delegate_tool",
tool_description="Streams handoff events",
on_stream=on_stream,
),
tool = first_agent.as_tool(
tool_name="delegate_tool",
tool_description="Streams handoff events",
on_stream=on_stream,
)

tool_call = ResponseFunctionToolCall(
Expand Down Expand Up @@ -1269,14 +1230,11 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
type="function_call",
)

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="stream_tool",
tool_description="Streams events",
custom_output_extractor=extractor,
on_stream=on_stream,
),
tool = agent.as_tool(
tool_name="stream_tool",
tool_description="Streams events",
custom_output_extractor=extractor,
on_stream=on_stream,
)

tool_context = ToolContext(
Expand Down Expand Up @@ -1329,13 +1287,10 @@ def sync_handler(event: AgentToolStreamEvent) -> None:
type="function_call",
)

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="sync_tool",
tool_description="Uses sync handler",
on_stream=sync_handler,
),
tool = agent.as_tool(
tool_name="sync_tool",
tool_description="Uses sync handler",
on_stream=sync_handler,
)
tool_context = ToolContext(
context=None,
Expand Down Expand Up @@ -1402,13 +1357,10 @@ async def on_stream(payload: AgentToolStreamEvent) -> None:
type="function_call",
)

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="nonblocking_tool",
tool_description="Uses non-blocking streaming handler",
on_stream=on_stream,
),
tool = agent.as_tool(
tool_name="nonblocking_tool",
tool_description="Uses non-blocking streaming handler",
on_stream=on_stream,
)
tool_context = ToolContext(
context=None,
Expand Down Expand Up @@ -1468,13 +1420,10 @@ def bad_handler(event: AgentToolStreamEvent) -> None:
type="function_call",
)

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="error_tool",
tool_description="Handler throws",
on_stream=bad_handler,
),
tool = agent.as_tool(
tool_name="error_tool",
tool_description="Handler throws",
on_stream=bad_handler,
)
tool_context = ToolContext(
context=None,
Expand Down Expand Up @@ -1525,12 +1474,9 @@ async def fake_run(
classmethod(lambda *args, **kwargs: (_ for _ in ()).throw(AssertionError("no stream"))),
)

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="nostream_tool",
tool_description="No streaming path",
),
tool = agent.as_tool(
tool_name="nostream_tool",
tool_description="No streaming path",
)
tool_context = ToolContext(
context=None,
Expand Down Expand Up @@ -1581,13 +1527,10 @@ async def on_stream(event: AgentToolStreamEvent) -> None:
type="function_call",
)

tool = cast(
FunctionTool,
agent.as_tool(
tool_name="direct_stream_tool",
tool_description="Direct invocation",
on_stream=on_stream,
),
tool = agent.as_tool(
tool_name="direct_stream_tool",
tool_description="Direct invocation",
on_stream=on_stream,
)
tool_context = ToolContext(
context=None,
Expand Down
13 changes: 5 additions & 8 deletions tests/test_example_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
)
from agents.agent import ToolsToFinalOutputResult
from agents.items import TResponseInputItem
from agents.tool import FunctionTool, FunctionToolResult, function_tool
from agents.tool import FunctionToolResult, function_tool

from .fake_model import FakeModel
from .test_responses import (
Expand Down Expand Up @@ -444,13 +444,10 @@ async def test_agent_as_tool_streaming_example_collects_events() -> None:
async def on_stream(event: AgentToolStreamEvent) -> None:
received.append(event)

billing_tool = cast(
FunctionTool,
billing_agent.as_tool(
tool_name="billing_agent",
tool_description="Answer billing questions",
on_stream=on_stream,
),
billing_tool = billing_agent.as_tool(
tool_name="billing_agent",
tool_description="Answer billing questions",
on_stream=on_stream,
)

async def fake_invoke(ctx, input: str) -> str:
Expand Down