Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/client/test_optionals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Tests for a2a.client.optionals module."""

import importlib
import sys

from unittest.mock import patch


def test_channel_import_failure():
"""Test Channel behavior when grpc is not available."""
with patch.dict('sys.modules', {'grpc': None, 'grpc.aio': None}):
if 'a2a.client.optionals' in sys.modules:
del sys.modules['a2a.client.optionals']

optionals = importlib.import_module('a2a.client.optionals')
assert optionals.Channel is None
118 changes: 118 additions & 0 deletions tests/server/test_models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""Tests for a2a.server.models module."""

from unittest.mock import MagicMock

from sqlalchemy.orm import DeclarativeBase

from a2a.server.models import (
PydanticListType,
PydanticType,
create_push_notification_config_model,
create_task_model,
)
from a2a.types import Artifact, TaskState, TaskStatus, TextPart


class TestPydanticType:
"""Tests for PydanticType SQLAlchemy type decorator."""

def test_process_bind_param_with_pydantic_model(self):
pydantic_type = PydanticType(TaskStatus)
status = TaskStatus(state=TaskState.working)
dialect = MagicMock()

result = pydantic_type.process_bind_param(status, dialect)
assert result['state'] == 'working'
assert result['message'] is None
# TaskStatus may have other optional fields

def test_process_bind_param_with_none(self):
pydantic_type = PydanticType(TaskStatus)
dialect = MagicMock()

result = pydantic_type.process_bind_param(None, dialect)
assert result is None

def test_process_result_value(self):
pydantic_type = PydanticType(TaskStatus)
dialect = MagicMock()

result = pydantic_type.process_result_value(
{'state': 'completed', 'message': None}, dialect
)
assert isinstance(result, TaskStatus)
assert result.state == 'completed'


class TestPydanticListType:
"""Tests for PydanticListType SQLAlchemy type decorator."""

def test_process_bind_param_with_list(self):
pydantic_list_type = PydanticListType(Artifact)
artifacts = [
Artifact(
artifact_id='1', parts=[TextPart(type='text', text='Hello')]
),
Artifact(
artifact_id='2', parts=[TextPart(type='text', text='World')]
),
]
dialect = MagicMock()

result = pydantic_list_type.process_bind_param(artifacts, dialect)
assert len(result) == 2
assert result[0]['artifactId'] == '1' # JSON mode uses camelCase
assert result[1]['artifactId'] == '2'

def test_process_result_value_with_list(self):
pydantic_list_type = PydanticListType(Artifact)
dialect = MagicMock()
data = [
{'artifact_id': '1', 'parts': [{'type': 'text', 'text': 'Hello'}]},
{'artifact_id': '2', 'parts': [{'type': 'text', 'text': 'World'}]},
]

result = pydantic_list_type.process_result_value(data, dialect)
assert len(result) == 2
assert all(isinstance(art, Artifact) for art in result)
assert result[0].artifact_id == '1'
assert result[1].artifact_id == '2'


def test_create_task_model():
"""Test dynamic task model creation."""

# Create a fresh base to avoid table conflicts
class TestBase(DeclarativeBase):
pass

# Create with default table name
default_task_model = create_task_model('test_tasks_1', TestBase)
assert default_task_model.__tablename__ == 'test_tasks_1'
assert default_task_model.__name__ == 'TaskModel_test_tasks_1'

# Create with custom table name
custom_task_model = create_task_model('test_tasks_2', TestBase)
assert custom_task_model.__tablename__ == 'test_tasks_2'
assert custom_task_model.__name__ == 'TaskModel_test_tasks_2'


def test_create_push_notification_config_model():
"""Test dynamic push notification config model creation."""

# Create a fresh base to avoid table conflicts
class TestBase(DeclarativeBase):
pass

# Create with default table name
default_model = create_push_notification_config_model(
'test_push_configs_1', TestBase
)
assert default_model.__tablename__ == 'test_push_configs_1'

# Create with custom table name
custom_model = create_push_notification_config_model(
'test_push_configs_2', TestBase
)
assert custom_model.__tablename__ == 'test_push_configs_2'
assert 'test_push_configs_2' in custom_model.__name__
21 changes: 21 additions & 0 deletions tests/utils/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Tests for a2a.utils.constants module."""

from a2a.utils import constants


def test_agent_card_constants():
"""Test that agent card constants have expected values."""
assert (
constants.AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent-card.json'
)
assert (
constants.PREV_AGENT_CARD_WELL_KNOWN_PATH == '/.well-known/agent.json'
)
assert (
constants.EXTENDED_AGENT_CARD_PATH == '/agent/authenticatedExtendedCard'
)


def test_default_rpc_url():
"""Test default RPC URL constant."""
assert constants.DEFAULT_RPC_URL == '/'
92 changes: 92 additions & 0 deletions tests/utils/test_error_handlers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
"""Tests for a2a.utils.error_handlers module."""

from unittest.mock import patch

import pytest

from a2a.types import (
InternalError,
InvalidRequestError,
MethodNotFoundError,
TaskNotFoundError,
)
from a2a.utils.error_handlers import (
A2AErrorToHttpStatus,
rest_error_handler,
rest_stream_error_handler,
)
from a2a.utils.errors import ServerError


class MockJSONResponse:
def __init__(self, content, status_code):
self.content = content
self.status_code = status_code


@pytest.mark.asyncio
async def test_rest_error_handler_server_error():
"""Test rest_error_handler with ServerError."""
error = InvalidRequestError(message='Bad request')

@rest_error_handler
async def failing_func():
raise ServerError(error=error)

with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
result = await failing_func()

assert isinstance(result, MockJSONResponse)
assert result.status_code == 400
assert result.content == {'message': 'Bad request'}


@pytest.mark.asyncio
async def test_rest_error_handler_unknown_exception():
"""Test rest_error_handler with unknown exception."""

@rest_error_handler
async def failing_func():
raise ValueError('Unexpected error')

with patch('a2a.utils.error_handlers.JSONResponse', MockJSONResponse):
result = await failing_func()

assert isinstance(result, MockJSONResponse)
assert result.status_code == 500
assert result.content == {'message': 'unknown exception'}


@pytest.mark.asyncio
async def test_rest_stream_error_handler_server_error():
"""Test rest_stream_error_handler with ServerError."""
error = InternalError(message='Internal server error')

@rest_stream_error_handler
async def failing_stream():
raise ServerError(error=error)

with pytest.raises(ServerError) as exc_info:
await failing_stream()

assert exc_info.value.error == error


@pytest.mark.asyncio
async def test_rest_stream_error_handler_reraises_exception():
"""Test rest_stream_error_handler reraises other exceptions."""

@rest_stream_error_handler
async def failing_stream():
raise RuntimeError('Stream failed')

with pytest.raises(RuntimeError, match='Stream failed'):
await failing_stream()


def test_a2a_error_to_http_status_mapping():
"""Test A2AErrorToHttpStatus mapping."""
assert A2AErrorToHttpStatus[InvalidRequestError] == 400
assert A2AErrorToHttpStatus[MethodNotFoundError] == 404
assert A2AErrorToHttpStatus[TaskNotFoundError] == 404
assert A2AErrorToHttpStatus[InternalError] == 500
Loading