diff --git a/src/mcp/server/checkpoint.py b/src/mcp/server/checkpoint.py new file mode 100644 index 000000000..a7951e9f0 --- /dev/null +++ b/src/mcp/server/checkpoint.py @@ -0,0 +1,142 @@ +from __future__ import annotations + +import abc +import time +from dataclasses import dataclass +from typing import Protocol, runtime_checkable + +from mcp.server.session import ServerSession +from mcp.types import ( + CheckpointCreateParams, + CheckpointCreateResult, + CheckpointValidateParams, + CheckpointValidateResult, + CheckpointResumeParams, + CheckpointResumeResult, + CheckpointDeleteParams, + CheckpointDeleteResult, +) + + +@runtime_checkable +class CheckpointBackend(Protocol): + """Backend that actually stores and restores state behind handles.""" + + async def create_checkpoint( + self, + session: ServerSession, + params: CheckpointCreateParams, + ) -> CheckpointCreateResult: ... + + async def validate_checkpoint( + self, + session: ServerSession, + params: CheckpointValidateParams, + ) -> CheckpointValidateResult: ... + + async def resume_checkpoint( + self, + session: ServerSession, + params: CheckpointResumeParams, + ) -> CheckpointResumeResult: ... + + async def delete_checkpoint( + self, + session: ServerSession, + params: CheckpointDeleteParams, + ) -> CheckpointDeleteResult: ... + + +@dataclass +class InMemoryHandleEntry: + value: object + digest: str + expires_at: float + + +class InMemoryCheckpointBackend(CheckpointBackend): + """Simple in-memory backend you can use for tests/POC. + + This is intentionally generic; concrete servers (data, browser, etc.) + decide *what* `value` is and how to interpret it. + """ + + def __init__(self, ttl_seconds: int = 1800) -> None: + self._ttl = ttl_seconds + self._handles: dict[str, InMemoryHandleEntry] = {} + + def _now(self) -> float: + return time.time() + + async def create_checkpoint( + self, + session: ServerSession, + params: CheckpointCreateParams, + ) -> CheckpointCreateResult: + # session.fastmcp or session.server can expose some "current state" + # For now you can override this backend in your server and implement + # your own snapshot logic. + raise NotImplementedError( + "Subclass InMemoryCheckpointBackend and override create_checkpoint " + "to capture concrete state (e.g. data tables, browser session)." + ) + + async def validate_checkpoint( + self, + session: ServerSession, + params: CheckpointValidateParams, + ) -> CheckpointValidateResult: + entry = self._handles.get(params.handle) + if not entry: + return CheckpointValidateResult( + valid=False, + remainingTtlSeconds=0, + digestMatch=False, + ) + + now = self._now() + if now >= entry.expires_at: + return CheckpointValidateResult( + valid=False, + remainingTtlSeconds=0, + digestMatch=params.expectedDigest == entry.digest, + ) + + remaining = int(entry.expires_at - now) + return CheckpointValidateResult( + valid=True, + remainingTtlSeconds=remaining, + digestMatch=( + params.expectedDigest is None + or params.expectedDigest == entry.digest + ), + ) + + async def resume_checkpoint( + self, + session: ServerSession, + params: CheckpointResumeParams, + ) -> CheckpointResumeResult: + entry = self._handles.get(params.handle) + if not entry: + # You’ll map this to HANDLE_NOT_FOUND at JSON-RPC level + return CheckpointResumeResult(resumed=False, handle=params.handle) + + if self._now() >= entry.expires_at: + # Map to EXPIRED + return CheckpointResumeResult(resumed=False, handle=params.handle) + + # Subclasses should take `entry.value` and rehydrate into session state. + raise NotImplementedError( + "Subclass InMemoryCheckpointBackend.resume_checkpoint to rehydrate " + "concrete session state from stored value." + ) + + async def delete_checkpoint( + self, + session: ServerSession, + params: CheckpointDeleteParams, + ) -> CheckpointDeleteResult: + deleted = params.handle in self._handles + self._handles.pop(params.handle, None) + return CheckpointDeleteResult(deleted=deleted) \ No newline at end of file diff --git a/src/mcp/server/fastmcp/server.py b/src/mcp/server/fastmcp/server.py index f74b65557..b269542ec 100644 --- a/src/mcp/server/fastmcp/server.py +++ b/src/mcp/server/fastmcp/server.py @@ -64,6 +64,7 @@ from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings +from mcp.server.checkpoint import CheckpointBackend from mcp.shared.context import LifespanContextT, RequestContext, RequestT from mcp.types import Annotations, AnyFunction, ContentBlock, GetPromptResult, Icon, ToolAnnotations from mcp.types import Prompt as MCPPrompt @@ -173,6 +174,7 @@ def __init__( # noqa: PLR0913 lifespan: (Callable[[FastMCP[LifespanResultT]], AbstractAsyncContextManager[LifespanResultT]] | None) = None, auth: AuthSettings | None = None, transport_security: TransportSecuritySettings | None = None, + checkpoint_backend: CheckpointBackend | None = None, ): # Auto-enable DNS rebinding protection for localhost (IPv4 and IPv6) if transport_security is None and host in ("127.0.0.1", "localhost", "::1"): @@ -202,6 +204,7 @@ def __init__( # noqa: PLR0913 transport_security=transport_security, ) + self._checkpoint_backend = checkpoint_backend self._mcp_server = MCPServer( name=name or "FastMCP", instructions=instructions, @@ -210,6 +213,7 @@ def __init__( # noqa: PLR0913 # TODO(Marcelo): It seems there's a type mismatch between the lifespan type from an FastMCP and Server. # We need to create a Lifespan type that is a generic on the server type, like Starlette does. lifespan=(lifespan_wrapper(self, self.settings.lifespan) if self.settings.lifespan else default_lifespan), # type: ignore + checkpoint_backend=self._checkpoint_backend, ) self._tool_manager = ToolManager(tools=tools, warn_on_duplicate_tools=self.settings.warn_on_duplicate_tools) self._resource_manager = ResourceManager(warn_on_duplicate_resources=self.settings.warn_on_duplicate_resources) diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index e29c021b7..6f79ca9cf 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -94,6 +94,7 @@ async def main(): from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.session import RequestResponder from mcp.shared.tool_name_validation import validate_and_warn_tool_name +from mcp.server.checkpoint import CheckpointBackend logger = logging.getLogger(__name__) @@ -146,6 +147,9 @@ def __init__( [Server[LifespanResultT, RequestT]], AbstractAsyncContextManager[LifespanResultT], ] = lifespan, + *, + stateless: bool = False, + checkpoint_backend: CheckpointBackend | None = None, ): self.name = name self.version = version @@ -159,6 +163,8 @@ def __init__( self.notification_handlers: dict[type, Callable[..., Awaitable[None]]] = {} self._tool_cache: dict[str, types.Tool] = {} self._experimental_handlers: ExperimentalHandlers | None = None + self._stateless = stateless + self._checkpoint_backend = checkpoint_backend logger.debug("Initializing server %r", name) def create_initialization_options( @@ -650,6 +656,7 @@ async def run( write_stream, initialization_options, stateless=stateless, + checkpoint_backend=self._checkpoint_backend, ) ) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 8f0baa3e9..ae0153801 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -38,7 +38,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: """ from enum import Enum -from typing import Any, TypeVar, overload +from typing import Any, TypeVar, overload, TYPE_CHECKING import anyio import anyio.lowlevel @@ -57,7 +57,8 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]: RequestResponder, ) from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS - +if TYPE_CHECKING: + from mcp.server.checkpoint import CheckpointBackend class InitializationState(Enum): NotInitialized = 1 @@ -91,6 +92,7 @@ def __init__( write_stream: MemoryObjectSendStream[SessionMessage], init_options: InitializationOptions, stateless: bool = False, + checkpoint_backend: "CheckpointBackend | None" = None, ) -> None: super().__init__(read_stream, write_stream, types.ClientRequest, types.ClientNotification) self._initialization_state = ( @@ -102,6 +104,7 @@ def __init__( ServerRequestResponder ](0) self._exit_stack.push_async_callback(lambda: self._incoming_message_stream_reader.aclose()) + self._checkpoint_backend = checkpoint_backend @property def client_params(self) -> types.InitializeRequestParams | None: @@ -116,6 +119,11 @@ def experimental(self) -> ExperimentalServerSessionFeatures: if self._experimental_features is None: self._experimental_features = ExperimentalServerSessionFeatures(self) return self._experimental_features + + @property + def checkpoint_backend(self) -> "CheckpointBackend | None": + """Optional checkpoint backend attached to this session.""" + return self._checkpoint_backend def check_client_capability(self, capability: types.ClientCapabilities) -> bool: # pragma: no cover """Check if the client supports a specific capability.""" @@ -688,4 +696,4 @@ async def _handle_incoming(self, req: ServerRequestResponder) -> None: def incoming_messages( self, ) -> MemoryObjectReceiveStream[ServerRequestResponder]: - return self._incoming_message_stream_reader + return self._incoming_message_stream_reader \ No newline at end of file diff --git a/src/mcp/types.py b/src/mcp/types.py index 9ca8ffc18..3be6b0e26 100644 --- a/src/mcp/types.py +++ b/src/mcp/types.py @@ -1996,3 +1996,67 @@ class ServerNotification(RootModel[ServerNotificationType]): class ServerResult(RootModel[ServerResultType]): pass + + +# --- Checkpoint protocol extensions ----------------------------------------- + +class CheckpointHandle(BaseModel): + """Opaque checkpoint handle returned by servers.""" + handle: str + digest: str + ttlSeconds: int + + +class CheckpointCreateParams(BaseModel): + """Params for checkpoint/create. + + For v1 you can keep this empty – the server infers the session + from transport/session context – but we define it for forward compat. + """ + # Optional: allow tools to tag a logical name + label: str | None = None + + +class CheckpointCreateResult(BaseModel): + """Result of checkpoint/create.""" + handle: str + digest: str + ttlSeconds: int + + +class CheckpointValidateParams(BaseModel): + """Params for checkpoint/validate.""" + handle: str + expectedDigest: str | None = None + + +class CheckpointValidateResult(BaseModel): + """Result of checkpoint/validate.""" + valid: bool + remainingTtlSeconds: int + digestMatch: bool + + +class CheckpointResumeParams(BaseModel): + """Params for checkpoint/resume.""" + handle: str + + +class CheckpointResumeResult(BaseModel): + """Result of checkpoint/resume. + + You can expand this later if you want to + surface metadata to the client. + """ + resumed: bool + handle: str + + +class CheckpointDeleteParams(BaseModel): + """Params for checkpoint/delete.""" + handle: str + + +class CheckpointDeleteResult(BaseModel): + """Result of checkpoint/delete.""" + deleted: bool \ No newline at end of file