Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/a2a/server/agent_execution/context.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import uuid

from a2a.server.context import ServerCallContext
from a2a.types import (
InvalidParamsError,
Message,
Expand All @@ -26,6 +27,7 @@ def __init__(
context_id: str | None = None,
task: Task | None = None,
related_tasks: list[Task] | None = None,
call_context: ServerCallContext | None = None,
):
"""Initializes the RequestContext.

Expand All @@ -43,6 +45,7 @@ def __init__(
self._context_id = context_id
self._current_task = task
self._related_tasks = related_tasks
self._call_context = call_context
# If the task id and context id were provided, make sure they
# match the request. Otherwise, create them
if self._params:
Expand Down Expand Up @@ -125,6 +128,11 @@ def configuration(self) -> MessageSendConfiguration | None:
return None
return self._params.configuration

@property
def call_context(self) -> ServerCallContext | None:
"""The server call context associated with this request."""
return self._call_context

def _check_or_generate_task_id(self) -> None:
"""Ensures a task ID is present, generating one if necessary."""
if not self._params:
Expand Down
2 changes: 2 additions & 0 deletions src/a2a/server/agent_execution/request_context_builder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from abc import ABC, abstractmethod

from a2a.server.agent_execution import RequestContext
from a2a.server.context import ServerCallContext
from a2a.types import MessageSendParams, Task


Expand All @@ -14,5 +15,6 @@ async def build(
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
context: ServerCallContext | None = None,
) -> RequestContext:
pass
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio

from a2a.server.agent_execution import RequestContext, RequestContextBuilder
from a2a.server.context import ServerCallContext
from a2a.server.tasks import TaskStore
from a2a.types import MessageSendParams, Task

Expand All @@ -22,6 +23,7 @@ async def build(
task_id: str | None = None,
context_id: str | None = None,
task: Task | None = None,
context: ServerCallContext | None = None,
) -> RequestContext:
related_tasks: list[Task] | None = None

Expand All @@ -45,4 +47,5 @@ async def build(
context_id=context_id,
task=task,
related_tasks=related_tasks,
call_context=context,
)
69 changes: 55 additions & 14 deletions src/a2a/server/apps/starlette_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import logging
import traceback

from abc import ABC, abstractmethod
from collections.abc import AsyncGenerator
from typing import Any

Expand All @@ -12,9 +13,9 @@
from starlette.responses import JSONResponse, Response
from starlette.routing import Route

from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.server.context import ServerCallContext
from a2a.server.request_handlers.jsonrpc_handler import JSONRPCHandler

from a2a.server.request_handlers.request_handler import RequestHandler
from a2a.types import (
A2AError,
A2ARequest,
Expand All @@ -41,6 +42,14 @@
logger = logging.getLogger(__name__)


class CallContextBuilder(ABC):
"""A class for building ServerCallContexts using the Starlette Request."""

@abstractmethod
def build(self, request: Request) -> ServerCallContext:
"""Builds a ServerCallContext from a Starlette Request."""


class A2AStarletteApplication:
"""A Starlette application implementing the A2A protocol server endpoints.

Expand All @@ -49,18 +58,27 @@ class A2AStarletteApplication:
(SSE).
"""

def __init__(self, agent_card: AgentCard, http_handler: RequestHandler):
def __init__(
self,
agent_card: AgentCard,
http_handler: RequestHandler,
context_builder: CallContextBuilder | None = None,
):
"""Initializes the A2AStarletteApplication.

Args:
agent_card: The AgentCard describing the agent's capabilities.
http_handler: The handler instance responsible for processing A2A
requests via http.
context_builder: The CallContextBuilder used to construct the
ServerCallContext passed to the http_handler. If None, no
ServerCallContext is passed.
"""
self.agent_card = agent_card
self.handler = JSONRPCHandler(
agent_card=agent_card, request_handler=http_handler
)
self._context_builder = context_builder

def _generate_error_response(
self, request_id: str | int | None, error: JSONRPCError | A2AError
Expand Down Expand Up @@ -122,6 +140,11 @@ async def _handle_requests(self, request: Request) -> Response:
try:
body = await request.json()
a2a_request = A2ARequest.model_validate(body)
call_context = (
self._context_builder.build(request)
if self._context_builder
else None
)

request_id = a2a_request.root.id
request_obj = a2a_request.root
Expand All @@ -131,11 +154,11 @@ async def _handle_requests(self, request: Request) -> Response:
TaskResubscriptionRequest | SendStreamingMessageRequest,
):
return await self._process_streaming_request(
request_id, a2a_request
request_id, a2a_request, call_context
)

return await self._process_non_streaming_request(
request_id, a2a_request
request_id, a2a_request, call_context
)
except MethodNotImplementedError:
traceback.print_exc()
Expand All @@ -161,7 +184,10 @@ async def _handle_requests(self, request: Request) -> Response:
)

async def _process_streaming_request(
self, request_id: str | int | None, a2a_request: A2ARequest
self,
request_id: str | int | None,
a2a_request: A2ARequest,
context: ServerCallContext,
) -> Response:
"""Processes streaming requests (message/stream or tasks/resubscribe).

Expand All @@ -178,14 +204,21 @@ async def _process_streaming_request(
request_obj,
SendStreamingMessageRequest,
):
handler_result = self.handler.on_message_send_stream(request_obj)
handler_result = self.handler.on_message_send_stream(
request_obj, context
)
elif isinstance(request_obj, TaskResubscriptionRequest):
handler_result = self.handler.on_resubscribe_to_task(request_obj)
handler_result = self.handler.on_resubscribe_to_task(
request_obj, context
)

return self._create_response(handler_result)

async def _process_non_streaming_request(
self, request_id: str | int | None, a2a_request: A2ARequest
self,
request_id: str | int | None,
a2a_request: A2ARequest,
context: ServerCallContext,
) -> Response:
"""Processes non-streaming requests (message/send, tasks/get, tasks/cancel, tasks/pushNotificationConfig/*).

Expand All @@ -200,18 +233,26 @@ async def _process_non_streaming_request(
handler_result: Any = None
match request_obj:
case SendMessageRequest():
handler_result = await self.handler.on_message_send(request_obj)
handler_result = await self.handler.on_message_send(
request_obj, context
)
case CancelTaskRequest():
handler_result = await self.handler.on_cancel_task(request_obj)
handler_result = await self.handler.on_cancel_task(
request_obj, context
)
case GetTaskRequest():
handler_result = await self.handler.on_get_task(request_obj)
handler_result = await self.handler.on_get_task(
request_obj, context
)
case SetTaskPushNotificationConfigRequest():
handler_result = await self.handler.set_push_notification(
request_obj
request_obj,
context,
)
case GetTaskPushNotificationConfigRequest():
handler_result = await self.handler.get_push_notification(
request_obj
request_obj,
context,
)
case _:
logger.error(
Expand Down
24 changes: 24 additions & 0 deletions src/a2a/server/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""Defines the ServerCallContext class."""

import collections.abc
import typing


State = collections.abc.MutableMapping[str, typing.Any]


class ServerCallContext:
"""A context passed when calling a server method.

This class allows storing arbitrary user data in the state attribute.
"""

def __init__(self, state: State | None = None):
if state is None:
state = {}
self._state = state

@property
def state(self) -> State:
"""Get the user-provided state."""
return self._state
35 changes: 28 additions & 7 deletions src/a2a/server/request_handlers/default_request_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
RequestContextBuilder,
SimpleRequestContextBuilder,
)
from a2a.server.context import ServerCallContext
from a2a.server.events import (
Event,
EventConsumer,
Expand Down Expand Up @@ -70,6 +71,8 @@ def __init__(
task_store: The `TaskStore` instance to manage task persistence.
queue_manager: The `QueueManager` instance to manage event queues. Defaults to `InMemoryQueueManager`.
push_notifier: The `PushNotifier` instance for sending push notifications. Defaults to None.
request_context_builder: The `RequestContextBuilder` instance used
to build request contexts. Defaults to `SimpleRequestContextBuilder`.
"""
self.agent_executor = agent_executor
self.task_store = task_store
Expand All @@ -85,14 +88,20 @@ def __init__(
self._running_agents = {}
self._running_agents_lock = asyncio.Lock()

async def on_get_task(self, params: TaskQueryParams) -> Task | None:
async def on_get_task(
self,
params: TaskQueryParams,
context: ServerCallContext | None = None,
) -> Task | None:
"""Default handler for 'tasks/get'."""
task: Task | None = await self.task_store.get(params.id)
if not task:
raise ServerError(error=TaskNotFoundError())
return task

async def on_cancel_task(self, params: TaskIdParams) -> Task | None:
async def on_cancel_task(
self, params: TaskIdParams, context: ServerCallContext | None = None
) -> Task | None:
"""Default handler for 'tasks/cancel'.

Attempts to cancel the task managed by the `AgentExecutor`.
Expand Down Expand Up @@ -150,7 +159,9 @@ async def _run_event_stream(
await queue.close()

async def on_message_send(
self, params: MessageSendParams
self,
params: MessageSendParams,
context: ServerCallContext | None = None,
) -> Message | Task:
"""Default handler for 'message/send' interface (non-streaming).

Expand Down Expand Up @@ -183,6 +194,7 @@ async def on_message_send(
task_id=task.id if task else None,
context_id=params.message.contextId,
task=task,
context=context,
)

task_id = cast(str, request_context.task_id)
Expand Down Expand Up @@ -232,7 +244,9 @@ async def on_message_send(
return result

async def on_message_send_stream(
self, params: MessageSendParams
self,
params: MessageSendParams,
context: ServerCallContext | None = None,
) -> AsyncGenerator[Event]:
"""Default handler for 'message/stream' (streaming).

Expand Down Expand Up @@ -270,6 +284,7 @@ async def on_message_send_stream(
task_id=task.id if task else None,
context_id=params.message.contextId,
task=task,
context=context,
)

task_id = cast(str, request_context.task_id)
Expand Down Expand Up @@ -334,7 +349,9 @@ async def _cleanup_producer(
self._running_agents.pop(task_id, None)

async def on_set_task_push_notification_config(
self, params: TaskPushNotificationConfig
self,
params: TaskPushNotificationConfig,
context: ServerCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Default handler for 'tasks/pushNotificationConfig/set'.

Expand All @@ -355,7 +372,9 @@ async def on_set_task_push_notification_config(
return params

async def on_get_task_push_notification_config(
self, params: TaskIdParams
self,
params: TaskIdParams,
context: ServerCallContext | None = None,
) -> TaskPushNotificationConfig:
"""Default handler for 'tasks/pushNotificationConfig/get'.

Expand All @@ -377,7 +396,9 @@ async def on_get_task_push_notification_config(
)

async def on_resubscribe_to_task(
self, params: TaskIdParams
self,
params: TaskIdParams,
context: ServerCallContext | None = None,
) -> AsyncGenerator[Event]:
"""Default handler for 'tasks/resubscribe'.

Expand Down
Loading