Skip to content

Commit 91e150a

Browse files
mikeas1martimfasantos
authored andcommitted
feat: Introduce a ServerCallContext (#94)
* Introduce a ServerCallContext parameter * Finish comment * Remove Starlette-specific DefaultCallContextBuilder * Update comment
1 parent cf61bce commit 91e150a

File tree

9 files changed

+212
-61
lines changed

9 files changed

+212
-61
lines changed

src/a2a/server/agent_execution/context.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import uuid
22

3+
from a2a.server.context import ServerCallContext
34
from a2a.types import (
45
InvalidParamsError,
56
Message,
@@ -26,6 +27,7 @@ def __init__(
2627
context_id: str | None = None,
2728
task: Task | None = None,
2829
related_tasks: list[Task] | None = None,
30+
call_context: ServerCallContext | None = None,
2931
):
3032
"""Initializes the RequestContext.
3133
@@ -43,6 +45,7 @@ def __init__(
4345
self._context_id = context_id
4446
self._current_task = task
4547
self._related_tasks = related_tasks
48+
self._call_context = call_context
4649
# If the task id and context id were provided, make sure they
4750
# match the request. Otherwise, create them
4851
if self._params:
@@ -125,6 +128,11 @@ def configuration(self) -> MessageSendConfiguration | None:
125128
return None
126129
return self._params.configuration
127130

131+
@property
132+
def call_context(self) -> ServerCallContext | None:
133+
"""The server call context associated with this request."""
134+
return self._call_context
135+
128136
def _check_or_generate_task_id(self) -> None:
129137
"""Ensures a task ID is present, generating one if necessary."""
130138
if not self._params:

src/a2a/server/agent_execution/request_context_builder.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from abc import ABC, abstractmethod
22

33
from a2a.server.agent_execution import RequestContext
4+
from a2a.server.context import ServerCallContext
45
from a2a.types import MessageSendParams, Task
56

67

@@ -14,5 +15,6 @@ async def build(
1415
task_id: str | None = None,
1516
context_id: str | None = None,
1617
task: Task | None = None,
18+
context: ServerCallContext | None = None,
1719
) -> RequestContext:
1820
pass

src/a2a/server/agent_execution/simple_request_context_builder.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22

33
from a2a.server.agent_execution import RequestContext, RequestContextBuilder
4+
from a2a.server.context import ServerCallContext
45
from a2a.server.tasks import TaskStore
56
from a2a.types import MessageSendParams, Task
67

@@ -22,6 +23,7 @@ async def build(
2223
task_id: str | None = None,
2324
context_id: str | None = None,
2425
task: Task | None = None,
26+
context: ServerCallContext | None = None,
2527
) -> RequestContext:
2628
related_tasks: list[Task] | None = None
2729

@@ -45,4 +47,5 @@ async def build(
4547
context_id=context_id,
4648
task=task,
4749
related_tasks=related_tasks,
50+
call_context=context,
4851
)

src/a2a/server/apps/starlette_app.py

Lines changed: 57 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import logging
22

3+
from abc import ABC, abstractmethod
4+
from collections.abc import AsyncGenerator
35
from typing import Any
46

57
from starlette.applications import Starlette
68
from starlette.routing import Route
79

8-
from a2a.server.request_handlers.request_handler import RequestHandler
10+
from a2a.server.context import ServerCallContext
911
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler
10-
12+
from a2a.server.request_handlers.request_handler import RequestHandler
1113
from a2a.types import (
1214
A2AError,
1315
A2ARequest,
@@ -34,26 +36,43 @@
3436
logger = logging.getLogger(__name__)
3537

3638

37-
class A2AStarletteApplication(DefaultA2AApplication):
39+
class CallContextBuilder(ABC):
40+
"""A class for building ServerCallContexts using the Starlette Request."""
41+
42+
@abstractmethod
43+
def build(self, request: Request) -> ServerCallContext:
44+
"""Builds a ServerCallContext from a Starlette Request."""
45+
46+
47+
class A2AStarletteApplication:
3848
"""A Starlette application implementing the A2A protocol server endpoints.
3949
4050
Handles incoming JSON-RPC requests, routes them to the appropriate
4151
handler methods, and manages response generation including Server-Sent Events
4252
(SSE).
4353
"""
4454

45-
def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
55+
def __init__(
56+
self,
57+
agent_card: AgentCard,
58+
http_handler: RequestHandler,
59+
context_builder: CallContextBuilder | None = None,
60+
):
4661
"""Initializes the A2AStarletteApplication.
4762
4863
Args:
4964
agent_card: The AgentCard describing the agent's capabilities.
5065
http_handler: The handler instance responsible for processing A2A
5166
requests via http.
67+
context_builder: The CallContextBuilder used to construct the
68+
ServerCallContext passed to the http_handler. If None, no
69+
ServerCallContext is passed.
5270
"""
5371
self.agent_card = agent_card
5472
self.handler = JSONRPCHandler(
5573
agent_card=agent_card, request_handler=http_handler
5674
)
75+
self._context_builder = context_builder
5776

5877
def _generate_error_response(
5978
self, request_id: str | int | None, error: JSONRPCError | A2AError
@@ -115,6 +134,11 @@ async def _handle_requests(self, request: Request) -> Response:
115134
try:
116135
body = await request.json()
117136
a2a_request = A2ARequest.model_validate(body)
137+
call_context = (
138+
self._context_builder.build(request)
139+
if self._context_builder
140+
else None
141+
)
118142

119143
request_id = a2a_request.root.id
120144
request_obj = a2a_request.root
@@ -124,11 +148,11 @@ async def _handle_requests(self, request: Request) -> Response:
124148
TaskResubscriptionRequest | SendStreamingMessageRequest,
125149
):
126150
return await self._process_streaming_request(
127-
request_id, a2a_request
151+
request_id, a2a_request, call_context
128152
)
129153

130154
return await self._process_non_streaming_request(
131-
request_id, a2a_request
155+
request_id, a2a_request, call_context
132156
)
133157
except MethodNotImplementedError:
134158
traceback.print_exc()
@@ -154,7 +178,10 @@ async def _handle_requests(self, request: Request) -> Response:
154178
)
155179

156180
async def _process_streaming_request(
157-
self, request_id: str | int | None, a2a_request: A2ARequest
181+
self,
182+
request_id: str | int | None,
183+
a2a_request: A2ARequest,
184+
context: ServerCallContext,
158185
) -> Response:
159186
"""Processes streaming requests (message/stream or tasks/resubscribe).
160187
@@ -171,14 +198,21 @@ async def _process_streaming_request(
171198
request_obj,
172199
SendStreamingMessageRequest,
173200
):
174-
handler_result = self.handler.on_message_send_stream(request_obj)
201+
handler_result = self.handler.on_message_send_stream(
202+
request_obj, context
203+
)
175204
elif isinstance(request_obj, TaskResubscriptionRequest):
176-
handler_result = self.handler.on_resubscribe_to_task(request_obj)
205+
handler_result = self.handler.on_resubscribe_to_task(
206+
request_obj, context
207+
)
177208

178209
return self._create_response(handler_result)
179210

180211
async def _process_non_streaming_request(
181-
self, request_id: str | int | None, a2a_request: A2ARequest
212+
self,
213+
request_id: str | int | None,
214+
a2a_request: A2ARequest,
215+
context: ServerCallContext,
182216
) -> Response:
183217
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).
184218
@@ -193,18 +227,26 @@ async def _process_non_streaming_request(
193227
handler_result: Any = None
194228
match request_obj:
195229
case SendMessageRequest():
196-
handler_result = await self.handler.on_message_send(request_obj)
230+
handler_result = await self.handler.on_message_send(
231+
request_obj, context
232+
)
197233
case CancelTaskRequest():
198-
handler_result = await self.handler.on_cancel_task(request_obj)
234+
handler_result = await self.handler.on_cancel_task(
235+
request_obj, context
236+
)
199237
case GetTaskRequest():
200-
handler_result = await self.handler.on_get_task(request_obj)
238+
handler_result = await self.handler.on_get_task(
239+
request_obj, context
240+
)
201241
case SetTaskPushNotificationConfigRequest():
202242
handler_result = await self.handler.set_push_notification(
203-
request_obj
243+
request_obj,
244+
context,
204245
)
205246
case GetTaskPushNotificationConfigRequest():
206247
handler_result = await self.handler.get_push_notification(
207-
request_obj
248+
request_obj,
249+
context,
208250
)
209251
case _:
210252
logger.error(

src/a2a/server/context.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
"""Defines the ServerCallContext class."""
2+
3+
import collections.abc
4+
import typing
5+
6+
7+
State = collections.abc.MutableMapping[str, typing.Any]
8+
9+
10+
class ServerCallContext:
11+
"""A context passed when calling a server method.
12+
13+
This class allows storing arbitrary user data in the state attribute.
14+
"""
15+
16+
def __init__(self, state: State | None = None):
17+
if state is None:
18+
state = {}
19+
self._state = state
20+
21+
@property
22+
def state(self) -> State:
23+
"""Get the user-provided state."""
24+
return self._state

src/a2a/server/request_handlers/default_request_handler.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
RequestContextBuilder,
1111
SimpleRequestContextBuilder,
1212
)
13+
from a2a.server.context import ServerCallContext
1314
from a2a.server.events import (
1415
Event,
1516
EventConsumer,
@@ -70,6 +71,8 @@ def __init__(
7071
task_store: The `TaskStore` instance to manage task persistence.
7172
queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`.
7273
push_notifier: The `PushNotifier` instance for sending push notifications. Defaults to None.
74+
request_context_builder: The `RequestContextBuilder` instance used
75+
to build request contexts. Defaults to `SimpleRequestContextBuilder`.
7376
"""
7477
self.agent_executor = agent_executor
7578
self.task_store = task_store
@@ -85,14 +88,20 @@ def __init__(
8588
self._running_agents = {}
8689
self._running_agents_lock = asyncio.Lock()
8790

88-
async def on_get_task(self, params: TaskQueryParams) -> Task | None:
91+
async def on_get_task(
92+
self,
93+
params: TaskQueryParams,
94+
context: ServerCallContext | None = None,
95+
) -> Task | None:
8996
"""Default handler for 'tasks/get'."""
9097
task: Task | None = await self.task_store.get(params.id)
9198
if not task:
9299
raise ServerError(error=TaskNotFoundError())
93100
return task
94101

95-
async def on_cancel_task(self, params: TaskIdParams) -> Task | None:
102+
async def on_cancel_task(
103+
self, params: TaskIdParams, context: ServerCallContext | None = None
104+
) -> Task | None:
96105
"""Default handler for 'tasks/cancel'.
97106
98107
Attempts to cancel the task managed by the `AgentExecutor`.
@@ -150,7 +159,9 @@ async def _run_event_stream(
150159
await queue.close()
151160

152161
async def on_message_send(
153-
self, params: MessageSendParams
162+
self,
163+
params: MessageSendParams,
164+
context: ServerCallContext | None = None,
154165
) -> Message | Task:
155166
"""Default handler for 'message/send' interface (non-streaming).
156167
@@ -183,6 +194,7 @@ async def on_message_send(
183194
task_id=task.id if task else None,
184195
context_id=params.message.contextId,
185196
task=task,
197+
context=context,
186198
)
187199

188200
task_id = cast(str, request_context.task_id)
@@ -232,7 +244,9 @@ async def on_message_send(
232244
return result
233245

234246
async def on_message_send_stream(
235-
self, params: MessageSendParams
247+
self,
248+
params: MessageSendParams,
249+
context: ServerCallContext | None = None,
236250
) -> AsyncGenerator[Event]:
237251
"""Default handler for 'message/stream' (streaming).
238252
@@ -270,6 +284,7 @@ async def on_message_send_stream(
270284
task_id=task.id if task else None,
271285
context_id=params.message.contextId,
272286
task=task,
287+
context=context,
273288
)
274289

275290
task_id = cast(str, request_context.task_id)
@@ -334,7 +349,9 @@ async def _cleanup_producer(
334349
self._running_agents.pop(task_id, None)
335350

336351
async def on_set_task_push_notification_config(
337-
self, params: TaskPushNotificationConfig
352+
self,
353+
params: TaskPushNotificationConfig,
354+
context: ServerCallContext | None = None,
338355
) -> TaskPushNotificationConfig:
339356
"""Default handler for 'tasks/pushNotificationConfig/set'.
340357
@@ -355,7 +372,9 @@ async def on_set_task_push_notification_config(
355372
return params
356373

357374
async def on_get_task_push_notification_config(
358-
self, params: TaskIdParams
375+
self,
376+
params: TaskIdParams,
377+
context: ServerCallContext | None = None,
359378
) -> TaskPushNotificationConfig:
360379
"""Default handler for 'tasks/pushNotificationConfig/get'.
361380
@@ -377,7 +396,9 @@ async def on_get_task_push_notification_config(
377396
)
378397

379398
async def on_resubscribe_to_task(
380-
self, params: TaskIdParams
399+
self,
400+
params: TaskIdParams,
401+
context: ServerCallContext | None = None,
381402
) -> AsyncGenerator[Event]:
382403
"""Default handler for 'tasks/resubscribe'.
383404

0 commit comments

Comments
 (0)