Skip to content

Commit f0fac5c

Browse files
committed
fix some pyright
1 parent a85068e commit f0fac5c

File tree

12 files changed

+71
-43
lines changed

12 files changed

+71
-43
lines changed

examples/servers/simple-streamablehttp/mcp_simple_streamablehttp/server.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from starlette.routing import Mount
1414
from starlette.types import Receive, Scope, Send
1515

16-
from pydantic import AnyUrl
1716

1817
from .event_store import InMemoryEventStore
1918

@@ -75,7 +74,7 @@ async def call_tool(name: str, arguments: dict[str, Any]) -> list[types.ContentB
7574

7675
# This will send a resource notificaiton though standalone SSE
7776
# established by GET request
78-
await ctx.session.send_resource_updated(uri=AnyUrl("http:///test_resource"))
77+
await ctx.session.send_resource_updated(uri="http:///test_resource")
7978
return [
8079
types.TextContent(
8180
type="text",

src/mcp/client/session.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,22 @@
2323
class SamplingFnT(Protocol):
2424
async def __call__(
2525
self,
26-
context: RequestContext["ClientTransportSession", Any],
26+
context: RequestContext["ClientSession", Any],
2727
params: types.CreateMessageRequestParams,
2828
) -> types.CreateMessageResult | types.CreateMessageResultWithTools | types.ErrorData: ... # pragma: no branch
2929

3030

3131
class ElicitationFnT(Protocol):
3232
async def __call__(
3333
self,
34-
context: RequestContext["ClientTransportSession", Any],
34+
context: RequestContext["ClientSession", Any],
3535
params: types.ElicitRequestParams,
3636
) -> types.ElicitResult | types.ErrorData: ... # pragma: no branch
3737

3838

3939
class ListRootsFnT(Protocol):
4040
async def __call__(
41-
self, context: RequestContext["ClientTransportSession", Any]
41+
self, context: RequestContext["ClientSession", Any]
4242
) -> types.ListRootsResult | types.ErrorData: ... # pragma: no branch
4343

4444

@@ -418,7 +418,7 @@ async def send_roots_list_changed(self) -> None: # pragma: no cover
418418
await self.send_notification(types.RootsListChangedNotification())
419419

420420
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
421-
ctx = RequestContext[ClientTransportSession, Any](
421+
ctx = RequestContext[ClientSession, Any](
422422
request_id=responder.request_id,
423423
meta=responder.request_meta,
424424
session=self,

src/mcp/client/transport_session.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from abc import ABC, abstractmethod
2-
from datetime import timedelta
32
from typing import Any
43

5-
from pydantic import AnyUrl
6-
74
import mcp.types as types
85
from mcp.shared.session import ProgressFnT
6+
from mcp.types import RequestParamsMeta
97

108

119
class ClientTransportSession(ABC):
@@ -57,17 +55,17 @@ async def list_resource_templates(
5755
raise NotImplementedError
5856

5957
@abstractmethod
60-
async def read_resource(self, uri: AnyUrl) -> types.ReadResourceResult:
58+
async def read_resource(self, uri: str) -> types.ReadResourceResult:
6159
"""Send a resources/read request."""
6260
raise NotImplementedError
6361

6462
@abstractmethod
65-
async def subscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
63+
async def subscribe_resource(self, uri: str) -> types.EmptyResult:
6664
"""Send a resources/subscribe request."""
6765
raise NotImplementedError
6866

6967
@abstractmethod
70-
async def unsubscribe_resource(self, uri: str, *, meta: RequestParamsMeta | None = None) -> types.EmptyResult:
68+
async def unsubscribe_resource(self, uri: str) -> types.EmptyResult:
7169
"""Send a resources/unsubscribe request."""
7270
raise NotImplementedError
7371

@@ -87,7 +85,8 @@ async def call_tool(
8785
@abstractmethod
8886
async def list_prompts(
8987
self,
90-
cursor: str | None = None,
88+
*,
89+
params: types.PaginatedRequestParams | None = None,
9190
) -> types.ListPromptsResult:
9291
"""Send a prompts/list request."""
9392
raise NotImplementedError
@@ -114,7 +113,6 @@ async def complete(
114113
@abstractmethod
115114
async def list_tools(
116115
self,
117-
cursor: str | None = None,
118116
*,
119117
params: types.PaginatedRequestParams | None = None,
120118
) -> types.ListToolsResult:

src/mcp/server/elicitation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
from pydantic import BaseModel
1010

1111
from mcp.server.transport_session import ServerTransportSession
12-
from mcp.types import RequestId
12+
from mcp.types import (
13+
ElicitResult,
14+
RequestId,
15+
)
1316

1417
ElicitSchemaModelT = TypeVar("ElicitSchemaModelT", bound=BaseModel)
1518

@@ -123,7 +126,7 @@ async def elicit_with_validation(
123126

124127
json_schema = schema.model_json_schema()
125128

126-
result = await session.elicit_form(
129+
result: ElicitResult = await session.elicit(
127130
message=message,
128131
requested_schema=json_schema,
129132
related_request_id=related_request_id,
@@ -143,7 +146,7 @@ async def elicit_with_validation(
143146

144147

145148
async def elicit_url(
146-
session: ServerSession,
149+
session: ServerTransportSession,
147150
message: str,
148151
url: str,
149152
elicitation_id: str,

src/mcp/server/experimental/task_result_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import anyio
1616

17-
from mcp.server.session import ServerSession
17+
from mcp.server.session import ServerTransportSession
1818
from mcp.shared.exceptions import McpError
1919
from mcp.shared.experimental.tasks.helpers import RELATED_TASK_METADATA_KEY, is_terminal
2020
from mcp.shared.experimental.tasks.message_queue import TaskMessageQueue
@@ -69,7 +69,7 @@ def __init__(
6969

7070
async def send_message(
7171
self,
72-
session: ServerSession,
72+
session: ServerTransportSession,
7373
message: SessionMessage,
7474
) -> None:
7575
"""Send a message via the session.
@@ -81,7 +81,7 @@ async def send_message(
8181
async def handle(
8282
self,
8383
request: GetTaskPayloadRequest,
84-
session: ServerSession,
84+
session: ServerTransportSession,
8585
request_id: RequestId,
8686
) -> GetTaskPayloadResult:
8787
"""Handle a tasks/result request.
@@ -131,7 +131,7 @@ async def handle(
131131
async def _deliver_queued_messages(
132132
self,
133133
task_id: str,
134-
session: ServerSession,
134+
session: ServerTransportSession,
135135
request_id: RequestId,
136136
) -> None:
137137
"""Dequeue and send all pending messages for a task.

src/mcp/server/session.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ async def handle_list_prompts(ctx: RequestContext) -> list[types.Prompt]:
4242
import anyio
4343
import anyio.lowlevel
4444
from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
45-
from pydantic import AnyUrl, TypeAdapter
45+
from pydantic import TypeAdapter
4646

4747
import mcp.types as types
4848
from mcp.server.experimental.session_features import ExperimentalServerSessionFeatures
@@ -232,7 +232,7 @@ async def send_log_message(
232232
related_request_id,
233233
)
234234

235-
async def send_resource_updated(self, uri: str | AnyUrl) -> None: # pragma: no cover
235+
async def send_resource_updated(self, uri: str) -> None: # pragma: no cover
236236
"""Send a resource updated notification."""
237237
await self.send_notification(
238238
types.ResourceUpdatedNotification(

src/mcp/server/transport_session.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,17 @@
33
from abc import ABC, abstractmethod
44
from typing import Any
55

6-
from pydantic import AnyUrl
76

87
import mcp.types as types
8+
from mcp.shared.message import SessionMessage
99

1010

1111
class ServerTransportSession(ABC):
1212
"""Abstract base class for transport sessions."""
13+
@abstractmethod
14+
async def send_message(self, message: SessionMessage) -> None:
15+
"""Send a raw session message."""
16+
raise NotImplementedError
1317

1418
@abstractmethod
1519
async def send_log_message(
@@ -23,7 +27,7 @@ async def send_log_message(
2327
raise NotImplementedError
2428

2529
@abstractmethod
26-
async def send_resource_updated(self, uri: AnyUrl) -> None:
30+
async def send_resource_updated(self, uri: str) -> None:
2731
"""Send a resource updated notification."""
2832
raise NotImplementedError
2933

@@ -36,12 +40,33 @@ async def list_roots(self) -> types.ListRootsResult:
3640
async def elicit(
3741
self,
3842
message: str,
39-
requestedSchema: types.ElicitRequestedSchema,
43+
requested_schema: types.ElicitRequestedSchema,
4044
related_request_id: types.RequestId | None = None,
4145
) -> types.ElicitResult:
4246
"""Send an elicitation/create request."""
4347
raise NotImplementedError
4448

49+
@abstractmethod
50+
async def elicit_form(
51+
self,
52+
message: str,
53+
requested_schema: types.ElicitRequestedSchema,
54+
related_request_id: types.RequestId | None = None,
55+
) -> types.ElicitResult:
56+
"""Send a form mode elicitation/create request."""
57+
raise NotImplementedError
58+
59+
@abstractmethod
60+
async def elicit_url(
61+
self,
62+
message: str,
63+
url: str,
64+
elicitation_id: str,
65+
related_request_id: types.RequestId | None = None,
66+
) -> types.ElicitResult:
67+
"""Send a URL mode elicitation/create request."""
68+
raise NotImplementedError
69+
4570
@abstractmethod
4671
async def send_ping(self) -> types.EmptyResult:
4772
"""Send a ping request."""

src/mcp/shared/context.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
from mcp import ClientTransportSession, ServerTransportSession
1414

1515
SessionT = TypeVar(
16-
"SessionT", bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession"
16+
"SessionT",
17+
bound=BaseSession[Any, Any, Any, Any, Any] | "ClientTransportSession" | "ServerTransportSession",
18+
covariant=True,
1719
)
1820
LifespanContextT = TypeVar("LifespanContextT")
1921
RequestT = TypeVar("RequestT", default=Any)

src/mcp/shared/session.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,8 +153,19 @@ def in_flight(self) -> bool: # pragma: no cover
153153
def cancelled(self) -> bool: # pragma: no cover
154154
return self._cancel_scope.cancel_called
155155

156+
class Session:
157+
"""Base class for a session that can send progress notifications."""
158+
async def send_progress_notification(
159+
self,
160+
progress_token: ProgressToken,
161+
progress: float,
162+
total: float | None = None,
163+
message: str | None = None,
164+
) -> None:
165+
"""Sends a progress notification for a request that is currently being processed."""
156166

157167
class BaseSession(
168+
Session,
158169
Generic[
159170
SendRequestT,
160171
SendNotificationT,
@@ -500,15 +511,6 @@ async def _received_notification(self, notification: ReceiveNotificationT) -> No
500511
to listen on the message stream.
501512
"""
502513

503-
async def send_progress_notification(
504-
self,
505-
progress_token: ProgressToken,
506-
progress: float,
507-
total: float | None = None,
508-
message: str | None = None,
509-
) -> None:
510-
"""Sends a progress notification for a request that is currently being processed."""
511-
512514
async def _handle_incoming(
513515
self, req: RequestResponder[ReceiveRequestT, SendResultT] | ReceiveNotificationT | Exception
514516
) -> None:

tests/client/test_list_roots_callback.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import pytest
22
from pydantic import FileUrl
33

4-
from mcp.client.transport_session import ClientTransportSession
5-
from mcp import Client
64
from mcp.client.session import ClientSession
5+
from mcp import Client
76
from mcp.server.fastmcp import FastMCP
87
from mcp.server.fastmcp.server import Context
98
from mcp.server.session import ServerSession
@@ -29,7 +28,7 @@ async def test_list_roots_callback():
2928
)
3029

3130
async def list_roots_callback(
32-
context: RequestContext[ClientTransportSession, None],
31+
context: RequestContext[ClientSession, None],
3332
) -> ListRootsResult:
3433
return callback_return
3534

0 commit comments

Comments
 (0)