Skip to content

Commit 499602e

Browse files
committed
Use typed request classes instead of hardcoded method strings in test_handlers.py
Replace raw JSONRPCRequest construction with typed request classes: - GetTaskRequest/GetTaskRequestParams for tasks/get - GetTaskPayloadRequest/GetTaskPayloadRequestParams for tasks/result - ListTasksRequest for tasks/list - CancelTaskRequest/CancelTaskRequestParams for tasks/cancel - CreateMessageRequest/CreateMessageRequestParams for sampling/createMessage This eliminates hardcoded method strings and ensures params are validated.
1 parent 76b3a26 commit 499602e

File tree

1 file changed

+28
-19
lines changed

1 file changed

+28
-19
lines changed

tests/experimental/tasks/client/test_handlers.py

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,24 @@
2727
from mcp.shared.message import SessionMessage
2828
from mcp.shared.session import RequestResponder
2929
from mcp.types import (
30+
CancelTaskRequest,
3031
CancelTaskRequestParams,
3132
CancelTaskResult,
3233
ClientResult,
34+
CreateMessageRequest,
3335
CreateMessageRequestParams,
3436
CreateMessageResult,
3537
CreateTaskResult,
3638
ErrorData,
39+
GetTaskPayloadRequest,
3740
GetTaskPayloadRequestParams,
3841
GetTaskPayloadResult,
42+
GetTaskRequest,
3943
GetTaskRequestParams,
4044
GetTaskResult,
45+
ListTasksRequest,
4146
ListTasksResult,
47+
SamplingMessage,
4248
ServerNotification,
4349
ServerRequest,
4450
TaskMetadata,
@@ -142,11 +148,11 @@ async def run_client() -> None:
142148
tg.start_soon(run_client)
143149
await client_ready.wait()
144150

151+
typed_request = GetTaskRequest(params=GetTaskRequestParams(taskId="test-task-123"))
145152
request = types.JSONRPCRequest(
146153
jsonrpc="2.0",
147154
id="req-1",
148-
method="tasks/get",
149-
params={"taskId": "test-task-123"},
155+
**typed_request.model_dump(by_alias=True),
150156
)
151157
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
152158

@@ -206,11 +212,11 @@ async def run_client() -> None:
206212
tg.start_soon(run_client)
207213
await client_ready.wait()
208214

215+
typed_request = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId="test-task-456"))
209216
request = types.JSONRPCRequest(
210217
jsonrpc="2.0",
211218
id="req-2",
212-
method="tasks/result",
213-
params={"taskId": "test-task-456"},
219+
**typed_request.model_dump(by_alias=True),
214220
)
215221
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
216222

@@ -264,10 +270,11 @@ async def run_client() -> None:
264270
tg.start_soon(run_client)
265271
await client_ready.wait()
266272

273+
typed_request = ListTasksRequest()
267274
request = types.JSONRPCRequest(
268275
jsonrpc="2.0",
269276
id="req-3",
270-
method="tasks/list",
277+
**typed_request.model_dump(by_alias=True),
271278
)
272279
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
273280

@@ -327,11 +334,11 @@ async def run_client() -> None:
327334
tg.start_soon(run_client)
328335
await client_ready.wait()
329336

337+
typed_request = CancelTaskRequest(params=CancelTaskRequestParams(taskId="task-to-cancel"))
330338
request = types.JSONRPCRequest(
331339
jsonrpc="2.0",
332340
id="req-4",
333-
method="tasks/cancel",
334-
params={"taskId": "task-to-cancel"},
341+
**typed_request.model_dump(by_alias=True),
335342
)
336343
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
337344

@@ -431,15 +438,17 @@ async def run_client() -> None:
431438
await client_ready.wait()
432439

433440
# Step 1: Server sends task-augmented CreateMessageRequest
441+
typed_request = CreateMessageRequest(
442+
params=CreateMessageRequestParams(
443+
messages=[SamplingMessage(role="user", content=TextContent(type="text", text="Hello"))],
444+
maxTokens=100,
445+
task=TaskMetadata(ttl=60000),
446+
)
447+
)
434448
request = types.JSONRPCRequest(
435449
jsonrpc="2.0",
436450
id="req-sampling",
437-
method="sampling/createMessage",
438-
params={
439-
"messages": [{"role": "user", "content": {"type": "text", "text": "Hello"}}],
440-
"maxTokens": 100,
441-
"task": {"ttl": 60000},
442-
},
451+
**typed_request.model_dump(by_alias=True),
443452
)
444453
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
445454

@@ -456,11 +465,11 @@ async def run_client() -> None:
456465
await sampling_completed.wait()
457466

458467
# Step 4: Server polls task status
468+
typed_poll = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))
459469
poll_request = types.JSONRPCRequest(
460470
jsonrpc="2.0",
461471
id="req-poll",
462-
method="tasks/get",
463-
params={"taskId": task_id},
472+
**typed_poll.model_dump(by_alias=True),
464473
)
465474
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request)))
466475

@@ -472,11 +481,11 @@ async def run_client() -> None:
472481
assert status.status == "completed"
473482

474483
# Step 5: Server gets result
484+
typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id))
475485
result_request = types.JSONRPCRequest(
476486
jsonrpc="2.0",
477487
id="req-result",
478-
method="tasks/result",
479-
params={"taskId": task_id},
488+
**typed_result_req.model_dump(by_alias=True),
480489
)
481490
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request)))
482491

@@ -512,11 +521,11 @@ async def run_client() -> None:
512521
tg.start_soon(run_client)
513522
await client_ready.wait()
514523

524+
typed_request = GetTaskRequest(params=GetTaskRequestParams(taskId="nonexistent"))
515525
request = types.JSONRPCRequest(
516526
jsonrpc="2.0",
517527
id="req-unhandled",
518-
method="tasks/get",
519-
params={"taskId": "nonexistent"},
528+
**typed_request.model_dump(by_alias=True),
520529
)
521530
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
522531

0 commit comments

Comments
 (0)