|
35 | 35 | CreateMessageRequestParams, |
36 | 36 | CreateMessageResult, |
37 | 37 | CreateTaskResult, |
| 38 | + ElicitRequest, |
| 39 | + ElicitRequestFormParams, |
| 40 | + ElicitRequestParams, |
| 41 | + ElicitResult, |
38 | 42 | ErrorData, |
39 | 43 | GetTaskPayloadRequest, |
40 | 44 | GetTaskPayloadRequestParams, |
@@ -501,6 +505,150 @@ async def run_client() -> None: |
501 | 505 | store.cleanup() |
502 | 506 |
|
503 | 507 |
|
| 508 | +@pytest.mark.anyio |
| 509 | +async def test_client_task_augmented_elicitation(client_streams: ClientTestStreams) -> None: |
| 510 | + """Test that client can handle task-augmented elicitation request from server.""" |
| 511 | + with anyio.fail_after(10): |
| 512 | + store = InMemoryTaskStore() |
| 513 | + elicitation_completed = Event() |
| 514 | + created_task_id: list[str | None] = [None] |
| 515 | + background_tg: list[TaskGroup | None] = [None] |
| 516 | + |
| 517 | + async def task_augmented_elicitation_callback( |
| 518 | + context: RequestContext[ClientSession, None], |
| 519 | + params: ElicitRequestParams, |
| 520 | + task_metadata: TaskMetadata, |
| 521 | + ) -> CreateTaskResult | ErrorData: |
| 522 | + task = await store.create_task(task_metadata) |
| 523 | + created_task_id[0] = task.taskId |
| 524 | + |
| 525 | + async def do_elicitation() -> None: |
| 526 | + # Simulate user providing elicitation response |
| 527 | + result = ElicitResult(action="accept", content={"name": "Test User"}) |
| 528 | + await store.store_result(task.taskId, result) |
| 529 | + await store.update_task(task.taskId, status="completed") |
| 530 | + elicitation_completed.set() |
| 531 | + |
| 532 | + assert background_tg[0] is not None |
| 533 | + background_tg[0].start_soon(do_elicitation) |
| 534 | + return CreateTaskResult(task=task) |
| 535 | + |
| 536 | + async def get_task_handler( |
| 537 | + context: RequestContext[ClientSession, None], |
| 538 | + params: GetTaskRequestParams, |
| 539 | + ) -> GetTaskResult | ErrorData: |
| 540 | + task = await store.get_task(params.taskId) |
| 541 | + if task is None: |
| 542 | + return ErrorData(code=types.INVALID_REQUEST, message="Task not found") |
| 543 | + return GetTaskResult( |
| 544 | + taskId=task.taskId, |
| 545 | + status=task.status, |
| 546 | + statusMessage=task.statusMessage, |
| 547 | + createdAt=task.createdAt, |
| 548 | + lastUpdatedAt=task.lastUpdatedAt, |
| 549 | + ttl=task.ttl, |
| 550 | + pollInterval=task.pollInterval, |
| 551 | + ) |
| 552 | + |
| 553 | + async def get_task_result_handler( |
| 554 | + context: RequestContext[ClientSession, None], |
| 555 | + params: GetTaskPayloadRequestParams, |
| 556 | + ) -> GetTaskPayloadResult | ErrorData: |
| 557 | + result = await store.get_result(params.taskId) |
| 558 | + if result is None: |
| 559 | + return ErrorData(code=types.INVALID_REQUEST, message="Result not found") |
| 560 | + assert isinstance(result, ElicitResult) |
| 561 | + return GetTaskPayloadResult(**result.model_dump()) |
| 562 | + |
| 563 | + task_handlers = ExperimentalTaskHandlers( |
| 564 | + augmented_elicitation=task_augmented_elicitation_callback, |
| 565 | + get_task=get_task_handler, |
| 566 | + get_task_result=get_task_result_handler, |
| 567 | + ) |
| 568 | + client_ready = anyio.Event() |
| 569 | + |
| 570 | + async with anyio.create_task_group() as tg: |
| 571 | + background_tg[0] = tg |
| 572 | + |
| 573 | + async def run_client() -> None: |
| 574 | + async with ClientSession( |
| 575 | + client_streams.client_receive, |
| 576 | + client_streams.client_send, |
| 577 | + message_handler=_default_message_handler, |
| 578 | + experimental_task_handlers=task_handlers, |
| 579 | + ): |
| 580 | + client_ready.set() |
| 581 | + await anyio.sleep_forever() |
| 582 | + |
| 583 | + tg.start_soon(run_client) |
| 584 | + await client_ready.wait() |
| 585 | + |
| 586 | + # Step 1: Server sends task-augmented ElicitRequest |
| 587 | + typed_request = ElicitRequest( |
| 588 | + params=ElicitRequestFormParams( |
| 589 | + message="What is your name?", |
| 590 | + requestedSchema={"type": "object", "properties": {"name": {"type": "string"}}}, |
| 591 | + task=TaskMetadata(ttl=60000), |
| 592 | + ) |
| 593 | + ) |
| 594 | + request = types.JSONRPCRequest( |
| 595 | + jsonrpc="2.0", |
| 596 | + id="req-elicit", |
| 597 | + **typed_request.model_dump(by_alias=True), |
| 598 | + ) |
| 599 | + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(request))) |
| 600 | + |
| 601 | + # Step 2: Client responds with CreateTaskResult |
| 602 | + response_msg = await client_streams.server_receive.receive() |
| 603 | + response = response_msg.message.root |
| 604 | + assert isinstance(response, types.JSONRPCResponse) |
| 605 | + |
| 606 | + task_result = CreateTaskResult.model_validate(response.result) |
| 607 | + task_id = task_result.task.taskId |
| 608 | + assert task_id == created_task_id[0] |
| 609 | + |
| 610 | + # Step 3: Wait for background elicitation |
| 611 | + await elicitation_completed.wait() |
| 612 | + |
| 613 | + # Step 4: Server polls task status |
| 614 | + typed_poll = GetTaskRequest(params=GetTaskRequestParams(taskId=task_id)) |
| 615 | + poll_request = types.JSONRPCRequest( |
| 616 | + jsonrpc="2.0", |
| 617 | + id="req-poll", |
| 618 | + **typed_poll.model_dump(by_alias=True), |
| 619 | + ) |
| 620 | + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(poll_request))) |
| 621 | + |
| 622 | + poll_response_msg = await client_streams.server_receive.receive() |
| 623 | + poll_response = poll_response_msg.message.root |
| 624 | + assert isinstance(poll_response, types.JSONRPCResponse) |
| 625 | + |
| 626 | + status = GetTaskResult.model_validate(poll_response.result) |
| 627 | + assert status.status == "completed" |
| 628 | + |
| 629 | + # Step 5: Server gets result |
| 630 | + typed_result_req = GetTaskPayloadRequest(params=GetTaskPayloadRequestParams(taskId=task_id)) |
| 631 | + result_request = types.JSONRPCRequest( |
| 632 | + jsonrpc="2.0", |
| 633 | + id="req-result", |
| 634 | + **typed_result_req.model_dump(by_alias=True), |
| 635 | + ) |
| 636 | + await client_streams.server_send.send(SessionMessage(types.JSONRPCMessage(result_request))) |
| 637 | + |
| 638 | + result_response_msg = await client_streams.server_receive.receive() |
| 639 | + result_response = result_response_msg.message.root |
| 640 | + assert isinstance(result_response, types.JSONRPCResponse) |
| 641 | + |
| 642 | + # Verify the elicitation result |
| 643 | + assert isinstance(result_response.result, dict) |
| 644 | + assert result_response.result["action"] == "accept" |
| 645 | + assert result_response.result["content"] == {"name": "Test User"} |
| 646 | + |
| 647 | + tg.cancel_scope.cancel() |
| 648 | + |
| 649 | + store.cleanup() |
| 650 | + |
| 651 | + |
504 | 652 | @pytest.mark.anyio |
505 | 653 | async def test_client_returns_error_for_unhandled_task_request(client_streams: ClientTestStreams) -> None: |
506 | 654 | """Test that client returns error when no handler is registered for task request.""" |
|
0 commit comments