|
7 | 7 | from examples.shared.in_memory_task_store import InMemoryTaskStore |
8 | 8 | from mcp.client.session import ClientSession |
9 | 9 | from mcp.server import Server |
| 10 | +from mcp.server.models import InitializationOptions |
| 11 | +from mcp.server.session import ServerSession |
10 | 12 | from mcp.shared.memory import create_client_server_memory_streams |
| 13 | +from mcp.shared.message import SessionMessage |
11 | 14 |
|
12 | 15 |
|
13 | 16 | @pytest.mark.anyio |
@@ -613,3 +616,298 @@ async def handle_tool(name: str, arguments: dict[str, str]) -> list[types.TextCo |
613 | 616 | assert server_task.status == "submitted" |
614 | 617 | finally: |
615 | 618 | tg.cancel_scope.cancel() |
| 619 | + |
| 620 | + |
| 621 | +@pytest.fixture |
| 622 | +async def server_session(): |
| 623 | + """Create a ServerSession for testing capability checking.""" |
| 624 | + from_client, to_server = anyio.create_memory_object_stream[SessionMessage](1) |
| 625 | + from_server, to_client = anyio.create_memory_object_stream[SessionMessage](1) |
| 626 | + |
| 627 | + session = ServerSession( |
| 628 | + to_server, |
| 629 | + from_server, |
| 630 | + InitializationOptions( |
| 631 | + server_name="test", |
| 632 | + server_version="1.0.0", |
| 633 | + capabilities=types.ServerCapabilities(), |
| 634 | + ), |
| 635 | + ) |
| 636 | + |
| 637 | + yield session |
| 638 | + |
| 639 | + # Cleanup |
| 640 | + await from_client.aclose() |
| 641 | + await to_server.aclose() |
| 642 | + await from_server.aclose() |
| 643 | + await to_client.aclose() |
| 644 | + |
| 645 | + |
| 646 | +@pytest.mark.anyio |
| 647 | +async def test_check_tasks_capability_no_requirements(server_session: ServerSession): |
| 648 | + """Test _check_tasks_capability returns True when no requirements specified.""" |
| 649 | + required = types.ClientTasksCapability(requests=None) |
| 650 | + client = types.ClientTasksCapability(requests=None) |
| 651 | + |
| 652 | + result = server_session._check_tasks_capability(required, client) |
| 653 | + assert result is True |
| 654 | + |
| 655 | + |
| 656 | +@pytest.mark.anyio |
| 657 | +async def test_check_tasks_capability_client_missing_requests(server_session: ServerSession): |
| 658 | + """Test _check_tasks_capability returns False when client has no requests capability.""" |
| 659 | + required = types.ClientTasksCapability( |
| 660 | + requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True)) |
| 661 | + ) |
| 662 | + client = types.ClientTasksCapability(requests=None) |
| 663 | + |
| 664 | + result = server_session._check_tasks_capability(required, client) |
| 665 | + assert result is False |
| 666 | + |
| 667 | + |
| 668 | +@pytest.mark.anyio |
| 669 | +async def test_check_tasks_capability_sampling_missing(server_session: ServerSession): |
| 670 | + """Test _check_tasks_capability returns False when sampling capability missing.""" |
| 671 | + required = types.ClientTasksCapability( |
| 672 | + requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True)) |
| 673 | + ) |
| 674 | + client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(sampling=None)) |
| 675 | + |
| 676 | + result = server_session._check_tasks_capability(required, client) |
| 677 | + assert result is False |
| 678 | + |
| 679 | + |
| 680 | +@pytest.mark.anyio |
| 681 | +async def test_check_tasks_capability_sampling_createMessage_false(server_session: ServerSession): |
| 682 | + """Test _check_tasks_capability returns False when sampling.createMessage is False.""" |
| 683 | + required = types.ClientTasksCapability( |
| 684 | + requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True)) |
| 685 | + ) |
| 686 | + client = types.ClientTasksCapability( |
| 687 | + requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=False)) |
| 688 | + ) |
| 689 | + |
| 690 | + result = server_session._check_tasks_capability(required, client) |
| 691 | + assert result is False |
| 692 | + |
| 693 | + |
| 694 | +@pytest.mark.anyio |
| 695 | +async def test_check_tasks_capability_sampling_success(server_session: ServerSession): |
| 696 | + """Test _check_tasks_capability returns True when sampling capability matches.""" |
| 697 | + required = types.ClientTasksCapability( |
| 698 | + requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True)) |
| 699 | + ) |
| 700 | + client = types.ClientTasksCapability( |
| 701 | + requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True)) |
| 702 | + ) |
| 703 | + |
| 704 | + result = server_session._check_tasks_capability(required, client) |
| 705 | + assert result is True |
| 706 | + |
| 707 | + |
| 708 | +@pytest.mark.anyio |
| 709 | +async def test_check_tasks_capability_elicitation_missing(server_session: ServerSession): |
| 710 | + """Test _check_tasks_capability returns False when elicitation capability missing.""" |
| 711 | + required = types.ClientTasksCapability( |
| 712 | + requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True)) |
| 713 | + ) |
| 714 | + client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(elicitation=None)) |
| 715 | + |
| 716 | + result = server_session._check_tasks_capability(required, client) |
| 717 | + assert result is False |
| 718 | + |
| 719 | + |
| 720 | +@pytest.mark.anyio |
| 721 | +async def test_check_tasks_capability_elicitation_create_false(server_session: ServerSession): |
| 722 | + """Test _check_tasks_capability returns False when elicitation.create is False.""" |
| 723 | + required = types.ClientTasksCapability( |
| 724 | + requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True)) |
| 725 | + ) |
| 726 | + client = types.ClientTasksCapability( |
| 727 | + requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=False)) |
| 728 | + ) |
| 729 | + |
| 730 | + result = server_session._check_tasks_capability(required, client) |
| 731 | + assert result is False |
| 732 | + |
| 733 | + |
| 734 | +@pytest.mark.anyio |
| 735 | +async def test_check_tasks_capability_elicitation_success(server_session: ServerSession): |
| 736 | + """Test _check_tasks_capability returns True when elicitation capability matches.""" |
| 737 | + required = types.ClientTasksCapability( |
| 738 | + requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True)) |
| 739 | + ) |
| 740 | + client = types.ClientTasksCapability( |
| 741 | + requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True)) |
| 742 | + ) |
| 743 | + |
| 744 | + result = server_session._check_tasks_capability(required, client) |
| 745 | + assert result is True |
| 746 | + |
| 747 | + |
| 748 | +@pytest.mark.anyio |
| 749 | +async def test_check_tasks_capability_roots_missing(server_session: ServerSession): |
| 750 | + """Test _check_tasks_capability returns False when roots capability missing.""" |
| 751 | + required = types.ClientTasksCapability( |
| 752 | + requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True)) |
| 753 | + ) |
| 754 | + client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(roots=None)) |
| 755 | + |
| 756 | + result = server_session._check_tasks_capability(required, client) |
| 757 | + assert result is False |
| 758 | + |
| 759 | + |
| 760 | +@pytest.mark.anyio |
| 761 | +async def test_check_tasks_capability_roots_list_false(server_session: ServerSession): |
| 762 | + """Test _check_tasks_capability returns False when roots.list is False.""" |
| 763 | + required = types.ClientTasksCapability( |
| 764 | + requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True)) |
| 765 | + ) |
| 766 | + client = types.ClientTasksCapability( |
| 767 | + requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=False)) |
| 768 | + ) |
| 769 | + |
| 770 | + result = server_session._check_tasks_capability(required, client) |
| 771 | + assert result is False |
| 772 | + |
| 773 | + |
| 774 | +@pytest.mark.anyio |
| 775 | +async def test_check_tasks_capability_roots_success(server_session: ServerSession): |
| 776 | + """Test _check_tasks_capability returns True when roots capability matches.""" |
| 777 | + required = types.ClientTasksCapability( |
| 778 | + requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True)) |
| 779 | + ) |
| 780 | + client = types.ClientTasksCapability( |
| 781 | + requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True)) |
| 782 | + ) |
| 783 | + |
| 784 | + result = server_session._check_tasks_capability(required, client) |
| 785 | + assert result is True |
| 786 | + |
| 787 | + |
| 788 | +@pytest.mark.anyio |
| 789 | +async def test_check_tasks_capability_tasks_missing(server_session: ServerSession): |
| 790 | + """Test _check_tasks_capability returns False when tasks capability missing.""" |
| 791 | + required = types.ClientTasksCapability( |
| 792 | + requests=types.ClientTasksRequestsCapability( |
| 793 | + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True) |
| 794 | + ) |
| 795 | + ) |
| 796 | + client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(tasks=None)) |
| 797 | + |
| 798 | + result = server_session._check_tasks_capability(required, client) |
| 799 | + assert result is False |
| 800 | + |
| 801 | + |
| 802 | +@pytest.mark.anyio |
| 803 | +async def test_check_tasks_capability_tasks_get_false(server_session: ServerSession): |
| 804 | + """Test _check_tasks_capability returns False when tasks.get is False.""" |
| 805 | + required = types.ClientTasksCapability( |
| 806 | + requests=types.ClientTasksRequestsCapability( |
| 807 | + tasks=types.TasksOperationsCapability(get=True, list=False, result=False, delete=False) |
| 808 | + ) |
| 809 | + ) |
| 810 | + client = types.ClientTasksCapability( |
| 811 | + requests=types.ClientTasksRequestsCapability( |
| 812 | + tasks=types.TasksOperationsCapability(get=False, list=True, result=True, delete=True) |
| 813 | + ) |
| 814 | + ) |
| 815 | + |
| 816 | + result = server_session._check_tasks_capability(required, client) |
| 817 | + assert result is False |
| 818 | + |
| 819 | + |
| 820 | +@pytest.mark.anyio |
| 821 | +async def test_check_tasks_capability_tasks_list_false(server_session: ServerSession): |
| 822 | + """Test _check_tasks_capability returns False when tasks.list is False.""" |
| 823 | + required = types.ClientTasksCapability( |
| 824 | + requests=types.ClientTasksRequestsCapability( |
| 825 | + tasks=types.TasksOperationsCapability(get=False, list=True, result=False, delete=False) |
| 826 | + ) |
| 827 | + ) |
| 828 | + client = types.ClientTasksCapability( |
| 829 | + requests=types.ClientTasksRequestsCapability( |
| 830 | + tasks=types.TasksOperationsCapability(get=True, list=False, result=True, delete=True) |
| 831 | + ) |
| 832 | + ) |
| 833 | + |
| 834 | + result = server_session._check_tasks_capability(required, client) |
| 835 | + assert result is False |
| 836 | + |
| 837 | + |
| 838 | +@pytest.mark.anyio |
| 839 | +async def test_check_tasks_capability_tasks_result_false(server_session: ServerSession): |
| 840 | + """Test _check_tasks_capability returns False when tasks.result is False.""" |
| 841 | + required = types.ClientTasksCapability( |
| 842 | + requests=types.ClientTasksRequestsCapability( |
| 843 | + tasks=types.TasksOperationsCapability(get=False, list=False, result=True, delete=False) |
| 844 | + ) |
| 845 | + ) |
| 846 | + client = types.ClientTasksCapability( |
| 847 | + requests=types.ClientTasksRequestsCapability( |
| 848 | + tasks=types.TasksOperationsCapability(get=True, list=True, result=False, delete=True) |
| 849 | + ) |
| 850 | + ) |
| 851 | + |
| 852 | + result = server_session._check_tasks_capability(required, client) |
| 853 | + assert result is False |
| 854 | + |
| 855 | + |
| 856 | +@pytest.mark.anyio |
| 857 | +async def test_check_tasks_capability_tasks_delete_false(server_session: ServerSession): |
| 858 | + """Test _check_tasks_capability returns False when tasks.delete is False.""" |
| 859 | + required = types.ClientTasksCapability( |
| 860 | + requests=types.ClientTasksRequestsCapability( |
| 861 | + tasks=types.TasksOperationsCapability(get=False, list=False, result=False, delete=True) |
| 862 | + ) |
| 863 | + ) |
| 864 | + client = types.ClientTasksCapability( |
| 865 | + requests=types.ClientTasksRequestsCapability( |
| 866 | + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=False) |
| 867 | + ) |
| 868 | + ) |
| 869 | + |
| 870 | + result = server_session._check_tasks_capability(required, client) |
| 871 | + assert result is False |
| 872 | + |
| 873 | + |
| 874 | +@pytest.mark.anyio |
| 875 | +async def test_check_tasks_capability_tasks_all_operations_true(server_session: ServerSession): |
| 876 | + """Test _check_tasks_capability returns True when all required task operations match.""" |
| 877 | + required = types.ClientTasksCapability( |
| 878 | + requests=types.ClientTasksRequestsCapability( |
| 879 | + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True) |
| 880 | + ) |
| 881 | + ) |
| 882 | + client = types.ClientTasksCapability( |
| 883 | + requests=types.ClientTasksRequestsCapability( |
| 884 | + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True) |
| 885 | + ) |
| 886 | + ) |
| 887 | + |
| 888 | + result = server_session._check_tasks_capability(required, client) |
| 889 | + assert result is True |
| 890 | + |
| 891 | + |
| 892 | +@pytest.mark.anyio |
| 893 | +async def test_check_tasks_capability_all_capabilities_present(server_session: ServerSession): |
| 894 | + """Test _check_tasks_capability returns True when all capabilities are satisfied.""" |
| 895 | + required = types.ClientTasksCapability( |
| 896 | + requests=types.ClientTasksRequestsCapability( |
| 897 | + sampling=types.TaskSamplingCapability(createMessage=True), |
| 898 | + elicitation=types.TaskElicitationCapability(create=True), |
| 899 | + roots=types.TaskRootsCapability(list=True), |
| 900 | + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True), |
| 901 | + ) |
| 902 | + ) |
| 903 | + client = types.ClientTasksCapability( |
| 904 | + requests=types.ClientTasksRequestsCapability( |
| 905 | + sampling=types.TaskSamplingCapability(createMessage=True), |
| 906 | + elicitation=types.TaskElicitationCapability(create=True), |
| 907 | + roots=types.TaskRootsCapability(list=True), |
| 908 | + tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True), |
| 909 | + ) |
| 910 | + ) |
| 911 | + |
| 912 | + result = server_session._check_tasks_capability(required, client) |
| 913 | + assert result is True |
0 commit comments