Skip to content

Commit 64b35dd

Browse files
committed
test: add coverage for error handlers, constants, optionals, and models
1 parent 1d8f92e commit 64b35dd

File tree

4 files changed

+242
-0
lines changed

4 files changed

+242
-0
lines changed

tests/client/test_optionals.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
"""Tests for a2a.client.optionals module."""
2+
3+
import importlib
4+
import sys
5+
from unittest.mock import patch
6+
7+
import pytest
8+
9+
10+
def test_channel_import_failure():
11+
"""Test Channel behavior when grpc is not available."""
12+
with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}):
13+
if 'a2a.client.optionals' in sys.modules:
14+
del sys.modules['a2a.client.optionals']
15+
16+
optionals = importlib.import_module('a2a.client.optionals')
17+
assert optionals.Channel is None

tests/server/test_models.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""Tests for a2a.server.models module."""
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
import pytest
6+
7+
from a2a.types import Artifact, TaskStatus
8+
9+
10+
class TestPydanticType:
11+
"""Tests for PydanticType SQLAlchemy type decorator."""
12+
13+
def test_process_bind_param_with_pydantic_model(self):
14+
from a2a.server.models import PydanticType
15+
from a2a.types import TaskState
16+
17+
pydantic_type = PydanticType(TaskStatus)
18+
status = TaskStatus(state=TaskState.working)
19+
dialect = MagicMock()
20+
21+
result = pydantic_type.process_bind_param(status, dialect)
22+
assert result["state"] == "working"
23+
assert result["message"] is None
24+
# TaskStatus may have other optional fields
25+
26+
def test_process_bind_param_with_none(self):
27+
from a2a.server.models import PydanticType
28+
29+
pydantic_type = PydanticType(TaskStatus)
30+
dialect = MagicMock()
31+
32+
result = pydantic_type.process_bind_param(None, dialect)
33+
assert result is None
34+
35+
def test_process_result_value(self):
36+
from a2a.server.models import PydanticType
37+
38+
pydantic_type = PydanticType(TaskStatus)
39+
dialect = MagicMock()
40+
41+
result = pydantic_type.process_result_value({"state": "completed", "message": None}, dialect)
42+
assert isinstance(result, TaskStatus)
43+
assert result.state == "completed"
44+
45+
46+
class TestPydanticListType:
47+
"""Tests for PydanticListType SQLAlchemy type decorator."""
48+
49+
def test_process_bind_param_with_list(self):
50+
from a2a.server.models import PydanticListType
51+
from a2a.types import Artifact, TextPart
52+
53+
pydantic_list_type = PydanticListType(Artifact)
54+
artifacts = [
55+
Artifact(artifact_id="1", parts=[TextPart(type="text", text="Hello")]),
56+
Artifact(artifact_id="2", parts=[TextPart(type="text", text="World")])
57+
]
58+
dialect = MagicMock()
59+
60+
result = pydantic_list_type.process_bind_param(artifacts, dialect)
61+
assert len(result) == 2
62+
assert result[0]["artifactId"] == "1" # JSON mode uses camelCase
63+
assert result[1]["artifactId"] == "2"
64+
65+
def test_process_result_value_with_list(self):
66+
from a2a.server.models import PydanticListType
67+
from a2a.types import Artifact
68+
69+
pydantic_list_type = PydanticListType(Artifact)
70+
dialect = MagicMock()
71+
data = [
72+
{"artifact_id": "1", "parts": [{"type": "text", "text": "Hello"}]},
73+
{"artifact_id": "2", "parts": [{"type": "text", "text": "World"}]}
74+
]
75+
76+
result = pydantic_list_type.process_result_value(data, dialect)
77+
assert len(result) == 2
78+
assert all(isinstance(art, Artifact) for art in result)
79+
assert result[0].artifact_id == "1"
80+
assert result[1].artifact_id == "2"
81+
82+
83+
def test_create_task_model():
84+
"""Test dynamic task model creation."""
85+
from a2a.server.models import Base, create_task_model
86+
from sqlalchemy.orm import DeclarativeBase
87+
88+
# Create a fresh base to avoid table conflicts
89+
class TestBase(DeclarativeBase):
90+
pass
91+
92+
# Create with default table name
93+
DefaultTaskModel = create_task_model('test_tasks_1', TestBase)
94+
assert DefaultTaskModel.__tablename__ == 'test_tasks_1'
95+
assert DefaultTaskModel.__name__ == 'TaskModel_test_tasks_1'
96+
97+
# Create with custom table name
98+
CustomTaskModel = create_task_model('test_tasks_2', TestBase)
99+
assert CustomTaskModel.__tablename__ == 'test_tasks_2'
100+
assert CustomTaskModel.__name__ == 'TaskModel_test_tasks_2'
101+
102+
103+
def test_create_push_notification_config_model():
104+
"""Test dynamic push notification config model creation."""
105+
from a2a.server.models import create_push_notification_config_model
106+
from sqlalchemy.orm import DeclarativeBase
107+
108+
# Create a fresh base to avoid table conflicts
109+
class TestBase(DeclarativeBase):
110+
pass
111+
112+
# Create with default table name
113+
DefaultModel = create_push_notification_config_model('test_push_configs_1', TestBase)
114+
assert DefaultModel.__tablename__ == 'test_push_configs_1'
115+
116+
# Create with custom table name
117+
CustomModel = create_push_notification_config_model('test_push_configs_2', TestBase)
118+
assert CustomModel.__tablename__ == 'test_push_configs_2'
119+
assert 'test_push_configs_2' in CustomModel.__name__

tests/utils/test_constants.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Tests for a2a.utils.constants module."""
2+
3+
from a2a.utils import constants
4+
5+
6+
def test_agent_card_constants():
7+
"""Test that agent card constants have expected values."""
8+
assert constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json'
9+
assert constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json'
10+
assert constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard'
11+
12+
13+
def test_default_rpc_url():
14+
"""Test default RPC URL constant."""
15+
assert constants.DEFAULT_RPC_URL == '/'

tests/utils/test_error_handlers.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
"""Tests for a2a.utils.error_handlers module."""
2+
3+
import logging
4+
from unittest.mock import patch
5+
6+
import pytest
7+
8+
from a2a.types import (
9+
InternalError,
10+
InvalidRequestError,
11+
MethodNotFoundError,
12+
TaskNotFoundError,
13+
)
14+
from a2a.utils.error_handlers import (
15+
A2AErrorToHttpStatus,
16+
rest_error_handler,
17+
rest_stream_error_handler,
18+
)
19+
from a2a.utils.errors import ServerError
20+
21+
22+
class MockJSONResponse:
23+
def __init__(self, content, status_code):
24+
self.content = content
25+
self.status_code = status_code
26+
27+
28+
@pytest.mark.asyncio
29+
async def test_rest_error_handler_server_error():
30+
"""Test rest_error_handler with ServerError."""
31+
error = InvalidRequestError(message="Bad request")
32+
33+
@rest_error_handler
34+
async def failing_func():
35+
raise ServerError(error=error)
36+
37+
with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
38+
result = await failing_func()
39+
40+
assert isinstance(result, MockJSONResponse)
41+
assert result.status_code == 400
42+
assert result.content == {'message': 'Bad request'}
43+
44+
45+
@pytest.mark.asyncio
46+
async def test_rest_error_handler_unknown_exception():
47+
"""Test rest_error_handler with unknown exception."""
48+
@rest_error_handler
49+
async def failing_func():
50+
raise ValueError("Unexpected error")
51+
52+
with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
53+
result = await failing_func()
54+
55+
assert isinstance(result, MockJSONResponse)
56+
assert result.status_code == 500
57+
assert result.content == {'message': 'unknown exception'}
58+
59+
60+
@pytest.mark.asyncio
61+
async def test_rest_stream_error_handler_server_error():
62+
"""Test rest_stream_error_handler with ServerError."""
63+
error = InternalError(message="Internal server error")
64+
65+
@rest_stream_error_handler
66+
async def failing_stream():
67+
raise ServerError(error=error)
68+
69+
with pytest.raises(ServerError) as exc_info:
70+
await failing_stream()
71+
72+
assert exc_info.value.error == error
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_rest_stream_error_handler_reraises_exception():
77+
"""Test rest_stream_error_handler reraises other exceptions."""
78+
@rest_stream_error_handler
79+
async def failing_stream():
80+
raise RuntimeError("Stream failed")
81+
82+
with pytest.raises(RuntimeError, match="Stream failed"):
83+
await failing_stream()
84+
85+
86+
def test_a2a_error_to_http_status_mapping():
87+
"""Test A2AErrorToHttpStatus mapping."""
88+
assert A2AErrorToHttpStatus[InvalidRequestError] == 400
89+
assert A2AErrorToHttpStatus[MethodNotFoundError] == 404
90+
assert A2AErrorToHttpStatus[TaskNotFoundError] == 404
91+
assert A2AErrorToHttpStatus[InternalError] == 500

0 commit comments

Comments
 (0)