Skip to content

Commit cd4fcf2

Browse files
committed
Completely drop RootModel from types module
1 parent 8e567ec commit cd4fcf2

34 files changed

+572
-790
lines changed

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ select = [
137137
]
138138
ignore = ["PERF203", "PLC0415", "PLR0402"]
139139

140+
[tool.ruff.lint.flake8-tidy-imports.banned-api]
141+
"pydantic.RootModel".msg = "Use `pydantic.TypeAdapter` instead."
142+
143+
140144
[tool.ruff.lint.mccabe]
141145
max-complexity = 24 # Default is 10
142146

src/mcp/client/experimental/task_handlers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def build_capability(self) -> types.ClientTasksCapability | None:
242242
def handles_request(request: types.ServerRequest) -> bool:
243243
"""Check if this handler handles the given request type."""
244244
return isinstance(
245-
request.root,
245+
request,
246246
types.GetTaskRequest | types.GetTaskPayloadRequest | types.ListTasksRequest | types.CancelTaskRequest,
247247
)
248248

@@ -259,7 +259,7 @@ async def handle_request(
259259
types.ClientResult | types.ErrorData
260260
)
261261

262-
match responder.request.root:
262+
match responder.request:
263263
case types.GetTaskRequest(params=params):
264264
response = await self.get_task(ctx, params)
265265
client_response = client_response_type.validate_python(response)
@@ -281,7 +281,7 @@ async def handle_request(
281281
await responder.respond(client_response)
282282

283283
case _: # pragma: no cover
284-
raise ValueError(f"Unhandled request type: {type(responder.request.root)}")
284+
raise ValueError(f"Unhandled request type: {type(responder.request)}")
285285

286286

287287
# Backwards compatibility aliases

src/mcp/client/experimental/tasks.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -92,15 +92,13 @@ async def call_tool_as_task(
9292
_meta = types.RequestParams.Meta(**meta)
9393

9494
return await self._session.send_request(
95-
types.ClientRequest(
96-
types.CallToolRequest(
97-
params=types.CallToolRequestParams(
98-
name=name,
99-
arguments=arguments,
100-
task=types.TaskMetadata(ttl=ttl),
101-
_meta=_meta,
102-
),
103-
)
95+
types.CallToolRequest(
96+
params=types.CallToolRequestParams(
97+
name=name,
98+
arguments=arguments,
99+
task=types.TaskMetadata(ttl=ttl),
100+
_meta=_meta,
101+
),
104102
),
105103
types.CreateTaskResult,
106104
)
@@ -115,10 +113,8 @@ async def get_task(self, task_id: str) -> types.GetTaskResult:
115113
GetTaskResult containing the task status and metadata
116114
"""
117115
return await self._session.send_request(
118-
types.ClientRequest(
119-
types.GetTaskRequest(
120-
params=types.GetTaskRequestParams(task_id=task_id),
121-
)
116+
types.GetTaskRequest(
117+
params=types.GetTaskRequestParams(task_id=task_id),
122118
),
123119
types.GetTaskResult,
124120
)
@@ -142,10 +138,8 @@ async def get_task_result(
142138
The task result, validated against result_type
143139
"""
144140
return await self._session.send_request(
145-
types.ClientRequest(
146-
types.GetTaskPayloadRequest(
147-
params=types.GetTaskPayloadRequestParams(task_id=task_id),
148-
)
141+
types.GetTaskPayloadRequest(
142+
params=types.GetTaskPayloadRequestParams(task_id=task_id),
149143
),
150144
result_type,
151145
)
@@ -164,9 +158,7 @@ async def list_tasks(
164158
"""
165159
params = types.PaginatedRequestParams(cursor=cursor) if cursor else None
166160
return await self._session.send_request(
167-
types.ClientRequest(
168-
types.ListTasksRequest(params=params),
169-
),
161+
types.ListTasksRequest(params=params),
170162
types.ListTasksResult,
171163
)
172164

@@ -180,10 +172,8 @@ async def cancel_task(self, task_id: str) -> types.CancelTaskResult:
180172
CancelTaskResult with the updated task state
181173
"""
182174
return await self._session.send_request(
183-
types.ClientRequest(
184-
types.CancelTaskRequest(
185-
params=types.CancelTaskRequestParams(task_id=task_id),
186-
)
175+
types.CancelTaskRequest(
176+
params=types.CancelTaskRequestParams(task_id=task_id),
187177
),
188178
types.CancelTaskResult,
189179
)

src/mcp/client/session.py

Lines changed: 50 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,7 @@ def __init__(
122122
sampling_capabilities: types.SamplingCapability | None = None,
123123
experimental_task_handlers: ExperimentalTaskHandlers | None = None,
124124
) -> None:
125-
super().__init__(
126-
read_stream,
127-
write_stream,
128-
types.ServerRequest,
129-
types.ServerNotification,
130-
read_timeout_seconds=read_timeout_seconds,
131-
)
125+
super().__init__(read_stream, write_stream, read_timeout_seconds=read_timeout_seconds)
132126
self._client_info = client_info or DEFAULT_CLIENT_INFO
133127
self._sampling_callback = sampling_callback or _default_sampling_callback
134128
self._sampling_capabilities = sampling_capabilities
@@ -143,6 +137,14 @@ def __init__(
143137
# Experimental: Task handlers (use defaults if not provided)
144138
self._task_handlers = experimental_task_handlers or ExperimentalTaskHandlers()
145139

140+
@property
141+
def _receive_request_adapter(self) -> TypeAdapter[types.ServerRequest]:
142+
return types.server_request_adapter
143+
144+
@property
145+
def _receive_notification_adapter(self) -> TypeAdapter[types.ServerNotification]:
146+
return types.server_notification_adapter
147+
146148
async def initialize(self) -> types.InitializeResult:
147149
sampling = (
148150
(self._sampling_capabilities or types.SamplingCapability())
@@ -167,20 +169,18 @@ async def initialize(self) -> types.InitializeResult:
167169
)
168170

169171
result = await self.send_request(
170-
types.ClientRequest(
171-
types.InitializeRequest(
172-
params=types.InitializeRequestParams(
173-
protocol_version=types.LATEST_PROTOCOL_VERSION,
174-
capabilities=types.ClientCapabilities(
175-
sampling=sampling,
176-
elicitation=elicitation,
177-
experimental=None,
178-
roots=roots,
179-
tasks=self._task_handlers.build_capability(),
180-
),
181-
client_info=self._client_info,
172+
types.InitializeRequest(
173+
params=types.InitializeRequestParams(
174+
protocol_version=types.LATEST_PROTOCOL_VERSION,
175+
capabilities=types.ClientCapabilities(
176+
sampling=sampling,
177+
elicitation=elicitation,
178+
experimental=None,
179+
roots=roots,
180+
tasks=self._task_handlers.build_capability(),
182181
),
183-
)
182+
client_info=self._client_info,
183+
),
184184
),
185185
types.InitializeResult,
186186
)
@@ -190,7 +190,7 @@ async def initialize(self) -> types.InitializeResult:
190190

191191
self._server_capabilities = result.capabilities
192192

193-
await self.send_notification(types.ClientNotification(types.InitializedNotification()))
193+
await self.send_notification(types.InitializedNotification())
194194

195195
return result
196196

@@ -218,10 +218,7 @@ def experimental(self) -> ExperimentalClientFeatures:
218218

219219
async def send_ping(self) -> types.EmptyResult:
220220
"""Send a ping request."""
221-
return await self.send_request(
222-
types.ClientRequest(types.PingRequest()),
223-
types.EmptyResult,
224-
)
221+
return await self.send_request(types.PingRequest(), types.EmptyResult)
225222

226223
async def send_progress_notification(
227224
self,
@@ -232,26 +229,20 @@ async def send_progress_notification(
232229
) -> None:
233230
"""Send a progress notification."""
234231
await self.send_notification(
235-
types.ClientNotification(
236-
types.ProgressNotification(
237-
params=types.ProgressNotificationParams(
238-
progress_token=progress_token,
239-
progress=progress,
240-
total=total,
241-
message=message,
242-
),
232+
types.ProgressNotification(
233+
params=types.ProgressNotificationParams(
234+
progress_token=progress_token,
235+
progress=progress,
236+
total=total,
237+
message=message,
243238
),
244239
)
245240
)
246241

247242
async def set_logging_level(self, level: types.LoggingLevel) -> types.EmptyResult:
248243
"""Send a logging/setLevel request."""
249244
return await self.send_request( # pragma: no cover
250-
types.ClientRequest(
251-
types.SetLevelRequest(
252-
params=types.SetLevelRequestParams(level=level),
253-
)
254-
),
245+
types.SetLevelRequest(params=types.SetLevelRequestParams(level=level)),
255246
types.EmptyResult,
256247
)
257248

@@ -261,10 +252,7 @@ async def list_resources(self, *, params: types.PaginatedRequestParams | None =
261252
Args:
262253
params: Full pagination parameters including cursor and any future fields
263254
"""
264-
return await self.send_request(
265-
types.ClientRequest(types.ListResourcesRequest(params=params)),
266-
types.ListResourcesResult,
267-
)
255+
return await self.send_request(types.ListResourcesRequest(params=params), types.ListResourcesResult)
268256

269257
async def list_resource_templates(
270258
self, *, params: types.PaginatedRequestParams | None = None
@@ -275,28 +263,28 @@ async def list_resource_templates(
275263
params: Full pagination parameters including cursor and any future fields
276264
"""
277265
return await self.send_request(
278-
types.ClientRequest(types.ListResourceTemplatesRequest(params=params)),
266+
types.ListResourceTemplatesRequest(params=params),
279267
types.ListResourceTemplatesResult,
280268
)
281269

282270
async def read_resource(self, uri: str | AnyUrl) -> types.ReadResourceResult:
283271
"""Send a resources/read request."""
284272
return await self.send_request(
285-
types.ClientRequest(types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri)))),
273+
types.ReadResourceRequest(params=types.ReadResourceRequestParams(uri=str(uri))),
286274
types.ReadResourceResult,
287275
)
288276

289277
async def subscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult:
290278
"""Send a resources/subscribe request."""
291279
return await self.send_request( # pragma: no cover
292-
types.ClientRequest(types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri)))),
280+
types.SubscribeRequest(params=types.SubscribeRequestParams(uri=str(uri))),
293281
types.EmptyResult,
294282
)
295283

296284
async def unsubscribe_resource(self, uri: str | AnyUrl) -> types.EmptyResult:
297285
"""Send a resources/unsubscribe request."""
298286
return await self.send_request( # pragma: no cover
299-
types.ClientRequest(types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri)))),
287+
types.UnsubscribeRequest(params=types.UnsubscribeRequestParams(uri=str(uri))),
300288
types.EmptyResult,
301289
)
302290

@@ -316,10 +304,8 @@ async def call_tool(
316304
_meta = types.RequestParams.Meta(**meta)
317305

318306
result = await self.send_request(
319-
types.ClientRequest(
320-
types.CallToolRequest(
321-
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
322-
)
307+
types.CallToolRequest(
308+
params=types.CallToolRequestParams(name=name, arguments=arguments, _meta=_meta),
323309
),
324310
types.CallToolResult,
325311
request_read_timeout_seconds=read_timeout_seconds,
@@ -364,17 +350,15 @@ async def list_prompts(self, *, params: types.PaginatedRequestParams | None = No
364350
params: Full pagination parameters including cursor and any future fields
365351
"""
366352
return await self.send_request(
367-
types.ClientRequest(types.ListPromptsRequest(params=params)),
353+
types.ListPromptsRequest(params=params),
368354
types.ListPromptsResult,
369355
)
370356

371357
async def get_prompt(self, name: str, arguments: dict[str, str] | None = None) -> types.GetPromptResult:
372358
"""Send a prompts/get request."""
373359
return await self.send_request(
374-
types.ClientRequest(
375-
types.GetPromptRequest(
376-
params=types.GetPromptRequestParams(name=name, arguments=arguments),
377-
)
360+
types.GetPromptRequest(
361+
params=types.GetPromptRequestParams(name=name, arguments=arguments),
378362
),
379363
types.GetPromptResult,
380364
)
@@ -391,14 +375,12 @@ async def complete(
391375
context = types.CompletionContext(arguments=context_arguments)
392376

393377
return await self.send_request(
394-
types.ClientRequest(
395-
types.CompleteRequest(
396-
params=types.CompleteRequestParams(
397-
ref=ref,
398-
argument=types.CompletionArgument(**argument),
399-
context=context,
400-
),
401-
)
378+
types.CompleteRequest(
379+
params=types.CompleteRequestParams(
380+
ref=ref,
381+
argument=types.CompletionArgument(**argument),
382+
context=context,
383+
),
402384
),
403385
types.CompleteResult,
404386
)
@@ -410,7 +392,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None
410392
params: Full pagination parameters including cursor and any future fields
411393
"""
412394
result = await self.send_request(
413-
types.ClientRequest(types.ListToolsRequest(params=params)),
395+
types.ListToolsRequest(params=params),
414396
types.ListToolsResult,
415397
)
416398

@@ -423,7 +405,7 @@ async def list_tools(self, *, params: types.PaginatedRequestParams | None = None
423405

424406
async def send_roots_list_changed(self) -> None: # pragma: no cover
425407
"""Send a roots/list_changed notification."""
426-
await self.send_notification(types.ClientNotification(types.RootsListChangedNotification()))
408+
await self.send_notification(types.RootsListChangedNotification())
427409

428410
async def _received_request(self, responder: RequestResponder[types.ServerRequest, types.ClientResult]) -> None:
429411
ctx = RequestContext[ClientSession, Any](
@@ -440,7 +422,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
440422
return None
441423

442424
# Core request handling
443-
match responder.request.root:
425+
match responder.request:
444426
case types.CreateMessageRequest(params=params):
445427
with responder:
446428
# Check if this is a task-augmented request
@@ -469,7 +451,7 @@ async def _received_request(self, responder: RequestResponder[types.ServerReques
469451

470452
case types.PingRequest(): # pragma: no cover
471453
with responder:
472-
return await responder.respond(types.ClientResult(root=types.EmptyResult()))
454+
return await responder.respond(types.EmptyResult())
473455

474456
case _: # pragma: no cover
475457
pass # Task requests handled above by _task_handlers
@@ -486,7 +468,7 @@ async def _handle_incoming(
486468
async def _received_notification(self, notification: types.ServerNotification) -> None:
487469
"""Handle notifications from the server."""
488470
# Process specific notification types
489-
match notification.root:
471+
match notification:
490472
case types.LoggingMessageNotification(params=params):
491473
await self._logging_callback(params)
492474
case types.ElicitCompleteNotification(params=params):

0 commit comments

Comments
 (0)