From 57da9b3958a125017a5198f41fb5c4721719307b Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Mon, 19 May 2025 11:36:54 -0700 Subject: [PATCH 1/7] test: Adding 19 server/app integration tests --- tests/server/test_integration.py | 561 +++++++++++++++++++++++++++++++ 1 file changed, 561 insertions(+) create mode 100644 tests/server/test_integration.py diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py new file mode 100644 index 00000000..1ae56b3d --- /dev/null +++ b/tests/server/test_integration.py @@ -0,0 +1,561 @@ +import json +from typing import Any +import pytest +import asyncio +from unittest import mock +from starlette.testclient import TestClient +from starlette.responses import JSONResponse +from starlette.routing import Route + +from a2a.server.apps.starlette_app import A2AStarletteApplication +from a2a.types import ( + AgentCapabilities, + AgentCard, + Artifact, + DataPart, + GetTaskPushNotificationConfigSuccessResponse, + InternalError, + Part, + PushNotificationConfig, + TaskArtifactUpdateEvent, + TaskPushNotificationConfig, + TextPart, + UnsupportedOperationError, + InvalidRequestError, + JSONParseError, + Task, + TaskStatus, +) +from a2a.utils.errors import MethodNotImplementedError + +# === TEST SETUP === + +MINIMAL_AGENT_SKILL: dict[str, Any] = { + 'id': 'skill-123', + 'name': 'Recipe Finder', + 'description': 'Finds recipes', + 'tags': ['cooking'], +} + +MINIMAL_AGENT_AUTH: dict[str, Any] = {'schemes': ['Bearer']} + +AGENT_CAPS = AgentCapabilities( + pushNotifications=True, stateTransitionHistory=False, streaming=True + ) + +MINIMAL_AGENT_CARD: dict[str, Any] = { + 'authentication': MINIMAL_AGENT_AUTH, + 'capabilities': AGENT_CAPS, # AgentCapabilities is required but can be empty + 'defaultInputModes': ['text/plain'], + 'defaultOutputModes': ['application/json'], + 'description': 'Test Agent', + 'name': 'TestAgent', + 'skills': [MINIMAL_AGENT_SKILL], + 'url': 'http://example.com/agent', + 'version': '1.0', +} + +TEXT_PART_DATA: dict[str, Any] = {'type': 'text', 'text': 'Hello'} + +DATA_PART_DATA: dict[str, Any] = {'type': 'data', 'data': {'key': 'value'}} + +MINIMAL_MESSAGE_USER: dict[str, Any] = { + 'role': 'user', + 'parts': [TEXT_PART_DATA], + 'messageId': 'msg-123', + 'type': 'message', +} + +MINIMAL_TASK_STATUS: dict[str, Any] = {'state': 'submitted'} + +FULL_TASK_STATUS: dict[str, Any] = { + 'state': 'working', + 'message': MINIMAL_MESSAGE_USER, + 'timestamp': '2023-10-27T10:00:00Z', +} + +@pytest.fixture +def agent_card(): + return AgentCard(**MINIMAL_AGENT_CARD) + +@pytest.fixture +def handler(): + handler = mock.AsyncMock() + handler.on_message_send = mock.AsyncMock() + handler.on_cancel_task = mock.AsyncMock() + handler.on_get_task = mock.AsyncMock() + handler.set_push_notification = mock.AsyncMock() + handler.get_push_notification = mock.AsyncMock() + handler.on_message_send_stream = mock.Mock() + handler.on_resubscribe_to_task = mock.Mock() + return handler + +@pytest.fixture +def app(agent_card: AgentCard, handler: mock.AsyncMock): + return A2AStarletteApplication(agent_card, handler) + +@pytest.fixture +def client(app: A2AStarletteApplication): + """Create a test client with the app.""" + return TestClient(app.build()) + +# === BASIC FUNCTIONALITY TESTS === + +def test_agent_card_endpoint(client: TestClient, agent_card: AgentCard): + """Test the agent card endpoint returns expected data.""" + response = client.get("/.well-known/agent.json") + assert response.status_code == 200 + data = response.json() + assert data["name"] == agent_card.name + assert data["version"] == agent_card.version + assert "streaming" in data["capabilities"] + +def test_agent_card_custom_url(app: A2AStarletteApplication, agent_card: AgentCard): + """Test the agent card endpoint with a custom URL.""" + client = TestClient(app.build(agent_card_url="/my-agent")) + response = client.get("/my-agent") + assert response.status_code == 200 + data = response.json() + assert data["name"] == agent_card.name + +def test_rpc_endpoint_custom_url(app: A2AStarletteApplication, handler: mock.AsyncMock): + """Test the RPC endpoint with a custom URL.""" + # Provide a valid Task object as the return value + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + task = Task(id="task1", contextId="ctx1", state="completed", status=task_status) + handler.on_get_task.return_value = task + client = TestClient(app.build(rpc_url="/api/rpc")) + response = client.post("/api/rpc", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "tasks/get", + "params": {"id": "task1"} + }) + assert response.status_code == 200 + data = response.json() + assert data["result"]["id"] == "task1" + +def test_build_with_extra_routes(app: A2AStarletteApplication, agent_card: AgentCard): + """Test building the app with additional routes.""" + def custom_handler(request): + return JSONResponse({"message": "Hello"}) + + extra_route = Route("/hello", custom_handler, methods=["GET"]) + test_app = app.build(routes=[extra_route]) + client = TestClient(test_app) + + # Test the added route + response = client.get("/hello") + assert response.status_code == 200 + assert response.json() == {"message": "Hello"} + + # Ensure default routes still work + response = client.get("/.well-known/agent.json") + assert response.status_code == 200 + data = response.json() + assert data["name"] == agent_card.name + +# === REQUEST METHODS TESTS === + +def test_send_message(client: TestClient, handler: mock.AsyncMock): + """Test sending a message.""" + # Prepare mock response + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + mock_task = Task(id="task1", contextId="session-xyz", state="completed", status=task_status) + handler.on_message_send.return_value = mock_task + + # Send request + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "message/send", + "params": { + "message": { + "role": "agent", + "parts": [ + { + "type": "text", + "text": "Hello" + } + ], + "messageId": "111", + "type": "message", + "taskId": "task1", + "contextId": "session-xyz", + } + } + }) + + # Verify response + assert response.status_code == 200 + data = response.json() + assert "result" in data + assert data["result"]["id"] == "task1" + assert data["result"]["status"]["state"] == "submitted" + + # Verify handler was called + handler.on_message_send.assert_awaited_once() + +def test_cancel_task(client: TestClient, handler: mock.AsyncMock): + """Test cancelling a task.""" + # Setup mock response + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + task_status.state = "cancelled" + task = Task(id="task1", contextId="ctx1", state="cancelled", status=task_status) + handler.on_cancel_task.return_value = task # JSONRPCResponse(root=task) + + # Send request + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "tasks/cancel", + "params": {"id": "task1"} + }) + + # Verify response + assert response.status_code == 200 + data = response.json() + assert data["result"]["id"] == "task1" + assert data["result"]["status"]["state"] == "cancelled" + + # Verify handler was called + handler.on_cancel_task.assert_awaited_once() + +def test_get_task(client: TestClient, handler: mock.AsyncMock): + """Test getting a task.""" + # Setup mock response + task_status = TaskStatus(**MINIMAL_TASK_STATUS) + task = Task(id="task1", contextId="ctx1", state="completed", status=task_status) + handler.on_get_task.return_value = task # JSONRPCResponse(root=task) + + # Send request + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "tasks/get", + "params": {"id": "task1"} + }) + + # Verify response + assert response.status_code == 200 + data = response.json() + assert data["result"]["id"] == "task1" + + # Verify handler was called + handler.on_get_task.assert_awaited_once() + +def test_set_push_notification_config(client: TestClient, handler: mock.AsyncMock): + """Test setting push notification configuration.""" + # Setup mock response + task_push_config = TaskPushNotificationConfig( + taskId='t2', + pushNotificationConfig=PushNotificationConfig( + url='https://example.com', token='secret-token' + ), + ) + handler.on_set_task_push_notification_config.return_value = task_push_config + + # Send request + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "tasks/pushNotificationConfig/set", + "params": { + "taskId": "t2", + "pushNotificationConfig": { + "url": "https://example.com", + "token": "secret-token", + } + } + }) + + # Verify response + assert response.status_code == 200 + data = response.json() + assert data["result"]["pushNotificationConfig"]["token"] == "secret-token" + + # Verify handler was called + handler.on_set_task_push_notification_config.assert_awaited_once() + +def test_get_push_notification_config(client: TestClient, handler: mock.AsyncMock): + """Test getting push notification configuration.""" + # Setup mock response + task_push_config = TaskPushNotificationConfig( + taskId='task1', + pushNotificationConfig=PushNotificationConfig( + url='https://example.com', token='secret-token' + ), + ) + + # Wrap the response in GetTaskPushNotificationConfigSuccessResponse + mock_response = GetTaskPushNotificationConfigSuccessResponse( + id="123", # Match the request ID + jsonrpc="2.0", + result=task_push_config, + ) + + handler.on_get_task_push_notification_config.return_value = task_push_config + + # Send request + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "tasks/pushNotificationConfig/get", + "params": {"id": "task1"} + }) + + # Verify response + assert response.status_code == 200 + data = response.json() + assert data["result"]["pushNotificationConfig"]["token"] == "secret-token" + + # Verify handler was called + handler.on_get_task_push_notification_config.assert_awaited_once() + +# === STREAMING TESTS === + +@pytest.mark.asyncio +async def test_message_send_stream(app: A2AStarletteApplication, handler: mock.AsyncMock): + """Test streaming message sending.""" + # Setup mock streaming response + async def stream_generator(): + for i in range(3): + text_part = TextPart(**TEXT_PART_DATA) + data_part = DataPart(**DATA_PART_DATA) + artifact = Artifact( + artifactId=f'artifact-{i}', + name='result_data', + parts=[Part(root=text_part), Part(root=data_part)], + ) + last = [False, False, True] + task_artifact_update_event_data: dict[str, Any] = { + 'artifact': artifact, + 'taskId': 'task_id', + 'contextId': 'session-xyz', + 'append': False, + 'lastChunk': last[i], + 'type': 'artifact-update', + } + event_data: dict[str, Any] = { + 'jsonrpc': '2.0', + 'id': 123, + 'result': task_artifact_update_event_data, + } + + yield TaskArtifactUpdateEvent.model_validate(task_artifact_update_event_data) + + handler.on_message_send_stream.return_value = stream_generator() + + client = None + try: + # Create client + client = TestClient(app.build(), raise_server_exceptions=False) + # Send request + with client.stream("POST", "/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "message/stream", + "params": { + "message": { + "role": "agent", + "parts": [ + { + "type": "text", + "text": "Hello" + } + ], + "messageId": "111", + "type": "message", + "taskId": "taskId", + "contextId": "session-xyz", + } + } + }) as response: + # Verify response is a stream + assert response.status_code == 200 + assert response.headers["content-type"].startswith("text/event-stream") + + # Read some content to verify streaming works + content = b"" + event_count = 0 + + for chunk in response.iter_bytes(): + content += chunk + if b"data" in chunk: # Naive check for SSE data lines + event_count +=1 + + # Check content has event data (e.g., part of the first event) + assert b'"artifactId":"artifact-0"' in content # Check for the actual JSON payload + assert b'"artifactId":"artifact-1"' in content # Check for the actual JSON payload + assert b'"artifactId":"artifact-2"' in content # Check for the actual JSON payload + assert event_count > 0 + finally: + # Ensure the client is closed + if client: + client.close() + # Allow event loop to process any pending callbacks + await asyncio.sleep(0.1) + +@pytest.mark.asyncio +async def test_task_resubscription(app: A2AStarletteApplication, handler: mock.AsyncMock): + """Test task resubscription streaming.""" + # Setup mock streaming response + async def stream_generator(): + for i in range(3): + text_part = TextPart(**TEXT_PART_DATA) + data_part = DataPart(**DATA_PART_DATA) + artifact = Artifact( + artifactId=f'artifact-{i}', + name='result_data', + parts=[Part(root=text_part), Part(root=data_part)], + ) + last = [False, False, True] + task_artifact_update_event_data: dict[str, Any] = { + 'artifact': artifact, + 'taskId': 'task_id', + 'contextId': 'session-xyz', + 'append': False, + 'lastChunk': last[i], + 'type': 'artifact-update', + } + yield TaskArtifactUpdateEvent.model_validate(task_artifact_update_event_data) + + handler.on_resubscribe_to_task.return_value = stream_generator() + + # Create client + client = TestClient(app.build(), raise_server_exceptions=False) + + try: + # Send request using client.stream() context manager + # Send request + with client.stream("POST", "/", json={ + "jsonrpc": "2.0", + "id": "123", # This ID is used in the success_event above + "method": "tasks/resubscribe", + "params": {"id": "task1"} + }) as response: + # Verify response is a stream + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + + # Read some content to verify streaming works + content = b"" + event_count = 0 + for chunk in response.iter_bytes(): + content += chunk + # A more robust check would be to parse each SSE event + if b"data:" in chunk: # Naive check for SSE data lines + event_count +=1 + if event_count >= 1 and len(content) > 20 : # Ensure we've processed at least one event + break + + # Check content has event data (e.g., part of the first event) + assert b'"artifactId":"artifact-0"' in content # Check for the actual JSON payload + assert b'"artifactId":"artifact-1"' in content # Check for the actual JSON payload + assert b'"artifactId":"artifact-2"' in content # Check for the actual JSON payload + assert event_count > 0 + finally: + # Ensure the client is closed + if client: + client.close() + # Allow event loop to process any pending callbacks + await asyncio.sleep(0.1) + +# === ERROR HANDLING TESTS === + +def test_invalid_json(client: TestClient): + """Test handling invalid JSON.""" + response = client.post("/", data="This is not JSON") + assert response.status_code == 200 # JSON-RPC errors still return 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == JSONParseError().code + +def test_invalid_request_structure(client: TestClient): + """Test handling an invalid request structure.""" + response = client.post("/", json={ + # Missing required fields + "id": "123" + }) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == InvalidRequestError().code + +def test_method_not_implemented(client: TestClient, handler: mock.AsyncMock): + """Test handling MethodNotImplementedError.""" + handler.on_get_task.side_effect = MethodNotImplementedError() + + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "tasks/get", + "params": {"id": "task1"} + }) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == UnsupportedOperationError().code + +def test_unknown_method(client: TestClient): + """Test handling unknown method.""" + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "unknown/method", + "params": {} + }) + assert response.status_code == 200 + data = response.json() + assert "error" in data + # This should produce an UnsupportedOperationError error code + assert data["error"]["code"] == InvalidRequestError().code + +def test_validation_error(client: TestClient): + """Test handling validation error.""" + # Missing required fields in the message + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "messages/send", + "params": { + "message": { + # Missing required fields + "text": "Hello" + } + } + }) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == InvalidRequestError().code + +def test_unhandled_exception(client: TestClient, handler: mock.AsyncMock): + """Test handling unhandled exception.""" + handler.on_get_task.side_effect = Exception("Unexpected error") + + response = client.post("/", json={ + "jsonrpc": "2.0", + "id": "123", + "method": "tasks/get", + "params": {"id": "task1"} + }) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == InternalError().code + assert "Unexpected error" in data["error"]["message"] + +def test_get_method_to_rpc_endpoint(client: TestClient): + """Test sending GET request to RPC endpoint.""" + response = client.get("/") + # Should return 405 Method Not Allowed + assert response.status_code == 405 + +def test_non_dict_json(client: TestClient): + """Test handling JSON that's not a dict.""" + response = client.post("/", json=["not", "a", "dict"]) + assert response.status_code == 200 + data = response.json() + assert "error" in data + assert data["error"]["code"] == InvalidRequestError().code \ No newline at end of file From a658ff897c9b535f9a5561197320d31b1b95ea15 Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Mon, 19 May 2025 12:01:45 -0700 Subject: [PATCH 2/7] test: small change to pyproject.toml update pytest asyciomode to strict --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 419a677a..eda65828 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ testpaths = ["tests"] python_files = "test_*.py" python_functions = "test_*" addopts = "--cov=src --cov-config=.coveragerc --cov-report term --cov-report xml:coverage.xml --cov-branch" +asyncio_mode = "strict" [build-system] requires = ["hatchling", "uv-dynamic-versioning"] From 630171f31da16c4ef8c1a3f118a3b8cf1379c89d Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Mon, 19 May 2025 13:46:39 -0700 Subject: [PATCH 3/7] test: updated .coveragerc --- .coveragerc | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/.coveragerc b/.coveragerc index f3d9d841..461f9bbe 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,8 +1,18 @@ [run] branch = True +omit = + */tests/* + */site-packages/* + */__init__.py + */noxfile.py* [report] -exclude_also = - pass +exclude_lines = + pragma: no cover import + def __repr__ + raise NotImplementedError + if TYPE_CHECKING @abstractmethod + pass + raise ImportError \ No newline at end of file From dc34346e5783cbb859cd9079d43dbe40aea9e1cc Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Mon, 19 May 2025 13:57:21 -0700 Subject: [PATCH 4/7] test: updated test_integration.py based on type->kind rename in #34 --- tests/server/test_integration.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 1ae56b3d..ed485d85 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -55,15 +55,15 @@ 'version': '1.0', } -TEXT_PART_DATA: dict[str, Any] = {'type': 'text', 'text': 'Hello'} +TEXT_PART_DATA: dict[str, Any] = {'kind': 'text', 'text': 'Hello'} -DATA_PART_DATA: dict[str, Any] = {'type': 'data', 'data': {'key': 'value'}} +DATA_PART_DATA: dict[str, Any] = {'kind': 'data', 'data': {'key': 'value'}} MINIMAL_MESSAGE_USER: dict[str, Any] = { 'role': 'user', 'parts': [TEXT_PART_DATA], 'messageId': 'msg-123', - 'type': 'message', + 'kind': 'message', } MINIMAL_TASK_STATUS: dict[str, Any] = {'state': 'submitted'} @@ -174,12 +174,12 @@ def test_send_message(client: TestClient, handler: mock.AsyncMock): "role": "agent", "parts": [ { - "type": "text", + "kind": "text", "text": "Hello" } ], "messageId": "111", - "type": "message", + "kind": "message", "taskId": "task1", "contextId": "session-xyz", } @@ -334,7 +334,7 @@ async def stream_generator(): 'contextId': 'session-xyz', 'append': False, 'lastChunk': last[i], - 'type': 'artifact-update', + 'kind': 'artifact-update', } event_data: dict[str, Any] = { 'jsonrpc': '2.0', @@ -360,12 +360,12 @@ async def stream_generator(): "role": "agent", "parts": [ { - "type": "text", + "kind": "text", "text": "Hello" } ], "messageId": "111", - "type": "message", + "kind": "message", "taskId": "taskId", "contextId": "session-xyz", } @@ -416,7 +416,7 @@ async def stream_generator(): 'contextId': 'session-xyz', 'append': False, 'lastChunk': last[i], - 'type': 'artifact-update', + 'kind': 'artifact-update', } yield TaskArtifactUpdateEvent.model_validate(task_artifact_update_event_data) From 471c98e440fc2d530a9fc87ddc2b622501c8ae25 Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Mon, 19 May 2025 14:48:08 -0700 Subject: [PATCH 5/7] test: updated test_integration.py with nox format --- tests/server/test_integration.py | 578 ++++++++++++++++++------------- 1 file changed, 345 insertions(+), 233 deletions(-) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index ed485d85..dc28e2b7 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -1,11 +1,13 @@ -import json -from typing import Any -import pytest import asyncio + +from typing import Any from unittest import mock -from starlette.testclient import TestClient + +import pytest + from starlette.responses import JSONResponse from starlette.routing import Route +from starlette.testclient import TestClient from a2a.server.apps.starlette_app import A2AStarletteApplication from a2a.types import ( @@ -15,19 +17,20 @@ DataPart, GetTaskPushNotificationConfigSuccessResponse, InternalError, + InvalidRequestError, + JSONParseError, Part, PushNotificationConfig, + Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, + TaskStatus, TextPart, UnsupportedOperationError, - InvalidRequestError, - JSONParseError, - Task, - TaskStatus, ) from a2a.utils.errors import MethodNotImplementedError + # === TEST SETUP === MINIMAL_AGENT_SKILL: dict[str, Any] = { @@ -40,8 +43,8 @@ MINIMAL_AGENT_AUTH: dict[str, Any] = {'schemes': ['Bearer']} AGENT_CAPS = AgentCapabilities( - pushNotifications=True, stateTransitionHistory=False, streaming=True - ) + pushNotifications=True, stateTransitionHistory=False, streaming=True +) MINIMAL_AGENT_CARD: dict[str, Any] = { 'authentication': MINIMAL_AGENT_AUTH, @@ -74,10 +77,12 @@ 'timestamp': '2023-10-27T10:00:00Z', } + @pytest.fixture def agent_card(): return AgentCard(**MINIMAL_AGENT_CARD) + @pytest.fixture def handler(): handler = mock.AsyncMock() @@ -90,161 +95,200 @@ def handler(): handler.on_resubscribe_to_task = mock.Mock() return handler + @pytest.fixture def app(agent_card: AgentCard, handler: mock.AsyncMock): return A2AStarletteApplication(agent_card, handler) + @pytest.fixture def client(app: A2AStarletteApplication): """Create a test client with the app.""" return TestClient(app.build()) + # === BASIC FUNCTIONALITY TESTS === + def test_agent_card_endpoint(client: TestClient, agent_card: AgentCard): """Test the agent card endpoint returns expected data.""" - response = client.get("/.well-known/agent.json") + response = client.get('/.well-known/agent.json') assert response.status_code == 200 data = response.json() - assert data["name"] == agent_card.name - assert data["version"] == agent_card.version - assert "streaming" in data["capabilities"] + assert data['name'] == agent_card.name + assert data['version'] == agent_card.version + assert 'streaming' in data['capabilities'] + -def test_agent_card_custom_url(app: A2AStarletteApplication, agent_card: AgentCard): +def test_agent_card_custom_url( + app: A2AStarletteApplication, agent_card: AgentCard +): """Test the agent card endpoint with a custom URL.""" - client = TestClient(app.build(agent_card_url="/my-agent")) - response = client.get("/my-agent") + client = TestClient(app.build(agent_card_url='/my-agent')) + response = client.get('/my-agent') assert response.status_code == 200 data = response.json() - assert data["name"] == agent_card.name + assert data['name'] == agent_card.name + -def test_rpc_endpoint_custom_url(app: A2AStarletteApplication, handler: mock.AsyncMock): +def test_rpc_endpoint_custom_url( + app: A2AStarletteApplication, handler: mock.AsyncMock +): """Test the RPC endpoint with a custom URL.""" # Provide a valid Task object as the return value task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task = Task(id="task1", contextId="ctx1", state="completed", status=task_status) + task = Task( + id='task1', contextId='ctx1', state='completed', status=task_status + ) handler.on_get_task.return_value = task - client = TestClient(app.build(rpc_url="/api/rpc")) - response = client.post("/api/rpc", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "tasks/get", - "params": {"id": "task1"} - }) + client = TestClient(app.build(rpc_url='/api/rpc')) + response = client.post( + '/api/rpc', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'tasks/get', + 'params': {'id': 'task1'}, + }, + ) assert response.status_code == 200 data = response.json() - assert data["result"]["id"] == "task1" + assert data['result']['id'] == 'task1' -def test_build_with_extra_routes(app: A2AStarletteApplication, agent_card: AgentCard): + +def test_build_with_extra_routes( + app: A2AStarletteApplication, agent_card: AgentCard +): """Test building the app with additional routes.""" + def custom_handler(request): - return JSONResponse({"message": "Hello"}) - - extra_route = Route("/hello", custom_handler, methods=["GET"]) + return JSONResponse({'message': 'Hello'}) + + extra_route = Route('/hello', custom_handler, methods=['GET']) test_app = app.build(routes=[extra_route]) client = TestClient(test_app) - + # Test the added route - response = client.get("/hello") + response = client.get('/hello') assert response.status_code == 200 - assert response.json() == {"message": "Hello"} - + assert response.json() == {'message': 'Hello'} + # Ensure default routes still work - response = client.get("/.well-known/agent.json") + response = client.get('/.well-known/agent.json') assert response.status_code == 200 data = response.json() - assert data["name"] == agent_card.name + assert data['name'] == agent_card.name + # === REQUEST METHODS TESTS === + def test_send_message(client: TestClient, handler: mock.AsyncMock): """Test sending a message.""" # Prepare mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) - mock_task = Task(id="task1", contextId="session-xyz", state="completed", status=task_status) + mock_task = Task( + id='task1', + contextId='session-xyz', + state='completed', + status=task_status, + ) handler.on_message_send.return_value = mock_task # Send request - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "message/send", - "params": { - "message": { - "role": "agent", - "parts": [ - { - "kind": "text", - "text": "Hello" - } - ], - "messageId": "111", - "kind": "message", - "taskId": "task1", - "contextId": "session-xyz", - } - } - }) - + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'message/send', + 'params': { + 'message': { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'messageId': '111', + 'kind': 'message', + 'taskId': 'task1', + 'contextId': 'session-xyz', + } + }, + }, + ) + # Verify response assert response.status_code == 200 data = response.json() - assert "result" in data - assert data["result"]["id"] == "task1" - assert data["result"]["status"]["state"] == "submitted" + assert 'result' in data + assert data['result']['id'] == 'task1' + assert data['result']['status']['state'] == 'submitted' # Verify handler was called handler.on_message_send.assert_awaited_once() + def test_cancel_task(client: TestClient, handler: mock.AsyncMock): """Test cancelling a task.""" # Setup mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task_status.state = "cancelled" - task = Task(id="task1", contextId="ctx1", state="cancelled", status=task_status) - handler.on_cancel_task.return_value = task # JSONRPCResponse(root=task) - + task_status.state = 'cancelled' + task = Task( + id='task1', contextId='ctx1', state='cancelled', status=task_status + ) + handler.on_cancel_task.return_value = task # JSONRPCResponse(root=task) + # Send request - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "tasks/cancel", - "params": {"id": "task1"} - }) - + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'tasks/cancel', + 'params': {'id': 'task1'}, + }, + ) + # Verify response assert response.status_code == 200 data = response.json() - assert data["result"]["id"] == "task1" - assert data["result"]["status"]["state"] == "cancelled" - + assert data['result']['id'] == 'task1' + assert data['result']['status']['state'] == 'cancelled' + # Verify handler was called handler.on_cancel_task.assert_awaited_once() + def test_get_task(client: TestClient, handler: mock.AsyncMock): """Test getting a task.""" # Setup mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task = Task(id="task1", contextId="ctx1", state="completed", status=task_status) - handler.on_get_task.return_value = task # JSONRPCResponse(root=task) - + task = Task( + id='task1', contextId='ctx1', state='completed', status=task_status + ) + handler.on_get_task.return_value = task # JSONRPCResponse(root=task) + # Send request - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "tasks/get", - "params": {"id": "task1"} - }) - + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'tasks/get', + 'params': {'id': 'task1'}, + }, + ) + # Verify response assert response.status_code == 200 data = response.json() - assert data["result"]["id"] == "task1" - + assert data['result']['id'] == 'task1' + # Verify handler was called handler.on_get_task.assert_awaited_once() -def test_set_push_notification_config(client: TestClient, handler: mock.AsyncMock): + +def test_set_push_notification_config( + client: TestClient, handler: mock.AsyncMock +): """Test setting push notification configuration.""" # Setup mock response task_push_config = TaskPushNotificationConfig( @@ -252,32 +296,38 @@ def test_set_push_notification_config(client: TestClient, handler: mock.AsyncMoc pushNotificationConfig=PushNotificationConfig( url='https://example.com', token='secret-token' ), - ) - handler.on_set_task_push_notification_config.return_value = task_push_config - + ) + handler.on_set_task_push_notification_config.return_value = task_push_config + # Send request - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "tasks/pushNotificationConfig/set", - "params": { - "taskId": "t2", - "pushNotificationConfig": { - "url": "https://example.com", - "token": "secret-token", - } - } - }) - + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'tasks/pushNotificationConfig/set', + 'params': { + 'taskId': 't2', + 'pushNotificationConfig': { + 'url': 'https://example.com', + 'token': 'secret-token', + }, + }, + }, + ) + # Verify response assert response.status_code == 200 data = response.json() - assert data["result"]["pushNotificationConfig"]["token"] == "secret-token" - + assert data['result']['pushNotificationConfig']['token'] == 'secret-token' + # Verify handler was called handler.on_set_task_push_notification_config.assert_awaited_once() -def test_get_push_notification_config(client: TestClient, handler: mock.AsyncMock): + +def test_get_push_notification_config( + client: TestClient, handler: mock.AsyncMock +): """Test getting push notification configuration.""" # Setup mock response task_push_config = TaskPushNotificationConfig( @@ -285,38 +335,46 @@ def test_get_push_notification_config(client: TestClient, handler: mock.AsyncMoc pushNotificationConfig=PushNotificationConfig( url='https://example.com', token='secret-token' ), - ) + ) # Wrap the response in GetTaskPushNotificationConfigSuccessResponse mock_response = GetTaskPushNotificationConfigSuccessResponse( - id="123", # Match the request ID - jsonrpc="2.0", + id='123', # Match the request ID + jsonrpc='2.0', result=task_push_config, ) - handler.on_get_task_push_notification_config.return_value = task_push_config - + handler.on_get_task_push_notification_config.return_value = task_push_config + # Send request - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "tasks/pushNotificationConfig/get", - "params": {"id": "task1"} - }) - + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'tasks/pushNotificationConfig/get', + 'params': {'id': 'task1'}, + }, + ) + # Verify response assert response.status_code == 200 data = response.json() - assert data["result"]["pushNotificationConfig"]["token"] == "secret-token" - + assert data['result']['pushNotificationConfig']['token'] == 'secret-token' + # Verify handler was called handler.on_get_task_push_notification_config.assert_awaited_once() + # === STREAMING TESTS === + @pytest.mark.asyncio -async def test_message_send_stream(app: A2AStarletteApplication, handler: mock.AsyncMock): +async def test_message_send_stream( + app: A2AStarletteApplication, handler: mock.AsyncMock +): """Test streaming message sending.""" + # Setup mock streaming response async def stream_generator(): for i in range(3): @@ -342,52 +400,61 @@ async def stream_generator(): 'result': task_artifact_update_event_data, } - yield TaskArtifactUpdateEvent.model_validate(task_artifact_update_event_data) - + yield TaskArtifactUpdateEvent.model_validate( + task_artifact_update_event_data + ) + handler.on_message_send_stream.return_value = stream_generator() - - client = None - try: + + client = None + try: # Create client client = TestClient(app.build(), raise_server_exceptions=False) # Send request - with client.stream("POST", "/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "message/stream", - "params": { - "message": { - "role": "agent", - "parts": [ - { - "kind": "text", - "text": "Hello" - } - ], - "messageId": "111", - "kind": "message", - "taskId": "taskId", - "contextId": "session-xyz", - } - } - }) as response: + with client.stream( + 'POST', + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'message/stream', + 'params': { + 'message': { + 'role': 'agent', + 'parts': [{'kind': 'text', 'text': 'Hello'}], + 'messageId': '111', + 'kind': 'message', + 'taskId': 'taskId', + 'contextId': 'session-xyz', + } + }, + }, + ) as response: # Verify response is a stream assert response.status_code == 200 - assert response.headers["content-type"].startswith("text/event-stream") - + assert response.headers['content-type'].startswith( + 'text/event-stream' + ) + # Read some content to verify streaming works - content = b"" + content = b'' event_count = 0 for chunk in response.iter_bytes(): content += chunk - if b"data" in chunk: # Naive check for SSE data lines - event_count +=1 - + if b'data' in chunk: # Naive check for SSE data lines + event_count += 1 + # Check content has event data (e.g., part of the first event) - assert b'"artifactId":"artifact-0"' in content # Check for the actual JSON payload - assert b'"artifactId":"artifact-1"' in content # Check for the actual JSON payload - assert b'"artifactId":"artifact-2"' in content # Check for the actual JSON payload + assert ( + b'"artifactId":"artifact-0"' in content + ) # Check for the actual JSON payload + assert ( + b'"artifactId":"artifact-1"' in content + ) # Check for the actual JSON payload + assert ( + b'"artifactId":"artifact-2"' in content + ) # Check for the actual JSON payload assert event_count > 0 finally: # Ensure the client is closed @@ -396,9 +463,13 @@ async def stream_generator(): # Allow event loop to process any pending callbacks await asyncio.sleep(0.1) + @pytest.mark.asyncio -async def test_task_resubscription(app: A2AStarletteApplication, handler: mock.AsyncMock): +async def test_task_resubscription( + app: A2AStarletteApplication, handler: mock.AsyncMock +): """Test task resubscription streaming.""" + # Setup mock streaming response async def stream_generator(): for i in range(3): @@ -418,41 +489,58 @@ async def stream_generator(): 'lastChunk': last[i], 'kind': 'artifact-update', } - yield TaskArtifactUpdateEvent.model_validate(task_artifact_update_event_data) + yield TaskArtifactUpdateEvent.model_validate( + task_artifact_update_event_data + ) handler.on_resubscribe_to_task.return_value = stream_generator() - + # Create client client = TestClient(app.build(), raise_server_exceptions=False) - + try: # Send request using client.stream() context manager # Send request - with client.stream("POST", "/", json={ - "jsonrpc": "2.0", - "id": "123", # This ID is used in the success_event above - "method": "tasks/resubscribe", - "params": {"id": "task1"} - }) as response: + with client.stream( + 'POST', + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', # This ID is used in the success_event above + 'method': 'tasks/resubscribe', + 'params': {'id': 'task1'}, + }, + ) as response: # Verify response is a stream assert response.status_code == 200 - assert response.headers["content-type"] == "text/event-stream; charset=utf-8" - + assert ( + response.headers['content-type'] + == 'text/event-stream; charset=utf-8' + ) + # Read some content to verify streaming works - content = b"" + content = b'' event_count = 0 - for chunk in response.iter_bytes(): + for chunk in response.iter_bytes(): content += chunk # A more robust check would be to parse each SSE event - if b"data:" in chunk: # Naive check for SSE data lines - event_count +=1 - if event_count >= 1 and len(content) > 20 : # Ensure we've processed at least one event + if b'data:' in chunk: # Naive check for SSE data lines + event_count += 1 + if ( + event_count >= 1 and len(content) > 20 + ): # Ensure we've processed at least one event break - + # Check content has event data (e.g., part of the first event) - assert b'"artifactId":"artifact-0"' in content # Check for the actual JSON payload - assert b'"artifactId":"artifact-1"' in content # Check for the actual JSON payload - assert b'"artifactId":"artifact-2"' in content # Check for the actual JSON payload + assert ( + b'"artifactId":"artifact-0"' in content + ) # Check for the actual JSON payload + assert ( + b'"artifactId":"artifact-1"' in content + ) # Check for the actual JSON payload + assert ( + b'"artifactId":"artifact-2"' in content + ) # Check for the actual JSON payload assert event_count > 0 finally: # Ensure the client is closed @@ -461,101 +549,125 @@ async def stream_generator(): # Allow event loop to process any pending callbacks await asyncio.sleep(0.1) + # === ERROR HANDLING TESTS === + def test_invalid_json(client: TestClient): """Test handling invalid JSON.""" - response = client.post("/", data="This is not JSON") + response = client.post('/', data='This is not JSON') assert response.status_code == 200 # JSON-RPC errors still return 200 data = response.json() - assert "error" in data - assert data["error"]["code"] == JSONParseError().code + assert 'error' in data + assert data['error']['code'] == JSONParseError().code + def test_invalid_request_structure(client: TestClient): """Test handling an invalid request structure.""" - response = client.post("/", json={ - # Missing required fields - "id": "123" - }) + response = client.post( + '/', + json={ + # Missing required fields + 'id': '123' + }, + ) assert response.status_code == 200 data = response.json() - assert "error" in data - assert data["error"]["code"] == InvalidRequestError().code + assert 'error' in data + assert data['error']['code'] == InvalidRequestError().code + def test_method_not_implemented(client: TestClient, handler: mock.AsyncMock): """Test handling MethodNotImplementedError.""" handler.on_get_task.side_effect = MethodNotImplementedError() - - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "tasks/get", - "params": {"id": "task1"} - }) + + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'tasks/get', + 'params': {'id': 'task1'}, + }, + ) assert response.status_code == 200 data = response.json() - assert "error" in data - assert data["error"]["code"] == UnsupportedOperationError().code + assert 'error' in data + assert data['error']['code'] == UnsupportedOperationError().code + def test_unknown_method(client: TestClient): """Test handling unknown method.""" - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "unknown/method", - "params": {} - }) + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'unknown/method', + 'params': {}, + }, + ) assert response.status_code == 200 data = response.json() - assert "error" in data + assert 'error' in data # This should produce an UnsupportedOperationError error code - assert data["error"]["code"] == InvalidRequestError().code + assert data['error']['code'] == InvalidRequestError().code + def test_validation_error(client: TestClient): """Test handling validation error.""" # Missing required fields in the message - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "messages/send", - "params": { - "message": { - # Missing required fields - "text": "Hello" - } - } - }) + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'messages/send', + 'params': { + 'message': { + # Missing required fields + 'text': 'Hello' + } + }, + }, + ) assert response.status_code == 200 data = response.json() - assert "error" in data - assert data["error"]["code"] == InvalidRequestError().code + assert 'error' in data + assert data['error']['code'] == InvalidRequestError().code + def test_unhandled_exception(client: TestClient, handler: mock.AsyncMock): """Test handling unhandled exception.""" - handler.on_get_task.side_effect = Exception("Unexpected error") - - response = client.post("/", json={ - "jsonrpc": "2.0", - "id": "123", - "method": "tasks/get", - "params": {"id": "task1"} - }) + handler.on_get_task.side_effect = Exception('Unexpected error') + + response = client.post( + '/', + json={ + 'jsonrpc': '2.0', + 'id': '123', + 'method': 'tasks/get', + 'params': {'id': 'task1'}, + }, + ) assert response.status_code == 200 data = response.json() - assert "error" in data - assert data["error"]["code"] == InternalError().code - assert "Unexpected error" in data["error"]["message"] + assert 'error' in data + assert data['error']['code'] == InternalError().code + assert 'Unexpected error' in data['error']['message'] + def test_get_method_to_rpc_endpoint(client: TestClient): """Test sending GET request to RPC endpoint.""" - response = client.get("/") + response = client.get('/') # Should return 405 Method Not Allowed assert response.status_code == 405 + def test_non_dict_json(client: TestClient): """Test handling JSON that's not a dict.""" - response = client.post("/", json=["not", "a", "dict"]) + response = client.post('/', json=['not', 'a', 'dict']) assert response.status_code == 200 data = response.json() - assert "error" in data - assert data["error"]["code"] == InvalidRequestError().code \ No newline at end of file + assert 'error' in data + assert data['error']['code'] == InvalidRequestError().code From 4061b842eed99d47e2743615480d0e5c83001a76 Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Mon, 19 May 2025 14:58:02 -0700 Subject: [PATCH 6/7] test: remove unused local variable from test_integration.py --- tests/server/test_integration.py | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index dc28e2b7..6c03a16e 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -15,7 +15,6 @@ AgentCard, Artifact, DataPart, - GetTaskPushNotificationConfigSuccessResponse, InternalError, InvalidRequestError, JSONParseError, @@ -337,13 +336,6 @@ def test_get_push_notification_config( ), ) - # Wrap the response in GetTaskPushNotificationConfigSuccessResponse - mock_response = GetTaskPushNotificationConfigSuccessResponse( - id='123', # Match the request ID - jsonrpc='2.0', - result=task_push_config, - ) - handler.on_get_task_push_notification_config.return_value = task_push_config # Send request @@ -394,11 +386,6 @@ async def stream_generator(): 'lastChunk': last[i], 'kind': 'artifact-update', } - event_data: dict[str, Any] = { - 'jsonrpc': '2.0', - 'id': 123, - 'result': task_artifact_update_event_data, - } yield TaskArtifactUpdateEvent.model_validate( task_artifact_update_event_data From bf63d9eac4366a38a6fdebe6853bae93d8e05b36 Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Mon, 19 May 2025 15:15:43 -0700 Subject: [PATCH 7/7] test: fixed lint warning for incompatible types in assignments and untyped function --- tests/server/test_integration.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 6c03a16e..79814577 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -23,6 +23,7 @@ Task, TaskArtifactUpdateEvent, TaskPushNotificationConfig, + TaskState, TaskStatus, TextPart, UnsupportedOperationError, @@ -229,11 +230,11 @@ def test_cancel_task(client: TestClient, handler: mock.AsyncMock): """Test cancelling a task.""" # Setup mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task_status.state = 'cancelled' + task_status.state = TaskState.canceled # 'cancelled' # task = Task( id='task1', contextId='ctx1', state='cancelled', status=task_status ) - handler.on_cancel_task.return_value = task # JSONRPCResponse(root=task) + handler.on_cancel_task.return_value = task # Send request response = client.post( @@ -250,7 +251,7 @@ def test_cancel_task(client: TestClient, handler: mock.AsyncMock): assert response.status_code == 200 data = response.json() assert data['result']['id'] == 'task1' - assert data['result']['status']['state'] == 'cancelled' + assert data['result']['status']['state'] == 'canceled' # Verify handler was called handler.on_cancel_task.assert_awaited_once() @@ -364,7 +365,7 @@ def test_get_push_notification_config( @pytest.mark.asyncio async def test_message_send_stream( app: A2AStarletteApplication, handler: mock.AsyncMock -): +) -> None: """Test streaming message sending.""" # Setup mock streaming response @@ -454,7 +455,7 @@ async def stream_generator(): @pytest.mark.asyncio async def test_task_resubscription( app: A2AStarletteApplication, handler: mock.AsyncMock -): +) -> None: """Test task resubscription streaming.""" # Setup mock streaming response