Skip to content

Commit 2f3b792

Browse files
committed
Add test for task-augmented elicitation (covers client/session.py line 569)
Adds test_client_task_augmented_elicitation to test the client-side handling of task-augmented elicitation requests from servers, similar to the existing test_client_task_augmented_sampling test.
1 parent 4eb5e45 commit 2f3b792

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed

tests/experimental/tasks/client/test_handlers.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
CreateMessageRequestParams,
3636
CreateMessageResult,
3737
CreateTaskResult,
38+
ElicitRequest,
39+
ElicitRequestFormParams,
40+
ElicitRequestParams,
41+
ElicitResult,
3842
ErrorData,
3943
GetTaskPayloadRequest,
4044
GetTaskPayloadRequestParams,
@@ -501,6 +505,150 @@ async def run_client() -> None:
501505
store.cleanup()
502506

503507

508+
@pytest.mark.anyio
509+
async def test_client_task_augmented_elicitation(client_streams: ClientTestStreams) -> None:
510+
"""Test that client can handle task-augmented elicitation request from server."""
511+
with anyio.fail_after(10):
512+
store = InMemoryTaskStore()
513+
elicitation_completed = Event()
514+
created_task_id: list[str | None] = [None]
515+
background_tg: list[TaskGroup | None] = [None]
516+
517+
async def task_augmented_elicitation_callback(
518+
context: RequestContext[ClientSession, None],
519+
params: ElicitRequestParams,
520+
task_metadata: TaskMetadata,
521+
) -> CreateTaskResult | ErrorData:
522+
task = await store.create_task(task_metadata)
523+
created_task_id[0] = task.taskId
524+
525+
async def do_elicitation() -> None:
526+
# Simulate user providing elicitation response
527+
result = ElicitResult(action="accept", content={"name": "Test User"})
528+
await store.store_result(task.taskId, result)
529+
await store.update_task(task.taskId, status="completed")
530+
elicitation_completed.set()
531+
532+
assert background_tg[0] is not None
533+
background_tg[0].start_soon(do_elicitation)
534+
return CreateTaskResult(task=task)
535+
536+
async def get_task_handler(
537+
context: RequestContext[ClientSession, None],
538+
params: GetTaskRequestParams,
539+
) -> GetTaskResult | ErrorData:
540+
task = await store.get_task(params.taskId)
541+
if task is None:
542+
return ErrorData(code=types.INVALID_REQUEST, message="Task not found")
543+
return GetTaskResult(
544+
taskId=task.taskId,
545+
status=task.status,
546+
statusMessage=task.statusMessage,
547+
createdAt=task.createdAt,
548+
lastUpdatedAt=task.lastUpdatedAt,
549+
ttl=task.ttl,
550+
pollInterval=task.pollInterval,
551+
)
552+
553+
async def get_task_result_handler(
554+
context: RequestContext[ClientSession, None],
555+
params: GetTaskPayloadRequestParams,
556+
) -> GetTaskPayloadResult | ErrorData:
557+
result = await store.get_result(params.taskId)
558+
if result is None:
559+
return ErrorData(code=types.INVALID_REQUEST, message="Result not found")
560+
assert isinstance(result, ElicitResult)
561+
return GetTaskPayloadResult(**result.model_dump())
562+
563+
task_handlers = ExperimentalTaskHandlers(
564+
augmented_elicitation=task_augmented_elicitation_callback,
565+
get_task=get_task_handler,
566+
get_task_result=get_task_result_handler,
567+
)
568+
client_ready = anyio.Event()
569+
570+
async with anyio.create_task_group() as tg:
571+
background_tg[0] = tg
572+
573+
async def run_client() -> None:
574+
async with ClientSession(
575+
client_streams.client_receive,
576+
client_streams.client_send,
577+
message_handler=_default_message_handler,
578+
experimental_task_handlers=task_handlers,
579+
):
580+
client_ready.set()
581+
await anyio.sleep_forever()
582+
583+
tg.start_soon(run_client)
584+
await client_ready.wait()
585+
586+
# Step 1: Server sends task-augmented ElicitRequest
587+
typed_request = ElicitRequest(
588+
params=ElicitRequestFormParams(
589+
message="What is your name?",
590+
requestedSchema={"type": "object", "properties": {"name": {"type": "string"}}},
591+
task=TaskMetadata(ttl=60000),
592+
)
593+
)
594+
request = types.JSONRPCRequest(
595+
jsonrpc="2.0",
596+
id="req-elicit",
597+
**typed_request.model_dump(by_alias=True),
598+
)
599+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request)))
600+
601+
# Step 2: Client responds with CreateTaskResult
602+
response_msg = await client_streams.server_receive.receive()
603+
response = response_msg.message.root
604+
assert isinstance(response, types.JSONRPCResponse)
605+
606+
task_result = CreateTaskResult.model_validate(response.result)
607+
task_id = task_result.task.taskId
608+
assert task_id == created_task_id[0]
609+
610+
# Step 3: Wait for background elicitation
611+
await elicitation_completed.wait()
612+
613+
# Step 4: Server polls task status
614+
typed_poll = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id))
615+
poll_request = types.JSONRPCRequest(
616+
jsonrpc="2.0",
617+
id="req-poll",
618+
**typed_poll.model_dump(by_alias=True),
619+
)
620+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request)))
621+
622+
poll_response_msg = await client_streams.server_receive.receive()
623+
poll_response = poll_response_msg.message.root
624+
assert isinstance(poll_response, types.JSONRPCResponse)
625+
626+
status = GetTaskResult.model_validate(poll_response.result)
627+
assert status.status == "completed"
628+
629+
# Step 5: Server gets result
630+
typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id))
631+
result_request = types.JSONRPCRequest(
632+
jsonrpc="2.0",
633+
id="req-result",
634+
**typed_result_req.model_dump(by_alias=True),
635+
)
636+
await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request)))
637+
638+
result_response_msg = await client_streams.server_receive.receive()
639+
result_response = result_response_msg.message.root
640+
assert isinstance(result_response, types.JSONRPCResponse)
641+
642+
# Verify the elicitation result
643+
assert isinstance(result_response.result, dict)
644+
assert result_response.result["action"] == "accept"
645+
assert result_response.result["content"] == {"name": "Test User"}
646+
647+
tg.cancel_scope.cancel()
648+
649+
store.cleanup()
650+
651+
504652
@pytest.mark.anyio
505653
async def test_client_returns_error_for_unhandled_task_request(client_streams: ClientTestStreams) -> None:
506654
"""Test that client returns error when no handler is registered for task request."""

0 commit comments

Comments
 (0)