Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/generate_types.sh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ uv run datamodel-codegen \
--no-alias

echo "Formatting generated file with ruff..."
uv run ruff check --fix-only "$GENERATED_FILE"
uv run ruff format "$GENERATED_FILE"

echo "Codegen finished successfully."
43 changes: 22 additions & 21 deletions src/a2a/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# generated by datamodel-codegen:
# filename: https://raw.githubusercontent.com/a2aproject/A2A/refs/heads/main/specification/json/a2a.json
# filename: https://raw.githubusercontent.com/a2aproject/A2A/refs/heads/uuid-fields/specification/json/a2a.json

from __future__ import annotations

from enum import Enum
from typing import Any, Literal
from uuid import UUID

Check failure on line 8 in src/a2a/types.py

View workflow job for this annotation

GitHub Actions / Lint Code Base

Ruff (TC003)

src/a2a/types.py:8:18: TC003 Move standard library import `uuid.UUID` into a type-checking block

from pydantic import Field, RootModel

Expand Down Expand Up @@ -293,15 +294,15 @@
Defines parameters for deleting a specific push notification configuration for a task.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
metadata: dict[str, Any] | None = None
"""
Optional metadata associated with the request.
"""
push_notification_config_id: str
push_notification_config_id: UUID
"""
The ID of the push notification configuration to delete.
"""
Expand Down Expand Up @@ -430,15 +431,15 @@
Defines parameters for fetching a specific push notification configuration for a task.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
metadata: dict[str, Any] | None = None
"""
Optional metadata associated with the request.
"""
push_notification_config_id: str | None = None
push_notification_config_id: UUID | None = None
"""
The ID of the push notification configuration to retrieve.
"""
Expand Down Expand Up @@ -675,7 +676,7 @@
Defines parameters for listing all push notification configurations associated with a task.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
Expand Down Expand Up @@ -828,7 +829,7 @@
"""
Optional authentication details for the agent to use when calling the notification URL.
"""
id: str | None = None
id: UUID | None = None
"""
A unique ID for the push notification configuration, set by the client
to support multiple notification callbacks.
Expand Down Expand Up @@ -879,7 +880,7 @@
Defines parameters containing a task ID, used for simple task operations.
"""

id: str
id: UUID
"""
The unique identifier of the task.
"""
Expand Down Expand Up @@ -938,7 +939,7 @@
"""
The push notification configuration for this task.
"""
task_id: str
task_id: UUID
"""
The ID of the task.
"""
Expand All @@ -953,7 +954,7 @@
"""
The number of most recent messages from the task's history to retrieve.
"""
id: str
id: UUID
"""
The unique identifier of the task.
"""
Expand Down Expand Up @@ -1374,7 +1375,7 @@
Represents a file, data structure, or other resource generated by an agent during a task.
"""

artifact_id: str
artifact_id: UUID
"""
A unique identifier for the artifact within the scope of the task.
"""
Expand Down Expand Up @@ -1438,7 +1439,7 @@
Represents a single message in the conversation between a user and an agent.
"""

context_id: str | None = None
context_id: UUID | None = None
"""
The context identifier for this message, used to group related interactions.
"""
Expand All @@ -1450,7 +1451,7 @@
"""
The type of this object, used as a discriminator. Always 'message' for a Message.
"""
message_id: str
message_id: UUID
"""
A unique identifier for the message, typically a UUID, generated by the sender.
"""
Expand All @@ -1463,15 +1464,15 @@
An array of content parts that form the message body. A message can be
composed of multiple parts of different types (e.g., text and files).
"""
reference_task_ids: list[str] | None = None
reference_task_ids: list[UUID] | None = None
"""
A list of other task IDs that this message references for additional context.
"""
role: Role
"""
Identifies the sender of the message. `user` for the client, `agent` for the service.
"""
task_id: str | None = None
task_id: UUID | None = None
"""
The identifier of the task this message is part of. Can be omitted for the first message of a new task.
"""
Expand Down Expand Up @@ -1614,7 +1615,7 @@
"""
The artifact that was generated or updated.
"""
context_id: str
context_id: UUID
"""
The context ID associated with the task.
"""
Expand All @@ -1630,7 +1631,7 @@
"""
Optional metadata for extensions.
"""
task_id: str
task_id: UUID
"""
The ID of the task this artifact belongs to.
"""
Expand Down Expand Up @@ -1663,7 +1664,7 @@
This is typically used in streaming or subscription models.
"""

context_id: str
context_id: UUID
"""
The context ID associated with the task.
"""
Expand All @@ -1683,7 +1684,7 @@
"""
The new status of the task.
"""
task_id: str
task_id: UUID
"""
The ID of the task that was updated.
"""
Expand Down Expand Up @@ -1861,15 +1862,15 @@
"""
A collection of artifacts generated by the agent during the execution of the task.
"""
context_id: str
context_id: UUID
"""
A server-generated identifier for maintaining context across multiple related tasks or interactions.
"""
history: list[Message] | None = None
"""
An array of messages exchanged during the task, representing the conversation history.
"""
id: str
id: UUID
"""
A unique identifier for the task, generated by the server for a new task.
"""
Expand Down
22 changes: 11 additions & 11 deletions src/a2a/utils/proto_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ def message(cls, message: types.Message | None) -> a2a_pb2.Message | None:
if message is None:
return None
return a2a_pb2.Message(
message_id=message.message_id,
message_id=str(message.message_id),
content=[ToProto.part(p) for p in message.parts],
context_id=message.context_id or '',
task_id=message.task_id or '',
context_id=str(message.context_id) if message.context_id else None,
task_id=str(message.task_id) if message.task_id else None,
role=cls.role(message.role),
metadata=ToProto.metadata(message.metadata),
)
Expand Down Expand Up @@ -86,8 +86,8 @@ def file(
@classmethod
def task(cls, task: types.Task) -> a2a_pb2.Task:
return a2a_pb2.Task(
id=task.id,
context_id=task.context_id,
id=str(task.id),
context_id=str(task.context_id),
status=ToProto.task_status(task.status),
artifacts=(
[ToProto.artifact(a) for a in task.artifacts]
Expand Down Expand Up @@ -129,7 +129,7 @@ def task_state(cls, state: types.TaskState) -> a2a_pb2.TaskState:
@classmethod
def artifact(cls, artifact: types.Artifact) -> a2a_pb2.Artifact:
return a2a_pb2.Artifact(
artifact_id=artifact.artifact_id,
artifact_id=str(artifact.artifact_id),
description=artifact.description,
metadata=ToProto.metadata(artifact.metadata),
name=artifact.name,
Expand All @@ -155,7 +155,7 @@ def push_notification_config(
else None
)
return a2a_pb2.PushNotificationConfig(
id=config.id or '',
id=str(config.id) if config.id else None,
url=config.url,
token=config.token,
authentication=auth_info,
Expand All @@ -166,8 +166,8 @@ def task_artifact_update_event(
cls, event: types.TaskArtifactUpdateEvent
) -> a2a_pb2.TaskArtifactUpdateEvent:
return a2a_pb2.TaskArtifactUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
task_id=str(event.task_id),
context_id=str(event.context_id),
artifact=ToProto.artifact(event.artifact),
metadata=ToProto.metadata(event.metadata),
append=event.append or False,
Expand All @@ -179,8 +179,8 @@ def task_status_update_event(
cls, event: types.TaskStatusUpdateEvent
) -> a2a_pb2.TaskStatusUpdateEvent:
return a2a_pb2.TaskStatusUpdateEvent(
task_id=event.task_id,
context_id=event.context_id,
task_id=str(event.task_id),
context_id=str(event.context_id),
status=ToProto.task_status(event.status),
metadata=ToProto.metadata(event.metadata),
final=event.final,
Expand Down
16 changes: 2 additions & 14 deletions src/a2a/utils/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,10 @@ def new_task(request: Message) -> Task:
if isinstance(part.root, TextPart) and not part.root.text:
raise ValueError('TextPart content cannot be empty')

context_id_str = request.context_id
if context_id_str is not None:
try:
uuid.UUID(context_id_str)
context_id = context_id_str
except (ValueError, AttributeError, TypeError) as e:
raise ValueError(
f"Invalid context_id: '{context_id_str}' is not a valid UUID."
) from e
else:
context_id = str(uuid.uuid4())

return Task(
status=TaskStatus(state=TaskState.submitted),
id=(request.task_id if request.task_id else str(uuid.uuid4())),
context_id=context_id,
id=request.task_id or uuid.uuid4(),
context_id=request.context_id or uuid.uuid4(),
history=[request],
)

Expand Down
4 changes: 2 additions & 2 deletions tests/client/test_auth_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def build_success_response(request: httpx.Request) -> httpx.Response:
jsonrpc='2.0',
result=Message(
kind='message',
message_id='message-id',
message_id='c222a603-645e-4c37-8f7b-e49f3ea80e9e',
role=Role.agent,
parts=[],
),
Expand All @@ -75,7 +75,7 @@ def build_success_response(request: httpx.Request) -> httpx.Response:
def build_message() -> Message:
"""Builds a minimal Message."""
return Message(
message_id='msg1',
message_id='87c8541d-f773-4825-bbb1-f518727231f2',
role=Role.user,
parts=[],
)
Expand Down
20 changes: 10 additions & 10 deletions tests/client/test_base_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def sample_agent_card():
def sample_message():
return Message(
role=Role.user,
message_id='msg-1',
message_id='15957e91-63e6-40ac-8205-1d1ffb09a5b2',
parts=[Part(root=TextPart(text='Hello'))],
)

Expand All @@ -65,8 +65,8 @@ async def test_send_message_streaming(
):
async def create_stream(*args, **kwargs):
yield Task(
id='task-123',
context_id='ctx-456',
id='536ab032-6915-47d1-9909-4172dbee4aa0',
context_id='9f18b6e9-63c4-4d44-a8b8-f4648003b6b8',
status=TaskStatus(state=TaskState.completed),
)

Expand All @@ -77,7 +77,7 @@ async def create_stream(*args, **kwargs):
mock_transport.send_message_streaming.assert_called_once()
assert not mock_transport.send_message.called
assert len(events) == 1
assert events[0][0].id == 'task-123'
assert str(events[0][0].id) == '536ab032-6915-47d1-9909-4172dbee4aa0'


@pytest.mark.asyncio
Expand All @@ -86,8 +86,8 @@ async def test_send_message_non_streaming(
):
base_client._config.streaming = False
mock_transport.send_message.return_value = Task(
id='task-456',
context_id='ctx-789',
id='9368e3b5-c796-46cf-9318-6c73e1a37e58',
context_id='0a934875-fa22-4af0-8b40-79b13d46e4a6',
status=TaskStatus(state=TaskState.completed),
)

Expand All @@ -96,7 +96,7 @@ async def test_send_message_non_streaming(
mock_transport.send_message.assert_called_once()
assert not mock_transport.send_message_streaming.called
assert len(events) == 1
assert events[0][0].id == 'task-456'
assert str(events[0][0].id) == '9368e3b5-c796-46cf-9318-6c73e1a37e58'


@pytest.mark.asyncio
Expand All @@ -105,8 +105,8 @@ async def test_send_message_non_streaming_agent_capability_false(
):
base_client._card.capabilities.streaming = False
mock_transport.send_message.return_value = Task(
id='task-789',
context_id='ctx-101',
id='d7541723-0796-4231-8849-f6f137ea3bf8',
context_id='dab80cd1-224d-47cd-abd8-cc53101fb273',
status=TaskStatus(state=TaskState.completed),
)

Expand All @@ -115,4 +115,4 @@ async def test_send_message_non_streaming_agent_capability_false(
mock_transport.send_message.assert_called_once()
assert not mock_transport.send_message_streaming.called
assert len(events) == 1
assert events[0][0].id == 'task-789'
assert str(events[0][0].id) == 'd7541723-0796-4231-8849-f6f137ea3bf8'
Loading
Loading