1+ """Tests for a2a.server.models module."""
2+
3+ from unittest .mock import MagicMock , patch
4+
5+ import pytest
6+
7+ from a2a .types import Artifact , TaskStatus
8+
9+
10+ class TestPydanticType :
11+ """Tests for PydanticType SQLAlchemy type decorator."""
12+
13+ def test_process_bind_param_with_pydantic_model (self ):
14+ from a2a .server .models import PydanticType
15+ from a2a .types import TaskState
16+
17+ pydantic_type = PydanticType (TaskStatus )
18+ status = TaskStatus (state = TaskState .working )
19+ dialect = MagicMock ()
20+
21+ result = pydantic_type .process_bind_param (status , dialect )
22+ assert result ["state" ] == "working"
23+ assert result ["message" ] is None
24+ # TaskStatus may have other optional fields
25+
26+ def test_process_bind_param_with_none (self ):
27+ from a2a .server .models import PydanticType
28+
29+ pydantic_type = PydanticType (TaskStatus )
30+ dialect = MagicMock ()
31+
32+ result = pydantic_type .process_bind_param (None , dialect )
33+ assert result is None
34+
35+ def test_process_result_value (self ):
36+ from a2a .server .models import PydanticType
37+
38+ pydantic_type = PydanticType (TaskStatus )
39+ dialect = MagicMock ()
40+
41+ result = pydantic_type .process_result_value ({"state" : "completed" , "message" : None }, dialect )
42+ assert isinstance (result , TaskStatus )
43+ assert result .state == "completed"
44+
45+
46+ class TestPydanticListType :
47+ """Tests for PydanticListType SQLAlchemy type decorator."""
48+
49+ def test_process_bind_param_with_list (self ):
50+ from a2a .server .models import PydanticListType
51+ from a2a .types import Artifact , TextPart
52+
53+ pydantic_list_type = PydanticListType (Artifact )
54+ artifacts = [
55+ Artifact (artifact_id = "1" , parts = [TextPart (type = "text" , text = "Hello" )]),
56+ Artifact (artifact_id = "2" , parts = [TextPart (type = "text" , text = "World" )])
57+ ]
58+ dialect = MagicMock ()
59+
60+ result = pydantic_list_type .process_bind_param (artifacts , dialect )
61+ assert len (result ) == 2
62+ assert result [0 ]["artifactId" ] == "1" # JSON mode uses camelCase
63+ assert result [1 ]["artifactId" ] == "2"
64+
65+ def test_process_result_value_with_list (self ):
66+ from a2a .server .models import PydanticListType
67+ from a2a .types import Artifact
68+
69+ pydantic_list_type = PydanticListType (Artifact )
70+ dialect = MagicMock ()
71+ data = [
72+ {"artifact_id" : "1" , "parts" : [{"type" : "text" , "text" : "Hello" }]},
73+ {"artifact_id" : "2" , "parts" : [{"type" : "text" , "text" : "World" }]}
74+ ]
75+
76+ result = pydantic_list_type .process_result_value (data , dialect )
77+ assert len (result ) == 2
78+ assert all (isinstance (art , Artifact ) for art in result )
79+ assert result [0 ].artifact_id == "1"
80+ assert result [1 ].artifact_id == "2"
81+
82+
83+ def test_create_task_model ():
84+ """Test dynamic task model creation."""
85+ from a2a .server .models import Base , create_task_model
86+ from sqlalchemy .orm import DeclarativeBase
87+
88+ # Create a fresh base to avoid table conflicts
89+ class TestBase (DeclarativeBase ):
90+ pass
91+
92+ # Create with default table name
93+ DefaultTaskModel = create_task_model ('test_tasks_1' , TestBase )
94+ assert DefaultTaskModel .__tablename__ == 'test_tasks_1'
95+ assert DefaultTaskModel .__name__ == 'TaskModel_test_tasks_1'
96+
97+ # Create with custom table name
98+ CustomTaskModel = create_task_model ('test_tasks_2' , TestBase )
99+ assert CustomTaskModel .__tablename__ == 'test_tasks_2'
100+ assert CustomTaskModel .__name__ == 'TaskModel_test_tasks_2'
101+
102+
103+ def test_create_push_notification_config_model ():
104+ """Test dynamic push notification config model creation."""
105+ from a2a .server .models import create_push_notification_config_model
106+ from sqlalchemy .orm import DeclarativeBase
107+
108+ # Create a fresh base to avoid table conflicts
109+ class TestBase (DeclarativeBase ):
110+ pass
111+
112+ # Create with default table name
113+ DefaultModel = create_push_notification_config_model ('test_push_configs_1' , TestBase )
114+ assert DefaultModel .__tablename__ == 'test_push_configs_1'
115+
116+ # Create with custom table name
117+ CustomModel = create_push_notification_config_model ('test_push_configs_2' , TestBase )
118+ assert CustomModel .__tablename__ == 'test_push_configs_2'
119+ assert 'test_push_configs_2' in CustomModel .__name__
0 commit comments