From 0c1d60a94acb0a751c64b953bfc66eab5808a727 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 21 Jan 2026 08:10:43 +0900 Subject: [PATCH 1/8] Refactor ag-ui to simplify flow --- .../ag-ui/agent_framework_ag_ui/_agent.py | 100 +-- .../ag-ui/agent_framework_ag_ui/_run.py | 730 ++++++++++++++++++ .../tests/test_agent_wrapper_comprehensive.py | 1 + 3 files changed, 750 insertions(+), 81 deletions(-) create mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_run.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index 806f5ab1bb..f04f498168 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft. All rights reserved. -"""AgentFrameworkAgent wrapper for AG-UI protocol - Clean Architecture.""" +"""AgentFrameworkAgent wrapper for AG-UI protocol.""" from collections.abc import AsyncGenerator from typing import Any, cast @@ -8,13 +8,7 @@ from ag_ui.core import BaseEvent from agent_framework import AgentProtocol -from ._confirmation_strategies import ConfirmationStrategy, DefaultConfirmationStrategy -from ._orchestrators import ( - DefaultOrchestrator, - ExecutionContext, - HumanInTheLoopOrchestrator, - Orchestrator, -) +from ._run import run_agent_stream class AgentConfig: @@ -26,19 +20,22 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, use_service_thread: bool = False, require_confirmation: bool = True, + confirmation_strategy: Any | None = None, ): """Initialize agent configuration. Args: state_schema: Optional state schema for state management; accepts dict or Pydantic model/class - predict_state_config: Configuration for predictive state updates - use_service_thread: Whether the agent thread is service-managed - require_confirmation: Whether predictive updates require confirmation + predict_state_config: Configuration for predictive state updates (currently unused in simplified impl) + use_service_thread: Whether the agent thread is service-managed (currently unused) + require_confirmation: Whether predictive updates require confirmation (currently unused) + confirmation_strategy: Optional strategy for generating confirmation messages """ self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} self.use_service_thread = use_service_thread self.require_confirmation = require_confirmation + self.confirmation_strategy = confirmation_strategy @staticmethod def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]: @@ -72,12 +69,7 @@ class AgentFrameworkAgent: """Wraps Agent Framework agents for AG-UI protocol compatibility. Translates between Agent Framework's AgentProtocol and AG-UI's event-based - protocol. Uses orchestrators to handle different execution flows (standard - execution, human-in-the-loop, etc.). Orchestrators are checked in order; - the first matching orchestrator handles the request. - - Supports predictive state updates for agentic generative UI, with optional - confirmation requirements configurable per use case. + protocol. Follows a simple linear flow: RunStarted -> content events -> RunFinished. """ def __init__( @@ -88,9 +80,9 @@ def __init__( state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, - orchestrators: list[Orchestrator] | None = None, + orchestrators: Any = None, # Deprecated, kept for backwards compatibility use_service_thread: bool = False, - confirmation_strategy: ConfirmationStrategy | None = None, + confirmation_strategy: Any = None, # Deprecated, kept for backwards compatibility ): """Initialize the AG-UI compatible agent wrapper. @@ -99,15 +91,11 @@ def __init__( name: Optional name for the agent description: Optional description state_schema: Optional state schema for state management; accepts dict or Pydantic model/class - predict_state_config: Configuration for predictive state updates. - Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} - require_confirmation: Whether predictive updates require confirmation. - Set to False for agentic generative UI that updates automatically. - orchestrators: Custom orchestrators (auto-configured if None). - Orchestrators are checked in order; first match handles the request. - use_service_thread: Whether the agent thread is service-managed. - confirmation_strategy: Strategy for generating confirmation messages. - Defaults to DefaultConfirmationStrategy if None. + predict_state_config: Configuration for predictive state updates (optional feature) + require_confirmation: Whether predictive updates require confirmation + orchestrators: Deprecated - kept for backwards compatibility + use_service_thread: Whether the agent thread is service-managed + confirmation_strategy: Strategy for generating confirmation messages """ self.agent = agent self.name = name or getattr(agent, "name", "agent") @@ -118,73 +106,23 @@ def __init__( predict_state_config=predict_state_config, use_service_thread=use_service_thread, require_confirmation=require_confirmation, + confirmation_strategy=confirmation_strategy, ) - # Configure orchestrators - if orchestrators is None: - self.orchestrators = self._default_orchestrators() - else: - self.orchestrators = orchestrators - - # Configure confirmation strategy - if confirmation_strategy is None: - self.confirmation_strategy: ConfirmationStrategy = DefaultConfirmationStrategy() - else: - self.confirmation_strategy = confirmation_strategy - - def _default_orchestrators(self) -> list[Orchestrator]: - """Create default orchestrator chain. - - Returns: - List of orchestrators in priority order. First matching orchestrator - handles the request, so order matters. - """ - return [ - HumanInTheLoopOrchestrator(), # Handle tool approval responses - # Add more specialized orchestrators here as needed - DefaultOrchestrator(), # Fallback: standard agent execution - ] - async def run_agent( self, input_data: dict[str, Any], ) -> AsyncGenerator[BaseEvent, None]: """Run the agent and yield AG-UI events. - This is the ONLY public method - much simpler than the original 376-line - implementation. All orchestration logic has been extracted into dedicated - Orchestrator classes. - - The method creates an ExecutionContext with all needed data, then finds - the first orchestrator that can handle the request and delegates to it. - Args: input_data: The AG-UI run input containing messages, state, etc. Yields: AG-UI events - - Raises: - RuntimeError: If no orchestrator matches (should never happen if - DefaultOrchestrator is last in the chain) """ - # Create execution context with all needed data - context = ExecutionContext( - input_data=input_data, - agent=self.agent, - config=self.config, - confirmation_strategy=self.confirmation_strategy, - ) - - # Find matching orchestrator and execute - for orchestrator in self.orchestrators: - if orchestrator.can_handle(context): - async for event in orchestrator.run(context): - yield event - return - - # Should never reach here if DefaultOrchestrator is last - raise RuntimeError("No orchestrator matched - check configuration") + async for event in run_agent_stream(input_data, self.agent, self.config): + yield event __all__ = [ diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py new file mode 100644 index 0000000000..b240e2a4af --- /dev/null +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -0,0 +1,730 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Simplified AG-UI orchestration - single linear flow.""" + +import json +import logging +import uuid +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from ag_ui.core import ( + BaseEvent, + CustomEvent, + MessagesSnapshotEvent, + RunFinishedEvent, + RunStartedEvent, + StateSnapshotEvent, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, +) +from agent_framework import ( + AgentProtocol, + AgentThread, + ChatMessage, + FunctionApprovalRequestContent, + FunctionCallContent, + FunctionResultContent, + TextContent, + prepare_function_call_results, +) + +from ._message_adapters import normalize_agui_input_messages +from ._orchestration._predictive_state import PredictiveStateHandler +from ._orchestration._tooling import collect_server_tools, merge_tools, register_additional_client_tools +from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_conversation_id_from_update + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator + + from ._agent import AgentConfig + +logger = logging.getLogger(__name__) + + +@dataclass +class FlowState: + """Minimal explicit state for a single AG-UI run.""" + + message_id: str | None = None # Current text message being streamed + tool_call_id: str | None = None # Current tool call being streamed + tool_call_name: str | None = None # Name of current tool call + waiting_for_approval: bool = False # Stop after approval request + current_state: dict[str, Any] = field(default_factory=dict) # Shared state + accumulated_text: str = "" # For MessagesSnapshotEvent + pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # For MessagesSnapshotEvent + tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict) + tool_results: list[dict[str, Any]] = field(default_factory=list) + + +def _create_state_context_message( + current_state: dict[str, Any], + state_schema: dict[str, Any], +) -> ChatMessage | None: + """Create a system message with current state context. + + This injects the current state into the conversation so the model + knows what state exists and can make informed updates. + + Args: + current_state: The current state to inject + state_schema: The state schema (used to determine if injection is needed) + + Returns: + ChatMessage with state context, or None if not needed + """ + if not current_state or not state_schema: + return None + + state_json = json.dumps(current_state, indent=2) + return ChatMessage( + role="system", + contents=[ + TextContent( + text=( + "Current state of the application:\n" + f"{state_json}\n\n" + "When modifying state, you MUST include ALL existing data plus your changes.\n" + "For example, if adding one new item to a list, include ALL existing items PLUS the new item.\n" + "Never replace existing data - always preserve and append or merge." + ) + ) + ], + ) + + +def _inject_state_context( + messages: list[ChatMessage], + current_state: dict[str, Any], + state_schema: dict[str, Any], +) -> list[ChatMessage]: + """Inject state context message into messages if appropriate. + + The state context is injected before the last user message to give + the model visibility into the current application state. + + Args: + messages: The messages to potentially inject into + current_state: The current state + state_schema: The state schema + + Returns: + Messages with state context injected if appropriate + """ + state_msg = _create_state_context_message(current_state, state_schema) + if not state_msg: + return messages + + # Check if the last message is from a user (new user turn) + if not messages: + return messages + + from ._utils import get_role_value + + last_role = get_role_value(messages[-1]) + if last_role != "user": + return messages + + # Always inject state context if state is provided + # This ensures UI state changes are visible to the model + + # Insert state context before the last user message + result = list(messages[:-1]) + result.append(state_msg) + result.append(messages[-1]) + return result + + +def _emit_text(content: TextContent, flow: FlowState, skip_text: bool = False) -> list[BaseEvent]: + """Emit TextMessage events for TextContent.""" + if not content.text: + return [] + + # Skip if we're in structured output mode or waiting for approval + if skip_text or flow.waiting_for_approval: + return [] + + events: list[BaseEvent] = [] + if not flow.message_id: + flow.message_id = generate_event_id() + events.append(TextMessageStartEvent(message_id=flow.message_id, role="assistant")) + + events.append(TextMessageContentEvent(message_id=flow.message_id, delta=content.text)) + flow.accumulated_text += content.text + return events + + +def _emit_tool_call( + content: FunctionCallContent, + flow: FlowState, + predictive_handler: PredictiveStateHandler | None = None, +) -> list[BaseEvent]: + """Emit ToolCall events for FunctionCallContent.""" + events: list[BaseEvent] = [] + + tool_call_id = content.call_id or flow.tool_call_id or generate_event_id() + + # Emit start event when we have a new tool call + if content.name and tool_call_id != flow.tool_call_id: + flow.tool_call_id = tool_call_id + flow.tool_call_name = content.name + if predictive_handler: + predictive_handler.reset_streaming() + + events.append( + ToolCallStartEvent( + tool_call_id=tool_call_id, + tool_call_name=content.name, + parent_message_id=flow.message_id, + ) + ) + + # Track for MessagesSnapshotEvent + tool_entry = { + "id": tool_call_id, + "type": "function", + "function": {"name": content.name, "arguments": ""}, + } + flow.pending_tool_calls.append(tool_entry) + flow.tool_calls_by_id[tool_call_id] = tool_entry + + elif tool_call_id: + flow.tool_call_id = tool_call_id + + # Emit args if present + if content.arguments: + delta = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) + events.append(ToolCallArgsEvent(tool_call_id=tool_call_id, delta=delta)) + + # Track args for MessagesSnapshotEvent + if tool_call_id in flow.tool_calls_by_id: + flow.tool_calls_by_id[tool_call_id]["function"]["arguments"] += delta + + # Emit predictive state deltas + if predictive_handler and flow.tool_call_name: + delta_events = predictive_handler.emit_streaming_deltas(flow.tool_call_name, delta) + events.extend(delta_events) + + return events + + +def _emit_tool_result( + content: FunctionResultContent, + flow: FlowState, + predictive_handler: PredictiveStateHandler | None = None, +) -> list[BaseEvent]: + """Emit ToolCallResult events for FunctionResultContent.""" + events: list[BaseEvent] = [] + + if content.call_id: + events.append(ToolCallEndEvent(tool_call_id=content.call_id)) + + result_content = prepare_function_call_results(content.result) + message_id = generate_event_id() + events.append( + ToolCallResultEvent( + message_id=message_id, + tool_call_id=content.call_id, + content=result_content, + role="tool", + ) + ) + + # Track for MessagesSnapshotEvent + flow.tool_results.append( + { + "id": message_id, + "role": "tool", + "toolCallId": content.call_id, + "content": result_content, + } + ) + + # Apply predictive state updates and emit snapshot + if predictive_handler: + predictive_handler.apply_pending_updates() + if flow.current_state: + events.append(StateSnapshotEvent(snapshot=flow.current_state)) + + # Reset tool tracking + flow.tool_call_id = None + flow.tool_call_name = None + + return events + + +def _emit_approval_request( + content: FunctionApprovalRequestContent, + flow: FlowState, + predictive_handler: PredictiveStateHandler | None = None, + require_confirmation: bool = True, +) -> list[BaseEvent]: + """Emit events for function approval request.""" + events: list[BaseEvent] = [] + + # Extract state from function arguments if predictive + if predictive_handler: + parsed_args = content.function_call.parse_arguments() + result = predictive_handler.extract_state_value(content.function_call.name, parsed_args) + if result: + state_key, state_value = result + flow.current_state[state_key] = state_value + events.append(StateSnapshotEvent(snapshot=flow.current_state)) + + # End the original tool call + if content.function_call.call_id: + events.append(ToolCallEndEvent(tool_call_id=content.function_call.call_id)) + + # Emit custom event for UI + events.append( + CustomEvent( + name="function_approval_request", + value={ + "id": content.id, + "function_call": { + "call_id": content.function_call.call_id, + "name": content.function_call.name, + "arguments": content.function_call.parse_arguments(), + }, + }, + ) + ) + + # Emit confirm_changes tool call for UI compatibility + if require_confirmation: + confirm_id = generate_event_id() + events.append( + ToolCallStartEvent( + tool_call_id=confirm_id, + tool_call_name="confirm_changes", + parent_message_id=flow.message_id, + ) + ) + args = { + "function_name": content.function_call.name, + "function_call_id": content.function_call.call_id, + "function_arguments": content.function_call.parse_arguments() or {}, + "steps": [{"description": f"Execute {content.function_call.name}", "status": "enabled"}], + } + events.append(ToolCallArgsEvent(tool_call_id=confirm_id, delta=json.dumps(args))) + events.append(ToolCallEndEvent(tool_call_id=confirm_id)) + + flow.waiting_for_approval = True + return events + + +def _emit_content( + content: Any, + flow: FlowState, + predictive_handler: PredictiveStateHandler | None = None, + skip_text: bool = False, + require_confirmation: bool = True, +) -> list[BaseEvent]: + """Emit appropriate events for any content type.""" + if isinstance(content, TextContent): + return _emit_text(content, flow, skip_text) + elif isinstance(content, FunctionCallContent): + return _emit_tool_call(content, flow, predictive_handler) + elif isinstance(content, FunctionResultContent): + return _emit_tool_result(content, flow, predictive_handler) + elif isinstance(content, FunctionApprovalRequestContent): + return _emit_approval_request(content, flow, predictive_handler, require_confirmation) + return [] + + +def _is_confirm_changes_response(messages: list[Any]) -> bool: + """Check if the last message is a confirm_changes tool result (state confirmation flow). + + This returns True for confirm_changes flows where we emit a confirmation message + and stop. The key indicator is the presence of a 'steps' key in the tool result + (even if empty), combined with 'accepted' boolean. + """ + if not messages: + return False + last = messages[-1] + if not last.additional_properties.get("is_tool_result", False): + return False + + # Parse the content to check if it has the confirm_changes structure + for content in last.contents: + if isinstance(content, TextContent): + try: + result = json.loads(content.text) + # confirm_changes results have 'accepted' and 'steps' keys + if "accepted" in result and "steps" in result: + return True + except json.JSONDecodeError: + pass + return False + + +def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]: + """Handle step-based approval response and emit confirmation message.""" + events: list[BaseEvent] = [] + last = messages[-1] + + # Parse the approval content + approval_text = "" + for content in last.contents: + if isinstance(content, TextContent): + approval_text = content.text + break + + try: + result = json.loads(approval_text) + accepted = result.get("accepted", False) + steps = result.get("steps", []) + + if accepted: + # Generate acceptance message with step descriptions + enabled_steps = [s for s in steps if s.get("status") == "enabled"] + if enabled_steps: + message_parts = [f"Executing {len(enabled_steps)} approved steps:\n\n"] + for i, step in enumerate(enabled_steps, 1): + message_parts.append(f"{i}. {step.get('description', 'Step')}\n") + message_parts.append("\nAll steps completed successfully!") + message = "".join(message_parts) + else: + message = "Changes confirmed and applied successfully!" + else: + # Rejection message + message = "No problem! What would you like me to change about the plan?" + except json.JSONDecodeError: + message = "Acknowledged." + + message_id = generate_event_id() + events.append(TextMessageStartEvent(message_id=message_id, role="assistant")) + events.append(TextMessageContentEvent(message_id=message_id, delta=message)) + events.append(TextMessageEndEvent(message_id=message_id)) + + return events + + +async def _resolve_approval_responses( + messages: list[Any], + tools: list[Any], + agent: AgentProtocol, + run_kwargs: dict[str, Any], +) -> None: + """Execute approved function calls and replace approval content with results. + + This modifies the messages list in place, replacing FunctionApprovalResponseContent + with FunctionResultContent containing the actual tool execution result. + + Args: + messages: List of messages (will be modified in place) + tools: List of available tools + agent: The agent instance (to get chat_client and config) + run_kwargs: Kwargs for tool execution + """ + from agent_framework._middleware import extract_and_merge_function_middleware + from agent_framework._tools import ( + FunctionInvocationConfiguration, + _collect_approval_responses, + _replace_approval_contents_with_results, + _try_execute_function_calls, + ) + + fcc_todo = _collect_approval_responses(messages) + if not fcc_todo: + return + + approved_responses = [resp for resp in fcc_todo.values() if resp.approved] + rejected_responses = [resp for resp in fcc_todo.values() if not resp.approved] + approved_function_results: list[Any] = [] + + # Execute approved tool calls + if approved_responses and tools: + chat_client = getattr(agent, "chat_client", None) + config = getattr(chat_client, "function_invocation_configuration", None) or FunctionInvocationConfiguration() + middleware_pipeline = extract_and_merge_function_middleware(chat_client, run_kwargs) + # Filter out AG-UI-specific kwargs that should not be passed to tool execution + tool_kwargs = {k: v for k, v in run_kwargs.items() if k != "options"} + try: + results, _ = await _try_execute_function_calls( + custom_args=tool_kwargs, + attempt_idx=0, + function_calls=approved_responses, + tools=tools, + middleware_pipeline=middleware_pipeline, + config=config, + ) + approved_function_results = list(results) + except Exception: + logger.error("Failed to execute approved tool calls; injecting error results.") + approved_function_results = [] + + # Build normalized results for approved responses + normalized_results: list[FunctionResultContent] = [] + for idx, approval in enumerate(approved_responses): + if idx < len(approved_function_results) and isinstance(approved_function_results[idx], FunctionResultContent): + normalized_results.append(approved_function_results[idx]) + continue + call_id = approval.function_call.call_id or approval.id + normalized_results.append(FunctionResultContent(call_id=call_id, result="Error: Tool call invocation failed.")) + + # Build rejection results + for rejection in rejected_responses: + call_id = rejection.function_call.call_id or rejection.id + normalized_results.append( + FunctionResultContent(call_id=call_id, result="Error: Tool call invocation was rejected by user.") + ) + + _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore + + +def _build_messages_snapshot( + flow: FlowState, + snapshot_messages: list[dict[str, Any]], +) -> MessagesSnapshotEvent: + """Build MessagesSnapshotEvent from current flow state.""" + all_messages = list(snapshot_messages) + + # Add assistant message with tool calls + if flow.pending_tool_calls: + tool_call_message = { + "id": flow.message_id or generate_event_id(), + "role": "assistant", + "tool_calls": flow.pending_tool_calls.copy(), + } + if flow.accumulated_text: + tool_call_message["content"] = flow.accumulated_text + all_messages.append(tool_call_message) + + # Add tool results + all_messages.extend(flow.tool_results) + + # Add text-only assistant message if no tool calls + if flow.accumulated_text and not flow.pending_tool_calls: + all_messages.append( + { + "id": flow.message_id or generate_event_id(), + "role": "assistant", + "content": flow.accumulated_text, + } + ) + + return MessagesSnapshotEvent(messages=all_messages) # type: ignore[arg-type] + + +async def run_agent_stream( + input_data: dict[str, Any], + agent: AgentProtocol, + config: "AgentConfig", +) -> "AsyncGenerator[BaseEvent, None]": + """Run agent and yield AG-UI events. + + This is the single entry point for all AG-UI agent runs. It follows a simple + linear flow: RunStarted -> content events -> RunFinished. + + Args: + input_data: AG-UI request data with messages, state, tools, etc. + agent: The Agent Framework agent to run + config: Agent configuration + + Yields: + AG-UI events + """ + # Parse IDs + thread_id = input_data.get("thread_id") or input_data.get("threadId") or str(uuid.uuid4()) + run_id = input_data.get("run_id") or input_data.get("runId") or str(uuid.uuid4()) + + # Initialize flow state with schema defaults + flow = FlowState() + if input_data.get("state"): + flow.current_state = dict(input_data["state"]) + + # Apply schema defaults for missing state keys + if config.state_schema: + for key, schema in config.state_schema.items(): + if key in flow.current_state: + continue + if isinstance(schema, dict) and schema.get("type") == "array": + flow.current_state[key] = [] + else: + flow.current_state[key] = {} + + # Initialize predictive state handler if configured + predictive_handler: PredictiveStateHandler | None = None + if config.predict_state_config: + predictive_handler = PredictiveStateHandler( + predict_state_config=config.predict_state_config, + current_state=flow.current_state, + ) + + # Normalize messages + raw_messages = input_data.get("messages", []) + messages, snapshot_messages = normalize_agui_input_messages(raw_messages) + + # Check for structured output mode (skip text content) + skip_text = False + response_format = None + from agent_framework import ChatAgent + + if isinstance(agent, ChatAgent): + response_format = agent.default_options.get("response_format") + skip_text = response_format is not None + + # Handle empty messages (emit RunStarted immediately since no agent response) + if not messages: + logger.warning("No messages provided in AG-UI input") + yield RunStartedEvent(run_id=run_id, thread_id=thread_id) + yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) + return + + # Prepare tools + client_tools = convert_agui_tools_to_agent_framework(input_data.get("tools")) + server_tools = collect_server_tools(agent) + register_additional_client_tools(agent, client_tools) + tools = merge_tools(server_tools, client_tools) + + # Create thread (with service thread support) + if config.use_service_thread: + supplied_thread_id = input_data.get("thread_id") or input_data.get("threadId") + thread = AgentThread(service_thread_id=supplied_thread_id) + else: + thread = AgentThread() + + # Inject metadata for AG-UI orchestration + thread.metadata = { # type: ignore[attr-defined] + "ag_ui_thread_id": thread_id, + "ag_ui_run_id": run_id, + } + if flow.current_state: + thread.metadata["current_state"] = flow.current_state # type: ignore[attr-defined] + + # Build run kwargs + run_kwargs: dict[str, Any] = {"thread": thread} + if tools: + run_kwargs["tools"] = tools + + # Resolve approval responses (execute approved tools, replace approvals with results) + # This must happen before running the agent so it sees the tool results + tools_for_execution = tools if tools is not None else server_tools + await _resolve_approval_responses(messages, tools_for_execution, agent, run_kwargs) + + # Handle confirm_changes response (state confirmation flow - emit confirmation and stop) + if _is_confirm_changes_response(messages): + yield RunStartedEvent(run_id=run_id, thread_id=thread_id) + for event in _handle_step_based_approval(messages): + yield event + yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) + return + + # Inject state context message so the model knows current application state + # This is critical for shared state scenarios where the UI state needs to be visible + if config.state_schema and flow.current_state: + messages = _inject_state_context(messages, flow.current_state, config.state_schema) + + # Stream from agent - emit RunStarted after first update to get service IDs + run_started_emitted = False + all_updates: list[Any] = [] # Collect for structured output processing + async for update in agent.run_stream(messages, **run_kwargs): + # Collect updates for structured output processing + if response_format is not None: + all_updates.append(update) + + # Update IDs from service response on first update and emit RunStarted + if not run_started_emitted: + conv_id = get_conversation_id_from_update(update) + if conv_id: + thread_id = conv_id + if update.response_id: + run_id = update.response_id + # NOW emit RunStarted with proper IDs + yield RunStartedEvent(run_id=run_id, thread_id=thread_id) + # Emit PredictState custom event if configured + if config.predict_state_config: + predict_state_value = [ + { + "state_key": state_key, + "tool": cfg["tool"], + "tool_argument": cfg["tool_argument"], + } + for state_key, cfg in config.predict_state_config.items() + ] + yield CustomEvent(name="PredictState", value=predict_state_value) + # Emit initial state snapshot only if we have both state_schema and state + if config.state_schema and flow.current_state: + yield StateSnapshotEvent(snapshot=flow.current_state) + run_started_emitted = True + + # Emit events for each content item + for content in update.contents: + for event in _emit_content( + content, + flow, + predictive_handler, + skip_text, + config.require_confirmation, + ): + yield event + + # Stop if waiting for approval + if flow.waiting_for_approval: + break + + # If no updates at all, still emit RunStarted + if not run_started_emitted: + yield RunStartedEvent(run_id=run_id, thread_id=thread_id) + if config.predict_state_config: + predict_state_value = [ + { + "state_key": state_key, + "tool": cfg["tool"], + "tool_argument": cfg["tool_argument"], + } + for state_key, cfg in config.predict_state_config.items() + ] + yield CustomEvent(name="PredictState", value=predict_state_value) + if config.state_schema and flow.current_state: + yield StateSnapshotEvent(snapshot=flow.current_state) + + # Process structured output if response_format is set + if response_format is not None and all_updates: + from agent_framework import AgentResponse + from pydantic import BaseModel + + logger.info(f"Processing structured output, update count: {len(all_updates)}") + final_response = AgentResponse.from_agent_run_response_updates(all_updates, output_format_type=response_format) + + if final_response.value and isinstance(final_response.value, BaseModel): + response_dict = final_response.value.model_dump(mode="json", exclude_none=True) + logger.info(f"Received structured output keys: {list(response_dict.keys())}") + + # Extract state updates - if no state_schema, all non-message fields are state + state_keys = ( + set(config.state_schema.keys()) if config.state_schema else set(response_dict.keys()) - {"message"} + ) + state_updates = {k: v for k, v in response_dict.items() if k in state_keys} + + if state_updates: + flow.current_state.update(state_updates) + yield StateSnapshotEvent(snapshot=flow.current_state) + logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}") + + # Emit message field as text if present + if "message" in response_dict and response_dict["message"]: + message_id = generate_event_id() + yield TextMessageStartEvent(message_id=message_id, role="assistant") + yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"]) + yield TextMessageEndEvent(message_id=message_id) + logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") + + # Close any open message + if flow.message_id: + yield TextMessageEndEvent(message_id=flow.message_id) + + # Emit MessagesSnapshotEvent if we have tool calls or results + if flow.pending_tool_calls or flow.tool_results or flow.accumulated_text: + yield _build_messages_snapshot(flow, snapshot_messages) + + yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) + + +__all__ = ["FlowState", "run_agent_stream"] diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index f8f5c1db8a..3adfe494cc 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -595,6 +595,7 @@ async def stream_fn( assert len(tool_events) == 0 +@pytest.mark.skip(reason="confirmation_strategy feature removed in orchestrator rewrite") async def test_suppressed_summary_with_document_state(): """Test suppressed summary uses document state for confirmation message.""" from agent_framework.ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy From 0f07739eea71e9999cfba5ae29262a615eff46cc Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 21 Jan 2026 13:41:18 +0900 Subject: [PATCH 2/8] Refactoring --- .../ag-ui/agent_framework_ag_ui/__init__.py | 12 - .../ag-ui/agent_framework_ag_ui/_agent.py | 24 +- .../_confirmation_strategies.py | 217 ----- .../ag-ui/agent_framework_ag_ui/_endpoint.py | 14 +- .../ag-ui/agent_framework_ag_ui/_events.py | 585 ------------ .../_orchestration/_helpers.py | 172 +--- .../_orchestration/_state_manager.py | 106 --- .../_orchestration/_tooling.py | 26 +- .../agent_framework_ag_ui/_orchestrators.py | 802 ---------------- .../ag-ui/agent_framework_ag_ui/_run.py | 308 ++++++- .../agents/document_writer_agent.py | 13 +- .../agents/recipe_agent.py | 3 +- .../agents/task_planner_agent.py | 3 +- .../server/main.py | 7 +- .../tests/test_backend_tool_rendering.py | 129 --- ...t_confirmation_strategies_comprehensive.py | 275 ------ .../ag-ui/tests/test_document_writer_flow.py | 236 ----- .../ag-ui/tests/test_events_comprehensive.py | 827 ----------------- .../ag-ui/tests/test_human_in_the_loop.py | 180 ---- .../ag-ui/tests/test_orchestrators.py | 307 ------ .../tests/test_orchestrators_coverage.py | 872 ------------------ .../packages/ag-ui/tests/test_shared_state.py | 108 --- .../ag-ui/tests/test_state_manager.py | 51 - .../packages/ag-ui/tests/utils_test_ag_ui.py | 12 +- .../packages/core/agent_framework/_tools.py | 8 + 25 files changed, 350 insertions(+), 4947 deletions(-) delete mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py delete mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_events.py delete mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py delete mode 100644 python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py delete mode 100644 python/packages/ag-ui/tests/test_backend_tool_rendering.py delete mode 100644 python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py delete mode 100644 python/packages/ag-ui/tests/test_document_writer_flow.py delete mode 100644 python/packages/ag-ui/tests/test_events_comprehensive.py delete mode 100644 python/packages/ag-ui/tests/test_human_in_the_loop.py delete mode 100644 python/packages/ag-ui/tests/test_orchestrators.py delete mode 100644 python/packages/ag-ui/tests/test_orchestrators_coverage.py delete mode 100644 python/packages/ag-ui/tests/test_shared_state.py delete mode 100644 python/packages/ag-ui/tests/test_state_manager.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py index c6dc575d36..f2c2ba7fe1 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/__init__.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/__init__.py @@ -6,13 +6,6 @@ from ._agent import AgentFrameworkAgent from ._client import AGUIChatClient -from ._confirmation_strategies import ( - ConfirmationStrategy, - DefaultConfirmationStrategy, - DocumentWriterConfirmationStrategy, - RecipeConfirmationStrategy, - TaskPlannerConfirmationStrategy, -) from ._endpoint import add_agent_framework_fastapi_endpoint from ._event_converters import AGUIEventConverter from ._http_service import AGUIHttpService @@ -35,13 +28,8 @@ "AGUIHttpService", "AGUIRequest", "AgentState", - "ConfirmationStrategy", - "DefaultConfirmationStrategy", "PredictStateConfig", "RunMetadata", - "TaskPlannerConfirmationStrategy", - "RecipeConfirmationStrategy", - "DocumentWriterConfirmationStrategy", "DEFAULT_TAGS", "__version__", ] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index f04f498168..56488df876 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -20,22 +20,19 @@ def __init__( predict_state_config: dict[str, dict[str, str]] | None = None, use_service_thread: bool = False, require_confirmation: bool = True, - confirmation_strategy: Any | None = None, ): """Initialize agent configuration. Args: state_schema: Optional state schema for state management; accepts dict or Pydantic model/class - predict_state_config: Configuration for predictive state updates (currently unused in simplified impl) - use_service_thread: Whether the agent thread is service-managed (currently unused) - require_confirmation: Whether predictive updates require confirmation (currently unused) - confirmation_strategy: Optional strategy for generating confirmation messages + predict_state_config: Configuration for predictive state updates + use_service_thread: Whether the agent thread is service-managed + require_confirmation: Whether predictive updates require user confirmation before applying """ self.state_schema = self._normalize_state_schema(state_schema) self.predict_state_config = predict_state_config or {} self.use_service_thread = use_service_thread self.require_confirmation = require_confirmation - self.confirmation_strategy = confirmation_strategy @staticmethod def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]: @@ -55,12 +52,12 @@ def _normalize_state_schema(state_schema: Any | None) -> dict[str, Any]: base_model_type = None if base_model_type is not None and isinstance(state_schema, base_model_type): - schema_dict = state_schema.__class__.model_json_schema() + schema_dict = state_schema.__class__.model_json_schema() # type: ignore[union-attr] return schema_dict.get("properties", {}) or {} if base_model_type is not None and isinstance(state_schema, type) and issubclass(state_schema, base_model_type): - schema_dict = state_schema.model_json_schema() - return schema_dict.get("properties", {}) or {} + schema_dict = state_schema.model_json_schema() # type: ignore[union-attr] + return schema_dict.get("properties", {}) or {} # type: ignore return {} @@ -80,9 +77,7 @@ def __init__( state_schema: Any | None = None, predict_state_config: dict[str, dict[str, str]] | None = None, require_confirmation: bool = True, - orchestrators: Any = None, # Deprecated, kept for backwards compatibility use_service_thread: bool = False, - confirmation_strategy: Any = None, # Deprecated, kept for backwards compatibility ): """Initialize the AG-UI compatible agent wrapper. @@ -91,11 +86,9 @@ def __init__( name: Optional name for the agent description: Optional description state_schema: Optional state schema for state management; accepts dict or Pydantic model/class - predict_state_config: Configuration for predictive state updates (optional feature) - require_confirmation: Whether predictive updates require confirmation - orchestrators: Deprecated - kept for backwards compatibility + predict_state_config: Configuration for predictive state updates + require_confirmation: Whether predictive updates require user confirmation before applying use_service_thread: Whether the agent thread is service-managed - confirmation_strategy: Strategy for generating confirmation messages """ self.agent = agent self.name = name or getattr(agent, "name", "agent") @@ -106,7 +99,6 @@ def __init__( predict_state_config=predict_state_config, use_service_thread=use_service_thread, require_confirmation=require_confirmation, - confirmation_strategy=confirmation_strategy, ) async def run_agent( diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py b/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py deleted file mode 100644 index 35e648c100..0000000000 --- a/python/packages/ag-ui/agent_framework_ag_ui/_confirmation_strategies.py +++ /dev/null @@ -1,217 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Confirmation strategies for human-in-the-loop approval flows. - -Each agent can provide a custom confirmation strategy to generate domain-specific -messages when users approve or reject changes/actions. -""" - -from abc import ABC, abstractmethod -from typing import Any - - -class ConfirmationStrategy(ABC): - """Strategy for generating confirmation messages during human-in-the-loop flows. - - Subclasses must define the message properties. The methods use those properties - by default, but can be overridden for complete customization. - """ - - @property - @abstractmethod - def approval_header(self) -> str: - """Header for approval accepted message. Must be overridden.""" - ... - - @property - @abstractmethod - def approval_footer(self) -> str: - """Footer for approval accepted message. Must be overridden.""" - ... - - @property - @abstractmethod - def rejection_message(self) -> str: - """Message when user rejects. Must be overridden.""" - ... - - @property - @abstractmethod - def state_confirmed_message(self) -> str: - """Message when state is confirmed. Must be overridden.""" - ... - - @property - @abstractmethod - def state_rejected_message(self) -> str: - """Message when state is rejected. Must be overridden.""" - ... - - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - """Generate message when user approves function execution. - - Default implementation uses header/footer properties. - Override for complete customization. - - Args: - steps: List of approved steps with 'description', 'status', etc. - - Returns: - Message to display to user - """ - enabled_steps = [s for s in steps if s.get("status") == "enabled"] - message_parts = [self.approval_header.format(count=len(enabled_steps))] - for i, step in enumerate(enabled_steps, 1): - message_parts.append(f"{i}. {step['description']}\n") - message_parts.append(self.approval_footer) - return "".join(message_parts) - - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - """Generate message when user rejects function execution. - - Args: - steps: List of rejected steps - - Returns: - Message to display to user - """ - return self.rejection_message - - def on_state_confirmed(self) -> str: - """Generate message when user confirms predictive state changes. - - Returns: - Message to display to user - """ - return self.state_confirmed_message - - def on_state_rejected(self) -> str: - """Generate message when user rejects predictive state changes. - - Returns: - Message to display to user - """ - return self.state_rejected_message - - -class DefaultConfirmationStrategy(ConfirmationStrategy): - """Generic confirmation messages suitable for most agents.""" - - @property - def approval_header(self) -> str: - return "Executing {count} approved steps:\n\n" - - @property - def approval_footer(self) -> str: - return "\nAll steps completed successfully!" - - @property - def rejection_message(self) -> str: - return "No problem! What would you like me to change about the plan?" - - @property - def state_confirmed_message(self) -> str: - return "Changes confirmed and applied successfully!" - - @property - def state_rejected_message(self) -> str: - return "No problem! What would you like me to change?" - - -class TaskPlannerConfirmationStrategy(ConfirmationStrategy): - """Domain-specific confirmation messages for task planning agents.""" - - @property - def approval_header(self) -> str: - return "Executing your requested tasks:\n\n" - - @property - def approval_footer(self) -> str: - return "\nAll tasks completed successfully!" - - @property - def rejection_message(self) -> str: - return "No problem! Let me revise the plan. What would you like me to change?" - - @property - def state_confirmed_message(self) -> str: - return "Tasks confirmed and ready to execute!" - - @property - def state_rejected_message(self) -> str: - return "No problem! How should I adjust the task list?" - - -class RecipeConfirmationStrategy(ConfirmationStrategy): - """Domain-specific confirmation messages for recipe agents.""" - - @property - def approval_header(self) -> str: - return "Updating your recipe:\n\n" - - @property - def approval_footer(self) -> str: - return "\nRecipe updated successfully!" - - @property - def rejection_message(self) -> str: - return "No problem! What ingredients or steps should I change?" - - @property - def state_confirmed_message(self) -> str: - return "Recipe changes applied successfully!" - - @property - def state_rejected_message(self) -> str: - return "No problem! What would you like me to adjust in the recipe?" - - -class DocumentWriterConfirmationStrategy(ConfirmationStrategy): - """Domain-specific confirmation messages for document writing agents.""" - - @property - def approval_header(self) -> str: - return "Applying your edits:\n\n" - - @property - def approval_footer(self) -> str: - return "\nDocument updated successfully!" - - @property - def rejection_message(self) -> str: - return "No problem! Which changes should I keep or modify?" - - @property - def state_confirmed_message(self) -> str: - return "Document edits applied!" - - @property - def state_rejected_message(self) -> str: - return "No problem! What should I change about the document?" - - -def apply_confirmation_strategy( - strategy: ConfirmationStrategy | None, - accepted: bool, - steps: list[dict[str, Any]], -) -> str: - """Apply a confirmation strategy to generate a message. - - This helper consolidates the pattern used in multiple orchestrators. - - Args: - strategy: Strategy to use, or None for default - accepted: Whether the user approved - steps: List of steps (may be empty for state confirmations) - - Returns: - Generated message string - """ - if strategy is None: - strategy = DefaultConfirmationStrategy() - - if not steps: - # State confirmation (no steps) - return strategy.on_state_confirmed() if accepted else strategy.on_state_rejected() - # Step-based approval - return strategy.on_approval_accepted(steps) if accepted else strategy.on_approval_rejected(steps) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py index 7948d4f935..07e818882d 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py @@ -82,12 +82,14 @@ async def event_generator(): event_count = 0 async for event in wrapped_agent.run_agent(input_data): event_count += 1 - logger.debug(f"[{path}] Event {event_count}: {type(event).__name__}") - - # Log event payload for debugging - if hasattr(event, "model_dump"): - event_data = event.model_dump(exclude_none=True) - logger.debug(f"[{path}] Event payload: {event_data}") + event_type_name = getattr(event, "type", type(event).__name__) + # Log important events at INFO level + if "TOOL_CALL" in str(event_type_name) or "RUN" in str(event_type_name): + if hasattr(event, "model_dump"): + event_data = event.model_dump(exclude_none=True) + logger.info(f"[{path}] Event {event_count}: {event_type_name} - {event_data}") + else: + logger.info(f"[{path}] Event {event_count}: {event_type_name}") encoded = encoder.encode(event) logger.debug( diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_events.py b/python/packages/ag-ui/agent_framework_ag_ui/_events.py deleted file mode 100644 index 34c1e3ed86..0000000000 --- a/python/packages/ag-ui/agent_framework_ag_ui/_events.py +++ /dev/null @@ -1,585 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Event bridge for converting Agent Framework events to AG-UI protocol.""" - -import json -import logging -import re -from copy import deepcopy -from typing import Any - -from ag_ui.core import ( - BaseEvent, - CustomEvent, - RunFinishedEvent, - RunStartedEvent, - StateDeltaEvent, - StateSnapshotEvent, - TextMessageContentEvent, - TextMessageEndEvent, - TextMessageStartEvent, - ToolCallArgsEvent, - ToolCallEndEvent, - ToolCallResultEvent, - ToolCallStartEvent, -) -from agent_framework import ( - AgentResponseUpdate, - Content, - prepare_function_call_results, -) - -from ._utils import extract_state_from_tool_args, generate_event_id, safe_json_parse - -logger = logging.getLogger(__name__) - - -class AgentFrameworkEventBridge: - """Converts Agent Framework responses to AG-UI events.""" - - def __init__( - self, - run_id: str, - thread_id: str, - predict_state_config: dict[str, dict[str, str]] | None = None, - current_state: dict[str, Any] | None = None, - skip_text_content: bool = False, - require_confirmation: bool = True, - approval_tool_name: str | None = None, - ) -> None: - """ - Initialize the event bridge. - - Args: - run_id: The run identifier. - thread_id: The thread identifier. - predict_state_config: Configuration for predictive state updates. - Format: {"state_key": {"tool": "tool_name", "tool_argument": "arg_name"}} - current_state: Reference to the current state dict for tracking updates. - skip_text_content: If True, skip emitting TextMessageContentEvents (for structured outputs). - require_confirmation: Whether predictive state updates require user confirmation. - """ - self.run_id = run_id - self.thread_id = thread_id - self.current_message_id: str | None = None - self.current_tool_call_id: str | None = None - self.current_tool_call_name: str | None = None # Track the tool name across streaming chunks - self.predict_state_config = predict_state_config or {} - self.current_state = current_state or {} - self.pending_state_updates: dict[str, Any] = {} # Track updates from tool calls - self.skip_text_content = skip_text_content - self.require_confirmation = require_confirmation - self.approval_tool_name = approval_tool_name - - # For predictive state updates: accumulate streaming arguments - self.streaming_tool_args: str = "" # Accumulated JSON string - self.last_emitted_state: dict[str, Any] = {} # Track last emitted state to avoid duplicates - self.state_delta_count: int = 0 # Counter for sampling log output - self.should_stop_after_confirm: bool = False # Flag to stop run after confirm_changes - self.suppressed_summary: str = "" # Store LLM summary to show after confirmation - - async def from_agent_run_update(self, update: AgentResponseUpdate) -> list[BaseEvent]: - """ - Convert an AgentResponseUpdate to AG-UI events. - - Args: - update: The agent run update to convert. - - Returns: - List of AG-UI events. - """ - events: list[BaseEvent] = [] - - logger.info(f"Processing AgentRunUpdate with {len(update.contents)} content items") - for idx, content in enumerate(update.contents): - logger.info(f" Content {idx}: type={type(content).__name__}") - match content.type: - case "text": - events.extend(self._handle_text_content(content)) - case "function_call": - events.extend(self._handle_function_call_content(content)) - case "function_result": - events.extend(self._handle_function_result_content(content)) - case "function_approval_request": - events.extend(self._handle_function_approval_request_content(content)) - case _: - logger.warning(f" Unsupported content type: {content.type}, skipping.") - return events - - def _handle_text_content(self, content: Content) -> list[BaseEvent]: - events: list[BaseEvent] = [] - logger.info(f" TextContent found: length={len(content.text)}") # type: ignore[arg-type] - logger.info( - " Flags: skip_text_content=%s, should_stop_after_confirm=%s", - self.skip_text_content, - self.should_stop_after_confirm, - ) - - if self.skip_text_content: - logger.info(" SKIPPING TextContent: skip_text_content is True") - return events - - if self.should_stop_after_confirm: - logger.info(" SKIPPING TextContent: waiting for confirm_changes response") - self.suppressed_summary += content.text # type: ignore[operator] - logger.info(f" Suppressed summary length={len(self.suppressed_summary)}") - return events - - # Skip empty text chunks to avoid emitting - # TextMessageContentEvent with an empty `delta` which fails - # Pydantic validation (AG-UI requires non-empty strings). - if not content.text: - logger.info(" SKIPPING TextContent: empty chunk") - return events - - if not self.current_message_id: - self.current_message_id = generate_event_id() - start_event = TextMessageStartEvent( - message_id=self.current_message_id, - role="assistant", - ) - logger.info(f" EMITTING TextMessageStartEvent with message_id={self.current_message_id}") - events.append(start_event) - - event = TextMessageContentEvent( - message_id=self.current_message_id, - delta=content.text, - ) - logger.info(f" EMITTING TextMessageContentEvent with text_len={len(content.text)}") - events.append(event) - return events - - def _handle_function_call_content(self, content: Content) -> list[BaseEvent]: - events: list[BaseEvent] = [] - if content.name: - logger.debug(f"Tool call: {content.name} (call_id: {content.call_id})") - - if not content.name and not content.call_id and not self.current_tool_call_name: - args_length = len(str(content.arguments)) if content.arguments else 0 - logger.warning(f"Content missing name and call_id. args_length={args_length}") - - tool_call_id = self._coalesce_tool_call_id(content) - # Only emit ToolCallStartEvent once per tool call (when it's a new tool call) - if content.name and tool_call_id != self.current_tool_call_id: - self.streaming_tool_args = "" - self.state_delta_count = 0 - self.current_tool_call_id = tool_call_id - self.current_tool_call_name = content.name - - tool_start_event = ToolCallStartEvent( - tool_call_id=tool_call_id, - tool_call_name=content.name, - parent_message_id=self.current_message_id, - ) - logger.info(f"Emitting ToolCallStartEvent with name='{content.name}', id='{tool_call_id}'") - events.append(tool_start_event) - elif tool_call_id: - self.current_tool_call_id = tool_call_id - - if content.arguments: - delta_str = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) - logger.info(f"Emitting ToolCallArgsEvent with delta_length={len(delta_str)}, id='{tool_call_id}'") - args_event = ToolCallArgsEvent( - tool_call_id=tool_call_id, - delta=delta_str, - ) - events.append(args_event) - - events.extend(self._emit_predictive_state_deltas(delta_str)) - - return events - - def _coalesce_tool_call_id(self, content: Content) -> str: - if content.call_id: - return content.call_id - if self.current_tool_call_id: - return self.current_tool_call_id - return generate_event_id() - - def _emit_predictive_state_deltas(self, argument_chunk: str) -> list[BaseEvent]: - events: list[BaseEvent] = [] - if not self.current_tool_call_name or not self.predict_state_config: - return events - - self.streaming_tool_args += argument_chunk - logger.debug( - "Predictive state: accumulated %s chars for tool '%s'", - len(self.streaming_tool_args), - self.current_tool_call_name, - ) - - parsed_args = safe_json_parse(self.streaming_tool_args) - if parsed_args is None: - for state_key, config in self.predict_state_config.items(): - if config["tool"] != self.current_tool_call_name: - continue - tool_arg_name = config["tool_argument"] - pattern = rf'"{re.escape(tool_arg_name)}":\s*"([^"]*)' - match = re.search(pattern, self.streaming_tool_args) - - if match: - partial_value = match.group(1).replace("\\n", "\n").replace('\\"', '"').replace("\\\\", "\\") - - if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != partial_value: - state_delta_event = StateDeltaEvent( - delta=[ - { - "op": "replace", - "path": f"/{state_key}", - "value": partial_value, - } - ], - ) - - self.state_delta_count += 1 - if self.state_delta_count % 10 == 1: - logger.info( - "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value_length=%s", - self.state_delta_count, - state_key, - state_key, - len(str(partial_value)), - ) - elif self.state_delta_count % 100 == 0: - logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") - - events.append(state_delta_event) - self.last_emitted_state[state_key] = partial_value - self.pending_state_updates[state_key] = partial_value - - if parsed_args: - for state_key, config in self.predict_state_config.items(): - if config["tool"] != self.current_tool_call_name: - continue - tool_arg_name = config["tool_argument"] - - state_value = extract_state_from_tool_args(parsed_args, tool_arg_name) - if state_value is None: - continue - - if state_key not in self.last_emitted_state or self.last_emitted_state[state_key] != state_value: - state_delta_event = StateDeltaEvent( - delta=[ - { - "op": "replace", - "path": f"/{state_key}", - "value": state_value, - } - ], - ) - - self.state_delta_count += 1 - if self.state_delta_count % 10 == 1: - logger.info( - "StateDeltaEvent #%s for '%s': op=replace, path=/%s, value_length=%s", - self.state_delta_count, - state_key, - state_key, - len(str(state_value)), - ) - elif self.state_delta_count % 100 == 0: - logger.info(f"StateDeltaEvent #{self.state_delta_count} emitted") - - events.append(state_delta_event) - self.last_emitted_state[state_key] = state_value - self.pending_state_updates[state_key] = state_value - return events - - def _handle_function_result_content(self, content: Content) -> list[BaseEvent]: - events: list[BaseEvent] = [] - if content.call_id: - end_event = ToolCallEndEvent( - tool_call_id=content.call_id, - ) - logger.info(f"Emitting ToolCallEndEvent for completed tool call '{content.call_id}'") - events.append(end_event) - - if self.state_delta_count > 0: - logger.info( - "Tool call '%s' complete: emitted %s StateDeltaEvents total", - content.call_id, - self.state_delta_count, - ) - - self.streaming_tool_args = "" - self.state_delta_count = 0 - - result_message_id = generate_event_id() - result_content = prepare_function_call_results(content.result) - - result_event = ToolCallResultEvent( - message_id=result_message_id, - tool_call_id=content.call_id, # type: ignore[arg-type] - content=result_content, - role="tool", - ) - events.append(result_event) - events.extend(self._emit_state_snapshot_and_confirmation()) - - return events - - def _emit_state_snapshot_and_confirmation(self) -> list[BaseEvent]: - events: list[BaseEvent] = [] - if self.pending_state_updates: - for key, value in self.pending_state_updates.items(): - self.current_state[key] = value - - logger.info(f"Emitting StateSnapshotEvent with keys: {list(self.current_state.keys())}") - if "recipe" in self.current_state: - recipe = self.current_state["recipe"] - logger.info( - "Recipe fields: title=%s, skill_level=%s, ingredients_count=%s, instructions_count=%s", - recipe.get("title"), - recipe.get("skill_level"), - len(recipe.get("ingredients", [])), - len(recipe.get("instructions", [])), - ) - - state_snapshot_event = StateSnapshotEvent( - snapshot=self.current_state, - ) - events.append(state_snapshot_event) - - tool_was_predictive = False - logger.debug( - "Checking predictive state: current_tool='%s', predict_config=%s", - self.current_tool_call_name, - list(self.predict_state_config.keys()) if self.predict_state_config else "None", - ) - for state_key, config in self.predict_state_config.items(): - if self.current_tool_call_name and config["tool"] == self.current_tool_call_name: - logger.info( - "Tool '%s' matches predictive config for state key '%s'", - self.current_tool_call_name, - state_key, - ) - tool_was_predictive = True - break - - if tool_was_predictive and self.require_confirmation: - events.extend(self._emit_confirm_changes_tool_call()) - elif tool_was_predictive: - logger.info("Skipping confirm_changes - require_confirmation is False") - - self.pending_state_updates.clear() - self.last_emitted_state = deepcopy(self.current_state) - self.current_tool_call_name = None - return events - - def _emit_confirm_changes_tool_call(self, function_call: Content | None = None) -> list[BaseEvent]: - """Emit a confirm_changes tool call for Dojo UI compatibility. - - Args: - function_call: Optional function call that needs confirmation. - If provided, includes function info in the confirm_changes args - so Dojo UI can display what's being confirmed. - """ - events: list[BaseEvent] = [] - confirm_call_id = generate_event_id() - logger.info("Emitting confirm_changes tool call for predictive update") - - confirm_start = ToolCallStartEvent( - tool_call_id=confirm_call_id, - tool_call_name="confirm_changes", - parent_message_id=self.current_message_id, - ) - events.append(confirm_start) - - # Include function info if this is for a function approval - # This helps Dojo UI display meaningful confirmation info - if function_call: - args_dict = { - "function_name": function_call.name, - "function_call_id": function_call.call_id, - "function_arguments": function_call.parse_arguments() or {}, - "steps": [ - { - "description": f"Execute {function_call.name}", - "status": "enabled", - } - ], - } - args_json = json.dumps(args_dict) - else: - args_json = "{}" - - confirm_args = ToolCallArgsEvent( - tool_call_id=confirm_call_id, - delta=args_json, - ) - events.append(confirm_args) - - confirm_end = ToolCallEndEvent( - tool_call_id=confirm_call_id, - ) - events.append(confirm_end) - - self.should_stop_after_confirm = True - logger.info("Set flag to stop run after confirm_changes") - return events - - def _emit_function_approval_tool_call(self, function_call: Content) -> list[BaseEvent]: - """Emit a tool call that can drive UI approval for function requests.""" - tool_call_name = "confirm_changes" - if self.approval_tool_name and self.approval_tool_name != function_call.name: - tool_call_name = self.approval_tool_name - - tool_call_id = generate_event_id() - tool_start = ToolCallStartEvent( - tool_call_id=tool_call_id, - tool_call_name=tool_call_name, - parent_message_id=self.current_message_id, - ) - events: list[BaseEvent] = [tool_start] - - args_dict = { - "function_name": function_call.name, - "function_call_id": function_call.call_id, - "function_arguments": function_call.parse_arguments() or {}, - "steps": [ - { - "description": f"Execute {function_call.name}", - "status": "enabled", - } - ], - } - args_json = json.dumps(args_dict) - - events.append( - ToolCallArgsEvent( - tool_call_id=tool_call_id, - delta=args_json, - ) - ) - events.append( - ToolCallEndEvent( - tool_call_id=tool_call_id, - ) - ) - - self.should_stop_after_confirm = True - logger.info("Set flag to stop run after confirm_changes") - return events - - def _handle_function_approval_request_content(self, content: Content) -> list[BaseEvent]: - events: list[BaseEvent] = [] - logger.info("=== FUNCTION APPROVAL REQUEST ===") - logger.info(f" Function: {content.function_call.name}") # type: ignore[union-attr] - logger.info(f" Call ID: {content.function_call.call_id}") # type: ignore[union-attr] - - parsed_args = content.function_call.parse_arguments() # type: ignore[union-attr] - parsed_arg_keys = list(parsed_args.keys()) if parsed_args else "None" - logger.info(f" Parsed args keys: {parsed_arg_keys}") - - if parsed_args and self.predict_state_config: - logger.info( - " Checking predict_state_config keys: %s", - list(self.predict_state_config.keys()) if self.predict_state_config else "None", - ) - for state_key, config in self.predict_state_config.items(): - if config["tool"] != content.function_call.name: # type: ignore[union-attr] - continue - tool_arg_name = config["tool_argument"] - logger.info( - " MATCHED tool '%s' for state key '%s', arg='%s'", - content.function_call.name, # type: ignore[union-attr] - state_key, - tool_arg_name, - ) - - state_value = extract_state_from_tool_args(parsed_args, tool_arg_name) - if state_value is None: - logger.warning(f" Tool argument '{tool_arg_name}' not found in parsed args") - continue - - self.current_state[state_key] = state_value - logger.info("Emitting StateSnapshotEvent for key '%s', value type: %s", state_key, type(state_value)) # type: ignore - state_snapshot = StateSnapshotEvent( - snapshot=self.current_state, - ) - events.append(state_snapshot) - - if content.function_call.call_id: # type: ignore[union-attr] - end_event = ToolCallEndEvent( - tool_call_id=content.function_call.call_id, # type: ignore[union-attr] - ) - logger.info(f"Emitting ToolCallEndEvent for approval-required tool '{content.function_call.call_id}'") # type: ignore[union-attr] - events.append(end_event) - - # Emit the function_approval_request custom event for UI implementations that support it - approval_event = CustomEvent( - name="function_approval_request", - value={ - "id": content.id, - "function_call": { - "call_id": content.function_call.call_id, # type: ignore[union-attr] - "name": content.function_call.name, # type: ignore[union-attr] - "arguments": content.function_call.parse_arguments(), # type: ignore[union-attr] - }, - }, - ) - logger.info(f"Emitting function_approval_request custom event for '{content.function_call.name}'") # type: ignore[union-attr] - events.append(approval_event) - - # Emit a UI-friendly approval tool call for function approvals. - if self.require_confirmation: - events.extend(self._emit_function_approval_tool_call(content.function_call)) # type: ignore[arg-type] - - # Signal orchestrator to stop the run and wait for user approval response - self.should_stop_after_confirm = True - logger.info("Set flag to stop run - waiting for function approval response") - return events - - def create_run_started_event(self) -> RunStartedEvent: - """Create a run started event.""" - return RunStartedEvent( - run_id=self.run_id, - thread_id=self.thread_id, - ) - - def create_run_finished_event(self, result: Any = None) -> RunFinishedEvent: - """Create a run finished event.""" - return RunFinishedEvent( - run_id=self.run_id, - thread_id=self.thread_id, - result=result, - ) - - def create_message_start_event(self, message_id: str, role: str = "assistant") -> TextMessageStartEvent: - """Create a message start event.""" - return TextMessageStartEvent( - message_id=message_id, - role=role, # type: ignore - ) - - def create_message_end_event(self, message_id: str) -> TextMessageEndEvent: - """Create a message end event.""" - return TextMessageEndEvent( - message_id=message_id, - ) - - def create_state_snapshot_event(self, state: dict[str, Any]) -> StateSnapshotEvent: - """Create a state snapshot event. - - Args: - state: The complete state snapshot. - - Returns: - StateSnapshotEvent. - """ - return StateSnapshotEvent( - snapshot=state, - ) - - def create_state_delta_event(self, delta: list[dict[str, Any]]) -> StateDeltaEvent: - """Create a state delta event using JSON Patch format (RFC 6902). - - Args: - delta: List of JSON Patch operations. - - Returns: - StateDeltaEvent. - """ - return StateDeltaEvent( - delta=delta, - ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py index b2e3b1d5eb..d38c125092 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py @@ -1,12 +1,15 @@ # Copyright (c) Microsoft. All rights reserved. -"""Helper functions for orchestration logic.""" +"""Helper functions for orchestration logic. + +Most orchestration helpers have been moved inline to _run.py. +This module retains utilities that may be useful for testing or extensions. +""" import json import logging -from typing import TYPE_CHECKING, Any +from typing import Any -from ag_ui.core import StateSnapshotEvent from agent_framework import ( ChatMessage, Content, @@ -14,10 +17,6 @@ from .._utils import get_role_value, safe_json_parse -if TYPE_CHECKING: - from .._events import AgentFrameworkEventBridge - from ._state_manager import StateManager - logger = logging.getLogger(__name__) @@ -111,53 +110,6 @@ def tool_name_for_call_id( return str(name) if name else None -def tool_calls_match_state( - provider_messages: list[ChatMessage], - state_manager: "StateManager", -) -> bool: - """Check if tool calls in messages match current state. - - Args: - provider_messages: Messages to check - state_manager: State manager with config and current state - - Returns: - True if tool calls match state configuration - """ - if not state_manager.predict_state_config or not state_manager.current_state: - return False - - for state_key, config in state_manager.predict_state_config.items(): - tool_name = config["tool"] - tool_arg_name = config["tool_argument"] - tool_args: dict[str, Any] | None = None - - for msg in reversed(provider_messages): - if get_role_value(msg) != "assistant": - continue - for content in msg.contents: - if content.type == "function_call" and content.name == tool_name: - tool_args = safe_json_parse(content.arguments) - break - if tool_args is not None: - break - - if not tool_args: - return False - - if tool_arg_name == "*": - state_value = tool_args - elif tool_arg_name in tool_args: - state_value = tool_args[tool_arg_name] - else: - return False - - if state_manager.current_state.get(state_key) != state_value: - return False - - return True - - def schema_has_steps(schema: Any) -> bool: """Check if a schema has a steps array property. @@ -202,45 +154,10 @@ def select_approval_tool_name(client_tools: list[Any] | None) -> str | None: return None -def select_messages_to_run( - provider_messages: list[ChatMessage], - state_manager: "StateManager", -) -> list[ChatMessage]: - """Select and prepare messages for agent execution. - - Injects state context message when appropriate. - - Args: - provider_messages: Original messages from client - state_manager: State manager instance - - Returns: - Messages ready for agent execution - """ - if not provider_messages: - return [] - - is_new_user_turn = get_role_value(provider_messages[-1]) == "user" - conversation_has_tool_calls = tool_calls_match_state(provider_messages, state_manager) - state_context_msg = state_manager.state_context_message( - is_new_user_turn=is_new_user_turn, conversation_has_tool_calls=conversation_has_tool_calls - ) - if not state_context_msg: - return list(provider_messages) - - messages_to_run = [msg for msg in provider_messages if not is_state_context_message(msg)] - if pending_tool_call_ids(messages_to_run): - return messages_to_run - - insert_index = len(messages_to_run) - 1 if is_new_user_turn else len(messages_to_run) - if insert_index < 0: - insert_index = 0 - messages_to_run.insert(insert_index, state_context_msg) - return messages_to_run - - def build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: - """Build metadata dict with truncated string values. + """Build metadata dict with truncated string values for Azure compatibility. + + Azure has a 512 character limit per metadata value. Args: thread_metadata: Raw metadata dict @@ -259,63 +176,6 @@ def build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any return safe_metadata -def collect_approved_state_snapshots( - provider_messages: list[ChatMessage], - predict_state_config: dict[str, dict[str, str]] | None, - current_state: dict[str, Any], - event_bridge: "AgentFrameworkEventBridge", -) -> list[StateSnapshotEvent]: - """Collect state snapshots from approved function calls. - - Args: - provider_messages: Messages containing approvals - predict_state_config: Predictive state configuration - current_state: Current state dict (will be mutated) - event_bridge: Event bridge for creating events - - Returns: - List of state snapshot events - """ - if not predict_state_config: - return [] - - events: list[StateSnapshotEvent] = [] - for msg in provider_messages: - if get_role_value(msg) != "user": - continue - for content in msg.contents: - if content.type == "function_approval_response": - if not content.function_call or not content.approved: - continue - parsed_args = content.function_call.parse_arguments() - state_args = None - if content.additional_properties: - state_args = content.additional_properties.get("ag_ui_state_args") - if not isinstance(state_args, dict): - state_args = parsed_args - if not state_args: - continue - for state_key, config in predict_state_config.items(): - if config["tool"] != content.function_call.name: - continue - tool_arg_name = config["tool_argument"] - if tool_arg_name == "*": - state_value = state_args - elif isinstance(state_args, dict) and tool_arg_name in state_args: - state_value = state_args[tool_arg_name] - else: - continue - current_state[state_key] = state_value - event_bridge.current_state[state_key] = state_value - logger.info( - f"Emitting StateSnapshotEvent for approved state key '{state_key}' " - f"with {len(state_value) if isinstance(state_value, list) else 'N/A'} items" - ) - events.append(StateSnapshotEvent(snapshot=current_state)) - break - return events - - def latest_approval_response(messages: list[ChatMessage]) -> Content | None: """Get the latest approval response from messages. @@ -384,3 +244,17 @@ def is_step_based_approval( if config.get("tool") == tool_name and config.get("tool_argument") == "steps": return True return False + + +__all__ = [ + "pending_tool_call_ids", + "is_state_context_message", + "ensure_tool_call_entry", + "tool_name_for_call_id", + "schema_has_steps", + "select_approval_tool_name", + "build_safe_metadata", + "latest_approval_response", + "approval_steps", + "is_step_based_approval", +] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py deleted file mode 100644 index 05cc55228d..0000000000 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_state_manager.py +++ /dev/null @@ -1,106 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""State orchestration utilities.""" - -import json -from typing import Any - -from ag_ui.core import CustomEvent, EventType -from agent_framework import ChatMessage, Content - - -class StateManager: - """Coordinates state defaults, snapshots, and structured updates.""" - - def __init__( - self, - state_schema: dict[str, Any] | None, - predict_state_config: dict[str, dict[str, str]] | None, - require_confirmation: bool, - ) -> None: - self.state_schema = state_schema or {} - self.predict_state_config = predict_state_config or {} - self.require_confirmation = require_confirmation - self.current_state: dict[str, Any] = {} - self._state_from_input: bool = False - - def initialize(self, initial_state: dict[str, Any] | None) -> dict[str, Any]: - """Initialize state with schema defaults.""" - self._state_from_input = initial_state is not None - self.current_state = (initial_state or {}).copy() - self._apply_schema_defaults() - return self.current_state - - def predict_state_event(self) -> CustomEvent | None: - """Create predict-state custom event when configured.""" - if not self.predict_state_config: - return None - - predict_state_value = [ - { - "state_key": state_key, - "tool": config["tool"], - "tool_argument": config["tool_argument"], - } - for state_key, config in self.predict_state_config.items() - ] - - return CustomEvent( - type=EventType.CUSTOM, - name="PredictState", - value=predict_state_value, - ) - - def initial_snapshot_event(self, event_bridge: Any) -> Any: - """Emit initial snapshot when schema and state present.""" - if not self.state_schema: - return None - self._apply_schema_defaults() - return event_bridge.create_state_snapshot_event(self.current_state) - - def state_context_message(self, is_new_user_turn: bool, conversation_has_tool_calls: bool) -> ChatMessage | None: - """Inject state context only when starting a new user turn.""" - if not self.current_state or not self.state_schema: - return None - if not is_new_user_turn: - return None - if conversation_has_tool_calls and not self._state_from_input: - return None - - state_json = json.dumps(self.current_state, indent=2) - return ChatMessage( - role="system", - contents=[ - Content.from_text( - text=( - "Current state of the application:\n" - f"{state_json}\n\n" - "When modifying state, you MUST include ALL existing data plus your changes.\n" - "For example, if adding one new item to a list, include ALL existing items PLUS the one new item.\n" - "Never replace existing data - always preserve and append or merge." - ) - ) - ], - ) - - def extract_state_updates(self, response_dict: dict[str, Any]) -> dict[str, Any]: - """Extract state updates from structured response payloads.""" - if self.state_schema: - return {key: response_dict[key] for key in self.state_schema.keys() if key in response_dict} - return {k: v for k, v in response_dict.items() if k != "message"} - - def apply_state_updates(self, updates: dict[str, Any]) -> None: - """Merge state updates into current state.""" - if not updates: - return - self.current_state.update(updates) - - def _apply_schema_defaults(self) -> None: - """Fill missing state fields based on schema hints.""" - for key, schema in self.state_schema.items(): - if key in self.current_state: - continue - if isinstance(schema, dict) and schema.get("type") == "array": # type: ignore - self.current_state[key] = [] - else: - self.current_state[key] = {} diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py index 0f86516448..5df6cd1d14 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_tooling.py @@ -84,9 +84,26 @@ def register_additional_client_tools(agent: "AgentProtocol", client_tools: list[ logger.debug(f"[TOOLS] Registered {len(client_tools)} client tools as additional_tools (declaration-only)") +def _has_approval_tools(tools: list[Any]) -> bool: + """Check if any tools require approval.""" + return any(getattr(tool, "approval_mode", None) == "always_require" for tool in tools) + + def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list[Any] | None: - """Combine server and client tools without overriding server metadata.""" + """Combine server and client tools without overriding server metadata. + + IMPORTANT: When server tools have approval_mode="always_require", we MUST return + them so they get passed to the streaming response handler. Otherwise, the approval + check in _try_execute_function_calls won't find the tool and won't trigger approval. + """ if not client_tools: + # Even without client tools, we must pass server tools if any require approval + if server_tools and _has_approval_tools(server_tools): + logger.info( + f"[TOOLS] No client tools but server has approval tools - " + f"passing {len(server_tools)} server tools for approval mode" + ) + return server_tools logger.info("[TOOLS] No client tools - not passing tools= parameter (using agent's configured tools)") return None @@ -94,6 +111,13 @@ def merge_tools(server_tools: list[Any], client_tools: list[Any] | None) -> list unique_client_tools = [tool for tool in client_tools if getattr(tool, "name", None) not in server_tool_names] if not unique_client_tools: + # Same check: must pass server tools if any require approval + if server_tools and _has_approval_tools(server_tools): + logger.info( + f"[TOOLS] Client tools duplicate server but server has approval tools - " + f"passing {len(server_tools)} server tools for approval mode" + ) + return server_tools logger.info("[TOOLS] All client tools duplicate server tools - not passing tools= parameter") return None diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py deleted file mode 100644 index 2bd24de8c8..0000000000 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py +++ /dev/null @@ -1,802 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Orchestrators for multi-turn agent flows.""" - -import json -import logging -import uuid -from abc import ABC, abstractmethod -from collections.abc import AsyncGenerator, Sequence -from typing import TYPE_CHECKING, Any - -from ag_ui.core import ( - BaseEvent, - MessagesSnapshotEvent, - RunErrorEvent, - TextMessageContentEvent, - TextMessageEndEvent, - TextMessageStartEvent, - ToolCallArgsEvent, - ToolCallEndEvent, - ToolCallResultEvent, - ToolCallStartEvent, -) -from agent_framework import ( - AgentProtocol, - AgentThread, - ChatAgent, - Content, - FunctionInvocationConfiguration, -) -from agent_framework._middleware import extract_and_merge_function_middleware -from agent_framework._tools import ( - _collect_approval_responses, # type: ignore - _replace_approval_contents_with_results, # type: ignore - _try_execute_function_calls, # type: ignore -) - -from ._orchestration._helpers import ( - approval_steps, - build_safe_metadata, - collect_approved_state_snapshots, - ensure_tool_call_entry, - is_step_based_approval, - latest_approval_response, - select_approval_tool_name, - select_messages_to_run, - tool_name_for_call_id, -) -from ._orchestration._tooling import ( - collect_server_tools, - merge_tools, - register_additional_client_tools, -) -from ._utils import ( - convert_agui_tools_to_agent_framework, - generate_event_id, - get_conversation_id_from_update, - get_role_value, -) - -if TYPE_CHECKING: - from ._agent import AgentConfig - from ._confirmation_strategies import ConfirmationStrategy - from ._events import AgentFrameworkEventBridge - from ._orchestration._state_manager import StateManager - - -logger = logging.getLogger(__name__) - - -class ExecutionContext: - """Shared context for orchestrators.""" - - def __init__( - self, - input_data: dict[str, Any], - agent: AgentProtocol, - config: "AgentConfig", # noqa: F821 - confirmation_strategy: "ConfirmationStrategy | None" = None, # noqa: F821 - ): - """Initialize execution context. - - Args: - input_data: AG-UI run input containing messages, state, etc. - agent: The Agent Framework agent to execute - config: Agent configuration - confirmation_strategy: Strategy for generating confirmation messages - """ - self.input_data = input_data - self.agent = agent - self.config = config - self.confirmation_strategy = confirmation_strategy - - # Lazy-loaded properties - self._messages = None - self._snapshot_messages = None - self._last_message = None - self._run_id: str | None = None - self._thread_id: str | None = None - self._supplied_run_id: str | None = None - self._supplied_thread_id: str | None = None - - @property - def messages(self): - """Get converted Agent Framework messages (lazy loaded).""" - if self._messages is None: - from ._message_adapters import normalize_agui_input_messages - - raw = self.input_data.get("messages", []) - if not isinstance(raw, list): - raw = [] - self._messages, self._snapshot_messages = normalize_agui_input_messages(raw) - return self._messages - - @property - def snapshot_messages(self) -> list[dict[str, Any]]: - """Get normalized AG-UI snapshot messages (lazy loaded).""" - if self._snapshot_messages is None: - if self._messages is None: - _ = self.messages - else: - from ._message_adapters import agent_framework_messages_to_agui, agui_messages_to_snapshot_format - - raw_snapshot = agent_framework_messages_to_agui(self._messages) - self._snapshot_messages = agui_messages_to_snapshot_format(raw_snapshot) - return self._snapshot_messages or [] - - @property - def last_message(self): - """Get the last message in the conversation (lazy loaded).""" - if self._last_message is None and self.messages: - self._last_message = self.messages[-1] - return self._last_message - - @property - def supplied_run_id(self) -> str | None: - """Get the supplied run ID, if any.""" - if self._supplied_run_id is None: - self._supplied_run_id = self.input_data.get("run_id") or self.input_data.get("runId") - return self._supplied_run_id - - @property - def run_id(self) -> str: - """Get supplied run ID or generate a new run ID.""" - if self._run_id: - return self._run_id - - if self.supplied_run_id: - self._run_id = self.supplied_run_id - - if self._run_id is None: - self._run_id = str(uuid.uuid4()) - - return self._run_id - - @property - def supplied_thread_id(self) -> str | None: - """Get the supplied thread ID, if any.""" - if self._supplied_thread_id is None: - self._supplied_thread_id = self.input_data.get("thread_id") or self.input_data.get("threadId") - return self._supplied_thread_id - - @property - def thread_id(self) -> str: - """Get supplied thread ID or generate a new thread ID.""" - if self._thread_id: - return self._thread_id - - if self.supplied_thread_id: - self._thread_id = self.supplied_thread_id - - if self._thread_id is None: - self._thread_id = str(uuid.uuid4()) - - return self._thread_id - - def update_run_id(self, new_run_id: str) -> None: - """Update the run ID in the context. - - Args: - new_run_id: The new run ID to set - """ - self._supplied_run_id = new_run_id - self._run_id = new_run_id - - def update_thread_id(self, new_thread_id: str) -> None: - """Update the thread ID in the context. - - Args: - new_thread_id: The new thread ID to set - """ - self._supplied_thread_id = new_thread_id - self._thread_id = new_thread_id - - -class Orchestrator(ABC): - """Base orchestrator for agent execution flows.""" - - @abstractmethod - def can_handle(self, context: ExecutionContext) -> bool: - """Determine if this orchestrator handles the current request. - - Args: - context: Execution context with input data and agent - - Returns: - True if this orchestrator should handle the request - """ - ... - - @abstractmethod - async def run( - self, - context: ExecutionContext, - ) -> AsyncGenerator[BaseEvent, None]: - """Execute the orchestration and yield events. - - Args: - context: Execution context - - Yields: - AG-UI events - """ - # This is never executed - just satisfies mypy's requirement for async generators - if False: # pragma: no cover - yield - raise NotImplementedError - - -class HumanInTheLoopOrchestrator(Orchestrator): - """Handles tool approval responses from user.""" - - def can_handle(self, context: ExecutionContext) -> bool: - """Check if last message is a tool approval response. - - Args: - context: Execution context - - Returns: - True if last message is a tool result - """ - msg = context.last_message - if not msg: - return False - - return bool(msg.additional_properties.get("is_tool_result", False)) - - async def run( - self, - context: ExecutionContext, - ) -> AsyncGenerator[BaseEvent, None]: - """Process approval response and generate confirmation events. - - This implementation is extracted from the legacy _agent.py lines 144-244. - - Args: - context: Execution context - - Yields: - AG-UI events (TextMessage, RunFinished) - """ - from ._confirmation_strategies import DefaultConfirmationStrategy - from ._events import AgentFrameworkEventBridge - - logger.info("=== TOOL RESULT DETECTED (HumanInTheLoopOrchestrator) ===") - - # Create event bridge for run events - event_bridge = AgentFrameworkEventBridge( - run_id=context.run_id, - thread_id=context.thread_id, - ) - - # CRITICAL: Every AG-UI run must start with RunStartedEvent - yield event_bridge.create_run_started_event() - - # Get confirmation strategy (use default if none provided) - strategy = context.confirmation_strategy - if strategy is None: - strategy = DefaultConfirmationStrategy() - - # Parse the tool result content - tool_content_text = "" - last_message = context.last_message - if last_message: - for content in last_message.contents: - if content.type == "text": - tool_content_text = content.text - break - - try: - tool_result = json.loads(tool_content_text) # type: ignore[arg-type] - accepted = tool_result.get("accepted", False) - steps = tool_result.get("steps", []) - - logger.info(f" Accepted: {accepted}") - logger.info(f" Steps count: {len(steps)}") - - # Emit a text message confirming execution - message_id = generate_event_id() - - yield TextMessageStartEvent(message_id=message_id, role="assistant") - - # Check if this is confirm_changes (no steps) or function approval (has steps) - if not steps: - # This is confirm_changes for predictive state updates - if accepted: - confirmation_message = strategy.on_state_confirmed() - else: - confirmation_message = strategy.on_state_rejected() - elif accepted: - # User approved - execute the enabled steps (function approval flow) - confirmation_message = strategy.on_approval_accepted(steps) - else: - # User rejected - confirmation_message = strategy.on_approval_rejected(steps) - - yield TextMessageContentEvent( - message_id=message_id, - delta=confirmation_message, - ) - - yield TextMessageEndEvent(message_id=message_id) - - # Emit run finished - yield event_bridge.create_run_finished_event() - - except json.JSONDecodeError: - logger.error(f"Failed to parse tool result: {tool_content_text}") - yield RunErrorEvent(message=f"Invalid tool result format: {tool_content_text[:100]}") # type: ignore[index] - yield event_bridge.create_run_finished_event() - - -class DefaultOrchestrator(Orchestrator): - """Standard agent execution (no special handling).""" - - def can_handle(self, context: ExecutionContext) -> bool: - """Always returns True as this is the fallback orchestrator. - - Args: - context: Execution context - - Returns: - Always True - """ - return True - - def _create_initial_events( - self, event_bridge: "AgentFrameworkEventBridge", state_manager: "StateManager" - ) -> Sequence[BaseEvent]: - """Generate initial events for the run. - - Args: - event_bridge: Event bridge for creating events - Returns: - Initial AG-UI events - """ - events: list[BaseEvent] = [event_bridge.create_run_started_event()] - - predict_event = state_manager.predict_state_event() - if predict_event: - events.append(predict_event) - - snapshot_event = state_manager.initial_snapshot_event(event_bridge) - if snapshot_event: - events.append(snapshot_event) - - return events - - async def run( - self, - context: ExecutionContext, - ) -> AsyncGenerator[BaseEvent, None]: - """Standard agent run with event translation. - - This implements the default agent execution flow using the event bridge - to translate Agent Framework events to AG-UI events. - - Args: - context: Execution context - - Yields: - AG-UI events - """ - from ._events import AgentFrameworkEventBridge - from ._orchestration._state_manager import StateManager - - logger.info(f"Starting default agent run for thread_id={context.thread_id}, run_id={context.run_id}") - - response_format = None - if isinstance(context.agent, ChatAgent): - response_format = context.agent.default_options.get("response_format") - skip_text_content = response_format is not None - - client_tools = convert_agui_tools_to_agent_framework(context.input_data.get("tools")) - approval_tool_name = select_approval_tool_name(client_tools) - - state_manager = StateManager( - state_schema=context.config.state_schema, - predict_state_config=context.config.predict_state_config, - require_confirmation=context.config.require_confirmation, - ) - current_state = state_manager.initialize(context.input_data.get("state")) - - event_bridge = AgentFrameworkEventBridge( - run_id=context.run_id, - thread_id=context.thread_id, - predict_state_config=context.config.predict_state_config, - current_state=current_state, - skip_text_content=skip_text_content, - require_confirmation=context.config.require_confirmation, - approval_tool_name=approval_tool_name, - ) - - if context.config.use_service_thread: - thread = AgentThread(service_thread_id=context.supplied_thread_id) - else: - thread = AgentThread() - - thread.metadata = { # type: ignore[attr-defined] - "ag_ui_thread_id": context.thread_id, - "ag_ui_run_id": context.run_id, - } - if current_state: - thread.metadata["current_state"] = current_state # type: ignore[attr-defined] - - provider_messages = context.messages or [] - snapshot_messages = context.snapshot_messages - if not provider_messages: - for event in self._create_initial_events(event_bridge, state_manager): - yield event - logger.warning("No messages provided in AG-UI input") - yield event_bridge.create_run_finished_event() - return - - logger.info(f"Received {len(provider_messages)} provider messages from client") - for i, msg in enumerate(provider_messages): - role = get_role_value(msg) - msg_id = getattr(msg, "message_id", None) - logger.info(f" Message {i}: role={role}, id={msg_id}") - if hasattr(msg, "contents") and msg.contents: - for j, content in enumerate(msg.contents): - if content.type == "text": - logger.debug(" Content %s: %s - text_length=%s", j, content.type, len(content.text)) # type: ignore[arg-type] - elif content.type == "function_call": - arg_length = len(str(content.arguments)) if content.arguments else 0 - logger.debug( - " Content %s: %s - %s args_length=%s", j, content.type, content.name, arg_length - ) - elif content.type == "function_result": - result_preview = type(content.result).__name__ if content.result is not None else "None" - logger.debug( - " Content %s: %s - call_id=%s, result_type=%s", - j, - content.type, - content.call_id, - result_preview, - ) - else: - logger.debug(f" Content {j}: {content.type}") - - pending_tool_calls: list[dict[str, Any]] = [] - tool_calls_by_id: dict[str, dict[str, Any]] = {} - tool_results: list[dict[str, Any]] = [] - tool_calls_ended: set[str] = set() - messages_snapshot_emitted = False - accumulated_text_content = "" - active_message_id: str | None = None - - # Check for FunctionApprovalResponseContent and emit updated state snapshot - # This ensures the UI shows the approved state (e.g., 2 steps) not the original (3 steps) - for snapshot_evt in collect_approved_state_snapshots( - provider_messages, - context.config.predict_state_config, - current_state, - event_bridge, - ): - yield snapshot_evt - - messages_to_run = select_messages_to_run(provider_messages, state_manager) - - logger.info(f"[TOOLS] Client sent {len(client_tools) if client_tools else 0} tools") - if client_tools: - for tool in client_tools: - tool_name = getattr(tool, "name", "unknown") - declaration_only = getattr(tool, "declaration_only", None) - logger.info(f"[TOOLS] - Client tool: {tool_name}, declaration_only={declaration_only}") - - server_tools = collect_server_tools(context.agent) - register_additional_client_tools(context.agent, client_tools) - tools_param = merge_tools(server_tools, client_tools) - - collect_updates = response_format is not None - all_updates: list[Any] | None = [] if collect_updates else None - update_count = 0 - # Prepare metadata for chat client (Azure requires string values) - safe_metadata = build_safe_metadata(getattr(thread, "metadata", None)) - - run_kwargs: dict[str, Any] = { - "thread": thread, - "tools": tools_param, - "options": {"metadata": safe_metadata}, - } - if safe_metadata: - run_kwargs["options"]["store"] = True - - async def _resolve_approval_responses( - messages: list[Any], - tools_for_execution: list[Any], - ) -> None: - fcc_todo = _collect_approval_responses(messages) - if not fcc_todo: - return - - approved_responses = [resp for resp in fcc_todo.values() if resp.approved] - approved_function_results: list[Any] = [] - if approved_responses and tools_for_execution: - chat_client = getattr(context.agent, "chat_client", None) - config = ( - getattr(chat_client, "function_invocation_configuration", None) or FunctionInvocationConfiguration() - ) - middleware_pipeline = extract_and_merge_function_middleware(chat_client, run_kwargs) - try: - results, _ = await _try_execute_function_calls( - custom_args=run_kwargs, - attempt_idx=0, - function_calls=approved_responses, - tools=tools_for_execution, - middleware_pipeline=middleware_pipeline, - config=config, - ) - approved_function_results = list(results) - except Exception: - logger.error("Failed to execute approved tool calls; injecting error results.") - approved_function_results = [] - - normalized_results: list[Content] = [] - for idx, approval in enumerate(approved_responses): - if idx < len(approved_function_results) and approved_function_results[idx].type == "function_result": - normalized_results.append(approved_function_results[idx]) - continue - call_id = approval.function_call.call_id or approval.id # type: ignore[union-attr] - normalized_results.append( - Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.") # type: ignore[arg-type] - ) - - _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore - - def _should_emit_tool_snapshot(tool_name: str | None) -> bool: - if not pending_tool_calls or not tool_results: - return False - if tool_name and context.config.predict_state_config and not context.config.require_confirmation: - for config in context.config.predict_state_config.values(): - if config["tool"] == tool_name: - logger.info( - f"Skipping intermediate MessagesSnapshotEvent for predictive tool '{tool_name}' " - " - delaying until summary" - ) - return False - return True - - def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnapshotEvent: - has_text_content = bool(accumulated_text_content) - all_messages = snapshot_messages.copy() - - if pending_tool_calls: - if tool_message_id and not has_text_content: - tool_call_message_id = tool_message_id - else: - tool_call_message_id = ( - active_message_id if not has_text_content and active_message_id else generate_event_id() - ) - tool_call_message = { - "id": tool_call_message_id, - "role": "assistant", - "tool_calls": pending_tool_calls.copy(), - } - all_messages.append(tool_call_message) - - all_messages.extend(tool_results) - - if has_text_content and active_message_id: - assistant_text_message = { - "id": active_message_id, - "role": "assistant", - "content": accumulated_text_content, - } - all_messages.append(assistant_text_message) - - return MessagesSnapshotEvent( - messages=all_messages, # type: ignore[arg-type] - ) - - # Use tools_param if available (includes client tools), otherwise fall back to server_tools - # This ensures both server tools AND client tools can be executed after approval - tools_for_approval = tools_param if tools_param is not None else server_tools - latest_approval = latest_approval_response(messages_to_run) - await _resolve_approval_responses(messages_to_run, tools_for_approval) - - if latest_approval and is_step_based_approval(latest_approval, context.config.predict_state_config): - from ._confirmation_strategies import DefaultConfirmationStrategy - - strategy = context.confirmation_strategy - if strategy is None: - strategy = DefaultConfirmationStrategy() - - steps = approval_steps(latest_approval) - if steps: - if latest_approval.approved: - confirmation_message = strategy.on_approval_accepted(steps) - else: - confirmation_message = strategy.on_approval_rejected(steps) - else: - if latest_approval.approved: - confirmation_message = strategy.on_state_confirmed() - else: - confirmation_message = strategy.on_state_rejected() - - message_id = generate_event_id() - for event in self._create_initial_events(event_bridge, state_manager): - yield event - yield TextMessageStartEvent(message_id=message_id, role="assistant") - yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message) - yield TextMessageEndEvent(message_id=message_id) - yield event_bridge.create_run_finished_event() - return - - should_recreate_event_bridge = False - async for update in context.agent.run_stream(messages_to_run, **run_kwargs): - conv_id = get_conversation_id_from_update(update) - if conv_id and conv_id != context.thread_id: - context.update_thread_id(conv_id) - should_recreate_event_bridge = True - - if update.response_id and update.response_id != context.run_id: - context.update_run_id(update.response_id) - should_recreate_event_bridge = True - - if should_recreate_event_bridge: - event_bridge = AgentFrameworkEventBridge( - run_id=context.run_id, - thread_id=context.thread_id, - predict_state_config=context.config.predict_state_config, - current_state=current_state, - skip_text_content=skip_text_content, - require_confirmation=context.config.require_confirmation, - approval_tool_name=approval_tool_name, - ) - should_recreate_event_bridge = False - - if update_count == 0: - for event in self._create_initial_events(event_bridge, state_manager): - yield event - - update_count += 1 - logger.info(f"[STREAM] Received update #{update_count} from agent") - if all_updates is not None: - all_updates.append(update) - if event_bridge.current_message_id is None and update.contents: - has_tool_call = any(content.type == "function_call" for content in update.contents) - has_text = any(content.type == "text" for content in update.contents) - if has_tool_call and not has_text: - tool_message_id = generate_event_id() - event_bridge.current_message_id = tool_message_id - active_message_id = tool_message_id - accumulated_text_content = "" - logger.info( - "[STREAM] Emitting TextMessageStartEvent for tool-only response message_id=%s", - tool_message_id, - ) - yield TextMessageStartEvent(message_id=tool_message_id, role="assistant") - events = await event_bridge.from_agent_run_update(update) - logger.info(f"[STREAM] Update #{update_count} produced {len(events)} events") - for event in events: - if isinstance(event, TextMessageStartEvent): - active_message_id = event.message_id - accumulated_text_content = "" - elif isinstance(event, TextMessageContentEvent): - accumulated_text_content += event.delta - elif isinstance(event, ToolCallStartEvent): - tool_call_entry = ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) - tool_call_entry["function"]["name"] = event.tool_call_name - elif isinstance(event, ToolCallArgsEvent): - tool_call_entry = ensure_tool_call_entry(event.tool_call_id, tool_calls_by_id, pending_tool_calls) - tool_call_entry["function"]["arguments"] += event.delta - elif isinstance(event, ToolCallEndEvent): - tool_calls_ended.add(event.tool_call_id) - elif isinstance(event, ToolCallResultEvent): - tool_results.append( - { - "id": event.message_id, - "role": "tool", - "toolCallId": event.tool_call_id, - "content": event.content, - } - ) - logger.info(f"[STREAM] Yielding event: {type(event).__name__}") - yield event - if isinstance(event, ToolCallResultEvent): - tool_name = tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) - if _should_emit_tool_snapshot(tool_name): - messages_snapshot_emitted = True - messages_snapshot = _build_messages_snapshot() - logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") - yield messages_snapshot - elif isinstance(event, ToolCallEndEvent): - tool_name = tool_name_for_call_id(tool_calls_by_id, event.tool_call_id) - if tool_name == "confirm_changes": - messages_snapshot_emitted = True - messages_snapshot = _build_messages_snapshot() - logger.info(f"[STREAM] Yielding event: {type(messages_snapshot).__name__}") - yield messages_snapshot - - logger.info(f"[STREAM] Agent stream completed. Total updates: {update_count}") - - if event_bridge.should_stop_after_confirm: - logger.info("Stopping run - waiting for user approval/confirmation response") - if event_bridge.current_message_id: - logger.info(f"[CONFIRM] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") - yield event_bridge.create_message_end_event(event_bridge.current_message_id) - event_bridge.current_message_id = None - yield event_bridge.create_run_finished_event() - return - - if pending_tool_calls: - pending_without_end = [tc for tc in pending_tool_calls if tc.get("id") not in tool_calls_ended] - if pending_without_end: - logger.info( - "Found %s pending tool calls without end event - emitting ToolCallEndEvent", - len(pending_without_end), - ) - for tool_call in pending_without_end: - tool_call_id = tool_call.get("id") - if tool_call_id: - end_event = ToolCallEndEvent(tool_call_id=tool_call_id) - logger.info(f"Emitting ToolCallEndEvent for declaration-only tool call '{tool_call_id}'") - yield end_event - - if response_format and all_updates: - from agent_framework import AgentResponse - from pydantic import BaseModel - - logger.info(f"Processing structured output, update count: {len(all_updates)}") - final_response = AgentResponse.from_agent_run_response_updates( - all_updates, output_format_type=response_format - ) - - if final_response.value and isinstance(final_response.value, BaseModel): - response_dict = final_response.value.model_dump(mode="json", exclude_none=True) - logger.info(f"Received structured output keys: {list(response_dict.keys())}") - - state_updates = state_manager.extract_state_updates(response_dict) - if state_updates: - state_manager.apply_state_updates(state_updates) - state_snapshot = event_bridge.create_state_snapshot_event(current_state) - yield state_snapshot - logger.info(f"Emitted StateSnapshotEvent with updates: {list(state_updates.keys())}") - - if "message" in response_dict and response_dict["message"]: - message_id = generate_event_id() - yield TextMessageStartEvent(message_id=message_id, role="assistant") - yield TextMessageContentEvent(message_id=message_id, delta=response_dict["message"]) - yield TextMessageEndEvent(message_id=message_id) - logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") - - if all_updates is not None and len(all_updates) == 0: - logger.info("No updates received from agent - emitting initial events") - for event in self._create_initial_events(event_bridge, state_manager): - yield event - - logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}") - if event_bridge.current_message_id: - logger.info(f"[FINALIZE] Emitting TextMessageEndEvent for message_id={event_bridge.current_message_id}") - yield event_bridge.create_message_end_event(event_bridge.current_message_id) - - messages_snapshot = _build_messages_snapshot(tool_message_id=event_bridge.current_message_id) - messages_snapshot_emitted = True - logger.info( - f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(messages_snapshot.messages)} messages " - f"(text content length: {len(accumulated_text_content)})" - ) - yield messages_snapshot - else: - logger.info("[FINALIZE] No current_message_id - skipping TextMessageEndEvent") - if not messages_snapshot_emitted and (pending_tool_calls or tool_results): - messages_snapshot = _build_messages_snapshot() - messages_snapshot_emitted = True - logger.info( - f"[FINALIZE] Emitting MessagesSnapshotEvent with {len(messages_snapshot.messages)} messages" - ) - yield messages_snapshot - - logger.info("[FINALIZE] Emitting RUN_FINISHED event") - yield event_bridge.create_run_finished_event() - logger.info(f"Completed agent run for thread_id={context.thread_id}, run_id={context.run_id}") - - -__all__ = [ - "Orchestrator", - "ExecutionContext", - "HumanInTheLoopOrchestrator", - "DefaultOrchestrator", -] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index b240e2a4af..0bc4f69610 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -27,10 +27,7 @@ AgentProtocol, AgentThread, ChatMessage, - FunctionApprovalRequestContent, - FunctionCallContent, - FunctionResultContent, - TextContent, + Content, prepare_function_call_results, ) @@ -47,6 +44,106 @@ logger = logging.getLogger(__name__) +def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: + """Build metadata dict with truncated string values for Azure compatibility. + + Azure has a 512 character limit per metadata value. + + Args: + thread_metadata: Raw metadata dict + + Returns: + Metadata with string values truncated to 512 chars + """ + if not thread_metadata: + return {} + safe_metadata: dict[str, Any] = {} + for key, value in thread_metadata.items(): + value_str = value if isinstance(value, str) else json.dumps(value) + if len(value_str) > 512: + value_str = value_str[:512] + safe_metadata[key] = value_str + return safe_metadata + + +def _has_only_tool_calls(contents: list[Any]) -> bool: + """Check if contents have only tool calls (no text). + + Args: + contents: List of content items + + Returns: + True if there are tool calls but no text content + """ + has_tool_call = any(getattr(c, "type", None) == "function_call" for c in contents) + has_text = any(getattr(c, "type", None) == "text" and getattr(c, "text", None) for c in contents) + return has_tool_call and not has_text + + +def _should_suppress_intermediate_snapshot( + tool_name: str | None, + predict_state_config: dict[str, dict[str, str]] | None, + require_confirmation: bool, +) -> bool: + """Check if intermediate MessagesSnapshotEvent should be suppressed for this tool. + + For predictive tools without confirmation, we delay the snapshot until the end. + + Args: + tool_name: Name of the tool that just completed + predict_state_config: Predictive state configuration + require_confirmation: Whether confirmation is required + + Returns: + True if snapshot should be suppressed + """ + if not tool_name or not predict_state_config: + return False + # Only suppress when confirmation is disabled + if require_confirmation: + return False + # Check if this tool is a predictive tool + for config in predict_state_config.values(): + if config["tool"] == tool_name: + logger.info(f"Suppressing intermediate MessagesSnapshotEvent for predictive tool '{tool_name}'") + return True + return False + + +def _extract_approved_state_updates( + messages: list[Any], + predictive_handler: PredictiveStateHandler | None, +) -> dict[str, Any]: + """Extract state updates from function_approval_response content. + + This emits StateSnapshotEvent for approved state-changing tools before running agent. + + Args: + messages: List of messages to scan + predictive_handler: Predictive state handler + + Returns: + Dict of state updates to apply + """ + if not predictive_handler: + return {} + + updates: dict[str, Any] = {} + for msg in messages: + for content in msg.contents: + if getattr(content, "type", None) != "function_approval_response": + continue + if not getattr(content, "approved", False) or not getattr(content, "function_call", None): + continue + parsed_args = content.function_call.parse_arguments() + result = predictive_handler.extract_state_value(content.function_call.name, parsed_args) + if result: + state_key, state_value = result + updates[state_key] = state_value + logger.info(f"Found approved state update for key '{state_key}'") + return updates + + @dataclass class FlowState: """Minimal explicit state for a single AG-UI run.""" @@ -55,11 +152,24 @@ class FlowState: tool_call_id: str | None = None # Current tool call being streamed tool_call_name: str | None = None # Name of current tool call waiting_for_approval: bool = False # Stop after approval request + pending_confirm_id: str | None = None # ID of pending confirm_changes tool call current_state: dict[str, Any] = field(default_factory=dict) # Shared state accumulated_text: str = "" # For MessagesSnapshotEvent pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # For MessagesSnapshotEvent tool_calls_by_id: dict[str, dict[str, Any]] = field(default_factory=dict) tool_results: list[dict[str, Any]] = field(default_factory=list) + tool_calls_ended: set[str] = field(default_factory=set) # Track which tool calls have been ended + + def get_tool_name(self, call_id: str | None) -> str | None: + """Get tool name by call ID.""" + if not call_id or call_id not in self.tool_calls_by_id: + return None + name = self.tool_calls_by_id[call_id]["function"].get("name") + return str(name) if name else None + + def get_pending_without_end(self) -> list[dict[str, Any]]: + """Get tool calls that started but never received an end event (declaration-only).""" + return [tc for tc in self.pending_tool_calls if tc.get("id") not in self.tool_calls_ended] def _create_state_context_message( @@ -85,7 +195,7 @@ def _create_state_context_message( return ChatMessage( role="system", contents=[ - TextContent( + Content.from_text( text=( "Current state of the application:\n" f"{state_json}\n\n" @@ -140,7 +250,7 @@ def _inject_state_context( return result -def _emit_text(content: TextContent, flow: FlowState, skip_text: bool = False) -> list[BaseEvent]: +def _emit_text(content: Content, flow: FlowState, skip_text: bool = False) -> list[BaseEvent]: """Emit TextMessage events for TextContent.""" if not content.text: return [] @@ -160,7 +270,7 @@ def _emit_text(content: TextContent, flow: FlowState, skip_text: bool = False) - def _emit_tool_call( - content: FunctionCallContent, + content: Content, flow: FlowState, predictive_handler: PredictiveStateHandler | None = None, ) -> list[BaseEvent]: @@ -214,7 +324,7 @@ def _emit_tool_call( def _emit_tool_result( - content: FunctionResultContent, + content: Content, flow: FlowState, predictive_handler: PredictiveStateHandler | None = None, ) -> list[BaseEvent]: @@ -223,6 +333,7 @@ def _emit_tool_result( if content.call_id: events.append(ToolCallEndEvent(tool_call_id=content.call_id)) + flow.tool_calls_ended.add(content.call_id) # Track ended tool calls result_content = prepare_function_call_results(content.result) message_id = generate_event_id() @@ -259,26 +370,39 @@ def _emit_tool_result( def _emit_approval_request( - content: FunctionApprovalRequestContent, + content: Content, flow: FlowState, predictive_handler: PredictiveStateHandler | None = None, require_confirmation: bool = True, ) -> list[BaseEvent]: """Emit events for function approval request.""" events: list[BaseEvent] = [] + logger.info(f"[APPROVAL-REQUEST] Starting _emit_approval_request, require_confirmation={require_confirmation}") + + # function_call is required for approval requests - skip if missing + func_call = content.function_call + logger.info(f"[APPROVAL-REQUEST] func_call={func_call}, content.id={content.id}") + if not func_call: + logger.warning("Approval request content missing function_call, skipping") + return events + + func_name = func_call.name or "" + func_call_id = func_call.call_id + logger.info(f"[APPROVAL-REQUEST] func_name={func_name}, func_call_id={func_call_id}") # Extract state from function arguments if predictive - if predictive_handler: - parsed_args = content.function_call.parse_arguments() - result = predictive_handler.extract_state_value(content.function_call.name, parsed_args) + if predictive_handler and func_name: + parsed_args = func_call.parse_arguments() + result = predictive_handler.extract_state_value(func_name, parsed_args) if result: state_key, state_value = result flow.current_state[state_key] = state_value events.append(StateSnapshotEvent(snapshot=flow.current_state)) # End the original tool call - if content.function_call.call_id: - events.append(ToolCallEndEvent(tool_call_id=content.function_call.call_id)) + if func_call_id: + events.append(ToolCallEndEvent(tool_call_id=func_call_id)) + flow.tool_calls_ended.add(func_call_id) # Track ended tool calls # Emit custom event for UI events.append( @@ -287,17 +411,21 @@ def _emit_approval_request( value={ "id": content.id, "function_call": { - "call_id": content.function_call.call_id, - "name": content.function_call.name, - "arguments": content.function_call.parse_arguments(), + "call_id": func_call_id, + "name": func_name, + "arguments": func_call.parse_arguments(), }, }, ) ) # Emit confirm_changes tool call for UI compatibility + # IMPORTANT: Do NOT emit ToolCallEndEvent here - the tool must remain in "executing" + # status for the frontend to show the confirmation dialog. The end event will be + # emitted when the user responds with their confirmation/rejection. if require_confirmation: confirm_id = generate_event_id() + logger.info(f"[APPROVAL-REQUEST] Emitting confirm_changes with id={confirm_id} (no end event - stays executing)") events.append( ToolCallStartEvent( tool_call_id=confirm_id, @@ -306,15 +434,17 @@ def _emit_approval_request( ) ) args = { - "function_name": content.function_call.name, - "function_call_id": content.function_call.call_id, - "function_arguments": content.function_call.parse_arguments() or {}, - "steps": [{"description": f"Execute {content.function_call.name}", "status": "enabled"}], + "function_name": func_name, + "function_call_id": func_call_id, + "function_arguments": func_call.parse_arguments() or {}, + "steps": [{"description": f"Execute {func_name}", "status": "enabled"}], } events.append(ToolCallArgsEvent(tool_call_id=confirm_id, delta=json.dumps(args))) - events.append(ToolCallEndEvent(tool_call_id=confirm_id)) + # Store the confirm_id in flow so we can track it for the response + flow.pending_confirm_id = confirm_id flow.waiting_for_approval = True + logger.info(f"[APPROVAL-REQUEST] Returning {len(events)} events") return events @@ -326,13 +456,16 @@ def _emit_content( require_confirmation: bool = True, ) -> list[BaseEvent]: """Emit appropriate events for any content type.""" - if isinstance(content, TextContent): + content_type = getattr(content, "type", None) + logger.info(f"[EMIT-CONTENT] Processing content type: {content_type}") + if content_type == "text": return _emit_text(content, flow, skip_text) - elif isinstance(content, FunctionCallContent): + elif content_type == "function_call": return _emit_tool_call(content, flow, predictive_handler) - elif isinstance(content, FunctionResultContent): + elif content_type == "function_result": return _emit_tool_result(content, flow, predictive_handler) - elif isinstance(content, FunctionApprovalRequestContent): + elif content_type == "function_approval_request": + logger.info("[EMIT-CONTENT] Got function_approval_request - emitting approval events") return _emit_approval_request(content, flow, predictive_handler, require_confirmation) return [] @@ -352,7 +485,7 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool: # Parse the content to check if it has the confirm_changes structure for content in last.contents: - if isinstance(content, TextContent): + if getattr(content, "type", None) == "text": try: result = json.loads(content.text) # confirm_changes results have 'accepted' and 'steps' keys @@ -371,7 +504,7 @@ def _handle_step_based_approval(messages: list[Any]) -> list[BaseEvent]: # Parse the approval content approval_text = "" for content in last.contents: - if isinstance(content, TextContent): + if getattr(content, "type", None) == "text": approval_text = content.text break @@ -460,19 +593,22 @@ async def _resolve_approval_responses( approved_function_results = [] # Build normalized results for approved responses - normalized_results: list[FunctionResultContent] = [] + normalized_results: list[Content] = [] for idx, approval in enumerate(approved_responses): - if idx < len(approved_function_results) and isinstance(approved_function_results[idx], FunctionResultContent): + if idx < len(approved_function_results) and getattr(approved_function_results[idx], "type", None) == "function_result": normalized_results.append(approved_function_results[idx]) continue - call_id = approval.function_call.call_id or approval.id - normalized_results.append(FunctionResultContent(call_id=call_id, result="Error: Tool call invocation failed.")) + # Get call_id from function_call if present, otherwise use approval.id + func_call = approval.function_call + call_id = (func_call.call_id if func_call else None) or approval.id or "" + normalized_results.append(Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.")) # Build rejection results for rejection in rejected_responses: - call_id = rejection.function_call.call_id or rejection.id + func_call = rejection.function_call + call_id = (func_call.call_id if func_call else None) or rejection.id or "" normalized_results.append( - FunctionResultContent(call_id=call_id, result="Error: Tool call invocation was rejected by user.") + Content.from_function_result(call_id=call_id, result="Error: Tool call invocation was rejected by user.") ) _replace_approval_contents_with_results(messages, fcc_todo, normalized_results) # type: ignore @@ -590,27 +726,44 @@ async def run_agent_stream( else: thread = AgentThread() - # Inject metadata for AG-UI orchestration - thread.metadata = { # type: ignore[attr-defined] + # Inject metadata for AG-UI orchestration (Feature #2: Azure-safe truncation) + base_metadata: dict[str, Any] = { "ag_ui_thread_id": thread_id, "ag_ui_run_id": run_id, } if flow.current_state: - thread.metadata["current_state"] = flow.current_state # type: ignore[attr-defined] + base_metadata["current_state"] = flow.current_state + thread.metadata = _build_safe_metadata(base_metadata) # type: ignore[attr-defined] - # Build run kwargs + # Build run kwargs (Feature #6: Azure store flag when metadata present) run_kwargs: dict[str, Any] = {"thread": thread} if tools: run_kwargs["tools"] = tools + logger.info(f"[DEBUG] Setting run_kwargs['tools'] with {len(tools)} tools") + for t in tools: + logger.info(f"[DEBUG] - {getattr(t, 'name', 'unknown')}: approval_mode={getattr(t, 'approval_mode', None)}") + safe_metadata = _build_safe_metadata(thread.metadata) # type: ignore[attr-defined] + if safe_metadata: + run_kwargs["options"] = {"metadata": safe_metadata, "store": True} # Resolve approval responses (execute approved tools, replace approvals with results) # This must happen before running the agent so it sees the tool results tools_for_execution = tools if tools is not None else server_tools await _resolve_approval_responses(messages, tools_for_execution, agent, run_kwargs) + # Feature #3: Emit StateSnapshotEvent for approved state-changing tools before agent runs + approved_state_updates = _extract_approved_state_updates(messages, predictive_handler) + approved_state_snapshot_emitted = False + if approved_state_updates: + flow.current_state.update(approved_state_updates) + approved_state_snapshot_emitted = True + # Handle confirm_changes response (state confirmation flow - emit confirmation and stop) if _is_confirm_changes_response(messages): yield RunStartedEvent(run_id=run_id, thread_id=thread_id) + # Emit approved state snapshot before confirmation message + if approved_state_snapshot_emitted: + yield StateSnapshotEvent(snapshot=flow.current_state) for event in _handle_step_based_approval(messages): yield event yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) @@ -654,6 +807,13 @@ async def run_agent_stream( yield StateSnapshotEvent(snapshot=flow.current_state) run_started_emitted = True + # Feature #4: Detect tool-only messages (no text content) + # Emit TextMessageStartEvent to create message context for tool calls + if not flow.message_id and _has_only_tool_calls(update.contents): + flow.message_id = generate_event_id() + logger.info(f"Tool-only response detected, creating message_id={flow.message_id}") + yield TextMessageStartEvent(message_id=flow.message_id, role="assistant") + # Emit events for each content item for content in update.contents: for event in _emit_content( @@ -716,15 +876,79 @@ async def run_agent_stream( yield TextMessageEndEvent(message_id=message_id) logger.info(f"Emitted conversational message with length={len(response_dict['message'])}") + # Feature #1: Emit ToolCallEndEvent for declaration-only tools (tools without results) + pending_without_end = flow.get_pending_without_end() + if pending_without_end: + logger.info(f"Found {len(pending_without_end)} pending tool calls without end event") + for tool_call in pending_without_end: + tool_call_id = tool_call.get("id") + tool_name = tool_call.get("function", {}).get("name") + if tool_call_id: + logger.info(f"Emitting ToolCallEndEvent for declaration-only tool '{tool_call_id}'") + yield ToolCallEndEvent(tool_call_id=tool_call_id) + + # For predictive tools with require_confirmation, emit confirm_changes + if config.require_confirmation and config.predict_state_config and tool_name: + is_predictive_tool = any( + cfg["tool"] == tool_name for cfg in config.predict_state_config.values() + ) + if is_predictive_tool: + logger.info(f"Emitting confirm_changes for predictive tool '{tool_name}'") + # Extract state value from tool arguments for StateSnapshot + if predictive_handler: + try: + args_str = tool_call.get("function", {}).get("arguments", "{}") + args = json.loads(args_str) if isinstance(args_str, str) else args_str + result = predictive_handler.extract_state_value(tool_name, args) + if result: + state_key, state_value = result + flow.current_state[state_key] = state_value + yield StateSnapshotEvent(snapshot=flow.current_state) + except json.JSONDecodeError: + pass + + # Emit confirm_changes tool call + confirm_id = generate_event_id() + yield ToolCallStartEvent( + tool_call_id=confirm_id, + tool_call_name="confirm_changes", + parent_message_id=flow.message_id, + ) + confirm_args = { + "function_name": tool_name, + "function_call_id": tool_call_id, + "function_arguments": json.loads(tool_call.get("function", {}).get("arguments", "{}")), + "steps": [{"description": f"Execute {tool_name}", "status": "enabled"}], + } + yield ToolCallArgsEvent(tool_call_id=confirm_id, delta=json.dumps(confirm_args)) + yield ToolCallEndEvent(tool_call_id=confirm_id) + flow.waiting_for_approval = True + # Close any open message if flow.message_id: yield TextMessageEndEvent(message_id=flow.message_id) # Emit MessagesSnapshotEvent if we have tool calls or results - if flow.pending_tool_calls or flow.tool_results or flow.accumulated_text: - yield _build_messages_snapshot(flow, snapshot_messages) - - yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) + # Feature #5: Suppress intermediate snapshots for predictive tools without confirmation + should_emit_snapshot = flow.pending_tool_calls or flow.tool_results or flow.accumulated_text + if should_emit_snapshot: + # Check if we should suppress for predictive tool + last_tool_name = None + if flow.tool_results: + last_result = flow.tool_results[-1] + last_call_id = last_result.get("toolCallId") + last_tool_name = flow.get_tool_name(last_call_id) + if not _should_suppress_intermediate_snapshot( + last_tool_name, config.predict_state_config, config.require_confirmation + ): + yield _build_messages_snapshot(flow, snapshot_messages) + + # Only emit RunFinished if we're not waiting for approval with an active confirm_changes tool + # The AG-UI protocol requires all tool calls to be ended before RUN_FINISHED + if not flow.pending_confirm_id: + yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) + else: + logger.info(f"Skipping RunFinishedEvent - waiting for approval on confirm_changes id={flow.pending_confirm_id}") __all__ = ["FlowState", "run_agent_stream"] diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/document_writer_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/document_writer_agent.py index bddc51846b..34ade05032 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/document_writer_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/document_writer_agent.py @@ -3,11 +3,11 @@ """Example agent demonstrating predictive state updates with document writing.""" from agent_framework import ChatAgent, ChatClientProtocol, ai_function -from agent_framework.ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy +from agent_framework.ag_ui import AgentFrameworkAgent -@ai_function -def write_document_local(document: str) -> str: +@ai_function(approval_mode="always_require") +def write_document(document: str) -> str: """Write a document. Use markdown formatting to format the document. It's good to format the document extensively so it's easy to read. @@ -28,7 +28,7 @@ def write_document_local(document: str) -> str: _DOCUMENT_WRITER_INSTRUCTIONS = ( "You are a helpful assistant for writing documents. " - "To write the document, you MUST use the write_document_local tool. " + "To write the document, you MUST use the write_document tool. " "You MUST write the full document, even when changing only a few words. " "When you wrote the document, DO NOT repeat it as a message. " "Just briefly summarize the changes you made. 2 sentences max. " @@ -51,7 +51,7 @@ def document_writer_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgen name="document_writer", instructions=_DOCUMENT_WRITER_INSTRUCTIONS, chat_client=chat_client, - tools=[write_document_local], + tools=[write_document], ) return AgentFrameworkAgent( @@ -62,7 +62,6 @@ def document_writer_agent(chat_client: ChatClientProtocol) -> AgentFrameworkAgen "document": {"type": "string", "description": "The current document content"}, }, predict_state_config={ - "document": {"tool": "write_document_local", "tool_argument": "document"}, + "document": {"tool": "write_document", "tool_argument": "document"}, }, - confirmation_strategy=DocumentWriterConfirmationStrategy(), ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py index 05c42efb30..2d38e612aa 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/recipe_agent.py @@ -6,7 +6,7 @@ from typing import Any from agent_framework import ChatAgent, ChatClientProtocol, ai_function -from agent_framework.ag_ui import AgentFrameworkAgent, RecipeConfirmationStrategy +from agent_framework.ag_ui import AgentFrameworkAgent from pydantic import BaseModel, Field @@ -128,6 +128,5 @@ def recipe_agent(chat_client: ChatClientProtocol[Any]) -> AgentFrameworkAgent: predict_state_config={ "recipe": {"tool": "update_recipe", "tool_argument": "recipe"}, }, - confirmation_strategy=RecipeConfirmationStrategy(), require_confirmation=False, ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py index c79c36f511..442c9e6182 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/task_planner_agent.py @@ -5,7 +5,7 @@ from typing import Any from agent_framework import ChatAgent, ChatClientProtocol, ai_function -from agent_framework.ag_ui import AgentFrameworkAgent, TaskPlannerConfirmationStrategy +from agent_framework.ag_ui import AgentFrameworkAgent @ai_function(approval_mode="always_require") @@ -81,5 +81,4 @@ def task_planner_agent(chat_client: ChatClientProtocol[Any]) -> AgentFrameworkAg agent=agent, name="TaskPlanner", description="Plans and executes tasks with user approval", - confirmation_strategy=TaskPlannerConfirmationStrategy(), ) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py index e71abe7507..8d9c212a5e 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py @@ -4,9 +4,10 @@ import logging import os -from typing import TYPE_CHECKING import uvicorn +from agent_framework import ChatOptions +from agent_framework._clients import BaseChatClient from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint from agent_framework.azure import AzureOpenAIChatClient from fastapi import FastAPI @@ -20,10 +21,6 @@ from ..agents.ui_generator_agent import ui_generator_agent from ..agents.weather_agent import weather_agent -if TYPE_CHECKING: - from agent_framework import ChatOptions - from agent_framework._clients import BaseChatClient - # Configure logging to file and console (disabled by default - set ENABLE_DEBUG_LOGGING=1 to enable) if os.getenv("ENABLE_DEBUG_LOGGING"): log_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "ag_ui_server.log") diff --git a/python/packages/ag-ui/tests/test_backend_tool_rendering.py b/python/packages/ag-ui/tests/test_backend_tool_rendering.py deleted file mode 100644 index 594d127532..0000000000 --- a/python/packages/ag-ui/tests/test_backend_tool_rendering.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Tests for backend tool rendering.""" - -from typing import cast - -from ag_ui.core import ( - TextMessageContentEvent, - TextMessageStartEvent, - ToolCallArgsEvent, - ToolCallEndEvent, - ToolCallResultEvent, - ToolCallStartEvent, -) -from agent_framework import AgentResponseUpdate, Content - -from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - -async def test_tool_call_flow(): - """Test complete tool call flow: call -> args -> end -> result.""" - bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread") - - # Step 1: Tool call starts - tool_call = Content.from_function_call( - call_id="weather-123", - name="get_weather", - arguments={"location": "Seattle"}, - ) - - update1 = AgentResponseUpdate(contents=[tool_call]) - events1 = await bridge.from_agent_run_update(update1) - - # Should have: ToolCallStartEvent, ToolCallArgsEvent - assert len(events1) == 2 - assert isinstance(events1[0], ToolCallStartEvent) - assert isinstance(events1[1], ToolCallArgsEvent) - - start_event = events1[0] - assert start_event.tool_call_id == "weather-123" - assert start_event.tool_call_name == "get_weather" - - args_event = events1[1] - assert "Seattle" in args_event.delta - - # Step 2: Tool result comes back - tool_result = Content.from_function_result( - call_id="weather-123", - result="Weather in Seattle: Rainy, 52°F", - ) - - update2 = AgentResponseUpdate(contents=[tool_result]) - events2 = await bridge.from_agent_run_update(update2) - - # Should have: ToolCallEndEvent, ToolCallResultEvent - assert len(events2) == 2 - assert isinstance(events2[0], ToolCallEndEvent) - assert isinstance(events2[1], ToolCallResultEvent) - - end_event = events2[0] - assert end_event.tool_call_id == "weather-123" - - result_event = events2[1] - assert result_event.tool_call_id == "weather-123" - assert "Seattle" in result_event.content - assert "Rainy" in result_event.content - - -async def test_text_with_tool_call(): - """Test agent response with both text and tool calls.""" - bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread") - - # Agent says something then calls a tool - text_content = Content.from_text(text="Let me check the weather for you.") - tool_call = Content.from_function_call( - call_id="weather-456", - name="get_forecast", - arguments={"location": "San Francisco", "days": 3}, - ) - - update = AgentResponseUpdate(contents=[text_content, tool_call]) - events = await bridge.from_agent_run_update(update) - - # Should have: TextMessageStart, TextMessageContent, ToolCallStart, ToolCallArgs - assert len(events) == 4 - - assert isinstance(events[0], TextMessageStartEvent) - assert isinstance(events[1], TextMessageContentEvent) - assert isinstance(events[2], ToolCallStartEvent) - assert isinstance(events[3], ToolCallArgsEvent) - - text_event = events[1] - assert "check the weather" in text_event.delta - - tool_start = events[2] - assert tool_start.tool_call_name == "get_forecast" - - -async def test_multiple_tool_results(): - """Test handling multiple tool results in sequence.""" - bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread") - - # Multiple tool results - results = [ - Content.from_function_result(call_id="tool-1", result="Result 1"), - Content.from_function_result(call_id="tool-2", result="Result 2"), - Content.from_function_result(call_id="tool-3", result="Result 3"), - ] - - update = AgentResponseUpdate(contents=results) - events = await bridge.from_agent_run_update(update) - - # Should have 3 pairs of ToolCallEndEvent + ToolCallResultEvent = 6 events - assert len(events) == 6 - - # Verify the pattern: End, Result, End, Result, End, Result - for i in range(3): - end_idx = i * 2 - result_idx = i * 2 + 1 - - assert isinstance(events[end_idx], ToolCallEndEvent) - assert isinstance(events[result_idx], ToolCallResultEvent) - - end_event = cast(ToolCallEndEvent, events[end_idx]) - result_event = cast(ToolCallResultEvent, events[result_idx]) - - assert end_event.tool_call_id == f"tool-{i + 1}" - assert result_event.tool_call_id == f"tool-{i + 1}" - assert f"Result {i + 1}" in result_event.content diff --git a/python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py b/python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py deleted file mode 100644 index ab355d8995..0000000000 --- a/python/packages/ag-ui/tests/test_confirmation_strategies_comprehensive.py +++ /dev/null @@ -1,275 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Comprehensive tests for all confirmation strategies.""" - -import pytest - -from agent_framework_ag_ui._confirmation_strategies import ( - ConfirmationStrategy, - DefaultConfirmationStrategy, - DocumentWriterConfirmationStrategy, - RecipeConfirmationStrategy, - TaskPlannerConfirmationStrategy, -) - - -@pytest.fixture -def sample_steps() -> list[dict[str, str]]: - """Sample steps for testing approval messages.""" - return [ - {"description": "Step 1: Do something", "status": "enabled"}, - {"description": "Step 2: Do another thing", "status": "enabled"}, - {"description": "Step 3: Disabled step", "status": "disabled"}, - ] - - -@pytest.fixture -def all_enabled_steps() -> list[dict[str, str]]: - """All steps enabled.""" - return [ - {"description": "Task A", "status": "enabled"}, - {"description": "Task B", "status": "enabled"}, - {"description": "Task C", "status": "enabled"}, - ] - - -@pytest.fixture -def empty_steps() -> list[dict[str, str]]: - """Empty steps list.""" - return [] - - -class TestDefaultConfirmationStrategy: - """Tests for DefaultConfirmationStrategy.""" - - def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: - strategy = DefaultConfirmationStrategy() - message = strategy.on_approval_accepted(sample_steps) - - assert "Executing 2 approved steps" in message - assert "Step 1: Do something" in message - assert "Step 2: Do another thing" in message - assert "Step 3" not in message # Disabled step shouldn't appear - assert "All steps completed successfully!" in message - - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: - strategy = DefaultConfirmationStrategy() - message = strategy.on_approval_accepted(all_enabled_steps) - - assert "Executing 3 approved steps" in message - assert "Task A" in message - assert "Task B" in message - assert "Task C" in message - - def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: - strategy = DefaultConfirmationStrategy() - message = strategy.on_approval_accepted(empty_steps) - - assert "Executing 0 approved steps" in message - assert "All steps completed successfully!" in message - - def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: - strategy = DefaultConfirmationStrategy() - message = strategy.on_approval_rejected(sample_steps) - - assert "No problem!" in message - assert "What would you like me to change" in message - - def test_on_state_confirmed(self) -> None: - strategy = DefaultConfirmationStrategy() - message = strategy.on_state_confirmed() - - assert "Changes confirmed" in message - assert "successfully" in message - - def test_on_state_rejected(self) -> None: - strategy = DefaultConfirmationStrategy() - message = strategy.on_state_rejected() - - assert "No problem!" in message - assert "What would you like me to change" in message - - -class TestTaskPlannerConfirmationStrategy: - """Tests for TaskPlannerConfirmationStrategy.""" - - def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: - strategy = TaskPlannerConfirmationStrategy() - message = strategy.on_approval_accepted(sample_steps) - - assert "Executing your requested tasks" in message - assert "1. Step 1: Do something" in message - assert "2. Step 2: Do another thing" in message - assert "Step 3" not in message - assert "All tasks completed successfully!" in message - - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: - strategy = TaskPlannerConfirmationStrategy() - message = strategy.on_approval_accepted(all_enabled_steps) - - assert "Executing your requested tasks" in message - assert "1. Task A" in message - assert "2. Task B" in message - assert "3. Task C" in message - - def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: - strategy = TaskPlannerConfirmationStrategy() - message = strategy.on_approval_accepted(empty_steps) - - assert "Executing your requested tasks" in message - assert "All tasks completed successfully!" in message - - def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: - strategy = TaskPlannerConfirmationStrategy() - message = strategy.on_approval_rejected(sample_steps) - - assert "No problem!" in message - assert "revise the plan" in message - - def test_on_state_confirmed(self) -> None: - strategy = TaskPlannerConfirmationStrategy() - message = strategy.on_state_confirmed() - - assert "Tasks confirmed" in message - assert "ready to execute" in message - - def test_on_state_rejected(self) -> None: - strategy = TaskPlannerConfirmationStrategy() - message = strategy.on_state_rejected() - - assert "No problem!" in message - assert "adjust the task list" in message - - -class TestRecipeConfirmationStrategy: - """Tests for RecipeConfirmationStrategy.""" - - def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: - strategy = RecipeConfirmationStrategy() - message = strategy.on_approval_accepted(sample_steps) - - assert "Updating your recipe" in message - assert "1. Step 1: Do something" in message - assert "2. Step 2: Do another thing" in message - assert "Step 3" not in message - assert "Recipe updated successfully!" in message - - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: - strategy = RecipeConfirmationStrategy() - message = strategy.on_approval_accepted(all_enabled_steps) - - assert "Updating your recipe" in message - assert "1. Task A" in message - assert "2. Task B" in message - assert "3. Task C" in message - - def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: - strategy = RecipeConfirmationStrategy() - message = strategy.on_approval_accepted(empty_steps) - - assert "Updating your recipe" in message - assert "Recipe updated successfully!" in message - - def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: - strategy = RecipeConfirmationStrategy() - message = strategy.on_approval_rejected(sample_steps) - - assert "No problem!" in message - assert "ingredients or steps" in message - - def test_on_state_confirmed(self) -> None: - strategy = RecipeConfirmationStrategy() - message = strategy.on_state_confirmed() - - assert "Recipe changes applied" in message - assert "successfully" in message - - def test_on_state_rejected(self) -> None: - strategy = RecipeConfirmationStrategy() - message = strategy.on_state_rejected() - - assert "No problem!" in message - assert "adjust in the recipe" in message - - -class TestDocumentWriterConfirmationStrategy: - """Tests for DocumentWriterConfirmationStrategy.""" - - def test_on_approval_accepted_with_enabled_steps(self, sample_steps: list[dict[str, str]]) -> None: - strategy = DocumentWriterConfirmationStrategy() - message = strategy.on_approval_accepted(sample_steps) - - assert "Applying your edits" in message - assert "1. Step 1: Do something" in message - assert "2. Step 2: Do another thing" in message - assert "Step 3" not in message - assert "Document updated successfully!" in message - - def test_on_approval_accepted_with_all_enabled(self, all_enabled_steps: list[dict[str, str]]) -> None: - strategy = DocumentWriterConfirmationStrategy() - message = strategy.on_approval_accepted(all_enabled_steps) - - assert "Applying your edits" in message - assert "1. Task A" in message - assert "2. Task B" in message - assert "3. Task C" in message - - def test_on_approval_accepted_with_empty_steps(self, empty_steps: list[dict[str, str]]) -> None: - strategy = DocumentWriterConfirmationStrategy() - message = strategy.on_approval_accepted(empty_steps) - - assert "Applying your edits" in message - assert "Document updated successfully!" in message - - def test_on_approval_rejected(self, sample_steps: list[dict[str, str]]) -> None: - strategy = DocumentWriterConfirmationStrategy() - message = strategy.on_approval_rejected(sample_steps) - - assert "No problem!" in message - assert "keep or modify" in message - - def test_on_state_confirmed(self) -> None: - strategy = DocumentWriterConfirmationStrategy() - message = strategy.on_state_confirmed() - - assert "Document edits applied!" in message - - def test_on_state_rejected(self) -> None: - strategy = DocumentWriterConfirmationStrategy() - message = strategy.on_state_rejected() - - assert "No problem!" in message - assert "change about the document" in message - - -class TestConfirmationStrategyInterface: - """Tests for ConfirmationStrategy abstract base class.""" - - def test_cannot_instantiate_abstract_class(self): - """Verify ConfirmationStrategy is abstract and cannot be instantiated.""" - with pytest.raises(TypeError): - ConfirmationStrategy() # type: ignore - - def test_all_strategies_implement_interface(self): - """Verify all concrete strategies implement the full interface.""" - strategies = [ - DefaultConfirmationStrategy(), - TaskPlannerConfirmationStrategy(), - RecipeConfirmationStrategy(), - DocumentWriterConfirmationStrategy(), - ] - - sample_steps = [{"description": "Test", "status": "enabled"}] - - for strategy in strategies: - # All should have these methods - assert callable(strategy.on_approval_accepted) - assert callable(strategy.on_approval_rejected) - assert callable(strategy.on_state_confirmed) - assert callable(strategy.on_state_rejected) - - # All should return strings - assert isinstance(strategy.on_approval_accepted(sample_steps), str) - assert isinstance(strategy.on_approval_rejected(sample_steps), str) - assert isinstance(strategy.on_state_confirmed(), str) - assert isinstance(strategy.on_state_rejected(), str) diff --git a/python/packages/ag-ui/tests/test_document_writer_flow.py b/python/packages/ag-ui/tests/test_document_writer_flow.py deleted file mode 100644 index 7e154682b4..0000000000 --- a/python/packages/ag-ui/tests/test_document_writer_flow.py +++ /dev/null @@ -1,236 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Tests for document writer predictive state flow with confirm_changes.""" - -from ag_ui.core import EventType, StateDeltaEvent, ToolCallArgsEvent, ToolCallEndEvent, ToolCallStartEvent -from agent_framework import AgentResponseUpdate, Content - -from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - -async def test_streaming_document_with_state_deltas(): - """Test that streaming tool arguments emit progressive StateDeltaEvents.""" - predict_config = { - "document": {"tool": "write_document_local", "tool_argument": "document"}, - } - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config=predict_config, - ) - - # Simulate streaming tool call - first chunk with name - tool_call_start = Content.from_function_call( - call_id="call_123", - name="write_document_local", - arguments='{"document":"Once', - ) - update1 = AgentResponseUpdate(contents=[tool_call_start]) - events1 = await bridge.from_agent_run_update(update1) - - # Should have ToolCallStartEvent and ToolCallArgsEvent - assert any(e.type == EventType.TOOL_CALL_START for e in events1) - assert any(e.type == EventType.TOOL_CALL_ARGS for e in events1) - - # Second chunk - incomplete JSON, should try partial extraction - tool_call_chunk2 = Content.from_function_call( - call_id="call_123", name="write_document_local", arguments=" upon a time" - ) - update2 = AgentResponseUpdate(contents=[tool_call_chunk2]) - events2 = await bridge.from_agent_run_update(update2) - - # Should emit StateDeltaEvent with partial document - state_deltas = [e for e in events2 if isinstance(e, StateDeltaEvent)] - assert len(state_deltas) >= 1 - - # Check JSON Patch format - delta = state_deltas[0] - assert isinstance(delta.delta, list) - assert len(delta.delta) > 0 - assert delta.delta[0]["op"] == "replace" - assert delta.delta[0]["path"] == "/document" - assert "Once upon a time" in delta.delta[0]["value"] - - -async def test_confirm_changes_emission(): - """Test that confirm_changes tool call is emitted after predictive tool completion.""" - predict_config = { - "document": {"tool": "write_document_local", "tool_argument": "document"}, - } - - current_state: dict[str, str] = {} - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config=predict_config, - current_state=current_state, - ) - - # Set current tool name (simulating earlier tool call start) - bridge.current_tool_call_name = "write_document_local" - bridge.pending_state_updates = {"document": "A short story"} - - # Tool result - tool_result = Content.from_function_result( - call_id="call_123", - result="Document written.", - ) - - update = AgentResponseUpdate(contents=[tool_result]) - events = await bridge.from_agent_run_update(update) - - # Should have: ToolCallEndEvent, ToolCallResultEvent, StateSnapshotEvent, confirm_changes sequence - assert any(e.type == EventType.TOOL_CALL_END for e in events) - assert any(e.type == EventType.TOOL_CALL_RESULT for e in events) - assert any(e.type == EventType.STATE_SNAPSHOT for e in events) - - # Check for confirm_changes tool call - confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"] - assert len(confirm_starts) == 1 - - confirm_args = [e for e in events if isinstance(e, ToolCallArgsEvent) and e.delta == "{}"] - assert len(confirm_args) >= 1 - - confirm_ends = [e for e in events if isinstance(e, ToolCallEndEvent)] - # At least 2: one for write_document_local, one for confirm_changes - assert len(confirm_ends) >= 2 - - # Check that stop flag is set - assert bridge.should_stop_after_confirm is True - - -async def test_text_suppression_before_confirm(): - """Test that text messages are suppressed when confirm_changes is pending.""" - predict_config = { - "document": {"tool": "write_document_local", "tool_argument": "document"}, - } - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config=predict_config, - ) - - # Set flag indicating we're waiting for confirmation - bridge.should_stop_after_confirm = True - - # Text content that should be suppressed - text = Content.from_text(text="I have written a story about pirates.") - update = AgentResponseUpdate(contents=[text]) - - events = await bridge.from_agent_run_update(update) - - # Should NOT emit TextMessageContentEvent - text_events = [e for e in events if e.type == EventType.TEXT_MESSAGE_CONTENT] - assert len(text_events) == 0 - - # But should save the text - assert bridge.suppressed_summary == "I have written a story about pirates." - - -async def test_no_confirm_for_non_predictive_tools(): - """Test that confirm_changes is NOT emitted for regular tool calls.""" - predict_config = { - "document": {"tool": "write_document_local", "tool_argument": "document"}, - } - - current_state: dict[str, str] = {} - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config=predict_config, - current_state=current_state, - ) - - # Different tool (not in predict_state_config) - bridge.current_tool_call_name = "get_weather" - - tool_result = Content.from_function_result( - call_id="call_456", - result="Sunny, 72°F", - ) - - update = AgentResponseUpdate(contents=[tool_result]) - events = await bridge.from_agent_run_update(update) - - # Should NOT have confirm_changes - confirm_starts = [e for e in events if isinstance(e, ToolCallStartEvent) and e.tool_call_name == "confirm_changes"] - assert len(confirm_starts) == 0 - - # Stop flag should NOT be set - assert bridge.should_stop_after_confirm is False - - -async def test_state_delta_deduplication(): - """Test that duplicate state values don't emit multiple StateDeltaEvents.""" - predict_config = { - "document": {"tool": "write_document_local", "tool_argument": "document"}, - } - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config=predict_config, - ) - - # First tool call with document - tool_call1 = Content.from_function_call( - call_id="call_1", - name="write_document_local", - arguments='{"document":"Same text"}', - ) - update1 = AgentResponseUpdate(contents=[tool_call1]) - events1 = await bridge.from_agent_run_update(update1) - - # Count state deltas - state_deltas_1 = [e for e in events1 if isinstance(e, StateDeltaEvent)] - assert len(state_deltas_1) >= 1 - - # Second tool call with SAME document (shouldn't emit new delta) - bridge.current_tool_call_name = "write_document_local" - tool_call2 = Content.from_function_call( - call_id="call_2", - name="write_document_local", - arguments='{"document":"Same text"}', # Identical content - ) - update2 = AgentResponseUpdate(contents=[tool_call2]) - events2 = await bridge.from_agent_run_update(update2) - - # Should NOT emit state delta (same value) - state_deltas_2 = [e for e in events2 if e.type == EventType.STATE_DELTA] - assert len(state_deltas_2) == 0 - - -async def test_predict_state_config_multiple_fields(): - """Test predictive state with multiple state fields.""" - predict_config = { - "title": {"tool": "create_post", "tool_argument": "title"}, - "content": {"tool": "create_post", "tool_argument": "body"}, - } - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config=predict_config, - ) - - # Tool call with both fields - tool_call = Content.from_function_call( - call_id="call_999", - name="create_post", - arguments='{"title":"My Post","body":"Post content"}', - ) - update = AgentResponseUpdate(contents=[tool_call]) - events = await bridge.from_agent_run_update(update) - - # Should emit StateDeltaEvent for both fields - state_deltas = [e for e in events if isinstance(e, StateDeltaEvent)] - assert len(state_deltas) >= 2 - - # Check both fields are present - paths = [delta.delta[0]["path"] for delta in state_deltas] - assert "/title" in paths - assert "/content" in paths diff --git a/python/packages/ag-ui/tests/test_events_comprehensive.py b/python/packages/ag-ui/tests/test_events_comprehensive.py deleted file mode 100644 index 75e923123f..0000000000 --- a/python/packages/ag-ui/tests/test_events_comprehensive.py +++ /dev/null @@ -1,827 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Comprehensive tests for AgentFrameworkEventBridge (_events.py).""" - -import json - -from agent_framework import ( - AgentResponseUpdate, - Content, -) - - -async def test_basic_text_message_conversion(): - """Test basic TextContent to AG-UI events.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate(contents=[Content.from_text(text="Hello")]) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 2 - assert events[0].type == "TEXT_MESSAGE_START" - assert events[0].role == "assistant" - assert events[1].type == "TEXT_MESSAGE_CONTENT" - assert events[1].delta == "Hello" - - -async def test_text_message_streaming(): - """Test streaming TextContent with multiple chunks.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")]) - update2 = AgentResponseUpdate(contents=[Content.from_text(text="world")]) - - events1 = await bridge.from_agent_run_update(update1) - events2 = await bridge.from_agent_run_update(update2) - - # First update: START + CONTENT - assert len(events1) == 2 - assert events1[0].type == "TEXT_MESSAGE_START" - assert events1[1].delta == "Hello " - - # Second update: just CONTENT (same message) - assert len(events2) == 1 - assert events2[0].type == "TEXT_MESSAGE_CONTENT" - assert events2[0].delta == "world" - - # Both content events should have same message_id - assert events1[1].message_id == events2[0].message_id - - -async def test_skip_text_content_for_structured_outputs(): - """Test that text content is skipped when skip_text_content=True.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread", skip_text_content=True) - - update = AgentResponseUpdate(contents=[Content.from_text(text='{"result": "data"}')]) - events = await bridge.from_agent_run_update(update) - - # No events should be emitted - assert len(events) == 0 - - -async def test_skip_text_content_for_empty_text(): - """Test streaming TextContent with empty chunks.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")]) - update2 = AgentResponseUpdate(contents=[Content.from_text(text="")]) # Empty chunk - update3 = AgentResponseUpdate(contents=[Content.from_text(text="world")]) - - events1 = await bridge.from_agent_run_update(update1) - events2 = await bridge.from_agent_run_update(update2) - events3 = await bridge.from_agent_run_update(update3) - - # First update: START + CONTENT - assert len(events1) == 2 - assert events1[0].type == "TEXT_MESSAGE_START" - assert events1[1].delta == "Hello " - - # Second update: should skip empty chunk, no events - assert len(events2) == 0 - - # Third update: just CONTENT (same message) - assert len(events3) == 1 - assert events3[0].type == "TEXT_MESSAGE_CONTENT" - assert events3[0].delta == "world" - - # Both content events should have same message_id - assert events1[1].message_id == events3[0].message_id - - -async def test_tool_call_with_name(): - """Test FunctionCallContent with name emits ToolCallStartEvent.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate(contents=[Content.from_function_call(name="search_web", call_id="call_123")]) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 1 - assert events[0].type == "TOOL_CALL_START" - assert events[0].tool_call_name == "search_web" - assert events[0].tool_call_id == "call_123" - - -async def test_tool_call_streaming_args(): - """Test streaming tool call arguments.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - # First chunk: name only - update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="search_web", call_id="call_123")]) - events1 = await bridge.from_agent_run_update(update1) - - # Second chunk: arguments chunk 1 (name can be empty string for continuation) - update2 = AgentResponseUpdate( - contents=[Content.from_function_call(name="", call_id="call_123", arguments='{"query": "')] - ) - events2 = await bridge.from_agent_run_update(update2) - - # Third chunk: arguments chunk 2 - update3 = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="call_123", arguments='AI"}')]) - events3 = await bridge.from_agent_run_update(update3) - - # First update: ToolCallStartEvent - assert len(events1) == 1 - assert events1[0].type == "TOOL_CALL_START" - - # Second update: ToolCallArgsEvent - assert len(events2) == 1 - assert events2[0].type == "TOOL_CALL_ARGS" - assert events2[0].delta == '{"query": "' - - # Third update: ToolCallArgsEvent - assert len(events3) == 1 - assert events3[0].type == "TOOL_CALL_ARGS" - assert events3[0].delta == 'AI"}' - - # All should have same tool_call_id - assert events1[0].tool_call_id == events2[0].tool_call_id == events3[0].tool_call_id - - -async def test_streaming_tool_call_no_duplicate_start_events(): - """Test that streaming tool calls emit exactly one ToolCallStartEvent. - - This is a regression test for the Anthropic streaming fix where input_json_delta - events were incorrectly passing the tool name, causing duplicate ToolCallStartEvents. - - The correct behavior is: - - Initial FunctionCallContent with name -> emits ToolCallStartEvent - - Subsequent FunctionCallContent with name="" -> emits only ToolCallArgsEvent - - See: https://github.com/microsoft/agent-framework/pull/3051 - """ - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - # Simulate streaming tool call: first chunk has name, subsequent chunks have name="" - update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="get_weather", call_id="call_789")]) - update2 = AgentResponseUpdate( - contents=[Content.from_function_call(name="", call_id="call_789", arguments='{"loc":')] - ) - update3 = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="call_789", arguments='"SF"}')]) - - events1 = await bridge.from_agent_run_update(update1) - events2 = await bridge.from_agent_run_update(update2) - events3 = await bridge.from_agent_run_update(update3) - - # Count all ToolCallStartEvents - should be exactly 1 - all_events = events1 + events2 + events3 - tool_call_start_count = sum(1 for e in all_events if e.type == "TOOL_CALL_START") - assert tool_call_start_count == 1, f"Expected 1 ToolCallStartEvent, got {tool_call_start_count}" - - # Verify event types - assert events1[0].type == "TOOL_CALL_START" - assert events2[0].type == "TOOL_CALL_ARGS" - assert events3[0].type == "TOOL_CALL_ARGS" - - -async def test_tool_result_with_dict(): - """Test FunctionResultContent with dict result.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - result_data = {"status": "success", "count": 42} - update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=result_data)]) - events = await bridge.from_agent_run_update(update) - - # Should emit ToolCallEndEvent + ToolCallResultEvent - assert len(events) == 2 - assert events[0].type == "TOOL_CALL_END" - assert events[0].tool_call_id == "call_123" - - assert events[1].type == "TOOL_CALL_RESULT" - assert events[1].tool_call_id == "call_123" - assert events[1].role == "tool" - # Result should be JSON-serialized - assert json.loads(events[1].content) == result_data - - -async def test_tool_result_with_string(): - """Test FunctionResultContent with string result.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result="Search complete")]) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 2 - assert events[0].type == "TOOL_CALL_END" - assert events[1].type == "TOOL_CALL_RESULT" - assert events[1].content == "Search complete" - - -async def test_tool_result_with_none(): - """Test FunctionResultContent with None result.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=None)]) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 2 - assert events[0].type == "TOOL_CALL_END" - assert events[1].type == "TOOL_CALL_RESULT" - # prepare_function_call_results serializes None as JSON "null" - assert events[1].content == "null" - - -async def test_multiple_tool_results_in_sequence(): - """Test multiple tool results processed sequentially.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate( - contents=[ - Content.from_function_result(call_id="call_1", result="Result 1"), - Content.from_function_result(call_id="call_2", result="Result 2"), - ] - ) - events = await bridge.from_agent_run_update(update) - - # Each result emits: ToolCallEndEvent + ToolCallResultEvent = 4 events total - assert len(events) == 4 - assert events[0].tool_call_id == "call_1" - assert events[1].tool_call_id == "call_1" - assert events[2].tool_call_id == "call_2" - assert events[3].tool_call_id == "call_2" - - -async def test_function_approval_request_basic(): - """Test FunctionApprovalRequestContent conversion.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - # Set require_confirmation=False to test just the function_approval_request event - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - require_confirmation=False, - ) - - func_call = Content.from_function_call( - call_id="call_123", - name="send_email", - arguments={"to": "user@example.com", "subject": "Test"}, - ) - approval = Content.from_function_approval_request( - id="approval_001", - function_call=func_call, - ) - - update = AgentResponseUpdate(contents=[approval]) - events = await bridge.from_agent_run_update(update) - - # Should emit: ToolCallEndEvent + CustomEvent - assert len(events) == 2 - - # First: ToolCallEndEvent to close the tool call - assert events[0].type == "TOOL_CALL_END" - assert events[0].tool_call_id == "call_123" - - # Second: CustomEvent with approval details - assert events[1].type == "CUSTOM" - assert events[1].name == "function_approval_request" - assert events[1].value["id"] == "approval_001" - assert events[1].value["function_call"]["name"] == "send_email" - - -async def test_empty_predict_state_config(): - """Test behavior with no predictive state configuration.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={}, # Empty config - ) - - # Tool call with arguments - update = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="write_doc", call_id="call_1", arguments='{"content": "test"}'), - Content.from_function_result(call_id="call_1", result="Done"), - ] - ) - events = await bridge.from_agent_run_update(update) - - # Should NOT emit StateDeltaEvent or confirm_changes - event_types = [e.type for e in events] - assert "STATE_DELTA" not in event_types - assert "STATE_SNAPSHOT" not in event_types - - # Should have: ToolCallStart, ToolCallArgs, ToolCallEnd, ToolCallResult - assert event_types == [ - "TOOL_CALL_START", - "TOOL_CALL_ARGS", - "TOOL_CALL_END", - "TOOL_CALL_RESULT", - ] - - -async def test_tool_not_in_predict_state_config(): - """Test tool that doesn't match any predict_state_config entry.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={ - "document": {"tool": "write_document", "tool_argument": "content"}, - }, - ) - - # Different tool name - update = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="search_web", call_id="call_1", arguments='{"query": "AI"}'), - Content.from_function_result(call_id="call_1", result="Results"), - ] - ) - events = await bridge.from_agent_run_update(update) - - # Should NOT emit StateDeltaEvent or confirm_changes - event_types = [e.type for e in events] - assert "STATE_DELTA" not in event_types - assert "STATE_SNAPSHOT" not in event_types - - -async def test_state_management_tracking(): - """Test current_state and pending_state_updates tracking.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - initial_state = {"document": ""} - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={ - "document": {"tool": "write_doc", "tool_argument": "content"}, - }, - current_state=initial_state, - ) - - # Streaming tool call - update1 = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="write_doc", call_id="call_1"), - Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Hello"}'), - ] - ) - await bridge.from_agent_run_update(update1) - - # Check pending_state_updates was populated - assert "document" in bridge.pending_state_updates - assert bridge.pending_state_updates["document"] == "Hello" - - # Tool result should update current_state - update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]) - await bridge.from_agent_run_update(update2) - - # current_state should be updated - assert bridge.current_state["document"] == "Hello" - - # pending_state_updates should be cleared - assert len(bridge.pending_state_updates) == 0 - - -async def test_wildcard_tool_argument(): - """Test tool_argument='*' uses all arguments as state value.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={ - "recipe": {"tool": "create_recipe", "tool_argument": "*"}, - }, - current_state={}, - ) - - # Complete tool call with dict arguments - update = AgentResponseUpdate( - contents=[ - Content.from_function_call( - name="create_recipe", - call_id="call_1", - arguments={"title": "Pasta", "ingredients": ["pasta", "sauce"]}, - ), - Content.from_function_result(call_id="call_1", result="Created"), - ] - ) - events = await bridge.from_agent_run_update(update) - - # Find StateDeltaEvent - delta_events = [e for e in events if e.type == "STATE_DELTA"] - assert len(delta_events) > 0 - - # Value should be the entire arguments dict - delta = delta_events[0].delta[0] - assert delta["path"] == "/recipe" - assert delta["value"] == {"title": "Pasta", "ingredients": ["pasta", "sauce"]} - - -async def test_run_lifecycle_events(): - """Test RunStartedEvent and RunFinishedEvent creation.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - started = bridge.create_run_started_event() - assert started.type == "RUN_STARTED" - assert started.run_id == "test_run" - assert started.thread_id == "test_thread" - - finished = bridge.create_run_finished_event(result={"status": "complete"}) - assert finished.type == "RUN_FINISHED" - assert finished.run_id == "test_run" - assert finished.thread_id == "test_thread" - assert finished.result == {"status": "complete"} - - -async def test_message_lifecycle_events(): - """Test TextMessageStartEvent and TextMessageEndEvent creation.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - start = bridge.create_message_start_event("msg_123", role="assistant") - assert start.type == "TEXT_MESSAGE_START" - assert start.message_id == "msg_123" - assert start.role == "assistant" - - end = bridge.create_message_end_event("msg_123") - assert end.type == "TEXT_MESSAGE_END" - assert end.message_id == "msg_123" - - -async def test_state_event_creation(): - """Test StateSnapshotEvent and StateDeltaEvent creation helpers.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - # StateSnapshotEvent - snapshot = bridge.create_state_snapshot_event({"document": "content"}) - assert snapshot.type == "STATE_SNAPSHOT" - assert snapshot.snapshot == {"document": "content"} - - # StateDeltaEvent with JSON Patch - delta = bridge.create_state_delta_event([{"op": "replace", "path": "/document", "value": "new content"}]) - assert delta.type == "STATE_DELTA" - assert len(delta.delta) == 1 - assert delta.delta[0]["op"] == "replace" - assert delta.delta[0]["path"] == "/document" - assert delta.delta[0]["value"] == "new content" - - -async def test_state_snapshot_after_tool_result(): - """Test StateSnapshotEvent emission after tool result with pending updates.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={ - "document": {"tool": "write_doc", "tool_argument": "content"}, - }, - current_state={"document": ""}, - ) - - # Tool call with streaming args - update1 = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="write_doc", call_id="call_1"), - Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Test"}'), - ] - ) - await bridge.from_agent_run_update(update1) - - # Tool result should trigger StateSnapshotEvent - update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]) - events = await bridge.from_agent_run_update(update2) - - # Should have: ToolCallEnd, ToolCallResult, StateSnapshot, ToolCallStart (confirm_changes), ToolCallArgs, ToolCallEnd - snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"] - assert len(snapshot_events) == 1 - assert snapshot_events[0].snapshot["document"] == "Test" - - -async def test_message_id_persistence_across_chunks(): - """Test that message_id persists across multiple text chunks.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - # First chunk - update1 = AgentResponseUpdate(contents=[Content.from_text(text="Hello ")]) - events1 = await bridge.from_agent_run_update(update1) - message_id = events1[0].message_id - - # Second chunk - update2 = AgentResponseUpdate(contents=[Content.from_text(text="world")]) - events2 = await bridge.from_agent_run_update(update2) - - # Should use same message_id - assert events2[0].message_id == message_id - assert bridge.current_message_id == message_id - - -async def test_tool_call_id_tracking(): - """Test tool_call_id tracking across streaming chunks.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - # First chunk with name - update1 = AgentResponseUpdate(contents=[Content.from_function_call(name="search", call_id="call_1")]) - await bridge.from_agent_run_update(update1) - - assert bridge.current_tool_call_id == "call_1" - assert bridge.current_tool_call_name == "search" - - # Second chunk with args but no name - update2 = AgentResponseUpdate( - contents=[Content.from_function_call(name="", call_id="call_1", arguments='{"q":"AI"}')] - ) - events2 = await bridge.from_agent_run_update(update2) - - # Should still track same tool call - assert bridge.current_tool_call_id == "call_1" - assert events2[0].tool_call_id == "call_1" - - -async def test_tool_name_reset_after_result(): - """Test current_tool_call_name is reset after tool result.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={ - "document": {"tool": "write_doc", "tool_argument": "content"}, - }, - ) - - # Tool call - update1 = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="write_doc", call_id="call_1"), - Content.from_function_call(name="", call_id="call_1", arguments='{"content": "Test"}'), - ] - ) - await bridge.from_agent_run_update(update1) - - assert bridge.current_tool_call_name == "write_doc" - - # Tool result with predictive state (should trigger confirm_changes and reset) - update2 = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]) - await bridge.from_agent_run_update(update2) - - # Tool name should be reset - assert bridge.current_tool_call_name is None - - -async def test_function_approval_with_wildcard_argument(): - """Test function approval with wildcard * argument.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={ - "payload": {"tool": "submit", "tool_argument": "*"}, - }, - ) - - approval_content = Content.from_function_approval_request( - id="approval_1", - function_call=Content.from_function_call( - name="submit", call_id="call_1", arguments='{"key1": "value1", "key2": "value2"}' - ), - ) - - update = AgentResponseUpdate(contents=[approval_content]) - events = await bridge.from_agent_run_update(update) - - # Should emit StateSnapshotEvent with entire parsed args as value - snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"] - assert len(snapshot_events) == 1 - assert snapshot_events[0].snapshot["payload"] == {"key1": "value1", "key2": "value2"} - - -async def test_function_approval_missing_argument(): - """Test function approval when specified argument is not in parsed args.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={ - "data": {"tool": "process", "tool_argument": "missing_field"}, - }, - ) - - approval_content = Content.from_function_approval_request( - id="approval_1", - function_call=Content.from_function_call( - name="process", call_id="call_1", arguments='{"other_field": "value"}' - ), - ) - - update = AgentResponseUpdate(contents=[approval_content]) - events = await bridge.from_agent_run_update(update) - - # Should not emit StateSnapshotEvent since argument not found - snapshot_events = [e for e in events if e.type == "STATE_SNAPSHOT"] - assert len(snapshot_events) == 0 - - -async def test_empty_predict_state_config_no_deltas(): - """Test with empty predict_state_config (no predictive updates).""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread", predict_state_config={}) - - # Tool call with arguments - update = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="search", call_id="call_1"), - Content.from_function_call(name="", call_id="call_1", arguments='{"query": "test"}'), - ] - ) - events = await bridge.from_agent_run_update(update) - - # Should not emit any StateDeltaEvents - delta_events = [e for e in events if e.type == "STATE_DELTA"] - assert len(delta_events) == 0 - - -async def test_tool_with_no_matching_config(): - """Test tool call for tool not in predict_state_config.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}}, - ) - - # Tool call for different tool - update = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="search_web", call_id="call_1"), - Content.from_function_call(name="", call_id="call_1", arguments='{"query": "test"}'), - ] - ) - events = await bridge.from_agent_run_update(update) - - # Should not emit StateDeltaEvents - delta_events = [e for e in events if e.type == "STATE_DELTA"] - assert len(delta_events) == 0 - - -async def test_tool_call_without_name_or_id(): - """Test handling FunctionCallContent with no name and no call_id.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - # This should not crash but log an error - update = AgentResponseUpdate(contents=[Content.from_function_call(name="", call_id="", arguments='{"arg": "val"}')]) - events = await bridge.from_agent_run_update(update) - - # Should emit ToolCallArgsEvent with generated ID - assert len(events) >= 1 - - -async def test_state_delta_count_logging(): - """Test that state delta count increments and logs at intervals.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}, - ) - - # Emit multiple state deltas with different content each time - for i in range(15): - update = AgentResponseUpdate( - contents=[ - Content.from_function_call(name="", call_id="call_1", arguments=f'{{"text": "Content variation {i}"}}'), - ] - ) - # Set the tool name to match config - bridge.current_tool_call_name = "write" - await bridge.from_agent_run_update(update) - - # State delta count should have incremented (one per unique state update) - assert bridge.state_delta_count >= 1 - - -# Tests for list type tool results (MCP tool serialization) - - -async def test_tool_result_with_empty_list(): - """Test FunctionResultContent with empty list result.""" - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_123", result=[])]) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 2 - assert events[0].type == "TOOL_CALL_END" - assert events[1].type == "TOOL_CALL_RESULT" - # Empty list serializes as JSON empty array - assert events[1].content == "[]" - - -async def test_tool_result_with_single_text_content(): - """Test FunctionResultContent with single TextContent-like item (MCP tool result).""" - from dataclasses import dataclass - - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - @dataclass - class MockTextContent: - text: str - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate( - contents=[Content.from_function_result(call_id="call_123", result=[MockTextContent("Hello from MCP tool!")])] - ) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 2 - assert events[0].type == "TOOL_CALL_END" - assert events[1].type == "TOOL_CALL_RESULT" - # TextContent text is extracted and serialized as JSON array - assert events[1].content == '["Hello from MCP tool!"]' - - -async def test_tool_result_with_multiple_text_contents(): - """Test FunctionResultContent with multiple TextContent-like items (MCP tool result).""" - from dataclasses import dataclass - - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - @dataclass - class MockTextContent: - text: str - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate( - contents=[ - Content.from_function_result( - call_id="call_123", - result=[MockTextContent("First result"), MockTextContent("Second result")], - ) - ] - ) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 2 - assert events[0].type == "TOOL_CALL_END" - assert events[1].type == "TOOL_CALL_RESULT" - # Multiple TextContent items should return JSON array - assert events[1].content == '["First result", "Second result"]' - - -async def test_tool_result_with_model_dump_objects(): - """Test FunctionResultContent with Pydantic BaseModel objects.""" - from pydantic import BaseModel - - from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - class MockModel(BaseModel): - value: int - - bridge = AgentFrameworkEventBridge(run_id="test_run", thread_id="test_thread") - - update = AgentResponseUpdate( - contents=[Content.from_function_result(call_id="call_123", result=[MockModel(value=1), MockModel(value=2)])] - ) - events = await bridge.from_agent_run_update(update) - - assert len(events) == 2 - assert events[1].type == "TOOL_CALL_RESULT" - # Should be properly serialized JSON array without double escaping - assert events[1].content == '[{"value": 1}, {"value": 2}]' diff --git a/python/packages/ag-ui/tests/test_human_in_the_loop.py b/python/packages/ag-ui/tests/test_human_in_the_loop.py deleted file mode 100644 index b643465e36..0000000000 --- a/python/packages/ag-ui/tests/test_human_in_the_loop.py +++ /dev/null @@ -1,180 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Tests for human in the loop (function approval requests).""" - -from agent_framework import AgentResponseUpdate, Content - -from agent_framework_ag_ui._events import AgentFrameworkEventBridge - - -async def test_function_approval_request_emission(): - """Test that CustomEvent is emitted for FunctionApprovalRequestContent.""" - # Set require_confirmation=False to test just the function_approval_request event - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - require_confirmation=False, - ) - - # Create approval request - func_call = Content.from_function_call( - call_id="call_123", - name="send_email", - arguments={"to": "user@example.com", "subject": "Test"}, - ) - approval_request = Content.from_function_approval_request( - id="approval_001", - function_call=func_call, - ) - - update = AgentResponseUpdate(contents=[approval_request]) - events = await bridge.from_agent_run_update(update) - - # Should emit ToolCallEndEvent + CustomEvent for approval request - assert len(events) == 2 - - # First event: ToolCallEndEvent to close the tool call - assert events[0].type == "TOOL_CALL_END" - assert events[0].tool_call_id == "call_123" - - # Second event: CustomEvent with approval details - event = events[1] - assert event.type == "CUSTOM" - assert event.name == "function_approval_request" - assert event.value["id"] == "approval_001" - assert event.value["function_call"]["call_id"] == "call_123" - assert event.value["function_call"]["name"] == "send_email" - assert event.value["function_call"]["arguments"]["to"] == "user@example.com" - assert event.value["function_call"]["arguments"]["subject"] == "Test" - - -async def test_function_approval_request_with_confirm_changes(): - """Test that confirm_changes is also emitted when require_confirmation=True.""" - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - require_confirmation=True, - ) - - func_call = Content.from_function_call( - call_id="call_456", - name="delete_file", - arguments={"path": "/tmp/test.txt"}, - ) - approval_request = Content.from_function_approval_request( - id="approval_002", - function_call=func_call, - ) - - update = AgentResponseUpdate(contents=[approval_request]) - events = await bridge.from_agent_run_update(update) - - # Should emit: ToolCallEndEvent, CustomEvent, and confirm_changes (Start, Args, End) = 5 events - assert len(events) == 5 - - # Check ToolCallEndEvent - assert events[0].type == "TOOL_CALL_END" - assert events[0].tool_call_id == "call_456" - - # Check function_approval_request CustomEvent - assert events[1].type == "CUSTOM" - assert events[1].name == "function_approval_request" - - # Check confirm_changes tool call events - assert events[2].type == "TOOL_CALL_START" - assert events[2].tool_call_name == "confirm_changes" - assert events[3].type == "TOOL_CALL_ARGS" - # Verify confirm_changes includes function info for Dojo UI - import json - - args = json.loads(events[3].delta) - assert args["function_name"] == "delete_file" - assert args["function_call_id"] == "call_456" - assert args["function_arguments"] == {"path": "/tmp/test.txt"} - assert args["steps"] == [ - { - "description": "Execute delete_file", - "status": "enabled", - } - ] - assert events[4].type == "TOOL_CALL_END" - - -async def test_multiple_approval_requests(): - """Test handling multiple approval requests in one update.""" - # Set require_confirmation=False to simplify the test - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - require_confirmation=False, - ) - - func_call_1 = Content.from_function_call( - call_id="call_1", - name="create_event", - arguments={"title": "Meeting"}, - ) - approval_1 = Content.from_function_approval_request( - id="approval_1", - function_call=func_call_1, - ) - - func_call_2 = Content.from_function_call( - call_id="call_2", - name="book_room", - arguments={"room": "Conference A"}, - ) - approval_2 = Content.from_function_approval_request( - id="approval_2", - function_call=func_call_2, - ) - - update = AgentResponseUpdate(contents=[approval_1, approval_2]) - events = await bridge.from_agent_run_update(update) - - # Should emit ToolCallEndEvent + CustomEvent for each approval (4 events total) - assert len(events) == 4 - - # Events should alternate: End, Custom, End, Custom - assert events[0].type == "TOOL_CALL_END" - assert events[0].tool_call_id == "call_1" - - assert events[1].type == "CUSTOM" - assert events[1].name == "function_approval_request" - assert events[1].value["id"] == "approval_1" - - assert events[2].type == "TOOL_CALL_END" - assert events[2].tool_call_id == "call_2" - - assert events[3].type == "CUSTOM" - assert events[3].name == "function_approval_request" - assert events[3].value["id"] == "approval_2" - - -async def test_function_approval_request_sets_stop_flag(): - """Test that function approval request sets should_stop_after_confirm flag. - - This ensures the orchestrator stops the run after emitting the approval request, - allowing the UI to send back an approval response. - """ - bridge = AgentFrameworkEventBridge( - run_id="test_run", - thread_id="test_thread", - ) - - assert bridge.should_stop_after_confirm is False - - func_call = Content.from_function_call( - call_id="call_stop_test", - name="get_datetime", - arguments={}, - ) - approval_request = Content.from_function_approval_request( - id="approval_stop_test", - function_call=func_call, - ) - - update = AgentResponseUpdate(contents=[approval_request]) - await bridge.from_agent_run_update(update) - - assert bridge.should_stop_after_confirm is True diff --git a/python/packages/ag-ui/tests/test_orchestrators.py b/python/packages/ag-ui/tests/test_orchestrators.py deleted file mode 100644 index c951246bfa..0000000000 --- a/python/packages/ag-ui/tests/test_orchestrators.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Tests for AG-UI orchestrators.""" - -from collections.abc import AsyncGenerator -from typing import Any -from unittest.mock import MagicMock - -from ag_ui.core import BaseEvent, RunFinishedEvent -from agent_framework import ( - AgentResponseUpdate, - AgentThread, - BaseChatClient, - ChatAgent, - ChatResponseUpdate, - Content, - FunctionInvocationConfiguration, - ai_function, -) - -from agent_framework_ag_ui._agent import AgentConfig -from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, ExecutionContext - - -@ai_function -def server_tool() -> str: - """Server-executable tool.""" - return "server" - - -def _create_mock_chat_agent( - tools: list[Any] | None = None, - response_format: Any = None, - capture_tools: list[Any] | None = None, - capture_messages: list[Any] | None = None, -) -> ChatAgent: - """Create a ChatAgent with mocked chat client for testing. - - Args: - tools: Tools to configure on the agent. - response_format: Response format to configure. - capture_tools: If provided, tools passed to run_stream will be appended here. - capture_messages: If provided, messages passed to run_stream will be appended here. - """ - mock_chat_client = MagicMock(spec=BaseChatClient) - mock_chat_client.function_invocation_configuration = FunctionInvocationConfiguration() - - agent = ChatAgent( - chat_client=mock_chat_client, - tools=tools or [server_tool], - response_format=response_format, - ) - - # Create a mock run_stream that captures parameters and yields a simple response - async def mock_run_stream( - messages: list[Any], - *, - # thread: AgentThread, - # tools: list[Any] | None = None, - # **kwargs: Any, - # ) -> AsyncGenerator[AgentRunResponseUpdate, None]: - # self.seen_tools = tools - # yield AgentRunResponseUpdate( - # contents=[TextContent(text="ok")], - # role="assistant", - # response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - # raw_representation=ChatResponseUpdate( - # contents=[TextContent(text="ok")], - # conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - # response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - # ), - # ) - thread: AgentThread, - tools: list[Any] | None = None, - **kwargs: Any, - ) -> AsyncGenerator[AgentResponseUpdate, None]: - if capture_tools is not None and tools is not None: - capture_tools.extend(tools) - if capture_messages is not None: - capture_messages.extend(messages) - yield AgentResponseUpdate( - contents=[Content.from_text(text="ok")], - role="assistant", - response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - raw_representation=ChatResponseUpdate( - contents=[Content.from_text(text="ok")], - conversation_id=thread.metadata.get("ag_ui_thread_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - response_id=thread.metadata.get("ag_ui_run_id"), # type: ignore[attr-defined] (metadata always created in orchestrator) - ), - ) - - # Patch the run_stream method - agent.run_stream = mock_run_stream # type: ignore[method-assign] - - return agent - - -async def test_default_orchestrator_merges_client_tools() -> None: - """Client tool declarations are merged with server tools before running agent.""" - captured_tools: list[Any] = [] - agent = _create_mock_chat_agent(tools=[server_tool], capture_tools=captured_tools) - orchestrator = DefaultOrchestrator() - - input_data = { - "messages": [ - { - "role": "user", - "content": [{"type": "input_text", "text": "Hello"}], - } - ], - "tools": [ - { - "name": "get_weather", - "description": "Client weather lookup.", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - } - ], - } - - context = ExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events = [] - async for event in orchestrator.run(context): - events.append(event) - - assert len(captured_tools) > 0 - tool_names = [getattr(tool, "name", "?") for tool in captured_tools] - assert "server_tool" in tool_names - assert "get_weather" in tool_names - assert agent.chat_client.function_invocation_configuration.additional_tools - - -async def test_default_orchestrator_with_camel_case_ids() -> None: - """Client tool is able to extract camelCase IDs.""" - agent = _create_mock_chat_agent() - orchestrator = DefaultOrchestrator() - - input_data = { - "runId": "test-camelcase-runid", - "threadId": "test-camelcase-threadid", - "messages": [ - { - "role": "user", - "content": [{"type": "input_text", "text": "Hello"}], - } - ], - "tools": [], - } - - context = ExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events = [] - async for event in orchestrator.run(context): - events.append(event) - - # assert the last event has the expected run_id and thread_id - assert isinstance(events[-1], RunFinishedEvent) - last_event = events[-1] - assert last_event.run_id == "test-camelcase-runid" - assert last_event.thread_id == "test-camelcase-threadid" - - -async def test_default_orchestrator_with_snake_case_ids() -> None: - """Client tool is able to extract snake_case IDs.""" - agent = _create_mock_chat_agent() - orchestrator = DefaultOrchestrator() - - input_data = { - "run_id": "test-snakecase-runid", - "thread_id": "test-snakecase-threadid", - "messages": [ - { - "role": "user", - "content": [{"type": "input_text", "text": "Hello"}], - } - ], - "tools": [], - } - - context = ExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events: list[BaseEvent] = [] - async for event in orchestrator.run(context): - events.append(event) - - # assert the last event has the expected run_id and thread_id - assert isinstance(events[-1], RunFinishedEvent) - last_event = events[-1] - assert last_event.run_id == "test-snakecase-runid" - assert last_event.thread_id == "test-snakecase-threadid" - - -async def test_state_context_injected_when_tool_call_state_mismatch() -> None: - """State context should be injected when current state differs from tool call args.""" - captured_messages: list[Any] = [] - agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages) - orchestrator = DefaultOrchestrator() - - tool_recipe = {"title": "Salad", "special_preferences": []} - current_recipe = {"title": "Salad", "special_preferences": ["Vegetarian"]} - - input_data = { - "state": {"recipe": current_recipe}, - "messages": [ - {"role": "system", "content": "Instructions"}, - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": {"name": "update_recipe", "arguments": {"recipe": tool_recipe}}, - } - ], - }, - {"role": "user", "content": "What are the dietary preferences?"}, - ], - } - - context = ExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig( - state_schema={"recipe": {"type": "object"}}, - predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}}, - require_confirmation=False, - ), - ) - - async for _event in orchestrator.run(context): - pass - - assert len(captured_messages) > 0 - state_messages = [] - for msg in captured_messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - if role_value != "system": - continue - for content in msg.contents or []: - if content.type == "text" and content.text.startswith("Current state of the application:"): - state_messages.append(content.text) - assert state_messages - assert "Vegetarian" in state_messages[0] - - -async def test_state_context_not_injected_when_tool_call_matches_state() -> None: - """State context should be skipped when tool call args match current state.""" - captured_messages: list[Any] = [] - agent = _create_mock_chat_agent(tools=[], capture_messages=captured_messages) - orchestrator = DefaultOrchestrator() - - input_data = { - "messages": [ - {"role": "system", "content": "Instructions"}, - { - "role": "assistant", - "tool_calls": [ - { - "id": "call_1", - "type": "function", - "function": {"name": "update_recipe", "arguments": {"recipe": {}}}, - } - ], - }, - {"role": "user", "content": "What are the dietary preferences?"}, - ], - } - - context = ExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig( - state_schema={"recipe": {"type": "object"}}, - predict_state_config={"recipe": {"tool": "update_recipe", "tool_argument": "recipe"}}, - require_confirmation=False, - ), - ) - - async for _event in orchestrator.run(context): - pass - - assert len(captured_messages) > 0 - state_messages = [] - for msg in captured_messages: - role_value = msg.role.value if hasattr(msg.role, "value") else str(msg.role) - if role_value != "system": - continue - for content in msg.contents or []: - if content.type == "text" and content.text.startswith("Current state of the application:"): - state_messages.append(content.text) - assert not state_messages diff --git a/python/packages/ag-ui/tests/test_orchestrators_coverage.py b/python/packages/ag-ui/tests/test_orchestrators_coverage.py deleted file mode 100644 index d579c691b7..0000000000 --- a/python/packages/ag-ui/tests/test_orchestrators_coverage.py +++ /dev/null @@ -1,872 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Comprehensive tests for orchestrator coverage.""" - -import sys -from collections.abc import AsyncGenerator -from pathlib import Path -from types import SimpleNamespace -from typing import Any - -from agent_framework import AgentResponseUpdate, ChatMessage, Content, ai_function -from pydantic import BaseModel - -from agent_framework_ag_ui._agent import AgentConfig -from agent_framework_ag_ui._orchestrators import DefaultOrchestrator, HumanInTheLoopOrchestrator - -sys.path.insert(0, str(Path(__file__).parent)) -from utils_test_ag_ui import StubAgent, TestExecutionContext - - -@ai_function(approval_mode="always_require") -def approval_tool(param: str) -> str: - """Tool requiring approval.""" - return f"executed: {param}" - - -DEFAULT_OPTIONS: dict[str, Any] = {"tools": [approval_tool], "response_format": None} - - -async def test_human_in_the_loop_json_decode_error() -> None: - """Test HumanInTheLoopOrchestrator handles invalid JSON in tool result.""" - orchestrator = HumanInTheLoopOrchestrator() - - input_data: dict[str, Any] = { - "messages": [ - { - "role": "tool", - "content": [{"type": "text", "text": "not valid json {"}], - } - ], - } - - messages = [ - ChatMessage( - role="tool", - contents=[Content.from_text(text="not valid json {")], - additional_properties={"is_tool_result": True}, - ) - ] - - agent = StubAgent( - default_options={"tools": [approval_tool], "response_format": None}, - updates=[AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant")], - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages, normalize=False) - - assert orchestrator.can_handle(context) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should emit RunErrorEvent for invalid JSON - error_events: list[Any] = [e for e in events if e.type == "RUN_ERROR"] - assert len(error_events) == 1 - assert "Invalid tool result format" in error_events[0].message - - -async def test_sanitize_tool_history_confirm_changes() -> None: - """Test sanitize_tool_history logic for confirm_changes synthetic result.""" - from agent_framework import ChatMessage - - # Create messages that will trigger confirm_changes synthetic result injection - messages = [ - ChatMessage( - role="assistant", - contents=[ - Content.from_function_call( - name="confirm_changes", - call_id="call_confirm_123", - arguments='{"changes": "test"}', - ) - ], - ), - ChatMessage( - role="user", - contents=[Content.from_text(text='{"accepted": true}')], - ), - ] - - # The sanitize_tool_history function is internal to DefaultOrchestrator.run - # We'll test it indirectly by checking the orchestrator processes it correctly - orchestrator = DefaultOrchestrator() - - # Use pre-constructed ChatMessage objects to bypass message adapter - input_data: dict[str, Any] = {"messages": []} - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - # Override the messages property to use our pre-constructed messages - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Agent should receive synthetic tool result - assert len(agent.messages_received) > 0 - tool_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" - ] - assert len(tool_messages) == 1 - assert str(tool_messages[0].contents[0].call_id) == "call_confirm_123" - assert tool_messages[0].contents[0].result == "Confirmed" - - -async def test_sanitize_tool_history_orphaned_tool_result() -> None: - """Test sanitize_tool_history removes orphaned tool results.""" - from agent_framework import ChatMessage - - # Tool result without preceding assistant tool call - messages = [ - ChatMessage( - role="tool", - contents=[Content.from_function_result(call_id="orphan_123", result="orphaned data")], - ), - ChatMessage( - role="user", - contents=[Content.from_text(text="Hello")], - ), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Orphaned tool result should be filtered out - tool_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" - ] - assert len(tool_messages) == 0 - - -async def test_orphaned_tool_result_sanitization() -> None: - """Test that orphaned tool results are filtered out.""" - orchestrator = DefaultOrchestrator() - - input_data: dict[str, Any] = { - "messages": [ - { - "role": "tool", - "content": [{"type": "tool_result", "tool_call_id": "orphan_123", "content": "result"}], - }, - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}], - }, - ], - } - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Orphaned tool result should be filtered, only user message remains - tool_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" - ] - assert len(tool_messages) == 0 - - -async def test_deduplicate_messages_empty_tool_results() -> None: - """Test deduplicate_messages prefers non-empty tool results.""" - from agent_framework import ChatMessage - - messages = [ - ChatMessage( - role="assistant", - contents=[Content.from_function_call(name="test_tool", call_id="call_789", arguments="{}")], - ), - ChatMessage( - role="tool", - contents=[Content.from_function_result(call_id="call_789", result="")], - ), - ChatMessage( - role="tool", - contents=[Content.from_function_result(call_id="call_789", result="real data")], - ), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should have only one tool result with actual data - tool_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" - ] - assert len(tool_messages) == 1 - assert tool_messages[0].contents[0].result == "real data" - - -async def test_deduplicate_messages_duplicate_assistant_tool_calls() -> None: - """Test deduplicate_messages removes duplicate assistant tool call messages.""" - from agent_framework import ChatMessage - - messages = [ - ChatMessage( - role="assistant", - contents=[Content.from_function_call(name="test_tool", call_id="call_abc", arguments="{}")], - ), - ChatMessage( - role="assistant", - contents=[Content.from_function_call(name="test_tool", call_id="call_abc", arguments="{}")], - ), - ChatMessage( - role="tool", - contents=[Content.from_function_result(call_id="call_abc", result="result")], - ), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should have only one assistant message - assistant_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "assistant" - ] - assert len(assistant_messages) == 1 - - -async def test_deduplicate_messages_duplicate_system_messages() -> None: - """Test that deduplication logic is invoked for system messages.""" - from agent_framework import ChatMessage - - messages = [ - ChatMessage( - role="system", - contents=[Content.from_text(text="You are a helpful assistant.")], - ), - ChatMessage( - role="system", - contents=[Content.from_text(text="You are a helpful assistant.")], - ), - ChatMessage( - role="user", - contents=[Content.from_text(text="Hello")], - ), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Deduplication uses hash() which may not deduplicate identical content - # This test verifies deduplication logic runs without errors - system_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "system" - ] - # At least one system message should be present - assert len(system_messages) >= 1 - - -async def test_state_context_injection() -> None: - """Test state context message injection for first request.""" - orchestrator = DefaultOrchestrator() - - input_data: dict[str, Any] = { - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}], - } - ], - "state": {"items": ["apple", "banana"]}, - } - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(state_schema={"items": {"type": "array"}}), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should inject system message with current state - system_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "system" - ] - assert len(system_messages) == 1 - assert "apple" in system_messages[0].contents[0].text - assert "banana" in system_messages[0].contents[0].text - - -async def test_state_context_injection_with_tool_calls_and_input_state() -> None: - """Test state context is injected when state is provided, even with tool calls.""" - from agent_framework import ChatMessage - - messages = [ - ChatMessage( - role="assistant", - contents=[Content.from_function_call(name="get_weather", call_id="call_xyz", arguments="{}")], - ), - ChatMessage( - role="tool", - contents=[Content.from_function_result(call_id="call_xyz", result="sunny")], - ), - ChatMessage( - role="user", - contents=[Content.from_text(text="Thanks")], - ), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": [], "state": {"weather": "sunny"}} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(state_schema={"weather": {"type": "string"}}), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should inject state context system message because input state is provided - system_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "system" - ] - assert len(system_messages) == 1 - - -async def test_structured_output_processing() -> None: - """Test structured output extraction and state update.""" - - class RecipeState(BaseModel): - ingredients: list[str] - message: str - - orchestrator = DefaultOrchestrator() - - input_data: dict[str, Any] = { - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": "Add tomato"}], - } - ], - } - - # Agent with structured output - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - updates=[ - AgentResponseUpdate( - contents=[Content.from_text(text='{"ingredients": ["tomato"], "message": "Added tomato"}')], - role="assistant", - ) - ], - ) - agent.default_options["response_format"] = RecipeState - - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(state_schema={"ingredients": {"type": "array"}}), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should emit StateSnapshotEvent with ingredients - state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"] - assert len(state_events) >= 1 - - # Should emit TextMessage with message field - text_content_events: list[Any] = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"] - assert len(text_content_events) >= 1 - assert any("Added tomato" in e.delta for e in text_content_events) - - -async def test_duplicate_client_tools_filtered() -> None: - """Test that client tools duplicating server tools are filtered out.""" - - @ai_function - def get_weather(location: str) -> str: - """Get weather for location.""" - return f"Weather in {location}" - - orchestrator = DefaultOrchestrator() - - input_data: dict[str, Any] = { - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}], - } - ], - "tools": [ - { - "name": "get_weather", - "description": "Client weather tool.", - "parameters": { - "type": "object", - "properties": {"location": {"type": "string"}}, - "required": ["location"], - }, - } - ], - } - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - agent.default_options["tools"] = [get_weather] - - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # tools parameter should not be passed since client tool duplicates server tool - assert agent.tools_received is None - - -async def test_unique_client_tools_merged() -> None: - """Test that unique client tools are merged with server tools.""" - - @ai_function - def server_tool() -> str: - """Server tool.""" - return "server" - - orchestrator = DefaultOrchestrator() - - input_data: dict[str, Any] = { - "messages": [ - { - "role": "user", - "content": [{"type": "text", "text": "Hello"}], - } - ], - "tools": [ - { - "name": "client_tool", - "description": "Unique client tool.", - "parameters": { - "type": "object", - "properties": {"param": {"type": "string"}}, - "required": ["param"], - }, - } - ], - } - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - agent.default_options["tools"] = [server_tool] - - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # tools parameter should be passed with both server and client tools - assert agent.tools_received is not None - tool_names = [getattr(tool, "name", None) for tool in agent.tools_received] - assert "server_tool" in tool_names - assert "client_tool" in tool_names - - -async def test_empty_messages_handling() -> None: - """Test orchestrator handles empty message list gracefully.""" - orchestrator = DefaultOrchestrator() - - input_data: dict[str, Any] = {"messages": []} - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should emit run lifecycle events but not call agent - assert len(agent.messages_received) == 0 - run_started = [e for e in events if e.type == "RUN_STARTED"] - run_finished = [e for e in events if e.type == "RUN_FINISHED"] - assert len(run_started) == 1 - assert len(run_finished) == 1 - - -async def test_all_messages_filtered_handling() -> None: - """Test orchestrator handles case where all messages are filtered out.""" - orchestrator = DefaultOrchestrator() - - input_data: dict[str, Any] = { - "messages": [ - { - "role": "tool", - "content": [{"type": "tool_result", "tool_call_id": "orphan", "content": "data"}], - } - ] - } - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should finish without calling agent - assert len(agent.messages_received) == 0 - run_finished = [e for e in events if e.type == "RUN_FINISHED"] - assert len(run_finished) == 1 - - -async def test_confirm_changes_with_invalid_json_fallback() -> None: - """Test confirm_changes with invalid JSON falls back to normal processing.""" - from agent_framework import ChatMessage - - messages = [ - ChatMessage( - role="assistant", - contents=[ - Content.from_function_call( - name="confirm_changes", - call_id="call_confirm_invalid", - arguments='{"changes": "test"}', - ) - ], - ), - ChatMessage( - role="user", - contents=[Content.from_text(text="invalid json {")], - ), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Invalid JSON should fall back - user message should be included - user_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "user" - ] - assert len(user_messages) == 1 - - -async def test_confirm_changes_closes_active_message_before_finish() -> None: - """Confirm-changes flow closes any active text message before run finishes.""" - from ag_ui.core import TextMessageEndEvent, TextMessageStartEvent - - updates = [ - AgentResponseUpdate( - contents=[ - Content.from_function_call( - name="write_document_local", - call_id="call_1", - arguments='{"document": "Draft"}', - ) - ] - ), - AgentResponseUpdate(contents=[Content.from_function_result(call_id="call_1", result="Done")]), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": [{"role": "user", "content": "Start"}]} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - updates=updates, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig( - predict_state_config={"document": {"tool": "write_document_local", "tool_argument": "document"}}, - require_confirmation=True, - ), - ) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - start_events = [e for e in events if isinstance(e, TextMessageStartEvent)] - end_events = [e for e in events if isinstance(e, TextMessageEndEvent)] - assert len(start_events) == 1 - assert len(end_events) == 1 - assert end_events[0].message_id == start_events[0].message_id - - end_index = events.index(end_events[0]) - finished_index = events.index([e for e in events if e.type == "RUN_FINISHED"][0]) - assert end_index < finished_index - - -async def test_tool_result_kept_when_call_id_matches() -> None: - """Test tool result is kept when call_id matches pending tool calls.""" - from agent_framework import ChatMessage - - messages = [ - ChatMessage( - role="assistant", - contents=[Content.from_function_call(name="get_data", call_id="call_match", arguments="{}")], - ), - ChatMessage( - role="tool", - contents=[Content.from_function_result(call_id="call_match", result="data")], - ), - ] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Tool result should be kept - tool_messages = [ - msg - for msg in agent.messages_received - if (msg.role.value if hasattr(msg.role, "value") else str(msg.role)) == "tool" - ] - assert len(tool_messages) == 1 - assert tool_messages[0].contents[0].result == "data" - - -async def test_agent_protocol_fallback_paths() -> None: - """Test fallback paths for non-ChatAgent implementations.""" - - class CustomAgent: - """Custom agent without ChatAgent type.""" - - def __init__(self) -> None: - self.default_options: dict[str, Any] = {"tools": [], "response_format": None} - self.chat_client = SimpleNamespace(function_invocation_configuration=SimpleNamespace()) - self.messages_received: list[Any] = [] - - async def run_stream( - self, - messages: list[Any], - *, - thread: Any = None, - tools: list[Any] | None = None, - **kwargs: Any, - ) -> AsyncGenerator[AgentResponseUpdate, None]: - self.messages_received = messages - yield AgentResponseUpdate(contents=[Content.from_text(text="response")], role="assistant") - - from agent_framework import ChatMessage - - messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - agent = CustomAgent() - context = TestExecutionContext( - input_data=input_data, - agent=agent, # type: ignore - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should work with custom agent implementation - assert len(agent.messages_received) > 0 - - -async def test_initial_state_snapshot_with_array_schema() -> None: - """Test state initialization with array type schema.""" - from agent_framework import ChatMessage - - messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": [], "state": {}} - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(state_schema={"items": {"type": "array"}}), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Should emit state snapshot with empty array for items - state_events: list[Any] = [e for e in events if e.type == "STATE_SNAPSHOT"] - assert len(state_events) >= 1 - - -async def test_response_format_skip_text_content() -> None: - """Test that response_format causes skip_text_content to be set.""" - - class OutputModel(BaseModel): - result: str - - from agent_framework import ChatMessage - - messages = [ChatMessage(role="user", contents=[Content.from_text(text="Hello")])] - - orchestrator = DefaultOrchestrator() - input_data: dict[str, Any] = {"messages": []} - - agent = StubAgent( - default_options=DEFAULT_OPTIONS, - ) - agent.default_options["response_format"] = OutputModel - - context = TestExecutionContext( - input_data=input_data, - agent=agent, - config=AgentConfig(), - ) - context.set_messages(messages) - - events: list[Any] = [] - async for event in orchestrator.run(context): - events.append(event) - - # Test passes if no errors occur - verifies response_format code path - assert len(events) > 0 diff --git a/python/packages/ag-ui/tests/test_shared_state.py b/python/packages/ag-ui/tests/test_shared_state.py deleted file mode 100644 index 4b3f5ebb23..0000000000 --- a/python/packages/ag-ui/tests/test_shared_state.py +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -"""Tests for shared state management.""" - -import sys -from pathlib import Path -from typing import Any - -import pytest -from ag_ui.core import StateSnapshotEvent -from agent_framework import ChatAgent, ChatResponseUpdate, Content - -from agent_framework_ag_ui._agent import AgentFrameworkAgent -from agent_framework_ag_ui._events import AgentFrameworkEventBridge - -sys.path.insert(0, str(Path(__file__).parent)) -from utils_test_ag_ui import StreamingChatClientStub, stream_from_updates - - -@pytest.fixture -def mock_agent() -> ChatAgent: - """Create a mock agent for testing.""" - updates = [ChatResponseUpdate(contents=[Content.from_text(text="Hello!")])] - chat_client = StreamingChatClientStub(stream_from_updates(updates)) - return ChatAgent(name="test_agent", instructions="Test agent", chat_client=chat_client) - - -def test_state_snapshot_event(): - """Test creating state snapshot events.""" - bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread") - - state = { - "recipe": { - "name": "Chocolate Chip Cookies", - "ingredients": ["flour", "sugar", "chocolate chips"], - "instructions": ["Mix ingredients", "Bake at 350°F"], - "servings": 24, - } - } - - event = bridge.create_state_snapshot_event(state) - - assert isinstance(event, StateSnapshotEvent) - assert event.snapshot == state - assert event.snapshot["recipe"]["name"] == "Chocolate Chip Cookies" - assert len(event.snapshot["recipe"]["ingredients"]) == 3 - - -def test_state_delta_event(): - """Test creating state delta events using JSON Patch format.""" - bridge = AgentFrameworkEventBridge(run_id="test-run", thread_id="test-thread") - - # JSON Patch operations (RFC 6902) - delta = [ - {"op": "add", "path": "/recipe/ingredients/-", "value": "vanilla extract"}, - {"op": "replace", "path": "/recipe/servings", "value": 30}, - ] - - event = bridge.create_state_delta_event(delta) - - assert event.delta == delta - assert len(event.delta) == 2 - assert event.delta[0]["op"] == "add" - assert event.delta[1]["op"] == "replace" - - -async def test_agent_with_initial_state(mock_agent: ChatAgent) -> None: - """Test agent emits state snapshot when initial state provided.""" - state_schema: dict[str, Any] = {"recipe": {"type": "object", "properties": {"name": {"type": "string"}}}} - - agent = AgentFrameworkAgent( - agent=mock_agent, - state_schema=state_schema, - ) - - initial_state = {"recipe": {"name": "Test Recipe"}} - - input_data: dict[str, Any] = { - "messages": [{"role": "user", "content": "Hello"}], - "state": initial_state, - } - - events: list[Any] = [] - async for event in agent.run_agent(input_data): - events.append(event) - - # Should have RunStartedEvent, StateSnapshotEvent, RunFinishedEvent at minimum - snapshot_events = [e for e in events if isinstance(e, StateSnapshotEvent)] - assert len(snapshot_events) == 1 - assert snapshot_events[0].snapshot == initial_state - - -async def test_agent_without_state_schema(mock_agent: ChatAgent) -> None: - """Test agent doesn't emit state events without state schema.""" - agent = AgentFrameworkAgent(agent=mock_agent) - - input_data: dict[str, Any] = { - "messages": [{"role": "user", "content": "Hello"}], - "state": {"some": "state"}, - } - - events: list[Any] = [] - async for event in agent.run_agent(input_data): - events.append(event) - - # Should NOT have any StateSnapshotEvent - snapshot_events = [e for e in events if isinstance(e, StateSnapshotEvent)] - assert len(snapshot_events) == 0 diff --git a/python/packages/ag-ui/tests/test_state_manager.py b/python/packages/ag-ui/tests/test_state_manager.py deleted file mode 100644 index 47b2940978..0000000000 --- a/python/packages/ag-ui/tests/test_state_manager.py +++ /dev/null @@ -1,51 +0,0 @@ -# Copyright (c) Microsoft. All rights reserved. - -from ag_ui.core import CustomEvent, EventType -from agent_framework import ChatMessage - -from agent_framework_ag_ui._events import AgentFrameworkEventBridge -from agent_framework_ag_ui._orchestration._state_manager import StateManager - - -def test_state_manager_initializes_defaults_and_snapshot() -> None: - state_manager = StateManager( - state_schema={"items": {"type": "array"}, "metadata": {"type": "object"}}, - predict_state_config=None, - require_confirmation=True, - ) - current_state = state_manager.initialize({"metadata": {"a": 1}}) - bridge = AgentFrameworkEventBridge(run_id="run", thread_id="thread", current_state=current_state) - - snapshot_event = state_manager.initial_snapshot_event(bridge) - assert snapshot_event is not None - assert snapshot_event.snapshot["items"] == [] - assert snapshot_event.snapshot["metadata"] == {"a": 1} - - -def test_state_manager_predict_state_event_shape() -> None: - state_manager = StateManager( - state_schema=None, - predict_state_config={"doc": {"tool": "write_document_local", "tool_argument": "document"}}, - require_confirmation=True, - ) - predict_event = state_manager.predict_state_event() - assert isinstance(predict_event, CustomEvent) - assert predict_event.type == EventType.CUSTOM - assert predict_event.name == "PredictState" - assert predict_event.value[0]["state_key"] == "doc" - - -def test_state_context_only_when_new_user_turn() -> None: - state_manager = StateManager( - state_schema={"items": {"type": "array"}}, - predict_state_config=None, - require_confirmation=True, - ) - state_manager.initialize({"items": [1]}) - - assert state_manager.state_context_message(is_new_user_turn=False, conversation_has_tool_calls=False) is None - - message = state_manager.state_context_message(is_new_user_turn=True, conversation_has_tool_calls=False) - assert isinstance(message, ChatMessage) - assert message.contents[0].type == "text" - assert "Current state of the application" in message.contents[0].text diff --git a/python/packages/ag-ui/tests/utils_test_ag_ui.py b/python/packages/ag-ui/tests/utils_test_ag_ui.py index 33f462257e..50710191ec 100644 --- a/python/packages/ag-ui/tests/utils_test_ag_ui.py +++ b/python/packages/ag-ui/tests/utils_test_ag_ui.py @@ -21,7 +21,6 @@ from agent_framework._clients import TOptions_co from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history -from agent_framework_ag_ui._orchestrators import ExecutionContext if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover @@ -127,12 +126,5 @@ def get_new_thread(self, **kwargs: Any) -> AgentThread: return AgentThread() -class TestExecutionContext(ExecutionContext): - """ExecutionContext helper that allows setting messages for tests.""" - - def set_messages(self, messages: list[ChatMessage], *, normalize: bool = True) -> None: - if normalize: - self._messages = _deduplicate_messages(_sanitize_tool_history(messages)) - else: - self._messages = messages - self._snapshot_messages = None +# Note: TestExecutionContext was removed along with _orchestrators.py +# Tests should now use run_agent_stream() directly or the StubAgent class diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 3ea7d33e72..1e86e1d4d5 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1689,6 +1689,7 @@ async def _try_execute_function_calls( tool_map = _get_tool_map(tools) approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"] + logger.info(f"[APPROVAL-DEBUG] _try_execute_function_calls: tool_map keys={list(tool_map.keys())}, approval_tools={approval_tools}") declaration_only = [tool_name for tool_name, tool in tool_map.items() if tool.declaration_only] additional_tool_names = [tool.name for tool in config.additional_tools] if config.additional_tools else [] # check if any are calling functions that need approval @@ -1696,7 +1697,9 @@ async def _try_execute_function_calls( approval_needed = False declaration_only_flag = False for fcc in function_calls: + logger.info(f"[APPROVAL-DEBUG] Checking fcc: type={fcc.type}, name={getattr(fcc, 'name', None)}, in approval_tools={getattr(fcc, 'name', None) in approval_tools}") if fcc.type == "function_call" and fcc.name in approval_tools: # type: ignore[attr-defined] + logger.info(f"[APPROVAL-DEBUG] APPROVAL NEEDED for {fcc.name}") approval_needed = True break if fcc.type == "function_call" and (fcc.name in declaration_only or fcc.name in additional_tool_names): # type: ignore[attr-defined] @@ -1706,6 +1709,7 @@ async def _try_execute_function_calls( raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined] if approval_needed: # approval can only be needed for Function Call Content, not Approval Responses. + logger.info("[APPROVAL-DEBUG] Returning function_approval_request contents") return ( [ Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined, arg-type] @@ -2148,6 +2152,10 @@ async def streaming_function_invocation_wrapper( # we load the tools here, since middleware might have changed them compared to before calling func. tools = _extract_tools(options) + logger.info(f"[APPROVAL-DEBUG-STREAMING] tools extracted: {tools is not None}, function_calls: {len(function_calls) if function_calls else 0}") + if tools: + for t in (tools if isinstance(tools, list) else [tools]): + logger.info(f"[APPROVAL-DEBUG-STREAMING] - {getattr(t, 'name', 'unknown')}: approval_mode={getattr(t, 'approval_mode', None)}") if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function From 04e570ca79a1c39e2444a32f307abfcf0bf44286 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 21 Jan 2026 14:05:37 +0900 Subject: [PATCH 3/8] Fix backend tool --- python/packages/ag-ui/agent_framework_ag_ui/_run.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index 0bc4f69610..9586186c34 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -362,9 +362,11 @@ def _emit_tool_result( if flow.current_state: events.append(StateSnapshotEvent(snapshot=flow.current_state)) - # Reset tool tracking + # Reset tool tracking and message context + # After tool result, any subsequent text should start a new message flow.tool_call_id = None flow.tool_call_name = None + flow.message_id = None # Reset so next text content starts a new message return events From 5169e219e053c9d6183777fc3de57c37b46aee24 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 21 Jan 2026 14:31:27 +0900 Subject: [PATCH 4/8] Update tests --- .../ag-ui/agent_framework_ag_ui/_agent.py | 6 - .../_message_adapters.py | 8 - .../_orchestration/_helpers.py | 14 - .../ag-ui/agent_framework_ag_ui/_run.py | 43 +- .../ag-ui/agent_framework_ag_ui/_types.py | 12 +- .../agents/weather_agent.py | 3 +- .../server/main.py | 7 +- python/packages/ag-ui/tests/test_helpers.py | 503 ++++++++++++++++++ .../ag-ui/tests/test_message_adapters.py | 106 ++++ .../ag-ui/tests/test_predictive_state.py | 363 +++++++++++++ python/packages/ag-ui/tests/test_run.py | 378 +++++++++++++ python/packages/ag-ui/tests/test_tooling.py | 92 ++++ python/packages/ag-ui/tests/test_utils.py | 170 ++++++ 13 files changed, 1636 insertions(+), 69 deletions(-) create mode 100644 python/packages/ag-ui/tests/test_helpers.py create mode 100644 python/packages/ag-ui/tests/test_predictive_state.py create mode 100644 python/packages/ag-ui/tests/test_run.py diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py index 56488df876..38ca0e9767 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_agent.py @@ -115,9 +115,3 @@ async def run_agent( """ async for event in run_agent_stream(input_data, self.agent, self.config): yield event - - -__all__ = [ - "AgentFrameworkAgent", - "AgentConfig", -] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index cf14641258..525fb2faab 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -781,11 +781,3 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic result.append(normalized_msg) return result - - -__all__ = [ - "agui_messages_to_agent_framework", - "agent_framework_messages_to_agui", - "agui_messages_to_snapshot_format", - "extract_text_from_contents", -] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py index d38c125092..66b0160469 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py @@ -244,17 +244,3 @@ def is_step_based_approval( if config.get("tool") == tool_name and config.get("tool_argument") == "steps": return True return False - - -__all__ = [ - "pending_tool_call_ids", - "is_state_context_message", - "ensure_tool_call_entry", - "tool_name_for_call_id", - "schema_has_steps", - "select_approval_tool_name", - "build_safe_metadata", - "latest_approval_response", - "approval_steps", - "is_step_based_approval", -] diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index 9586186c34..554a044f39 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -152,7 +152,6 @@ class FlowState: tool_call_id: str | None = None # Current tool call being streamed tool_call_name: str | None = None # Name of current tool call waiting_for_approval: bool = False # Stop after approval request - pending_confirm_id: str | None = None # ID of pending confirm_changes tool call current_state: dict[str, Any] = field(default_factory=dict) # Shared state accumulated_text: str = "" # For MessagesSnapshotEvent pending_tool_calls: list[dict[str, Any]] = field(default_factory=list) # For MessagesSnapshotEvent @@ -379,18 +378,15 @@ def _emit_approval_request( ) -> list[BaseEvent]: """Emit events for function approval request.""" events: list[BaseEvent] = [] - logger.info(f"[APPROVAL-REQUEST] Starting _emit_approval_request, require_confirmation={require_confirmation}") # function_call is required for approval requests - skip if missing func_call = content.function_call - logger.info(f"[APPROVAL-REQUEST] func_call={func_call}, content.id={content.id}") if not func_call: logger.warning("Approval request content missing function_call, skipping") return events func_name = func_call.name or "" func_call_id = func_call.call_id - logger.info(f"[APPROVAL-REQUEST] func_name={func_name}, func_call_id={func_call_id}") # Extract state from function arguments if predictive if predictive_handler and func_name: @@ -422,12 +418,9 @@ def _emit_approval_request( ) # Emit confirm_changes tool call for UI compatibility - # IMPORTANT: Do NOT emit ToolCallEndEvent here - the tool must remain in "executing" - # status for the frontend to show the confirmation dialog. The end event will be - # emitted when the user responds with their confirmation/rejection. + # The complete sequence (Start -> Args -> End) signals the UI to show the confirmation dialog if require_confirmation: confirm_id = generate_event_id() - logger.info(f"[APPROVAL-REQUEST] Emitting confirm_changes with id={confirm_id} (no end event - stays executing)") events.append( ToolCallStartEvent( tool_call_id=confirm_id, @@ -442,11 +435,9 @@ def _emit_approval_request( "steps": [{"description": f"Execute {func_name}", "status": "enabled"}], } events.append(ToolCallArgsEvent(tool_call_id=confirm_id, delta=json.dumps(args))) - # Store the confirm_id in flow so we can track it for the response - flow.pending_confirm_id = confirm_id + events.append(ToolCallEndEvent(tool_call_id=confirm_id)) flow.waiting_for_approval = True - logger.info(f"[APPROVAL-REQUEST] Returning {len(events)} events") return events @@ -459,7 +450,6 @@ def _emit_content( ) -> list[BaseEvent]: """Emit appropriate events for any content type.""" content_type = getattr(content, "type", None) - logger.info(f"[EMIT-CONTENT] Processing content type: {content_type}") if content_type == "text": return _emit_text(content, flow, skip_text) elif content_type == "function_call": @@ -467,7 +457,6 @@ def _emit_content( elif content_type == "function_result": return _emit_tool_result(content, flow, predictive_handler) elif content_type == "function_approval_request": - logger.info("[EMIT-CONTENT] Got function_approval_request - emitting approval events") return _emit_approval_request(content, flow, predictive_handler, require_confirmation) return [] @@ -597,13 +586,18 @@ async def _resolve_approval_responses( # Build normalized results for approved responses normalized_results: list[Content] = [] for idx, approval in enumerate(approved_responses): - if idx < len(approved_function_results) and getattr(approved_function_results[idx], "type", None) == "function_result": + if ( + idx < len(approved_function_results) + and getattr(approved_function_results[idx], "type", None) == "function_result" + ): normalized_results.append(approved_function_results[idx]) continue # Get call_id from function_call if present, otherwise use approval.id func_call = approval.function_call call_id = (func_call.call_id if func_call else None) or approval.id or "" - normalized_results.append(Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.")) + normalized_results.append( + Content.from_function_result(call_id=call_id, result="Error: Tool call invocation failed.") + ) # Build rejection results for rejection in rejected_responses: @@ -741,9 +735,6 @@ async def run_agent_stream( run_kwargs: dict[str, Any] = {"thread": thread} if tools: run_kwargs["tools"] = tools - logger.info(f"[DEBUG] Setting run_kwargs['tools'] with {len(tools)} tools") - for t in tools: - logger.info(f"[DEBUG] - {getattr(t, 'name', 'unknown')}: approval_mode={getattr(t, 'approval_mode', None)}") safe_metadata = _build_safe_metadata(thread.metadata) # type: ignore[attr-defined] if safe_metadata: run_kwargs["options"] = {"metadata": safe_metadata, "store": True} @@ -891,9 +882,7 @@ async def run_agent_stream( # For predictive tools with require_confirmation, emit confirm_changes if config.require_confirmation and config.predict_state_config and tool_name: - is_predictive_tool = any( - cfg["tool"] == tool_name for cfg in config.predict_state_config.values() - ) + is_predictive_tool = any(cfg["tool"] == tool_name for cfg in config.predict_state_config.values()) if is_predictive_tool: logger.info(f"Emitting confirm_changes for predictive tool '{tool_name}'") # Extract state value from tool arguments for StateSnapshot @@ -945,12 +934,6 @@ async def run_agent_stream( ): yield _build_messages_snapshot(flow, snapshot_messages) - # Only emit RunFinished if we're not waiting for approval with an active confirm_changes tool - # The AG-UI protocol requires all tool calls to be ended before RUN_FINISHED - if not flow.pending_confirm_id: - yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) - else: - logger.info(f"Skipping RunFinishedEvent - waiting for approval on confirm_changes id={flow.pending_confirm_id}") - - -__all__ = ["FlowState", "run_agent_stream"] + # Always emit RunFinished - confirm_changes tool call is complete (Start -> Args -> End) + # The UI will show confirmation dialog and send a new request when user responds + yield RunFinishedEvent(run_id=run_id, thread_id=thread_id) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index f88dceb78b..7466f09371 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -13,14 +13,6 @@ else: from typing_extensions import TypeVar -__all__ = [ - "AGUIChatOptions", - "AgentState", - "PredictStateConfig", - "RunMetadata", -] - - class PredictStateConfig(TypedDict): """Configuration for predictive state updates.""" @@ -62,6 +54,10 @@ class AGUIRequest(BaseModel): None, description="Optional shared state for agentic generative UI", ) + tools: list[dict[str, Any]] | None = Field( + None, + description="Client-side tools to advertise to the LLM", + ) # region AG-UI Chat Options TypedDict diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py index 32324d72eb..5ebdc10d73 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/agents/weather_agent.py @@ -71,7 +71,8 @@ def weather_agent(chat_client: ChatClientProtocol[Any]) -> ChatAgent[Any]: instructions=( "You are a helpful weather assistant. " "Use the get_weather and get_forecast functions to help users with weather information. " - "Always provide friendly and informative responses." + "Always provide friendly and informative responses. " + "First return the weather result, and then return details about the forecast." ), chat_client=chat_client, tools=[get_weather, get_forecast], diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py index 8d9c212a5e..54dcc5f558 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py @@ -9,7 +9,9 @@ from agent_framework import ChatOptions from agent_framework._clients import BaseChatClient from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint -from agent_framework.azure import AzureOpenAIChatClient + +# from agent_framework.azure import AzureOpenAIChatClient +from agent_framework.anthropic import AnthropicClient from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -62,7 +64,8 @@ # Create a shared chat client for all agents # You can use different chat clients for different agents if needed -chat_client: BaseChatClient[ChatOptions] = AzureOpenAIChatClient() +# chat_client: BaseChatClient[ChatOptions] = AzureOpenAIChatClient() +chat_client: BaseChatClient[ChatOptions] = AnthropicClient() # Agentic Chat - basic chat agent add_agent_framework_fastapi_endpoint( diff --git a/python/packages/ag-ui/tests/test_helpers.py b/python/packages/ag-ui/tests/test_helpers.py new file mode 100644 index 0000000000..3762d3c5fd --- /dev/null +++ b/python/packages/ag-ui/tests/test_helpers.py @@ -0,0 +1,503 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for orchestration helper functions.""" + +import pytest +from agent_framework import ChatMessage, Content + +from agent_framework_ag_ui._orchestration._helpers import ( + approval_steps, + build_safe_metadata, + ensure_tool_call_entry, + is_state_context_message, + is_step_based_approval, + latest_approval_response, + pending_tool_call_ids, + schema_has_steps, + select_approval_tool_name, + tool_name_for_call_id, +) + + +class TestPendingToolCallIds: + """Tests for pending_tool_call_ids function.""" + + def test_empty_messages(self): + """Returns empty set for empty messages list.""" + result = pending_tool_call_ids([]) + assert result == set() + + def test_no_tool_calls(self): + """Returns empty set when no tool calls in messages.""" + messages = [ + ChatMessage(role="user", contents=[Content.from_text("Hello")]), + ChatMessage(role="assistant", contents=[Content.from_text("Hi there")]), + ] + result = pending_tool_call_ids(messages) + assert result == set() + + def test_pending_tool_call(self): + """Returns pending tool call ID when no result exists.""" + messages = [ + ChatMessage( + role="assistant", + contents=[Content.from_function_call(call_id="call_123", name="get_weather", arguments="{}")], + ), + ] + result = pending_tool_call_ids(messages) + assert result == {"call_123"} + + def test_resolved_tool_call(self): + """Returns empty set when tool call has result.""" + messages = [ + ChatMessage( + role="assistant", + contents=[Content.from_function_call(call_id="call_123", name="get_weather", arguments="{}")], + ), + ChatMessage( + role="tool", + contents=[Content.from_function_result(call_id="call_123", result="sunny")], + ), + ] + result = pending_tool_call_ids(messages) + assert result == set() + + def test_multiple_tool_calls_some_resolved(self): + """Returns only unresolved tool call IDs.""" + messages = [ + ChatMessage( + role="assistant", + contents=[ + Content.from_function_call(call_id="call_1", name="tool_a", arguments="{}"), + Content.from_function_call(call_id="call_2", name="tool_b", arguments="{}"), + Content.from_function_call(call_id="call_3", name="tool_c", arguments="{}"), + ], + ), + ChatMessage( + role="tool", + contents=[Content.from_function_result(call_id="call_1", result="result_a")], + ), + ChatMessage( + role="tool", + contents=[Content.from_function_result(call_id="call_3", result="result_c")], + ), + ] + result = pending_tool_call_ids(messages) + assert result == {"call_2"} + + +class TestIsStateContextMessage: + """Tests for is_state_context_message function.""" + + def test_state_context_message(self): + """Returns True for state context message.""" + message = ChatMessage( + role="system", + contents=[Content.from_text("Current state of the application: {}")], + ) + assert is_state_context_message(message) is True + + def test_non_system_message(self): + """Returns False for non-system message.""" + message = ChatMessage( + role="user", + contents=[Content.from_text("Current state of the application: {}")], + ) + assert is_state_context_message(message) is False + + def test_system_message_without_state_prefix(self): + """Returns False for system message without state prefix.""" + message = ChatMessage( + role="system", + contents=[Content.from_text("You are a helpful assistant.")], + ) + assert is_state_context_message(message) is False + + def test_empty_contents(self): + """Returns False for message with empty contents.""" + message = ChatMessage(role="system", contents=[]) + assert is_state_context_message(message) is False + + +class TestEnsureToolCallEntry: + """Tests for ensure_tool_call_entry function.""" + + def test_creates_new_entry(self): + """Creates new entry when ID not found.""" + tool_calls_by_id: dict = {} + pending_tool_calls: list = [] + + entry = ensure_tool_call_entry("call_123", tool_calls_by_id, pending_tool_calls) + + assert entry["id"] == "call_123" + assert entry["type"] == "function" + assert entry["function"]["name"] == "" + assert entry["function"]["arguments"] == "" + assert "call_123" in tool_calls_by_id + assert len(pending_tool_calls) == 1 + + def test_returns_existing_entry(self): + """Returns existing entry when ID found.""" + existing_entry = { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}, + } + tool_calls_by_id = {"call_123": existing_entry} + pending_tool_calls: list = [] + + entry = ensure_tool_call_entry("call_123", tool_calls_by_id, pending_tool_calls) + + assert entry is existing_entry + assert entry["function"]["name"] == "get_weather" + assert len(pending_tool_calls) == 0 # Not added again + + +class TestToolNameForCallId: + """Tests for tool_name_for_call_id function.""" + + def test_returns_tool_name(self): + """Returns tool name for valid entry.""" + tool_calls_by_id = { + "call_123": { + "id": "call_123", + "function": {"name": "get_weather", "arguments": "{}"}, + } + } + result = tool_name_for_call_id(tool_calls_by_id, "call_123") + assert result == "get_weather" + + def test_returns_none_for_missing_id(self): + """Returns None when ID not found.""" + tool_calls_by_id: dict = {} + result = tool_name_for_call_id(tool_calls_by_id, "call_123") + assert result is None + + def test_returns_none_for_missing_function(self): + """Returns None when function key missing.""" + tool_calls_by_id = {"call_123": {"id": "call_123"}} + result = tool_name_for_call_id(tool_calls_by_id, "call_123") + assert result is None + + def test_returns_none_for_non_dict_function(self): + """Returns None when function is not a dict.""" + tool_calls_by_id = {"call_123": {"id": "call_123", "function": "not_a_dict"}} + result = tool_name_for_call_id(tool_calls_by_id, "call_123") + assert result is None + + def test_returns_none_for_empty_name(self): + """Returns None when name is empty.""" + tool_calls_by_id = {"call_123": {"id": "call_123", "function": {"name": "", "arguments": "{}"}}} + result = tool_name_for_call_id(tool_calls_by_id, "call_123") + assert result is None + + +class TestSchemaHasSteps: + """Tests for schema_has_steps function.""" + + def test_schema_with_steps_array(self): + """Returns True when schema has steps array property.""" + schema = {"properties": {"steps": {"type": "array"}}} + assert schema_has_steps(schema) is True + + def test_schema_without_steps(self): + """Returns False when schema doesn't have steps.""" + schema = {"properties": {"name": {"type": "string"}}} + assert schema_has_steps(schema) is False + + def test_schema_with_non_array_steps(self): + """Returns False when steps is not array type.""" + schema = {"properties": {"steps": {"type": "string"}}} + assert schema_has_steps(schema) is False + + def test_non_dict_schema(self): + """Returns False for non-dict schema.""" + assert schema_has_steps(None) is False + assert schema_has_steps("not a dict") is False + assert schema_has_steps([]) is False + + def test_missing_properties(self): + """Returns False when properties key is missing.""" + schema = {"type": "object"} + assert schema_has_steps(schema) is False + + def test_non_dict_properties(self): + """Returns False when properties is not a dict.""" + schema = {"properties": "not a dict"} + assert schema_has_steps(schema) is False + + def test_non_dict_steps(self): + """Returns False when steps is not a dict.""" + schema = {"properties": {"steps": "not a dict"}} + assert schema_has_steps(schema) is False + + +class TestSelectApprovalToolName: + """Tests for select_approval_tool_name function.""" + + def test_none_client_tools(self): + """Returns None when client_tools is None.""" + result = select_approval_tool_name(None) + assert result is None + + def test_empty_client_tools(self): + """Returns None when client_tools is empty.""" + result = select_approval_tool_name([]) + assert result is None + + def test_finds_approval_tool(self): + """Returns tool name when tool has steps schema.""" + + class MockTool: + name = "generate_task_steps" + + def parameters(self): + return {"properties": {"steps": {"type": "array"}}} + + result = select_approval_tool_name([MockTool()]) + assert result == "generate_task_steps" + + def test_skips_tool_without_name(self): + """Skips tools without name attribute.""" + + class MockToolNoName: + def parameters(self): + return {"properties": {"steps": {"type": "array"}}} + + result = select_approval_tool_name([MockToolNoName()]) + assert result is None + + def test_skips_tool_without_parameters_method(self): + """Skips tools without callable parameters method.""" + + class MockToolNoParams: + name = "some_tool" + parameters = "not callable" + + result = select_approval_tool_name([MockToolNoParams()]) + assert result is None + + def test_skips_tool_without_steps_schema(self): + """Skips tools that don't have steps in schema.""" + + class MockToolNoSteps: + name = "other_tool" + + def parameters(self): + return {"properties": {"data": {"type": "string"}}} + + result = select_approval_tool_name([MockToolNoSteps()]) + assert result is None + + +class TestBuildSafeMetadata: + """Tests for build_safe_metadata function.""" + + def test_none_metadata(self): + """Returns empty dict for None metadata.""" + result = build_safe_metadata(None) + assert result == {} + + def test_empty_metadata(self): + """Returns empty dict for empty metadata.""" + result = build_safe_metadata({}) + assert result == {} + + def test_string_values_under_limit(self): + """Preserves string values under 512 chars.""" + metadata = {"key1": "short value", "key2": "another value"} + result = build_safe_metadata(metadata) + assert result == metadata + + def test_truncates_long_string_values(self): + """Truncates string values over 512 chars.""" + long_value = "x" * 1000 + metadata = {"key": long_value} + result = build_safe_metadata(metadata) + assert len(result["key"]) == 512 + assert result["key"] == "x" * 512 + + def test_non_string_values_serialized(self): + """Serializes non-string values to JSON.""" + metadata = {"count": 42, "items": ["a", "b"]} + result = build_safe_metadata(metadata) + assert result["count"] == "42" + assert result["items"] == '["a", "b"]' + + def test_truncates_serialized_values(self): + """Truncates serialized JSON values over 512 chars.""" + long_list = list(range(200)) # Will serialize to >512 chars + metadata = {"data": long_list} + result = build_safe_metadata(metadata) + assert len(result["data"]) == 512 + + +class TestLatestApprovalResponse: + """Tests for latest_approval_response function.""" + + def test_empty_messages(self): + """Returns None for empty messages.""" + result = latest_approval_response([]) + assert result is None + + def test_no_approval_response(self): + """Returns None when no approval response in last message.""" + messages = [ + ChatMessage(role="assistant", contents=[Content.from_text("Hello")]), + ] + result = latest_approval_response(messages) + assert result is None + + def test_finds_approval_response(self): + """Returns approval response from last message.""" + # Create a function call content first + fc = Content.from_function_call(call_id="call_123", name="test_tool", arguments="{}") + approval_content = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + ) + messages = [ + ChatMessage(role="user", contents=[approval_content]), + ] + result = latest_approval_response(messages) + assert result is approval_content + + +class TestApprovalSteps: + """Tests for approval_steps function.""" + + def test_steps_from_ag_ui_state_args(self): + """Extracts steps from ag_ui_state_args.""" + fc = Content.from_function_call(call_id="call_123", name="test_tool", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + additional_properties={"ag_ui_state_args": {"steps": [{"id": 1}, {"id": 2}]}}, + ) + result = approval_steps(approval) + assert result == [{"id": 1}, {"id": 2}] + + def test_steps_from_function_call(self): + """Extracts steps from function call arguments.""" + fc = Content.from_function_call( + call_id="call_123", + name="test", + arguments='{"steps": [{"step": 1}]}', + ) + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + ) + result = approval_steps(approval) + assert result == [{"step": 1}] + + def test_empty_steps_when_no_state_args(self): + """Returns empty list when no ag_ui_state_args.""" + fc = Content.from_function_call(call_id="call_123", name="test_tool", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + ) + result = approval_steps(approval) + assert result == [] + + def test_empty_steps_when_state_args_not_dict(self): + """Returns empty list when ag_ui_state_args is not a dict.""" + fc = Content.from_function_call(call_id="call_123", name="test_tool", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + additional_properties={"ag_ui_state_args": "not a dict"}, + ) + result = approval_steps(approval) + assert result == [] + + def test_empty_steps_when_steps_not_list(self): + """Returns empty list when steps is not a list.""" + fc = Content.from_function_call(call_id="call_123", name="test_tool", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + additional_properties={"ag_ui_state_args": {"steps": "not a list"}}, + ) + result = approval_steps(approval) + assert result == [] + + +class TestIsStepBasedApproval: + """Tests for is_step_based_approval function.""" + + def test_returns_true_when_has_steps(self): + """Returns True when approval has steps.""" + fc = Content.from_function_call(call_id="call_123", name="test_tool", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + additional_properties={"ag_ui_state_args": {"steps": [{"id": 1}]}}, + ) + result = is_step_based_approval(approval, None) + assert result is True + + def test_returns_false_no_steps_no_function_call(self): + """Returns False when no steps and no function call.""" + # Create content directly to have no function_call + approval = Content( + type="function_approval_response", + function_call=None, + ) + result = is_step_based_approval(approval, None) + assert result is False + + def test_returns_false_no_predict_config(self): + """Returns False when no predict_state_config.""" + fc = Content.from_function_call(call_id="call_123", name="some_tool", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + ) + result = is_step_based_approval(approval, None) + assert result is False + + def test_returns_true_when_tool_matches_config(self): + """Returns True when tool matches predict_state_config with steps.""" + fc = Content.from_function_call(call_id="call_123", name="generate_steps", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + ) + config = {"steps": {"tool": "generate_steps", "tool_argument": "steps"}} + result = is_step_based_approval(approval, config) + assert result is True + + def test_returns_false_when_tool_not_in_config(self): + """Returns False when tool not in predict_state_config.""" + fc = Content.from_function_call(call_id="call_123", name="other_tool", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + ) + config = {"steps": {"tool": "generate_steps", "tool_argument": "steps"}} + result = is_step_based_approval(approval, config) + assert result is False + + def test_returns_false_when_tool_arg_not_steps(self): + """Returns False when tool_argument is not 'steps'.""" + fc = Content.from_function_call(call_id="call_123", name="generate_steps", arguments="{}") + approval = Content.from_function_approval_response( + approved=True, + id="approval_123", + function_call=fc, + ) + config = {"document": {"tool": "generate_steps", "tool_argument": "content"}} + result = is_step_based_approval(approval, config) + assert result is False diff --git a/python/packages/ag-ui/tests/test_message_adapters.py b/python/packages/ag-ui/tests/test_message_adapters.py index 970a4fe76b..4f6c3f1d42 100644 --- a/python/packages/ag-ui/tests/test_message_adapters.py +++ b/python/packages/ag-ui/tests/test_message_adapters.py @@ -642,3 +642,109 @@ class MockTextContent: agui_msg = messages[0] # Multiple items should return JSON array assert agui_msg["content"] == '["First result", "Second result"]' + + +# Additional tests for better coverage + + +def test_extract_text_from_contents_empty(): + """Test extracting text from empty contents.""" + result = extract_text_from_contents([]) + assert result == "" + + +def test_extract_text_from_contents_multiple(): + """Test extracting text from multiple text contents.""" + contents = [ + Content.from_text("Hello "), + Content.from_text("World"), + ] + result = extract_text_from_contents(contents) + assert result == "Hello World" + + +def test_extract_text_from_contents_non_text(): + """Test extracting text ignores non-text contents.""" + contents = [ + Content.from_text("Hello"), + Content.from_function_call(call_id="call_1", name="tool", arguments="{}"), + ] + result = extract_text_from_contents(contents) + assert result == "Hello" + + +def test_agui_to_agent_framework_with_tool_calls(): + """Test converting AG-UI message with tool_calls.""" + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": '{"city": "NYC"}'}, + } + ], + } + ] + + result = agui_messages_to_agent_framework(messages) + + assert len(result) == 1 + assert len(result[0].contents) == 1 + assert result[0].contents[0].type == "function_call" + assert result[0].contents[0].name == "get_weather" + + +def test_agui_to_agent_framework_tool_result(): + """Test converting AG-UI tool result message.""" + messages = [ + { + "role": "assistant", + "content": "", + "tool_calls": [ + { + "id": "call_123", + "type": "function", + "function": {"name": "get_weather", "arguments": "{}"}, + } + ], + }, + { + "role": "tool", + "content": "Sunny", + "toolCallId": "call_123", + }, + ] + + result = agui_messages_to_agent_framework(messages) + + assert len(result) == 2 + # Second message should be tool result + tool_msg = result[1] + assert tool_msg.role == Role.TOOL + assert tool_msg.contents[0].type == "function_result" + assert tool_msg.contents[0].result == "Sunny" + + +def test_agui_messages_to_snapshot_format_empty(): + """Test converting empty messages to snapshot format.""" + result = agui_messages_to_snapshot_format([]) + assert result == [] + + +def test_agui_messages_to_snapshot_format_basic(): + """Test converting messages to snapshot format.""" + messages = [ + {"role": "user", "content": "Hello", "id": "msg_1"}, + {"role": "assistant", "content": "Hi there", "id": "msg_2"}, + ] + + result = agui_messages_to_snapshot_format(messages) + + assert len(result) == 2 + assert result[0]["role"] == "user" + assert result[0]["content"] == "Hello" + assert result[1]["role"] == "assistant" + assert result[1]["content"] == "Hi there" diff --git a/python/packages/ag-ui/tests/test_predictive_state.py b/python/packages/ag-ui/tests/test_predictive_state.py new file mode 100644 index 0000000000..59113bf05c --- /dev/null +++ b/python/packages/ag-ui/tests/test_predictive_state.py @@ -0,0 +1,363 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for predictive state handling.""" + +import pytest +from ag_ui.core import StateDeltaEvent + +from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler + + +class TestPredictiveStateHandlerInit: + """Tests for PredictiveStateHandler initialization.""" + + def test_default_init(self): + """Initializes with default values.""" + handler = PredictiveStateHandler() + assert handler.predict_state_config == {} + assert handler.current_state == {} + assert handler.streaming_tool_args == "" + assert handler.last_emitted_state == {} + assert handler.state_delta_count == 0 + assert handler.pending_state_updates == {} + + def test_init_with_config(self): + """Initializes with provided config.""" + config = {"document": {"tool": "write_doc", "tool_argument": "content"}} + state = {"document": "initial"} + handler = PredictiveStateHandler(predict_state_config=config, current_state=state) + assert handler.predict_state_config == config + assert handler.current_state == state + + +class TestResetStreaming: + """Tests for reset_streaming method.""" + + def test_resets_streaming_state(self): + """Resets streaming-related state.""" + handler = PredictiveStateHandler() + handler.streaming_tool_args = "some accumulated args" + handler.state_delta_count = 5 + + handler.reset_streaming() + + assert handler.streaming_tool_args == "" + assert handler.state_delta_count == 0 + + +class TestExtractStateValue: + """Tests for extract_state_value method.""" + + def test_no_config(self): + """Returns None when no config.""" + handler = PredictiveStateHandler() + result = handler.extract_state_value("some_tool", {"arg": "value"}) + assert result is None + + def test_no_args(self): + """Returns None when args is None.""" + handler = PredictiveStateHandler( + predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}} + ) + result = handler.extract_state_value("tool", None) + assert result is None + + def test_empty_args(self): + """Returns None when args is empty string.""" + handler = PredictiveStateHandler( + predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}} + ) + result = handler.extract_state_value("tool", "") + assert result is None + + def test_tool_not_in_config(self): + """Returns None when tool not in config.""" + handler = PredictiveStateHandler( + predict_state_config={"key": {"tool": "other_tool", "tool_argument": "arg"}} + ) + result = handler.extract_state_value("some_tool", {"arg": "value"}) + assert result is None + + def test_extracts_specific_argument(self): + """Extracts value from specific tool argument.""" + handler = PredictiveStateHandler( + predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}} + ) + result = handler.extract_state_value("write_doc", {"content": "Hello world"}) + assert result == ("document", "Hello world") + + def test_extracts_with_wildcard(self): + """Extracts entire args with * wildcard.""" + handler = PredictiveStateHandler( + predict_state_config={"data": {"tool": "update_data", "tool_argument": "*"}} + ) + args = {"key1": "value1", "key2": "value2"} + result = handler.extract_state_value("update_data", args) + assert result == ("data", args) + + def test_extracts_from_json_string(self): + """Extracts value from JSON string args.""" + handler = PredictiveStateHandler( + predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}} + ) + result = handler.extract_state_value("write_doc", '{"content": "Hello world"}') + assert result == ("document", "Hello world") + + def test_argument_not_in_args(self): + """Returns None when tool_argument not in args.""" + handler = PredictiveStateHandler( + predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}} + ) + result = handler.extract_state_value("write_doc", {"other": "value"}) + assert result is None + + +class TestIsPredictiveTool: + """Tests for is_predictive_tool method.""" + + def test_none_tool_name(self): + """Returns False for None tool name.""" + handler = PredictiveStateHandler( + predict_state_config={"key": {"tool": "some_tool", "tool_argument": "arg"}} + ) + assert handler.is_predictive_tool(None) is False + + def test_no_config(self): + """Returns False when no config.""" + handler = PredictiveStateHandler() + assert handler.is_predictive_tool("some_tool") is False + + def test_tool_in_config(self): + """Returns True when tool is in config.""" + handler = PredictiveStateHandler( + predict_state_config={"key": {"tool": "some_tool", "tool_argument": "arg"}} + ) + assert handler.is_predictive_tool("some_tool") is True + + def test_tool_not_in_config(self): + """Returns False when tool not in config.""" + handler = PredictiveStateHandler( + predict_state_config={"key": {"tool": "other_tool", "tool_argument": "arg"}} + ) + assert handler.is_predictive_tool("some_tool") is False + + +class TestEmitStreamingDeltas: + """Tests for emit_streaming_deltas method.""" + + def test_no_tool_name(self): + """Returns empty list for None tool name.""" + handler = PredictiveStateHandler( + predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}} + ) + result = handler.emit_streaming_deltas(None, '{"arg": "value"}') + assert result == [] + + def test_no_config(self): + """Returns empty list when no config.""" + handler = PredictiveStateHandler() + result = handler.emit_streaming_deltas("some_tool", '{"arg": "value"}') + assert result == [] + + def test_accumulates_args(self): + """Accumulates argument chunks.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + handler.emit_streaming_deltas("write", '{"text') + handler.emit_streaming_deltas("write", '": "hello') + assert handler.streaming_tool_args == '{"text": "hello' + + def test_emits_delta_on_complete_json(self): + """Emits delta when JSON is complete.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + events = handler.emit_streaming_deltas("write", '{"text": "hello"}') + assert len(events) == 1 + assert isinstance(events[0], StateDeltaEvent) + assert events[0].delta[0]["path"] == "/doc" + assert events[0].delta[0]["value"] == "hello" + assert events[0].delta[0]["op"] == "replace" + + def test_emits_delta_on_partial_json(self): + """Emits delta from partial JSON using regex.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + # First chunk - partial + events = handler.emit_streaming_deltas("write", '{"text": "hel') + assert len(events) == 1 + assert events[0].delta[0]["value"] == "hel" + + def test_does_not_emit_duplicate_deltas(self): + """Does not emit delta when value unchanged.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + # First emission + events1 = handler.emit_streaming_deltas("write", '{"text": "hello"}') + assert len(events1) == 1 + + # Reset and emit same value again + handler.streaming_tool_args = "" + events2 = handler.emit_streaming_deltas("write", '{"text": "hello"}') + assert len(events2) == 0 # No duplicate + + def test_emits_delta_on_value_change(self): + """Emits delta when value changes.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + # First value + events1 = handler.emit_streaming_deltas("write", '{"text": "hello"}') + assert len(events1) == 1 + + # Reset and new value + handler.streaming_tool_args = "" + events2 = handler.emit_streaming_deltas("write", '{"text": "world"}') + assert len(events2) == 1 + assert events2[0].delta[0]["value"] == "world" + + def test_tracks_pending_updates(self): + """Tracks pending state updates.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + handler.emit_streaming_deltas("write", '{"text": "hello"}') + assert handler.pending_state_updates == {"doc": "hello"} + + +class TestEmitPartialDeltas: + """Tests for _emit_partial_deltas method.""" + + def test_unescapes_newlines(self): + """Unescapes \\n in partial values.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + handler.streaming_tool_args = '{"text": "line1\\nline2' + events = handler._emit_partial_deltas("write") + assert len(events) == 1 + assert events[0].delta[0]["value"] == "line1\nline2" + + def test_handles_escaped_quotes_partially(self): + """Handles escaped quotes - regex stops at quote character.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + # The regex pattern [^"]* stops at ANY quote, including escaped ones. + # This is expected behavior for partial streaming - the full JSON + # will be parsed correctly when complete. + handler.streaming_tool_args = '{"text": "say \\"hi' + events = handler._emit_partial_deltas("write") + assert len(events) == 1 + # Captures "say \" then the backslash gets converted to empty string + # by the replace("\\\\", "\\") first, then replace('\\"', '"') + # but since there's no closing quote, we get "say \" + # After .replace("\\\\", "\\") -> "say \" + # After .replace('\\"', '"') -> "say " (but actually still "say \" due to order) + # The actual result: backslash is preserved since it's not a valid escape sequence + assert events[0].delta[0]["value"] == "say \\" + + def test_unescapes_backslashes(self): + """Unescapes \\\\ in partial values.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + handler.streaming_tool_args = '{"text": "path\\\\to\\\\file' + events = handler._emit_partial_deltas("write") + assert len(events) == 1 + assert events[0].delta[0]["value"] == "path\\to\\file" + + +class TestEmitCompleteDeltas: + """Tests for _emit_complete_deltas method.""" + + def test_emits_for_matching_tool(self): + """Emits delta for tool matching config.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + events = handler._emit_complete_deltas("write", {"text": "content"}) + assert len(events) == 1 + assert events[0].delta[0]["value"] == "content" + + def test_skips_non_matching_tool(self): + """Skips tools not matching config.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + events = handler._emit_complete_deltas("other_tool", {"text": "content"}) + assert len(events) == 0 + + def test_handles_wildcard_argument(self): + """Handles * wildcard for entire args.""" + handler = PredictiveStateHandler( + predict_state_config={"data": {"tool": "update", "tool_argument": "*"}} + ) + args = {"key1": "val1", "key2": "val2"} + events = handler._emit_complete_deltas("update", args) + assert len(events) == 1 + assert events[0].delta[0]["value"] == args + + def test_skips_missing_argument(self): + """Skips when tool_argument not in args.""" + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} + ) + events = handler._emit_complete_deltas("write", {"other": "value"}) + assert len(events) == 0 + + +class TestCreateDeltaEvent: + """Tests for _create_delta_event method.""" + + def test_creates_event(self): + """Creates StateDeltaEvent with correct structure.""" + handler = PredictiveStateHandler() + event = handler._create_delta_event("key", "value") + + assert isinstance(event, StateDeltaEvent) + assert event.delta[0]["op"] == "replace" + assert event.delta[0]["path"] == "/key" + assert event.delta[0]["value"] == "value" + + def test_increments_count(self): + """Increments state_delta_count.""" + handler = PredictiveStateHandler() + handler._create_delta_event("key", "value") + assert handler.state_delta_count == 1 + handler._create_delta_event("key", "value2") + assert handler.state_delta_count == 2 + + +class TestApplyPendingUpdates: + """Tests for apply_pending_updates method.""" + + def test_applies_pending_to_current(self): + """Applies pending updates to current state.""" + handler = PredictiveStateHandler(current_state={"existing": "value"}) + handler.pending_state_updates = {"doc": "new content", "count": 5} + + handler.apply_pending_updates() + + assert handler.current_state == {"existing": "value", "doc": "new content", "count": 5} + + def test_clears_pending_updates(self): + """Clears pending updates after applying.""" + handler = PredictiveStateHandler() + handler.pending_state_updates = {"doc": "content"} + + handler.apply_pending_updates() + + assert handler.pending_state_updates == {} + + def test_overwrites_existing_keys(self): + """Overwrites existing keys in current state.""" + handler = PredictiveStateHandler(current_state={"doc": "old"}) + handler.pending_state_updates = {"doc": "new"} + + handler.apply_pending_updates() + + assert handler.current_state["doc"] == "new" diff --git a/python/packages/ag-ui/tests/test_run.py b/python/packages/ag-ui/tests/test_run.py new file mode 100644 index 0000000000..0f842325a6 --- /dev/null +++ b/python/packages/ag-ui/tests/test_run.py @@ -0,0 +1,378 @@ +# Copyright (c) Microsoft. All rights reserved. + +"""Tests for _run.py helper functions and FlowState.""" + +import pytest +from agent_framework import ChatMessage, Content + +from agent_framework_ag_ui._run import ( + FlowState, + _build_safe_metadata, + _create_state_context_message, + _has_only_tool_calls, + _inject_state_context, + _should_suppress_intermediate_snapshot, +) + + +class TestBuildSafeMetadata: + """Tests for _build_safe_metadata function.""" + + def test_none_metadata(self): + """Returns empty dict for None.""" + result = _build_safe_metadata(None) + assert result == {} + + def test_empty_metadata(self): + """Returns empty dict for empty dict.""" + result = _build_safe_metadata({}) + assert result == {} + + def test_short_string_values(self): + """Preserves short string values.""" + metadata = {"key1": "short", "key2": "value"} + result = _build_safe_metadata(metadata) + assert result == metadata + + def test_truncates_long_strings(self): + """Truncates strings over 512 chars.""" + long_value = "x" * 1000 + metadata = {"key": long_value} + result = _build_safe_metadata(metadata) + assert len(result["key"]) == 512 + + def test_serializes_non_strings(self): + """Serializes non-string values to JSON.""" + metadata = {"count": 42, "items": [1, 2, 3]} + result = _build_safe_metadata(metadata) + assert result["count"] == "42" + assert result["items"] == "[1, 2, 3]" + + def test_truncates_serialized_values(self): + """Truncates serialized values over 512 chars.""" + long_list = list(range(200)) + metadata = {"data": long_list} + result = _build_safe_metadata(metadata) + assert len(result["data"]) == 512 + + +class TestHasOnlyToolCalls: + """Tests for _has_only_tool_calls function.""" + + def test_only_tool_calls(self): + """Returns True when only function_call content.""" + contents = [ + Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"), + ] + assert _has_only_tool_calls(contents) is True + + def test_tool_call_with_text(self): + """Returns False when both tool call and text.""" + contents = [ + Content.from_text("Some text"), + Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"), + ] + assert _has_only_tool_calls(contents) is False + + def test_only_text(self): + """Returns False when only text.""" + contents = [Content.from_text("Just text")] + assert _has_only_tool_calls(contents) is False + + def test_empty_contents(self): + """Returns False for empty contents.""" + assert _has_only_tool_calls([]) is False + + def test_tool_call_with_empty_text(self): + """Returns True when text content has empty text.""" + contents = [ + Content.from_text(""), + Content.from_function_call(call_id="call_1", name="tool1", arguments="{}"), + ] + assert _has_only_tool_calls(contents) is True + + +class TestShouldSuppressIntermediateSnapshot: + """Tests for _should_suppress_intermediate_snapshot function.""" + + def test_no_tool_name(self): + """Returns False when no tool name.""" + result = _should_suppress_intermediate_snapshot( + None, {"key": {"tool": "write_doc", "tool_argument": "content"}}, False + ) + assert result is False + + def test_no_config(self): + """Returns False when no config.""" + result = _should_suppress_intermediate_snapshot("write_doc", None, False) + assert result is False + + def test_confirmation_required(self): + """Returns False when confirmation is required.""" + config = {"key": {"tool": "write_doc", "tool_argument": "content"}} + result = _should_suppress_intermediate_snapshot("write_doc", config, True) + assert result is False + + def test_tool_not_in_config(self): + """Returns False when tool not in config.""" + config = {"key": {"tool": "other_tool", "tool_argument": "content"}} + result = _should_suppress_intermediate_snapshot("write_doc", config, False) + assert result is False + + def test_suppresses_predictive_tool(self): + """Returns True for predictive tool without confirmation.""" + config = {"document": {"tool": "write_doc", "tool_argument": "content"}} + result = _should_suppress_intermediate_snapshot("write_doc", config, False) + assert result is True + + +class TestFlowState: + """Tests for FlowState dataclass.""" + + def test_default_values(self): + """Tests default initialization.""" + flow = FlowState() + assert flow.message_id is None + assert flow.tool_call_id is None + assert flow.tool_call_name is None + assert flow.waiting_for_approval is False + assert flow.current_state == {} + assert flow.accumulated_text == "" + assert flow.pending_tool_calls == [] + assert flow.tool_calls_by_id == {} + assert flow.tool_results == [] + assert flow.tool_calls_ended == set() + + def test_get_tool_name(self): + """Tests get_tool_name method.""" + flow = FlowState() + flow.tool_calls_by_id = { + "call_123": {"function": {"name": "get_weather", "arguments": "{}"}} + } + + assert flow.get_tool_name("call_123") == "get_weather" + assert flow.get_tool_name("nonexistent") is None + assert flow.get_tool_name(None) is None + + def test_get_tool_name_empty_name(self): + """Tests get_tool_name with empty name.""" + flow = FlowState() + flow.tool_calls_by_id = {"call_123": {"function": {"name": "", "arguments": "{}"}}} + + assert flow.get_tool_name("call_123") is None + + def test_get_pending_without_end(self): + """Tests get_pending_without_end method.""" + flow = FlowState() + flow.pending_tool_calls = [ + {"id": "call_1", "function": {"name": "tool1"}}, + {"id": "call_2", "function": {"name": "tool2"}}, + {"id": "call_3", "function": {"name": "tool3"}}, + ] + flow.tool_calls_ended = {"call_1", "call_3"} + + result = flow.get_pending_without_end() + assert len(result) == 1 + assert result[0]["id"] == "call_2" + + +class TestCreateStateContextMessage: + """Tests for _create_state_context_message function.""" + + def test_no_state(self): + """Returns None when no state.""" + result = _create_state_context_message({}, {"properties": {}}) + assert result is None + + def test_no_schema(self): + """Returns None when no schema.""" + result = _create_state_context_message({"key": "value"}, {}) + assert result is None + + def test_creates_message(self): + """Creates state context message.""" + from agent_framework import Role + + state = {"document": "Hello world"} + schema = {"properties": {"document": {"type": "string"}}} + + result = _create_state_context_message(state, schema) + + assert result is not None + assert result.role == Role.SYSTEM + assert len(result.contents) == 1 + assert "Hello world" in result.contents[0].text + assert "Current state" in result.contents[0].text + + +class TestInjectStateContext: + """Tests for _inject_state_context function.""" + + def test_no_state_message(self): + """Returns original messages when no state context needed.""" + messages = [ChatMessage(role="user", contents=[Content.from_text("Hello")])] + result = _inject_state_context(messages, {}, {}) + assert result == messages + + def test_empty_messages(self): + """Returns empty list for empty messages.""" + result = _inject_state_context([], {"key": "value"}, {"properties": {}}) + assert result == [] + + def test_last_message_not_user(self): + """Returns original messages when last message is not from user.""" + messages = [ + ChatMessage(role="user", contents=[Content.from_text("Hello")]), + ChatMessage(role="assistant", contents=[Content.from_text("Hi")]), + ] + state = {"key": "value"} + schema = {"properties": {"key": {"type": "string"}}} + + result = _inject_state_context(messages, state, schema) + assert result == messages + + def test_injects_before_last_user_message(self): + """Injects state context before last user message.""" + from agent_framework import Role + + messages = [ + ChatMessage(role="system", contents=[Content.from_text("You are helpful")]), + ChatMessage(role="user", contents=[Content.from_text("Hello")]), + ] + state = {"document": "content"} + schema = {"properties": {"document": {"type": "string"}}} + + result = _inject_state_context(messages, state, schema) + + assert len(result) == 3 + # System message first + assert result[0].role == Role.SYSTEM + assert "helpful" in result[0].contents[0].text + # State context second + assert result[1].role == Role.SYSTEM + assert "Current state" in result[1].contents[0].text + # User message last + assert result[2].role == Role.USER + assert "Hello" in result[2].contents[0].text + + +# Additional tests for _run.py functions + + +def test_emit_text_basic(): + """Test _emit_text emits correct events.""" + from agent_framework_ag_ui._run import _emit_text + + flow = FlowState() + content = Content.from_text("Hello world") + + events = _emit_text(content, flow) + + assert len(events) == 2 # TextMessageStartEvent + TextMessageContentEvent + assert flow.message_id is not None + assert flow.accumulated_text == "Hello world" + + +def test_emit_text_skip_empty(): + """Test _emit_text skips empty text.""" + from agent_framework_ag_ui._run import _emit_text + + flow = FlowState() + content = Content.from_text("") + + events = _emit_text(content, flow) + + assert len(events) == 0 + + +def test_emit_text_continues_existing_message(): + """Test _emit_text continues existing message.""" + from agent_framework_ag_ui._run import _emit_text + + flow = FlowState() + flow.message_id = "existing-id" + content = Content.from_text("more text") + + events = _emit_text(content, flow) + + assert len(events) == 1 # Only TextMessageContentEvent, no new start + assert flow.message_id == "existing-id" + + +def test_emit_text_skips_when_waiting_for_approval(): + """Test _emit_text skips when waiting for approval.""" + from agent_framework_ag_ui._run import _emit_text + + flow = FlowState() + flow.waiting_for_approval = True + content = Content.from_text("should skip") + + events = _emit_text(content, flow) + + assert len(events) == 0 + + +def test_emit_text_skips_when_skip_text_flag(): + """Test _emit_text skips with skip_text flag.""" + from agent_framework_ag_ui._run import _emit_text + + flow = FlowState() + content = Content.from_text("should skip") + + events = _emit_text(content, flow, skip_text=True) + + assert len(events) == 0 + + +def test_emit_tool_call_basic(): + """Test _emit_tool_call emits correct events.""" + from agent_framework_ag_ui._run import _emit_tool_call + + flow = FlowState() + content = Content.from_function_call( + call_id="call_123", + name="get_weather", + arguments='{"city": "NYC"}', + ) + + events = _emit_tool_call(content, flow) + + assert len(events) >= 1 # At least ToolCallStartEvent + assert flow.tool_call_id == "call_123" + assert flow.tool_call_name == "get_weather" + + +def test_emit_tool_call_generates_id(): + """Test _emit_tool_call generates ID when not provided.""" + from agent_framework_ag_ui._run import _emit_tool_call + + flow = FlowState() + # Create content without call_id + content = Content(type="function_call", name="test_tool", arguments="{}") + + events = _emit_tool_call(content, flow) + + assert len(events) >= 1 + assert flow.tool_call_id is not None # ID should be generated + + +def test_extract_approved_state_updates_no_handler(): + """Test _extract_approved_state_updates returns empty with no handler.""" + from agent_framework_ag_ui._run import _extract_approved_state_updates + + messages = [ChatMessage(role="user", contents=[Content.from_text("Hello")])] + result = _extract_approved_state_updates(messages, None) + assert result == {} + + +def test_extract_approved_state_updates_no_approval(): + """Test _extract_approved_state_updates returns empty when no approval content.""" + from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler + from agent_framework_ag_ui._run import _extract_approved_state_updates + + handler = PredictiveStateHandler( + predict_state_config={"doc": {"tool": "write", "tool_argument": "content"}} + ) + messages = [ChatMessage(role="user", contents=[Content.from_text("Hello")])] + result = _extract_approved_state_updates(messages, handler) + assert result == {} diff --git a/python/packages/ag-ui/tests/test_tooling.py b/python/packages/ag-ui/tests/test_tooling.py index 23d82dda90..b8c9700cd4 100644 --- a/python/packages/ag-ui/tests/test_tooling.py +++ b/python/packages/ag-ui/tests/test_tooling.py @@ -129,3 +129,95 @@ def test_collect_server_tools_with_mcp_tools_via_public_property() -> None: assert "regular_tool" in names assert "mcp_function" in names assert len(tools) == 2 + + +# Additional tests for tooling coverage + + +def test_collect_server_tools_no_default_options() -> None: + """collect_server_tools returns empty list when agent has no default_options.""" + + class MockAgent: + pass + + agent = MockAgent() + tools = collect_server_tools(agent) + assert tools == [] + + +def test_register_additional_client_tools_no_tools() -> None: + """register_additional_client_tools does nothing with None tools.""" + mock_chat_client = MagicMock() + agent = ChatAgent(chat_client=mock_chat_client) + + # Should not raise + register_additional_client_tools(agent, None) + + +def test_register_additional_client_tools_no_chat_client() -> None: + """register_additional_client_tools does nothing when agent has no chat_client.""" + from agent_framework_ag_ui._orchestration._tooling import register_additional_client_tools + + class MockAgent: + pass + + agent = MockAgent() + tools = [DummyTool("x")] + + # Should not raise + register_additional_client_tools(agent, tools) + + +def test_merge_tools_no_client_tools() -> None: + """merge_tools returns None when no client tools.""" + server = [DummyTool("a")] + result = merge_tools(server, None) + assert result is None + + +def test_merge_tools_all_duplicates() -> None: + """merge_tools returns None when all client tools duplicate server tools.""" + server = [DummyTool("a"), DummyTool("b")] + client = [DummyTool("a"), DummyTool("b")] + result = merge_tools(server, client) + assert result is None + + +def test_merge_tools_empty_server() -> None: + """merge_tools works with empty server tools.""" + server: list = [] + client = [DummyTool("a"), DummyTool("b")] + result = merge_tools(server, client) + assert result is not None + assert len(result) == 2 + + +def test_merge_tools_with_approval_tools_no_client() -> None: + """merge_tools returns server tools when they have approval mode even without client tools.""" + + class ApprovalTool: + def __init__(self, name: str): + self.name = name + self.approval_mode = "always_require" + + server = [ApprovalTool("write_doc")] + result = merge_tools(server, None) + assert result is not None + assert len(result) == 1 + assert result[0].name == "write_doc" + + +def test_merge_tools_with_approval_tools_all_duplicates() -> None: + """merge_tools returns server tools with approval mode even when client duplicates.""" + + class ApprovalTool: + def __init__(self, name: str): + self.name = name + self.approval_mode = "always_require" + + server = [ApprovalTool("write_doc")] + client = [DummyTool("write_doc")] # Same name as server + result = merge_tools(server, client) + assert result is not None + assert len(result) == 1 + assert result[0].approval_mode == "always_require" diff --git a/python/packages/ag-ui/tests/test_utils.py b/python/packages/ag-ui/tests/test_utils.py index b077468b81..4c743bd34f 100644 --- a/python/packages/ag-ui/tests/test_utils.py +++ b/python/packages/ag-ui/tests/test_utils.py @@ -307,3 +307,173 @@ def tool2(y: str) -> str: assert len(result) == 2 assert result[0]["name"] == "tool1" assert result[1]["name"] == "tool2" + + +# Additional tests for utils coverage + + +def test_safe_json_parse_with_dict(): + """Test safe_json_parse with dict input.""" + from agent_framework_ag_ui._utils import safe_json_parse + + input_dict = {"key": "value"} + result = safe_json_parse(input_dict) + assert result == input_dict + + +def test_safe_json_parse_with_json_string(): + """Test safe_json_parse with JSON string.""" + from agent_framework_ag_ui._utils import safe_json_parse + + result = safe_json_parse('{"key": "value"}') + assert result == {"key": "value"} + + +def test_safe_json_parse_with_invalid_json(): + """Test safe_json_parse with invalid JSON.""" + from agent_framework_ag_ui._utils import safe_json_parse + + result = safe_json_parse("not json") + assert result is None + + +def test_safe_json_parse_with_non_dict_json(): + """Test safe_json_parse with JSON that parses to non-dict.""" + from agent_framework_ag_ui._utils import safe_json_parse + + result = safe_json_parse("[1, 2, 3]") + assert result is None + + +def test_safe_json_parse_with_none(): + """Test safe_json_parse with None input.""" + from agent_framework_ag_ui._utils import safe_json_parse + + result = safe_json_parse(None) + assert result is None + + +def test_get_role_value_with_enum(): + """Test get_role_value with enum role.""" + from agent_framework import ChatMessage, Content, Role + + from agent_framework_ag_ui._utils import get_role_value + + message = ChatMessage(role=Role.USER, contents=[Content.from_text("test")]) + result = get_role_value(message) + assert result == "user" + + +def test_get_role_value_with_string(): + """Test get_role_value with string role.""" + from agent_framework_ag_ui._utils import get_role_value + + class MockMessage: + role = "assistant" + + result = get_role_value(MockMessage()) + assert result == "assistant" + + +def test_get_role_value_with_none(): + """Test get_role_value with no role.""" + from agent_framework_ag_ui._utils import get_role_value + + class MockMessage: + pass + + result = get_role_value(MockMessage()) + assert result == "" + + +def test_normalize_agui_role_developer(): + """Test normalize_agui_role maps developer to system.""" + from agent_framework_ag_ui._utils import normalize_agui_role + + assert normalize_agui_role("developer") == "system" + + +def test_normalize_agui_role_valid(): + """Test normalize_agui_role with valid roles.""" + from agent_framework_ag_ui._utils import normalize_agui_role + + assert normalize_agui_role("user") == "user" + assert normalize_agui_role("assistant") == "assistant" + assert normalize_agui_role("system") == "system" + assert normalize_agui_role("tool") == "tool" + + +def test_normalize_agui_role_invalid(): + """Test normalize_agui_role with invalid role defaults to user.""" + from agent_framework_ag_ui._utils import normalize_agui_role + + assert normalize_agui_role("invalid") == "user" + assert normalize_agui_role(123) == "user" + + +def test_extract_state_from_tool_args(): + """Test extract_state_from_tool_args.""" + from agent_framework_ag_ui._utils import extract_state_from_tool_args + + # Specific key + assert extract_state_from_tool_args({"key": "value"}, "key") == "value" + + # Wildcard + args = {"a": 1, "b": 2} + assert extract_state_from_tool_args(args, "*") == args + + # Missing key + assert extract_state_from_tool_args({"other": "value"}, "key") is None + + # None args + assert extract_state_from_tool_args(None, "key") is None + + +def test_convert_agui_tools_to_agent_framework(): + """Test convert_agui_tools_to_agent_framework.""" + from agent_framework_ag_ui._utils import convert_agui_tools_to_agent_framework + + agui_tools = [ + { + "name": "test_tool", + "description": "A test tool", + "parameters": {"type": "object", "properties": {"arg": {"type": "string"}}}, + } + ] + + result = convert_agui_tools_to_agent_framework(agui_tools) + + assert result is not None + assert len(result) == 1 + assert result[0].name == "test_tool" + assert result[0].description == "A test tool" + assert result[0].declaration_only is True + + +def test_convert_agui_tools_to_agent_framework_none(): + """Test convert_agui_tools_to_agent_framework with None.""" + from agent_framework_ag_ui._utils import convert_agui_tools_to_agent_framework + + result = convert_agui_tools_to_agent_framework(None) + assert result is None + + +def test_convert_agui_tools_to_agent_framework_empty(): + """Test convert_agui_tools_to_agent_framework with empty list.""" + from agent_framework_ag_ui._utils import convert_agui_tools_to_agent_framework + + result = convert_agui_tools_to_agent_framework([]) + assert result is None + + +def test_make_json_safe_unconvertible(): + """Test make_json_safe with object that has no standard conversion.""" + + class NoConversion: + __slots__ = () # No __dict__ + + from agent_framework_ag_ui._utils import make_json_safe + + result = make_json_safe(NoConversion()) + # Falls back to str() + assert isinstance(result, str) From e6bdeb0b987789bdf95573fbcc158562be0d0b62 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 21 Jan 2026 15:13:52 +0900 Subject: [PATCH 5/8] Improvements --- .../ag-ui/agent_framework_ag_ui/_run.py | 31 +++++++----- .../tests/test_agent_wrapper_comprehensive.py | 49 +++++++++++++++---- .../agent_framework_anthropic/_chat_client.py | 5 +- .../packages/core/agent_framework/_agents.py | 6 ++- .../packages/core/agent_framework/_tools.py | 2 +- 5 files changed, 68 insertions(+), 25 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index 554a044f39..afcf12bd86 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -30,6 +30,13 @@ Content, prepare_function_call_results, ) +from agent_framework._middleware import extract_and_merge_function_middleware +from agent_framework._tools import ( + FunctionInvocationConfiguration, + _collect_approval_responses, # type: ignore + _replace_approval_contents_with_results, # type: ignore + _try_execute_function_calls, # type: ignore +) from ._message_adapters import normalize_agui_input_messages from ._orchestration._predictive_state import PredictiveStateHandler @@ -43,6 +50,9 @@ logger = logging.getLogger(__name__) +# Keys that are internal to AG-UI orchestration and should not be passed to chat clients +AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state"} + def _build_safe_metadata(thread_metadata: dict[str, Any] | None) -> dict[str, Any]: """Build metadata dict with truncated string values for Azure compatibility. @@ -546,14 +556,6 @@ async def _resolve_approval_responses( agent: The agent instance (to get chat_client and config) run_kwargs: Kwargs for tool execution """ - from agent_framework._middleware import extract_and_merge_function_middleware - from agent_framework._tools import ( - FunctionInvocationConfiguration, - _collect_approval_responses, - _replace_approval_contents_with_results, - _try_execute_function_calls, - ) - fcc_todo = _collect_approval_responses(messages) if not fcc_todo: return @@ -579,8 +581,8 @@ async def _resolve_approval_responses( config=config, ) approved_function_results = list(results) - except Exception: - logger.error("Failed to execute approved tool calls; injecting error results.") + except Exception as e: + logger.exception("Failed to execute approved tool calls; injecting error results: %s", e) approved_function_results = [] # Build normalized results for approved responses @@ -735,7 +737,14 @@ async def run_agent_stream( run_kwargs: dict[str, Any] = {"thread": thread} if tools: run_kwargs["tools"] = tools - safe_metadata = _build_safe_metadata(thread.metadata) # type: ignore[attr-defined] + # Filter out AG-UI internal metadata keys before passing to chat client + # These are used internally for orchestration and should not be sent to the LLM provider + client_metadata = { + k: v + for k, v in (thread.metadata or {}).items() + if k not in AG_UI_INTERNAL_METADATA_KEYS # type: ignore[attr-defined] + } + safe_metadata = _build_safe_metadata(client_metadata) if client_metadata else {} if safe_metadata: run_kwargs["options"] = {"metadata": safe_metadata, "store": True} diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 3adfe494cc..2add81a7d9 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -420,17 +420,25 @@ async def stream_fn( async def test_thread_metadata_tracking(): - """Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id.""" + """Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id. + + AG-UI internal metadata is stored in thread.metadata for orchestration, + but filtered out before passing to the chat client's options.metadata. + """ from agent_framework.ag_ui import AgentFrameworkAgent - thread_metadata: dict[str, Any] = {} + captured_thread: dict[str, Any] = {} + captured_options: dict[str, Any] = {} async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - metadata = options.get("metadata") - if metadata: - thread_metadata.update(metadata) + # Capture the thread object from kwargs + thread = kwargs.get("thread") + if thread and hasattr(thread, "metadata"): + captured_thread["metadata"] = thread.metadata + # Capture options to verify internal keys are NOT passed to chat client + captured_options.update(options) yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) @@ -446,22 +454,37 @@ async def stream_fn( async for event in wrapper.run_agent(input_data): events.append(event) + # AG-UI internal metadata should be stored in thread.metadata + thread_metadata = captured_thread.get("metadata", {}) assert thread_metadata.get("ag_ui_thread_id") == "test_thread_123" assert thread_metadata.get("ag_ui_run_id") == "test_run_456" + # Internal metadata should NOT be passed to chat client options + options_metadata = captured_options.get("metadata", {}) + assert "ag_ui_thread_id" not in options_metadata + assert "ag_ui_run_id" not in options_metadata + async def test_state_context_injection(): - """Test that current state is injected into thread metadata.""" + """Test that current state is injected into thread metadata. + + AG-UI internal metadata (including current_state) is stored in thread.metadata + for orchestration, but filtered out before passing to the chat client's options.metadata. + """ from agent_framework_ag_ui import AgentFrameworkAgent - thread_metadata: dict[str, Any] = {} + captured_thread: dict[str, Any] = {} + captured_options: dict[str, Any] = {} async def stream_fn( messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any ) -> AsyncIterator[ChatResponseUpdate]: - metadata = options.get("metadata") - if metadata: - thread_metadata.update(metadata) + # Capture the thread object from kwargs + thread = kwargs.get("thread") + if thread and hasattr(thread, "metadata"): + captured_thread["metadata"] = thread.metadata + # Capture options to verify internal keys are NOT passed to chat client + captured_options.update(options) yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")]) agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) @@ -479,11 +502,17 @@ async def stream_fn( async for event in wrapper.run_agent(input_data): events.append(event) + # Current state should be stored in thread.metadata + thread_metadata = captured_thread.get("metadata", {}) current_state = thread_metadata.get("current_state") if isinstance(current_state, str): current_state = json.loads(current_state) assert current_state == {"document": "Test content"} + # Internal metadata should NOT be passed to chat client options + options_metadata = captured_options.get("metadata", {}) + assert "current_state" not in options_metadata + async def test_no_messages_provided(): """Test handling when no messages are provided.""" diff --git a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py index 4fdcdfadc7..a52faafc80 100644 --- a/python/packages/anthropic/agent_framework_anthropic/_chat_client.py +++ b/python/packages/anthropic/agent_framework_anthropic/_chat_client.py @@ -145,6 +145,7 @@ class AnthropicChatOptions(ChatOptions, total=False): frequency_penalty: None # type: ignore[misc] presence_penalty: None # type: ignore[misc] store: None # type: ignore[misc] + conversation_id: None # type: ignore[misc] TAnthropicOptions = TypeVar( @@ -468,7 +469,9 @@ def _prepare_message_for_anthropic(self, message: ChatMessage) -> dict[str, Any] for content in message.contents: match content.type: case "text": - a_content.append({"type": "text", "text": content.text}) + # Skip empty text content blocks - Anthropic API rejects them + if content.text: + a_content.append({"type": "text", "text": content.text}) case "data": if content.has_top_level_media_type("image"): a_content.append({ diff --git a/python/packages/core/agent_framework/_agents.py b/python/packages/core/agent_framework/_agents.py index 412221af1f..2092ebcb32 100644 --- a/python/packages/core/agent_framework/_agents.py +++ b/python/packages/core/agent_framework/_agents.py @@ -89,8 +89,10 @@ def _merge_options(base: dict[str, Any], override: dict[str, Any]) -> dict[str, if value is None: continue if key == "tools" and result.get("tools"): - # Combine tool lists - result["tools"] = list(result["tools"]) + list(value) + # Combine tool lists, avoiding duplicates by name + existing_names = {getattr(t, "name", None) for t in result["tools"]} + unique_new = [t for t in value if getattr(t, "name", None) not in existing_names] + result["tools"] = list(result["tools"]) + unique_new elif key == "logit_bias" and result.get("logit_bias"): # Merge logit_bias dicts result["logit_bias"] = {**result["logit_bias"], **value} diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 1e86e1d4d5..cbd211ed5a 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1553,7 +1553,7 @@ async def _auto_invoke_function( runtime_kwargs: dict[str, Any] = { key: value for key, value in (custom_args or {}).items() - if key not in {"_function_middleware_pipeline", "middleware"} + if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"} } try: args = tool.input_model.model_validate(parsed_args) From 7d5440a33ec507d08f21da324f79a8af91e649a7 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Wed, 21 Jan 2026 15:24:01 +0900 Subject: [PATCH 6/8] Fix mypy --- .../ag-ui/agent_framework_ag_ui/_client.py | 4 +- .../ag-ui/agent_framework_ag_ui/_endpoint.py | 8 +- .../_message_adapters.py | 51 +++++------ .../_orchestration/_helpers.py | 2 +- .../ag-ui/agent_framework_ag_ui/_run.py | 13 +-- .../ag-ui/agent_framework_ag_ui/_types.py | 1 + python/packages/ag-ui/tests/test_helpers.py | 1 - .../ag-ui/tests/test_predictive_state.py | 85 +++++-------------- python/packages/ag-ui/tests/test_run.py | 9 +- .../packages/ag-ui/tests/utils_test_ag_ui.py | 2 - .../packages/core/agent_framework/_tools.py | 22 +++-- 11 files changed, 76 insertions(+), 122 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_client.py b/python/packages/ag-ui/agent_framework_ag_ui/_client.py index a336f28b76..7a03949b66 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_client.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_client.py @@ -74,7 +74,7 @@ def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseCha original_get_streaming_response = chat_client.get_streaming_response @wraps(original_get_streaming_response) - async def streaming_wrapper(self, *args: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]: + async def streaming_wrapper(self: Any, *args: Any, **kwargs: Any) -> AsyncIterable[ChatResponseUpdate]: async for update in original_get_streaming_response(self, *args, **kwargs): _unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], update.contents)) yield update @@ -84,7 +84,7 @@ async def streaming_wrapper(self, *args: Any, **kwargs: Any) -> AsyncIterable[Ch original_get_response = chat_client.get_response @wraps(original_get_response) - async def response_wrapper(self, *args: Any, **kwargs: Any) -> ChatResponse: + async def response_wrapper(self: Any, *args: Any, **kwargs: Any) -> ChatResponse: response = await original_get_response(self, *args, **kwargs) if response.messages: for message in response.messages: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py index 07e818882d..dc39be77e7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_endpoint.py @@ -4,7 +4,7 @@ import copy import logging -from collections.abc import Sequence +from collections.abc import AsyncGenerator, Sequence from typing import Any from ag_ui.encoder import EventEncoder @@ -56,8 +56,8 @@ def add_agent_framework_fastapi_endpoint( else: wrapped_agent = agent - @app.post(path, tags=tags or ["AG-UI"], dependencies=dependencies) # type: ignore[arg-type] - async def agent_endpoint(request_body: AGUIRequest): # type: ignore[misc] + @app.post(path, tags=tags or ["AG-UI"], dependencies=dependencies, response_model=None) # type: ignore[arg-type] + async def agent_endpoint(request_body: AGUIRequest) -> StreamingResponse | dict[str, str]: """Handle AG-UI agent requests. Note: Function is accessed via FastAPI's decorator registration, @@ -77,7 +77,7 @@ async def agent_endpoint(request_body: AGUIRequest): # type: ignore[misc] ) logger.info(f"Received request at {path}: {input_data.get('run_id', 'no-run-id')}") - async def event_generator(): + async def event_generator() -> AsyncGenerator[str, None]: encoder = EventEncoder() event_count = 0 async for event in wrapped_agent.run_agent(input_data): diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py index 525fb2faab..f8f1623a30 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_message_adapters.py @@ -252,22 +252,19 @@ def _update_tool_call_arguments( tool_calls = raw_msg.get("tool_calls") or raw_msg.get("toolCalls") if not isinstance(tool_calls, list): continue - tool_calls_list = cast(list[Any], tool_calls) - for tool_call in tool_calls_list: + for tool_call in tool_calls: if not isinstance(tool_call, dict): continue - tool_call_dict = cast(dict[str, Any], tool_call) - if str(tool_call_dict.get("id", "")) != tool_call_id: + if str(tool_call.get("id", "")) != tool_call_id: continue - function_payload = tool_call_dict.get("function") + function_payload = tool_call.get("function") if not isinstance(function_payload, dict): return - function_payload_dict = cast(dict[str, Any], function_payload) - existing_args = function_payload_dict.get("arguments") + existing_args = function_payload.get("arguments") if isinstance(existing_args, str): - function_payload_dict["arguments"] = json.dumps(modified_args) + function_payload["arguments"] = json.dumps(modified_args) else: - function_payload_dict["arguments"] = modified_args + function_payload["arguments"] = modified_args return def _find_matching_func_call(call_id: str) -> Content | None: @@ -433,8 +430,7 @@ def _filter_modified_args( if desc: approved_by_description[str(desc)] = step_item_dict merged_steps: list[Any] = [] - original_steps_list = cast(list[Any], original_steps) - for orig_step in original_steps_list: + for orig_step in original_steps: if not isinstance(orig_step, dict): merged_steps.append(orig_step) continue @@ -498,9 +494,9 @@ def _filter_modified_args( if isinstance(result_content, str): func_result = result_content elif isinstance(result_content, dict): - func_result = cast(dict[str, Any], result_content) + func_result = result_content elif isinstance(result_content, list): - func_result = cast(list[Any], result_content) + func_result = result_content else: func_result = str(result_content) chat_msg = ChatMessage( @@ -734,40 +730,35 @@ def agui_messages_to_snapshot_format(messages: list[dict[str, Any]]) -> list[dic if isinstance(content, list): # Convert content array format to simple string text_parts: list[str] = [] - content_list = cast(list[Any], content) - for item in content_list: + for item in content: if isinstance(item, dict): - item_dict = cast(dict[str, Any], item) # Convert 'input_text' to 'text' type - if item_dict.get("type") == "input_text": - text_parts.append(str(item_dict.get("text", ""))) - elif item_dict.get("type") == "text": - text_parts.append(str(item_dict.get("text", ""))) + if item.get("type") == "input_text": + text_parts.append(str(item.get("text", ""))) + elif item.get("type") == "text": + text_parts.append(str(item.get("text", ""))) else: # Other types - just extract text field if present - text_parts.append(str(item_dict.get("text", ""))) + text_parts.append(str(item.get("text", ""))) normalized_msg["content"] = "".join(text_parts) elif content is None: normalized_msg["content"] = "" tool_calls = normalized_msg.get("tool_calls") or normalized_msg.get("toolCalls") if isinstance(tool_calls, list): - tool_calls_list = cast(list[Any], tool_calls) - for tool_call in tool_calls_list: + for tool_call in tool_calls: if not isinstance(tool_call, dict): continue - tool_call_dict = cast(dict[str, Any], tool_call) - function_payload = tool_call_dict.get("function") + function_payload = tool_call.get("function") if not isinstance(function_payload, dict): continue - function_payload_dict = cast(dict[str, Any], function_payload) - if "arguments" not in function_payload_dict: + if "arguments" not in function_payload: continue - arguments = function_payload_dict.get("arguments") + arguments = function_payload.get("arguments") if arguments is None: - function_payload_dict["arguments"] = "" + function_payload["arguments"] = "" elif not isinstance(arguments, str): - function_payload_dict["arguments"] = json.dumps(arguments) + function_payload["arguments"] = json.dumps(arguments) # Normalize tool_call_id to toolCallId for tool messages normalized_msg["role"] = normalize_agui_role(normalized_msg.get("role")) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py index 66b0160469..f12430a086 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_orchestration/_helpers.py @@ -15,7 +15,7 @@ Content, ) -from .._utils import get_role_value, safe_json_parse +from .._utils import get_role_value logger = logging.getLogger(__name__) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index afcf12bd86..bc8d510641 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -340,9 +340,12 @@ def _emit_tool_result( """Emit ToolCallResult events for FunctionResultContent.""" events: list[BaseEvent] = [] - if content.call_id: - events.append(ToolCallEndEvent(tool_call_id=content.call_id)) - flow.tool_calls_ended.add(content.call_id) # Track ended tool calls + # Cannot emit tool result without a call_id to associate it with + if not content.call_id: + return events + + events.append(ToolCallEndEvent(tool_call_id=content.call_id)) + flow.tool_calls_ended.add(content.call_id) # Track ended tool calls result_content = prepare_function_call_results(content.result) message_id = generate_event_id() @@ -740,9 +743,7 @@ async def run_agent_stream( # Filter out AG-UI internal metadata keys before passing to chat client # These are used internally for orchestration and should not be sent to the LLM provider client_metadata = { - k: v - for k, v in (thread.metadata or {}).items() - if k not in AG_UI_INTERNAL_METADATA_KEYS # type: ignore[attr-defined] + k: v for k, v in (getattr(thread, "metadata", None) or {}).items() if k not in AG_UI_INTERNAL_METADATA_KEYS } safe_metadata = _build_safe_metadata(client_metadata) if client_metadata else {} if safe_metadata: diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index 7466f09371..aedfb43990 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -13,6 +13,7 @@ else: from typing_extensions import TypeVar + class PredictStateConfig(TypedDict): """Configuration for predictive state updates.""" diff --git a/python/packages/ag-ui/tests/test_helpers.py b/python/packages/ag-ui/tests/test_helpers.py index 3762d3c5fd..b4a7e9f047 100644 --- a/python/packages/ag-ui/tests/test_helpers.py +++ b/python/packages/ag-ui/tests/test_helpers.py @@ -2,7 +2,6 @@ """Tests for orchestration helper functions.""" -import pytest from agent_framework import ChatMessage, Content from agent_framework_ag_ui._orchestration._helpers import ( diff --git a/python/packages/ag-ui/tests/test_predictive_state.py b/python/packages/ag-ui/tests/test_predictive_state.py index 59113bf05c..31ad46fc3a 100644 --- a/python/packages/ag-ui/tests/test_predictive_state.py +++ b/python/packages/ag-ui/tests/test_predictive_state.py @@ -2,7 +2,6 @@ """Tests for predictive state handling.""" -import pytest from ag_ui.core import StateDeltaEvent from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler @@ -56,25 +55,19 @@ def test_no_config(self): def test_no_args(self): """Returns None when args is None.""" - handler = PredictiveStateHandler( - predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}} - ) + handler = PredictiveStateHandler(predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}}) result = handler.extract_state_value("tool", None) assert result is None def test_empty_args(self): """Returns None when args is empty string.""" - handler = PredictiveStateHandler( - predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}} - ) + handler = PredictiveStateHandler(predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}}) result = handler.extract_state_value("tool", "") assert result is None def test_tool_not_in_config(self): """Returns None when tool not in config.""" - handler = PredictiveStateHandler( - predict_state_config={"key": {"tool": "other_tool", "tool_argument": "arg"}} - ) + handler = PredictiveStateHandler(predict_state_config={"key": {"tool": "other_tool", "tool_argument": "arg"}}) result = handler.extract_state_value("some_tool", {"arg": "value"}) assert result is None @@ -88,9 +81,7 @@ def test_extracts_specific_argument(self): def test_extracts_with_wildcard(self): """Extracts entire args with * wildcard.""" - handler = PredictiveStateHandler( - predict_state_config={"data": {"tool": "update_data", "tool_argument": "*"}} - ) + handler = PredictiveStateHandler(predict_state_config={"data": {"tool": "update_data", "tool_argument": "*"}}) args = {"key1": "value1", "key2": "value2"} result = handler.extract_state_value("update_data", args) assert result == ("data", args) @@ -117,9 +108,7 @@ class TestIsPredictiveTool: def test_none_tool_name(self): """Returns False for None tool name.""" - handler = PredictiveStateHandler( - predict_state_config={"key": {"tool": "some_tool", "tool_argument": "arg"}} - ) + handler = PredictiveStateHandler(predict_state_config={"key": {"tool": "some_tool", "tool_argument": "arg"}}) assert handler.is_predictive_tool(None) is False def test_no_config(self): @@ -129,16 +118,12 @@ def test_no_config(self): def test_tool_in_config(self): """Returns True when tool is in config.""" - handler = PredictiveStateHandler( - predict_state_config={"key": {"tool": "some_tool", "tool_argument": "arg"}} - ) + handler = PredictiveStateHandler(predict_state_config={"key": {"tool": "some_tool", "tool_argument": "arg"}}) assert handler.is_predictive_tool("some_tool") is True def test_tool_not_in_config(self): """Returns False when tool not in config.""" - handler = PredictiveStateHandler( - predict_state_config={"key": {"tool": "other_tool", "tool_argument": "arg"}} - ) + handler = PredictiveStateHandler(predict_state_config={"key": {"tool": "other_tool", "tool_argument": "arg"}}) assert handler.is_predictive_tool("some_tool") is False @@ -147,9 +132,7 @@ class TestEmitStreamingDeltas: def test_no_tool_name(self): """Returns empty list for None tool name.""" - handler = PredictiveStateHandler( - predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}} - ) + handler = PredictiveStateHandler(predict_state_config={"key": {"tool": "tool", "tool_argument": "arg"}}) result = handler.emit_streaming_deltas(None, '{"arg": "value"}') assert result == [] @@ -161,18 +144,14 @@ def test_no_config(self): def test_accumulates_args(self): """Accumulates argument chunks.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) handler.emit_streaming_deltas("write", '{"text') handler.emit_streaming_deltas("write", '": "hello') assert handler.streaming_tool_args == '{"text": "hello' def test_emits_delta_on_complete_json(self): """Emits delta when JSON is complete.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) events = handler.emit_streaming_deltas("write", '{"text": "hello"}') assert len(events) == 1 assert isinstance(events[0], StateDeltaEvent) @@ -182,9 +161,7 @@ def test_emits_delta_on_complete_json(self): def test_emits_delta_on_partial_json(self): """Emits delta from partial JSON using regex.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) # First chunk - partial events = handler.emit_streaming_deltas("write", '{"text": "hel') assert len(events) == 1 @@ -192,9 +169,7 @@ def test_emits_delta_on_partial_json(self): def test_does_not_emit_duplicate_deltas(self): """Does not emit delta when value unchanged.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) # First emission events1 = handler.emit_streaming_deltas("write", '{"text": "hello"}') assert len(events1) == 1 @@ -206,9 +181,7 @@ def test_does_not_emit_duplicate_deltas(self): def test_emits_delta_on_value_change(self): """Emits delta when value changes.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) # First value events1 = handler.emit_streaming_deltas("write", '{"text": "hello"}') assert len(events1) == 1 @@ -221,9 +194,7 @@ def test_emits_delta_on_value_change(self): def test_tracks_pending_updates(self): """Tracks pending state updates.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) handler.emit_streaming_deltas("write", '{"text": "hello"}') assert handler.pending_state_updates == {"doc": "hello"} @@ -233,9 +204,7 @@ class TestEmitPartialDeltas: def test_unescapes_newlines(self): """Unescapes \\n in partial values.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) handler.streaming_tool_args = '{"text": "line1\\nline2' events = handler._emit_partial_deltas("write") assert len(events) == 1 @@ -243,9 +212,7 @@ def test_unescapes_newlines(self): def test_handles_escaped_quotes_partially(self): """Handles escaped quotes - regex stops at quote character.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) # The regex pattern [^"]* stops at ANY quote, including escaped ones. # This is expected behavior for partial streaming - the full JSON # will be parsed correctly when complete. @@ -262,9 +229,7 @@ def test_handles_escaped_quotes_partially(self): def test_unescapes_backslashes(self): """Unescapes \\\\ in partial values.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) handler.streaming_tool_args = '{"text": "path\\\\to\\\\file' events = handler._emit_partial_deltas("write") assert len(events) == 1 @@ -276,26 +241,20 @@ class TestEmitCompleteDeltas: def test_emits_for_matching_tool(self): """Emits delta for tool matching config.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) events = handler._emit_complete_deltas("write", {"text": "content"}) assert len(events) == 1 assert events[0].delta[0]["value"] == "content" def test_skips_non_matching_tool(self): """Skips tools not matching config.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) events = handler._emit_complete_deltas("other_tool", {"text": "content"}) assert len(events) == 0 def test_handles_wildcard_argument(self): """Handles * wildcard for entire args.""" - handler = PredictiveStateHandler( - predict_state_config={"data": {"tool": "update", "tool_argument": "*"}} - ) + handler = PredictiveStateHandler(predict_state_config={"data": {"tool": "update", "tool_argument": "*"}}) args = {"key1": "val1", "key2": "val2"} events = handler._emit_complete_deltas("update", args) assert len(events) == 1 @@ -303,9 +262,7 @@ def test_handles_wildcard_argument(self): def test_skips_missing_argument(self): """Skips when tool_argument not in args.""" - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "text"}}) events = handler._emit_complete_deltas("write", {"other": "value"}) assert len(events) == 0 diff --git a/python/packages/ag-ui/tests/test_run.py b/python/packages/ag-ui/tests/test_run.py index 0f842325a6..a415000692 100644 --- a/python/packages/ag-ui/tests/test_run.py +++ b/python/packages/ag-ui/tests/test_run.py @@ -2,7 +2,6 @@ """Tests for _run.py helper functions and FlowState.""" -import pytest from agent_framework import ChatMessage, Content from agent_framework_ag_ui._run import ( @@ -146,9 +145,7 @@ def test_default_values(self): def test_get_tool_name(self): """Tests get_tool_name method.""" flow = FlowState() - flow.tool_calls_by_id = { - "call_123": {"function": {"name": "get_weather", "arguments": "{}"}} - } + flow.tool_calls_by_id = {"call_123": {"function": {"name": "get_weather", "arguments": "{}"}}} assert flow.get_tool_name("call_123") == "get_weather" assert flow.get_tool_name("nonexistent") is None @@ -370,9 +367,7 @@ def test_extract_approved_state_updates_no_approval(): from agent_framework_ag_ui._orchestration._predictive_state import PredictiveStateHandler from agent_framework_ag_ui._run import _extract_approved_state_updates - handler = PredictiveStateHandler( - predict_state_config={"doc": {"tool": "write", "tool_argument": "content"}} - ) + handler = PredictiveStateHandler(predict_state_config={"doc": {"tool": "write", "tool_argument": "content"}}) messages = [ChatMessage(role="user", contents=[Content.from_text("Hello")])] result = _extract_approved_state_updates(messages, handler) assert result == {} diff --git a/python/packages/ag-ui/tests/utils_test_ag_ui.py b/python/packages/ag-ui/tests/utils_test_ag_ui.py index 50710191ec..9d0ce11f65 100644 --- a/python/packages/ag-ui/tests/utils_test_ag_ui.py +++ b/python/packages/ag-ui/tests/utils_test_ag_ui.py @@ -20,8 +20,6 @@ ) from agent_framework._clients import TOptions_co -from agent_framework_ag_ui._message_adapters import _deduplicate_messages, _sanitize_tool_history - if sys.version_info >= (3, 12): from typing import override # type: ignore # pragma: no cover else: diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index cbd211ed5a..449bbf3b08 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1689,7 +1689,10 @@ async def _try_execute_function_calls( tool_map = _get_tool_map(tools) approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"] - logger.info(f"[APPROVAL-DEBUG] _try_execute_function_calls: tool_map keys={list(tool_map.keys())}, approval_tools={approval_tools}") + logger.info( + f"[APPROVAL-DEBUG] _try_execute_function_calls: tool_map keys={list(tool_map.keys())}, " + f"approval_tools={approval_tools}" + ) declaration_only = [tool_name for tool_name, tool in tool_map.items() if tool.declaration_only] additional_tool_names = [tool.name for tool in config.additional_tools] if config.additional_tools else [] # check if any are calling functions that need approval @@ -1697,7 +1700,11 @@ async def _try_execute_function_calls( approval_needed = False declaration_only_flag = False for fcc in function_calls: - logger.info(f"[APPROVAL-DEBUG] Checking fcc: type={fcc.type}, name={getattr(fcc, 'name', None)}, in approval_tools={getattr(fcc, 'name', None) in approval_tools}") + fcc_name = getattr(fcc, "name", None) + logger.info( + f"[APPROVAL-DEBUG] Checking fcc: type={fcc.type}, name={fcc_name}, " + f"in approval_tools={fcc_name in approval_tools}" + ) if fcc.type == "function_call" and fcc.name in approval_tools: # type: ignore[attr-defined] logger.info(f"[APPROVAL-DEBUG] APPROVAL NEEDED for {fcc.name}") approval_needed = True @@ -2152,10 +2159,15 @@ async def streaming_function_invocation_wrapper( # we load the tools here, since middleware might have changed them compared to before calling func. tools = _extract_tools(options) - logger.info(f"[APPROVAL-DEBUG-STREAMING] tools extracted: {tools is not None}, function_calls: {len(function_calls) if function_calls else 0}") + fc_count = len(function_calls) if function_calls else 0 + logger.info( + f"[APPROVAL-DEBUG-STREAMING] tools extracted: {tools is not None}, function_calls: {fc_count}" + ) if tools: - for t in (tools if isinstance(tools, list) else [tools]): - logger.info(f"[APPROVAL-DEBUG-STREAMING] - {getattr(t, 'name', 'unknown')}: approval_mode={getattr(t, 'approval_mode', None)}") + for t in tools if isinstance(tools, list) else [tools]: + t_name = getattr(t, "name", "unknown") + t_approval = getattr(t, "approval_mode", None) + logger.info(f"[APPROVAL-DEBUG-STREAMING] - {t_name}: approval_mode={t_approval}") if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function From 9aefad9ec110bc9bda0ebc9b8c5d9ebe3277ed32 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 22 Jan 2026 10:04:18 +0900 Subject: [PATCH 7/8] Fixes --- .../ag-ui/agent_framework_ag_ui/_run.py | 11 ++- .../ag-ui/agent_framework_ag_ui/_types.py | 12 +++ .../server/main.py | 9 +- python/packages/ag-ui/tests/test_types.py | 82 ++++++++++++++++++- .../packages/ag-ui/tests/utils_test_ag_ui.py | 4 - .../packages/core/agent_framework/_tools.py | 27 +++--- 6 files changed, 123 insertions(+), 22 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index bc8d510641..c652ce2e01 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -496,7 +496,8 @@ def _is_confirm_changes_response(messages: list[Any]) -> bool: if "accepted" in result and "steps" in result: return True except json.JSONDecodeError: - pass + # Content is not valid JSON; continue checking other content items + logger.debug("Failed to parse confirm_changes tool result as JSON; treating as non-confirmation.") return False @@ -906,7 +907,13 @@ async def run_agent_stream( flow.current_state[state_key] = state_value yield StateSnapshotEvent(snapshot=flow.current_state) except json.JSONDecodeError: - pass + # Ignore malformed JSON in tool arguments for predictive state; + # predictive updates are best-effort and should not break the flow. + logger.warning( + "Failed to decode JSON arguments for predictive tool '%s' (tool_call_id=%s).", + tool_name, + tool_call_id, + ) # Emit confirm_changes tool call confirm_id = generate_event_id() diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_types.py b/python/packages/ag-ui/agent_framework_ag_ui/_types.py index aedfb43990..a80cd155d2 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_types.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_types.py @@ -59,6 +59,18 @@ class AGUIRequest(BaseModel): None, description="Client-side tools to advertise to the LLM", ) + context: list[dict[str, Any]] | None = Field( + None, + description="List of context objects provided to the agent", + ) + forwarded_props: dict[str, Any] | None = Field( + None, + description="Additional properties forwarded to the agent", + ) + parent_run_id: str | None = Field( + None, + description="ID of the run that spawned this run", + ) # region AG-UI Chat Options TypedDict diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py index 54dcc5f558..7369c84679 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/server/main.py @@ -9,9 +9,8 @@ from agent_framework import ChatOptions from agent_framework._clients import BaseChatClient from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint - -# from agent_framework.azure import AzureOpenAIChatClient from agent_framework.anthropic import AnthropicClient +from agent_framework.azure import AzureOpenAIChatClient from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware @@ -64,8 +63,10 @@ # Create a shared chat client for all agents # You can use different chat clients for different agents if needed -# chat_client: BaseChatClient[ChatOptions] = AzureOpenAIChatClient() -chat_client: BaseChatClient[ChatOptions] = AnthropicClient() +# Set CHAT_CLIENT=anthropic to use Anthropic, defaults to Azure OpenAI +chat_client: BaseChatClient[ChatOptions] = ( + AnthropicClient() if os.getenv("CHAT_CLIENT", "").lower() == "anthropic" else AzureOpenAIChatClient() +) # Agentic Chat - basic chat agent add_agent_framework_fastapi_endpoint( diff --git a/python/packages/ag-ui/tests/test_types.py b/python/packages/ag-ui/tests/test_types.py index 3c61278d9e..6b0b00a687 100644 --- a/python/packages/ag-ui/tests/test_types.py +++ b/python/packages/ag-ui/tests/test_types.py @@ -2,7 +2,7 @@ """Tests for type definitions in _types.py.""" -from agent_framework_ag_ui._types import AgentState, PredictStateConfig, RunMetadata +from agent_framework_ag_ui._types import AgentState, AGUIRequest, PredictStateConfig, RunMetadata class TestPredictStateConfig: @@ -143,3 +143,83 @@ def test_agent_state_complex_messages(self) -> None: assert len(state["messages"]) == 2 assert "metadata" in state["messages"][0] assert "tool_calls" in state["messages"][1] + + +class TestAGUIRequest: + """Test AGUIRequest Pydantic model.""" + + def test_agui_request_minimal(self) -> None: + """Test creating AGUIRequest with only required fields.""" + request = AGUIRequest(messages=[{"role": "user", "content": "Hello"}]) + + assert len(request.messages) == 1 + assert request.messages[0]["content"] == "Hello" + assert request.run_id is None + assert request.thread_id is None + assert request.state is None + assert request.tools is None + assert request.context is None + assert request.forwarded_props is None + assert request.parent_run_id is None + + def test_agui_request_all_fields(self) -> None: + """Test creating AGUIRequest with all fields populated.""" + request = AGUIRequest( + messages=[{"role": "user", "content": "Hello"}], + run_id="run-123", + thread_id="thread-456", + state={"counter": 0}, + tools=[{"name": "search", "description": "Search tool"}], + context=[{"type": "document", "content": "Some context"}], + forwarded_props={"custom_key": "custom_value"}, + parent_run_id="parent-run-789", + ) + + assert request.run_id == "run-123" + assert request.thread_id == "thread-456" + assert request.state == {"counter": 0} + assert request.tools == [{"name": "search", "description": "Search tool"}] + assert request.context == [{"type": "document", "content": "Some context"}] + assert request.forwarded_props == {"custom_key": "custom_value"} + assert request.parent_run_id == "parent-run-789" + + def test_agui_request_model_dump_excludes_none(self) -> None: + """Test that model_dump(exclude_none=True) excludes None fields.""" + request = AGUIRequest( + messages=[{"role": "user", "content": "test"}], + tools=[{"name": "my_tool"}], + context=[{"id": "ctx1"}], + ) + + dumped = request.model_dump(exclude_none=True) + + assert "messages" in dumped + assert "tools" in dumped + assert "context" in dumped + assert "run_id" not in dumped + assert "thread_id" not in dumped + assert "state" not in dumped + assert "forwarded_props" not in dumped + assert "parent_run_id" not in dumped + + def test_agui_request_model_dump_includes_all_set_fields(self) -> None: + """Test that model_dump preserves all explicitly set fields. + + This is critical for the fix - ensuring tools, context, forwarded_props, + and parent_run_id are not stripped during request validation. + """ + request = AGUIRequest( + messages=[{"role": "user", "content": "test"}], + tools=[{"name": "client_tool", "parameters": {"type": "object"}}], + context=[{"type": "snippet", "content": "code here"}], + forwarded_props={"auth_token": "secret", "user_id": "user-1"}, + parent_run_id="parent-456", + ) + + dumped = request.model_dump(exclude_none=True) + + # Verify all fields are preserved (the main bug fix) + assert dumped["tools"] == [{"name": "client_tool", "parameters": {"type": "object"}}] + assert dumped["context"] == [{"type": "snippet", "content": "code here"}] + assert dumped["forwarded_props"] == {"auth_token": "secret", "user_id": "user-1"} + assert dumped["parent_run_id"] == "parent-456" diff --git a/python/packages/ag-ui/tests/utils_test_ag_ui.py b/python/packages/ag-ui/tests/utils_test_ag_ui.py index 9d0ce11f65..5c2415583c 100644 --- a/python/packages/ag-ui/tests/utils_test_ag_ui.py +++ b/python/packages/ag-ui/tests/utils_test_ag_ui.py @@ -122,7 +122,3 @@ async def _stream() -> AsyncIterator[AgentResponseUpdate]: def get_new_thread(self, **kwargs: Any) -> AgentThread: return AgentThread() - - -# Note: TestExecutionContext was removed along with _orchestrators.py -# Tests should now use run_agent_stream() directly or the StubAgent class diff --git a/python/packages/core/agent_framework/_tools.py b/python/packages/core/agent_framework/_tools.py index 449bbf3b08..e2edce3585 100644 --- a/python/packages/core/agent_framework/_tools.py +++ b/python/packages/core/agent_framework/_tools.py @@ -1689,9 +1689,10 @@ async def _try_execute_function_calls( tool_map = _get_tool_map(tools) approval_tools = [tool_name for tool_name, tool in tool_map.items() if tool.approval_mode == "always_require"] - logger.info( - f"[APPROVAL-DEBUG] _try_execute_function_calls: tool_map keys={list(tool_map.keys())}, " - f"approval_tools={approval_tools}" + logger.debug( + "_try_execute_function_calls: tool_map keys=%s, approval_tools=%s", + list(tool_map.keys()), + approval_tools, ) declaration_only = [tool_name for tool_name, tool in tool_map.items() if tool.declaration_only] additional_tool_names = [tool.name for tool in config.additional_tools] if config.additional_tools else [] @@ -1701,12 +1702,14 @@ async def _try_execute_function_calls( declaration_only_flag = False for fcc in function_calls: fcc_name = getattr(fcc, "name", None) - logger.info( - f"[APPROVAL-DEBUG] Checking fcc: type={fcc.type}, name={fcc_name}, " - f"in approval_tools={fcc_name in approval_tools}" + logger.debug( + "Checking function call: type=%s, name=%s, in approval_tools=%s", + fcc.type, + fcc_name, + fcc_name in approval_tools, ) if fcc.type == "function_call" and fcc.name in approval_tools: # type: ignore[attr-defined] - logger.info(f"[APPROVAL-DEBUG] APPROVAL NEEDED for {fcc.name}") + logger.debug("Approval needed for function: %s", fcc.name) approval_needed = True break if fcc.type == "function_call" and (fcc.name in declaration_only or fcc.name in additional_tool_names): # type: ignore[attr-defined] @@ -1716,7 +1719,7 @@ async def _try_execute_function_calls( raise KeyError(f'Error: Requested function "{fcc.name}" not found.') # type: ignore[attr-defined] if approval_needed: # approval can only be needed for Function Call Content, not Approval Responses. - logger.info("[APPROVAL-DEBUG] Returning function_approval_request contents") + logger.debug("Returning function_approval_request contents") return ( [ Content.from_function_approval_request(id=fcc.call_id, function_call=fcc) # type: ignore[attr-defined, arg-type] @@ -2160,14 +2163,16 @@ async def streaming_function_invocation_wrapper( # we load the tools here, since middleware might have changed them compared to before calling func. tools = _extract_tools(options) fc_count = len(function_calls) if function_calls else 0 - logger.info( - f"[APPROVAL-DEBUG-STREAMING] tools extracted: {tools is not None}, function_calls: {fc_count}" + logger.debug( + "Streaming: tools extracted=%s, function_calls=%d", + tools is not None, + fc_count, ) if tools: for t in tools if isinstance(tools, list) else [tools]: t_name = getattr(t, "name", "unknown") t_approval = getattr(t, "approval_mode", None) - logger.info(f"[APPROVAL-DEBUG-STREAMING] - {t_name}: approval_mode={t_approval}") + logger.debug(" Tool %s: approval_mode=%s", t_name, t_approval) if function_calls and tools: # Use the stored middleware pipeline instead of extracting from kwargs # because kwargs may have been modified by the underlying function From 21399f614851c6dbbc2b7db5451b8827075e28e7 Mon Sep 17 00:00:00 2001 From: Evan Mattson Date: Thu, 22 Jan 2026 11:05:44 +0900 Subject: [PATCH 8/8] Fix json serialize errors --- .../ag-ui/agent_framework_ag_ui/_run.py | 15 +++++-- .../agent_framework_ag_ui_examples/README.md | 34 --------------- .../tests/test_agent_wrapper_comprehensive.py | 43 ------------------- .../core/agent_framework/ag_ui/__init__.py | 5 --- .../core/agent_framework/ag_ui/__init__.pyi | 10 ----- 5 files changed, 11 insertions(+), 96 deletions(-) diff --git a/python/packages/ag-ui/agent_framework_ag_ui/_run.py b/python/packages/ag-ui/agent_framework_ag_ui/_run.py index c652ce2e01..d1229620a7 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui/_run.py +++ b/python/packages/ag-ui/agent_framework_ag_ui/_run.py @@ -41,7 +41,12 @@ from ._message_adapters import normalize_agui_input_messages from ._orchestration._predictive_state import PredictiveStateHandler from ._orchestration._tooling import collect_server_tools, merge_tools, register_additional_client_tools -from ._utils import convert_agui_tools_to_agent_framework, generate_event_id, get_conversation_id_from_update +from ._utils import ( + convert_agui_tools_to_agent_framework, + generate_event_id, + get_conversation_id_from_update, + make_json_safe, +) if TYPE_CHECKING: from collections.abc import AsyncGenerator @@ -317,7 +322,9 @@ def _emit_tool_call( # Emit args if present if content.arguments: - delta = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) + delta = ( + content.arguments if isinstance(content.arguments, str) else json.dumps(make_json_safe(content.arguments)) + ) events.append(ToolCallArgsEvent(tool_call_id=tool_call_id, delta=delta)) # Track args for MessagesSnapshotEvent @@ -424,7 +431,7 @@ def _emit_approval_request( "function_call": { "call_id": func_call_id, "name": func_name, - "arguments": func_call.parse_arguments(), + "arguments": make_json_safe(func_call.parse_arguments()), }, }, ) @@ -444,7 +451,7 @@ def _emit_approval_request( args = { "function_name": func_name, "function_call_id": func_call_id, - "function_arguments": func_call.parse_arguments() or {}, + "function_arguments": make_json_safe(func_call.parse_arguments()) or {}, "steps": [{"description": f"Execute {func_name}", "status": "enabled"}], } events.append(ToolCallArgsEvent(tool_call_id=confirm_id, delta=json.dumps(args))) diff --git a/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md b/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md index e9d6d4ed17..f22969f883 100644 --- a/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md +++ b/python/packages/ag-ui/agent_framework_ag_ui_examples/README.md @@ -289,40 +289,6 @@ wrapped_agent = AgentFrameworkAgent( ) ``` -### Custom Confirmation Strategies - -Provide domain-specific confirmation messages: - -```python -from typing import Any -from agent_framework import ChatAgent -from agent_framework.azure import AzureOpenAIChatClient -from agent_framework.ag_ui import AgentFrameworkAgent, ConfirmationStrategy - -class CustomConfirmationStrategy(ConfirmationStrategy): - def on_approval_accepted(self, steps: list[dict[str, Any]]) -> str: - return "Your custom approval message!" - - def on_approval_rejected(self, steps: list[dict[str, Any]]) -> str: - return "Your custom rejection message!" - - def on_state_confirmed(self) -> str: - return "State changes confirmed!" - - def on_state_rejected(self) -> str: - return "State changes rejected!" - -agent = ChatAgent( - name="custom_agent", - chat_client=AzureOpenAIChatClient(model_id="gpt-4o"), -) - -wrapped_agent = AgentFrameworkAgent( - agent=agent, - confirmation_strategy=CustomConfirmationStrategy(), -) -``` - ### Human in the Loop Human-in-the-loop is automatically handled when tools are marked for approval: diff --git a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py index 2add81a7d9..8acd56a094 100644 --- a/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py +++ b/python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py @@ -624,49 +624,6 @@ async def stream_fn( assert len(tool_events) == 0 -@pytest.mark.skip(reason="confirmation_strategy feature removed in orchestrator rewrite") -async def test_suppressed_summary_with_document_state(): - """Test suppressed summary uses document state for confirmation message.""" - from agent_framework.ag_ui import AgentFrameworkAgent, DocumentWriterConfirmationStrategy - - async def stream_fn( - messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any - ) -> AsyncIterator[ChatResponseUpdate]: - yield ChatResponseUpdate(contents=[Content.from_text(text="Response")]) - - agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn)) - wrapper = AgentFrameworkAgent( - agent=agent, - state_schema={"document": {"type": "string"}}, - predict_state_config={"document": {"tool": "write_doc", "tool_argument": "content"}}, - confirmation_strategy=DocumentWriterConfirmationStrategy(), - ) - - # Simulate confirmation with document state - tool_result: dict[str, Any] = {"accepted": True, "steps": []} - input_data: dict[str, Any] = { - "messages": [ - { - "role": "tool", - "content": json.dumps(tool_result), - "toolCallId": "confirm_123", - } - ], - "state": {"document": "This is the beginning of a document. It contains important information."}, - } - - events: list[Any] = [] - async for event in wrapper.run_agent(input_data): - events.append(event) - - # Should generate fallback summary from document state - text_events = [e for e in events if e.type == "TEXT_MESSAGE_CONTENT"] - assert len(text_events) > 0 - # Should contain some reference to the document - full_text = "".join(e.delta for e in text_events) - assert "written" in full_text.lower() or "document" in full_text.lower() - - async def test_agent_with_use_service_thread_is_false(): """Test that when use_service_thread is False, the AgentThread used to run the agent is NOT set to the service thread ID.""" from agent_framework.ag_ui import AgentFrameworkAgent diff --git a/python/packages/core/agent_framework/ag_ui/__init__.py b/python/packages/core/agent_framework/ag_ui/__init__.py index 941a586d30..b469bb8a60 100644 --- a/python/packages/core/agent_framework/ag_ui/__init__.py +++ b/python/packages/core/agent_framework/ag_ui/__init__.py @@ -12,11 +12,6 @@ "AGUIChatClient", "AGUIEventConverter", "AGUIHttpService", - "ConfirmationStrategy", - "DefaultConfirmationStrategy", - "TaskPlannerConfirmationStrategy", - "RecipeConfirmationStrategy", - "DocumentWriterConfirmationStrategy", ] diff --git a/python/packages/core/agent_framework/ag_ui/__init__.pyi b/python/packages/core/agent_framework/ag_ui/__init__.pyi index 201e1a0256..d7b6acafec 100644 --- a/python/packages/core/agent_framework/ag_ui/__init__.pyi +++ b/python/packages/core/agent_framework/ag_ui/__init__.pyi @@ -5,11 +5,6 @@ from agent_framework_ag_ui import ( AGUIChatClient, AGUIEventConverter, AGUIHttpService, - ConfirmationStrategy, - DefaultConfirmationStrategy, - DocumentWriterConfirmationStrategy, - RecipeConfirmationStrategy, - TaskPlannerConfirmationStrategy, __version__, add_agent_framework_fastapi_endpoint, ) @@ -19,11 +14,6 @@ __all__ = [ "AGUIEventConverter", "AGUIHttpService", "AgentFrameworkAgent", - "ConfirmationStrategy", - "DefaultConfirmationStrategy", - "DocumentWriterConfirmationStrategy", - "RecipeConfirmationStrategy", - "TaskPlannerConfirmationStrategy", "__version__", "add_agent_framework_fastapi_endpoint", ]