From bcd51e36441b2b87d07881c5359bdf2eb82288c8 Mon Sep 17 00:00:00 2001 From: Krishna Thota Date: Tue, 20 May 2025 13:58:50 -0700 Subject: [PATCH 1/8] fix: Throw exception for task_id mismatches --- examples/langgraph/agent_executor.py | 6 ++-- .../default_request_handler.py | 34 ++++++++++++------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/examples/langgraph/agent_executor.py b/examples/langgraph/agent_executor.py index add3c045..4bac9d3b 100644 --- a/examples/langgraph/agent_executor.py +++ b/examples/langgraph/agent_executor.py @@ -1,5 +1,5 @@ -from agent import CurrencyAgent # type: ignore[import-untyped] -from typing_extensions import override +from agent import CurrencyAgent # type: ignore[import-untyped] + from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events.event_queue import EventQueue from a2a.types import ( @@ -17,7 +17,6 @@ class CurrencyAgentExecutor(AgentExecutor): def __init__(self): self.agent = CurrencyAgent() - @override async def execute( self, context: RequestContext, @@ -89,7 +88,6 @@ async def execute( ) ) - @override async def cancel( self, context: RequestContext, event_queue: EventQueue ) -> None: diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index eb8de0a5..fa6d7ea8 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -16,7 +16,6 @@ EventQueue, InMemoryQueueManager, QueueManager, - TaskQueueExists, ) from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.tasks import ( @@ -212,6 +211,15 @@ async def on_message_send( ) = await result_aggregator.consume_and_break_on_interrupt(consumer) if not result: raise ServerError(error=InternalError()) + + if isinstance(result, Task) and task_id != result.id: + logger.error( + f'Agent generated task_id={result.id} does not match the RequestContext task_id={task_id}.' + ) + raise ServerError( + InternalError(message='Task ID mismatch in agent response') + ) + finally: if interrupted: # TODO: Track this disconnected cleanup task. @@ -278,27 +286,27 @@ async def on_message_send_stream( consumer = EventConsumer(queue) producer_task.add_done_callback(consumer.agent_task_callback) async for event in result_aggregator.consume_and_emit(consumer): - if isinstance(event, Task) and task_id != event.id: - logger.warning( - f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.' - ) - try: - created_task: Task = event - await self._queue_manager.add(created_task.id, queue) - task_id = created_task.id - except TaskQueueExists: - logging.info( - 'Multiple Task objects created in event stream.' + if isinstance(event, Task): + if task_id != event.id: + logger.error( + f'Agent generated task_id={event.id} does not match the RequestContext task_id={task_id}.' ) + raise ServerError( + InternalError( + message='Task ID mismatch in agent response' + ) + ) + if ( self._push_notifier and params.configuration and params.configuration.pushNotificationConfig ): await self._push_notifier.set_info( - created_task.id, + task_id, params.configuration.pushNotificationConfig, ) + if self._push_notifier and task_id: latest_task = await result_aggregator.current_result if isinstance(latest_task, Task): From 73dbc3f37839e7bbdf0602ab0bf84fa4c0d566a2 Mon Sep 17 00:00:00 2001 From: Krishna Thota Date: Tue, 20 May 2025 15:25:06 -0700 Subject: [PATCH 2/8] Add tests --- .../request_handlers/test_jsonrpc_handler.py | 110 +++++++++++++++++- 1 file changed, 105 insertions(+), 5 deletions(-) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 7e36aeb9..ee3843e2 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -8,7 +8,7 @@ import pytest import httpx -from a2a.server.agent_execution import AgentExecutor +from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.events import ( QueueManager, ) @@ -55,6 +55,7 @@ TaskStatusUpdateEvent, TextPart, UnsupportedOperationError, + InternalError, ) from a2a.utils.errors import ServerError @@ -183,7 +184,12 @@ async def test_on_cancel_task_not_found(self) -> None: mock_task_store.get.assert_called_once_with('nonexistent_id') mock_agent_executor.cancel.assert_not_called() - async def test_on_message_new_message_success(self) -> None: + @patch( + 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' + ) + async def test_on_message_new_message_success( + self, _mock_builder_build: AsyncMock + ) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( @@ -194,6 +200,14 @@ async def test_on_message_new_message_success(self) -> None: mock_task_store.get.return_value = mock_task mock_agent_executor.execute.return_value = None + _mock_builder_build.return_value = RequestContext( + request=MagicMock(), + task_id='task_123', + context_id='session-xyz', + task=None, + related_tasks=None, + ) + async def streaming_coro(): yield mock_task @@ -279,15 +293,28 @@ async def streaming_coro(): assert response.root.error == UnsupportedOperationError() # type: ignore mock_agent_executor.execute.assert_called_once() - async def test_on_message_stream_new_message_success(self) -> None: + @patch( + 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' + ) + async def test_on_message_stream_new_message_success( + self, _mock_builder_build: AsyncMock + ) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) request_handler = DefaultRequestHandler( mock_agent_executor, mock_task_store ) - self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) + self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) handler = JSONRPCHandler(self.mock_agent_card, request_handler) + _mock_builder_build.return_value = RequestContext( + request=MagicMock(), + task_id='task_123', + context_id='session-xyz', + task=None, + related_tasks=None, + ) + events: list[Any] = [ Task(**MINIMAL_TASK), TaskArtifactUpdateEvent( @@ -462,8 +489,11 @@ async def test_get_push_notification_success(self) -> None: ) assert get_response.root.result == task_push_config # type: ignore + @patch( + 'a2a.server.agent_execution.simple_request_context_builder.SimpleRequestContextBuilder.build' + ) async def test_on_message_stream_new_message_send_push_notification_success( - self, + self, _mock_builder_build: AsyncMock ) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) @@ -475,6 +505,13 @@ async def test_on_message_stream_new_message_send_push_notification_success( self.mock_agent_card.capabilities = AgentCapabilities( streaming=True, pushNotifications=True ) + _mock_builder_build.return_value = RequestContext( + request=MagicMock(), + task_id='task_123', + context_id='session-xyz', + task=None, + related_tasks=None, + ) handler = JSONRPCHandler(self.mock_agent_card, request_handler) events: list[Any] = [ @@ -642,3 +679,66 @@ async def test_on_resubscribe_no_existing_task_error(self) -> None: assert len(collected_events) == 1 self.assertIsInstance(collected_events[0].root, JSONRPCErrorResponse) assert collected_events[0].root.error == TaskNotFoundError() + + async def test_on_message_send_task_id_mismatch(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + mock_agent_executor.execute.return_value = None + + async def streaming_coro(): + yield mock_task + + with patch( + 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', + return_value=streaming_coro(), + ): + request = SendMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + response = await handler.on_message_send(request) + assert mock_agent_executor.execute.call_count == 1 + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertIsInstance(response.root.error, InternalError) # type: ignore + + async def test_on_message_stream_task_id_mismatch(self) -> None: + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + + self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + events: list[Any] = [Task(**MINIMAL_TASK)] + + async def streaming_coro(): + for event in events: + yield event + + with patch( + 'a2a.server.request_handlers.default_request_handler.EventConsumer.consume_all', + return_value=streaming_coro(), + ): + mock_task_store.get.return_value = None + mock_agent_executor.execute.return_value = None + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + response = handler.on_message_send_stream(request) + assert isinstance(response, AsyncGenerator) + collected_events: list[Any] = [] + async for event in response: + collected_events.append(event) + assert len(collected_events) == 1 + self.assertIsInstance( + collected_events[0].root, JSONRPCErrorResponse + ) + self.assertIsInstance(collected_events[0].root.error, InternalError) From ee0cc76d8765ee86e58a54a40e9d744afbf09d99 Mon Sep 17 00:00:00 2001 From: Holt Skinner <13262395+holtskinner@users.noreply.github.com> Date: Tue, 20 May 2025 13:24:10 -0700 Subject: [PATCH 3/8] ci: Create GitHub Action to generate `types.py` from specification JSON (#67) --- .github/actions/spelling/allow.txt | 1 + .github/release-please.yml | 1 + .github/workflows/update-a2a-types.yml | 86 ++++++++++++++++++++++++++ development.md | 18 +++++- 4 files changed, 103 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/update-a2a-types.yml diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index dd4dd5ec..19f28982 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -26,6 +26,7 @@ dunders genai gle inmemory +kwarg langgraph lifecycles linting diff --git a/.github/release-please.yml b/.github/release-please.yml index b0a050d4..8d4679d2 100644 --- a/.github/release-please.yml +++ b/.github/release-please.yml @@ -1,3 +1,4 @@ releaseType: python handleGHRelease: true bumpMinorPreMajor: false +bumpPatchForMinorPreMajor: true diff --git a/.github/workflows/update-a2a-types.yml b/.github/workflows/update-a2a-types.yml new file mode 100644 index 00000000..56bf302a --- /dev/null +++ b/.github/workflows/update-a2a-types.yml @@ -0,0 +1,86 @@ +name: Update A2A Schema from Specification + +on: + schedule: + - cron: "0 0 * * *" + workflow_dispatch: + +jobs: + check_and_update: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.13" + + - name: Install uv + run: curl -LsSf https://astral.sh/uv/install.sh | sh + + - name: Configure uv shell + run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH + + - name: Install dependencies (datamodel-code-generator) + run: uv sync + + - name: Define output file variable + id: vars + run: | + GENERATED_FILE="./src/a2a/types.py" + echo "GENERATED_FILE=$GENERATED_FILE" >> "$GITHUB_OUTPUT" + + - name: Run datamodel-codegen + run: | + set -euo pipefail # Exit immediately if a command exits with a non-zero status + + REMOTE_URL="https://raw.githubusercontent.com/google/A2A/refs/heads/main/specification/json/a2a.json" + GENERATED_FILE="${{ steps.vars.outputs.GENERATED_FILE }}" + + echo "Running datamodel-codegen..." + uv run datamodel-codegen \ + --url "$REMOTE_URL" \ + --input-file-type jsonschema \ + --output "$GENERATED_FILE" \ + --target-python-version 3.10 \ + --output-model-type pydantic_v2.BaseModel \ + --disable-timestamp \ + --use-schema-description \ + --use-union-operator \ + --use-field-description \ + --use-default \ + --use-default-kwarg \ + --use-one-literal-as-default \ + --class-name A2A \ + --use-standard-collections + echo "Codegen finished." + + - name: Commit and push if generated file changed + if: github.ref == 'refs/heads/main' # Or your default branch name + run: | + set -euo pipefail + + GENERATED_FILE="${{ steps.vars.outputs.GENERATED_FILE }}" + + # Check if the generated file has any changes compared to HEAD + if git diff --quiet "$GENERATED_FILE"; then + echo "$GENERATED_FILE has no changes after codegen. Nothing to commit." + else + echo "Changes detected in $GENERATED_FILE. Committing..." + # Configure git user for the commit + git config user.name "github-actions" + git config user.email "github-actions@github.com" + + # Add the generated file + git add "$GENERATED_FILE" + + # Commit changes + git commit -m "🤖 chore: Auto-update A2A schema from specification" + + # Push changes + git push + echo "Changes committed and pushed." + fi diff --git a/development.md b/development.md index 0d9ef29c..7f42a75c 100644 --- a/development.md +++ b/development.md @@ -2,8 +2,20 @@ ## Type generation from spec - - ```bash -uv run datamodel-codegen --input ./spec.json --input-file-type jsonschema --output ./src/a2a/types.py --target-python-version 3.10 --output-model-type pydantic_v2.BaseModel --disable-timestamp --use-schema-description --use-union-operator --use-field-description --use-default --use-default-kwarg --use-one-literal-as-default --class-name A2A --use-standard-collections +uv run datamodel-codegen \ + --url https://raw.githubusercontent.com/google/A2A/refs/heads/main/specification/json/a2a.json \ + --input-file-type jsonschema \ + --output ./src/a2a/types.py \ + --target-python-version 3.10 \ + --output-model-type pydantic_v2.BaseModel \ + --disable-timestamp \ + --use-schema-description \ + --use-union-operator \ + --use-field-description \ + --use-default \ + --use-default-kwarg \ + --use-one-literal-as-default \ + --class-name A2A \ + --use-standard-collections ``` From c522c0ff678ed38019a41adbd9c8c1c1f49cf748 Mon Sep 17 00:00:00 2001 From: holtskinner Date: Tue, 20 May 2025 14:05:50 -0700 Subject: [PATCH 4/8] ci: Remove update-a2a-types.yml workflow --- .github/workflows/update-a2a-types.yml | 86 -------------------------- 1 file changed, 86 deletions(-) delete mode 100644 .github/workflows/update-a2a-types.yml diff --git a/.github/workflows/update-a2a-types.yml b/.github/workflows/update-a2a-types.yml deleted file mode 100644 index 56bf302a..00000000 --- a/.github/workflows/update-a2a-types.yml +++ /dev/null @@ -1,86 +0,0 @@ -name: Update A2A Schema from Specification - -on: - schedule: - - cron: "0 0 * * *" - workflow_dispatch: - -jobs: - check_and_update: - runs-on: ubuntu-latest - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set up Python - uses: actions/setup-python@v5 - with: - python-version: "3.13" - - - name: Install uv - run: curl -LsSf https://astral.sh/uv/install.sh | sh - - - name: Configure uv shell - run: echo "$HOME/.cargo/bin" >> $GITHUB_PATH - - - name: Install dependencies (datamodel-code-generator) - run: uv sync - - - name: Define output file variable - id: vars - run: | - GENERATED_FILE="./src/a2a/types.py" - echo "GENERATED_FILE=$GENERATED_FILE" >> "$GITHUB_OUTPUT" - - - name: Run datamodel-codegen - run: | - set -euo pipefail # Exit immediately if a command exits with a non-zero status - - REMOTE_URL="https://raw.githubusercontent.com/google/A2A/refs/heads/main/specification/json/a2a.json" - GENERATED_FILE="${{ steps.vars.outputs.GENERATED_FILE }}" - - echo "Running datamodel-codegen..." - uv run datamodel-codegen \ - --url "$REMOTE_URL" \ - --input-file-type jsonschema \ - --output "$GENERATED_FILE" \ - --target-python-version 3.10 \ - --output-model-type pydantic_v2.BaseModel \ - --disable-timestamp \ - --use-schema-description \ - --use-union-operator \ - --use-field-description \ - --use-default \ - --use-default-kwarg \ - --use-one-literal-as-default \ - --class-name A2A \ - --use-standard-collections - echo "Codegen finished." - - - name: Commit and push if generated file changed - if: github.ref == 'refs/heads/main' # Or your default branch name - run: | - set -euo pipefail - - GENERATED_FILE="${{ steps.vars.outputs.GENERATED_FILE }}" - - # Check if the generated file has any changes compared to HEAD - if git diff --quiet "$GENERATED_FILE"; then - echo "$GENERATED_FILE has no changes after codegen. Nothing to commit." - else - echo "Changes detected in $GENERATED_FILE. Committing..." - # Configure git user for the commit - git config user.name "github-actions" - git config user.email "github-actions@github.com" - - # Add the generated file - git add "$GENERATED_FILE" - - # Commit changes - git commit -m "🤖 chore: Auto-update A2A schema from specification" - - # Push changes - git push - echo "Changes committed and pushed." - fi From 6fb0e50dea948752085524a11c6cf46277964c74 Mon Sep 17 00:00:00 2001 From: Holt Skinner <13262395+holtskinner@users.noreply.github.com> Date: Tue, 20 May 2025 14:09:24 -0700 Subject: [PATCH 5/8] chore: Regenerate types.py from spec (#71) --- src/a2a/types.py | 1286 +++++++++++++++++++++++++++++++++++----------- 1 file changed, 983 insertions(+), 303 deletions(-) diff --git a/src/a2a/types.py b/src/a2a/types.py index 675ed832..52731b1e 100644 --- a/src/a2a/types.py +++ b/src/a2a/types.py @@ -1,7 +1,5 @@ # generated by datamodel-codegen: -# filename: spec.json - -"""Data models representing the A2A protocol.""" +# filename: https://raw.githubusercontent.com/google/A2A/refs/heads/main/specification/json/a2a.json from __future__ import annotations @@ -12,13 +10,13 @@ class A2A(RootModel[Any]): - """Root model for the A2A specification.""" - root: Any class In(Enum): - """The location of the API key. Valid values are "query", "header", or "cookie".""" + """ + The location of the API key. Valid values are "query", "header", or "cookie". + """ cookie = 'cookie' header = 'header' @@ -26,377 +24,719 @@ class In(Enum): class APIKeySecurityScheme(BaseModel): - """API Key security scheme.""" + """ + API Key security scheme. + """ description: str | None = None - """Description of this security scheme.""" + """ + Description of this security scheme. + """ in_: In = Field(..., alias='in') - """The location of the API key. Valid values are "query", "header", or "cookie".""" + """ + The location of the API key. Valid values are "query", "header", or "cookie". + """ name: str - """The name of the header, query or cookie parameter to be used.""" + """ + The name of the header, query or cookie parameter to be used. + """ type: Literal['apiKey'] = 'apiKey' class AgentCapabilities(BaseModel): - """Defines optional capabilities supported by an agent.""" + """ + Defines optional capabilities supported by an agent. + """ pushNotifications: bool | None = None - """True if the agent can notify updates to client.""" + """ + true if the agent can notify updates to client. + """ stateTransitionHistory: bool | None = None - """True if the agent exposes status change history for tasks.""" + """ + true if the agent exposes status change history for tasks. + """ streaming: bool | None = None - """True if the agent supports SSE.""" + """ + true if the agent supports SSE. + """ class AgentProvider(BaseModel): - """Represents the service provider of an agent.""" + """ + Represents the service provider of an agent. + """ organization: str - """Agent provider's organization name.""" + """ + Agent provider's organization name. + """ url: str - """Agent provider's url.""" + """ + Agent provider's URL. + """ class AgentSkill(BaseModel): - """Represents a unit of capability that an agent can perform.""" + """ + Represents a unit of capability that an agent can perform. + """ description: str - """Description of the skill - will be used by the client or a human as a hint to understand what the skill does.""" + """ + Description of the skill - will be used by the client or a human + as a hint to understand what the skill does. + """ examples: list[str] | None = None - """The set of example scenarios that the skill can perform. Will be used by the client as a hint to understand how the skill can be used.""" + """ + The set of example scenarios that the skill can perform. + Will be used by the client as a hint to understand how the skill can be + used. + """ id: str - """Unique identifier for the agent's skill.""" + """ + Unique identifier for the agent's skill. + """ inputModes: list[str] | None = None - """The set of interaction modes that the skill supports (if different than the default). Supported mime types for input.""" + """ + The set of interaction modes that the skill supports + (if different than the default). + Supported mime types for input. + """ name: str - """Human readable name of the skill.""" + """ + Human readable name of the skill. + """ outputModes: list[str] | None = None - """Supported mime types for output.""" + """ + Supported mime types for output. + """ tags: list[str] - """Set of tagwords describing classes of capabilities for this specific skill.""" + """ + Set of tagwords describing classes of capabilities for this specific + skill. + """ class AuthorizationCodeOAuthFlow(BaseModel): - """Configuration details for a supported OAuth Flow.""" + """ + Configuration details for a supported OAuth Flow + """ authorizationUrl: str - """The authorization URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The authorization URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS + """ refreshUrl: str | None = None - """The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS. + """ scopes: dict[str, str] - """The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. The map MAY be empty.""" + """ + The available scopes for the OAuth2 security scheme. A map between the scope name and a short + description for it. The map MAY be empty. + """ tokenUrl: str - """The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard + requires the use of TLS. + """ class ClientCredentialsOAuthFlow(BaseModel): - """Configuration details for a supported OAuth Flow.""" + """ + Configuration details for a supported OAuth Flow + """ refreshUrl: str | None = None - """The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS. + """ scopes: dict[str, str] - """The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. The map MAY be empty.""" + """ + The available scopes for the OAuth2 security scheme. A map between the scope name and a short + description for it. The map MAY be empty. + """ tokenUrl: str - """The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard + requires the use of TLS. + """ class ContentTypeNotSupportedError(BaseModel): - """A2A specific error indicating incompatible content types between request and agent capabilities.""" + """ + A2A specific error indicating incompatible content types between request and agent capabilities. + """ code: Literal[-32005] = -32005 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Incompatible content types' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class DataPart(BaseModel): - """Represents a structured data segment within a message part.""" + """ + Represents a structured data segment within a message part. + """ data: dict[str, Any] - """Structured data content.""" + """ + Structured data content + """ kind: Literal['data'] = 'data' - """Part type - data for DataParts.""" + """ + Part type - data for DataParts + """ metadata: dict[str, Any] | None = None - """Optional metadata associated with the part.""" + """ + Optional metadata associated with the part. + """ class FileBase(BaseModel): - """Represents the base entity for FileParts.""" + """ + Represents the base entity for FileParts + """ mimeType: str | None = None - """Optional mimeType for the file.""" + """ + Optional mimeType for the file + """ name: str | None = None - """Optional name for the file.""" + """ + Optional name for the file + """ -class FileWithBytes(FileBase): - """Define the variant where 'bytes' is present and 'uri' is absent.""" +class FileWithBytes(BaseModel): + """ + Define the variant where 'bytes' is present and 'uri' is absent + """ bytes: str - """Base64 encoded content of the file.""" + """ + base64 encoded content of the file + """ + mimeType: str | None = None + """ + Optional mimeType for the file + """ + name: str | None = None + """ + Optional name for the file + """ -class FileWithUri(FileBase): - """Define the variant where 'uri' is present and 'bytes' is absent.""" +class FileWithUri(BaseModel): + """ + Define the variant where 'uri' is present and 'bytes' is absent + """ + mimeType: str | None = None + """ + Optional mimeType for the file + """ + name: str | None = None + """ + Optional name for the file + """ uri: str + """ + URL for the File content + """ class HTTPAuthSecurityScheme(BaseModel): - """HTTP Authentication security scheme.""" + """ + HTTP Authentication security scheme. + """ bearerFormat: str | None = None - """A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation purposes.""" + """ + A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually + generated by an authorization server, so this information is primarily for documentation + purposes. + """ description: str | None = None - """Description of this security scheme.""" + """ + Description of this security scheme. + """ scheme: str - """The name of the HTTP Authentication scheme to be used in the Authorization header as defined in RFC7235. The values used SHOULD be registered in the IANA Authentication Scheme registry. The value is case-insensitive, as defined in RFC7235.""" + """ + The name of the HTTP Authentication scheme to be used in the Authorization header as defined + in RFC7235. The values used SHOULD be registered in the IANA Authentication Scheme registry. + The value is case-insensitive, as defined in RFC7235. + """ type: Literal['http'] = 'http' class ImplicitOAuthFlow(BaseModel): - """Configuration details for a supported OAuth Flow.""" + """ + Configuration details for a supported OAuth Flow + """ authorizationUrl: str - """The authorization URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The authorization URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS + """ refreshUrl: str | None = None - """The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS. + """ scopes: dict[str, str] - """The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. The map MAY be empty.""" + """ + The available scopes for the OAuth2 security scheme. A map between the scope name and a short + description for it. The map MAY be empty. + """ class InternalError(BaseModel): - """JSON-RPC error indicating an internal JSON-RPC error on the server.""" + """ + JSON-RPC error indicating an internal JSON-RPC error on the server. + """ code: Literal[-32603] = -32603 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Internal error' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class InvalidAgentResponseError(BaseModel): - """A2A specific error indicating agent returned invalid response for the current method.""" + """ + A2A specific error indicating agent returned invalid response for the current method + """ code: Literal[-32006] = -32006 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Invalid agent response' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class InvalidParamsError(BaseModel): - """JSON-RPC error indicating invalid method parameter(s).""" + """ + JSON-RPC error indicating invalid method parameter(s). + """ code: Literal[-32602] = -32602 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Invalid parameters' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class InvalidRequestError(BaseModel): - """JSON-RPC error indicating the JSON sent is not a valid Request object.""" + """ + JSON-RPC error indicating the JSON sent is not a valid Request object. + """ code: Literal[-32600] = -32600 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Request payload validation error' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class JSONParseError(BaseModel): - """JSON-RPC error indicating invalid JSON was received by the server.""" + """ + JSON-RPC error indicating invalid JSON was received by the server. + """ code: Literal[-32700] = -32700 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Invalid JSON payload' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class JSONRPCError(BaseModel): - """Represents a JSON-RPC 2.0 Error object.""" + """ + Represents a JSON-RPC 2.0 Error object. + This is typically included in a JSONRPCErrorResponse when an error occurs. + """ code: int - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class JSONRPCMessage(BaseModel): - """Base interface for any JSON-RPC 2.0 request or response.""" + """ + Base interface for any JSON-RPC 2.0 request or response. + """ id: str | int | None = None - """An identifier established by the Client that MUST contain a String, Number. Numbers SHOULD NOT contain fractional parts.""" + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ jsonrpc: Literal['2.0'] = '2.0' - """Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0".""" + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ -class JSONRPCRequest(JSONRPCMessage): - """Represents a JSON-RPC 2.0 Request object.""" +class JSONRPCRequest(BaseModel): + """ + Represents a JSON-RPC 2.0 Request object. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: str - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: dict[str, Any] | None = None - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ -class JSONRPCResult(JSONRPCMessage): - """Represents a JSON-RPC 2.0 Result object.""" +class JSONRPCResult(BaseModel): + """ + Represents a JSON-RPC 2.0 Result object. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ result: Any - """The result object on success.""" + """ + The result object on success + """ class Role(Enum): - """Message sender's role.""" + """ + Message sender's role + """ agent = 'agent' user = 'user' class MethodNotFoundError(BaseModel): - """JSON-RPC error indicating the method does not exist or is not available.""" + """ + JSON-RPC error indicating the method does not exist or is not available. + """ code: Literal[-32601] = -32601 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Method not found' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class OpenIdConnectSecurityScheme(BaseModel): - """OpenID Connect security scheme configuration.""" + """ + OpenID Connect security scheme configuration. + """ description: str | None = None - """Description of this security scheme.""" + """ + Description of this security scheme. + """ openIdConnectUrl: str - """Well-known URL to discover the [[OpenID-Connect-Discovery]] provider metadata.""" + """ + Well-known URL to discover the [[OpenID-Connect-Discovery]] provider metadata. + """ type: Literal['openIdConnect'] = 'openIdConnect' class PartBase(BaseModel): - """Base properties common to all message parts.""" + """ + Base properties common to all message parts. + """ metadata: dict[str, Any] | None = None - """Optional metadata associated with the part.""" + """ + Optional metadata associated with the part. + """ class PasswordOAuthFlow(BaseModel): - """Configuration details for a supported OAuth Flow.""" + """ + Configuration details for a supported OAuth Flow + """ refreshUrl: str | None = None - """The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The URL to be used for obtaining refresh tokens. This MUST be in the form of a URL. The OAuth2 + standard requires the use of TLS. + """ scopes: dict[str, str] - """The available scopes for the OAuth2 security scheme. A map between the scope name and a short description for it. The map MAY be empty.""" + """ + The available scopes for the OAuth2 security scheme. A map between the scope name and a short + description for it. The map MAY be empty. + """ tokenUrl: str - """The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard requires the use of TLS.""" + """ + The token URL to be used for this flow. This MUST be in the form of a URL. The OAuth2 standard + requires the use of TLS. + """ class PushNotificationAuthenticationInfo(BaseModel): - """Defines authentication details for push notifications.""" + """ + Defines authentication details for push notifications. + """ credentials: str | None = None - """Optional credentials.""" + """ + Optional credentials + """ schemes: list[str] - """Supported authentication schemes - e.g. Basic, Bearer.""" + """ + Supported authentication schemes - e.g. Basic, Bearer + """ class PushNotificationConfig(BaseModel): - """Configuration for setting up push notifications for task updates.""" + """ + Configuration for setting up push notifications for task updates. + """ authentication: PushNotificationAuthenticationInfo | None = None token: str | None = None - """Token unique to this task/session.""" + """ + Token unique to this task/session. + """ url: str - """URL for sending the push notifications.""" + """ + URL for sending the push notifications. + """ class PushNotificationNotSupportedError(BaseModel): - """A2A specific error indicating the agent does not support push notifications.""" + """ + A2A specific error indicating the agent does not support push notifications. + """ code: Literal[-32003] = -32003 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Push Notification is not supported' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class SecuritySchemeBase(BaseModel): - """Base properties shared by all security schemes.""" + """ + Base properties shared by all security schemes. + """ description: str | None = None - """Description of this security scheme.""" + """ + Description of this security scheme. + """ class TaskIdParams(BaseModel): - """Parameters containing only a task ID, used for simple task operations.""" + """ + Parameters containing only a task ID, used for simple task operations. + """ id: str - """Task id.""" + """ + Task id. + """ metadata: dict[str, Any] | None = None class TaskNotCancelableError(BaseModel): - """A2A specific error indicating the task is in a state where it cannot be canceled.""" + """ + A2A specific error indicating the task is in a state where it cannot be canceled. + """ code: Literal[-32002] = -32002 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Task cannot be canceled' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class TaskNotFoundError(BaseModel): - """A2A specific error indicating the requested task ID was not found.""" + """ + A2A specific error indicating the requested task ID was not found. + """ code: Literal[-32001] = -32001 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'Task not found' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class TaskPushNotificationConfig(BaseModel): - """Parameters for setting or getting push notification configuration for a task.""" + """ + Parameters for setting or getting push notification configuration for a task + """ pushNotificationConfig: PushNotificationConfig + """ + Push notification configuration. + """ taskId: str - """Task id.""" + """ + Task id. + """ -class TaskQueryParams(TaskIdParams): - """Parameters for querying a task, including optional history length.""" +class TaskQueryParams(BaseModel): + """ + Parameters for querying a task, including optional history length. + """ historyLength: int | None = None - """Number of recent messages to be retrieved.""" + """ + Number of recent messages to be retrieved. + """ + id: str + """ + Task id. + """ + metadata: dict[str, Any] | None = None -class TaskResubscriptionRequest(JSONRPCRequest): - """JSON-RPC request model for the 'tasks/resubscribe' method.""" +class TaskResubscriptionRequest(BaseModel): + """ + JSON-RPC request model for the 'tasks/resubscribe' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: Literal['tasks/resubscribe'] = 'tasks/resubscribe' - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: TaskIdParams - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ class TaskState(Enum): - """Represents the possible states of a Task.""" + """ + Represents the possible states of a Task. + """ submitted = 'submitted' working = 'working' @@ -409,24 +749,43 @@ class TaskState(Enum): unknown = 'unknown' -class TextPart(PartBase): - """Represents a text segment within parts.""" +class TextPart(BaseModel): + """ + Represents a text segment within parts. + """ kind: Literal['text'] = 'text' - """Part type - text for TextParts.""" + """ + Part type - text for TextParts + """ + metadata: dict[str, Any] | None = None + """ + Optional metadata associated with the part. + """ text: str - """Text content.""" + """ + Text content + """ class UnsupportedOperationError(BaseModel): - """A2A specific error indicating the requested operation is not supported by the agent.""" + """ + A2A specific error indicating the requested operation is not supported by the agent. + """ code: Literal[-32004] = -32004 - """A Number that indicates the error type that occurred.""" + """ + A Number that indicates the error type that occurred. + """ data: Any | None = None - """A Primitive or Structured value that contains additional information about the error. This may be omitted.""" + """ + A Primitive or Structured value that contains additional information about the error. + This may be omitted. + """ message: str | None = 'This operation is not supported' - """A String providing a short description of the error.""" + """ + A String providing a short description of the error. + """ class A2AError( @@ -444,8 +803,6 @@ class A2AError( | InvalidAgentResponseError ] ): - """Represents A2A specific JSON-RPC error responses.""" - root: ( JSONParseError | InvalidRequestError @@ -461,53 +818,123 @@ class A2AError( ) -class CancelTaskRequest(JSONRPCRequest): - """JSON-RPC request model for the 'tasks/cancel' method.""" +class CancelTaskRequest(BaseModel): + """ + JSON-RPC request model for the 'tasks/cancel' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: Literal['tasks/cancel'] = 'tasks/cancel' - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: TaskIdParams - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ -class FilePart(PartBase): - """Represents a File segment within parts.""" +class FilePart(BaseModel): + """ + Represents a File segment within parts. + """ file: FileWithBytes | FileWithUri - """File content either as url or bytes.""" + """ + File content either as url or bytes + """ kind: Literal['file'] = 'file' - """Part type - file for FileParts.""" + """ + Part type - file for FileParts + """ + metadata: dict[str, Any] | None = None + """ + Optional metadata associated with the part. + """ -class GetTaskPushNotificationConfigRequest(JSONRPCRequest): - """JSON-RPC request model for the 'tasks/pushNotificationConfig/get' method.""" +class GetTaskPushNotificationConfigRequest(BaseModel): + """ + JSON-RPC request model for the 'tasks/pushNotificationConfig/get' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: Literal['tasks/pushNotificationConfig/get'] = ( 'tasks/pushNotificationConfig/get' ) - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: TaskIdParams - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ -class GetTaskPushNotificationConfigSuccessResponse(JSONRPCResult): - """JSON-RPC success response model for the 'tasks/pushNotificationConfig/get' method.""" +class GetTaskPushNotificationConfigSuccessResponse(BaseModel): + """ + JSON-RPC success response model for the 'tasks/pushNotificationConfig/get' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ result: TaskPushNotificationConfig - """The result object on success.""" + """ + The result object on success. + """ -class GetTaskRequest(JSONRPCRequest): - """JSON-RPC request model for the 'tasks/get' method.""" +class GetTaskRequest(BaseModel): + """ + JSON-RPC request model for the 'tasks/get' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: Literal['tasks/get'] = 'tasks/get' - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: TaskQueryParams - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ -class JSONRPCErrorResponse(JSONRPCMessage): - """Represents a JSON-RPC 2.0 Error Response object.""" +class JSONRPCErrorResponse(BaseModel): + """ + Represents a JSON-RPC 2.0 Error Response object. + """ error: ( JSONRPCError @@ -523,126 +950,223 @@ class JSONRPCErrorResponse(JSONRPCMessage): | ContentTypeNotSupportedError | InvalidAgentResponseError ) + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ class MessageSendConfiguration(BaseModel): - """Configuration for the send message request.""" + """ + Configuration for the send message request. + """ acceptedOutputModes: list[str] - """Accepted output modalities by the client.""" + """ + Accepted output modalities by the client. + """ blocking: bool | None = None - """If the server should treat the client as a blocking request.""" + """ + If the server should treat the client as a blocking request. + """ historyLength: int | None = None - """Number of recent messages to be retrieved.""" + """ + Number of recent messages to be retrieved. + """ pushNotificationConfig: PushNotificationConfig | None = None - """Where the server should send notifications when disconnected.""" + """ + Where the server should send notifications when disconnected. + """ class OAuthFlows(BaseModel): - """Allows configuration of the supported OAuth Flows.""" + """ + Allows configuration of the supported OAuth Flows + """ authorizationCode: AuthorizationCodeOAuthFlow | None = None - """Configuration for the OAuth Authorization Code flow. Previously called accessCode in OpenAPI 2.0.""" + """ + Configuration for the OAuth Authorization Code flow. Previously called accessCode in OpenAPI 2.0. + """ clientCredentials: ClientCredentialsOAuthFlow | None = None - """Configuration for the OAuth Client Credentials flow. Previously called application in OpenAPI 2.0.""" + """ + Configuration for the OAuth Client Credentials flow. Previously called application in OpenAPI 2.0 + """ implicit: ImplicitOAuthFlow | None = None - """Configuration for the OAuth Implicit flow.""" + """ + Configuration for the OAuth Implicit flow + """ password: PasswordOAuthFlow | None = None - """Configuration for the OAuth Resource Owner Password flow.""" + """ + Configuration for the OAuth Resource Owner Password flow + """ class Part(RootModel[TextPart | FilePart | DataPart]): - """Represents a part of a message, which can be text, a file, or structured data.""" - root: TextPart | FilePart | DataPart + """ + Represents a part of a message, which can be text, a file, or structured data. + """ -class SetTaskPushNotificationConfigRequest(JSONRPCRequest): - """JSON-RPC request model for the 'tasks/pushNotificationConfig/set' method.""" +class SetTaskPushNotificationConfigRequest(BaseModel): + """ + JSON-RPC request model for the 'tasks/pushNotificationConfig/set' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: Literal['tasks/pushNotificationConfig/set'] = ( 'tasks/pushNotificationConfig/set' ) - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: TaskPushNotificationConfig - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ -class SetTaskPushNotificationConfigSuccessResponse(JSONRPCResult): - """JSON-RPC success response model for the 'tasks/pushNotificationConfig/set' method.""" +class SetTaskPushNotificationConfigSuccessResponse(BaseModel): + """ + JSON-RPC success response model for the 'tasks/pushNotificationConfig/set' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ result: TaskPushNotificationConfig - """The result object on success.""" + """ + The result object on success. + """ class Artifact(BaseModel): - """Represents an artifact generated for a task task.""" + """ + Represents an artifact generated for a task task. + """ artifactId: str - """Unique identifier for the artifact.""" + """ + Unique identifier for the artifact. + """ description: str | None = None - """Optional description for the artifact.""" + """ + Optional description for the artifact. + """ metadata: dict[str, Any] | None = None - """Extension metadata.""" + """ + Extension metadata. + """ name: str | None = None - """Optional name for the artifact.""" + """ + Optional name for the artifact. + """ parts: list[Part] - """Artifact parts.""" + """ + Artifact parts. + """ class GetTaskPushNotificationConfigResponse( - RootModel[ - JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse - ] + RootModel[JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse] ): - """JSON-RPC response for the 'tasks/pushNotificationConfig/get' method.""" - root: JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse + """ + JSON-RPC response for the 'tasks/pushNotificationConfig/set' method. + """ class Message(BaseModel): - """Represents a single message exchanged between user and agent.""" + """ + Represents a single message exchanged between user and agent. + """ contextId: str | None = None - """The context the message is associated with.""" + """ + The context the message is associated with + """ kind: Literal['message'] = 'message' - """Event type.""" + """ + Event type + """ messageId: str - """Identifier created by the message creator.""" + """ + Identifier created by the message creator + """ metadata: dict[str, Any] | None = None - """Extension metadata.""" + """ + Extension metadata. + """ parts: list[Part] - """Message content.""" + """ + Message content + """ referenceTaskIds: list[str] | None = None """ list of tasks referenced as context by this message. """ role: Role - """Message sender's role.""" + """ + Message sender's role + """ taskId: str | None = None - """Identifier of task the message is related to.""" - final: bool | None = None - """Indicates if this is the final message in a stream.""" + """ + Identifier of task the message is related to + """ class MessageSendParams(BaseModel): - """Sent by the client to the agent as a request. May create, continue or restart a task.""" + """ + Sent by the client to the agent as a request. May create, continue or restart a task. + """ configuration: MessageSendConfiguration | None = None - """Send message configuration.""" + """ + Send message configuration. + """ message: Message - """The message being sent to the server.""" + """ + The message being sent to the server. + """ metadata: dict[str, Any] | None = None - """Extension metadata.""" + """ + Extension metadata. + """ class OAuth2SecurityScheme(BaseModel): - """OAuth2.0 security scheme configuration.""" + """ + OAuth2.0 security scheme configuration. + """ description: str | None = None - """Description of this security scheme.""" + """ + Description of this security scheme. + """ flows: OAuthFlows - """An object containing configuration information for the flow types supported.""" + """ + An object containing configuration information for the flow types supported. + """ type: Literal['oauth2'] = 'oauth2' @@ -654,88 +1178,155 @@ class SecurityScheme( | OpenIdConnectSecurityScheme ] ): - """Mirrors the OpenAPI Security Scheme Object (https://swagger.io/specification/#security-scheme-object).""" - root: ( APIKeySecurityScheme | HTTPAuthSecurityScheme | OAuth2SecurityScheme | OpenIdConnectSecurityScheme ) + """ + Mirrors the OpenAPI Security Scheme Object + (https://swagger.io/specification/#security-scheme-object) + """ -class SendMessageRequest(JSONRPCRequest): - """JSON-RPC request model for the 'message/send' method.""" +class SendMessageRequest(BaseModel): + """ + JSON-RPC request model for the 'message/send' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: Literal['message/send'] = 'message/send' - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: MessageSendParams - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ -class SendStreamingMessageRequest(JSONRPCRequest): - """JSON-RPC request model for the 'message/stream' method.""" +class SendStreamingMessageRequest(BaseModel): + """ + JSON-RPC request model for the 'message/stream' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ method: Literal['message/stream'] = 'message/stream' - """A String containing the name of the method to be invoked.""" + """ + A String containing the name of the method to be invoked. + """ params: MessageSendParams - """A Structured value that holds the parameter values to be used during the invocation of the method.""" + """ + A Structured value that holds the parameter values to be used during the invocation of the method. + """ class SetTaskPushNotificationConfigResponse( - RootModel[ - JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse - ] + RootModel[JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse] ): - """JSON-RPC response for the 'tasks/pushNotificationConfig/set' method.""" - root: JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse + """ + JSON-RPC response for the 'tasks/pushNotificationConfig/set' method. + """ class TaskArtifactUpdateEvent(BaseModel): - """Event sent by server during sendStream or subscribe requests indicating an artifact update.""" + """ + sent by server during sendStream or subscribe requests + """ append: bool | None = None - """Indicates if this artifact appends to a previous one.""" + """ + Indicates if this artifact appends to a previous one + """ artifact: Artifact - """Generated artifact.""" + """ + Generated artifact + """ contextId: str - """The context the task is associated with.""" + """ + The context the task is associated with + """ kind: Literal['artifact-update'] = 'artifact-update' - """Event type.""" + """ + Event type + """ lastChunk: bool | None = None - """Indicates if this is the last chunk of the artifact.""" + """ + Indicates if this is the last chunk of the artifact + """ metadata: dict[str, Any] | None = None - """Extension metadata.""" + """ + Extension metadata. + """ taskId: str - """Task id.""" + """ + Task id + """ class TaskStatus(BaseModel): - """TaskState and accompanying message.""" + """ + TaskState and accompanying message. + """ message: Message | None = None - """Additional status updates for client.""" + """ + Additional status updates for client + """ state: TaskState timestamp: str | None = None - """ISO 8601 datetime string when the status was recorded.""" + """ + ISO 8601 datetime string when the status was recorded. + """ class TaskStatusUpdateEvent(BaseModel): - """Event sent by server during sendStream or subscribe requests indicating a status update.""" + """ + sent by server during sendStream or subscribe requests + """ contextId: str - """The context the task is associated with.""" + """ + The context the task is associated with + """ final: bool - """Indicates the end of the event stream for this task (implies the task is finished or failed).""" + """ + Indicates the end of the event stream + """ kind: Literal['status-update'] = 'status-update' - """Event type.""" + """ + Event type + """ metadata: dict[str, Any] | None = None - """Extension metadata.""" + """ + Extension metadata. + """ status: TaskStatus - """Current status of the task.""" + """ + Current status of the task + """ taskId: str - """Task id.""" + """ + Task id + """ class A2ARequest( @@ -749,8 +1340,6 @@ class A2ARequest( | TaskResubscriptionRequest ] ): - """A2A supported request types.""" - root: ( SendMessageRequest | SendStreamingMessageRequest @@ -760,12 +1349,14 @@ class A2ARequest( | GetTaskPushNotificationConfigRequest | TaskResubscriptionRequest ) + """ + A2A supported request types + """ class AgentCard(BaseModel): - """An AgentCard conveys key information about an agent. - - Includes: + """ + An AgentCard conveys key information: - Overall details (version, name, description, uses) - Skills: A set of capabilities the agent can perform - Default modalities/content types supported by the agent. @@ -773,90 +1364,178 @@ class AgentCard(BaseModel): """ capabilities: AgentCapabilities - """Optional capabilities supported by the agent.""" + """ + Optional capabilities supported by the agent. + """ defaultInputModes: list[str] - """The set of interaction modes that the agent supports across all skills. This can be overridden per-skill. Supported mime types for input.""" + """ + The set of interaction modes that the agent + supports across all skills. This can be overridden per-skill. + Supported mime types for input. + """ defaultOutputModes: list[str] - """Supported mime types for output.""" + """ + Supported mime types for output. + """ description: str - """A human-readable description of the agent. Used to assist users and other agents in understanding what the agent can do.""" + """ + A human-readable description of the agent. Used to assist users and + other agents in understanding what the agent can do. + """ documentationUrl: str | None = None - """A URL to documentation for the agent.""" + """ + A URL to documentation for the agent. + """ name: str - """Human readable name of the agent.""" + """ + Human readable name of the agent. + """ provider: AgentProvider | None = None - """The service provider of the agent.""" + """ + The service provider of the agent + """ security: list[dict[str, list[str]]] | None = None - """Security requirements for contacting the agent.""" + """ + Security requirements for contacting the agent. + """ securitySchemes: dict[str, SecurityScheme] | None = None - """Security scheme details used for authenticating with this agent.""" + """ + Security scheme details used for authenticating with this agent. + """ skills: list[AgentSkill] - """Skills are a unit of capability that an agent can perform.""" + """ + Skills are a unit of capability that an agent can perform. + """ url: str - """A URL to the address the agent is hosted at.""" + """ + A URL to the address the agent is hosted at. + """ version: str - """The version of the agent - format is up to the provider.""" + """ + The version of the agent - format is up to the provider. + """ class Task(BaseModel): - """Represents a task managed by the agent.""" - artifacts: list[Artifact] | None = None - """Collection of artifacts created by the agent.""" + """ + Collection of artifacts created by the agent. + """ contextId: str - """Server-generated id for contextual alignment across interactions.""" + """ + Server-generated id for contextual alignment across interactions + """ history: list[Message] | None = None - """History of messages associated with the task.""" id: str - """Unique identifier for the task.""" + """ + Unique identifier for the task + """ kind: Literal['task'] = 'task' - """Event type.""" + """ + Event type + """ metadata: dict[str, Any] | None = None - """Extension metadata.""" + """ + Extension metadata. + """ status: TaskStatus - """Current status of the task.""" + """ + Current status of the task + """ -class CancelTaskSuccessResponse(JSONRPCResult): - """JSON-RPC success response model for the 'tasks/cancel' method.""" +class CancelTaskSuccessResponse(BaseModel): + """ + JSON-RPC success response model for the 'tasks/cancel' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ result: Task - """The result object on success.""" + """ + The result object on success. + """ -class GetTaskSuccessResponse(JSONRPCResult): - """JSON-RPC success response for the 'tasks/get' method.""" +class GetTaskSuccessResponse(BaseModel): + """ + JSON-RPC success response for the 'tasks/get' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ result: Task - """The result object on success.""" + """ + The result object on success. + """ -class SendMessageSuccessResponse(JSONRPCResult): - """JSON-RPC success response model for the 'message/send' method.""" +class SendMessageSuccessResponse(BaseModel): + """ + JSON-RPC success response model for the 'message/send' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ result: Task | Message - """The result object on success.""" + """ + The result object on success + """ -class SendStreamingMessageSuccessResponse(JSONRPCResult): - """JSON-RPC success response model for the 'message/stream' method.""" +class SendStreamingMessageSuccessResponse(BaseModel): + """ + JSON-RPC success response model for the 'message/stream' method. + """ + id: str | int | None = None + """ + An identifier established by the Client that MUST contain a String, Number + Numbers SHOULD NOT contain fractional parts. + """ + jsonrpc: Literal['2.0'] = '2.0' + """ + Specifies the version of the JSON-RPC protocol. MUST be exactly "2.0". + """ result: Task | Message | TaskStatusUpdateEvent | TaskArtifactUpdateEvent - """The result object on success.""" - + """ + The result object on success + """ -class CancelTaskResponse( - RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse] -): - """JSON-RPC response for the 'tasks/cancel' method.""" +class CancelTaskResponse(RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse]): root: JSONRPCErrorResponse | CancelTaskSuccessResponse + """ + JSON-RPC response for the 'tasks/cancel' method. + """ class GetTaskResponse(RootModel[JSONRPCErrorResponse | GetTaskSuccessResponse]): - """JSON-RPC success response for the 'tasks/get' method.""" - root: JSONRPCErrorResponse | GetTaskSuccessResponse + """ + JSON-RPC success response for the 'tasks/get' method. + """ class JSONRPCResponse( @@ -870,8 +1549,6 @@ class JSONRPCResponse( | GetTaskPushNotificationConfigSuccessResponse ] ): - """Represents a JSON-RPC 2.0 Response object.""" - root: ( JSONRPCErrorResponse | SendMessageSuccessResponse @@ -881,19 +1558,22 @@ class JSONRPCResponse( | SetTaskPushNotificationConfigSuccessResponse | GetTaskPushNotificationConfigSuccessResponse ) + """ + Represents a JSON-RPC 2.0 Response object. + """ -class SendMessageResponse( - RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse] -): - """JSON-RPC response model for the 'message/send' method.""" - +class SendMessageResponse(RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse]): root: JSONRPCErrorResponse | SendMessageSuccessResponse + """ + JSON-RPC response model for the 'message/send' method. + """ class SendStreamingMessageResponse( RootModel[JSONRPCErrorResponse | SendStreamingMessageSuccessResponse] ): - """JSON-RPC response model for the 'message/stream' method.""" - root: JSONRPCErrorResponse | SendStreamingMessageSuccessResponse + """ + JSON-RPC response model for the 'message/stream' method. + """ From 4e3f68a68118b78ad9b757ba6601764ef24cbd72 Mon Sep 17 00:00:00 2001 From: "release-please[bot]" <55107282+release-please[bot]@users.noreply.github.com> Date: Tue, 20 May 2025 21:11:06 +0000 Subject: [PATCH 6/8] chore(main): release 0.2.3 (#68) Co-authored-by: release-please[bot] <55107282+release-please[bot]@users.noreply.github.com> --- CHANGELOG.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index ed67f7f7..e3784083 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [0.2.3](https://github.com/google/a2a-python/compare/v0.2.2...v0.2.3) (2025-05-20) + + +### Features + +* Add request context builder with referenceTasks ([#56](https://github.com/google/a2a-python/issues/56)) ([f20bfe7](https://github.com/google/a2a-python/commit/f20bfe74b8cc854c9c29720b2ea3859aff8f509e)) + ## [0.2.2](https://github.com/google/a2a-python/compare/v0.2.1...v0.2.2) (2025-05-20) From 7ee14784df247ba506769866161ba4a659650a98 Mon Sep 17 00:00:00 2001 From: Junjie Bu Date: Tue, 20 May 2025 15:29:20 -0700 Subject: [PATCH 7/8] test: Adding 8 tests for jsonrpc_handler.py and also fix minor waring in test_integration.py (#72) * test: Adding 8 tests for jsonrpc_handler.py and also fix minor waring in test_integration.py * test: remove comments --- .../request_handlers/test_jsonrpc_handler.py | 270 +++++++++++++++++- tests/server/test_integration.py | 4 +- 2 files changed, 269 insertions(+), 5 deletions(-) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index ee3843e2..05f95c7c 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -3,12 +3,15 @@ from collections.abc import AsyncGenerator from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch, call +from unittest.mock import AsyncMock, MagicMock, call, patch -import pytest import httpx +import pytest from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.agent_execution.request_context_builder import ( + RequestContextBuilder, +) from a2a.server.events import ( QueueManager, ) @@ -30,14 +33,15 @@ GetTaskRequest, GetTaskResponse, GetTaskSuccessResponse, + InternalError, JSONRPCErrorResponse, Message, MessageSendConfiguration, MessageSendParams, Part, + PushNotificationConfig, SendMessageRequest, SendMessageSuccessResponse, - PushNotificationConfig, SendStreamingMessageRequest, SendStreamingMessageSuccessResponse, SetTaskPushNotificationConfigRequest, @@ -59,6 +63,7 @@ ) from a2a.utils.errors import ServerError + MINIMAL_TASK: dict[str, Any] = { 'id': 'task_123', 'contextId': 'session-xyz', @@ -680,6 +685,265 @@ async def test_on_resubscribe_no_existing_task_error(self) -> None: self.assertIsInstance(collected_events[0].root, JSONRPCErrorResponse) assert collected_events[0].root.error == TaskNotFoundError() + async def test_streaming_not_supported_error( + self, + ) -> None: + """Test that on_message_send_stream raises an error when streaming not supported.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + # Create agent card with streaming capability disabled + self.mock_agent_card.capabilities = AgentCapabilities(streaming=False) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Act & Assert + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + + # Should raise ServerError about streaming not supported + with self.assertRaises(ServerError) as context: + async for _ in handler.on_message_send_stream(request): + pass + + aaa = context.exception + self.assertEqual( + str(context.exception.error.message), + 'Streaming is not supported by the agent', + ) + + async def test_push_notifications_not_supported_error(self) -> None: + """Test that set_push_notification raises an error when push notifications not supported.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + # Create agent card with push notifications capability disabled + self.mock_agent_card.capabilities = AgentCapabilities( + pushNotifications=False, streaming=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Act & Assert + task_push_config = TaskPushNotificationConfig( + taskId='task_123', + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + + # Should raise ServerError about push notifications not supported + with self.assertRaises(ServerError) as context: + await handler.set_push_notification(request) + + self.assertEqual( + str(context.exception.error.message), + 'Push notifications are not supported by the agent', + ) + + async def test_on_get_push_notification_no_push_notifier(self) -> None: + """Test get_push_notification with no push notifier configured.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + # Create request handler without a push notifier + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + self.mock_agent_card.capabilities = AgentCapabilities( + pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + + # Act + get_request = GetTaskPushNotificationConfigRequest( + id='1', params=TaskIdParams(id=mock_task.id) + ) + response = await handler.get_push_notification(get_request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertEqual(response.root.error, UnsupportedOperationError()) + + async def test_on_set_push_notification_no_push_notifier(self) -> None: + """Test set_push_notification with no push notifier configured.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + # Create request handler without a push notifier + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + self.mock_agent_card.capabilities = AgentCapabilities( + pushNotifications=True + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + + # Act + task_push_config = TaskPushNotificationConfig( + taskId=mock_task.id, + pushNotificationConfig=PushNotificationConfig( + url='http://example.com' + ), + ) + request = SetTaskPushNotificationConfigRequest( + id='1', params=task_push_config + ) + response = await handler.set_push_notification(request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertEqual(response.root.error, UnsupportedOperationError()) + + async def test_on_message_send_internal_error(self) -> None: + """Test on_message_send with an internal error.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Make the request handler raise an Internal error without specifying an error type + async def raise_server_error(*args, **kwargs): + raise ServerError(InternalError(message='Internal Error')) + + # Patch the method to raise an error + with patch.object( + request_handler, 'on_message_send', side_effect=raise_server_error + ): + # Act + request = SendMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + response = await handler.on_message_send(request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertIsInstance(response.root.error, InternalError) + + async def test_on_message_stream_internal_error(self) -> None: + """Test on_message_send_stream with an internal error.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + self.mock_agent_card.capabilities = AgentCapabilities(streaming=True) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Make the request handler raise an Internal error without specifying an error type + async def raise_server_error(*args, **kwargs): + raise ServerError(InternalError(message='Internal Error')) + yield # Need this to make it an async generator + + # Patch the method to raise an error + with patch.object( + request_handler, + 'on_message_send_stream', + return_value=raise_server_error(), + ): + # Act + request = SendStreamingMessageRequest( + id='1', + params=MessageSendParams(message=Message(**MESSAGE_PAYLOAD)), + ) + + # Get the single error response + responses = [] + async for response in handler.on_message_send_stream(request): + responses.append(response) + + # Assert + self.assertEqual(len(responses), 1) + self.assertIsInstance(responses[0].root, JSONRPCErrorResponse) + self.assertIsInstance(responses[0].root.error, InternalError) + + async def test_default_request_handler_with_custom_components(self) -> None: + """Test DefaultRequestHandler initialization with custom components.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + mock_queue_manager = AsyncMock(spec=QueueManager) + mock_push_notifier = AsyncMock(spec=PushNotifier) + mock_request_context_builder = AsyncMock(spec=RequestContextBuilder) + + # Act + handler = DefaultRequestHandler( + agent_executor=mock_agent_executor, + task_store=mock_task_store, + queue_manager=mock_queue_manager, + push_notifier=mock_push_notifier, + request_context_builder=mock_request_context_builder, + ) + + # Assert + self.assertEqual(handler.agent_executor, mock_agent_executor) + self.assertEqual(handler.task_store, mock_task_store) + self.assertEqual(handler._queue_manager, mock_queue_manager) + self.assertEqual(handler._push_notifier, mock_push_notifier) + self.assertEqual( + handler._request_context_builder, mock_request_context_builder + ) + + async def test_on_message_send_error_handling(self) -> None: + """Test error handling in on_message_send when consuming raises ServerError.""" + # Arrange + mock_agent_executor = AsyncMock(spec=AgentExecutor) + mock_task_store = AsyncMock(spec=TaskStore) + request_handler = DefaultRequestHandler( + mock_agent_executor, mock_task_store + ) + handler = JSONRPCHandler(self.mock_agent_card, request_handler) + + # Let task exist + mock_task = Task(**MINIMAL_TASK) + mock_task_store.get.return_value = mock_task + + # Set up consume_and_break_on_interrupt to raise ServerError + async def consume_raises_error(*args, **kwargs): + raise ServerError(error=UnsupportedOperationError()) + + with patch( + 'a2a.server.tasks.result_aggregator.ResultAggregator.consume_and_break_on_interrupt', + side_effect=consume_raises_error, + ): + # Act + request = SendMessageRequest( + id='1', + params=MessageSendParams( + message=Message( + **MESSAGE_PAYLOAD, + taskId=mock_task.id, + contextId=mock_task.contextId, + ) + ), + ) + + response = await handler.on_message_send(request) + + # Assert + self.assertIsInstance(response.root, JSONRPCErrorResponse) + self.assertEqual(response.root.error, UnsupportedOperationError()) + async def test_on_message_send_task_id_mismatch(self) -> None: mock_agent_executor = AsyncMock(spec=AgentExecutor) mock_task_store = AsyncMock(spec=TaskStore) diff --git a/tests/server/test_integration.py b/tests/server/test_integration.py index 79814577..b116b2cc 100644 --- a/tests/server/test_integration.py +++ b/tests/server/test_integration.py @@ -230,7 +230,7 @@ def test_cancel_task(client: TestClient, handler: mock.AsyncMock): """Test cancelling a task.""" # Setup mock response task_status = TaskStatus(**MINIMAL_TASK_STATUS) - task_status.state = TaskState.canceled # 'cancelled' # + task_status.state = TaskState.canceled # 'cancelled' # task = Task( id='task1', contextId='ctx1', state='cancelled', status=task_status ) @@ -543,7 +543,7 @@ async def stream_generator(): def test_invalid_json(client: TestClient): """Test handling invalid JSON.""" - response = client.post('/', data='This is not JSON') + response = client.post('/', content=b'This is not JSON') # Use bytes assert response.status_code == 200 # JSON-RPC errors still return 200 data = response.json() assert 'error' in data From 962a25f3b0a6c66e6f904c6fbafc1a2ad8c3cd2b Mon Sep 17 00:00:00 2001 From: Krishna Thota Date: Tue, 20 May 2025 15:25:06 -0700 Subject: [PATCH 8/8] Add tests --- tests/server/request_handlers/test_jsonrpc_handler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 05f95c7c..5921ba8d 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -775,7 +775,7 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None: # Assert self.assertIsInstance(response.root, JSONRPCErrorResponse) - self.assertEqual(response.root.error, UnsupportedOperationError()) + self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore async def test_on_set_push_notification_no_push_notifier(self) -> None: """Test set_push_notification with no push notifier configured.""" @@ -808,7 +808,7 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None: # Assert self.assertIsInstance(response.root, JSONRPCErrorResponse) - self.assertEqual(response.root.error, UnsupportedOperationError()) + self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore async def test_on_message_send_internal_error(self) -> None: """Test on_message_send with an internal error.""" @@ -837,7 +837,7 @@ async def raise_server_error(*args, **kwargs): # Assert self.assertIsInstance(response.root, JSONRPCErrorResponse) - self.assertIsInstance(response.root.error, InternalError) + self.assertIsInstance(response.root.error, InternalError) # type: ignore async def test_on_message_stream_internal_error(self) -> None: """Test on_message_send_stream with an internal error."""