Skip to content

Commit ee78ac1

Browse files
Merge branch 'main' into fix-httpstreamable-errorcode
2 parents 3a45353 + 959d4e3 commit ee78ac1

File tree

6 files changed

+297
-53
lines changed

6 files changed

+297
-53
lines changed

src/mcp/client/auth.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -526,8 +526,8 @@ async def async_auth_flow(self, request: httpx.Request) -> AsyncGenerator[httpx.
526526
break
527527
except ValidationError:
528528
continue
529-
elif oauth_metadata_response.status_code != 404:
530-
break # Non-404 error, stop trying
529+
elif oauth_metadata_response.status_code < 400 or oauth_metadata_response.status_code >= 500:
530+
break # Non-4XX error, stop trying
531531

532532
# Step 3: Register client if needed
533533
registration_request = await self._register_client()

src/mcp/server/fastmcp/server.py

Lines changed: 71 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,13 @@
44

55
import inspect
66
import re
7-
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable, Sequence
8-
from contextlib import (
9-
AbstractAsyncContextManager,
10-
asynccontextmanager,
11-
)
7+
from collections.abc import AsyncIterator, Awaitable, Callable, Collection, Iterable, Sequence
8+
from contextlib import AbstractAsyncContextManager, asynccontextmanager
129
from typing import Any, Generic, Literal
1310

1411
import anyio
1512
import pydantic_core
16-
from pydantic import BaseModel, Field
13+
from pydantic import BaseModel
1714
from pydantic.networks import AnyUrl
1815
from pydantic_settings import BaseSettings, SettingsConfigDict
1916
from starlette.applications import Starlette
@@ -25,10 +22,7 @@
2522
from starlette.types import Receive, Scope, Send
2623

2724
from mcp.server.auth.middleware.auth_context import AuthContextMiddleware
28-
from mcp.server.auth.middleware.bearer_auth import (
29-
BearerAuthBackend,
30-
RequireAuthMiddleware,
31-
)
25+
from mcp.server.auth.middleware.bearer_auth import BearerAuthBackend, RequireAuthMiddleware
3226
from mcp.server.auth.provider import OAuthAuthorizationServerProvider, ProviderTokenVerifier, TokenVerifier
3327
from mcp.server.auth.settings import AuthSettings
3428
from mcp.server.elicitation import ElicitationResult, ElicitSchemaModelT, elicit_with_validation
@@ -48,12 +42,7 @@
4842
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
4943
from mcp.server.transport_security import TransportSecuritySettings
5044
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
51-
from mcp.types import (
52-
AnyFunction,
53-
ContentBlock,
54-
GetPromptResult,
55-
ToolAnnotations,
56-
)
45+
from mcp.types import AnyFunction, ContentBlock, GetPromptResult, ToolAnnotations
5746
from mcp.types import Prompt as MCPPrompt
5847
from mcp.types import PromptArgument as MCPPromptArgument
5948
from mcp.types import Resource as MCPResource
@@ -79,58 +68,57 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
7968
)
8069

8170
# Server settings
82-
debug: bool = False
83-
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO"
71+
debug: bool
72+
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]
8473

8574
# HTTP settings
86-
host: str = "127.0.0.1"
87-
port: int = 8000
88-
mount_path: str = "/" # Mount path (e.g. "/github", defaults to root path)
89-
sse_path: str = "/sse"
90-
message_path: str = "/messages/"
91-
streamable_http_path: str = "/mcp"
75+
host: str
76+
port: int
77+
mount_path: str
78+
sse_path: str
79+
message_path: str
80+
streamable_http_path: str
9281

9382
# StreamableHTTP settings
94-
json_response: bool = False
95-
stateless_http: bool = False # If True, uses true stateless mode (new transport per request)
83+
json_response: bool
84+
stateless_http: bool
85+
"""Define if the server should create a new transport per request."""
9686

9787
# resource settings
98-
warn_on_duplicate_resources: bool = True
88+
warn_on_duplicate_resources: bool
9989

10090
# tool settings
101-
warn_on_duplicate_tools: bool = True
91+
warn_on_duplicate_tools: bool
10292

10393
# prompt settings
104-
warn_on_duplicate_prompts: bool = True
94+
warn_on_duplicate_prompts: bool
10595

106-
dependencies: list[str] = Field(
107-
default_factory=list,
108-
description="List of dependencies to install in the server environment",
109-
)
96+
# TODO(Marcelo): Investigate if this is used. If it is, it's probably a good idea to remove it.
97+
dependencies: list[str]
98+
"""A list of dependencies to install in the server environment."""
11099

111-
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]] | None = Field(
112-
None, description="Lifespan context manager"
113-
)
100+
lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None
101+
"""A async context manager that will be called when the server is started."""
114102

115-
auth: AuthSettings | None = None
103+
auth: AuthSettings | None
116104

117105
# Transport security settings (DNS rebinding protection)
118-
transport_security: TransportSecuritySettings | None = None
106+
transport_security: TransportSecuritySettings | None
119107

120108

121109
def lifespan_wrapper(
122-
app: FastMCP,
123-
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
124-
) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]]:
110+
app: FastMCP[LifespanResultT],
111+
lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]],
112+
) -> Callable[[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[LifespanResultT]]:
125113
@asynccontextmanager
126-
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
114+
async def wrap(_: MCPServer[LifespanResultT, Request]) -> AsyncIterator[LifespanResultT]:
127115
async with lifespan(app) as context:
128116
yield context
129117

130118
return wrap
131119

132120

133-
class FastMCP:
121+
class FastMCP(Generic[LifespanResultT]):
134122
def __init__(
135123
self,
136124
name: str | None = None,
@@ -140,14 +128,50 @@ def __init__(
140128
event_store: EventStore | None = None,
141129
*,
142130
tools: list[Tool] | None = None,
143-
**settings: Any,
131+
debug: bool = False,
132+
log_level: Literal["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] = "INFO",
133+
host: str = "127.0.0.1",
134+
port: int = 8000,
135+
mount_path: str = "/",
136+
sse_path: str = "/sse",
137+
message_path: str = "/messages/",
138+
streamable_http_path: str = "/mcp",
139+
json_response: bool = False,
140+
stateless_http: bool = False,
141+
warn_on_duplicate_resources: bool = True,
142+
warn_on_duplicate_tools: bool = True,
143+
warn_on_duplicate_prompts: bool = True,
144+
dependencies: Collection[str] = (),
145+
lifespan: Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None = None,
146+
auth: AuthSettings | None = None,
147+
transport_security: TransportSecuritySettings | None = None,
144148
):
145-
self.settings = Settings(**settings)
149+
self.settings = Settings(
150+
debug=debug,
151+
log_level=log_level,
152+
host=host,
153+
port=port,
154+
mount_path=mount_path,
155+
sse_path=sse_path,
156+
message_path=message_path,
157+
streamable_http_path=streamable_http_path,
158+
json_response=json_response,
159+
stateless_http=stateless_http,
160+
warn_on_duplicate_resources=warn_on_duplicate_resources,
161+
warn_on_duplicate_tools=warn_on_duplicate_tools,
162+
warn_on_duplicate_prompts=warn_on_duplicate_prompts,
163+
dependencies=list(dependencies),
164+
lifespan=lifespan,
165+
auth=auth,
166+
transport_security=transport_security,
167+
)
146168

147169
self._mcp_server = MCPServer(
148170
name=name or "FastMCP",
149171
instructions=instructions,
150-
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan),
172+
# TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server.
173+
# We need to create a Lifespan type that is a generic on the server type, like Starlette does.
174+
lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore
151175
)
152176
self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools)
153177
self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources)
@@ -257,7 +281,7 @@ async def list_tools(self) -> list[MCPTool]:
257281
for info in tools
258282
]
259283

260-
def get_context(self) -> Context[ServerSession, object, Request]:
284+
def get_context(self) -> Context[ServerSession, LifespanResultT, Request]:
261285
"""
262286
Returns a Context object. Note that the context will only be valid
263287
during a request; outside a request, most methods will error.

src/mcp/server/lowlevel/server.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ async def main():
9393

9494
logger = logging.getLogger(__name__)
9595

96-
LifespanResultT = TypeVar("LifespanResultT")
96+
LifespanResultT = TypeVar("LifespanResultT", default=Any)
9797
RequestT = TypeVar("RequestT", default=Any)
9898

9999
# type aliases for tool call results
@@ -118,7 +118,7 @@ def __init__(
118118

119119

120120
@asynccontextmanager
121-
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
121+
async def lifespan(_: Server[LifespanResultT, RequestT]) -> AsyncIterator[dict[str, Any]]:
122122
"""Default lifespan context manager that does nothing.
123123
124124
Args:
@@ -647,6 +647,12 @@ async def _handle_request(
647647
response = await handler(req)
648648
except McpError as err:
649649
response = err.error
650+
except anyio.get_cancelled_exc_class():
651+
logger.info(
652+
"Request %s cancelled - duplicate response suppressed",
653+
message.request_id,
654+
)
655+
return
650656
except Exception as err:
651657
if raise_exceptions:
652658
raise err

tests/client/test_auth.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,109 @@ async def test_oauth_discovery_fallback_order(self, oauth_provider):
261261
"https://api.example.com/v1/mcp/.well-known/openid-configuration",
262262
]
263263

264+
@pytest.mark.anyio
265+
async def test_oauth_discovery_fallback_conditions(self, oauth_provider):
266+
"""Test the conditions during which an AS metadata discovery fallback will be attempted."""
267+
# Ensure no tokens are stored
268+
oauth_provider.context.current_tokens = None
269+
oauth_provider.context.token_expiry_time = None
270+
oauth_provider._initialized = True
271+
272+
# Mock client info to skip DCR
273+
oauth_provider.context.client_info = OAuthClientInformationFull(
274+
client_id="existing_client",
275+
redirect_uris=[AnyUrl("http://localhost:3030/callback")],
276+
)
277+
278+
# Create a test request
279+
test_request = httpx.Request("GET", "https://api.example.com/v1/mcp")
280+
281+
# Mock the auth flow
282+
auth_flow = oauth_provider.async_auth_flow(test_request)
283+
284+
# First request should be the original request without auth header
285+
request = await auth_flow.__anext__()
286+
assert "Authorization" not in request.headers
287+
288+
# Send a 401 response to trigger the OAuth flow
289+
response = httpx.Response(
290+
401,
291+
headers={
292+
"WWW-Authenticate": 'Bearer resource_metadata="https://api.example.com/.well-known/oauth-protected-resource"'
293+
},
294+
request=test_request,
295+
)
296+
297+
# Next request should be to discover protected resource metadata
298+
discovery_request = await auth_flow.asend(response)
299+
assert str(discovery_request.url) == "https://api.example.com/.well-known/oauth-protected-resource"
300+
assert discovery_request.method == "GET"
301+
302+
# Send a successful discovery response with minimal protected resource metadata
303+
discovery_response = httpx.Response(
304+
200,
305+
content=b'{"resource": "https://api.example.com/v1/mcp", "authorization_servers": ["https://auth.example.com/v1/mcp"]}',
306+
request=discovery_request,
307+
)
308+
309+
# Next request should be to discover OAuth metadata
310+
oauth_metadata_request_1 = await auth_flow.asend(discovery_response)
311+
assert (
312+
str(oauth_metadata_request_1.url)
313+
== "https://auth.example.com/.well-known/oauth-authorization-server/v1/mcp"
314+
)
315+
assert oauth_metadata_request_1.method == "GET"
316+
317+
# Send a 404 response
318+
oauth_metadata_response_1 = httpx.Response(
319+
404,
320+
content=b"Not Found",
321+
request=oauth_metadata_request_1,
322+
)
323+
324+
# Next request should be to discover OAuth metadata at the next endpoint
325+
oauth_metadata_request_2 = await auth_flow.asend(oauth_metadata_response_1)
326+
assert str(oauth_metadata_request_2.url) == "https://auth.example.com/.well-known/oauth-authorization-server"
327+
assert oauth_metadata_request_2.method == "GET"
328+
329+
# Send a 400 response
330+
oauth_metadata_response_2 = httpx.Response(
331+
400,
332+
content=b"Bad Request",
333+
request=oauth_metadata_request_2,
334+
)
335+
336+
# Next request should be to discover OAuth metadata at the next endpoint
337+
oauth_metadata_request_3 = await auth_flow.asend(oauth_metadata_response_2)
338+
assert str(oauth_metadata_request_3.url) == "https://auth.example.com/.well-known/openid-configuration/v1/mcp"
339+
assert oauth_metadata_request_3.method == "GET"
340+
341+
# Send a 500 response
342+
oauth_metadata_response_3 = httpx.Response(
343+
500,
344+
content=b"Internal Server Error",
345+
request=oauth_metadata_request_3,
346+
)
347+
348+
# Mock the authorization process to minimize unnecessary state in this test
349+
oauth_provider._perform_authorization = mock.AsyncMock(return_value=("test_auth_code", "test_code_verifier"))
350+
351+
# Next request should fall back to legacy behavior and auth with the RS (mocked /authorize, next is /token)
352+
token_request = await auth_flow.asend(oauth_metadata_response_3)
353+
assert str(token_request.url) == "https://api.example.com/token"
354+
assert token_request.method == "POST"
355+
356+
# Send a successful token response
357+
token_response = httpx.Response(
358+
200,
359+
content=(
360+
b'{"access_token": "new_access_token", "token_type": "Bearer", "expires_in": 3600, '
361+
b'"refresh_token": "new_refresh_token"}'
362+
),
363+
request=token_request,
364+
)
365+
token_request = await auth_flow.asend(token_response)
366+
264367
@pytest.mark.anyio
265368
async def test_handle_metadata_response_success(self, oauth_provider):
266369
"""Test successful metadata response handling."""

0 commit comments

Comments
 (0)