1818
1919
2020# Regexp patterns for matching
21- _TASK_NAME_MATCH = r'tasks/([\w-]+)'
22- _TASK_PUSH_CONFIG_NAME_MATCH = (
21+ _TASK_NAME_MATCH = re . compile ( r'tasks/([\w-]+)' )
22+ _TASK_PUSH_CONFIG_NAME_MATCH = re . compile (
2323 r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)'
2424)
2525
2626
27+ def dict_to_struct (dictionary : dict [str , Any ]) -> struct_pb2 .Struct :
28+ """Converts a Python dict to a Struct proto.
29+
30+ Args:
31+ dictionary: The Python dict to convert.
32+
33+ Returns:
34+ The Struct proto.
35+ """
36+ struct = struct_pb2 .Struct ()
37+ for key , val in dictionary .items ():
38+ if isinstance (val , dict ):
39+ struct [key ] = dict_to_struct (val )
40+ else :
41+ struct [key ] = val
42+ return struct
43+
44+
2745class ToProto :
2846 """Converts Python types to proto types."""
2947
@@ -33,11 +51,11 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None:
3351 return None
3452 return a2a_pb2 .Message (
3553 message_id = message .message_id ,
36- content = [ToProto .part (p ) for p in message .parts ],
54+ content = [cls .part (p ) for p in message .parts ],
3755 context_id = message .context_id or '' ,
3856 task_id = message .task_id or '' ,
3957 role = cls .role (message .role ),
40- metadata = ToProto .metadata (message .metadata ),
58+ metadata = cls .metadata (message .metadata ),
4159 )
4260
4361 @classmethod
@@ -53,20 +71,14 @@ def part(cls, part: types.Part) -> a2a_pb2.Part:
5371 if isinstance (part .root , types .TextPart ):
5472 return a2a_pb2 .Part (text = part .root .text )
5573 if isinstance (part .root , types .FilePart ):
56- return a2a_pb2 .Part (file = ToProto .file (part .root .file ))
74+ return a2a_pb2 .Part (file = cls .file (part .root .file ))
5775 if isinstance (part .root , types .DataPart ):
58- return a2a_pb2 .Part (data = ToProto .data (part .root .data ))
76+ return a2a_pb2 .Part (data = cls .data (part .root .data ))
5977 raise ValueError (f'Unsupported part type: { part .root } ' )
6078
6179 @classmethod
6280 def data (cls , data : dict [str , Any ]) -> a2a_pb2 .DataPart :
63- json_data = json .dumps (data )
64- return a2a_pb2 .DataPart (
65- data = json_format .Parse (
66- json_data ,
67- struct_pb2 .Struct (),
68- )
69- )
81+ return a2a_pb2 .DataPart (data = dict_to_struct (data ))
7082
7183 @classmethod
7284 def file (
@@ -87,14 +99,14 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
8799 return a2a_pb2 .Task (
88100 id = task .id ,
89101 context_id = task .context_id ,
90- status = ToProto .task_status (task .status ),
102+ status = cls .task_status (task .status ),
91103 artifacts = (
92- [ToProto .artifact (a ) for a in task .artifacts ]
104+ [cls .artifact (a ) for a in task .artifacts ]
93105 if task .artifacts
94106 else None
95107 ),
96108 history = (
97- [ToProto .message (h ) for h in task .history ] # type: ignore[misc]
109+ [cls .message (h ) for h in task .history ] # type: ignore[misc]
98110 if task .history
99111 else None
100112 ),
@@ -103,8 +115,8 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
103115 @classmethod
104116 def task_status (cls , status : types .TaskStatus ) -> a2a_pb2 .TaskStatus :
105117 return a2a_pb2 .TaskStatus (
106- state = ToProto .task_state (status .state ),
107- update = ToProto .message (status .message ),
118+ state = cls .task_state (status .state ),
119+ update = cls .message (status .message ),
108120 )
109121
110122 @classmethod
@@ -132,9 +144,9 @@ def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact:
132144 return a2a_pb2 .Artifact (
133145 artifact_id = artifact .artifact_id ,
134146 description = artifact .description ,
135- metadata = ToProto .metadata (artifact .metadata ),
147+ metadata = cls .metadata (artifact .metadata ),
136148 name = artifact .name ,
137- parts = [ToProto .part (p ) for p in artifact .parts ],
149+ parts = [cls .part (p ) for p in artifact .parts ],
138150 )
139151
140152 @classmethod
@@ -151,7 +163,7 @@ def push_notification_config(
151163 cls , config : types .PushNotificationConfig
152164 ) -> a2a_pb2 .PushNotificationConfig :
153165 auth_info = (
154- ToProto .authentication_info (config .authentication )
166+ cls .authentication_info (config .authentication )
155167 if config .authentication
156168 else None
157169 )
@@ -169,8 +181,8 @@ def task_artifact_update_event(
169181 return a2a_pb2 .TaskArtifactUpdateEvent (
170182 task_id = event .task_id ,
171183 context_id = event .context_id ,
172- artifact = ToProto .artifact (event .artifact ),
173- metadata = ToProto .metadata (event .metadata ),
184+ artifact = cls .artifact (event .artifact ),
185+ metadata = cls .metadata (event .metadata ),
174186 append = event .append or False ,
175187 last_chunk = event .last_chunk or False ,
176188 )
@@ -182,8 +194,8 @@ def task_status_update_event(
182194 return a2a_pb2 .TaskStatusUpdateEvent (
183195 task_id = event .task_id ,
184196 context_id = event .context_id ,
185- status = ToProto .task_status (event .status ),
186- metadata = ToProto .metadata (event .metadata ),
197+ status = cls .task_status (event .status ),
198+ metadata = cls .metadata (event .metadata ),
187199 final = event .final ,
188200 )
189201
@@ -195,7 +207,7 @@ def message_send_configuration(
195207 return a2a_pb2 .SendMessageConfiguration ()
196208 return a2a_pb2 .SendMessageConfiguration (
197209 accepted_output_modes = config .accepted_output_modes ,
198- push_notification = ToProto .push_notification_config (
210+ push_notification = cls .push_notification_config (
199211 config .push_notification_config
200212 )
201213 if config .push_notification_config
@@ -213,19 +225,7 @@ def update_event(
213225 | types .TaskArtifactUpdateEvent ,
214226 ) -> a2a_pb2 .StreamResponse :
215227 """Converts a task, message, or task update event to a StreamResponse."""
216- if isinstance (event , types .TaskStatusUpdateEvent ):
217- return a2a_pb2 .StreamResponse (
218- status_update = ToProto .task_status_update_event (event )
219- )
220- if isinstance (event , types .TaskArtifactUpdateEvent ):
221- return a2a_pb2 .StreamResponse (
222- artifact_update = ToProto .task_artifact_update_event (event )
223- )
224- if isinstance (event , types .Message ):
225- return a2a_pb2 .StreamResponse (msg = ToProto .message (event ))
226- if isinstance (event , types .Task ):
227- return a2a_pb2 .StreamResponse (task = ToProto .task (event ))
228- raise ValueError (f'Unsupported event type: { type (event )} ' )
228+ return cls .stream_response (event )
229229
230230 @classmethod
231231 def task_or_message (
@@ -257,9 +257,11 @@ def stream_response(
257257 return a2a_pb2 .StreamResponse (
258258 status_update = cls .task_status_update_event (event ),
259259 )
260- return a2a_pb2 .StreamResponse (
261- artifact_update = cls .task_artifact_update_event (event ),
262- )
260+ if isinstance (event , types .TaskArtifactUpdateEvent ):
261+ return a2a_pb2 .StreamResponse (
262+ artifact_update = cls .task_artifact_update_event (event ),
263+ )
264+ raise ValueError (f'Unsupported event type: { type (event )} ' )
263265
264266 @classmethod
265267 def task_push_notification_config (
@@ -480,11 +482,11 @@ class FromProto:
480482 def message (cls , message : a2a_pb2 .Message ) -> types .Message :
481483 return types .Message (
482484 message_id = message .message_id ,
483- parts = [FromProto .part (p ) for p in message .content ],
485+ parts = [cls .part (p ) for p in message .content ],
484486 context_id = message .context_id or None ,
485487 task_id = message .task_id or None ,
486- role = FromProto .role (message .role ),
487- metadata = FromProto .metadata (message .metadata ),
488+ role = cls .role (message .role ),
489+ metadata = cls .metadata (message .metadata ),
488490 )
489491
490492 @classmethod
@@ -498,13 +500,9 @@ def part(cls, part: a2a_pb2.Part) -> types.Part:
498500 if part .HasField ('text' ):
499501 return types .Part (root = types .TextPart (text = part .text ))
500502 if part .HasField ('file' ):
501- return types .Part (
502- root = types .FilePart (file = FromProto .file (part .file ))
503- )
503+ return types .Part (root = types .FilePart (file = cls .file (part .file )))
504504 if part .HasField ('data' ):
505- return types .Part (
506- root = types .DataPart (data = FromProto .data (part .data ))
507- )
505+ return types .Part (root = types .DataPart (data = cls .data (part .data )))
508506 raise ValueError (f'Unsupported part type: { part } ' )
509507
510508 @classmethod
@@ -543,16 +541,16 @@ def task(cls, task: a2a_pb2.Task) -> types.Task:
543541 return types .Task (
544542 id = task .id ,
545543 context_id = task .context_id ,
546- status = FromProto .task_status (task .status ),
547- artifacts = [FromProto .artifact (a ) for a in task .artifacts ],
548- history = [FromProto .message (h ) for h in task .history ],
544+ status = cls .task_status (task .status ),
545+ artifacts = [cls .artifact (a ) for a in task .artifacts ],
546+ history = [cls .message (h ) for h in task .history ],
549547 )
550548
551549 @classmethod
552550 def task_status (cls , status : a2a_pb2 .TaskStatus ) -> types .TaskStatus :
553551 return types .TaskStatus (
554- state = FromProto .task_state (status .state ),
555- message = FromProto .message (status .update ),
552+ state = cls .task_state (status .state ),
553+ message = cls .message (status .update ),
556554 )
557555
558556 @classmethod
@@ -580,9 +578,9 @@ def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact:
580578 return types .Artifact (
581579 artifact_id = artifact .artifact_id ,
582580 description = artifact .description ,
583- metadata = FromProto .metadata (artifact .metadata ),
581+ metadata = cls .metadata (artifact .metadata ),
584582 name = artifact .name ,
585- parts = [FromProto .part (p ) for p in artifact .parts ],
583+ parts = [cls .part (p ) for p in artifact .parts ],
586584 )
587585
588586 @classmethod
@@ -592,8 +590,8 @@ def task_artifact_update_event(
592590 return types .TaskArtifactUpdateEvent (
593591 task_id = event .task_id ,
594592 context_id = event .context_id ,
595- artifact = FromProto .artifact (event .artifact ),
596- metadata = FromProto .metadata (event .metadata ),
593+ artifact = cls .artifact (event .artifact ),
594+ metadata = cls .metadata (event .metadata ),
597595 append = event .append ,
598596 last_chunk = event .last_chunk ,
599597 )
@@ -605,8 +603,8 @@ def task_status_update_event(
605603 return types .TaskStatusUpdateEvent (
606604 task_id = event .task_id ,
607605 context_id = event .context_id ,
608- status = FromProto .task_status (event .status ),
609- metadata = FromProto .metadata (event .metadata ),
606+ status = cls .task_status (event .status ),
607+ metadata = cls .metadata (event .metadata ),
610608 final = event .final ,
611609 )
612610
@@ -618,7 +616,7 @@ def push_notification_config(
618616 id = config .id ,
619617 url = config .url ,
620618 token = config .token ,
621- authentication = FromProto .authentication_info (config .authentication )
619+ authentication = cls .authentication_info (config .authentication )
622620 if config .HasField ('authentication' )
623621 else None ,
624622 )
@@ -638,7 +636,7 @@ def message_send_configuration(
638636 ) -> types .MessageSendConfiguration :
639637 return types .MessageSendConfiguration (
640638 accepted_output_modes = list (config .accepted_output_modes ),
641- push_notification_config = FromProto .push_notification_config (
639+ push_notification_config = cls .push_notification_config (
642640 config .push_notification
643641 )
644642 if config .HasField ('push_notification' )
@@ -666,18 +664,16 @@ def task_id_params(
666664 | a2a_pb2 .GetTaskPushNotificationConfigRequest
667665 ),
668666 ) -> types .TaskIdParams :
669- # This is currently incomplete until the core sdk supports multiple
670- # configs for a single task.
671667 if isinstance (request , a2a_pb2 .GetTaskPushNotificationConfigRequest ):
672- m = re .match (_TASK_PUSH_CONFIG_NAME_MATCH , request .name )
668+ m = _TASK_PUSH_CONFIG_NAME_MATCH .match (request .name )
673669 if not m :
674670 raise ServerError (
675671 error = types .InvalidParamsError (
676672 message = f'No task for { request .name } '
677673 )
678674 )
679675 return types .TaskIdParams (id = m .group (1 ))
680- m = re .match (_TASK_NAME_MATCH , request .name )
676+ m = _TASK_NAME_MATCH .match (request .name )
681677 if not m :
682678 raise ServerError (
683679 error = types .InvalidParamsError (
@@ -691,7 +687,7 @@ def task_push_notification_config_request(
691687 cls ,
692688 request : a2a_pb2 .CreateTaskPushNotificationConfigRequest ,
693689 ) -> types .TaskPushNotificationConfig :
694- m = re .match (_TASK_NAME_MATCH , request .parent )
690+ m = _TASK_NAME_MATCH .match (request .parent )
695691 if not m :
696692 raise ServerError (
697693 error = types .InvalidParamsError (
@@ -710,7 +706,7 @@ def task_push_notification_config(
710706 cls ,
711707 config : a2a_pb2 .TaskPushNotificationConfig ,
712708 ) -> types .TaskPushNotificationConfig :
713- m = re .match (_TASK_PUSH_CONFIG_NAME_MATCH , config .name )
709+ m = _TASK_PUSH_CONFIG_NAME_MATCH .match (config .name )
714710 if not m :
715711 raise ServerError (
716712 error = types .InvalidParamsError (
@@ -767,7 +763,7 @@ def task_query_params(
767763 cls ,
768764 request : a2a_pb2 .GetTaskRequest ,
769765 ) -> types .TaskQueryParams :
770- m = re .match (_TASK_NAME_MATCH , request .name )
766+ m = _TASK_NAME_MATCH .match (request .name )
771767 if not m :
772768 raise ServerError (
773769 error = types .InvalidParamsError (
@@ -862,6 +858,12 @@ def security_scheme(
862858 flows = cls .oauth2_flows (scheme .oauth2_security_scheme .flows ),
863859 )
864860 )
861+ if scheme .HasField ('mtls_security_scheme' ):
862+ return types .SecurityScheme (
863+ root = types .MutualTLSSecurityScheme (
864+ description = scheme .mtls_security_scheme .description ,
865+ )
866+ )
865867 return types .SecurityScheme (
866868 root = types .OpenIdConnectSecurityScheme (
867869 description = scheme .open_id_connect_security_scheme .description ,
@@ -920,7 +922,9 @@ def stream_response(
920922 return cls .task (response .task )
921923 if response .HasField ('status_update' ):
922924 return cls .task_status_update_event (response .status_update )
923- return cls .task_artifact_update_event (response .artifact_update )
925+ if response .HasField ('artifact_update' ):
926+ return cls .task_artifact_update_event (response .artifact_update )
927+ raise ValueError ('Unsupported StreamResponse type' )
924928
925929 @classmethod
926930 def skill (cls , skill : a2a_pb2 .AgentSkill ) -> types .AgentSkill :
@@ -943,25 +947,3 @@ def role(cls, role: a2a_pb2.Role) -> types.Role:
943947 return types .Role .agent
944948 case _:
945949 return types .Role .agent
946-
947-
948- def dict_to_struct (dictionary : dict [str , Any ]) -> struct_pb2 .Struct :
949- """Converts a Python dict to a Struct proto.
950-
951- Unfortunately, using `json_format.ParseDict` does not work because this
952- wants the dictionary to be an exact match of the Struct proto with fields
953- and keys and values, not the traditional Python dict structure.
954-
955- Args:
956- dictionary: The Python dict to convert.
957-
958- Returns:
959- The Struct proto.
960- """
961- struct = struct_pb2 .Struct ()
962- for key , val in dictionary .items ():
963- if isinstance (val , dict ):
964- struct [key ] = dict_to_struct (val )
965- else :
966- struct [key ] = val
967- return struct
0 commit comments