Skip to content

Commit 85b521d

Browse files
authored
feat: Introduce a ServerCallContext (#94)
* Introduce a ServerCallContext parameter * Finish comment * Remove Starlette-specific DefaultCallContextBuilder * Update comment
1 parent c351656 commit 85b521d

File tree

9 files changed

+210
-60
lines changed

9 files changed

+210
-60
lines changed

src/a2a/server/agent_execution/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22

3+
from a2a.server.context import ServerCallContext
34
from a2a.types import (
45
InvalidParamsError,
56
Message,
@@ -26,6 +27,7 @@ def __init__(
2627
context_id: str | None = None,
2728
task: Task | None = None,
2829
related_tasks: list[Task] | None = None,
30+
call_context: ServerCallContext | None = None,
2931
):
3032
"""Initializes the RequestContext.
3133
@@ -43,6 +45,7 @@ def __init__(
4345
self._context_id = context_id
4446
self._current_task = task
4547
self._related_tasks = related_tasks
48+
self._call_context = call_context
4649
# If the task id and context id were provided, make sure they
4750
# match the request. Otherwise, create them
4851
if self._params:
@@ -125,6 +128,11 @@ def configuration(self) -> MessageSendConfiguration | None:
125128
return None
126129
return self._params.configuration
127130

131+
@property
132+
def call_context(self) -> ServerCallContext | None:
133+
"""The server call context associated with this request."""
134+
return self._call_context
135+
128136
def _check_or_generate_task_id(self) -> None:
129137
"""Ensures a task ID is present, generating one if necessary."""
130138
if not self._params:

src/a2a/server/agent_execution/request_context_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from a2a.server.agent_execution import RequestContext
4+
from a2a.server.context import ServerCallContext
45
from a2a.types import MessageSendParams, Task
56

67

@@ -14,5 +15,6 @@ async def build(
1415
task_id: str | None = None,
1516
context_id: str | None = None,
1617
task: Task | None = None,
18+
context: ServerCallContext | None = None,
1719
) -> RequestContext:
1820
pass

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22

33
from a2a.server.agent_execution import RequestContext, RequestContextBuilder
4+
from a2a.server.context import ServerCallContext
45
from a2a.server.tasks import TaskStore
56
from a2a.types import MessageSendParams, Task
67

@@ -22,6 +23,7 @@ async def build(
2223
task_id: str | None = None,
2324
context_id: str | None = None,
2425
task: Task | None = None,
26+
context: ServerCallContext | None = None,
2527
) -> RequestContext:
2628
related_tasks: list[Task] | None = None
2729

@@ -45,4 +47,5 @@ async def build(
4547
context_id=context_id,
4648
task=task,
4749
related_tasks=related_tasks,
50+
call_context=context,
4851
)

src/a2a/server/apps/starlette_app.py

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import logging
33
import traceback
44

5+
from abc import ABC, abstractmethod
56
from collections.abc import AsyncGenerator
67
from typing import Any
78

@@ -12,9 +13,9 @@
1213
from starlette.responses import JSONResponse, Response
1314
from starlette.routing import Route
1415

15-
from a2a.server.request_handlers.request_handler import RequestHandler
16+
from a2a.server.context import ServerCallContext
1617
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
17-
18+
from a2a.server.request_handlers.request_handler import RequestHandler
1819
from a2a.types import (
1920
A2AError,
2021
A2ARequest,
@@ -41,6 +42,14 @@
4142
logger = logging.getLogger(__name__)
4243

4344

45+
class CallContextBuilder(ABC):
46+
"""A class for building ServerCallContexts using the Starlette Request."""
47+
48+
@abstractmethod
49+
def build(self, request: Request) -> ServerCallContext:
50+
"""Builds a ServerCallContext from a Starlette Request."""
51+
52+
4453
class A2AStarletteApplication:
4554
"""A Starlette application implementing the A2A protocol server endpoints.
4655
@@ -49,18 +58,27 @@ class A2AStarletteApplication:
4958
(SSE).
5059
"""
5160

52-
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
61+
def __init__(
62+
self,
63+
agent_card: AgentCard,
64+
http_handler: RequestHandler,
65+
context_builder: CallContextBuilder | None = None,
66+
):
5367
"""Initializes the A2AStarletteApplication.
5468
5569
Args:
5670
agent_card: The AgentCard describing the agent's capabilities.
5771
http_handler: The handler instance responsible for processing A2A
5872
requests via http.
73+
context_builder: The CallContextBuilder used to construct the
74+
ServerCallContext passed to the http_handler. If None, no
75+
ServerCallContext is passed.
5976
"""
6077
self.agent_card = agent_card
6178
self.handler = JSONRPCHandler(
6279
agent_card=agent_card, request_handler=http_handler
6380
)
81+
self._context_builder = context_builder
6482

6583
def _generate_error_response(
6684
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -122,6 +140,11 @@ async def _handle_requests(self, request: Request) -> Response:
122140
try:
123141
body = await request.json()
124142
a2a_request = A2ARequest.model_validate(body)
143+
call_context = (
144+
self._context_builder.build(request)
145+
if self._context_builder
146+
else None
147+
)
125148

126149
request_id = a2a_request.root.id
127150
request_obj = a2a_request.root
@@ -131,11 +154,11 @@ async def _handle_requests(self, request: Request) -> Response:
131154
TaskResubscriptionRequest | SendStreamingMessageRequest,
132155
):
133156
return await self._process_streaming_request(
134-
request_id, a2a_request
157+
request_id, a2a_request, call_context
135158
)
136159

137160
return await self._process_non_streaming_request(
138-
request_id, a2a_request
161+
request_id, a2a_request, call_context
139162
)
140163
except MethodNotImplementedError:
141164
traceback.print_exc()
@@ -161,7 +184,10 @@ async def _handle_requests(self, request: Request) -> Response:
161184
)
162185

163186
async def _process_streaming_request(
164-
self, request_id: str | int | None, a2a_request: A2ARequest
187+
self,
188+
request_id: str | int | None,
189+
a2a_request: A2ARequest,
190+
context: ServerCallContext,
165191
) -> Response:
166192
"""Processes streaming requests (message/stream or tasks/resubscribe).
167193
@@ -178,14 +204,21 @@ async def _process_streaming_request(
178204
request_obj,
179205
SendStreamingMessageRequest,
180206
):
181-
handler_result = self.handler.on_message_send_stream(request_obj)
207+
handler_result = self.handler.on_message_send_stream(
208+
request_obj, context
209+
)
182210
elif isinstance(request_obj, TaskResubscriptionRequest):
183-
handler_result = self.handler.on_resubscribe_to_task(request_obj)
211+
handler_result = self.handler.on_resubscribe_to_task(
212+
request_obj, context
213+
)
184214

185215
return self._create_response(handler_result)
186216

187217
async def _process_non_streaming_request(
188-
self, request_id: str | int | None, a2a_request: A2ARequest
218+
self,
219+
request_id: str | int | None,
220+
a2a_request: A2ARequest,
221+
context: ServerCallContext,
189222
) -> Response:
190223
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
191224
@@ -200,18 +233,26 @@ async def _process_non_streaming_request(
200233
handler_result: Any = None
201234
match request_obj:
202235
case SendMessageRequest():
203-
handler_result = await self.handler.on_message_send(request_obj)
236+
handler_result = await self.handler.on_message_send(
237+
request_obj, context
238+
)
204239
case CancelTaskRequest():
205-
handler_result = await self.handler.on_cancel_task(request_obj)
240+
handler_result = await self.handler.on_cancel_task(
241+
request_obj, context
242+
)
206243
case GetTaskRequest():
207-
handler_result = await self.handler.on_get_task(request_obj)
244+
handler_result = await self.handler.on_get_task(
245+
request_obj, context
246+
)
208247
case SetTaskPushNotificationConfigRequest():
209248
handler_result = await self.handler.set_push_notification(
210-
request_obj
249+
request_obj,
250+
context,
211251
)
212252
case GetTaskPushNotificationConfigRequest():
213253
handler_result = await self.handler.get_push_notification(
214-
request_obj
254+
request_obj,
255+
context,
215256
)
216257
case _:
217258
logger.error(

src/a2a/server/context.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Defines the ServerCallContext class."""
2+
3+
import collections.abc
4+
import typing
5+
6+
7+
State = collections.abc.MutableMapping[str, typing.Any]
8+
9+
10+
class ServerCallContext:
11+
"""A context passed when calling a server method.
12+
13+
This class allows storing arbitrary user data in the state attribute.
14+
"""
15+
16+
def __init__(self, state: State | None = None):
17+
if state is None:
18+
state = {}
19+
self._state = state
20+
21+
@property
22+
def state(self) -> State:
23+
"""Get the user-provided state."""
24+
return self._state

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
RequestContextBuilder,
1111
SimpleRequestContextBuilder,
1212
)
13+
from a2a.server.context import ServerCallContext
1314
from a2a.server.events import (
1415
Event,
1516
EventConsumer,
@@ -70,6 +71,8 @@ def __init__(
7071
task_store: The `TaskStore` instance to manage task persistence.
7172
queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`.
7273
push_notifier: The `PushNotifier` instance for sending push notifications. Defaults to None.
74+
request_context_builder: The `RequestContextBuilder` instance used
75+
to build request contexts. Defaults to `SimpleRequestContextBuilder`.
7376
"""
7477
self.agent_executor = agent_executor
7578
self.task_store = task_store
@@ -85,14 +88,20 @@ def __init__(
8588
self._running_agents = {}
8689
self._running_agents_lock = asyncio.Lock()
8790

88-
async def on_get_task(self, params: TaskQueryParams) -> Task | None:
91+
async def on_get_task(
92+
self,
93+
params: TaskQueryParams,
94+
context: ServerCallContext | None = None,
95+
) -> Task | None:
8996
"""Default handler for 'tasks/get'."""
9097
task: Task | None = await self.task_store.get(params.id)
9198
if not task:
9299
raise ServerError(error=TaskNotFoundError())
93100
return task
94101

95-
async def on_cancel_task(self, params: TaskIdParams) -> Task | None:
102+
async def on_cancel_task(
103+
self, params: TaskIdParams, context: ServerCallContext | None = None
104+
) -> Task | None:
96105
"""Default handler for 'tasks/cancel'.
97106
98107
Attempts to cancel the task managed by the `AgentExecutor`.
@@ -150,7 +159,9 @@ async def _run_event_stream(
150159
await queue.close()
151160

152161
async def on_message_send(
153-
self, params: MessageSendParams
162+
self,
163+
params: MessageSendParams,
164+
context: ServerCallContext | None = None,
154165
) -> Message | Task:
155166
"""Default handler for 'message/send' interface (non-streaming).
156167
@@ -183,6 +194,7 @@ async def on_message_send(
183194
task_id=task.id if task else None,
184195
context_id=params.message.contextId,
185196
task=task,
197+
context=context,
186198
)
187199

188200
task_id = cast(str, request_context.task_id)
@@ -232,7 +244,9 @@ async def on_message_send(
232244
return result
233245

234246
async def on_message_send_stream(
235-
self, params: MessageSendParams
247+
self,
248+
params: MessageSendParams,
249+
context: ServerCallContext | None = None,
236250
) -> AsyncGenerator[Event]:
237251
"""Default handler for 'message/stream' (streaming).
238252
@@ -270,6 +284,7 @@ async def on_message_send_stream(
270284
task_id=task.id if task else None,
271285
context_id=params.message.contextId,
272286
task=task,
287+
context=context,
273288
)
274289

275290
task_id = cast(str, request_context.task_id)
@@ -334,7 +349,9 @@ async def _cleanup_producer(
334349
self._running_agents.pop(task_id, None)
335350

336351
async def on_set_task_push_notification_config(
337-
self, params: TaskPushNotificationConfig
352+
self,
353+
params: TaskPushNotificationConfig,
354+
context: ServerCallContext | None = None,
338355
) -> TaskPushNotificationConfig:
339356
"""Default handler for 'tasks/pushNotificationConfig/set'.
340357
@@ -355,7 +372,9 @@ async def on_set_task_push_notification_config(
355372
return params
356373

357374
async def on_get_task_push_notification_config(
358-
self, params: TaskIdParams
375+
self,
376+
params: TaskIdParams,
377+
context: ServerCallContext | None = None,
359378
) -> TaskPushNotificationConfig:
360379
"""Default handler for 'tasks/pushNotificationConfig/get'.
361380
@@ -377,7 +396,9 @@ async def on_get_task_push_notification_config(
377396
)
378397

379398
async def on_resubscribe_to_task(
380-
self, params: TaskIdParams
399+
self,
400+
params: TaskIdParams,
401+
context: ServerCallContext | None = None,
381402
) -> AsyncGenerator[Event]:
382403
"""Default handler for 'tasks/resubscribe'.
383404

0 commit comments

Comments
 (0)