diff --git a/src/a2a/server/agent_execution/context.py b/src/a2a/server/agent_execution/context.py index 371b8419..27464414 100644 --- a/src/a2a/server/agent_execution/context.py +++ b/src/a2a/server/agent_execution/context.py @@ -1,5 +1,6 @@ import uuid +from a2a.server.context import ServerCallContext from a2a.types import ( InvalidParamsError, Message, @@ -26,6 +27,7 @@ def __init__( context_id: str | None = None, task: Task | None = None, related_tasks: list[Task] | None = None, + call_context: ServerCallContext | None = None, ): """Initializes the RequestContext. @@ -43,6 +45,7 @@ def __init__( self._context_id = context_id self._current_task = task self._related_tasks = related_tasks + self._call_context = call_context # If the task id and context id were provided, make sure they # match the request. Otherwise, create them if self._params: @@ -125,6 +128,11 @@ def configuration(self) -> MessageSendConfiguration | None: return None return self._params.configuration + @property + def call_context(self) -> ServerCallContext | None: + """The server call context associated with this request.""" + return self._call_context + def _check_or_generate_task_id(self) -> None: """Ensures a task ID is present, generating one if necessary.""" if not self._params: diff --git a/src/a2a/server/agent_execution/request_context_builder.py b/src/a2a/server/agent_execution/request_context_builder.py index 5a59eb96..0e36254b 100644 --- a/src/a2a/server/agent_execution/request_context_builder.py +++ b/src/a2a/server/agent_execution/request_context_builder.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from a2a.server.agent_execution import RequestContext +from a2a.server.context import ServerCallContext from a2a.types import MessageSendParams, Task @@ -14,5 +15,6 @@ async def build( task_id: str | None = None, context_id: str | None = None, task: Task | None = None, + context: ServerCallContext | None = None, ) -> RequestContext: pass diff --git a/src/a2a/server/agent_execution/simple_request_context_builder.py b/src/a2a/server/agent_execution/simple_request_context_builder.py index 4a9b9a88..16a84d7b 100644 --- a/src/a2a/server/agent_execution/simple_request_context_builder.py +++ b/src/a2a/server/agent_execution/simple_request_context_builder.py @@ -1,6 +1,7 @@ import asyncio from a2a.server.agent_execution import RequestContext, RequestContextBuilder +from a2a.server.context import ServerCallContext from a2a.server.tasks import TaskStore from a2a.types import MessageSendParams, Task @@ -22,6 +23,7 @@ async def build( task_id: str | None = None, context_id: str | None = None, task: Task | None = None, + context: ServerCallContext | None = None, ) -> RequestContext: related_tasks: list[Task] | None = None @@ -45,4 +47,5 @@ async def build( context_id=context_id, task=task, related_tasks=related_tasks, + call_context=context, ) diff --git a/src/a2a/server/apps/starlette_app.py b/src/a2a/server/apps/starlette_app.py index 87b3edaf..7f4d6e8b 100644 --- a/src/a2a/server/apps/starlette_app.py +++ b/src/a2a/server/apps/starlette_app.py @@ -2,6 +2,7 @@ import logging import traceback +from abc import ABC, abstractmethod from collections.abc import AsyncGenerator from typing import Any @@ -12,9 +13,9 @@ from starlette.responses import JSONResponse, Response from starlette.routing import Route -from a2a.server.request_handlers.request_handler import RequestHandler +from a2a.server.context import ServerCallContext from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler - +from a2a.server.request_handlers.request_handler import RequestHandler from a2a.types import ( A2AError, A2ARequest, @@ -41,6 +42,14 @@ logger = logging.getLogger(__name__) +class CallContextBuilder(ABC): + """A class for building ServerCallContexts using the Starlette Request.""" + + @abstractmethod + def build(self, request: Request) -> ServerCallContext: + """Builds a ServerCallContext from a Starlette Request.""" + + class A2AStarletteApplication: """A Starlette application implementing the A2A protocol server endpoints. @@ -49,18 +58,27 @@ class A2AStarletteApplication: (SSE). """ - def __init__(self, agent_card: AgentCard, http_handler: RequestHandler): + def __init__( + self, + agent_card: AgentCard, + http_handler: RequestHandler, + context_builder: CallContextBuilder | None = None, + ): """Initializes the A2AStarletteApplication. Args: agent_card: The AgentCard describing the agent's capabilities. http_handler: The handler instance responsible for processing A2A requests via http. + context_builder: The CallContextBuilder used to construct the + ServerCallContext passed to the http_handler. If None, no + ServerCallContext is passed. """ self.agent_card = agent_card self.handler = JSONRPCHandler( agent_card=agent_card, request_handler=http_handler ) + self._context_builder = context_builder def _generate_error_response( self, request_id: str | int | None, error: JSONRPCError | A2AError @@ -122,6 +140,11 @@ async def _handle_requests(self, request: Request) -> Response: try: body = await request.json() a2a_request = A2ARequest.model_validate(body) + call_context = ( + self._context_builder.build(request) + if self._context_builder + else None + ) request_id = a2a_request.root.id request_obj = a2a_request.root @@ -131,11 +154,11 @@ async def _handle_requests(self, request: Request) -> Response: TaskResubscriptionRequest | SendStreamingMessageRequest, ): return await self._process_streaming_request( - request_id, a2a_request + request_id, a2a_request, call_context ) return await self._process_non_streaming_request( - request_id, a2a_request + request_id, a2a_request, call_context ) except MethodNotImplementedError: traceback.print_exc() @@ -161,7 +184,10 @@ async def _handle_requests(self, request: Request) -> Response: ) async def _process_streaming_request( - self, request_id: str | int | None, a2a_request: A2ARequest + self, + request_id: str | int | None, + a2a_request: A2ARequest, + context: ServerCallContext, ) -> Response: """Processes streaming requests (message/stream or tasks/resubscribe). @@ -178,14 +204,21 @@ async def _process_streaming_request( request_obj, SendStreamingMessageRequest, ): - handler_result = self.handler.on_message_send_stream(request_obj) + handler_result = self.handler.on_message_send_stream( + request_obj, context + ) elif isinstance(request_obj, TaskResubscriptionRequest): - handler_result = self.handler.on_resubscribe_to_task(request_obj) + handler_result = self.handler.on_resubscribe_to_task( + request_obj, context + ) return self._create_response(handler_result) async def _process_non_streaming_request( - self, request_id: str | int | None, a2a_request: A2ARequest + self, + request_id: str | int | None, + a2a_request: A2ARequest, + context: ServerCallContext, ) -> Response: """Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*). @@ -200,18 +233,26 @@ async def _process_non_streaming_request( handler_result: Any = None match request_obj: case SendMessageRequest(): - handler_result = await self.handler.on_message_send(request_obj) + handler_result = await self.handler.on_message_send( + request_obj, context + ) case CancelTaskRequest(): - handler_result = await self.handler.on_cancel_task(request_obj) + handler_result = await self.handler.on_cancel_task( + request_obj, context + ) case GetTaskRequest(): - handler_result = await self.handler.on_get_task(request_obj) + handler_result = await self.handler.on_get_task( + request_obj, context + ) case SetTaskPushNotificationConfigRequest(): handler_result = await self.handler.set_push_notification( - request_obj + request_obj, + context, ) case GetTaskPushNotificationConfigRequest(): handler_result = await self.handler.get_push_notification( - request_obj + request_obj, + context, ) case _: logger.error( diff --git a/src/a2a/server/context.py b/src/a2a/server/context.py new file mode 100644 index 00000000..21f2d66d --- /dev/null +++ b/src/a2a/server/context.py @@ -0,0 +1,24 @@ +"""Defines the ServerCallContext class.""" + +import collections.abc +import typing + + +State = collections.abc.MutableMapping[str, typing.Any] + + +class ServerCallContext: + """A context passed when calling a server method. + + This class allows storing arbitrary user data in the state attribute. + """ + + def __init__(self, state: State | None = None): + if state is None: + state = {} + self._state = state + + @property + def state(self) -> State: + """Get the user-provided state.""" + return self._state diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index e0ecd97a..09b1d304 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -10,6 +10,7 @@ RequestContextBuilder, SimpleRequestContextBuilder, ) +from a2a.server.context import ServerCallContext from a2a.server.events import ( Event, EventConsumer, @@ -70,6 +71,8 @@ def __init__( task_store: The `TaskStore` instance to manage task persistence. queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`. push_notifier: The `PushNotifier` instance for sending push notifications. Defaults to None. + request_context_builder: The `RequestContextBuilder` instance used + to build request contexts. Defaults to `SimpleRequestContextBuilder`. """ self.agent_executor = agent_executor self.task_store = task_store @@ -85,14 +88,20 @@ def __init__( self._running_agents = {} self._running_agents_lock = asyncio.Lock() - async def on_get_task(self, params: TaskQueryParams) -> Task | None: + async def on_get_task( + self, + params: TaskQueryParams, + context: ServerCallContext | None = None, + ) -> Task | None: """Default handler for 'tasks/get'.""" task: Task | None = await self.task_store.get(params.id) if not task: raise ServerError(error=TaskNotFoundError()) return task - async def on_cancel_task(self, params: TaskIdParams) -> Task | None: + async def on_cancel_task( + self, params: TaskIdParams, context: ServerCallContext | None = None + ) -> Task | None: """Default handler for 'tasks/cancel'. Attempts to cancel the task managed by the `AgentExecutor`. @@ -150,7 +159,9 @@ async def _run_event_stream( await queue.close() async def on_message_send( - self, params: MessageSendParams + self, + params: MessageSendParams, + context: ServerCallContext | None = None, ) -> Message | Task: """Default handler for 'message/send' interface (non-streaming). @@ -183,6 +194,7 @@ async def on_message_send( task_id=task.id if task else None, context_id=params.message.contextId, task=task, + context=context, ) task_id = cast(str, request_context.task_id) @@ -232,7 +244,9 @@ async def on_message_send( return result async def on_message_send_stream( - self, params: MessageSendParams + self, + params: MessageSendParams, + context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: """Default handler for 'message/stream' (streaming). @@ -270,6 +284,7 @@ async def on_message_send_stream( task_id=task.id if task else None, context_id=params.message.contextId, task=task, + context=context, ) task_id = cast(str, request_context.task_id) @@ -334,7 +349,9 @@ async def _cleanup_producer( self._running_agents.pop(task_id, None) async def on_set_task_push_notification_config( - self, params: TaskPushNotificationConfig + self, + params: TaskPushNotificationConfig, + context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: """Default handler for 'tasks/pushNotificationConfig/set'. @@ -355,7 +372,9 @@ async def on_set_task_push_notification_config( return params async def on_get_task_push_notification_config( - self, params: TaskIdParams + self, + params: TaskIdParams, + context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: """Default handler for 'tasks/pushNotificationConfig/get'. @@ -377,7 +396,9 @@ async def on_get_task_push_notification_config( ) async def on_resubscribe_to_task( - self, params: TaskIdParams + self, + params: TaskIdParams, + context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: """Default handler for 'tasks/resubscribe'. diff --git a/src/a2a/server/request_handlers/jsonrpc_handler.py b/src/a2a/server/request_handlers/jsonrpc_handler.py index c766d999..13d2854b 100644 --- a/src/a2a/server/request_handlers/jsonrpc_handler.py +++ b/src/a2a/server/request_handlers/jsonrpc_handler.py @@ -2,6 +2,7 @@ from collections.abc import AsyncIterable +from a2a.server.context import ServerCallContext from a2a.server.request_handlers.request_handler import RequestHandler from a2a.server.request_handlers.response_helpers import prepare_response_object from a2a.types import ( @@ -61,12 +62,15 @@ def __init__( self.request_handler = request_handler async def on_message_send( - self, request: SendMessageRequest + self, + request: SendMessageRequest, + context: ServerCallContext | None = None, ) -> SendMessageResponse: """Handles the 'message/send' JSON-RPC method. Args: request: The incoming `SendMessageRequest` object. + context: Context provided by the server. Returns: A `SendMessageResponse` object containing the result (Task or Message) @@ -75,7 +79,7 @@ async def on_message_send( # TODO: Wrap in error handler to return error states try: task_or_message = await self.request_handler.on_message_send( - request.params + request.params, context ) return prepare_response_object( request.id, @@ -96,7 +100,9 @@ async def on_message_send( 'Streaming is not supported by the agent', ) async def on_message_send_stream( - self, request: SendStreamingMessageRequest + self, + request: SendStreamingMessageRequest, + context: ServerCallContext | None = None, ) -> AsyncIterable[SendStreamingMessageResponse]: """Handles the 'message/stream' JSON-RPC method. @@ -104,6 +110,7 @@ async def on_message_send_stream( Args: request: The incoming `SendStreamingMessageRequest` object. + context: Context provided by the server. Yields: `SendStreamingMessageResponse` objects containing streaming events @@ -112,7 +119,7 @@ async def on_message_send_stream( """ try: async for event in self.request_handler.on_message_send_stream( - request.params + request.params, context ): yield prepare_response_object( request.id, @@ -134,18 +141,23 @@ async def on_message_send_stream( ) async def on_cancel_task( - self, request: CancelTaskRequest + self, + request: CancelTaskRequest, + context: ServerCallContext | None = None, ) -> CancelTaskResponse: """Handles the 'tasks/cancel' JSON-RPC method. Args: request: The incoming `CancelTaskRequest` object. + context: Context provided by the server. Returns: A `CancelTaskResponse` object containing the updated Task or a JSON-RPC error. """ try: - task = await self.request_handler.on_cancel_task(request.params) + task = await self.request_handler.on_cancel_task( + request.params, context + ) if task: return prepare_response_object( request.id, @@ -163,7 +175,9 @@ async def on_cancel_task( ) async def on_resubscribe_to_task( - self, request: TaskResubscriptionRequest + self, + request: TaskResubscriptionRequest, + context: ServerCallContext | None = None, ) -> AsyncIterable[SendStreamingMessageResponse]: """Handles the 'tasks/resubscribe' JSON-RPC method. @@ -171,6 +185,7 @@ async def on_resubscribe_to_task( Args: request: The incoming `TaskResubscriptionRequest` object. + context: Context provided by the server. Yields: `SendStreamingMessageResponse` objects containing streaming events @@ -178,7 +193,7 @@ async def on_resubscribe_to_task( """ try: async for event in self.request_handler.on_resubscribe_to_task( - request.params + request.params, context ): yield prepare_response_object( request.id, @@ -200,12 +215,15 @@ async def on_resubscribe_to_task( ) async def get_push_notification( - self, request: GetTaskPushNotificationConfigRequest + self, + request: GetTaskPushNotificationConfigRequest, + context: ServerCallContext | None = None, ) -> GetTaskPushNotificationConfigResponse: """Handles the 'tasks/pushNotificationConfig/get' JSON-RPC method. Args: request: The incoming `GetTaskPushNotificationConfigRequest` object. + context: Context provided by the server. Returns: A `GetTaskPushNotificationConfigResponse` object containing the config or a JSON-RPC error. @@ -213,7 +231,7 @@ async def get_push_notification( try: config = ( await self.request_handler.on_get_task_push_notification_config( - request.params + request.params, context ) ) return prepare_response_object( @@ -235,7 +253,9 @@ async def get_push_notification( 'Push notifications are not supported by the agent', ) async def set_push_notification( - self, request: SetTaskPushNotificationConfigRequest + self, + request: SetTaskPushNotificationConfigRequest, + context: ServerCallContext | None = None, ) -> SetTaskPushNotificationConfigResponse: """Handles the 'tasks/pushNotificationConfig/set' JSON-RPC method. @@ -243,6 +263,7 @@ async def set_push_notification( Args: request: The incoming `SetTaskPushNotificationConfigRequest` object. + context: Context provided by the server. Returns: A `SetTaskPushNotificationConfigResponse` object containing the config or a JSON-RPC error. @@ -254,7 +275,7 @@ async def set_push_notification( try: config = ( await self.request_handler.on_set_task_push_notification_config( - request.params + request.params, context ) ) return prepare_response_object( @@ -271,17 +292,24 @@ async def set_push_notification( ) ) - async def on_get_task(self, request: GetTaskRequest) -> GetTaskResponse: + async def on_get_task( + self, + request: GetTaskRequest, + context: ServerCallContext | None = None, + ) -> GetTaskResponse: """Handles the 'tasks/get' JSON-RPC method. Args: request: The incoming `GetTaskRequest` object. + context: Context provided by the server. Returns: A `GetTaskResponse` object containing the Task or a JSON-RPC error. """ try: - task = await self.request_handler.on_get_task(request.params) + task = await self.request_handler.on_get_task( + request.params, context + ) if task: return prepare_response_object( request.id, diff --git a/src/a2a/server/request_handlers/request_handler.py b/src/a2a/server/request_handlers/request_handler.py index a8229a8a..811c8da2 100644 --- a/src/a2a/server/request_handlers/request_handler.py +++ b/src/a2a/server/request_handlers/request_handler.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncGenerator +from a2a.server.context import ServerCallContext from a2a.server.events.event_queue import Event from a2a.types import ( Message, @@ -22,26 +23,36 @@ class RequestHandler(ABC): """ @abstractmethod - async def on_get_task(self, params: TaskQueryParams) -> Task | None: + async def on_get_task( + self, + params: TaskQueryParams, + context: ServerCallContext | None = None, + ) -> Task | None: """Handles the 'tasks/get' method. Retrieves the state and history of a specific task. Args: params: Parameters specifying the task ID and optionally history length. + context: Context provided by the server. Returns: The `Task` object if found, otherwise `None`. """ @abstractmethod - async def on_cancel_task(self, params: TaskIdParams) -> Task | None: + async def on_cancel_task( + self, + params: TaskIdParams, + context: ServerCallContext | None = None, + ) -> Task | None: """Handles the 'tasks/cancel' method. Requests the agent to cancel an ongoing task. Args: params: Parameters specifying the task ID. + context: Context provided by the server. Returns: The `Task` object with its status updated to canceled, or `None` if the task was not found. @@ -49,7 +60,9 @@ async def on_cancel_task(self, params: TaskIdParams) -> Task | None: @abstractmethod async def on_message_send( - self, params: MessageSendParams + self, + params: MessageSendParams, + context: ServerCallContext | None = None, ) -> Task | Message: """Handles the 'message/send' method (non-streaming). @@ -58,6 +71,7 @@ async def on_message_send( Args: params: Parameters including the message and configuration. + context: Context provided by the server. Returns: The final `Task` object or a final `Message` object. @@ -65,7 +79,9 @@ async def on_message_send( @abstractmethod async def on_message_send_stream( - self, params: MessageSendParams + self, + params: MessageSendParams, + context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: """Handles the 'message/stream' method (streaming). @@ -74,6 +90,7 @@ async def on_message_send_stream( Args: params: Parameters including the message and configuration. + context: Context provided by the server. Yields: `Event` objects from the agent's execution. @@ -86,7 +103,9 @@ async def on_message_send_stream( @abstractmethod async def on_set_task_push_notification_config( - self, params: TaskPushNotificationConfig + self, + params: TaskPushNotificationConfig, + context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: """Handles the 'tasks/pushNotificationConfig/set' method. @@ -94,6 +113,7 @@ async def on_set_task_push_notification_config( Args: params: Parameters including the task ID and push notification configuration. + context: Context provided by the server. Returns: The provided `TaskPushNotificationConfig` upon success. @@ -101,7 +121,9 @@ async def on_set_task_push_notification_config( @abstractmethod async def on_get_task_push_notification_config( - self, params: TaskIdParams + self, + params: TaskIdParams, + context: ServerCallContext | None = None, ) -> TaskPushNotificationConfig: """Handles the 'tasks/pushNotificationConfig/get' method. @@ -109,6 +131,7 @@ async def on_get_task_push_notification_config( Args: params: Parameters including the task ID. + context: Context provided by the server. Returns: The `TaskPushNotificationConfig` for the task. @@ -116,7 +139,9 @@ async def on_get_task_push_notification_config( @abstractmethod async def on_resubscribe_to_task( - self, params: TaskIdParams + self, + params: TaskIdParams, + context: ServerCallContext | None = None, ) -> AsyncGenerator[Event]: """Handles the 'tasks/resubscribe' method. @@ -124,6 +149,7 @@ async def on_resubscribe_to_task( Args: params: Parameters including the task ID. + context: Context provided by the server. Yields: `Event` objects from the agent's ongoing execution for the specified task. diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 431c54b1..459b6e29 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -1,6 +1,5 @@ import unittest import unittest.async_case - from collections.abc import AsyncGenerator from typing import Any from unittest.mock import AsyncMock, MagicMock, call, patch @@ -8,19 +7,14 @@ import httpx import pytest - from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.agent_execution.request_context_builder import ( RequestContextBuilder, ) -from a2a.server.events import ( - QueueManager, -) +from a2a.server.context import ServerCallContext +from a2a.server.events import QueueManager from a2a.server.events.event_queue import EventQueue -from a2a.server.request_handlers import ( - DefaultRequestHandler, - JSONRPCHandler, -) +from a2a.server.request_handlers import DefaultRequestHandler, JSONRPCHandler from a2a.server.tasks import InMemoryPushNotifier, PushNotifier, TaskStore from a2a.types import ( AgentCapabilities, @@ -60,11 +54,9 @@ TaskStatusUpdateEvent, TextPart, UnsupportedOperationError, - InternalError, ) from a2a.utils.errors import ServerError - MINIMAL_TASK: dict[str, Any] = { 'id': 'task_123', 'contextId': 'session-xyz', @@ -91,12 +83,15 @@ async def test_on_get_task_success(self) -> None: request_handler = DefaultRequestHandler( mock_agent_executor, mock_task_store ) + call_context = ServerCallContext(state={'foo': 'bar'}) handler = JSONRPCHandler(self.mock_agent_card, request_handler) task_id = 'test_task_id' mock_task = Task(**MINIMAL_TASK) mock_task_store.get.return_value = mock_task request = GetTaskRequest(id='1', params=TaskQueryParams(id=task_id)) - response: GetTaskResponse = await handler.on_get_task(request) + response: GetTaskResponse = await handler.on_get_task( + request, call_context + ) self.assertIsInstance(response.root, GetTaskSuccessResponse) assert response.root.result == mock_task # type: ignore mock_task_store.get.assert_called_once_with(task_id) @@ -114,7 +109,10 @@ async def test_on_get_task_not_found(self) -> None: method='tasks/get', params=TaskQueryParams(id='nonexistent_id'), ) - response: GetTaskResponse = await handler.on_get_task(request) + call_context = ServerCallContext(state={'foo': 'bar'}) + response: GetTaskResponse = await handler.on_get_task( + request, call_context + ) self.assertIsInstance(response.root, JSONRPCErrorResponse) assert response.root.error == TaskNotFoundError() # type: ignore @@ -129,6 +127,7 @@ async def test_on_cancel_task_success(self) -> None: mock_task = Task(**MINIMAL_TASK) mock_task_store.get.return_value = mock_task mock_agent_executor.cancel.return_value = None + call_context = ServerCallContext(state={'foo': 'bar'}) async def streaming_coro(): yield mock_task @@ -138,7 +137,7 @@ async def streaming_coro(): return_value=streaming_coro(), ): request = CancelTaskRequest(id='1', params=TaskIdParams(id=task_id)) - response = await handler.on_cancel_task(request) + response = await handler.on_cancel_task(request, call_context) assert mock_agent_executor.cancel.call_count == 1 self.assertIsInstance(response.root, CancelTaskSuccessResponse) assert response.root.result == mock_task # type: ignore @@ -155,6 +154,7 @@ async def test_on_cancel_task_not_supported(self) -> None: mock_task = Task(**MINIMAL_TASK) mock_task_store.get.return_value = mock_task mock_agent_executor.cancel.return_value = None + call_context = ServerCallContext(state={'foo': 'bar'}) async def streaming_coro(): raise ServerError(UnsupportedOperationError()) @@ -165,7 +165,7 @@ async def streaming_coro(): return_value=streaming_coro(), ): request = CancelTaskRequest(id='1', params=TaskIdParams(id=task_id)) - response = await handler.on_cancel_task(request) + response = await handler.on_cancel_task(request, call_context) assert mock_agent_executor.cancel.call_count == 1 self.assertIsInstance(response.root, JSONRPCErrorResponse) assert response.root.error == UnsupportedOperationError() # type: ignore @@ -778,7 +778,6 @@ async def test_on_get_push_notification_no_push_notifier(self) -> None: self.assertIsInstance(response.root, JSONRPCErrorResponse) self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore - async def test_on_set_push_notification_no_push_notifier(self) -> None: """Test set_push_notification with no push notifier configured.""" # Arrange @@ -812,7 +811,6 @@ async def test_on_set_push_notification_no_push_notifier(self) -> None: self.assertIsInstance(response.root, JSONRPCErrorResponse) self.assertEqual(response.root.error, UnsupportedOperationError()) # type: ignore - async def test_on_message_send_internal_error(self) -> None: """Test on_message_send with an internal error.""" # Arrange @@ -842,7 +840,6 @@ async def raise_server_error(*args, **kwargs): self.assertIsInstance(response.root, JSONRPCErrorResponse) self.assertIsInstance(response.root.error, InternalError) # type: ignore - async def test_on_message_stream_internal_error(self) -> None: """Test on_message_send_stream with an internal error.""" # Arrange