diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..0fc09d50 --- /dev/null +++ b/examples/README.md @@ -0,0 +1,111 @@ +# Examples + +This directory contains example scripts and interactive tools for exploring the Llama Stack Client Python SDK. + +## Interactive Agent CLI + +`interactive_agent_cli.py` - An interactive command-line tool for exploring agent turn/step events with server-side tools. + +### Features + +- šŸ” **File Search Integration**: Automatically sets up a vector store with sample knowledge base +- šŸ“Š **Event Streaming**: See real-time turn/step events as the agent processes your queries +- šŸŽÆ **Server-Side Tools**: Demonstrates file_search and other server-side tool execution +- šŸ’¬ **Interactive REPL**: Chat-style interface for easy exploration + +### Prerequisites + +1. Start a Llama Stack server with OpenAI provider: + ```bash + cd ~/local/llama-stack + source ../stack-venv/bin/activate + export OPENAI_API_KEY= + llama stack run ci-tests --port 8321 + ``` + +2. Install the client (from repository root): + ```bash + cd /Users/ashwin/local/new-stainless/llama-stack-client-python + uv sync + ``` + +### Usage + +Basic usage (uses defaults: openai/gpt-4o, localhost:8321): +```bash +cd examples +uv run python interactive_agent_cli.py +``` + +With custom options: +```bash +uv run python interactive_agent_cli.py --model openai/gpt-4o-mini --base-url http://localhost:8321 +``` + +### Example Session + +``` +╔══════════════════════════════════════════════════════════════╗ +ā•‘ ā•‘ +ā•‘ šŸ¤– Interactive Agent Explorer šŸ” ā•‘ +ā•‘ ā•‘ +ā•‘ Explore agent turn/step events with server-side tools ā•‘ +ā•‘ ā•‘ +ā•šā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā• + +šŸ”§ Configuration: + Model: openai/gpt-4o + Server: http://localhost:8321 + +šŸ”Œ Connecting to server... + āœ“ Connected + +šŸ“š Setting up knowledge base... + Indexing documents....... āœ“ + Vector store ID: vs_abc123 + +šŸ¤– Creating agent with tools... + āœ“ Agent ready + +šŸ’¬ Type your questions (or 'quit' to exit, 'help' for suggestions) +────────────────────────────────────────────────────────────── + +šŸ§‘ You: What is Project Phoenix? + +šŸ¤– Assistant: + + ā”Œā”€ā”€ā”€ Turn turn_abc123 started ───┐ + │ │ + │ 🧠 Inference Step 0 started │ + │ šŸ” Tool Execution Step 1 │ + │ Tool: knowledge_search │ + │ Status: server_side │ + │ 🧠 Inference Step 2 │ + │ āœ“ Response: Project Phoenix... │ + │ │ + └─── Turn completed ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”€ā”˜ + +Project Phoenix is a next-generation distributed systems platform launched in 2024... +``` + +### What You'll See + +The tool uses `AgentEventLogger` to display: +- **Turn lifecycle**: TurnStarted → TurnCompleted +- **Inference steps**: When the model is thinking/generating text +- **Tool execution steps**: When server-side tools (like file_search) are running +- **Step metadata**: Whether tools are server-side or client-side +- **Real-time streaming**: Text appears as it's generated + +### Sample Questions + +Type `help` in the interactive session to see suggested questions, or try: +- "What is Project Phoenix?" +- "Who is the lead architect?" +- "What ports does the system use?" +- "How long do JWT tokens last?" +- "Where is the production environment deployed?" + +### Exit + +Type `quit`, `exit`, `q`, or press `Ctrl+C` to exit. diff --git a/examples/interactive_agent_cli.py b/examples/interactive_agent_cli.py new file mode 100755 index 00000000..c79d9efe --- /dev/null +++ b/examples/interactive_agent_cli.py @@ -0,0 +1,320 @@ +#!/usr/bin/env python3 +"""Interactive CLI for exploring agent turn/step events with server-side tools. + +Usage: + python interactive_agent_cli.py [--model MODEL] [--base-url URL] +""" + +import argparse +import io +import json +import os +import sys +import time +from pathlib import Path +from typing import Optional +from uuid import uuid4 + +from openai import OpenAI +from llama_stack_client import Agent, AgentEventLogger + + +CACHE_DIR = Path(os.path.expanduser("~/.cache/interactive-agent-cli")) +CACHE_DIR.mkdir(parents=True, exist_ok=True) +CACHE_FILE = CACHE_DIR / "vector_store.json" + + +def load_cached_vector_store() -> Optional[str]: + try: + with CACHE_FILE.open("r", encoding="utf-8") as fh: + payload = json.load(fh) + return payload.get("vector_store_id") + except FileNotFoundError: + return None + except Exception as exc: # pragma: no cover - defensive + print(f"āš ļø Failed to load cached vector store info: {exc}", file=sys.stderr) + return None + + +def save_cached_vector_store(vector_store_id: str) -> None: + try: + with CACHE_FILE.open("w", encoding="utf-8") as fh: + json.dump({"vector_store_id": vector_store_id}, fh) + except Exception as exc: # pragma: no cover - defensive + print(f"āš ļø Failed to cache vector store id: {exc}", file=sys.stderr) + + +def ensure_vector_store(client: OpenAI) -> str: + cached_id = load_cached_vector_store() + if cached_id: + # Verify the vector store still exists on the server + existing = client.vector_stores.list().data + if any(store.id == cached_id for store in existing): + print(f"šŸ“š Reusing cached knowledge base (vector store {cached_id})") + return cached_id + else: + print("āš ļø Cached vector store not found on server; creating a new one.") + + return setup_knowledge_base(client) + + +def setup_knowledge_base(client: OpenAI) -> str: + """Create a vector store with interesting test knowledge.""" + print("šŸ“š Setting up knowledge base...") + + # Create interesting test content + knowledge_content = """ + # Project Phoenix Documentation + + ## Overview + Project Phoenix is a next-generation distributed systems platform launched in 2024. + + ## Key Components + - **Phoenix Core**: The main orchestration engine + - **Phoenix Mesh**: Service mesh implementation + - **Phoenix Analytics**: Real-time data processing pipeline + + ## Authentication + - Primary auth method: OAuth 2.0 with JWT tokens + - Token expiration: 24 hours + - Refresh token validity: 7 days + + ## Architecture + The system uses a microservices architecture with: + - API Gateway on port 8080 + - Auth service on port 8081 + - Data service on port 8082 + + ## Team + - Lead Architect: Dr. Sarah Chen + - Security Lead: James Rodriguez + - DevOps Lead: Maria Santos + + ## Deployment + - Production: AWS us-east-1 + - Staging: AWS us-west-2 + - Development: Local Kubernetes cluster + """ + + # Upload file + file_payload = io.BytesIO(knowledge_content.encode("utf-8")) + uploaded_file = client.files.create( + file=("project_phoenix_docs.txt", file_payload, "text/plain"), + purpose="assistants", + ) + + # Create vector store + vector_store = client.vector_stores.create( + name=f"phoenix-kb-{uuid4().hex[:8]}", + extra_body={ + "provider_id": "faiss", + "embedding_model": "nomic-ai/nomic-embed-text-v1.5", + }, + ) + + # Add file to vector store + vector_store_file = client.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=uploaded_file.id, + ) + + # Wait for ingestion + print(" Indexing documents...", end="", flush=True) + deadline = time.time() + 60.0 + while vector_store_file.status != "completed": + if vector_store_file.status in {"failed", "cancelled"}: + raise RuntimeError(f"Vector store ingestion failed: {vector_store_file.status}") + if time.time() > deadline: + raise TimeoutError("Vector store file ingest timed out") + time.sleep(0.5) + vector_store_file = client.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=vector_store_file.id, + ) + print(".", end="", flush=True) + + print(" āœ“") + print(f" Vector store ID: {vector_store.id}") + print() + save_cached_vector_store(vector_store.id) + return vector_store.id + + +def print_banner(): + """Print a nice banner.""" + banner = """ +╔══════════════════════════════════════════════════════════════╗ +ā•‘ ā•‘ +ā•‘ šŸ¤– Interactive Agent Explorer šŸ” ā•‘ +ā•‘ ā•‘ +ā•‘ Explore agent turn/step events with server-side tools ā•‘ +ā•‘ ā•‘ +ā•šā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā•ā• +""" + print(banner) + + +def create_agent_with_tools(client, model, vector_store_id): + """Create an agent with file_search and other server-side tools.""" + tools = [ + { + "type": "file_search", + "vector_store_ids": [vector_store_id], + } + ] + + instructions = """You are a helpful AI assistant with access to a knowledge base about Project Phoenix. + +When answering questions: +1. ALWAYS search the knowledge base first using file_search +2. Provide specific details from the documentation +3. If information isn't in the knowledge base, say so clearly +4. Be concise but thorough + +Available tools: +- file_search: Search the Project Phoenix documentation +""" + + agent = Agent( + client=client, + model=model, + instructions=instructions, + tools=tools, + ) + + return agent + + +def interactive_loop(agent): + """Run the interactive query loop with nice event logging.""" + session_id = agent.create_session(f"interactive-{uuid4().hex[:8]}") + print(f"šŸ“ Session created: {session_id}\n") + + print("šŸ’¬ Type your questions (or 'quit' to exit, 'help' for suggestions)") + print("─" * 70) + print() + + while True: + try: + # Get user input + user_input = input("\nšŸ§‘ You: ").strip() + + if not user_input: + continue + + if user_input.lower() in {"quit", "exit", "q"}: + print("\nšŸ‘‹ Goodbye!") + break + + if user_input.lower() == "help": + print("\nšŸ’” Try asking:") + print(" • What is Project Phoenix?") + print(" • Who is the lead architect?") + print(" • What ports does the system use?") + print(" • How long do JWT tokens last?") + print(" • Where is the production environment deployed?") + continue + + # Create message + messages = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": user_input}], + } + ] + + print() + print("šŸ¤– Assistant:", end=" ", flush=True) + + # Stream response with event logging + event_logger = AgentEventLogger() + + for chunk in agent.create_turn(messages=messages, session_id=session_id, stream=True): + # Log the event + for log_msg in event_logger.log([chunk]): + print(log_msg, end="", flush=True) + + print() # New line after response + + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Goodbye!") + break + except Exception as e: + print(f"\nāŒ Error: {e}", file=sys.stderr) + print(" Please try again or type 'quit' to exit", file=sys.stderr) + + +def main(): + """Main entry point.""" + parser = argparse.ArgumentParser( + description="Interactive agent CLI with server-side tools", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s + %(prog)s --model openai/gpt-4o + %(prog)s --base-url http://localhost:8321/v1 + """, + ) + parser.add_argument( + "--model", + default="openai/gpt-4o", + help="Model to use (default: openai/gpt-4o)", + ) + parser.add_argument( + "--base-url", + default="http://localhost:8321/v1", + help="Llama Stack server URL (default: http://localhost:8321/v1)", + ) + + args = parser.parse_args() + + print_banner() + print(f"šŸ”§ Configuration:") + print(f" Model: {args.model}") + print(f" Server: {args.base_url}") + print() + + # Create client + print("šŸ”Œ Connecting to server...") + try: + client = OpenAI(base_url=args.base_url) + models = client.models.list() + identifiers = [model.identifier for model in models] + if args.model not in identifiers: + print(f" āœ— Model {args.model} not found", file=sys.stderr) + print(f" Available models: {', '.join(identifiers)}", file=sys.stderr) + sys.exit(1) + + print(" āœ“ Connected") + print() + except Exception as e: + print(f" āœ— Failed to connect: {e}", file=sys.stderr) + print(f"\n Make sure the server is running at {args.base_url}", file=sys.stderr) + sys.exit(1) + + # Setup knowledge base + try: + print("šŸ” Setting up knowledge base...") + vector_store_id = ensure_vector_store(client) + except Exception as e: + print(f"āŒ Failed to setup knowledge base: {e}", file=sys.stderr) + sys.exit(1) + + # Create agent + print("šŸ¤– Creating agent with tools...") + try: + agent = create_agent_with_tools(client, args.model, vector_store_id) + print(" āœ“ Agent ready") + print() + except Exception as e: + print(f" āœ— Failed to create agent: {e}", file=sys.stderr) + sys.exit(1) + + # Run interactive loop + interactive_loop(agent) + + +if __name__ == "__main__": + main() diff --git a/src/llama_stack_client/__init__.py b/src/llama_stack_client/__init__.py index cc2fcb9b..c910dfe5 100644 --- a/src/llama_stack_client/__init__.py +++ b/src/llama_stack_client/__init__.py @@ -41,8 +41,8 @@ from .lib.agents.agent import Agent from .lib.agents.event_logger import EventLogger as AgentEventLogger from .lib.inference.event_logger import EventLogger as InferenceEventLogger -from .types.alpha.agents.turn_create_params import Document from .types.shared_params.document import Document as RAGDocument +from .types.alpha.agents.turn_create_params import Document __all__ = [ "types", diff --git a/src/llama_stack_client/lib/agents/agent.py b/src/llama_stack_client/lib/agents/agent.py index 5e00a88b..417a53e2 100644 --- a/src/llama_stack_client/lib/agents/agent.py +++ b/src/llama_stack_client/lib/agents/agent.py @@ -3,244 +3,161 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import json import logging -from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Tuple, Union - -from llama_stack_client import LlamaStackClient -from llama_stack_client.types import ToolResponseMessage, UserMessage -from llama_stack_client.types.alpha import ToolResponseParam -from llama_stack_client.types.alpha.agent_create_params import AgentConfig -from llama_stack_client.types.alpha.agents.agent_turn_response_stream_chunk import ( - AgentTurnResponseStreamChunk, +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Iterator, + List, + Optional, + Tuple, + Union, + TypedDict, ) -from llama_stack_client.types.alpha.agents.turn import CompletionMessage, Turn -from llama_stack_client.types.alpha.agents.turn_create_params import Document, Toolgroup -from llama_stack_client.types.shared.tool_call import ToolCall -from llama_stack_client.types.shared_params.agent_config import ToolConfig -from llama_stack_client.types.shared_params.response_format import ResponseFormat -from llama_stack_client.types.shared_params.sampling_params import SamplingParams +from uuid import uuid4 from ..._types import Headers from .client_tool import ClientTool, client_tool from .tool_parser import ToolParser +from .turn_events import ( + AgentStreamChunk, + StepCompleted, + StepProgress, + StepStarted, + ToolCallIssuedDelta, + TurnFailed, + ToolExecutionStepResult, +) +from .event_synthesizer import TurnEventSynthesizer +from .types import CompletionMessage, ToolCall, ToolResponse -DEFAULT_MAX_ITER = 10 -logger = logging.getLogger(__name__) +class ToolResponsePayload(TypedDict, total=False): + call_id: str + tool_name: str + content: Any + metadata: Dict[str, Any] -class AgentUtils: - @staticmethod - def get_client_tools( - tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]], - ) -> List[ClientTool]: - if not tools: - return [] +logger = logging.getLogger(__name__) - # Wrap any function in client_tool decorator - tools = [client_tool(tool) if (callable(tool) and not isinstance(tool, ClientTool)) else tool for tool in tools] - return [tool for tool in tools if isinstance(tool, ClientTool)] +class ToolUtils: @staticmethod - def get_tool_calls(chunk: AgentTurnResponseStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]: - if chunk.event.payload.event_type not in { - "turn_complete", - "turn_awaiting_input", - }: - return [] - - message = chunk.event.payload.turn.output_message - if message.stop_reason == "out_of_tokens": - return [] - - if tool_parser: - return tool_parser.get_tool_calls(message) - - return message.tool_calls + def coerce_tool_content(content: Any) -> str: + if isinstance(content, str): + return content + if content is None: + return "" + if isinstance(content, (dict, list)): + try: + return json.dumps(content) + except TypeError: + return str(content) + return str(content) @staticmethod - def get_turn_id(chunk: AgentTurnResponseStreamChunk) -> Optional[str]: - if chunk.event.payload.event_type not in [ - "turn_complete", - "turn_awaiting_input", - ]: - return None - - return chunk.event.payload.turn.turn_id + def parse_tool_arguments(arguments: Any) -> Dict[str, Any]: + if isinstance(arguments, dict): + return arguments + if not arguments: + return {} + if isinstance(arguments, str): + try: + parsed = json.loads(arguments) + except json.JSONDecodeError: + logger.warning("Failed to decode tool arguments JSON", exc_info=True) + return {} + if isinstance(parsed, dict): + return parsed + logger.warning("Tool arguments JSON did not decode into a dict: %s", type(parsed)) + return {} + logger.warning("Unsupported tool arguments type: %s", type(arguments)) + return {} @staticmethod - def get_agent_config( - model: Optional[str] = None, - instructions: Optional[str] = None, - tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, - tool_config: Optional[ToolConfig] = None, - sampling_params: Optional[SamplingParams] = None, - max_infer_iters: Optional[int] = None, - input_shields: Optional[List[str]] = None, - output_shields: Optional[List[str]] = None, - response_format: Optional[ResponseFormat] = None, - enable_session_persistence: Optional[bool] = None, - name: str | None = None, - ) -> AgentConfig: - # Create a minimal valid AgentConfig with required fields - if model is None or instructions is None: - raise ValueError("Both 'model' and 'instructions' are required when agent_config is not provided") - - agent_config = { - "model": model, - "instructions": instructions, - "toolgroups": [], - "client_tools": [], - } - - # Add optional parameters if provided - if enable_session_persistence is not None: - agent_config["enable_session_persistence"] = enable_session_persistence - if max_infer_iters is not None: - agent_config["max_infer_iters"] = max_infer_iters - if input_shields is not None: - agent_config["input_shields"] = input_shields - if output_shields is not None: - agent_config["output_shields"] = output_shields - if response_format is not None: - agent_config["response_format"] = response_format - if sampling_params is not None: - agent_config["sampling_params"] = sampling_params - if tool_config is not None: - agent_config["tool_config"] = tool_config - if name is not None: - agent_config["name"] = name - if tools is not None: - toolgroups: List[Toolgroup] = [] - for tool in tools: - if isinstance(tool, str) or isinstance(tool, dict): - toolgroups.append(tool) - - agent_config["toolgroups"] = toolgroups - agent_config["client_tools"] = [tool.get_tool_definition() for tool in AgentUtils.get_client_tools(tools)] - - agent_config = AgentConfig(**agent_config) - return agent_config + def normalize_tool_response(tool_response: Any) -> ToolResponsePayload: + if isinstance(tool_response, ToolResponse): + payload: ToolResponsePayload = { + "call_id": tool_response.call_id, + "tool_name": str(tool_response.tool_name), + "content": ToolUtils.coerce_tool_content(tool_response.content), + "metadata": dict(tool_response.metadata), + } + return payload + + if isinstance(tool_response, dict): + call_id = tool_response.get("call_id") + tool_name = tool_response.get("tool_name") + if call_id is None or tool_name is None: + raise KeyError("Tool response missing required keys 'call_id' or 'tool_name'") + payload: ToolResponsePayload = { + "call_id": str(call_id), + "tool_name": str(tool_name), + "content": ToolUtils.coerce_tool_content(tool_response.get("content")), + "metadata": dict(tool_response.get("metadata") or {}), + } + return payload + + raise TypeError(f"Unsupported tool response type: {type(tool_response)!r}") class Agent: def __init__( self, - client: LlamaStackClient, - # begin deprecated - agent_config: Optional[AgentConfig] = None, - client_tools: Tuple[ClientTool, ...] = (), - # end deprecated + client: Any, # Accept any OpenAI-compatible client + *, + model: str, + instructions: str, + tools: Optional[List[Union[Dict[str, Any], ClientTool, Callable[..., Any]]]] = None, tool_parser: Optional[ToolParser] = None, - model: Optional[str] = None, - instructions: Optional[str] = None, - tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, - tool_config: Optional[ToolConfig] = None, - sampling_params: Optional[SamplingParams] = None, - max_infer_iters: Optional[int] = None, - input_shields: Optional[List[str]] = None, - output_shields: Optional[List[str]] = None, - response_format: Optional[ResponseFormat] = None, - enable_session_persistence: Optional[bool] = None, extra_headers: Headers | None = None, - name: str | None = None, ): - """Construct an Agent with the given parameters. - - :param client: The LlamaStackClient instance. - :param agent_config: The AgentConfig instance. - ::deprecated: use other parameters instead - :param client_tools: A tuple of ClientTool instances. - ::deprecated: use tools instead - :param tool_parser: Custom logic that parses tool calls from a message. - :param model: The model to use for the agent. - :param instructions: The instructions for the agent. - :param tools: A list of tools for the agent. Values can be one of the following: - - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} - - a python function with a docstring. See @client_tool for more details. - - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" - - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent - - an instance of ClientTool: A client tool object. - :param tool_config: The tool configuration for the agent. - :param sampling_params: The sampling parameters for the agent. - :param max_infer_iters: The maximum number of inference iterations. - :param input_shields: The input shields for the agent. - :param output_shields: The output shields for the agent. - :param response_format: The response format for the agent. - :param enable_session_persistence: Whether to enable session persistence. - :param extra_headers: Extra headers to add to all requests sent by the agent. - :param name: Optional name for the agent, used in telemetry and identification. + """Construct an Agent backed by the responses + conversations APIs. + + Args: + client: An OpenAI-compatible client (e.g., openai.OpenAI()). + The client must support the responses and conversations APIs. """ self.client = client - - if agent_config is not None: - logger.warning("`agent_config` is deprecated. Use inlined parameters instead.") - if client_tools != (): - logger.warning("`client_tools` is deprecated. Use `tools` instead.") - - # Construct agent_config from parameters if not provided - if agent_config is None: - agent_config = AgentUtils.get_agent_config( - model=model, - instructions=instructions, - tools=tools, - tool_config=tool_config, - sampling_params=sampling_params, - max_infer_iters=max_infer_iters, - input_shields=input_shields, - output_shields=output_shields, - response_format=response_format, - enable_session_persistence=enable_session_persistence, - name=name, - ) - client_tools = AgentUtils.get_client_tools(tools) - - self.agent_config = agent_config - self.client_tools = {t.get_name(): t for t in client_tools} - self.sessions = [] self.tool_parser = tool_parser - self.builtin_tools = {} self.extra_headers = extra_headers - self.initialize() + self._model = model + self._instructions = instructions - def initialize(self) -> None: - agentic_system_create_response = self.client.alpha.agents.create( - agent_config=self.agent_config, - extra_headers=self.extra_headers, - ) - self.agent_id = agentic_system_create_response.agent_id - for tg in self.agent_config["toolgroups"]: - toolgroup_id = tg if isinstance(tg, str) else tg.get("name") - for tool in self.client.tools.list(toolgroup_id=toolgroup_id, extra_headers=self.extra_headers): - self.builtin_tools[tool.name] = tg.get("args", {}) if isinstance(tg, dict) else {} + # Convert all tools to API format and separate client-side functions + self._tools, client_tools = AgentUtils.normalize_tools(tools) + self.client_tools = {tool.get_name(): tool for tool in client_tools} + + self.sessions: List[str] = [] def create_session(self, session_name: str) -> str: - agentic_system_create_session_response = self.client.alpha.agents.session.create( - agent_id=self.agent_id, - session_name=session_name, + conversation = self.client.conversations.create( extra_headers=self.extra_headers, + metadata={"name": session_name}, ) - self.session_id = agentic_system_create_session_response.session_id - self.sessions.append(self.session_id) - return self.session_id + self.sessions.append(conversation.id) + return conversation.id - def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]: - responses = [] + def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponsePayload]: + responses: List[ToolResponsePayload] = [] for tool_call in tool_calls: - responses.append(self._run_single_tool(tool_call)) + raw_result = self._run_single_tool(tool_call) + responses.append(ToolUtils.normalize_tool_response(raw_result)) return responses - def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: - # custom client tools + def _run_single_tool(self, tool_call: ToolCall) -> Any: + # Execute client-side tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] - # NOTE: tool.run() expects a list of messages, we only pass in last message here - # but we could pass in the entire message history result_message = tool.run( [ CompletionMessage( role="assistant", - content=tool_call.tool_name, + content=tool_call.arguments, tool_calls=[tool_call], stop_reason="end_of_turn", ) @@ -248,277 +165,201 @@ def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: ) return result_message - # builtin tools executed by tool_runtime - if tool_call.tool_name in self.builtin_tools: - tool_result = self.client.tool_runtime.invoke_tool( - tool_name=tool_call.tool_name, - kwargs={ - **tool_call.arguments, - **self.builtin_tools[tool_call.tool_name], - }, - extra_headers=self.extra_headers, - ) - return ToolResponseParam( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=tool_result.content, - ) - - # cannot find tools - return ToolResponseParam( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called.", - ) + # Server-side tools should never reach here (they execute within response stream) + # If we get here, it's an error + return { + "call_id": tool_call.call_id, + "tool_name": tool_call.tool_name, + "content": f"Unknown tool `{tool_call.tool_name}` was called.", + } def create_turn( self, - messages: List[Union[UserMessage, ToolResponseMessage]], - session_id: Optional[str] = None, - toolgroups: Optional[List[Toolgroup]] = None, - documents: Optional[List[Document]] = None, + messages: List[Dict[str, Any]], + session_id: str, stream: bool = True, # TODO: deprecate this extra_headers: Headers | None = None, - ) -> Iterator[AgentTurnResponseStreamChunk] | Turn: + ) -> Iterator[AgentStreamChunk] | Any: if stream: - return self._create_turn_streaming( - messages, session_id, toolgroups, documents, extra_headers=extra_headers or self.extra_headers - ) + return self._create_turn_streaming(messages, session_id, extra_headers=extra_headers or self.extra_headers) else: - chunks = [ - x - for x in self._create_turn_streaming( - messages, - session_id, - toolgroups, - documents, - extra_headers=extra_headers or self.extra_headers, - ) - ] - if not chunks: + last_chunk: Optional[AgentStreamChunk] = None + for chunk in self._create_turn_streaming( + messages, session_id, extra_headers=extra_headers or self.extra_headers + ): + last_chunk = chunk + + if not last_chunk or not last_chunk.response: raise Exception("Turn did not complete") - last_chunk = chunks[-1] - if hasattr(last_chunk, "error"): - if "message" in last_chunk.error: - error_msg = last_chunk.error["message"] - else: - error_msg = str(last_chunk.error) - raise RuntimeError(f"Turn did not complete. Error: {error_msg}") - try: - return last_chunk.event.payload.turn - except AttributeError: - raise RuntimeError(f"Turn did not complete. Output: {last_chunk}") from None + return last_chunk.response def _create_turn_streaming( self, - messages: List[Union[UserMessage, ToolResponseMessage]], - session_id: Optional[str] = None, - toolgroups: Optional[List[Toolgroup]] = None, - documents: Optional[List[Document]] = None, + messages: List[Dict[str, Any]], + session_id: str, # TODO: deprecate this extra_headers: Headers | None = None, - ) -> Iterator[AgentTurnResponseStreamChunk]: - n_iter = 0 - - # 1. create an agent turn - turn_response = self.client.alpha.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - extra_headers=extra_headers or self.extra_headers, - ) + ) -> Iterator[AgentStreamChunk]: + # Generate turn_id + turn_id = f"turn_{uuid4().hex[:12]}" + + # Create synthesizer + synthesizer = TurnEventSynthesizer(session_id=session_id, turn_id=turn_id) + + request_headers = extra_headers or self.extra_headers + + # Main turn loop + while True: + # Create response stream + raw_stream = self.client.responses.create( + model=self._model, + instructions=self._instructions, + conversation=session_id, + input=messages, + tools=self._tools, + stream=True, + extra_headers=request_headers, + ) + + # Process events + function_calls_to_execute: List[ToolCall] = [] # Only client-side! - # 2. process turn and resume if there's a tool call - is_turn_complete = False - while not is_turn_complete: - is_turn_complete = True - for chunk in turn_response: - if hasattr(chunk, "error"): - yield chunk + for high_level_event in synthesizer.process_raw_stream(raw_stream): + # Handle failures + if isinstance(high_level_event, TurnFailed): + yield AgentStreamChunk(event=high_level_event) return - tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) - if not tool_calls: - yield chunk - else: - is_turn_complete = False - # End of turn is reached, do not resume even if there's a tool call - # We only check for this if tool_parser is not set, because otherwise - # tool call will be parsed on client side, and server will always return "end_of_turn" - if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: - yield chunk - break - - turn_id = AgentUtils.get_turn_id(chunk) - if n_iter == 0: - yield chunk - - # run the tools - tool_responses = self._run_tool_calls(tool_calls) - - # pass it to next iteration - turn_response = self.client.alpha.agents.turn.resume( - agent_id=self.agent_id, - session_id=session_id or self.session_id[-1], - turn_id=turn_id, + + # Track function calls that need client execution + if isinstance(high_level_event, StepCompleted): + if high_level_event.step_type == "inference": + result = high_level_event.result + if result.function_calls: # Only client-side function calls + function_calls_to_execute = result.function_calls + + yield AgentStreamChunk(event=high_level_event) + + # If no client-side function calls, turn is done + if not function_calls_to_execute: + # Emit TurnCompleted + response = synthesizer.last_response + if not response: + raise RuntimeError("No response available") + for event in synthesizer.finish_turn(): + yield AgentStreamChunk(event=event, response=response) + break + + # Execute client-side tools (emit tool execution step events) + tool_step_id = f"{turn_id}_step_{synthesizer.step_counter}" + synthesizer.step_counter += 1 + + yield AgentStreamChunk( + event=StepStarted( + step_id=tool_step_id, + step_type="tool_execution", + turn_id=turn_id, + metadata={"server_side": False}, + ) + ) + + tool_responses = self._run_tool_calls(function_calls_to_execute) + + yield AgentStreamChunk( + event=StepCompleted( + step_id=tool_step_id, + step_type="tool_execution", + turn_id=turn_id, + result=ToolExecutionStepResult( + step_id=tool_step_id, + tool_calls=function_calls_to_execute, tool_responses=tool_responses, - stream=True, - extra_headers=extra_headers or self.extra_headers, - ) - n_iter += 1 + ), + ) + ) - if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): - raise Exception("Max inference iterations reached") + # Continue loop with tool outputs as input + messages = [ + { + "type": "function_call_output", + "call_id": payload["call_id"], + "output": payload["content"], + } + for payload in tool_responses + ] class AsyncAgent: def __init__( self, - client: LlamaStackClient, - # begin deprecated - agent_config: Optional[AgentConfig] = None, - client_tools: Tuple[ClientTool, ...] = (), - # end deprecated + client: Any, # Accept any async OpenAI-compatible client + *, + model: str, + instructions: str, + tools: Optional[List[Union[Dict[str, Any], ClientTool, Callable[..., Any]]]] = None, tool_parser: Optional[ToolParser] = None, - model: Optional[str] = None, - instructions: Optional[str] = None, - tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, - tool_config: Optional[ToolConfig] = None, - sampling_params: Optional[SamplingParams] = None, - max_infer_iters: Optional[int] = None, - input_shields: Optional[List[str]] = None, - output_shields: Optional[List[str]] = None, - response_format: Optional[ResponseFormat] = None, - enable_session_persistence: Optional[bool] = None, extra_headers: Headers | None = None, - name: str | None = None, ): - """Construct an Agent with the given parameters. - - :param client: The LlamaStackClient instance. - :param agent_config: The AgentConfig instance. - ::deprecated: use other parameters instead - :param client_tools: A tuple of ClientTool instances. - ::deprecated: use tools instead - :param tool_parser: Custom logic that parses tool calls from a message. - :param model: The model to use for the agent. - :param instructions: The instructions for the agent. - :param tools: A list of tools for the agent. Values can be one of the following: - - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} - - a python function with a docstring. See @client_tool for more details. - - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" - - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent - - an instance of ClientTool: A client tool object. - :param tool_config: The tool configuration for the agent. - :param sampling_params: The sampling parameters for the agent. - :param max_infer_iters: The maximum number of inference iterations. - :param input_shields: The input shields for the agent. - :param output_shields: The output shields for the agent. - :param response_format: The response format for the agent. - :param enable_session_persistence: Whether to enable session persistence. - :param extra_headers: Extra headers to add to all requests sent by the agent. - :param name: Optional name for the agent, used in telemetry and identification. + """Construct an async Agent backed by the responses + conversations APIs. + + Args: + client: An async OpenAI-compatible client (e.g., openai.AsyncOpenAI() or AsyncLlamaStackClient). + The client must support the responses and conversations APIs. """ self.client = client - if agent_config is not None: - logger.warning("`agent_config` is deprecated. Use inlined parameters instead.") - if client_tools != (): - logger.warning("`client_tools` is deprecated. Use `tools` instead.") - - # Construct agent_config from parameters if not provided - if agent_config is None: - agent_config = AgentUtils.get_agent_config( - model=model, - instructions=instructions, - tools=tools, - tool_config=tool_config, - sampling_params=sampling_params, - max_infer_iters=max_infer_iters, - input_shields=input_shields, - output_shields=output_shields, - response_format=response_format, - enable_session_persistence=enable_session_persistence, - name=name, - ) - client_tools = AgentUtils.get_client_tools(tools) - - self.agent_config = agent_config - self.client_tools = {t.get_name(): t for t in client_tools} - self.sessions = [] self.tool_parser = tool_parser - self.builtin_tools = {} self.extra_headers = extra_headers - self._agent_id = None - - if isinstance(client, LlamaStackClient): - raise ValueError("AsyncAgent must be initialized with an AsyncLlamaStackClient") + self._model = model + self._instructions = instructions - @property - def agent_id(self) -> str: - if not self._agent_id: - raise RuntimeError("Agent ID not initialized. Call initialize() first.") - return self._agent_id + # Convert all tools to API format and separate client-side functions + self._tools, client_tools = AgentUtils.normalize_tools(tools) + self.client_tools = {tool.get_name(): tool for tool in client_tools} - async def initialize(self) -> None: - if self._agent_id: - return - - agentic_system_create_response = await self.client.alpha.agents.create( - agent_config=self.agent_config, - ) - self._agent_id = agentic_system_create_response.agent_id - for tg in self.agent_config["toolgroups"]: - for tool in await self.client.tools.list(toolgroup_id=tg, extra_headers=self.extra_headers): - self.builtin_tools[tool.name] = tg.get("args", {}) if isinstance(tg, dict) else {} + self.sessions: List[str] = [] async def create_session(self, session_name: str) -> str: - await self.initialize() - agentic_system_create_session_response = await self.client.alpha.agents.session.create( - agent_id=self.agent_id, - session_name=session_name, + conversation = await self.client.conversations.create( # type: ignore[union-attr] extra_headers=self.extra_headers, + metadata={"name": session_name}, ) - self.session_id = agentic_system_create_session_response.session_id - self.sessions.append(self.session_id) - return self.session_id + self.sessions.append(conversation.id) + return conversation.id async def create_turn( self, - messages: List[Union[UserMessage, ToolResponseMessage]], - session_id: Optional[str] = None, - toolgroups: Optional[List[Toolgroup]] = None, - documents: Optional[List[Document]] = None, + messages: List[Dict[str, Any]], + session_id: str, stream: bool = True, - ) -> AsyncIterator[AgentTurnResponseStreamChunk] | Turn: + ) -> AsyncIterator[AgentStreamChunk] | Any: if stream: - return self._create_turn_streaming(messages, session_id, toolgroups, documents) + return self._create_turn_streaming(messages, session_id) else: - chunks = [x async for x in self._create_turn_streaming(messages, session_id, toolgroups, documents)] - if not chunks: + last_chunk: Optional[AgentStreamChunk] = None + async for chunk in self._create_turn_streaming(messages, session_id): + last_chunk = chunk + if not last_chunk or not last_chunk.response: raise Exception("Turn did not complete") - return chunks[-1].event.payload.turn + return last_chunk.response - async def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponseParam]: - responses = [] + async def _run_tool_calls(self, tool_calls: List[ToolCall]) -> List[ToolResponsePayload]: + responses: List[ToolResponsePayload] = [] for tool_call in tool_calls: - responses.append(await self._run_single_tool(tool_call)) + raw_result = await self._run_single_tool(tool_call) + responses.append(ToolUtils.normalize_tool_response(raw_result)) return responses - async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: - # custom client tools + async def _run_single_tool(self, tool_call: ToolCall) -> Any: + # Execute client-side tools if tool_call.tool_name in self.client_tools: tool = self.client_tools[tool_call.tool_name] result_message = await tool.async_run( [ CompletionMessage( role="assistant", - content=tool_call.tool_name, + content=tool_call.arguments, tool_calls=[tool_call], stop_reason="end_of_turn", ) @@ -526,86 +367,181 @@ async def _run_single_tool(self, tool_call: ToolCall) -> ToolResponseParam: ) return result_message - # builtin tools executed by tool_runtime - if tool_call.tool_name in self.builtin_tools: - tool_result = await self.client.tool_runtime.invoke_tool( - tool_name=tool_call.tool_name, - kwargs={ - **tool_call.arguments, - **self.builtin_tools[tool_call.tool_name], - }, - extra_headers=self.extra_headers, - ) - return ToolResponseParam( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=tool_result.content, - ) - - # cannot find tools - return ToolResponseParam( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=f"Unknown tool `{tool_call.tool_name}` was called.", - ) + # Server-side tools should never reach here (they execute within response stream) + # If we get here, it's an error + return { + "call_id": tool_call.call_id, + "tool_name": tool_call.tool_name, + "content": f"Unknown tool `{tool_call.tool_name}` was called.", + } async def _create_turn_streaming( self, - messages: List[Union[UserMessage, ToolResponseMessage]], - session_id: Optional[str] = None, - toolgroups: Optional[List[Toolgroup]] = None, - documents: Optional[List[Document]] = None, - ) -> AsyncIterator[AgentTurnResponseStreamChunk]: - n_iter = 0 - - # 1. create an agent turn - turn_response = await self.client.alpha.agents.turn.create( - agent_id=self.agent_id, - # use specified session_id or last session created - session_id=session_id or self.session_id[-1], - messages=messages, - stream=True, - documents=documents, - toolgroups=toolgroups, - extra_headers=self.extra_headers, - ) + messages: List[Dict[str, Any]], + session_id: str, + ) -> AsyncIterator[AgentStreamChunk]: + await self.initialize() - # 2. process turn and resume if there's a tool call - is_turn_complete = False - while not is_turn_complete: - is_turn_complete = True - async for chunk in turn_response: - if hasattr(chunk, "error"): - yield chunk + # Generate turn_id + turn_id = f"turn_{uuid4().hex[:12]}" + + # Create synthesizer + synthesizer = TurnEventSynthesizer(session_id=session_id, turn_id=turn_id) + + request_headers = self.extra_headers + + # Main turn loop + while True: + # Create response stream + raw_stream = await self.client.responses.create( + model=self._model, + instructions=self._instructions, + conversation=session_id, + input=messages, + tools=self._tools, + stream=True, + extra_headers=request_headers, + ) + + # Process events + function_calls_to_execute: List[ToolCall] = [] # Only client-side! + + for high_level_event in synthesizer.process_raw_stream(raw_stream): + # Handle failures + if isinstance(high_level_event, TurnFailed): + yield AgentStreamChunk(event=high_level_event) return - tool_calls = AgentUtils.get_tool_calls(chunk, self.tool_parser) - if not tool_calls: - yield chunk - else: - is_turn_complete = False - # End of turn is reached, do not resume even if there's a tool call - if not self.tool_parser and chunk.event.payload.turn.output_message.stop_reason in {"end_of_turn"}: - yield chunk - break - - turn_id = AgentUtils.get_turn_id(chunk) - if n_iter == 0: - yield chunk - - # run the tools - tool_responses = await self._run_tool_calls(tool_calls) - - # pass it to next iteration - turn_response = await self.client.alpha.agents.turn.resume( - agent_id=self.agent_id, - session_id=session_id or self.session_id[-1], - turn_id=turn_id, + # Track function calls that need client execution + if isinstance(high_level_event, StepCompleted): + if high_level_event.step_type == "inference": + result = high_level_event.result + if result.function_calls: # Only client-side function calls + function_calls_to_execute = result.function_calls + + yield AgentStreamChunk(event=high_level_event) + + # If no client-side function calls, turn is done + if not function_calls_to_execute: + # Emit TurnCompleted + response = synthesizer.last_response + if not response: + raise RuntimeError("No response available") + for event in synthesizer.finish_turn(): + yield AgentStreamChunk(event=event, response=response) + break + + # Execute client-side tools (emit tool execution step events) + tool_step_id = f"{turn_id}_step_{synthesizer.step_counter}" + synthesizer.step_counter += 1 + + yield AgentStreamChunk( + event=StepStarted( + step_id=tool_step_id, + step_type="tool_execution", + turn_id=turn_id, + metadata={"server_side": False}, + ) + ) + + tool_responses = await self._run_tool_calls(function_calls_to_execute) + + yield AgentStreamChunk( + event=StepCompleted( + step_id=tool_step_id, + step_type="tool_execution", + turn_id=turn_id, + result=ToolExecutionStepResult( + step_id=tool_step_id, + tool_calls=function_calls_to_execute, tool_responses=tool_responses, - stream=True, - extra_headers=self.extra_headers, - ) - n_iter += 1 + ), + ) + ) + + # Continue loop with tool outputs as input + messages = [ + { + "type": "function_call_output", + "call_id": payload["call_id"], + "output": payload["content"], + } + for payload in tool_responses + ] + - if self.tool_parser and n_iter > self.agent_config.get("max_infer_iters", DEFAULT_MAX_ITER): - raise Exception("Max inference iterations reached") +class AgentUtils: + @staticmethod + def get_client_tools( + tools: Optional[List[Union[Dict[str, Any], ClientTool, Callable[..., Any]]]], + ) -> List[ClientTool]: + if not tools: + return [] + + # Wrap any function in client_tool decorator + tools = [client_tool(tool) if (callable(tool) and not isinstance(tool, ClientTool)) else tool for tool in tools] + return [tool for tool in tools if isinstance(tool, ClientTool)] + + @staticmethod + def get_tool_calls(chunk: AgentStreamChunk, tool_parser: Optional[ToolParser] = None) -> List[ToolCall]: + if not isinstance(chunk.event, StepProgress): + return [] + + delta = chunk.event.delta + if not isinstance(delta, ToolCallIssuedDelta) or delta.tool_type != "function": + return [] + + tool_call = ToolCall( + call_id=delta.call_id, + tool_name=delta.tool_name, + arguments=delta.arguments, + ) + + if tool_parser: + completion = CompletionMessage( + role="assistant", + content="", + tool_calls=[tool_call], + stop_reason="end_of_turn", + ) + return tool_parser.get_tool_calls(completion) + + return [tool_call] + + @staticmethod + def get_turn_id(chunk: AgentStreamChunk) -> Optional[str]: + return chunk.response.turn.turn_id if chunk.response else None + + @staticmethod + def normalize_tools( + tools: Optional[List[Union[Dict[str, Any], ClientTool, Callable[..., Any]]]], + ) -> Tuple[List[Dict[str, Any]], List[ClientTool]]: + """Convert all tools to API format dicts. + + Returns: + - List of tool dicts for responses.create(tools=...) + - List of ClientTool instances for client-side execution + """ + if not tools: + return [], [] + + tool_dicts: List[Dict[str, Any]] = [] + client_tool_instances: List[ClientTool] = [] + + for tool in tools: + # Convert callable to ClientTool + if callable(tool) and not isinstance(tool, ClientTool): + tool = client_tool(tool) + + if isinstance(tool, ClientTool): + # Convert ClientTool to function tool dict + tool_def = tool.get_tool_definition() + tool_dicts.append(tool_def) + client_tool_instances.append(tool) + elif isinstance(tool, dict): + # Server-side tool dict (file_search, web_search, etc.) + tool_dicts.append(tool) # type: ignore[arg-type] + else: + raise TypeError(f"Unsupported tool type: {type(tool)!r}") + + return tool_dicts, client_tool_instances diff --git a/src/llama_stack_client/lib/agents/client_tool.py b/src/llama_stack_client/lib/agents/client_tool.py index 09164361..63a5bfd7 100644 --- a/src/llama_stack_client/lib/agents/client_tool.py +++ b/src/llama_stack_client/lib/agents/client_tool.py @@ -21,9 +21,7 @@ from typing_extensions import TypedDict -from llama_stack_client.types import CompletionMessage, Message -from llama_stack_client.types.alpha import ToolResponse -from llama_stack_client.types.tool_def_param import ToolDefParam +from .types import CompletionMessage, Message, ToolDefinition, ToolResponse class JSONSchema(TypedDict, total=False): @@ -61,13 +59,13 @@ def get_input_schema(self) -> JSONSchema: def get_instruction_string(self) -> str: return f"Use the function '{self.get_name()}' to: {self.get_description()}" - def get_tool_definition(self) -> ToolDefParam: - return ToolDefParam( - name=self.get_name(), - description=self.get_description(), - input_schema=self.get_input_schema(), - metadata={}, - ) + def get_tool_definition(self) -> ToolDefinition: + return { + "type": "function", + "name": self.get_name(), + "description": self.get_description(), + "parameters": self.get_input_schema(), + } def run( self, @@ -82,7 +80,6 @@ def run( metadata = {} try: params = json.loads(tool_call.arguments) - response = self.run_impl(**params) if isinstance(response, dict) and "content" in response: content = json.dumps(response["content"], ensure_ascii=False) @@ -108,7 +105,8 @@ async def async_run( tool_call = last_message.tool_calls[0] metadata = {} try: - response = await self.async_run_impl(**tool_call.arguments) + params = json.loads(tool_call.arguments) + response = await self.async_run_impl(**params) if isinstance(response, dict) and "content" in response: content = json.dumps(response["content"], ensure_ascii=False) metadata = response.get("metadata", {}) diff --git a/src/llama_stack_client/lib/agents/event_logger.py b/src/llama_stack_client/lib/agents/event_logger.py index b4e1a219..8b56f398 100644 --- a/src/llama_stack_client/lib/agents/event_logger.py +++ b/src/llama_stack_client/lib/agents/event_logger.py @@ -4,142 +4,131 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from typing import Any, Iterator, Optional - -from termcolor import cprint - -from llama_stack_client.types import InterleavedContent - - -def interleaved_content_as_str(content: InterleavedContent, sep: str = " ") -> str: - def _process(c: Any) -> str: - if isinstance(c, str): - return c - elif hasattr(c, "type"): - if c.type == "text": - return c.text - elif c.type == "image": - return "" - else: - raise ValueError(f"Unexpected type {c}") - else: - raise ValueError(f"Unsupported content type: {type(c)}") - - if isinstance(content, list): - return sep.join(_process(c) for c in content) - else: - return _process(content) - - -class TurnStreamPrintableEvent: - def __init__( - self, - role: Optional[str] = None, - content: str = "", - end: Optional[str] = "\n", - color: str = "white", - ) -> None: - self.role = role - self.content = content - self.color = color - self.end = "\n" if end is None else end - - def __str__(self) -> str: - if self.role is not None: - return f"{self.role}> {self.content}" - else: - return f"{self.content}" - - def print(self, flush: bool = True) -> None: - cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush) - - -class TurnStreamEventPrinter: - def yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEvent]: - for printable_event in self._yield_printable_events(chunk): - yield printable_event - - def _yield_printable_events(self, chunk: Any) -> Iterator[TurnStreamPrintableEvent]: - if hasattr(chunk, "error"): - yield TurnStreamPrintableEvent(role=None, content=chunk.error["message"], color="red") - return - - event = chunk.event - event_type = event.payload.event_type - - if event_type in {"turn_start", "turn_complete", "turn_awaiting_input"}: - # Currently not logging any turn realted info - yield TurnStreamPrintableEvent(role=None, content="", end="", color="grey") - return - - step_type = event.payload.step_type - # handle safety - if step_type == "shield_call" and event_type == "step_complete": - violation = event.payload.step_details.violation - if not violation: - yield TurnStreamPrintableEvent(role=step_type, content="No Violation", color="magenta") - else: - yield TurnStreamPrintableEvent( - role=step_type, - content=f"{violation.metadata} {violation.user_message}", - color="red", - ) - - # handle inference - if step_type == "inference": - if event_type == "step_start": - yield TurnStreamPrintableEvent(role=step_type, content="", end="", color="yellow") - elif event_type == "step_progress": - if event.payload.delta.type == "tool_call": - if isinstance(event.payload.delta.tool_call, str): - yield TurnStreamPrintableEvent( - role=None, - content=event.payload.delta.tool_call, - end="", - color="cyan", - ) - elif event.payload.delta.type == "text": - yield TurnStreamPrintableEvent( - role=None, - content=event.payload.delta.text, - end="", - color="yellow", - ) - else: - # step complete - yield TurnStreamPrintableEvent(role=None, content="") - - # handle tool_execution - if step_type == "tool_execution" and event_type == "step_complete": - # Only print tool calls and responses at the step_complete event - details = event.payload.step_details - for t in details.tool_calls: - yield TurnStreamPrintableEvent( - role=step_type, - content=f"Tool:{t.tool_name} Args:{t.arguments}", - color="green", - ) - - for r in details.tool_responses: - if r.tool_name == "query_from_memory": - inserted_context = interleaved_content_as_str(r.content) - content = f"fetched {len(inserted_context)} bytes from memory" - - yield TurnStreamPrintableEvent( - role=step_type, - content=content, - color="cyan", - ) - else: - yield TurnStreamPrintableEvent( - role=step_type, - content=f"Tool:{r.tool_name} Response:{r.content}", - color="green", - ) - - -class EventLogger: - def log(self, event_generator: Iterator[Any]) -> Iterator[TurnStreamPrintableEvent]: - printer = TurnStreamEventPrinter() +"""Event logger for agent interactions. + +This module provides a simple logger that converts agent stream events +into human-readable printable strings for console output. +""" + +from typing import Iterator + +from .turn_events import ( + AgentStreamChunk, + TurnStarted, + TurnCompleted, + TurnFailed, + StepStarted, + StepProgress, + StepCompleted, + TextDelta, + ToolCallIssuedDelta, + ToolCallDelta, +) + +__all__ = ["AgentEventLogger", "EventLogger"] + + +class AgentEventLogger: + """Logger for agent events with turn/step semantics. + + This logger converts high-level agent events into printable strings + that can be displayed to users. It handles: + - Turn lifecycle events + - Step boundaries (inference, tool execution) + - Streaming content (text, tool calls) + - Server-side and client-side tool execution + + Usage: + logger = AgentEventLogger() + for chunk in agent.create_turn(...): + for printable in logger.log([chunk]): + print(printable, end="", flush=True) + """ + + def log(self, event_generator: Iterator[AgentStreamChunk]) -> Iterator[str]: + """Generate printable strings from agent stream chunks. + + Args: + event_generator: Iterator of AgentStreamChunk objects + + Yields: + Printable string fragments + """ for chunk in event_generator: - yield from printer.yield_printable_events(chunk) + event = chunk.event + + if isinstance(event, TurnStarted): + # Optionally log turn start (commented out to reduce noise) + # yield f"[Turn {event.turn_id}]\n" + pass + + elif isinstance(event, StepStarted): + if event.step_type == "inference": + # Indicate model is thinking (no newline) + yield "šŸ¤” " + elif event.step_type == "tool_execution": + # Indicate tools are executing + server_side = event.metadata and event.metadata.get("server_side", False) + if server_side: + tool_type = event.metadata.get("tool_type", "tool") + yield f"\nšŸ”§ Executing {tool_type} (server-side)...\n" + else: + yield "\nšŸ”§ Executing function tools (client-side)...\n" + + elif isinstance(event, StepProgress): + if event.step_type == "inference": + if isinstance(event.delta, TextDelta): + # Stream text as it comes + yield event.delta.text + + elif isinstance(event.delta, ToolCallIssuedDelta): + # Log client-side function calls (server-side handled as separate tool_execution steps) + if event.delta.tool_type == "function": + # Client-side function call + yield f"\nšŸ“ž Function call: {event.delta.tool_name}({event.delta.arguments})" + + elif isinstance(event.delta, ToolCallDelta): + # Optionally stream tool arguments (can be noisy, so commented out) + # yield event.delta.arguments_delta + pass + + elif event.step_type == "tool_execution": + # Handle tool execution progress (for server-side tools) + if isinstance(event.delta, ToolCallIssuedDelta): + # Don't log again, already logged at StepStarted + pass + elif isinstance(event.delta, ToolCallDelta): + # Optionally log argument streaming + pass + + elif isinstance(event, StepCompleted): + if event.step_type == "inference": + result = event.result + # Server-side tools already logged during progress + if not result.function_calls: + # End of inference with no function calls + yield "\n" + + elif event.step_type == "tool_execution": + # Log client-side tool execution results + result = event.result + for resp in result.tool_responses: + tool_name = resp.get("tool_name", "unknown") + content = resp.get("content", "") + # Truncate long responses for readability + if isinstance(content, str) and len(content) > 100: + content = content[:100] + "..." + yield f" → {tool_name}: {content}\n" + + elif isinstance(event, TurnCompleted): + # Optionally log turn completion (commented out to reduce noise) + # yield f"\n[Completed in {event.num_steps} steps]\n" + pass + + elif isinstance(event, TurnFailed): + # Always log failures + yield f"\nāŒ Turn failed: {event.error_message}\n" + + +# Alias for backwards compatibility +EventLogger = AgentEventLogger diff --git a/src/llama_stack_client/lib/agents/event_synthesizer.py b/src/llama_stack_client/lib/agents/event_synthesizer.py new file mode 100644 index 00000000..22ce5c24 --- /dev/null +++ b/src/llama_stack_client/lib/agents/event_synthesizer.py @@ -0,0 +1,449 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Translate Responses API stream events into structured turn events. + +TurnEventSynthesizer keeps just enough state to expose turns and steps for +agents. It consumes the raw Responses API stream and emits the higher-level +events defined in ``turn_events.py`` without introducing an intermediate +low-level event layer. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass +from typing import Any, Dict, Iterable, Iterator, List, Optional + +from logging import getLogger + +from .types import ToolCall + +from .turn_events import ( + AgentEvent, + InferenceStepResult, + StepCompleted, + StepProgress, + StepStarted, + TextDelta, + ToolCallDelta, + ToolCallIssuedDelta, + ToolExecutionStepResult, + TurnCompleted, + TurnFailed, + TurnStarted, +) + +logger = getLogger(__name__) + + +@dataclass +class _ToolCallState: + call_id: str + tool_name: str + tool_type: str + server_side: bool + arguments: str = "" + + def update(self, *, delta: Optional[str] = None, final: Optional[str] = None) -> None: + if final is not None: + self.arguments = final or "{}" + elif delta: + self.arguments += delta + + def as_tool_call(self) -> ToolCall: + payload = self.arguments or "{}" + return ToolCall(call_id=self.call_id, tool_name=self.tool_name, arguments=payload) + + +class TurnEventSynthesizer: + """Produce turn/step events directly from Responses API streaming events.""" + + def __init__(self, session_id: str, turn_id: str): + self.session_id = session_id + self.turn_id = turn_id + + self.step_counter = 0 + self.current_step_id: Optional[str] = None + self.current_step_type: Optional[str] = None + + self.current_response_id: Optional[str] = None + self.text_parts: List[str] = [] + self._function_call_ids: List[str] = [] + self._tool_calls: Dict[str, _ToolCallState] = {} + + self.turn_started = False + self.all_response_ids: List[str] = [] + self.last_response: Optional[Any] = None + + # ------------------------------------------------------------------ helpers + + def _next_step_id(self) -> str: + step_id = f"{self.turn_id}_step_{self.step_counter}" + self.step_counter += 1 + return step_id + + def _maybe_emit_turn_started(self) -> Iterator[AgentEvent]: + if not self.turn_started: + self.turn_started = True + yield TurnStarted(turn_id=self.turn_id, session_id=self.session_id) + + def _start_inference_step( + self, *, response_id: Optional[str] = None, reset_tool_state: bool = False + ) -> Iterator[AgentEvent]: + if response_id: + self.current_response_id = response_id + if reset_tool_state: + self._tool_calls = {} + self.text_parts = [] + self._function_call_ids = [] + self.current_step_id = self._next_step_id() + self.current_step_type = "inference" + yield StepStarted(step_id=self.current_step_id, step_type="inference", turn_id=self.turn_id) + + def _complete_inference_step(self, *, stop_reason: str, response_id: Optional[str] = None) -> Iterator[AgentEvent]: + if self.current_step_type != "inference": + return + step_id = self.current_step_id or self._next_step_id() + resolved_response_id = response_id or self.current_response_id or "" + self.current_response_id = resolved_response_id + + function_calls: List[ToolCall] = [] + for call_id in self._function_call_ids: + state = self._tool_calls.get(call_id) + if state is None: + continue + function_calls.append(state.as_tool_call()) + + yield StepCompleted( + step_id=step_id, + step_type="inference", + turn_id=self.turn_id, + result=InferenceStepResult( + step_id=step_id, + response_id=resolved_response_id, + text_content="".join(self.text_parts), + function_calls=function_calls, + server_tool_executions=[], + stop_reason=stop_reason, + ), + ) + + for call_id in list(self._function_call_ids): + # Drop client-side call state once we hand it back to the agent. + self._tool_calls.pop(call_id, None) + self._function_call_ids = [] + self.text_parts = [] + self.current_step_id = None + self.current_step_type = None + + def _start_tool_execution_step(self, call_state: _ToolCallState) -> Iterator[AgentEvent]: + self.current_step_id = self._next_step_id() + self.current_step_type = "tool_execution" + yield StepStarted( + step_id=self.current_step_id, + step_type="tool_execution", + turn_id=self.turn_id, + metadata={"server_side": True, "tool_type": call_state.tool_type, "tool_name": call_state.tool_name}, + ) + yield StepProgress( + step_id=self.current_step_id, + step_type="tool_execution", + turn_id=self.turn_id, + delta=ToolCallIssuedDelta( + call_id=call_state.call_id, + tool_type=call_state.tool_type, # type: ignore[arg-type] + tool_name=call_state.tool_name, + arguments=call_state.arguments or "{}", + ), + ) + + def _complete_tool_execution_step(self, call_state: _ToolCallState) -> Iterator[AgentEvent]: + if self.current_step_type != "tool_execution": + return + step_id = self.current_step_id or self._next_step_id() + yield StepCompleted( + step_id=step_id, + step_type="tool_execution", + turn_id=self.turn_id, + result=ToolExecutionStepResult( + step_id=step_id, + tool_calls=[call_state.as_tool_call()], + tool_responses=[], + ), + ) + self._tool_calls.pop(call_state.call_id, None) + self.current_step_id = None + self.current_step_type = None + + @staticmethod + def _coerce_arguments(payload: Any) -> Optional[str]: + if payload is None: + return None + if isinstance(payload, str): + return payload + try: + return json.dumps(payload) + except Exception: # pragma: no cover - defensive + return str(payload) + + def _register_tool_call( + self, + *, + call_id: Optional[str], + tool_name: str, + tool_type: str, + arguments: Optional[str], + ) -> _ToolCallState: + resolved_call_id = call_id or f"{tool_name}_{len(self._tool_calls)}" + state = _ToolCallState( + call_id=resolved_call_id, + tool_name=tool_name, + tool_type=tool_type, + server_side=tool_type != "function", + ) + if arguments: + state.update(final=arguments) + self._tool_calls[resolved_call_id] = state + return state + + def _classify_tool_type(self, tool_name: str) -> str: + server_side_tools = { + "file_search", + "file_search_call", + "knowledge_search", + "web_search", + "web_search_call", + "query_from_memory", + "mcp_call", + "mcp_list_tools", + "memory_retrieval", + } + + if tool_name in server_side_tools: + if tool_name in {"file_search_call", "knowledge_search"}: + return "file_search" + if tool_name == "web_search_call": + return "web_search" + return tool_name + + return "function" + + # ------------------------------------------------------------------ handlers + + def process_raw_stream(self, events: Iterable[Any]) -> Iterator[AgentEvent]: + current_response_id: Optional[str] = None + + for event in events: + yield from self._maybe_emit_turn_started() + + response_id = getattr(event, "response_id", None) + if response_id is None and hasattr(event, "response"): + response = getattr(event, "response") + response_id = getattr(response, "id", None) + if response_id is not None: + current_response_id = response_id + + event_type = getattr(event, "type", None) + if not event_type: + logger.debug("Unhandled stream event with no type: %r", event) + continue + + if event_type == "response.in_progress": + response = getattr(event, "response", None) + response_id = getattr(response, "id", current_response_id) + if response_id is None: + continue + if not self.all_response_ids or self.all_response_ids[-1] != response_id: + self.all_response_ids.append(response_id) + yield from self._start_inference_step(response_id=response_id, reset_tool_state=True) + + elif event_type == "response.output_text.delta": + if self.current_step_type != "inference": + continue + text = getattr(event, "delta", "") or "" + if not text: + continue + self.text_parts.append(text) + yield StepProgress( + step_id=self.current_step_id or "", + step_type="inference", + turn_id=self.turn_id, + delta=TextDelta(text=text), + ) + + elif event_type == "response.output_text.done": + # Text completions are tracked via the final StepCompleted event. + continue + + elif event_type == "response.output_item.added": + yield from self._handle_output_item_added(getattr(event, "item", None)) + + elif event_type == "response.output_item.delta": + yield from self._handle_output_item_delta(getattr(event, "delta", None)) + + elif event_type == "response.output_item.done": + yield from self._handle_output_item_done(getattr(event, "item", None)) + + elif event_type == "response.completed": + response = getattr(event, "response", None) + if response is not None: + self.last_response = response + stop_reason = "tool_calls" if self._function_call_ids else "end_of_turn" + yield from self._complete_inference_step(stop_reason=stop_reason, response_id=current_response_id) + + elif event_type == "response.failed": + response = getattr(event, "response", None) + error_obj = getattr(response, "error", None) + error_message = getattr(error_obj, "message", None) if error_obj else None + yield TurnFailed( + turn_id=self.turn_id, + session_id=self.session_id, + error_message=error_message or "Unknown error", + ) + + else: # pragma: no cover - depends on streaming responses + # Allow unknown streaming events to pass silently; they are often ancillary metadata. + continue + + def _handle_output_item_added(self, item: Any) -> Iterator[AgentEvent]: + if item is None: + return + item_type = getattr(item, "type", None) + if item_type is None: + return + + if item_type == "message": + # Messages mirror text deltas, nothing extra to emit. + return + + if item_type in { + "function_call", + "web_search", + "web_search_call", + "mcp_call", + "mcp_list_tools", + "file_search_call", + }: + call_id = getattr(item, "call_id", None) or getattr(item, "id", None) + tool_name = getattr(item, "name", None) or getattr(item, "type", "") + arguments = self._coerce_arguments(getattr(item, "arguments", None)) + tool_type = self._classify_tool_type(tool_name if item_type == "function_call" else item_type) + + state = self._register_tool_call( + call_id=call_id, + tool_name=tool_name, + tool_type=tool_type, + arguments=arguments, + ) + + if state.server_side: + yield from self._complete_inference_step( + stop_reason="server_tool_call", response_id=self.current_response_id + ) + yield from self._start_tool_execution_step(state) + else: + if state.call_id not in self._function_call_ids: + self._function_call_ids.append(state.call_id) + yield StepProgress( + step_id=self.current_step_id or "", + step_type="inference", + turn_id=self.turn_id, + delta=ToolCallIssuedDelta( + call_id=state.call_id, + tool_type="function", + tool_name=state.tool_name, + arguments=state.arguments or "{}", + ), + ) + + def _handle_output_item_delta(self, delta: Any) -> Iterator[AgentEvent]: + if delta is None: + return + delta_type = getattr(delta, "type", None) + if delta_type not in { + "function_call", + "web_search", + "web_search_call", + "mcp_call", + "mcp_list_tools", + "file_search_call", + }: + return + + call_id = getattr(delta, "call_id", None) or getattr(delta, "id", None) + if call_id is None: + return + + arguments_delta = getattr(delta, "arguments_delta", None) + if arguments_delta is None: + arguments_delta = getattr(delta, "arguments", None) + if arguments_delta is None and isinstance(delta, dict): + arguments_delta = delta.get("arguments_delta") or delta.get("arguments") + if arguments_delta is None: + return + + state = self._tool_calls.get(call_id) + if state is None: + return + + state.update(delta=arguments_delta) + step_type = "tool_execution" if state.server_side else "inference" + yield StepProgress( + step_id=self.current_step_id or "", + step_type=step_type, # type: ignore[arg-type] + turn_id=self.turn_id, + delta=ToolCallDelta(call_id=state.call_id, arguments_delta=arguments_delta), + ) + + def _handle_output_item_done(self, item: Any) -> Iterator[AgentEvent]: + if item is None: + return + + item_type = getattr(item, "type", None) + if item_type not in { + "function_call", + "web_search", + "web_search_call", + "mcp_call", + "mcp_list_tools", + "file_search_call", + }: + return + + call_id = getattr(item, "call_id", None) or getattr(item, "id", None) + if call_id is None: + return + + state = self._tool_calls.get(call_id) + if state is None: + return + + arguments = self._coerce_arguments(getattr(item, "arguments", None)) + if arguments: + state.update(final=arguments) + + if state.server_side: + yield from self._complete_tool_execution_step(state) + # Start a fresh inference step so the model can continue reasoning. + yield from self._start_inference_step() + else: + if call_id not in self._function_call_ids: + self._function_call_ids.append(call_id) + + # ------------------------------------------------------------------ turn end + + def finish_turn(self) -> Iterator[AgentEvent]: + if not self.last_response: + raise RuntimeError("Cannot finish turn without a response") + + yield TurnCompleted( + turn_id=self.turn_id, + session_id=self.session_id, + final_text=self.last_response.output_text, + response_ids=self.all_response_ids, + num_steps=self.step_counter, + ) diff --git a/src/llama_stack_client/lib/agents/react/agent.py b/src/llama_stack_client/lib/agents/react/agent.py index 919f0420..cd14c6bf 100644 --- a/src/llama_stack_client/lib/agents/react/agent.py +++ b/src/llama_stack_client/lib/agents/react/agent.py @@ -4,225 +4,91 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. import logging -from typing import Any, Callable, List, Optional, Tuple, Union - -from llama_stack_client import LlamaStackClient -from llama_stack_client.types.agent_create_params import AgentConfig -from llama_stack_client.types.agents.turn_create_params import Toolgroup -from llama_stack_client.types.shared_params.agent_config import ToolConfig -from llama_stack_client.types.shared_params.response_format import ResponseFormat -from llama_stack_client.types.shared_params.sampling_params import SamplingParams +from collections.abc import Mapping +from typing import Any, Callable, Dict, List, Optional, Tuple, Union from ...._types import Headers from ..agent import Agent, AgentUtils from ..client_tool import ClientTool from ..tool_parser import ToolParser from .prompts import DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE -from .tool_parser import ReActOutput, ReActToolParser +from .tool_parser import ReActToolParser logger = logging.getLogger(__name__) -def get_tool_defs( - client: LlamaStackClient, builtin_toolgroups: Tuple[Toolgroup] = (), client_tools: Tuple[ClientTool] = () -): - tool_defs = [] - for x in builtin_toolgroups: - if isinstance(x, str): - toolgroup_id = x - else: - toolgroup_id = x["name"] - tool_defs.extend( - [ - { - "name": tool.identifier, - "description": tool.description, - "input_schema": tool.input_schema, - } - for tool in client.tools.list(toolgroup_id=toolgroup_id) - ] - ) +def _tool_definition_from_mapping(tool: Mapping[str, Any]) -> Dict[str, Any]: + name = tool.get("name") or tool.get("identifier") or tool.get("tool_name") or tool.get("type") or "tool" + description = tool.get("description") or tool.get("summary") or "" + parameters = tool.get("parameters") or tool.get("input_schema") or {} + return { + "name": str(name), + "description": str(description), + "input_schema": parameters, + } + +def _collect_tool_definitions( + server_tools: Tuple[Mapping[str, Any], ...], + client_tools: Tuple[ClientTool, ...], +) -> List[Dict[str, Any]]: + tool_defs = [_tool_definition_from_mapping(tool) for tool in server_tools] tool_defs.extend( - [ - { - "name": tool.get_name(), - "description": tool.get_description(), - "input_schema": tool.get_input_schema(), - } - for tool in client_tools - ] + { + "name": tool.get_name(), + "description": tool.get_description(), + "input_schema": tool.get_input_schema(), + } + for tool in client_tools ) return tool_defs -def get_default_react_instructions( - client: LlamaStackClient, builtin_toolgroups: Tuple[str] = (), client_tools: Tuple[ClientTool] = () -): - tool_defs = get_tool_defs(client, builtin_toolgroups, client_tools) - tool_names = ", ".join([x["name"] for x in tool_defs]) - tool_descriptions = "\n".join([f"- {x['name']}: {x}" for x in tool_defs]) +def get_default_react_instructions(tool_defs: List[Dict[str, Any]]) -> str: + tool_names = ", ".join([definition["name"] for definition in tool_defs]) + tool_descriptions = "\n".join( + [ + f"- {definition['name']}: {definition['description'] or definition['input_schema']}" + for definition in tool_defs + ] + ) instruction = DEFAULT_REACT_AGENT_SYSTEM_PROMPT_TEMPLATE.replace("<>", tool_names).replace( "<>", tool_descriptions ) return instruction -def get_agent_config_DEPRECATED( - client: LlamaStackClient, - model: str, - builtin_toolgroups: Tuple[str] = (), - client_tools: Tuple[ClientTool] = (), - json_response_format: bool = False, - custom_agent_config: Optional[AgentConfig] = None, -) -> AgentConfig: - if custom_agent_config is None: - instruction = get_default_react_instructions(client, builtin_toolgroups, client_tools) - - # user default toolgroups - agent_config = AgentConfig( - model=model, - instructions=instruction, - toolgroups=builtin_toolgroups, - client_tools=[client_tool.get_tool_definition() for client_tool in client_tools], - tool_config={ - "tool_choice": "auto", - "system_message_behavior": "replace", - }, - input_shields=[], - output_shields=[], - enable_session_persistence=False, - ) - else: - agent_config = custom_agent_config - - if json_response_format: - agent_config["response_format"] = { - "type": "json_schema", - "json_schema": ReActOutput.model_json_schema(), - } - - return agent_config - - class ReActAgent(Agent): - """ReAct agent. - - Simple wrapper around Agent to add prepare prompts for creating a ReAct agent from a list of tools. - """ - def __init__( self, - client: LlamaStackClient, + client: Any, + *, model: str, - tool_parser: ToolParser = ReActToolParser(), + tools: Optional[List[Union[Dict[str, Any], ClientTool, Callable[..., Any]]]] = None, + tool_parser: Optional[ToolParser] = None, instructions: Optional[str] = None, - tools: Optional[List[Union[Toolgroup, ClientTool, Callable[..., Any]]]] = None, - tool_config: Optional[ToolConfig] = None, - sampling_params: Optional[SamplingParams] = None, - max_infer_iters: Optional[int] = None, - input_shields: Optional[List[str]] = None, - output_shields: Optional[List[str]] = None, - response_format: Optional[ResponseFormat] = None, - enable_session_persistence: Optional[bool] = None, - json_response_format: bool = False, - builtin_toolgroups: Tuple[str] = (), # DEPRECATED - client_tools: Tuple[ClientTool] = (), # DEPRECATED - custom_agent_config: Optional[AgentConfig] = None, # DEPRECATED extra_headers: Headers | None = None, + json_response_format: bool = False, ): - """Construct an Agent with the given parameters. - - :param client: The LlamaStackClient instance. - :param custom_agent_config: The AgentConfig instance. - ::deprecated: use other parameters instead - :param client_tools: A tuple of ClientTool instances. - ::deprecated: use tools instead - :param builtin_toolgroups: A tuple of Toolgroup instances. - ::deprecated: use tools instead - :param tool_parser: Custom logic that parses tool calls from a message. - :param model: The model to use for the agent. - :param instructions: The instructions for the agent. - :param tools: A list of tools for the agent. Values can be one of the following: - - dict representing a toolgroup/tool with arguments: e.g. {"name": "builtin::rag/knowledge_search", "args": {"vector_db_ids": [123]}} - - a python function with a docstring. See @client_tool for more details. - - str representing a tool within a toolgroup: e.g. "builtin::rag/knowledge_search" - - str representing a toolgroup_id: e.g. "builtin::rag", "builtin::code_interpreter", where all tools in the toolgroup will be added to the agent - - an instance of ClientTool: A client tool object. - :param tool_config: The tool configuration for the agent. - :param sampling_params: The sampling parameters for the agent. - :param max_infer_iters: The maximum number of inference iterations. - :param input_shields: The input shields for the agent. - :param output_shields: The output shields for the agent. - :param response_format: The response format for the agent. - :param enable_session_persistence: Whether to enable session persistence. - :param json_response_format: Whether to use the json response format with default ReAct output schema. - ::deprecated: use response_format instead - :param extra_headers: Extra headers to add to all requests sent by the agent. - """ - use_deprecated_params = False - if custom_agent_config is not None: - logger.warning("`custom_agent_config` is deprecated. Use inlined parameters instead.") - use_deprecated_params = True - if client_tools != (): - logger.warning("`client_tools` is deprecated. Use `tools` instead.") - use_deprecated_params = True - if builtin_toolgroups != (): - logger.warning("`builtin_toolgroups` is deprecated. Use `tools` instead.") - use_deprecated_params = True + if json_response_format: + logger.warning("`json_response_format` is deprecated and will be removed in a future release.") - if use_deprecated_params: - agent_config = get_agent_config_DEPRECATED( - client=client, - model=model, - builtin_toolgroups=builtin_toolgroups, - client_tools=client_tools, - json_response_format=json_response_format, - ) - super().__init__( - client=client, - agent_config=agent_config, - client_tools=client_tools, - tool_parser=tool_parser, - extra_headers=extra_headers, - ) + if tool_parser is None: + tool_parser = ReActToolParser() - else: - if not tool_config: - tool_config = { - "tool_choice": "auto", - "system_message_behavior": "replace", - } + tool_list = tools or [] + client_tool_instances = AgentUtils.get_client_tools(tool_list) + server_tool_defs = tuple(tool for tool in tool_list if isinstance(tool, Mapping)) - if json_response_format: - if instructions is not None: - logger.warning( - "Using a custom instructions, but json_response_format is set. Please make sure instructions are" - "compatible with the default ReAct output format." - ) - response_format = { - "type": "json_schema", - "json_schema": ReActOutput.model_json_schema(), - } + if instructions is None: + tool_definitions = _collect_tool_definitions(tuple(server_tool_defs), tuple(client_tool_instances)) + instructions = get_default_react_instructions(tool_definitions) - # build REACT instructions - client_tools = AgentUtils.get_client_tools(tools) - builtin_toolgroups = [x for x in tools if isinstance(x, str) or isinstance(x, dict)] - if not instructions: - instructions = get_default_react_instructions(client, builtin_toolgroups, client_tools) - - super().__init__( - client=client, - model=model, - tool_parser=tool_parser, - instructions=instructions, - tools=tools, - tool_config=tool_config, - sampling_params=sampling_params, - max_infer_iters=max_infer_iters, - input_shields=input_shields, - output_shields=output_shields, - response_format=response_format, - enable_session_persistence=enable_session_persistence, - extra_headers=extra_headers, - ) + super().__init__( + client=client, + model=model, + instructions=instructions, + tools=tool_list, + tool_parser=tool_parser, + extra_headers=extra_headers, + ) diff --git a/src/llama_stack_client/lib/agents/react/tool_parser.py b/src/llama_stack_client/lib/agents/react/tool_parser.py index a796abac..9120f83d 100644 --- a/src/llama_stack_client/lib/agents/react/tool_parser.py +++ b/src/llama_stack_client/lib/agents/react/tool_parser.py @@ -8,8 +8,7 @@ import uuid from typing import List, Optional, Union -from llama_stack_client.types.shared.completion_message import CompletionMessage -from llama_stack_client.types.shared.tool_call import ToolCall +from ..types import CompletionMessage, ToolCall from pydantic import BaseModel, ValidationError diff --git a/src/llama_stack_client/lib/agents/tool_parser.py b/src/llama_stack_client/lib/agents/tool_parser.py index ca8d28ea..0e6a97ad 100644 --- a/src/llama_stack_client/lib/agents/tool_parser.py +++ b/src/llama_stack_client/lib/agents/tool_parser.py @@ -7,8 +7,7 @@ from abc import abstractmethod from typing import List -from llama_stack_client.types.alpha.agents.turn import CompletionMessage -from llama_stack_client.types.shared.tool_call import ToolCall +from .types import CompletionMessage, ToolCall class ToolParser: diff --git a/src/llama_stack_client/lib/agents/turn_events.py b/src/llama_stack_client/lib/agents/turn_events.py new file mode 100644 index 00000000..cca11095 --- /dev/null +++ b/src/llama_stack_client/lib/agents/turn_events.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""High-level turn and step events for agent interactions. + +This module defines the semantic event model that wraps the lower-level +responses API stream events. It provides a turn/step conceptual model that +makes agent interactions easier to understand and work with. + +Key concepts: +- Turn: A complete interaction loop that may span multiple responses +- Step: A distinct phase within a turn (inference or tool_execution) +- Delta: Incremental updates during step execution +- Result: Complete output when a step finishes +""" + +from dataclasses import dataclass +from typing import Union, List, Optional, Dict, Any, Literal + +from .types import ToolCall + +__all__ = [ + "TurnStarted", + "TurnCompleted", + "TurnFailed", + "StepStarted", + "StepProgress", + "StepCompleted", + "TextDelta", + "ToolCallIssuedDelta", + "ToolCallDelta", + "ToolCallCompletedDelta", + "StepDelta", + "InferenceStepResult", + "ToolExecutionStepResult", + "StepResult", + "AgentEvent", + "AgentStreamChunk", +] + + +# ============= Turn-Level Events ============= + + +@dataclass +class TurnStarted: + """Emitted when agent begins processing user input. + + This marks the beginning of a complete interaction cycle that may + involve multiple inference steps and tool executions. + """ + + turn_id: str + session_id: str + event_type: Literal["turn_started"] = "turn_started" + + +@dataclass +class TurnCompleted: + """Emitted when agent finishes with final answer. + + This marks the end of a turn when the model has produced a final + response without any pending client-side tool calls. + """ + + turn_id: str + session_id: str + final_text: str + response_ids: List[str] # All response IDs involved in this turn + num_steps: int + event_type: Literal["turn_completed"] = "turn_completed" + + +@dataclass +class TurnFailed: + """Emitted if turn processing fails. + + This indicates an unrecoverable error during turn processing. + """ + + turn_id: str + session_id: str + error_message: str + event_type: Literal["turn_failed"] = "turn_failed" + + +# ============= Step-Level Events ============= + + +@dataclass +class StepStarted: + """Emitted when a distinct work phase begins. + + Steps represent distinct phases of work within a turn: + - inference: Model thinking/generation (deciding what to do) + - tool_execution: Tool execution (server-side or client-side) + """ + + step_id: str + step_type: Literal["inference", "tool_execution"] + turn_id: str + event_type: Literal["step_started"] = "step_started" + metadata: Optional[Dict[str, Any]] = None # e.g., {"server_side": True/False, "tool_type": "file_search"} + + +# ============= Progress Delta Types ============= + + +@dataclass +class TextDelta: + """Incremental text during inference. + + Emitted as the model generates text token by token. + """ + + text: str + delta_type: Literal["text"] = "text" + + +@dataclass +class ToolCallIssuedDelta: + """Model initiates a tool call (client or server-side). + + This is emitted when the model decides to call a tool. The tool_type + field indicates whether this is: + - "function": Client-side tool requiring client execution + - Other types: Server-side tools executed within the response + """ + + call_id: str + tool_type: Literal["function", "file_search", "web_search", "mcp_call", "mcp_list_tools", "memory_retrieval"] + tool_name: str + arguments: str # JSON string + delta_type: Literal["tool_call_issued"] = "tool_call_issued" + + +@dataclass +class ToolCallDelta: + """Incremental tool call arguments (streaming). + + Emitted as the model streams tool call arguments. The arguments + are accumulated over multiple deltas to form the complete JSON. + """ + + call_id: str + arguments_delta: str + delta_type: Literal["tool_call_delta"] = "tool_call_delta" + + +@dataclass +class ToolCallCompletedDelta: + """Server-side tool execution completed. + + Emitted when a server-side tool (file_search, web_search, etc.) + finishes execution. The result field contains the tool output. + + Note: Client-side function tools do NOT emit this event; instead + they trigger a separate tool_execution step. + """ + + call_id: str + tool_type: Literal["file_search", "web_search", "mcp_call", "mcp_list_tools", "memory_retrieval"] + tool_name: str + result: Any # Tool execution result from server + delta_type: Literal["tool_call_completed"] = "tool_call_completed" + + +# Union of all delta types +StepDelta = Union[TextDelta, ToolCallIssuedDelta, ToolCallDelta, ToolCallCompletedDelta] + + +@dataclass +class StepProgress: + """Emitted during step execution with streaming updates. + + Progress events provide real-time updates as a step executes, + including text deltas and tool call information. + """ + + step_id: str + step_type: Literal["inference", "tool_execution"] + turn_id: str + delta: StepDelta + event_type: Literal["step_progress"] = "step_progress" + + +# ============= Step Result Types ============= + + +@dataclass +class InferenceStepResult: + """Complete inference step output. + + This contains the final accumulated state after an inference step + completes. It separates client-side function calls (which need + client execution) from server-side tool executions (which are + included for logging/reference only). + """ + + step_id: str + response_id: str + text_content: str + function_calls: List[ToolCall] # Client-side function calls that need execution + server_tool_executions: List[Dict[str, Any]] # Server-side tool calls (for reference/logging) + stop_reason: str + + +@dataclass +class ToolExecutionStepResult: + """Complete tool execution step output (client-side only). + + This contains the results of executing client-side function tools. + These results will be fed back to the model in the next inference step. + """ + + step_id: str + tool_calls: List[ToolCall] # Function calls executed + tool_responses: List[Dict[str, Any]] # Normalized responses + + +# Union of all result types +StepResult = Union[InferenceStepResult, ToolExecutionStepResult] + + +@dataclass +class StepCompleted: + """Emitted when a step finishes. + + This provides the complete result of the step execution, including + all accumulated data and final state. + """ + + step_id: str + step_type: Literal["inference", "tool_execution"] + turn_id: str + result: StepResult + event_type: Literal["step_completed"] = "step_completed" + + +# ============= Unified Event Type ============= + + +# Union of all event types +AgentEvent = Union[ + TurnStarted, + StepStarted, + StepProgress, + StepCompleted, + TurnCompleted, + TurnFailed, +] + + +@dataclass +class AgentStreamChunk: + """What the agent yields to users. + + This is the top-level container for streaming events. Each chunk + contains a high-level event (turn or step) and optionally the + final response payload when the turn completes. + + Usage: + for chunk in agent.create_turn(messages, session_id, stream=True): + if isinstance(chunk.event, StepProgress): + if isinstance(chunk.event.delta, TextDelta): + print(chunk.event.delta.text, end="") + elif isinstance(chunk.event, TurnCompleted): + print(f"\\nDone! Response: {chunk.response}") + """ + + event: AgentEvent + response: Optional[Any] = None # Only set on TurnCompleted diff --git a/src/llama_stack_client/lib/agents/types.py b/src/llama_stack_client/lib/agents/types.py new file mode 100644 index 00000000..8ec67485 --- /dev/null +++ b/src/llama_stack_client/lib/agents/types.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +"""Lightweight agent-facing types that avoid llama-stack SDK dependencies.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Protocol, TypedDict + + +@dataclass +class ToolCall: + """Minimal representation of an issued tool call.""" + + call_id: str + tool_name: str + arguments: str + + +@dataclass +class ToolResponse: + """Payload returned from executing a client-side tool.""" + + call_id: str + tool_name: str + content: Any + metadata: Dict[str, Any] = field(default_factory=dict) + + +@dataclass +class CompletionMessage: + """Synthetic completion message mirroring the OpenAI Responses schema.""" + + role: str + content: Any + tool_calls: List[ToolCall] + stop_reason: str + + +Message = CompletionMessage + + +class ToolDefinition(TypedDict, total=False): + """Definition object passed to the Responses API when registering tools.""" + + type: str + name: str + description: str + parameters: Dict[str, Any] + + +class FunctionTool(Protocol): + """Protocol describing the minimal surface area we expect from tools.""" + + def get_name(self) -> str: ... + + def get_description(self) -> str: ... + + def get_input_schema(self) -> Dict[str, Any]: ... diff --git a/tests/integration/test_agent_responses_e2e.py b/tests/integration/test_agent_responses_e2e.py new file mode 100644 index 00000000..7a6cd7d1 --- /dev/null +++ b/tests/integration/test_agent_responses_e2e.py @@ -0,0 +1,151 @@ +import io +import os +import time +from uuid import uuid4 + +import pytest + +from llama_stack_client import BadRequestError, AgentEventLogger +from llama_stack_client.types import ResponseObject, response_create_params +from llama_stack_client.lib.agents.agent import Agent + +MODEL_ID = os.environ.get("LLAMA_STACK_TEST_MODEL") +BASE_URL = os.environ.get("TEST_API_BASE_URL") +KNOWLEDGE_SNIPPET = "SKYRIM-DRAGON-ALLOY" +_VECTOR_STORE_READY_TIMEOUT = 60.0 +_VECTOR_STORE_POLL_INTERVAL = 0.5 + + +def _wrap_response_retrieval(client) -> None: + original_retrieve = client.responses.retrieve + + def retrying_retrieve(response_id: str, **kwargs): + attempts = 0 + while True: + try: + return original_retrieve(response_id, **kwargs) + except BadRequestError as exc: + if getattr(exc, "status_code", None) != 400 or attempts >= 5: + raise + time.sleep(0.2) + attempts += 1 + + client.responses.retrieve = retrying_retrieve # type: ignore[assignment] + + +def _create_vector_store_with_document(client) -> str: + file_payload = io.BytesIO( + f"The secret project codename is {KNOWLEDGE_SNIPPET}. Preserve the hyphens exactly.".encode("utf-8") + ) + uploaded_file = client.files.create( + file=("agent_e2e_notes.txt", file_payload, "text/plain"), + purpose="assistants", + ) + + vector_store = client.vector_stores.create(name=f"agent-e2e-{uuid4().hex[:8]}") + vector_store_file = client.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=uploaded_file.id, + ) + + deadline = time.time() + _VECTOR_STORE_READY_TIMEOUT + while vector_store_file.status != "completed": + if vector_store_file.status in {"failed", "cancelled"}: + raise RuntimeError(f"Vector store ingestion did not succeed: {vector_store_file.status}") + if time.time() > deadline: + raise TimeoutError("Vector store file ingest timed out") + time.sleep(_VECTOR_STORE_POLL_INTERVAL) + vector_store_file = client.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=vector_store_file.id, + ) + + return vector_store.id + + +pytestmark = pytest.mark.skipif( + MODEL_ID is None or BASE_URL in (None, "http://127.0.0.1:4010"), + reason="requires a running llama stack server, TEST_API_BASE_URL, and LLAMA_STACK_TEST_MODEL", +) + + +def test_agent_streaming_and_follow_up_turn(client) -> None: + _wrap_response_retrieval(client) + vector_store_id = _create_vector_store_with_document(client) + + agent = Agent( + client=client, + model=MODEL_ID, + instructions="You can search the uploaded vector store to answer with precise facts.", + tools=[{"type": "file_search", "vector_store_ids": [vector_store_id]}], + ) + + session_id = agent.create_session(f"agent-session-{uuid4().hex[:8]}") + + messages: list[response_create_params.InputUnionMember1] = [ + { + "type": "message", + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Retrieve the secret project codename from the knowledge base and reply as 'codename: '.", + } + ], + } + ] + + event_logger = AgentEventLogger() + stream_chunks = [] + + for chunk in agent.create_turn(messages=messages, session_id=session_id, stream=True): + stream_chunks.append(chunk) + # Drain the event logger for streaming responses (no-op assertions but ensures coverage). + for printable in event_logger.log([chunk]): + _ = printable + + completed_chunks = [chunk for chunk in stream_chunks if chunk.response is not None] + assert completed_chunks, "Expected streaming turn to yield a final response chunk" + + streamed_response = completed_chunks[-1].response + assert isinstance(streamed_response, ResponseObject) + first_response_id = streamed_response.id + + assert streamed_response.model == MODEL_ID + assert agent._last_response_id == first_response_id + assert agent._session_last_response_id.get(session_id) == first_response_id + assert streamed_response.output, "Response output should include tool and message items" + + tool_call_outputs = [item for item in streamed_response.output if getattr(item, "type", None) == "file_search_call"] + assert tool_call_outputs, "Expected a file_search tool call in the response output" + assert any( + KNOWLEDGE_SNIPPET in getattr(result, "text", "") + for output in tool_call_outputs + for result in getattr(output, "results", []) or [] + ), "Vector store results should surface the knowledge snippet" + + assert KNOWLEDGE_SNIPPET in streamed_response.output_text, "Assistant reply should incorporate retrieved snippet" + + follow_up_messages: list[response_create_params.InputUnionMember1] = [ + { + "type": "message", + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Briefly explain why that codename matters.", + } + ], + } + ] + + follow_up_response = agent.create_turn( + messages=follow_up_messages, + session_id=session_id, + stream=False, + ) + + assert isinstance(follow_up_response, ResponseObject) + assert follow_up_response.previous_response_id == first_response_id + assert agent._last_response_id == follow_up_response.id + assert KNOWLEDGE_SNIPPET in follow_up_response.output_text diff --git a/tests/integration/test_agent_turn_step_events.py b/tests/integration/test_agent_turn_step_events.py new file mode 100644 index 00000000..086915df --- /dev/null +++ b/tests/integration/test_agent_turn_step_events.py @@ -0,0 +1,321 @@ +"""Integration tests for agent turn/step event model. + +These tests verify the core architecture of the turn/step event system: +1. Turn = complete interaction loop +2. Inference steps = model thinking/deciding +3. Tool execution steps = ANY tool executing (server OR client-side) + +Key architectural validations: +- Server-side tools (file_search, web_search) appear as tool_execution steps +- Client-side tools (function) appear as tool_execution steps +- Both are properly annotated with metadata +""" + +import io +import os +import time +from uuid import uuid4 + +import pytest +from openai import OpenAI + +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.turn_events import ( + TextDelta, + StepStarted, + TurnStarted, + StepProgress, + StepCompleted, + TurnCompleted, + ToolCallIssuedDelta, +) + +# Test configuration +MODEL_ID = os.environ.get("LLAMA_STACK_TEST_MODEL", "openai/gpt-4o") +BASE_URL = os.environ.get("TEST_API_BASE_URL", "http://localhost:8321/v1") + + +pytestmark = pytest.mark.skipif( + not BASE_URL or BASE_URL == "http://127.0.0.1:4010", + reason="requires a running llama stack server", +) + + +@pytest.fixture +def openai_client(): + return OpenAI(api_key="fake", base_url=BASE_URL) + + +@pytest.fixture +def agent_with_no_tools(openai_client): + """Create an agent with no tools for basic text-only tests.""" + return Agent( + client=openai_client, + model=MODEL_ID, + instructions="You are a helpful assistant. Keep responses brief and concise.", + tools=None, + ) + + +@pytest.fixture +def agent_with_file_search(openai_client): + """Create an agent with file_search tool (server-side).""" + # Create a vector store with test content (unique info to force tool use) + file_content = "Project Nightingale is a classified initiative. The project codename is BLUE_FALCON_7. The lead researcher is Dr. Elena Vasquez." + file_payload = io.BytesIO(file_content.encode("utf-8")) + + uploaded_file = openai_client.files.create( + file=("test_knowledge.txt", file_payload, "text/plain"), + purpose="assistants", + ) + + vector_store = openai_client.vector_stores.create( + name=f"test-vs-{uuid4().hex[:8]}", + extra_body={ + "provider_id": "faiss", + "embedding_model": "nomic-ai/nomic-embed-text-v1.5", + }, + ) + vector_store_file = openai_client.vector_stores.files.create( + vector_store_id=vector_store.id, + file_id=uploaded_file.id, + ) + + # Wait for vector store to be ready + deadline = time.time() + 60.0 + while vector_store_file.status != "completed": + if vector_store_file.status in {"failed", "cancelled"}: + raise RuntimeError(f"Vector store ingestion failed: {vector_store_file.status}") + if time.time() > deadline: + raise TimeoutError("Vector store file ingest timed out") + time.sleep(0.5) + vector_store_file = openai_client.vector_stores.files.retrieve( + vector_store_id=vector_store.id, + file_id=vector_store_file.id, + ) + + return Agent( + client=openai_client, + model=MODEL_ID, + instructions="You MUST search the knowledge base to answer every question. Never answer from your own knowledge.", + tools=[{"type": "file_search", "vector_store_ids": [vector_store.id]}], + ) + + +def test_basic_turn_without_tools(agent_with_no_tools): + # Expected event sequence: + # 1. TurnStarted + # 2. StepStarted(inference) + # 3. StepProgress(TextDelta) x N + # 4. StepCompleted(inference) + # 5. TurnCompleted + agent = agent_with_no_tools + session_id = agent.create_session(f"test-session-{uuid4().hex[:8]}") + + messages = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "Say hello in exactly 3 words."}], + } + ] + + events = [] + for chunk in agent.create_turn(messages=messages, session_id=session_id, stream=True): + events.append(chunk.event) + + # Verify event sequence + assert len(events) > 0, "Should have at least some events" + + # First event should be TurnStarted + assert isinstance(events[0], TurnStarted), f"First event should be TurnStarted, got {type(events[0])}" + assert events[0].session_id == session_id + + # Second event should be StepStarted(inference) + assert isinstance(events[1], StepStarted), f"Second event should be StepStarted, got {type(events[1])}" + assert events[1].step_type == "inference" + assert events[1].metadata is None or events[1].metadata.get("server_side") is None + + # Should have some StepProgress(TextDelta) events + text_deltas = [e for e in events if isinstance(e, StepProgress) and isinstance(e.delta, TextDelta)] + assert len(text_deltas) > 0, "Should have at least one text delta" + + # Should have NO tool_execution steps (no tools configured) + tool_execution_starts = [e for e in events if isinstance(e, StepStarted) and e.step_type == "tool_execution"] + assert len(tool_execution_starts) == 0, "Should have no tool_execution steps without tools" + + # Second-to-last event should be StepCompleted(inference) + inference_completes = [e for e in events if isinstance(e, StepCompleted) and e.step_type == "inference"] + assert len(inference_completes) >= 1, "Should have at least one inference step completion" + + # Last inference completion should have no function calls + last_inference = inference_completes[-1] + assert len(last_inference.result.function_calls) == 0, "Should have no function calls" + assert last_inference.result.stop_reason == "end_of_turn" + + # Last event should be TurnCompleted + assert isinstance(events[-1], TurnCompleted), f"Last event should be TurnCompleted, got {type(events[-1])}" + assert events[-1].session_id == session_id + assert len(events[-1].final_text) > 0, "Should have some final text" + + +def test_server_side_file_search_tool(agent_with_file_search): + # Expected event sequence: + # 1. TurnStarted + # 2. StepStarted(inference) - model decides to search + # 3. StepProgress(TextDelta) - optional text before tool + # 4. StepCompleted(inference) - inference done, decided to use file_search + # 5. StepStarted(tool_execution, metadata.server_side=True) + # 6. StepCompleted(tool_execution) - file_search results + # 7. StepStarted(inference) - model processes results + # 8. StepProgress(TextDelta) - model generates response + # 9. StepCompleted(inference) + # 10. TurnCompleted + agent = agent_with_file_search + session_id = agent.create_session(f"test-session-{uuid4().hex[:8]}") + + messages = [ + { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "What is the codename for Project Nightingale?"}], + } + ] + + events = [] + for chunk in agent.create_turn(messages=messages, session_id=session_id, stream=True): + events.append(chunk.event) + + # Verify Turn started and completed + assert isinstance(events[0], TurnStarted) + assert isinstance(events[-1], TurnCompleted) + + # KEY ASSERTION: Should have at least one tool_execution step (server-side file_search) + tool_execution_starts = [e for e in events if isinstance(e, StepStarted) and e.step_type == "tool_execution"] + assert len(tool_execution_starts) >= 1, ( + f"Should have at least one tool_execution step, found {len(tool_execution_starts)}" + ) + + # KEY ASSERTION: The tool_execution step should be marked as server_side + file_search_step = tool_execution_starts[0] + assert file_search_step.metadata is not None, "Tool execution step should have metadata" + assert file_search_step.metadata.get("server_side") is True, "file_search should be marked as server_side" + assert file_search_step.metadata.get("tool_type") == "file_search", "Should identify as file_search tool" + + # Should have tool_execution completion + tool_execution_completes = [e for e in events if isinstance(e, StepCompleted) and e.step_type == "tool_execution"] + assert len(tool_execution_completes) >= 1, "Should have at least one tool_execution completion" + + # Should have multiple inference steps (before and after tool execution) + inference_starts = [e for e in events if isinstance(e, StepStarted) and e.step_type == "inference"] + assert len(inference_starts) >= 2, ( + f"Should have at least 2 inference steps (before/after tool), found {len(inference_starts)}" + ) + + # Final response should contain the answer (BLUE_FALCON_7) + assert "BLUE_FALCON" in events[-1].final_text or "BLUE_FALCON_7" in events[-1].final_text, ( + "Response should contain the codename" + ) + + +def test_client_side_function_tool(openai_client): + # We are going to test + # Expected event sequence: + # 1. TurnStarted + # 2. StepStarted(inference) + # 3. StepProgress(ToolCallIssuedDelta) - function call + # 4. StepCompleted(inference) - with function_calls + # 5. StepStarted(tool_execution, metadata.server_side=False) + # 6. StepCompleted(tool_execution) - client executed function + # 7. StepStarted(inference) - model processes results + # 8. StepProgress(TextDelta) + # 9. StepCompleted(inference) + # 10. TurnCompleted + + # Create a function that returns data requiring model processing + def get_user_secret_token(user_id: str) -> str: + """Get the encrypted authentication token for a user from the secure database. + + The token is returned in encrypted hex format and must be decoded by the AI. + + :param user_id: The unique identifier of the user + """ + import time + import hashlib + + unique = f"{user_id}-{time.time()}-SECRET" + token_hash = hashlib.sha256(unique.encode()).hexdigest()[:16] + return f'{{"status": "success", "encrypted_token": "{token_hash}", "format": "hex", "expires_in_hours": 24}}' + + agent = Agent( + client=openai_client, + model=MODEL_ID, + instructions="You are a helpful assistant. When retrieving tokens, you MUST call get_user_secret_token, then parse the JSON response and present it in a user-friendly format. Explain what the token is and when it expires.", + tools=[get_user_secret_token], + ) + + session_id = agent.create_session(f"test-session-{uuid4().hex[:8]}") + messages = [ + { + "type": "message", + "role": "user", + "content": [ + { + "type": "input_text", + "text": "Can you get me the authentication token for user_12345? Please explain what it is.", + } + ], + } + ] + + events = [] + + for chunk in agent.create_turn(messages=messages, session_id=session_id, stream=True): + events.append(chunk.event) + + # Verify Turn started and completed + assert isinstance(events[0], TurnStarted) + assert isinstance(events[-1], TurnCompleted) + + # Should have ToolCallIssuedDelta for the function call + function_calls = [ + e + for e in events + if isinstance(e, StepProgress) and isinstance(e.delta, ToolCallIssuedDelta) and e.delta.tool_type == "function" + ] + assert len(function_calls) >= 1, ( + f"Should have at least one function call (get_user_secret_token), found {len(function_calls)}" + ) + + # KEY ASSERTION: Should have tool_execution step (client-side function) + tool_execution_starts = [e for e in events if isinstance(e, StepStarted) and e.step_type == "tool_execution"] + assert len(tool_execution_starts) >= 1, ( + f"Should have at least one tool_execution step, found {len(tool_execution_starts)}" + ) + + # KEY ASSERTION: The tool_execution step should be marked as client-side + function_step = tool_execution_starts[0] + assert function_step.metadata is not None, "Tool execution step should have metadata" + assert function_step.metadata.get("server_side") is False, "Function tool should be marked as client_side" + + # Should have at least one inference step (before tool execution) + inference_starts = [e for e in events if isinstance(e, StepStarted) and e.step_type == "inference"] + assert len(inference_starts) >= 1, f"Should have at least 1 inference step, found {len(inference_starts)}" + + # Verify the tool was actually executed with proper arguments + tool_execution_completes = [e for e in events if isinstance(e, StepCompleted) and e.step_type == "tool_execution"] + assert len(tool_execution_completes) >= 1, "Should have at least one tool_execution completion" + + # Final response should contain the token data (proves function was called and processed) + final_text = events[-1].final_text.lower() + assert "token" in final_text or "encrypted" in final_text, ( + "Response should contain token information from function call" + ) + assert "24" in events[-1].final_text or "hour" in final_text or "expire" in final_text, ( + "Response should mention expiration from parsed JSON" + ) + + +if __name__ == "__main__": + # Allow running tests directly for development + pytest.main([__file__, "-v", "-s"]) diff --git a/tests/lib/agents/test_agent_responses.py b/tests/lib/agents/test_agent_responses.py new file mode 100644 index 00000000..4ed2c3b6 --- /dev/null +++ b/tests/lib/agents/test_agent_responses.py @@ -0,0 +1,206 @@ +from __future__ import annotations + +from types import SimpleNamespace +from typing import Any, Dict, Iterable, Iterator, List, Optional + +import pytest + +from llama_stack_client.lib.agents.agent import Agent +from llama_stack_client.lib.agents.client_tool import client_tool +from llama_stack_client.lib.agents.turn_events import ( + AgentStreamChunk, + StepCompleted, + StepProgress, + StepStarted, + ToolExecutionStepResult, + TurnCompleted, + TurnStarted, +) + + +def _event(event_type: str, **payload: Any) -> SimpleNamespace: + return SimpleNamespace(type=event_type, **payload) + + +def _in_progress(response_id: str) -> SimpleNamespace: + return _event("response.in_progress", response=SimpleNamespace(id=response_id)) + + +def _completed(response_id: str, text: str) -> SimpleNamespace: + response = FakeResponse(response_id, text) + return _event("response.completed", response=response) + + +class FakeResponse: + def __init__(self, response_id: str, text: str) -> None: + self.id = response_id + self.output_text = text + self.turn = SimpleNamespace(turn_id=f"turn_{response_id}") + + +class FakeResponsesAPI: + def __init__(self, event_script: Iterable[Iterable[SimpleNamespace]]) -> None: + self._event_script: List[List[SimpleNamespace]] = [list(events) for events in event_script] + self.create_calls: List[Dict[str, Any]] = [] + + def create(self, **kwargs: Any) -> Iterator[SimpleNamespace]: + self.create_calls.append(kwargs) + if not self._event_script: + raise AssertionError("No scripted events left for responses.create") + return iter(self._event_script.pop(0)) + + +class FakeConversationsAPI: + def __init__(self) -> None: + self._counter = 0 + + def create(self, **_: Any) -> SimpleNamespace: + self._counter += 1 + return SimpleNamespace(id=f"conv_{self._counter}") + + +class FakeClient: + def __init__(self, event_script: Iterable[Iterable[SimpleNamespace]]) -> None: + self.responses = FakeResponsesAPI(event_script) + self.conversations = FakeConversationsAPI() + + +def make_completion_events(response_id: str, text: str) -> List[SimpleNamespace]: + return [ + _in_progress(response_id), + _event("response.output_text.delta", delta=text, output_index=0), + _completed(response_id, text), + ] + + +def make_function_tool_events(response_id: str, call_id: str, tool_name: str, arguments: str) -> List[SimpleNamespace]: + tool_item = SimpleNamespace(type="function_call", call_id=call_id, name=tool_name, arguments=arguments) + return [ + _in_progress(response_id), + _event("response.output_item.added", item=tool_item), + _event("response.output_item.done", item=tool_item), + _completed(response_id, ""), + ] + + +def make_server_tool_events(response_id: str, call_id: str, arguments: str, final_text: str) -> List[SimpleNamespace]: + tool_item = SimpleNamespace(type="file_search_call", id=call_id, arguments=arguments) + completion = FakeResponse(response_id, final_text) + return [ + _in_progress(response_id), + _event("response.output_item.added", item=tool_item), + _event("response.output_item.done", item=tool_item), + _event("response.output_text.delta", delta=final_text, output_index=0), + _completed(response_id, final_text), + ] + + +def test_agent_tracks_multiple_sessions() -> None: + event_script = [ + make_completion_events("resp_a1", "session A turn 1"), + make_completion_events("resp_b1", "session B turn 1"), + make_completion_events("resp_a2", "session A turn 2"), + ] + + client = FakeClient(event_script) + agent = Agent(client=client, model="test-model", instructions="test") + + session_a = agent.create_session("A") + session_b = agent.create_session("B") + + message = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "hi"}], + } + + agent.create_turn([message], session_id=session_a, stream=False) + agent.create_turn([message], session_id=session_b, stream=False) + agent.create_turn([message], session_id=session_a, stream=False) + + calls = client.responses.create_calls + assert calls[0]["conversation"] == session_a + assert calls[1]["conversation"] == session_b + assert calls[2]["conversation"] == session_a + + assert agent._session_last_response_id[session_a] == "resp_a2" + assert agent._session_last_response_id[session_b] == "resp_b1" + assert agent._last_response_id == "resp_a2" + + +def test_agent_handles_client_tool_and_finishes_turn() -> None: + tool_invocations: List[str] = [] + + @client_tool + def echo_tool(text: str) -> str: + """Echo text back to the caller. + + :param text: value to echo + """ + tool_invocations.append(text) + return text + + event_script = [ + make_function_tool_events("resp_intermediate", "call_1", "echo_tool", '{"text": "pong"}'), + make_completion_events("resp_final", "all done"), + ] + + client = FakeClient(event_script) + agent = Agent(client=client, model="test-model", instructions="use tools", tools=[echo_tool]) + + session_id = agent.create_session("default") + message = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "run the tool"}], + } + + response = agent.create_turn([message], session_id=session_id, stream=False) + + assert response.id == "resp_final" + assert response.output_text == "all done" + assert tool_invocations == ["pong"] + assert len(client.responses.create_calls) == 2 + + +def test_agent_streams_server_tool_events() -> None: + event_script = [ + make_server_tool_events("resp_server", "server_call", '{"query": "docs"}', "tool finished"), + ] + + client = FakeClient(event_script) + agent = Agent(client=client, model="test-model", instructions="use server tool") + + session_id = agent.create_session("default") + message = { + "type": "message", + "role": "user", + "content": [{"type": "input_text", "text": "find info"}], + } + + chunks = list(agent.create_turn([message], session_id=session_id, stream=True)) + + events = [chunk.event for chunk in chunks] + assert isinstance(events[0], TurnStarted) + assert isinstance(events[1], StepStarted) + assert events[1].step_type == "inference" + + # Look for the tool execution step in the stream + tool_step_started = next(event for event in events if isinstance(event, StepStarted) and event.step_type == "tool_execution") + assert tool_step_started.metadata == {"server_side": True, "tool_type": "file_search", "tool_name": "file_search_call"} + + tool_step_completed = next( + event for event in events if isinstance(event, StepCompleted) and event.step_type == "tool_execution" + ) + assert isinstance(tool_step_completed.result, ToolExecutionStepResult) + assert tool_step_completed.result.tool_calls[0].call_id == "server_call" + + text_progress = [ + event.delta.text + for event in events + if isinstance(event, StepProgress) and hasattr(event.delta, "text") + ] + assert text_progress == ["tool finished"] + + assert isinstance(events[-1], TurnCompleted) + assert chunks[-1].response and chunks[-1].response.output_text == "tool finished" diff --git a/uv.lock b/uv.lock index 053f41d2..9d63e22e 100644 --- a/uv.lock +++ b/uv.lock @@ -424,7 +424,7 @@ wheels = [ [[package]] name = "llama-stack-client" -version = "0.2.23" +version = "0.3.0a6" source = { editable = "." } dependencies = [ { name = "anyio" },