Skip to content

Commit 54cf133

Browse files
authored
Merge branch 'main' into auth-metadata
2 parents 6788f52 + 7f94bef commit 54cf133

File tree

21 files changed

+777
-84
lines changed

21 files changed

+777
-84
lines changed

README.md

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -315,27 +315,42 @@ async def long_task(files: list[str], ctx: Context) -> str:
315315
Authentication can be used by servers that want to expose tools accessing protected resources.
316316

317317
`mcp.server.auth` implements an OAuth 2.0 server interface, which servers can use by
318-
providing an implementation of the `OAuthServerProvider` protocol.
318+
providing an implementation of the `OAuthAuthorizationServerProvider` protocol.
319319

320-
```
321-
mcp = FastMCP("My App",
322-
auth_server_provider=MyOAuthServerProvider(),
323-
auth=AuthSettings(
324-
issuer_url="https://myapp.com",
325-
revocation_options=RevocationOptions(
326-
enabled=True,
327-
),
328-
client_registration_options=ClientRegistrationOptions(
329-
enabled=True,
330-
valid_scopes=["myscope", "myotherscope"],
331-
default_scopes=["myscope"],
332-
),
333-
required_scopes=["myscope"],
320+
```python
321+
from mcp import FastMCP
322+
from mcp.server.auth.provider import OAuthAuthorizationServerProvider
323+
from mcp.server.auth.settings import (
324+
AuthSettings,
325+
ClientRegistrationOptions,
326+
RevocationOptions,
327+
)
328+
329+
330+
class MyOAuthServerProvider(OAuthAuthorizationServerProvider):
331+
# See an example on how to implement at `examples/servers/simple-auth`
332+
...
333+
334+
335+
mcp = FastMCP(
336+
"My App",
337+
auth_server_provider=MyOAuthServerProvider(),
338+
auth=AuthSettings(
339+
issuer_url="https://myapp.com",
340+
revocation_options=RevocationOptions(
341+
enabled=True,
342+
),
343+
client_registration_options=ClientRegistrationOptions(
344+
enabled=True,
345+
valid_scopes=["myscope", "myotherscope"],
346+
default_scopes=["myscope"],
334347
),
348+
required_scopes=["myscope"],
349+
),
335350
)
336351
```
337352

338-
See [OAuthServerProvider](src/mcp/server/auth/provider.py) for more details.
353+
See [OAuthAuthorizationServerProvider](src/mcp/server/auth/provider.py) for more details.
339354

340355
## Running Your Server
341356

@@ -462,15 +477,12 @@ For low level server with Streamable HTTP implementations, see:
462477
- Stateful server: [`examples/servers/simple-streamablehttp/`](examples/servers/simple-streamablehttp/)
463478
- Stateless server: [`examples/servers/simple-streamablehttp-stateless/`](examples/servers/simple-streamablehttp-stateless/)
464479

465-
466-
467480
The streamable HTTP transport supports:
468481
- Stateful and stateless operation modes
469482
- Resumability with event stores
470-
- JSON or SSE response formats
483+
- JSON or SSE response formats
471484
- Better scalability for multi-node deployments
472485

473-
474486
### Mounting to an Existing ASGI Server
475487

476488
> **Note**: SSE transport is being superseded by [Streamable HTTP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http).

src/mcp/client/session.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -116,12 +116,18 @@ def __init__(
116116
self._message_handler = message_handler or _default_message_handler
117117

118118
async def initialize(self) -> types.InitializeResult:
119-
sampling = types.SamplingCapability()
120-
roots = types.RootsCapability(
119+
sampling = (
120+
types.SamplingCapability()
121+
if self._sampling_callback is not _default_sampling_callback
122+
else None
123+
)
124+
roots = (
121125
# TODO: Should this be based on whether we
122126
# _will_ send notifications, or only whether
123127
# they're supported?
124-
listChanged=True,
128+
types.RootsCapability(listChanged=True)
129+
if self._list_roots_callback is not _default_list_roots_callback
130+
else None
125131
)
126132

127133
result = await self.send_request(

src/mcp/client/sse.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ async def sse_client(
5353

5454
async with anyio.create_task_group() as tg:
5555
try:
56-
logger.info(f"Connecting to SSE endpoint: {remove_request_params(url)}")
56+
logger.debug(f"Connecting to SSE endpoint: {remove_request_params(url)}")
5757
async with httpx_client_factory(headers=headers, auth=auth) as client:
5858
async with aconnect_sse(
5959
client,
@@ -73,7 +73,7 @@ async def sse_reader(
7373
match sse.event:
7474
case "endpoint":
7575
endpoint_url = urljoin(url, sse.data)
76-
logger.info(
76+
logger.debug(
7777
f"Received endpoint URL: {endpoint_url}"
7878
)
7979

@@ -146,7 +146,7 @@ async def post_writer(endpoint_url: str):
146146
await write_stream.aclose()
147147

148148
endpoint_url = await tg.start(sse_reader)
149-
logger.info(
149+
logger.debug(
150150
f"Starting post writer with endpoint URL: {endpoint_url}"
151151
)
152152
tg.start_soon(post_writer, endpoint_url)

src/mcp/client/stdio/__init__.py

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -108,20 +108,28 @@ async def stdio_client(server: StdioServerParameters, errlog: TextIO = sys.stder
108108
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
109109
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
110110

111-
command = _get_executable_command(server.command)
112-
113-
# Open process with stderr piped for capture
114-
process = await _create_platform_compatible_process(
115-
command=command,
116-
args=server.args,
117-
env=(
118-
{**get_default_environment(), **server.env}
119-
if server.env is not None
120-
else get_default_environment()
121-
),
122-
errlog=errlog,
123-
cwd=server.cwd,
124-
)
111+
try:
112+
command = _get_executable_command(server.command)
113+
114+
# Open process with stderr piped for capture
115+
process = await _create_platform_compatible_process(
116+
command=command,
117+
args=server.args,
118+
env=(
119+
{**get_default_environment(), **server.env}
120+
if server.env is not None
121+
else get_default_environment()
122+
),
123+
errlog=errlog,
124+
cwd=server.cwd,
125+
)
126+
except OSError:
127+
# Clean up streams if process creation fails
128+
await read_stream.aclose()
129+
await write_stream.aclose()
130+
await read_stream_writer.aclose()
131+
await write_stream_reader.aclose()
132+
raise
125133

126134
async def stdout_reader():
127135
assert process.stdout, "Opened process is missing stdout"
@@ -177,12 +185,18 @@ async def stdin_writer():
177185
yield read_stream, write_stream
178186
finally:
179187
# Clean up process to prevent any dangling orphaned processes
180-
if sys.platform == "win32":
181-
await terminate_windows_process(process)
182-
else:
183-
process.terminate()
188+
try:
189+
if sys.platform == "win32":
190+
await terminate_windows_process(process)
191+
else:
192+
process.terminate()
193+
except ProcessLookupError:
194+
# Process already exited, which is fine
195+
pass
184196
await read_stream.aclose()
185197
await write_stream.aclose()
198+
await read_stream_writer.aclose()
199+
await write_stream_reader.aclose()
186200

187201

188202
def _get_executable_command(command: str) -> str:

src/mcp/client/streamable_http.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,7 +463,7 @@ async def streamablehttp_client(
463463

464464
async with anyio.create_task_group() as tg:
465465
try:
466-
logger.info(f"Connecting to StreamableHTTP endpoint: {url}")
466+
logger.debug(f"Connecting to StreamableHTTP endpoint: {url}")
467467

468468
async with httpx_client_factory(
469469
headers=transport.request_headers,

src/mcp/server/fastmcp/server.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@
4949
from mcp.server.stdio import stdio_server
5050
from mcp.server.streamable_http import EventStore
5151
from mcp.server.streamable_http_manager import StreamableHTTPSessionManager
52-
from mcp.shared.context import LifespanContextT, RequestContext
52+
from mcp.shared.context import LifespanContextT, RequestContext, RequestT
5353
from mcp.types import (
5454
AnyFunction,
5555
EmbeddedResource,
@@ -124,9 +124,11 @@ class Settings(BaseSettings, Generic[LifespanResultT]):
124124
def lifespan_wrapper(
125125
app: FastMCP,
126126
lifespan: Callable[[FastMCP], AbstractAsyncContextManager[LifespanResultT]],
127-
) -> Callable[[MCPServer[LifespanResultT]], AbstractAsyncContextManager[object]]:
127+
) -> Callable[
128+
[MCPServer[LifespanResultT, Request]], AbstractAsyncContextManager[object]
129+
]:
128130
@asynccontextmanager
129-
async def wrap(s: MCPServer[LifespanResultT]) -> AsyncIterator[object]:
131+
async def wrap(s: MCPServer[LifespanResultT, Request]) -> AsyncIterator[object]:
130132
async with lifespan(app) as context:
131133
yield context
132134

@@ -260,7 +262,7 @@ async def list_tools(self) -> list[MCPTool]:
260262
for info in tools
261263
]
262264

263-
def get_context(self) -> Context[ServerSession, object]:
265+
def get_context(self) -> Context[ServerSession, object, Request]:
264266
"""
265267
Returns a Context object. Note that the context will only be valid
266268
during a request; outside a request, most methods will error.
@@ -917,7 +919,7 @@ def _convert_to_content(
917919
return [TextContent(type="text", text=result)]
918920

919921

920-
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT]):
922+
class Context(BaseModel, Generic[ServerSessionT, LifespanContextT, RequestT]):
921923
"""Context object providing access to MCP capabilities.
922924
923925
This provides a cleaner interface to MCP's RequestContext functionality.
@@ -951,13 +953,15 @@ def my_tool(x: int, ctx: Context) -> str:
951953
The context is optional - tools that don't need it can omit the parameter.
952954
"""
953955

954-
_request_context: RequestContext[ServerSessionT, LifespanContextT] | None
956+
_request_context: RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
955957
_fastmcp: FastMCP | None
956958

957959
def __init__(
958960
self,
959961
*,
960-
request_context: RequestContext[ServerSessionT, LifespanContextT] | None = None,
962+
request_context: (
963+
RequestContext[ServerSessionT, LifespanContextT, RequestT] | None
964+
) = None,
961965
fastmcp: FastMCP | None = None,
962966
**kwargs: Any,
963967
):
@@ -973,7 +977,9 @@ def fastmcp(self) -> FastMCP:
973977
return self._fastmcp
974978

975979
@property
976-
def request_context(self) -> RequestContext[ServerSessionT, LifespanContextT]:
980+
def request_context(
981+
self,
982+
) -> RequestContext[ServerSessionT, LifespanContextT, RequestT]:
977983
"""Access to the underlying request context."""
978984
if self._request_context is None:
979985
raise ValueError("Context is not available outside of a request")

src/mcp/server/fastmcp/tools/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
if TYPE_CHECKING:
1515
from mcp.server.fastmcp.server import Context
1616
from mcp.server.session import ServerSessionT
17-
from mcp.shared.context import LifespanContextT
17+
from mcp.shared.context import LifespanContextT, RequestT
1818

1919

2020
class Tool(BaseModel):
@@ -85,7 +85,7 @@ def from_function(
8585
async def run(
8686
self,
8787
arguments: dict[str, Any],
88-
context: Context[ServerSessionT, LifespanContextT] | None = None,
88+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
8989
) -> Any:
9090
"""Run the tool with arguments."""
9191
try:

src/mcp/server/fastmcp/tools/tool_manager.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from mcp.server.fastmcp.exceptions import ToolError
77
from mcp.server.fastmcp.tools.base import Tool
88
from mcp.server.fastmcp.utilities.logging import get_logger
9-
from mcp.shared.context import LifespanContextT
9+
from mcp.shared.context import LifespanContextT, RequestT
1010
from mcp.types import ToolAnnotations
1111

1212
if TYPE_CHECKING:
@@ -65,7 +65,7 @@ async def call_tool(
6565
self,
6666
name: str,
6767
arguments: dict[str, Any],
68-
context: Context[ServerSessionT, LifespanContextT] | None = None,
68+
context: Context[ServerSessionT, LifespanContextT, RequestT] | None = None,
6969
) -> Any:
7070
"""Call a tool by name with arguments."""
7171
tool = self.get_tool(name)

src/mcp/server/lowlevel/server.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,12 @@ async def main():
7272
import warnings
7373
from collections.abc import AsyncIterator, Awaitable, Callable, Iterable
7474
from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager
75-
from typing import Any, Generic, TypeVar
75+
from typing import Any, Generic
7676

7777
import anyio
7878
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
7979
from pydantic import AnyUrl
80+
from typing_extensions import TypeVar
8081

8182
import mcp.types as types
8283
from mcp.server.lowlevel.helper_types import ReadResourceContents
@@ -85,15 +86,16 @@ async def main():
8586
from mcp.server.stdio import stdio_server as stdio_server
8687
from mcp.shared.context import RequestContext
8788
from mcp.shared.exceptions import McpError
88-
from mcp.shared.message import SessionMessage
89+
from mcp.shared.message import ServerMessageMetadata, SessionMessage
8990
from mcp.shared.session import RequestResponder
9091

9192
logger = logging.getLogger(__name__)
9293

9394
LifespanResultT = TypeVar("LifespanResultT")
95+
RequestT = TypeVar("RequestT", default=Any)
9496

9597
# This will be properly typed in each Server instance's context
96-
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any]] = (
98+
request_ctx: contextvars.ContextVar[RequestContext[ServerSession, Any, Any]] = (
9799
contextvars.ContextVar("request_ctx")
98100
)
99101

@@ -111,7 +113,7 @@ def __init__(
111113

112114

113115
@asynccontextmanager
114-
async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
116+
async def lifespan(server: Server[LifespanResultT, RequestT]) -> AsyncIterator[object]:
115117
"""Default lifespan context manager that does nothing.
116118
117119
Args:
@@ -123,14 +125,15 @@ async def lifespan(server: Server[LifespanResultT]) -> AsyncIterator[object]:
123125
yield {}
124126

125127

126-
class Server(Generic[LifespanResultT]):
128+
class Server(Generic[LifespanResultT, RequestT]):
127129
def __init__(
128130
self,
129131
name: str,
130132
version: str | None = None,
131133
instructions: str | None = None,
132134
lifespan: Callable[
133-
[Server[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]
135+
[Server[LifespanResultT, RequestT]],
136+
AbstractAsyncContextManager[LifespanResultT],
134137
] = lifespan,
135138
):
136139
self.name = name
@@ -215,7 +218,9 @@ def get_capabilities(
215218
)
216219

217220
@property
218-
def request_context(self) -> RequestContext[ServerSession, LifespanResultT]:
221+
def request_context(
222+
self,
223+
) -> RequestContext[ServerSession, LifespanResultT, RequestT]:
219224
"""If called outside of a request context, this will raise a LookupError."""
220225
return request_ctx.get()
221226

@@ -555,6 +560,13 @@ async def _handle_request(
555560

556561
token = None
557562
try:
563+
# Extract request context from message metadata
564+
request_data = None
565+
if message.message_metadata is not None and isinstance(
566+
message.message_metadata, ServerMessageMetadata
567+
):
568+
request_data = message.message_metadata.request_context
569+
558570
# Set our global state that can be retrieved via
559571
# app.get_request_context()
560572
token = request_ctx.set(
@@ -563,6 +575,7 @@ async def _handle_request(
563575
message.request_meta,
564576
session,
565577
lifespan_context,
578+
request=request_data,
566579
)
567580
)
568581
response = await handler(req)

0 commit comments

Comments
 (0)