Skip to content

Commit 19c3939

Browse files
committed
Implement more capability tests
1 parent 4096f2e commit 19c3939

File tree

1 file changed

+298
-0
lines changed

1 file changed

+298
-0
lines changed

tests/server/test_session_tasks.py

Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from examples.shared.in_memory_task_store import InMemoryTaskStore
88
from mcp.client.session import ClientSession
99
from mcp.server import Server
10+
from mcp.server.models import InitializationOptions
11+
from mcp.server.session import ServerSession
1012
from mcp.shared.memory import create_client_server_memory_streams
13+
from mcp.shared.message import SessionMessage
1114

1215

1316
@pytest.mark.anyio
@@ -613,3 +616,298 @@ async def handle_tool(name: str, arguments: dict[str, str]) -> list[types.TextCo
613616
assert server_task.status == "submitted"
614617
finally:
615618
tg.cancel_scope.cancel()
619+
620+
621+
@pytest.fixture
622+
async def server_session():
623+
"""Create a ServerSession for testing capability checking."""
624+
from_client, to_server = anyio.create_memory_object_stream[SessionMessage](1)
625+
from_server, to_client = anyio.create_memory_object_stream[SessionMessage](1)
626+
627+
session = ServerSession(
628+
to_server,
629+
from_server,
630+
InitializationOptions(
631+
server_name="test",
632+
server_version="1.0.0",
633+
capabilities=types.ServerCapabilities(),
634+
),
635+
)
636+
637+
yield session
638+
639+
# Cleanup
640+
await from_client.aclose()
641+
await to_server.aclose()
642+
await from_server.aclose()
643+
await to_client.aclose()
644+
645+
646+
@pytest.mark.anyio
647+
async def test_check_tasks_capability_no_requirements(server_session: ServerSession):
648+
"""Test _check_tasks_capability returns True when no requirements specified."""
649+
required = types.ClientTasksCapability(requests=None)
650+
client = types.ClientTasksCapability(requests=None)
651+
652+
result = server_session._check_tasks_capability(required, client)
653+
assert result is True
654+
655+
656+
@pytest.mark.anyio
657+
async def test_check_tasks_capability_client_missing_requests(server_session: ServerSession):
658+
"""Test _check_tasks_capability returns False when client has no requests capability."""
659+
required = types.ClientTasksCapability(
660+
requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True))
661+
)
662+
client = types.ClientTasksCapability(requests=None)
663+
664+
result = server_session._check_tasks_capability(required, client)
665+
assert result is False
666+
667+
668+
@pytest.mark.anyio
669+
async def test_check_tasks_capability_sampling_missing(server_session: ServerSession):
670+
"""Test _check_tasks_capability returns False when sampling capability missing."""
671+
required = types.ClientTasksCapability(
672+
requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True))
673+
)
674+
client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(sampling=None))
675+
676+
result = server_session._check_tasks_capability(required, client)
677+
assert result is False
678+
679+
680+
@pytest.mark.anyio
681+
async def test_check_tasks_capability_sampling_createMessage_false(server_session: ServerSession):
682+
"""Test _check_tasks_capability returns False when sampling.createMessage is False."""
683+
required = types.ClientTasksCapability(
684+
requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True))
685+
)
686+
client = types.ClientTasksCapability(
687+
requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=False))
688+
)
689+
690+
result = server_session._check_tasks_capability(required, client)
691+
assert result is False
692+
693+
694+
@pytest.mark.anyio
695+
async def test_check_tasks_capability_sampling_success(server_session: ServerSession):
696+
"""Test _check_tasks_capability returns True when sampling capability matches."""
697+
required = types.ClientTasksCapability(
698+
requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True))
699+
)
700+
client = types.ClientTasksCapability(
701+
requests=types.ClientTasksRequestsCapability(sampling=types.TaskSamplingCapability(createMessage=True))
702+
)
703+
704+
result = server_session._check_tasks_capability(required, client)
705+
assert result is True
706+
707+
708+
@pytest.mark.anyio
709+
async def test_check_tasks_capability_elicitation_missing(server_session: ServerSession):
710+
"""Test _check_tasks_capability returns False when elicitation capability missing."""
711+
required = types.ClientTasksCapability(
712+
requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True))
713+
)
714+
client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(elicitation=None))
715+
716+
result = server_session._check_tasks_capability(required, client)
717+
assert result is False
718+
719+
720+
@pytest.mark.anyio
721+
async def test_check_tasks_capability_elicitation_create_false(server_session: ServerSession):
722+
"""Test _check_tasks_capability returns False when elicitation.create is False."""
723+
required = types.ClientTasksCapability(
724+
requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True))
725+
)
726+
client = types.ClientTasksCapability(
727+
requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=False))
728+
)
729+
730+
result = server_session._check_tasks_capability(required, client)
731+
assert result is False
732+
733+
734+
@pytest.mark.anyio
735+
async def test_check_tasks_capability_elicitation_success(server_session: ServerSession):
736+
"""Test _check_tasks_capability returns True when elicitation capability matches."""
737+
required = types.ClientTasksCapability(
738+
requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True))
739+
)
740+
client = types.ClientTasksCapability(
741+
requests=types.ClientTasksRequestsCapability(elicitation=types.TaskElicitationCapability(create=True))
742+
)
743+
744+
result = server_session._check_tasks_capability(required, client)
745+
assert result is True
746+
747+
748+
@pytest.mark.anyio
749+
async def test_check_tasks_capability_roots_missing(server_session: ServerSession):
750+
"""Test _check_tasks_capability returns False when roots capability missing."""
751+
required = types.ClientTasksCapability(
752+
requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True))
753+
)
754+
client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(roots=None))
755+
756+
result = server_session._check_tasks_capability(required, client)
757+
assert result is False
758+
759+
760+
@pytest.mark.anyio
761+
async def test_check_tasks_capability_roots_list_false(server_session: ServerSession):
762+
"""Test _check_tasks_capability returns False when roots.list is False."""
763+
required = types.ClientTasksCapability(
764+
requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True))
765+
)
766+
client = types.ClientTasksCapability(
767+
requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=False))
768+
)
769+
770+
result = server_session._check_tasks_capability(required, client)
771+
assert result is False
772+
773+
774+
@pytest.mark.anyio
775+
async def test_check_tasks_capability_roots_success(server_session: ServerSession):
776+
"""Test _check_tasks_capability returns True when roots capability matches."""
777+
required = types.ClientTasksCapability(
778+
requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True))
779+
)
780+
client = types.ClientTasksCapability(
781+
requests=types.ClientTasksRequestsCapability(roots=types.TaskRootsCapability(list=True))
782+
)
783+
784+
result = server_session._check_tasks_capability(required, client)
785+
assert result is True
786+
787+
788+
@pytest.mark.anyio
789+
async def test_check_tasks_capability_tasks_missing(server_session: ServerSession):
790+
"""Test _check_tasks_capability returns False when tasks capability missing."""
791+
required = types.ClientTasksCapability(
792+
requests=types.ClientTasksRequestsCapability(
793+
tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True)
794+
)
795+
)
796+
client = types.ClientTasksCapability(requests=types.ClientTasksRequestsCapability(tasks=None))
797+
798+
result = server_session._check_tasks_capability(required, client)
799+
assert result is False
800+
801+
802+
@pytest.mark.anyio
803+
async def test_check_tasks_capability_tasks_get_false(server_session: ServerSession):
804+
"""Test _check_tasks_capability returns False when tasks.get is False."""
805+
required = types.ClientTasksCapability(
806+
requests=types.ClientTasksRequestsCapability(
807+
tasks=types.TasksOperationsCapability(get=True, list=False, result=False, delete=False)
808+
)
809+
)
810+
client = types.ClientTasksCapability(
811+
requests=types.ClientTasksRequestsCapability(
812+
tasks=types.TasksOperationsCapability(get=False, list=True, result=True, delete=True)
813+
)
814+
)
815+
816+
result = server_session._check_tasks_capability(required, client)
817+
assert result is False
818+
819+
820+
@pytest.mark.anyio
821+
async def test_check_tasks_capability_tasks_list_false(server_session: ServerSession):
822+
"""Test _check_tasks_capability returns False when tasks.list is False."""
823+
required = types.ClientTasksCapability(
824+
requests=types.ClientTasksRequestsCapability(
825+
tasks=types.TasksOperationsCapability(get=False, list=True, result=False, delete=False)
826+
)
827+
)
828+
client = types.ClientTasksCapability(
829+
requests=types.ClientTasksRequestsCapability(
830+
tasks=types.TasksOperationsCapability(get=True, list=False, result=True, delete=True)
831+
)
832+
)
833+
834+
result = server_session._check_tasks_capability(required, client)
835+
assert result is False
836+
837+
838+
@pytest.mark.anyio
839+
async def test_check_tasks_capability_tasks_result_false(server_session: ServerSession):
840+
"""Test _check_tasks_capability returns False when tasks.result is False."""
841+
required = types.ClientTasksCapability(
842+
requests=types.ClientTasksRequestsCapability(
843+
tasks=types.TasksOperationsCapability(get=False, list=False, result=True, delete=False)
844+
)
845+
)
846+
client = types.ClientTasksCapability(
847+
requests=types.ClientTasksRequestsCapability(
848+
tasks=types.TasksOperationsCapability(get=True, list=True, result=False, delete=True)
849+
)
850+
)
851+
852+
result = server_session._check_tasks_capability(required, client)
853+
assert result is False
854+
855+
856+
@pytest.mark.anyio
857+
async def test_check_tasks_capability_tasks_delete_false(server_session: ServerSession):
858+
"""Test _check_tasks_capability returns False when tasks.delete is False."""
859+
required = types.ClientTasksCapability(
860+
requests=types.ClientTasksRequestsCapability(
861+
tasks=types.TasksOperationsCapability(get=False, list=False, result=False, delete=True)
862+
)
863+
)
864+
client = types.ClientTasksCapability(
865+
requests=types.ClientTasksRequestsCapability(
866+
tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=False)
867+
)
868+
)
869+
870+
result = server_session._check_tasks_capability(required, client)
871+
assert result is False
872+
873+
874+
@pytest.mark.anyio
875+
async def test_check_tasks_capability_tasks_all_operations_true(server_session: ServerSession):
876+
"""Test _check_tasks_capability returns True when all required task operations match."""
877+
required = types.ClientTasksCapability(
878+
requests=types.ClientTasksRequestsCapability(
879+
tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True)
880+
)
881+
)
882+
client = types.ClientTasksCapability(
883+
requests=types.ClientTasksRequestsCapability(
884+
tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True)
885+
)
886+
)
887+
888+
result = server_session._check_tasks_capability(required, client)
889+
assert result is True
890+
891+
892+
@pytest.mark.anyio
893+
async def test_check_tasks_capability_all_capabilities_present(server_session: ServerSession):
894+
"""Test _check_tasks_capability returns True when all capabilities are satisfied."""
895+
required = types.ClientTasksCapability(
896+
requests=types.ClientTasksRequestsCapability(
897+
sampling=types.TaskSamplingCapability(createMessage=True),
898+
elicitation=types.TaskElicitationCapability(create=True),
899+
roots=types.TaskRootsCapability(list=True),
900+
tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True),
901+
)
902+
)
903+
client = types.ClientTasksCapability(
904+
requests=types.ClientTasksRequestsCapability(
905+
sampling=types.TaskSamplingCapability(createMessage=True),
906+
elicitation=types.TaskElicitationCapability(create=True),
907+
roots=types.TaskRootsCapability(list=True),
908+
tasks=types.TasksOperationsCapability(get=True, list=True, result=True, delete=True),
909+
)
910+
)
911+
912+
result = server_session._check_tasks_capability(required, client)
913+
assert result is True

0 commit comments

Comments
 (0)