Skip to content

Commit 6259bea

Browse files
committed
perf: Improve performance and code style for proto_utils.py
1 parent 42ff0d4 commit 6259bea

File tree

1 file changed

+76
-94
lines changed

1 file changed

+76
-94
lines changed

src/a2a/utils/proto_utils.py

Lines changed: 76 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,30 @@
1818

1919

2020
# Regexp patterns for matching
21-
_TASK_NAME_MATCH = r'tasks/([\w-]+)'
22-
_TASK_PUSH_CONFIG_NAME_MATCH = (
21+
_TASK_NAME_MATCH = re.compile(r'tasks/([\w-]+)')
22+
_TASK_PUSH_CONFIG_NAME_MATCH = re.compile(
2323
r'tasks/([\w-]+)/pushNotificationConfigs/([\w-]+)'
2424
)
2525

2626

27+
def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
28+
"""Converts a Python dict to a Struct proto.
29+
30+
Args:
31+
dictionary: The Python dict to convert.
32+
33+
Returns:
34+
The Struct proto.
35+
"""
36+
struct = struct_pb2.Struct()
37+
for key, val in dictionary.items():
38+
if isinstance(val, dict):
39+
struct[key] = dict_to_struct(val)
40+
else:
41+
struct[key] = val
42+
return struct
43+
44+
2745
class ToProto:
2846
"""Converts Python types to proto types."""
2947

@@ -33,11 +51,11 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None:
3351
return None
3452
return a2a_pb2.Message(
3553
message_id=message.message_id,
36-
content=[ToProto.part(p) for p in message.parts],
54+
content=[cls.part(p) for p in message.parts],
3755
context_id=message.context_id or '',
3856
task_id=message.task_id or '',
3957
role=cls.role(message.role),
40-
metadata=ToProto.metadata(message.metadata),
58+
metadata=cls.metadata(message.metadata),
4159
)
4260

4361
@classmethod
@@ -53,20 +71,14 @@ def part(cls, part: types.Part) -> a2a_pb2.Part:
5371
if isinstance(part.root, types.TextPart):
5472
return a2a_pb2.Part(text=part.root.text)
5573
if isinstance(part.root, types.FilePart):
56-
return a2a_pb2.Part(file=ToProto.file(part.root.file))
74+
return a2a_pb2.Part(file=cls.file(part.root.file))
5775
if isinstance(part.root, types.DataPart):
58-
return a2a_pb2.Part(data=ToProto.data(part.root.data))
76+
return a2a_pb2.Part(data=cls.data(part.root.data))
5977
raise ValueError(f'Unsupported part type: {part.root}')
6078

6179
@classmethod
6280
def data(cls, data: dict[str, Any]) -> a2a_pb2.DataPart:
63-
json_data = json.dumps(data)
64-
return a2a_pb2.DataPart(
65-
data=json_format.Parse(
66-
json_data,
67-
struct_pb2.Struct(),
68-
)
69-
)
81+
return a2a_pb2.DataPart(data=dict_to_struct(data))
7082

7183
@classmethod
7284
def file(
@@ -87,14 +99,14 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
8799
return a2a_pb2.Task(
88100
id=task.id,
89101
context_id=task.context_id,
90-
status=ToProto.task_status(task.status),
102+
status=cls.task_status(task.status),
91103
artifacts=(
92-
[ToProto.artifact(a) for a in task.artifacts]
104+
[cls.artifact(a) for a in task.artifacts]
93105
if task.artifacts
94106
else None
95107
),
96108
history=(
97-
[ToProto.message(h) for h in task.history] # type: ignore[misc]
109+
[cls.message(h) for h in task.history] # type: ignore[misc]
98110
if task.history
99111
else None
100112
),
@@ -103,8 +115,8 @@ def task(cls, task: types.Task) -> a2a_pb2.Task:
103115
@classmethod
104116
def task_status(cls, status: types.TaskStatus) -> a2a_pb2.TaskStatus:
105117
return a2a_pb2.TaskStatus(
106-
state=ToProto.task_state(status.state),
107-
update=ToProto.message(status.message),
118+
state=cls.task_state(status.state),
119+
update=cls.message(status.message),
108120
)
109121

110122
@classmethod
@@ -132,9 +144,9 @@ def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact:
132144
return a2a_pb2.Artifact(
133145
artifact_id=artifact.artifact_id,
134146
description=artifact.description,
135-
metadata=ToProto.metadata(artifact.metadata),
147+
metadata=cls.metadata(artifact.metadata),
136148
name=artifact.name,
137-
parts=[ToProto.part(p) for p in artifact.parts],
149+
parts=[cls.part(p) for p in artifact.parts],
138150
)
139151

140152
@classmethod
@@ -151,7 +163,7 @@ def push_notification_config(
151163
cls, config: types.PushNotificationConfig
152164
) -> a2a_pb2.PushNotificationConfig:
153165
auth_info = (
154-
ToProto.authentication_info(config.authentication)
166+
cls.authentication_info(config.authentication)
155167
if config.authentication
156168
else None
157169
)
@@ -169,8 +181,8 @@ def task_artifact_update_event(
169181
return a2a_pb2.TaskArtifactUpdateEvent(
170182
task_id=event.task_id,
171183
context_id=event.context_id,
172-
artifact=ToProto.artifact(event.artifact),
173-
metadata=ToProto.metadata(event.metadata),
184+
artifact=cls.artifact(event.artifact),
185+
metadata=cls.metadata(event.metadata),
174186
append=event.append or False,
175187
last_chunk=event.last_chunk or False,
176188
)
@@ -182,8 +194,8 @@ def task_status_update_event(
182194
return a2a_pb2.TaskStatusUpdateEvent(
183195
task_id=event.task_id,
184196
context_id=event.context_id,
185-
status=ToProto.task_status(event.status),
186-
metadata=ToProto.metadata(event.metadata),
197+
status=cls.task_status(event.status),
198+
metadata=cls.metadata(event.metadata),
187199
final=event.final,
188200
)
189201

@@ -195,7 +207,7 @@ def message_send_configuration(
195207
return a2a_pb2.SendMessageConfiguration()
196208
return a2a_pb2.SendMessageConfiguration(
197209
accepted_output_modes=config.accepted_output_modes,
198-
push_notification=ToProto.push_notification_config(
210+
push_notification=cls.push_notification_config(
199211
config.push_notification_config
200212
)
201213
if config.push_notification_config
@@ -213,19 +225,7 @@ def update_event(
213225
| types.TaskArtifactUpdateEvent,
214226
) -> a2a_pb2.StreamResponse:
215227
"""Converts a task, message, or task update event to a StreamResponse."""
216-
if isinstance(event, types.TaskStatusUpdateEvent):
217-
return a2a_pb2.StreamResponse(
218-
status_update=ToProto.task_status_update_event(event)
219-
)
220-
if isinstance(event, types.TaskArtifactUpdateEvent):
221-
return a2a_pb2.StreamResponse(
222-
artifact_update=ToProto.task_artifact_update_event(event)
223-
)
224-
if isinstance(event, types.Message):
225-
return a2a_pb2.StreamResponse(msg=ToProto.message(event))
226-
if isinstance(event, types.Task):
227-
return a2a_pb2.StreamResponse(task=ToProto.task(event))
228-
raise ValueError(f'Unsupported event type: {type(event)}')
228+
return cls.stream_response(event)
229229

230230
@classmethod
231231
def task_or_message(
@@ -257,9 +257,11 @@ def stream_response(
257257
return a2a_pb2.StreamResponse(
258258
status_update=cls.task_status_update_event(event),
259259
)
260-
return a2a_pb2.StreamResponse(
261-
artifact_update=cls.task_artifact_update_event(event),
262-
)
260+
if isinstance(event, types.TaskArtifactUpdateEvent):
261+
return a2a_pb2.StreamResponse(
262+
artifact_update=cls.task_artifact_update_event(event),
263+
)
264+
raise ValueError(f'Unsupported event type: {type(event)}')
263265

264266
@classmethod
265267
def task_push_notification_config(
@@ -480,11 +482,11 @@ class FromProto:
480482
def message(cls, message: a2a_pb2.Message) -> types.Message:
481483
return types.Message(
482484
message_id=message.message_id,
483-
parts=[FromProto.part(p) for p in message.content],
485+
parts=[cls.part(p) for p in message.content],
484486
context_id=message.context_id or None,
485487
task_id=message.task_id or None,
486-
role=FromProto.role(message.role),
487-
metadata=FromProto.metadata(message.metadata),
488+
role=cls.role(message.role),
489+
metadata=cls.metadata(message.metadata),
488490
)
489491

490492
@classmethod
@@ -498,13 +500,9 @@ def part(cls, part: a2a_pb2.Part) -> types.Part:
498500
if part.HasField('text'):
499501
return types.Part(root=types.TextPart(text=part.text))
500502
if part.HasField('file'):
501-
return types.Part(
502-
root=types.FilePart(file=FromProto.file(part.file))
503-
)
503+
return types.Part(root=types.FilePart(file=cls.file(part.file)))
504504
if part.HasField('data'):
505-
return types.Part(
506-
root=types.DataPart(data=FromProto.data(part.data))
507-
)
505+
return types.Part(root=types.DataPart(data=cls.data(part.data)))
508506
raise ValueError(f'Unsupported part type: {part}')
509507

510508
@classmethod
@@ -543,16 +541,16 @@ def task(cls, task: a2a_pb2.Task) -> types.Task:
543541
return types.Task(
544542
id=task.id,
545543
context_id=task.context_id,
546-
status=FromProto.task_status(task.status),
547-
artifacts=[FromProto.artifact(a) for a in task.artifacts],
548-
history=[FromProto.message(h) for h in task.history],
544+
status=cls.task_status(task.status),
545+
artifacts=[cls.artifact(a) for a in task.artifacts],
546+
history=[cls.message(h) for h in task.history],
549547
)
550548

551549
@classmethod
552550
def task_status(cls, status: a2a_pb2.TaskStatus) -> types.TaskStatus:
553551
return types.TaskStatus(
554-
state=FromProto.task_state(status.state),
555-
message=FromProto.message(status.update),
552+
state=cls.task_state(status.state),
553+
message=cls.message(status.update),
556554
)
557555

558556
@classmethod
@@ -580,9 +578,9 @@ def artifact(cls, artifact: a2a_pb2.Artifact) -> types.Artifact:
580578
return types.Artifact(
581579
artifact_id=artifact.artifact_id,
582580
description=artifact.description,
583-
metadata=FromProto.metadata(artifact.metadata),
581+
metadata=cls.metadata(artifact.metadata),
584582
name=artifact.name,
585-
parts=[FromProto.part(p) for p in artifact.parts],
583+
parts=[cls.part(p) for p in artifact.parts],
586584
)
587585

588586
@classmethod
@@ -592,8 +590,8 @@ def task_artifact_update_event(
592590
return types.TaskArtifactUpdateEvent(
593591
task_id=event.task_id,
594592
context_id=event.context_id,
595-
artifact=FromProto.artifact(event.artifact),
596-
metadata=FromProto.metadata(event.metadata),
593+
artifact=cls.artifact(event.artifact),
594+
metadata=cls.metadata(event.metadata),
597595
append=event.append,
598596
last_chunk=event.last_chunk,
599597
)
@@ -605,8 +603,8 @@ def task_status_update_event(
605603
return types.TaskStatusUpdateEvent(
606604
task_id=event.task_id,
607605
context_id=event.context_id,
608-
status=FromProto.task_status(event.status),
609-
metadata=FromProto.metadata(event.metadata),
606+
status=cls.task_status(event.status),
607+
metadata=cls.metadata(event.metadata),
610608
final=event.final,
611609
)
612610

@@ -618,7 +616,7 @@ def push_notification_config(
618616
id=config.id,
619617
url=config.url,
620618
token=config.token,
621-
authentication=FromProto.authentication_info(config.authentication)
619+
authentication=cls.authentication_info(config.authentication)
622620
if config.HasField('authentication')
623621
else None,
624622
)
@@ -638,7 +636,7 @@ def message_send_configuration(
638636
) -> types.MessageSendConfiguration:
639637
return types.MessageSendConfiguration(
640638
accepted_output_modes=list(config.accepted_output_modes),
641-
push_notification_config=FromProto.push_notification_config(
639+
push_notification_config=cls.push_notification_config(
642640
config.push_notification
643641
)
644642
if config.HasField('push_notification')
@@ -666,18 +664,16 @@ def task_id_params(
666664
| a2a_pb2.GetTaskPushNotificationConfigRequest
667665
),
668666
) -> types.TaskIdParams:
669-
# This is currently incomplete until the core sdk supports multiple
670-
# configs for a single task.
671667
if isinstance(request, a2a_pb2.GetTaskPushNotificationConfigRequest):
672-
m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, request.name)
668+
m = _TASK_PUSH_CONFIG_NAME_MATCH.match(request.name)
673669
if not m:
674670
raise ServerError(
675671
error=types.InvalidParamsError(
676672
message=f'No task for {request.name}'
677673
)
678674
)
679675
return types.TaskIdParams(id=m.group(1))
680-
m = re.match(_TASK_NAME_MATCH, request.name)
676+
m = _TASK_NAME_MATCH.match(request.name)
681677
if not m:
682678
raise ServerError(
683679
error=types.InvalidParamsError(
@@ -691,7 +687,7 @@ def task_push_notification_config_request(
691687
cls,
692688
request: a2a_pb2.CreateTaskPushNotificationConfigRequest,
693689
) -> types.TaskPushNotificationConfig:
694-
m = re.match(_TASK_NAME_MATCH, request.parent)
690+
m = _TASK_NAME_MATCH.match(request.parent)
695691
if not m:
696692
raise ServerError(
697693
error=types.InvalidParamsError(
@@ -710,7 +706,7 @@ def task_push_notification_config(
710706
cls,
711707
config: a2a_pb2.TaskPushNotificationConfig,
712708
) -> types.TaskPushNotificationConfig:
713-
m = re.match(_TASK_PUSH_CONFIG_NAME_MATCH, config.name)
709+
m = _TASK_PUSH_CONFIG_NAME_MATCH.match(config.name)
714710
if not m:
715711
raise ServerError(
716712
error=types.InvalidParamsError(
@@ -767,7 +763,7 @@ def task_query_params(
767763
cls,
768764
request: a2a_pb2.GetTaskRequest,
769765
) -> types.TaskQueryParams:
770-
m = re.match(_TASK_NAME_MATCH, request.name)
766+
m = _TASK_NAME_MATCH.match(request.name)
771767
if not m:
772768
raise ServerError(
773769
error=types.InvalidParamsError(
@@ -862,6 +858,12 @@ def security_scheme(
862858
flows=cls.oauth2_flows(scheme.oauth2_security_scheme.flows),
863859
)
864860
)
861+
if scheme.HasField('mtls_security_scheme'):
862+
return types.SecurityScheme(
863+
root=types.MutualTLSSecurityScheme(
864+
description=scheme.mtls_security_scheme.description,
865+
)
866+
)
865867
return types.SecurityScheme(
866868
root=types.OpenIdConnectSecurityScheme(
867869
description=scheme.open_id_connect_security_scheme.description,
@@ -920,7 +922,9 @@ def stream_response(
920922
return cls.task(response.task)
921923
if response.HasField('status_update'):
922924
return cls.task_status_update_event(response.status_update)
923-
return cls.task_artifact_update_event(response.artifact_update)
925+
if response.HasField('artifact_update'):
926+
return cls.task_artifact_update_event(response.artifact_update)
927+
raise ValueError('Unsupported StreamResponse type')
924928

925929
@classmethod
926930
def skill(cls, skill: a2a_pb2.AgentSkill) -> types.AgentSkill:
@@ -943,25 +947,3 @@ def role(cls, role: a2a_pb2.Role) -> types.Role:
943947
return types.Role.agent
944948
case _:
945949
return types.Role.agent
946-
947-
948-
def dict_to_struct(dictionary: dict[str, Any]) -> struct_pb2.Struct:
949-
"""Converts a Python dict to a Struct proto.
950-
951-
Unfortunately, using `json_format.ParseDict` does not work because this
952-
wants the dictionary to be an exact match of the Struct proto with fields
953-
and keys and values, not the traditional Python dict structure.
954-
955-
Args:
956-
dictionary: The Python dict to convert.
957-
958-
Returns:
959-
The Struct proto.
960-
"""
961-
struct = struct_pb2.Struct()
962-
for key, val in dictionary.items():
963-
if isinstance(val, dict):
964-
struct[key] = dict_to_struct(val)
965-
else:
966-
struct[key] = val
967-
return struct

0 commit comments

Comments
 (0)