Skip to content

Commit 9676158

Browse files
authored
Merge branch 'main' into client_notif_support
2 parents 1c53a79 + 8ac0cab commit 9676158

File tree

27 files changed

+1043
-362
lines changed

27 files changed

+1043
-362
lines changed

README.md

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2241,12 +2241,12 @@ Run from the repository root:
22412241
import asyncio
22422242

22432243
from mcp import ClientSession
2244-
from mcp.client.streamable_http import streamablehttp_client
2244+
from mcp.client.streamable_http import streamable_http_client
22452245

22462246

22472247
async def main():
22482248
# Connect to a streamable HTTP server
2249-
async with streamablehttp_client("http://localhost:8000/mcp") as (
2249+
async with streamable_http_client("http://localhost:8000/mcp") as (
22502250
read_stream,
22512251
write_stream,
22522252
_,
@@ -2420,11 +2420,12 @@ cd to the `examples/snippets` directory and run:
24202420
import asyncio
24212421
from urllib.parse import parse_qs, urlparse
24222422

2423+
import httpx
24232424
from pydantic import AnyUrl
24242425

24252426
from mcp import ClientSession
24262427
from mcp.client.auth import OAuthClientProvider, TokenStorage
2427-
from mcp.client.streamable_http import streamablehttp_client
2428+
from mcp.client.streamable_http import streamable_http_client
24282429
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
24292430

24302431

@@ -2478,15 +2479,16 @@ async def main():
24782479
callback_handler=handle_callback,
24792480
)
24802481

2481-
async with streamablehttp_client("http://localhost:8001/mcp", auth=oauth_auth) as (read, write, _):
2482-
async with ClientSession(read, write) as session:
2483-
await session.initialize()
2482+
async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client:
2483+
async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write, _):
2484+
async with ClientSession(read, write) as session:
2485+
await session.initialize()
24842486

2485-
tools = await session.list_tools()
2486-
print(f"Available tools: {[tool.name for tool in tools.tools]}")
2487+
tools = await session.list_tools()
2488+
print(f"Available tools: {[tool.name for tool in tools.tools]}")
24872489

2488-
resources = await session.list_resources()
2489-
print(f"Available resources: {[r.uri for r in resources.resources]}")
2490+
resources = await session.list_resources()
2491+
print(f"Available resources: {[r.uri for r in resources.resources]}")
24902492

24912493

24922494
def run():

examples/clients/simple-auth-client/mcp_simple_auth_client/main.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@
1111
import threading
1212
import time
1313
import webbrowser
14-
from datetime import timedelta
1514
from http.server import BaseHTTPRequestHandler, HTTPServer
1615
from typing import Any
1716
from urllib.parse import parse_qs, urlparse
1817

18+
import httpx
1919
from mcp.client.auth import OAuthClientProvider, TokenStorage
2020
from mcp.client.session import ClientSession
2121
from mcp.client.sse import sse_client
22-
from mcp.client.streamable_http import streamablehttp_client
22+
from mcp.client.streamable_http import streamable_http_client
2323
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
2424

2525

@@ -193,7 +193,7 @@ async def _default_redirect_handler(authorization_url: str) -> None:
193193
# Create OAuth authentication handler using the new interface
194194
# Use client_metadata_url to enable CIMD when the server supports it
195195
oauth_auth = OAuthClientProvider(
196-
server_url=self.server_url,
196+
server_url=self.server_url.replace("/mcp", ""),
197197
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
198198
storage=InMemoryTokenStorage(),
199199
redirect_handler=_default_redirect_handler,
@@ -212,12 +212,12 @@ async def _default_redirect_handler(authorization_url: str) -> None:
212212
await self._run_session(read_stream, write_stream, None)
213213
else:
214214
print("📡 Opening StreamableHTTP transport connection with auth...")
215-
async with streamablehttp_client(
216-
url=self.server_url,
217-
auth=oauth_auth,
218-
timeout=timedelta(seconds=60),
219-
) as (read_stream, write_stream, get_session_id):
220-
await self._run_session(read_stream, write_stream, get_session_id)
215+
async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client:
216+
async with streamable_http_client(
217+
url=self.server_url,
218+
http_client=custom_client,
219+
) as (read_stream, write_stream, get_session_id):
220+
await self._run_session(read_stream, write_stream, get_session_id)
221221

222222
except Exception as e:
223223
print(f"❌ Failed to connect: {e}")

examples/snippets/clients/oauth_client.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,12 @@
1010
import asyncio
1111
from urllib.parse import parse_qs, urlparse
1212

13+
import httpx
1314
from pydantic import AnyUrl
1415

1516
from mcp import ClientSession
1617
from mcp.client.auth import OAuthClientProvider, TokenStorage
17-
from mcp.client.streamable_http import streamablehttp_client
18+
from mcp.client.streamable_http import streamable_http_client
1819
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
1920

2021

@@ -68,15 +69,16 @@ async def main():
6869
callback_handler=handle_callback,
6970
)
7071

71-
async with streamablehttp_client("http://localhost:8001/mcp", auth=oauth_auth) as (read, write, _):
72-
async with ClientSession(read, write) as session:
73-
await session.initialize()
72+
async with httpx.AsyncClient(auth=oauth_auth, follow_redirects=True) as custom_client:
73+
async with streamable_http_client("http://localhost:8001/mcp", http_client=custom_client) as (read, write, _):
74+
async with ClientSession(read, write) as session:
75+
await session.initialize()
7476

75-
tools = await session.list_tools()
76-
print(f"Available tools: {[tool.name for tool in tools.tools]}")
77+
tools = await session.list_tools()
78+
print(f"Available tools: {[tool.name for tool in tools.tools]}")
7779

78-
resources = await session.list_resources()
79-
print(f"Available resources: {[r.uri for r in resources.resources]}")
80+
resources = await session.list_resources()
81+
print(f"Available resources: {[r.uri for r in resources.resources]}")
8082

8183

8284
def run():

examples/snippets/clients/streamable_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@
66
import asyncio
77

88
from mcp import ClientSession
9-
from mcp.client.streamable_http import streamablehttp_client
9+
from mcp.client.streamable_http import streamable_http_client
1010

1111

1212
async def main():
1313
# Connect to a streamable HTTP server
14-
async with streamablehttp_client("http://localhost:8000/mcp") as (
14+
async with streamable_http_client("http://localhost:8000/mcp") as (
1515
read_stream,
1616
write_stream,
1717
_,

src/mcp/client/auth/oauth2.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import httpx
2020
from pydantic import BaseModel, Field, ValidationError
2121

22-
from mcp.client.auth.exceptions import OAuthFlowError, OAuthRegistrationError, OAuthTokenError
22+
from mcp.client.auth.exceptions import OAuthFlowError, OAuthTokenError
2323
from mcp.client.auth.utils import (
2424
build_oauth_authorization_server_metadata_discovery_urls,
2525
build_protected_resource_metadata_discovery_urls,
@@ -299,44 +299,6 @@ async def _handle_protected_resource_response(self, response: httpx.Response) ->
299299
f"Protected Resource Metadata request failed: {response.status_code}"
300300
) # pragma: no cover
301301

302-
async def _register_client(self) -> httpx.Request | None:
303-
"""Build registration request or skip if already registered."""
304-
if self.context.client_info:
305-
return None
306-
307-
if self.context.oauth_metadata and self.context.oauth_metadata.registration_endpoint:
308-
registration_url = str(self.context.oauth_metadata.registration_endpoint) # pragma: no cover
309-
else:
310-
auth_base_url = self.context.get_authorization_base_url(self.context.server_url)
311-
registration_url = urljoin(auth_base_url, "/register")
312-
313-
registration_data = self.context.client_metadata.model_dump(by_alias=True, mode="json", exclude_none=True)
314-
315-
# If token_endpoint_auth_method is None, auto-select based on server support
316-
if self.context.client_metadata.token_endpoint_auth_method is None:
317-
preference_order = ["client_secret_basic", "client_secret_post", "none"]
318-
319-
if self.context.oauth_metadata and self.context.oauth_metadata.token_endpoint_auth_methods_supported:
320-
supported = self.context.oauth_metadata.token_endpoint_auth_methods_supported
321-
for method in preference_order:
322-
if method in supported:
323-
registration_data["token_endpoint_auth_method"] = method
324-
break
325-
else:
326-
# No compatible methods between client and server
327-
raise OAuthRegistrationError(
328-
f"No compatible authentication methods. "
329-
f"Server supports: {supported}, "
330-
f"Client supports: {preference_order}"
331-
)
332-
else:
333-
# No server metadata available, use our default preference
334-
registration_data["token_endpoint_auth_method"] = preference_order[0]
335-
336-
return httpx.Request(
337-
"POST", registration_url, json=registration_data, headers={"Content-Type": "application/json"}
338-
)
339-
340302
async def _perform_authorization(self) -> httpx.Request:
341303
"""Perform the authorization flow."""
342304
auth_code, code_verifier = await self._perform_authorization_code_grant()

src/mcp/client/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async def __call__(
2525
self,
2626
context: RequestContext["ClientSession", Any],
2727
params: types.CreateMessageRequestParams,
28-
) -> types.CreateMessageResult | types.ErrorData: ... # pragma: no branch
28+
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
2929

3030

3131
class ElicitationFnT(Protocol):
@@ -104,7 +104,7 @@ async def _default_message_handler(
104104
async def _default_sampling_callback(
105105
context: RequestContext["ClientSession", Any],
106106
params: types.CreateMessageRequestParams,
107-
) -> types.CreateMessageResult | types.ErrorData:
107+
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData:
108108
return types.ErrorData(
109109
code=types.INVALID_REQUEST,
110110
message="Sampling not supported",
@@ -196,6 +196,7 @@ def __init__(
196196
message_handler: MessageHandlerFnT | None = None,
197197
client_info: types.Implementation | None = None,
198198
*,
199+
sampling_capabilities: types.SamplingCapability | None = None,
199200
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
200201
) -> None:
201202
super().__init__(
@@ -207,6 +208,7 @@ def __init__(
207208
)
208209
self._client_info = client_info or DEFAULT_CLIENT_INFO
209210
self._sampling_callback = sampling_callback or _default_sampling_callback
211+
self._sampling_capabilities = sampling_capabilities
210212
self._elicitation_callback = elicitation_callback or _default_elicitation_callback
211213
self._elicit_complete_callback = elicit_complete_callback or _default_elicit_complete_callback
212214
self._list_roots_callback = list_roots_callback or _default_list_roots_callback
@@ -225,7 +227,11 @@ def __init__(
225227
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
226228

227229
async def initialize(self) -> types.InitializeResult:
228-
sampling = types.SamplingCapability() if self._sampling_callback is not _default_sampling_callback else None
230+
sampling = (
231+
(self._sampling_capabilities or types.SamplingCapability())
232+
if self._sampling_callback is not _default_sampling_callback
233+
else None
234+
)
229235
elicitation = (
230236
types.ElicitationCapability(
231237
form=types.FormElicitationCapability(),

src/mcp/client/session_group.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from typing import Any, TypeAlias, overload
1818

1919
import anyio
20+
import httpx
2021
from pydantic import BaseModel
2122
from typing_extensions import Self, deprecated
2223

@@ -25,7 +26,8 @@
2526
from mcp.client.session import ElicitationFnT, ListRootsFnT, LoggingFnT, MessageHandlerFnT, SamplingFnT
2627
from mcp.client.sse import sse_client
2728
from mcp.client.stdio import StdioServerParameters
28-
from mcp.client.streamable_http import streamablehttp_client
29+
from mcp.client.streamable_http import streamable_http_client
30+
from mcp.shared._httpx_utils import create_mcp_http_client
2931
from mcp.shared.exceptions import McpError
3032
from mcp.shared.session import ProgressFnT
3133

@@ -47,7 +49,7 @@ class SseServerParameters(BaseModel):
4749

4850

4951
class StreamableHttpParameters(BaseModel):
50-
"""Parameters for intializing a streamablehttp_client."""
52+
"""Parameters for intializing a streamable_http_client."""
5153

5254
# The endpoint URL.
5355
url: str
@@ -309,11 +311,18 @@ async def _establish_session(
309311
)
310312
read, write = await session_stack.enter_async_context(client)
311313
else:
312-
client = streamablehttp_client(
313-
url=server_params.url,
314+
httpx_client = create_mcp_http_client(
314315
headers=server_params.headers,
315-
timeout=server_params.timeout,
316-
sse_read_timeout=server_params.sse_read_timeout,
316+
timeout=httpx.Timeout(
317+
server_params.timeout.total_seconds(),
318+
read=server_params.sse_read_timeout.total_seconds(),
319+
),
320+
)
321+
await session_stack.enter_async_context(httpx_client)
322+
323+
client = streamable_http_client(
324+
url=server_params.url,
325+
http_client=httpx_client,
317326
terminate_on_close=server_params.terminate_on_close,
318327
)
319328
read, write, _ = await session_stack.enter_async_context(client)

src/mcp/client/sse.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,9 @@ async def sse_reader(
105105
task_status.started(endpoint_url)
106106

107107
case "message":
108+
# Skip empty data (keep-alive pings)
109+
if not sse.data:
110+
continue
108111
try:
109112
message = types.JSONRPCMessage.model_validate_json( # noqa: E501
110113
sse.data

0 commit comments

Comments
 (0)