Skip to content

Commit 900c2dd

Browse files
authored
Merge branch 'main' into fix-307-Temporary-Redirect
2 parents a84399b + f2f4dbd commit 900c2dd

File tree

8 files changed

+100
-13
lines changed

8 files changed

+100
-13
lines changed

examples/clients/simple-chatbot/mcp_simple_chatbot/main.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ def get_response(self, messages: list[dict[str, str]]) -> str:
245245
}
246246
payload = {
247247
"messages": messages,
248-
"model": "llama-3.2-90b-vision-preview",
248+
"model": "meta-llama/llama-4-scout-17b-16e-instruct",
249249
"temperature": 0.7,
250250
"max_tokens": 4096,
251251
"top_p": 1,
@@ -284,12 +284,9 @@ def __init__(self, servers: list[Server], llm_client: LLMClient) -> None:
284284

285285
async def cleanup_servers(self) -> None:
286286
"""Clean up all servers properly."""
287-
cleanup_tasks = [
288-
asyncio.create_task(server.cleanup()) for server in self.servers
289-
]
290-
if cleanup_tasks:
287+
for server in reversed(self.servers):
291288
try:
292-
await asyncio.gather(*cleanup_tasks, return_exceptions=True)
289+
await server.cleanup()
293290
except Exception as e:
294291
logging.warning(f"Warning during final cleanup: {e}")
295292

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,13 @@ members = ["examples/servers/*"]
109109
mcp = { workspace = true }
110110

111111
[tool.pytest.ini_options]
112+
log_cli = true
112113
xfail_strict = true
114+
addopts = """
115+
--color=yes
116+
--capture=fd
117+
--numprocesses auto
118+
"""
113119
filterwarnings = [
114120
"error",
115121
# This should be fixed on Uvicorn's side.

src/mcp/client/sse.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from httpx_sse import aconnect_sse
1111

1212
import mcp.types as types
13-
from mcp.shared._httpx_utils import create_mcp_http_client
13+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
1414
from mcp.shared.message import SessionMessage
1515

1616
logger = logging.getLogger(__name__)
@@ -26,6 +26,7 @@ async def sse_client(
2626
headers: dict[str, Any] | None = None,
2727
timeout: float = 5,
2828
sse_read_timeout: float = 60 * 5,
29+
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
2930
auth: httpx.Auth | None = None,
3031
):
3132
"""
@@ -53,7 +54,7 @@ async def sse_client(
5354
async with anyio.create_task_group() as tg:
5455
try:
5556
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
56-
async with create_mcp_http_client(headers=headers, auth=auth) as client:
57+
async with httpx_client_factory(headers=headers, auth=auth) as client:
5758
async with aconnect_sse(
5859
client,
5960
"GET",

src/mcp/client/streamable_http.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
2020
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
2121

22-
from mcp.shared._httpx_utils import create_mcp_http_client
22+
from mcp.shared._httpx_utils import McpHttpClientFactory, create_mcp_http_client
2323
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2424
from mcp.types import (
2525
ErrorData,
@@ -430,6 +430,7 @@ async def streamablehttp_client(
430430
timeout: timedelta = timedelta(seconds=30),
431431
sse_read_timeout: timedelta = timedelta(seconds=60 * 5),
432432
terminate_on_close: bool = True,
433+
httpx_client_factory: McpHttpClientFactory = create_mcp_http_client,
433434
auth: httpx.Auth | None = None,
434435
) -> AsyncGenerator[
435436
tuple[
@@ -464,7 +465,7 @@ async def streamablehttp_client(
464465
try:
465466
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
466467

467-
async with create_mcp_http_client(
468+
async with httpx_client_factory(
468469
headers=transport.request_headers,
469470
timeout=httpx.Timeout(
470471
transport.timeout.seconds, read=transport.sse_read_timeout.seconds

src/mcp/server/fastmcp/tools/base.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations as _annotations
22

3+
import functools
34
import inspect
45
from collections.abc import Callable
56
from typing import TYPE_CHECKING, Any, get_origin
@@ -53,7 +54,7 @@ def from_function(
5354
raise ValueError("You must provide a name for lambda functions")
5455

5556
func_doc = description or fn.__doc__ or ""
56-
is_async = inspect.iscoroutinefunction(fn)
57+
is_async = _is_async_callable(fn)
5758

5859
if context_kwarg is None:
5960
sig = inspect.signature(fn)
@@ -98,3 +99,12 @@ async def run(
9899
)
99100
except Exception as e:
100101
raise ToolError(f"Error executing tool {self.name}: {e}") from e
102+
103+
104+
def _is_async_callable(obj: Any) -> bool:
105+
while isinstance(obj, functools.partial):
106+
obj = obj.func
107+
108+
return inspect.iscoroutinefunction(obj) or (
109+
callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None))
110+
)

src/mcp/shared/_httpx_utils.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,21 @@
11
"""Utilities for creating standardized httpx AsyncClient instances."""
22

3-
from typing import Any
3+
from typing import Any, Protocol
44

55
import httpx
66

77
__all__ = ["create_mcp_http_client"]
88

99

10+
class McpHttpClientFactory(Protocol):
11+
def __call__(
12+
self,
13+
headers: dict[str, str] | None = None,
14+
timeout: httpx.Timeout | None = None,
15+
auth: httpx.Auth | None = None,
16+
) -> httpx.AsyncClient: ...
17+
18+
1019
def create_mcp_http_client(
1120
headers: dict[str, str] | None = None,
1221
timeout: httpx.Timeout | None = None,

tests/client/test_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,9 @@ def test_command_execution(mock_config_path: Path):
4444

4545
test_args = [command] + args + ["--help"]
4646

47-
result = subprocess.run(test_args, capture_output=True, text=True, timeout=5)
47+
result = subprocess.run(
48+
test_args, capture_output=True, text=True, timeout=5, check=False
49+
)
4850

4951
assert result.returncode == 0
5052
assert "usage" in result.stdout.lower()

tests/server/fastmcp/test_tool_manager.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,39 @@ def create_user(user: UserInput, flag: bool) -> dict:
102102
assert "age" in tool.parameters["$defs"]["UserInput"]["properties"]
103103
assert "flag" in tool.parameters["properties"]
104104

105+
def test_add_callable_object(self):
106+
"""Test registering a callable object."""
107+
108+
class MyTool:
109+
def __init__(self):
110+
self.__name__ = "MyTool"
111+
112+
def __call__(self, x: int) -> int:
113+
return x * 2
114+
115+
manager = ToolManager()
116+
tool = manager.add_tool(MyTool())
117+
assert tool.name == "MyTool"
118+
assert tool.is_async is False
119+
assert tool.parameters["properties"]["x"]["type"] == "integer"
120+
121+
@pytest.mark.anyio
122+
async def test_add_async_callable_object(self):
123+
"""Test registering an async callable object."""
124+
125+
class MyAsyncTool:
126+
def __init__(self):
127+
self.__name__ = "MyAsyncTool"
128+
129+
async def __call__(self, x: int) -> int:
130+
return x * 2
131+
132+
manager = ToolManager()
133+
tool = manager.add_tool(MyAsyncTool())
134+
assert tool.name == "MyAsyncTool"
135+
assert tool.is_async is True
136+
assert tool.parameters["properties"]["x"]["type"] == "integer"
137+
105138
def test_add_invalid_tool(self):
106139
manager = ToolManager()
107140
with pytest.raises(AttributeError):
@@ -168,6 +201,34 @@ async def double(n: int) -> int:
168201
result = await manager.call_tool("double", {"n": 5})
169202
assert result == 10
170203

204+
@pytest.mark.anyio
205+
async def test_call_object_tool(self):
206+
class MyTool:
207+
def __init__(self):
208+
self.__name__ = "MyTool"
209+
210+
def __call__(self, x: int) -> int:
211+
return x * 2
212+
213+
manager = ToolManager()
214+
tool = manager.add_tool(MyTool())
215+
result = await tool.run({"x": 5})
216+
assert result == 10
217+
218+
@pytest.mark.anyio
219+
async def test_call_async_object_tool(self):
220+
class MyAsyncTool:
221+
def __init__(self):
222+
self.__name__ = "MyAsyncTool"
223+
224+
async def __call__(self, x: int) -> int:
225+
return x * 2
226+
227+
manager = ToolManager()
228+
tool = manager.add_tool(MyAsyncTool())
229+
result = await tool.run({"x": 5})
230+
assert result == 10
231+
171232
@pytest.mark.anyio
172233
async def test_call_tool_with_default_args(self):
173234
def add(a: int, b: int = 1) -> int:

0 commit comments

Comments
 (0)