Skip to content
49 changes: 41 additions & 8 deletions python/packages/ag-ui/agent_framework_ag_ui/_orchestrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@

logger = logging.getLogger(__name__)

AG_UI_INTERNAL_METADATA_KEYS = {"ag_ui_thread_id", "ag_ui_run_id", "current_state"}


class ExecutionContext:
"""Shared context for orchestrators."""
Expand Down Expand Up @@ -465,6 +467,24 @@ async def run(
messages_snapshot_emitted = False
accumulated_text_content = ""
active_message_id: str | None = None
initial_events_emitted = False

# Check if this is an approval response flow (needs RunStartedEvent before any other events)
has_approval_response = any(
hasattr(msg, "contents")
and any(
getattr(c, "type", None) == "function_approval_response"
or type(c).__name__ == "FunctionApprovalResponseContent"
for c in (msg.contents or [])
)
for msg in provider_messages
)

# For approval responses, emit initial events upfront since the agent may not stream any updates
if has_approval_response:
for event in self._create_initial_events(event_bridge, state_manager):
yield event
initial_events_emitted = True

# Check for FunctionApprovalResponseContent and emit updated state snapshot
# This ensures the UI shows the approved state (e.g., 2 steps) not the original (3 steps)
Expand Down Expand Up @@ -493,15 +513,16 @@ async def run(
all_updates: list[Any] | None = [] if collect_updates else None
update_count = 0
# Prepare metadata for chat client (Azure requires string values)
# Filter out AG-UI internal fields (ag_ui_thread_id, ag_ui_run_id) that are
# used only for AG-UI orchestration and not understood by chat clients.
safe_metadata = build_safe_metadata(getattr(thread, "metadata", None))
client_metadata = {k: v for k, v in safe_metadata.items() if k not in AG_UI_INTERNAL_METADATA_KEYS}

run_kwargs: dict[str, Any] = {
"thread": thread,
"tools": tools_param,
"options": {"metadata": safe_metadata},
"options": {"metadata": client_metadata} if client_metadata else {},
}
if safe_metadata:
run_kwargs["options"]["store"] = True

async def _resolve_approval_responses(
messages: list[Any],
Expand All @@ -519,9 +540,13 @@ async def _resolve_approval_responses(
getattr(chat_client, "function_invocation_configuration", None) or FunctionInvocationConfiguration()
)
middleware_pipeline = extract_and_merge_function_middleware(chat_client, run_kwargs)
# Filter out AG-UI-specific kwargs that should not be passed to tool execution.
# 'options' contains metadata/store for Azure AI client requirements but is not
# understood by external tools like MCP servers.
tool_kwargs = {k: v for k, v in run_kwargs.items() if k != "options"}
try:
results, _ = await _try_execute_function_calls(
custom_args=run_kwargs,
custom_args=tool_kwargs,
attempt_idx=0,
function_calls=approved_responses,
tools=tools_for_execution,
Expand Down Expand Up @@ -616,8 +641,6 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
confirmation_message = strategy.on_state_rejected()

message_id = generate_event_id()
for event in self._create_initial_events(event_bridge, state_manager):
yield event
yield TextMessageStartEvent(message_id=message_id, role="assistant")
yield TextMessageContentEvent(message_id=message_id, delta=confirmation_message)
yield TextMessageEndEvent(message_id=message_id)
Expand All @@ -636,6 +659,9 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
should_recreate_event_bridge = True

if should_recreate_event_bridge:
# Preserve state from the old bridge to avoid orphaned messages and lost flags
old_message_id = event_bridge.current_message_id
old_should_stop_after_confirm = event_bridge.should_stop_after_confirm
event_bridge = AgentFrameworkEventBridge(
run_id=context.run_id,
thread_id=context.thread_id,
Expand All @@ -645,11 +671,16 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
require_confirmation=context.config.require_confirmation,
approval_tool_name=approval_tool_name,
)
# Restore state so messages can be properly closed and confirmation flow works
event_bridge.current_message_id = old_message_id
event_bridge.should_stop_after_confirm = old_should_stop_after_confirm
should_recreate_event_bridge = False

if update_count == 0:
# Emit initial events after the first update when we have the correct thread_id/run_id
if not initial_events_emitted:
for event in self._create_initial_events(event_bridge, state_manager):
yield event
initial_events_emitted = True

update_count += 1
logger.info(f"[STREAM] Received update #{update_count} from agent")
Expand Down Expand Up @@ -762,10 +793,12 @@ def _build_messages_snapshot(tool_message_id: str | None = None) -> MessagesSnap
yield TextMessageEndEvent(message_id=message_id)
logger.info(f"Emitted conversational message with length={len(response_dict['message'])}")

if all_updates is not None and len(all_updates) == 0:
# Ensure initial events are emitted even if the stream was empty
if not initial_events_emitted:
logger.info("No updates received from agent - emitting initial events")
for event in self._create_initial_events(event_bridge, state_manager):
yield event
initial_events_emitted = True

logger.info(f"[FINALIZE] Checking for unclosed message. current_message_id={event_bridge.current_message_id}")
if event_bridge.current_message_id:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

import logging
import os
from typing import TYPE_CHECKING

import uvicorn
from agent_framework import BaseChatClient, ChatOptions
from agent_framework.ag_ui import add_agent_framework_fastapi_endpoint
from agent_framework.azure import AzureOpenAIChatClient
from fastapi import FastAPI
Expand All @@ -20,10 +20,6 @@
from ..agents.ui_generator_agent import ui_generator_agent
from ..agents.weather_agent import weather_agent

if TYPE_CHECKING:
from agent_framework import ChatOptions
from agent_framework._clients import BaseChatClient

# Configure logging to file and console (disabled by default - set ENABLE_DEBUG_LOGGING=1 to enable)
if os.getenv("ENABLE_DEBUG_LOGGING"):
log_file = os.path.join(os.path.dirname(os.path.abspath(__file__)), "..", "..", "ag_ui_server.log")
Expand Down
38 changes: 27 additions & 11 deletions python/packages/ag-ui/tests/test_agent_wrapper_comprehensive.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,17 +420,22 @@ async def stream_fn(


async def test_thread_metadata_tracking():
"""Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id."""
"""Test that thread metadata includes ag_ui_thread_id and ag_ui_run_id.

Note: AG-UI internal metadata (ag_ui_thread_id, ag_ui_run_id) is stored in
thread.metadata for orchestration purposes, but is NOT passed to chat clients
via options.metadata since external clients may not accept these fields.
"""
from agent_framework.ag_ui import AgentFrameworkAgent

thread_metadata: dict[str, Any] = {}
captured_thread: list[Any] = []

async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
metadata = options.get("metadata")
if metadata:
thread_metadata.update(metadata)
thread = kwargs.get("thread")
if thread:
captured_thread.append(thread)
yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])

agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
Expand All @@ -446,22 +451,30 @@ async def stream_fn(
async for event in wrapper.run_agent(input_data):
events.append(event)

# AG-UI internal metadata is stored in thread.metadata (not in options.metadata)
assert len(captured_thread) > 0, "Thread should be passed to chat client"
thread_metadata = getattr(captured_thread[0], "metadata", {})
assert thread_metadata.get("ag_ui_thread_id") == "test_thread_123"
assert thread_metadata.get("ag_ui_run_id") == "test_run_456"


async def test_state_context_injection():
"""Test that current state is injected into thread metadata."""
from agent_framework_ag_ui import AgentFrameworkAgent
"""Test that current state is injected into thread metadata.

thread_metadata: dict[str, Any] = {}
Note: AG-UI internal metadata (including current_state) is stored in
thread.metadata for orchestration purposes, but is NOT passed to chat clients
via options.metadata since external clients may not accept these fields.
"""
from agent_framework.ag_ui import AgentFrameworkAgent

captured_thread: list[Any] = []

async def stream_fn(
messages: MutableSequence[ChatMessage], options: dict[str, Any], **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
metadata = options.get("metadata")
if metadata:
thread_metadata.update(metadata)
thread = kwargs.get("thread")
if thread:
captured_thread.append(thread)
yield ChatResponseUpdate(contents=[Content.from_text(text="Hello")])

agent = ChatAgent(name="test_agent", instructions="Test", chat_client=StreamingChatClientStub(stream_fn))
Expand All @@ -479,6 +492,9 @@ async def stream_fn(
async for event in wrapper.run_agent(input_data):
events.append(event)

# AG-UI internal metadata is stored in thread.metadata (not in options.metadata)
assert len(captured_thread) > 0, "Thread should be passed to chat client"
thread_metadata = getattr(captured_thread[0], "metadata", {})
current_state = thread_metadata.get("current_state")
if isinstance(current_state, str):
current_state = json.loads(current_state)
Expand Down
88 changes: 88 additions & 0 deletions python/packages/ag-ui/tests/test_orchestrators.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,3 +305,91 @@ async def test_state_context_not_injected_when_tool_call_matches_state() -> None
if content.type == "text" and content.text.startswith("Current state of the application:"):
state_messages.append(content.text)
assert not state_messages


def test_options_filtered_from_tool_kwargs() -> None:
"""Verify 'options' is filtered when creating tool_kwargs from run_kwargs.

The AG-UI orchestrator adds 'options' (containing metadata/store for Azure AI)
to run_kwargs, but this should NOT be passed to _try_execute_function_calls
as external tools like MCP servers don't understand these kwargs.

This test verifies the filtering logic inline rather than through the full
orchestrator flow, matching the pattern in _resolve_approval_responses:
tool_kwargs = {k: v for k, v in run_kwargs.items() if k != "options"}
"""
# Simulate the run_kwargs that the orchestrator creates
run_kwargs: dict[str, Any] = {
"thread": MagicMock(),
"tools": [server_tool],
"options": {"metadata": {"thread_id": "test-123"}, "store": True},
}

# This is the exact filtering logic from _resolve_approval_responses
tool_kwargs = {k: v for k, v in run_kwargs.items() if k != "options"}

# Verify 'options' was filtered out
assert "options" not in tool_kwargs, "'options' should be filtered out before tool execution"

# Verify other kwargs are preserved
assert "thread" in tool_kwargs, "'thread' should be preserved"
assert "tools" in tool_kwargs, "'tools' should be preserved"

# Verify the original run_kwargs still has options (it's needed for run_stream)
assert "options" in run_kwargs, "Original run_kwargs should still have 'options'"


def test_orchestrator_filters_options_in_resolve_approval_responses() -> None:
"""Verify the orchestrator code filters 'options' before tool execution.

This is a code inspection test that verifies the fix is present in the
_resolve_approval_responses function within DefaultOrchestrator.run().
"""
import inspect

# Get the source code of the DefaultOrchestrator.run method
source = inspect.getsource(DefaultOrchestrator.run)

# Verify the filtering pattern is present
assert 'k != "options"' in source, (
"Expected 'options' filtering in DefaultOrchestrator.run(). "
"The line 'tool_kwargs = {k: v for k, v in run_kwargs.items() if k != \"options\"}' "
"should be present in _resolve_approval_responses."
)

# Verify tool_kwargs is passed to _try_execute_function_calls (not run_kwargs)
assert "custom_args=tool_kwargs" in source, (
"Expected _try_execute_function_calls to receive tool_kwargs (not run_kwargs). "
"This ensures 'options' is filtered out before tool execution."
)


def test_agui_internal_metadata_filtered_from_client_metadata() -> None:
"""Verify AG-UI internal metadata is filtered before passing to chat client.

AG-UI internal fields like 'ag_ui_thread_id', 'ag_ui_run_id', and 'current_state'
are used for orchestration tracking but should NOT be passed to chat clients
(e.g., Anthropic API only accepts 'user_id' in metadata).
"""
import inspect

# Get the source code of the DefaultOrchestrator.run method
source = inspect.getsource(DefaultOrchestrator.run)

# Verify the AG-UI internal metadata keys are defined
assert "AG_UI_INTERNAL_METADATA_KEYS" in source, (
"Expected AG_UI_INTERNAL_METADATA_KEYS to be defined for filtering internal metadata."
)

# Verify the internal keys include the AG-UI specific fields
assert '"ag_ui_thread_id"' in source, "Expected 'ag_ui_thread_id' to be filtered"
assert '"ag_ui_run_id"' in source, "Expected 'ag_ui_run_id' to be filtered"
assert '"current_state"' in source, "Expected 'current_state' to be filtered"

# Verify client_metadata is used instead of safe_metadata for options
assert "client_metadata = {k: v for k, v in safe_metadata.items()" in source, (
"Expected client_metadata to be created by filtering safe_metadata."
)
assert '"options": {"metadata": client_metadata}' in source, (
"Expected client_metadata (not safe_metadata) to be passed in options."
)
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@
ANTHROPIC_DEFAULT_MAX_TOKENS: Final[int] = 1024
BETA_FLAGS: Final[list[str]] = ["mcp-client-2025-04-04", "code-execution-2025-08-25"]


# region Anthropic Chat Options TypedDict


Expand Down Expand Up @@ -145,6 +144,7 @@ class AnthropicChatOptions(ChatOptions, total=False):
frequency_penalty: None # type: ignore[misc]
presence_penalty: None # type: ignore[misc]
store: None # type: ignore[misc]
conversation_id: None # type: ignore[misc]


TAnthropicOptions = TypeVar(
Expand Down Expand Up @@ -384,7 +384,7 @@ def _prepare_options(

messages = prepend_instructions_to_messages(list(messages), instructions, role="system")

# Start with a copy of options
# Start with a copy of options, excluding already-handled options
run_options: dict[str, Any] = {k: v for k, v in options.items() if v is not None and k not in {"instructions"}}

# Translation between options keys and Anthropic Messages API
Expand Down
6 changes: 5 additions & 1 deletion python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,8 +754,12 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Content] | Any
# Filter out framework kwargs that cannot be serialized by the MCP SDK.
# These are internal objects passed through the function invocation pipeline
# that should not be forwarded to external MCP servers.
# conversation_id is an internal tracking ID used by services like Azure AI.
# options contains metadata/store used by AG-UI for Azure AI client requirements.
filtered_kwargs = {
k: v for k, v in kwargs.items() if k not in {"chat_options", "tools", "tool_choice", "thread"}
k: v
for k, v in kwargs.items()
if k not in {"chat_options", "tools", "tool_choice", "thread", "conversation_id", "options"}
}

# Try the operation, reconnecting once if the connection is closed
Expand Down
3 changes: 2 additions & 1 deletion python/packages/core/agent_framework/_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1550,10 +1550,11 @@ async def _auto_invoke_function(
parsed_args: dict[str, Any] = dict(function_call_content.parse_arguments() or {})

# Filter out internal framework kwargs before passing to tools.
# conversation_id is an internal tracking ID that should not be forwarded to tools.
runtime_kwargs: dict[str, Any] = {
key: value
for key, value in (custom_args or {}).items()
if key not in {"_function_middleware_pipeline", "middleware"}
if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"}
}
try:
args = tool.input_model.model_validate(parsed_args)
Expand Down
Loading
Loading