Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion python/packages/core/agent_framework/_mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,10 +756,12 @@ async def call_tool(self, tool_name: str, **kwargs: Any) -> list[Content] | Any
# that should not be forwarded to external MCP servers.
# conversation_id is an internal tracking ID used by services like Azure AI.
# options contains metadata/store used by AG-UI for Azure AI client requirements.
# response_format is a Pydantic model class used for structured output (not serializable).
filtered_kwargs = {
k: v
for k, v in kwargs.items()
if k not in {"chat_options", "tools", "tool_choice", "thread", "conversation_id", "options"}
if k
not in {"chat_options", "tools", "tool_choice", "thread", "conversation_id", "options", "response_format"}
}

# Try the operation, reconnecting once if the connection is closed
Expand Down
194 changes: 194 additions & 0 deletions python/packages/core/tests/core/test_mcp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
# type: ignore[reportPrivateUsage]
import logging
import os
from contextlib import _AsyncGeneratorContextManager # type: ignore
from typing import Any
Expand Down Expand Up @@ -30,6 +31,7 @@
_parse_message_from_mcp,
_prepare_content_for_mcp,
_prepare_message_for_mcp,
logger,
)
from agent_framework.exceptions import ToolException, ToolExecutionException

Expand Down Expand Up @@ -2514,3 +2516,195 @@ async def test_mcp_tool_safe_close_handles_cancelled_error():

# Verify aclose was called
mock_exit_stack.aclose.assert_called_once()


async def test_connect_sets_logging_level_when_logger_level_is_set():
"""Test that connect() sets the MCP server logging level when the logger level is not NOTSET."""

tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=False,
load_prompts=False,
)

# Mock the transport and session
mock_transport = (Mock(), Mock())
mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(return_value=mock_transport)
mock_context.__aexit__ = AsyncMock()

mock_session = Mock()
mock_session._request_id = 1
mock_session.initialize = AsyncMock()
mock_session.set_logging_level = AsyncMock()

mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock()

with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
patch.object(logger, "level", logging.DEBUG), # Set logger level to DEBUG
):
await tool.connect()

# Verify set_logging_level was called with "debug"
mock_session.set_logging_level.assert_called_once_with("debug")


async def test_connect_does_not_set_logging_level_when_logger_level_is_notset():
"""Test that connect() does not set logging level when logger level is NOTSET."""

tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=False,
load_prompts=False,
)

# Mock the transport and session
mock_transport = (Mock(), Mock())
mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(return_value=mock_transport)
mock_context.__aexit__ = AsyncMock()

mock_session = Mock()
mock_session._request_id = 1
mock_session.initialize = AsyncMock()
mock_session.set_logging_level = AsyncMock()

mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock()

with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
patch.object(logger, "level", logging.NOTSET), # Set logger level to NOTSET
):
await tool.connect()

# Verify set_logging_level was NOT called
mock_session.set_logging_level.assert_not_called()


async def test_connect_handles_set_logging_level_exception():
"""Test that connect() handles exceptions from set_logging_level gracefully."""

tool = MCPStdioTool(
name="test_server",
command="test_command",
args=["arg1"],
load_tools=False,
load_prompts=False,
)

# Mock the transport and session
mock_transport = (Mock(), Mock())
mock_context = AsyncMock()
mock_context.__aenter__ = AsyncMock(return_value=mock_transport)
mock_context.__aexit__ = AsyncMock()

mock_session = Mock()
mock_session._request_id = 1
mock_session.initialize = AsyncMock()
# Make set_logging_level raise an exception
mock_session.set_logging_level = AsyncMock(side_effect=RuntimeError("Server doesn't support logging level"))

mock_session_context = AsyncMock()
mock_session_context.__aenter__ = AsyncMock(return_value=mock_session)
mock_session_context.__aexit__ = AsyncMock()

with (
patch.object(tool, "get_mcp_client", return_value=mock_context),
patch("agent_framework._mcp.ClientSession", return_value=mock_session_context),
patch.object(logger, "level", logging.INFO), # Set logger level to INFO
patch.object(logger, "warning") as mock_warning,
):
# Should NOT raise - the exception should be caught and logged
await tool.connect()

# Verify set_logging_level was called
mock_session.set_logging_level.assert_called_once_with("info")

# Verify warning was logged
mock_warning.assert_called_once()
call_args = mock_warning.call_args
assert "Failed to set log level" in call_args[0][0]


async def test_mcp_tool_filters_framework_kwargs():
"""Test that call_tool filters out framework-specific kwargs before calling MCP session.

This verifies that non-serializable kwargs like response_format (Pydantic model class),
chat_options, tools, tool_choice, thread, conversation_id, and options are filtered out
before being passed to the external MCP server.
"""

class TestServer(MCPTool):
async def connect(self):
self.session = Mock(spec=ClientSession)
self.session.list_tools = AsyncMock(
return_value=types.ListToolsResult(
tools=[
types.Tool(
name="test_tool",
description="Test tool",
inputSchema={
"type": "object",
"properties": {"param": {"type": "string"}},
"required": ["param"],
},
)
]
)
)
# Mock call_tool to capture the arguments it receives
self.session.call_tool = AsyncMock(
return_value=types.CallToolResult(content=[types.TextContent(type="text", text="Success")])
)

def get_mcp_client(self) -> _AsyncGeneratorContextManager[Any, None]:
return None

# Create a mock Pydantic model class to use as response_format
class MockResponseFormat(BaseModel):
result: str

server = TestServer(name="test_server")
async with server:
await server.load_tools()
func = server.functions[0]

# Invoke the tool with framework kwargs that should be filtered out
await func.invoke(
param="test_value",
response_format=MockResponseFormat, # Should be filtered
chat_options={"some": "option"}, # Should be filtered
tools=[Mock()], # Should be filtered
tool_choice="auto", # Should be filtered
thread=Mock(), # Should be filtered
conversation_id="conv-123", # Should be filtered
options={"metadata": "value"}, # Should be filtered
)

# Verify call_tool was called with only the valid argument
server.session.call_tool.assert_called_once()
call_args = server.session.call_tool.call_args

# Check that the arguments dict only contains 'param' and none of the framework kwargs
arguments = call_args.kwargs.get("arguments", call_args[1] if len(call_args) > 1 else {})
assert arguments == {"param": "test_value"}, f"Expected only 'param' but got: {arguments}"

# Explicitly verify that framework kwargs were NOT passed
assert "response_format" not in arguments
assert "chat_options" not in arguments
assert "tools" not in arguments
assert "tool_choice" not in arguments
assert "thread" not in arguments
assert "conversation_id" not in arguments
assert "options" not in arguments
Loading