Skip to content

Commit 1adb6e2

Browse files
committed
test: add coverage for error handlers, constants, optionals, and models
1 parent 1d8f92e commit 1adb6e2

File tree

4 files changed

+227
-0
lines changed

4 files changed

+227
-0
lines changed

tests/client/test_optionals.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
"""Tests for a2a.client.optionals module."""
2+
3+
import importlib
4+
import sys
5+
6+
from unittest.mock import patch
7+
8+
9+
def test_channel_import_failure():
10+
"""Test Channel behavior when grpc is not available."""
11+
with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}):
12+
if 'a2a.client.optionals' in sys.modules:
13+
del sys.modules['a2a.client.optionals']
14+
15+
optionals = importlib.import_module('a2a.client.optionals')
16+
assert optionals.Channel is None

tests/server/test_models.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
"""Tests for a2a.server.models module."""
2+
3+
from unittest.mock import MagicMock
4+
5+
from sqlalchemy.orm import DeclarativeBase
6+
7+
from a2a.server.models import (
8+
PydanticListType,
9+
PydanticType,
10+
create_push_notification_config_model,
11+
create_task_model,
12+
)
13+
from a2a.types import Artifact, TaskState, TaskStatus, TextPart
14+
15+
16+
class TestPydanticType:
17+
"""Tests for PydanticType SQLAlchemy type decorator."""
18+
19+
def test_process_bind_param_with_pydantic_model(self):
20+
pydantic_type = PydanticType(TaskStatus)
21+
status = TaskStatus(state=TaskState.working)
22+
dialect = MagicMock()
23+
24+
result = pydantic_type.process_bind_param(status, dialect)
25+
assert result["state"] == "working"
26+
assert result["message"] is None
27+
# TaskStatus may have other optional fields
28+
29+
def test_process_bind_param_with_none(self):
30+
pydantic_type = PydanticType(TaskStatus)
31+
dialect = MagicMock()
32+
33+
result = pydantic_type.process_bind_param(None, dialect)
34+
assert result is None
35+
36+
def test_process_result_value(self):
37+
pydantic_type = PydanticType(TaskStatus)
38+
dialect = MagicMock()
39+
40+
result = pydantic_type.process_result_value({"state": "completed", "message": None}, dialect)
41+
assert isinstance(result, TaskStatus)
42+
assert result.state == "completed"
43+
44+
45+
class TestPydanticListType:
46+
"""Tests for PydanticListType SQLAlchemy type decorator."""
47+
48+
def test_process_bind_param_with_list(self):
49+
pydantic_list_type = PydanticListType(Artifact)
50+
artifacts = [
51+
Artifact(artifact_id="1", parts=[TextPart(type="text", text="Hello")]),
52+
Artifact(artifact_id="2", parts=[TextPart(type="text", text="World")])
53+
]
54+
dialect = MagicMock()
55+
56+
result = pydantic_list_type.process_bind_param(artifacts, dialect)
57+
assert len(result) == 2
58+
assert result[0]["artifactId"] == "1" # JSON mode uses camelCase
59+
assert result[1]["artifactId"] == "2"
60+
61+
def test_process_result_value_with_list(self):
62+
pydantic_list_type = PydanticListType(Artifact)
63+
dialect = MagicMock()
64+
data = [
65+
{"artifact_id": "1", "parts": [{"type": "text", "text": "Hello"}]},
66+
{"artifact_id": "2", "parts": [{"type": "text", "text": "World"}]}
67+
]
68+
69+
result = pydantic_list_type.process_result_value(data, dialect)
70+
assert len(result) == 2
71+
assert all(isinstance(art, Artifact) for art in result)
72+
assert result[0].artifact_id == "1"
73+
assert result[1].artifact_id == "2"
74+
75+
76+
def test_create_task_model():
77+
"""Test dynamic task model creation."""
78+
# Create a fresh base to avoid table conflicts
79+
class TestBase(DeclarativeBase):
80+
pass
81+
82+
# Create with default table name
83+
default_task_model = create_task_model('test_tasks_1', TestBase)
84+
assert default_task_model.__tablename__ == 'test_tasks_1'
85+
assert default_task_model.__name__ == 'TaskModel_test_tasks_1'
86+
87+
# Create with custom table name
88+
custom_task_model = create_task_model('test_tasks_2', TestBase)
89+
assert custom_task_model.__tablename__ == 'test_tasks_2'
90+
assert custom_task_model.__name__ == 'TaskModel_test_tasks_2'
91+
92+
93+
def test_create_push_notification_config_model():
94+
"""Test dynamic push notification config model creation."""
95+
# Create a fresh base to avoid table conflicts
96+
class TestBase(DeclarativeBase):
97+
pass
98+
99+
# Create with default table name
100+
default_model = create_push_notification_config_model('test_push_configs_1', TestBase)
101+
assert default_model.__tablename__ == 'test_push_configs_1'
102+
103+
# Create with custom table name
104+
custom_model = create_push_notification_config_model('test_push_configs_2', TestBase)
105+
assert custom_model.__tablename__ == 'test_push_configs_2'
106+
assert 'test_push_configs_2' in custom_model.__name__

tests/utils/test_constants.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
"""Tests for a2a.utils.constants module."""
2+
3+
from a2a.utils import constants
4+
5+
6+
def test_agent_card_constants():
7+
"""Test that agent card constants have expected values."""
8+
assert constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json'
9+
assert constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json'
10+
assert constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard'
11+
12+
13+
def test_default_rpc_url():
14+
"""Test default RPC URL constant."""
15+
assert constants.DEFAULT_RPC_URL == '/'

tests/utils/test_error_handlers.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
"""Tests for a2a.utils.error_handlers module."""
2+
3+
from unittest.mock import patch
4+
5+
import pytest
6+
7+
from a2a.types import (
8+
InternalError,
9+
InvalidRequestError,
10+
MethodNotFoundError,
11+
TaskNotFoundError,
12+
)
13+
from a2a.utils.error_handlers import (
14+
A2AErrorToHttpStatus,
15+
rest_error_handler,
16+
rest_stream_error_handler,
17+
)
18+
from a2a.utils.errors import ServerError
19+
20+
21+
class MockJSONResponse:
22+
def __init__(self, content, status_code):
23+
self.content = content
24+
self.status_code = status_code
25+
26+
27+
@pytest.mark.asyncio
28+
async def test_rest_error_handler_server_error():
29+
"""Test rest_error_handler with ServerError."""
30+
error = InvalidRequestError(message="Bad request")
31+
32+
@rest_error_handler
33+
async def failing_func():
34+
raise ServerError(error=error)
35+
36+
with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
37+
result = await failing_func()
38+
39+
assert isinstance(result, MockJSONResponse)
40+
assert result.status_code == 400
41+
assert result.content == {'message': 'Bad request'}
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_rest_error_handler_unknown_exception():
46+
"""Test rest_error_handler with unknown exception."""
47+
@rest_error_handler
48+
async def failing_func():
49+
raise ValueError("Unexpected error")
50+
51+
with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
52+
result = await failing_func()
53+
54+
assert isinstance(result, MockJSONResponse)
55+
assert result.status_code == 500
56+
assert result.content == {'message': 'unknown exception'}
57+
58+
59+
@pytest.mark.asyncio
60+
async def test_rest_stream_error_handler_server_error():
61+
"""Test rest_stream_error_handler with ServerError."""
62+
error = InternalError(message="Internal server error")
63+
64+
@rest_stream_error_handler
65+
async def failing_stream():
66+
raise ServerError(error=error)
67+
68+
with pytest.raises(ServerError) as exc_info:
69+
await failing_stream()
70+
71+
assert exc_info.value.error == error
72+
73+
74+
@pytest.mark.asyncio
75+
async def test_rest_stream_error_handler_reraises_exception():
76+
"""Test rest_stream_error_handler reraises other exceptions."""
77+
@rest_stream_error_handler
78+
async def failing_stream():
79+
raise RuntimeError("Stream failed")
80+
81+
with pytest.raises(RuntimeError, match="Stream failed"):
82+
await failing_stream()
83+
84+
85+
def test_a2a_error_to_http_status_mapping():
86+
"""Test A2AErrorToHttpStatus mapping."""
87+
assert A2AErrorToHttpStatus[InvalidRequestError] == 400
88+
assert A2AErrorToHttpStatus[MethodNotFoundError] == 404
89+
assert A2AErrorToHttpStatus[TaskNotFoundError] == 404
90+
assert A2AErrorToHttpStatus[InternalError] == 500

0 commit comments

Comments
 (0)