Skip to content

Commit 508d8e7

Browse files
authored
Merge branch 'main' into fix-leaked-stream
2 parents a0afd1d + 239d682 commit 508d8e7

File tree

8 files changed

+335
-134
lines changed

8 files changed

+335
-134
lines changed

.github/actions/conformance/client.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,27 @@ async def run_client_credentials_basic(server_url: str) -> None:
275275
async def run_auth_code_client(server_url: str) -> None:
276276
"""Authorization code flow (default for auth/* scenarios)."""
277277
callback_handler = ConformanceOAuthCallbackHandler()
278+
storage = InMemoryTokenStorage()
279+
280+
# Check for pre-registered client credentials from context
281+
context_json = os.environ.get("MCP_CONFORMANCE_CONTEXT")
282+
if context_json:
283+
try:
284+
context = json.loads(context_json)
285+
client_id = context.get("client_id")
286+
client_secret = context.get("client_secret")
287+
if client_id:
288+
await storage.set_client_info(
289+
OAuthClientInformationFull(
290+
client_id=client_id,
291+
client_secret=client_secret,
292+
redirect_uris=[AnyUrl("http://localhost:3000/callback")],
293+
token_endpoint_auth_method="client_secret_basic" if client_secret else "none",
294+
)
295+
)
296+
logger.debug(f"Pre-loaded client credentials: client_id={client_id}")
297+
except json.JSONDecodeError:
298+
logger.exception("Failed to parse MCP_CONFORMANCE_CONTEXT")
278299

279300
oauth_auth = OAuthClientProvider(
280301
server_url=server_url,
@@ -284,7 +305,7 @@ async def run_auth_code_client(server_url: str) -> None:
284305
grant_types=["authorization_code", "refresh_token"],
285306
response_types=["code"],
286307
),
287-
storage=InMemoryTokenStorage(),
308+
storage=storage,
288309
redirect_handler=callback_handler.handle_redirect,
289310
callback_handler=callback_handler.handle_callback,
290311
client_metadata_url="https://conformance-test.local/client-metadata.json",

.github/workflows/conformance.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,4 @@ jobs:
4242
with:
4343
node-version: 24
4444
- run: uv sync --frozen --all-extras --package mcp
45-
- run: npx @modelcontextprotocol/conformance@0.1.10 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all
45+
- run: npx @modelcontextprotocol/conformance@0.1.13 client --command 'uv run --frozen python .github/actions/conformance/client.py' --suite all

CLAUDE.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ This document contains critical information about working with this codebase. Fo
2929
- IMPORTANT: The `tests/client/test_client.py` is the most well designed test file. Follow its patterns.
3030
- IMPORTANT: Be minimal, and focus on E2E tests: Use the `mcp.client.Client` whenever possible.
3131

32+
Test files mirror the source tree: `src/mcp/client/streamable_http.py``tests/client/test_streamable_http.py`
33+
Add tests to the existing file for that module.
34+
3235
- For commits fixing bugs or adding features based on user reports add:
3336

3437
```bash

src/mcp/client/auth/oauth2.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,7 @@ def __init__(
229229
callback_handler: Callable[[], Awaitable[tuple[str, str | None]]] | None = None,
230230
timeout: float = 300.0,
231231
client_metadata_url: str | None = None,
232+
validate_resource_url: Callable[[str, str | None], Awaitable[None]] | None = None,
232233
):
233234
"""Initialize OAuth2 authentication.
234235
@@ -243,6 +244,10 @@ def __init__(
243244
advertises client_id_metadata_document_supported=true, this URL will be
244245
used as the client_id instead of performing dynamic client registration.
245246
Must be a valid HTTPS URL with a non-root pathname.
247+
validate_resource_url: Optional callback to override resource URL validation.
248+
Called with (server_url, prm_resource) where prm_resource is the resource
249+
from Protected Resource Metadata (or None if not present). If not provided,
250+
default validation rejects mismatched resources per RFC 8707.
246251
247252
Raises:
248253
ValueError: If client_metadata_url is provided but not a valid HTTPS URL
@@ -263,6 +268,7 @@ def __init__(
263268
timeout=timeout,
264269
client_metadata_url=client_metadata_url,
265270
)
271+
self._validate_resource_url_callback = validate_resource_url
266272
self._initialized = False
267273

268274
async def _handle_protected_resource_response(self, response: httpx.Response) -> bool:
@@ -476,6 +482,26 @@ async def _handle_oauth_metadata_response(self, response: httpx.Response) -> Non
476482
metadata = OAuthMetadata.model_validate_json(content)
477483
self.context.oauth_metadata = metadata
478484

485+
async def _validate_resource_match(self, prm: ProtectedResourceMetadata) -> None:
486+
"""Validate that PRM resource matches the server URL per RFC 8707."""
487+
prm_resource = str(prm.resource) if prm.resource else None
488+
489+
if self._validate_resource_url_callback is not None:
490+
await self._validate_resource_url_callback(self.context.server_url, prm_resource)
491+
return
492+
493+
if not prm_resource:
494+
return # pragma: no cover
495+
default_resource = resource_url_from_server_url(self.context.server_url)
496+
# Normalize: Pydantic AnyHttpUrl adds trailing slash to root URLs
497+
# (e.g. "https://example.com/") while resource_url_from_server_url may not.
498+
if not default_resource.endswith("/"):
499+
default_resource += "/"
500+
if not prm_resource.endswith("/"):
501+
prm_resource += "/"
502+
if not check_resource_allowed(requested_resource=default_resource, configured_resource=prm_resource):
503+
raise OAuthFlowError(f"Protected resource {prm_resource} does not match expected {default_resource}")
504+
479505
async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.Request, httpx.Response]:
480506
"""HTTPX auth flow integration."""
481507
async with self.context.lock:
@@ -517,6 +543,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
517543

518544
prm = await handle_protected_resource_response(discovery_response)
519545
if prm:
546+
# Validate PRM resource matches server URL (RFC 8707)
547+
await self._validate_resource_match(prm)
520548
self.context.protected_resource_metadata = prm
521549

522550
# todo: try all authorization_servers to find the OASM

src/mcp/client/streamable_http.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
from anyio.abc import TaskGroup
1414
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
1515
from httpx_sse import EventSource, ServerSentEvent, aconnect_sse
16+
from pydantic import ValidationError
1617

1718
from mcp.client._transport import TransportStreams
1819
from mcp.shared._httpx_utils import create_mcp_http_client
1920
from mcp.shared.message import ClientMessageMetadata, SessionMessage
2021
from mcp.types import (
22+
INVALID_REQUEST,
23+
PARSE_ERROR,
2124
ErrorData,
2225
InitializeResult,
2326
JSONRPCError,
@@ -163,6 +166,11 @@ async def _handle_sse_event(
163166

164167
except Exception as exc: # pragma: no cover
165168
logger.exception("Error parsing SSE message")
169+
if original_request_id is not None:
170+
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse SSE message: {exc}")
171+
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=original_request_id, error=error_data))
172+
await read_stream_writer.send(error_msg)
173+
return True
166174
await read_stream_writer.send(exc)
167175
return False
168176
else: # pragma: no cover
@@ -260,7 +268,9 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
260268

261269
if response.status_code == 404: # pragma: no branch
262270
if isinstance(message, JSONRPCRequest): # pragma: no branch
263-
await self._send_session_terminated_error(ctx.read_stream_writer, message.id)
271+
error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated")
272+
session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
273+
await ctx.read_stream_writer.send(session_message)
264274
return
265275

266276
response.raise_for_status()
@@ -272,20 +282,24 @@ async def _handle_post_request(self, ctx: RequestContext) -> None:
272282
if isinstance(message, JSONRPCRequest):
273283
content_type = response.headers.get("content-type", "").lower()
274284
if content_type.startswith("application/json"):
275-
await self._handle_json_response(response, ctx.read_stream_writer, is_initialization)
285+
await self._handle_json_response(
286+
response, ctx.read_stream_writer, is_initialization, request_id=message.id
287+
)
276288
elif content_type.startswith("text/event-stream"):
277289
await self._handle_sse_response(response, ctx, is_initialization)
278290
else:
279-
await self._handle_unexpected_content_type( # pragma: no cover
280-
content_type, # pragma: no cover
281-
ctx.read_stream_writer, # pragma: no cover
282-
) # pragma: no cover
291+
logger.error(f"Unexpected content type: {content_type}")
292+
error_data = ErrorData(code=INVALID_REQUEST, message=f"Unexpected content type: {content_type}")
293+
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data))
294+
await ctx.read_stream_writer.send(error_msg)
283295

284296
async def _handle_json_response(
285297
self,
286298
response: httpx.Response,
287299
read_stream_writer: StreamWriter,
288300
is_initialization: bool = False,
301+
*,
302+
request_id: RequestId,
289303
) -> None:
290304
"""Handle JSON response from the server."""
291305
try:
@@ -298,9 +312,11 @@ async def _handle_json_response(
298312

299313
session_message = SessionMessage(message)
300314
await read_stream_writer.send(session_message)
301-
except Exception as exc: # pragma: no cover
315+
except (httpx.StreamError, ValidationError) as exc:
302316
logger.exception("Error parsing JSON response")
303-
await read_stream_writer.send(exc)
317+
error_data = ErrorData(code=PARSE_ERROR, message=f"Failed to parse JSON response: {exc}")
318+
error_msg = SessionMessage(JSONRPCError(jsonrpc="2.0", id=request_id, error=error_data))
319+
await read_stream_writer.send(error_msg)
304320

305321
async def _handle_sse_response(
306322
self,
@@ -312,6 +328,11 @@ async def _handle_sse_response(
312328
last_event_id: str | None = None
313329
retry_interval_ms: int | None = None
314330

331+
# The caller (_handle_post_request) only reaches here inside
332+
# isinstance(message, JSONRPCRequest), so this is always a JSONRPCRequest.
333+
assert isinstance(ctx.session_message.message, JSONRPCRequest)
334+
original_request_id = ctx.session_message.message.id
335+
315336
try:
316337
event_source = EventSource(response)
317338
async for sse in event_source.aiter_sse(): # pragma: no branch
@@ -326,6 +347,7 @@ async def _handle_sse_response(
326347
is_complete = await self._handle_sse_event(
327348
sse,
328349
ctx.read_stream_writer,
350+
original_request_id=original_request_id,
329351
resumption_callback=(ctx.metadata.on_resumption_token_update if ctx.metadata else None),
330352
is_initialization=is_initialization,
331353
)
@@ -334,8 +356,8 @@ async def _handle_sse_response(
334356
if is_complete:
335357
await response.aclose()
336358
return # Normal completion, no reconnect needed
337-
except Exception as e:
338-
logger.debug(f"SSE stream ended: {e}") # pragma: no cover
359+
except Exception:
360+
logger.debug("SSE stream ended", exc_info=True) # pragma: no cover
339361

340362
# Stream ended without response - reconnect if we received an event with ID
341363
if last_event_id is not None: # pragma: no branch
@@ -400,24 +422,6 @@ async def _handle_reconnection(
400422
# Try to reconnect again if we still have an event ID
401423
await self._handle_reconnection(ctx, last_event_id, retry_interval_ms, attempt + 1)
402424

403-
async def _handle_unexpected_content_type(
404-
self, content_type: str, read_stream_writer: StreamWriter
405-
) -> None: # pragma: no cover
406-
"""Handle unexpected content type in response."""
407-
error_msg = f"Unexpected content type: {content_type}" # pragma: no cover
408-
logger.error(error_msg) # pragma: no cover
409-
await read_stream_writer.send(ValueError(error_msg)) # pragma: no cover
410-
411-
async def _send_session_terminated_error(self, read_stream_writer: StreamWriter, request_id: RequestId) -> None:
412-
"""Send a session terminated error response."""
413-
jsonrpc_error = JSONRPCError(
414-
jsonrpc="2.0",
415-
id=request_id,
416-
error=ErrorData(code=32600, message="Session terminated"),
417-
)
418-
session_message = SessionMessage(jsonrpc_error)
419-
await read_stream_writer.send(session_message)
420-
421425
async def post_writer(
422426
self,
423427
client: httpx.AsyncClient,
@@ -467,8 +471,8 @@ async def handle_request_async():
467471
else:
468472
await handle_request_async()
469473

470-
except Exception:
471-
logger.exception("Error in post_writer") # pragma: no cover
474+
except Exception: # pragma: lax no cover
475+
logger.exception("Error in post_writer")
472476
finally:
473477
await read_stream_writer.aclose()
474478
await write_stream.aclose()

src/mcp/types/jsonrpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class JSONRPCError(BaseModel):
7575
"""A response to a request that indicates an error occurred."""
7676

7777
jsonrpc: Literal["2.0"]
78-
id: str | int
78+
id: RequestId
7979
error: ErrorData
8080

8181

0 commit comments

Comments
 (0)