From 2d2d31ed811480370d378f4b665d5406181c125d Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 18 Feb 2026 16:16:19 +0000 Subject: [PATCH 01/15] feat(server): implement `resource scoping` for `tasks` and `push notifications` Introduces caller indentity isolation to ensure clients only access authorized resources, as mandated by the A2A spec. - Add 'owner' field to `TaskMixin` and `PushNotificationConfig` database models. - Add 'last_updated' field to `TaskMixin` for optimized sorting and indexing. - Update `DatabaseTaskStore`, `InMemoryTaskStore` and `DatabasePushNotificationConfigStore` to use `OwnerResolver`. - Add relevant Unit tests. - Add Alembic configuration to enable users to update their own databases with non-optional `owner` field in `tasks` table. --- alembic.ini | 45 ++++++++ alembic/README | 56 ++++++++++ alembic/env.py | 85 +++++++++++++++ alembic/script.py.mako | 28 +++++ .../6419d2d130f6_add_owner_to_task.py | 38 +++++++ pyproject.toml | 86 +++++++++++++++ src/a2a/server/models.py | 18 +++- src/a2a/server/owner_resolver.py | 18 ++++ ...database_push_notification_config_store.py | 62 +++++++---- src/a2a/server/tasks/database_task_store.py | 100 +++++++++++++----- src/a2a/server/tasks/inmemory_task_store.py | 85 +++++++++++---- .../tasks/push_notification_config_store.py | 17 ++- .../test_default_request_handler.py | 14 ++- ...database_push_notification_config_store.py | 89 ++++++++++++++++ .../server/tasks/test_database_task_store.py | 66 ++++++++++++ .../server/tasks/test_inmemory_task_store.py | 65 ++++++++++++ tests/server/test_owner_resolver.py | 31 ++++++ 17 files changed, 823 insertions(+), 80 deletions(-) create mode 100644 alembic.ini create mode 100644 alembic/README create mode 100644 alembic/env.py create mode 100644 alembic/script.py.mako create mode 100644 alembic/versions/6419d2d130f6_add_owner_to_task.py create mode 100644 src/a2a/server/owner_resolver.py create mode 100644 tests/server/test_owner_resolver.py diff --git a/alembic.ini b/alembic.ini new file mode 100644 index 000000000..58249b073 --- /dev/null +++ b/alembic.ini @@ -0,0 +1,45 @@ +# A generic, single database configuration. + +[alembic] + +# database URL. This is consumed by the user-maintained env.py script only. +# other means of configuring database URLs may be customized within the env.py +# file. +# IMPORTANT: This is a placeholder and an example, and should be replaced with your actual database URL. +sqlalchemy.url = sqlite+aiosqlite:///./test.db + + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARNING +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARNING +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/alembic/README b/alembic/README new file mode 100644 index 000000000..0c6d7dba1 --- /dev/null +++ b/alembic/README @@ -0,0 +1,56 @@ +# Database Migrations with Alembic + +This directory contains database migration scripts for the A2A SDK, managed by [Alembic](https://alembic.sqlalchemy.org/). + +## Configuration + +- `alembic.ini`: Global configuration for Alembic, including the database URL. +- `env.py`: Python script that runs when the Alembic environment is invoked. It configures the SQLAlchemy engine and connects it to the migration context. +- `versions/`: Directory containing individual migration scripts. + +## Common Commands + +All commands should be run from the project root using `uv run`. + +### Viewing Status +```bash +# View current migration version of the database +uv run alembic current + +# View migration history +uv run alembic history --verbose +``` + +### Running Migrations +```bash +# Upgrade to the latest version +uv run alembic upgrade head + +# Downgrade by one version +uv run alembic downgrade base +``` + +### Creating Migrations +```bash +# Create a new migration manually +uv run alembic revision -m "description of changes" + +# Create a new migration automatically (detects changes in models.py) +uv run alembic revision --autogenerate -m "description of changes" +``` + +## Troubleshooting + +### "duplicate column name" error +If you see an error like `sqlalchemy.exc.OperationalError: (sqlite3.OperationalError) duplicate column name: owner`, it usually means the column was already created (perhaps by `Base.metadata.create_all()` in tests or development) but Alembic doesn't know about it yet. + +To fix this, "stamp" the database to tell Alembic it is already at the latest version: +```bash +uv run alembic stamp head +``` + +## How to add a new migration +1. Modify the models in `src/a2a/server/models.py`. +2. Run `uv run alembic revision --autogenerate -m "Add new field to Task"`. +3. Review the generated script in `alembic/versions/`. +4. Apply the migration with `uv run alembic upgrade head`. diff --git a/alembic/env.py b/alembic/env.py new file mode 100644 index 000000000..d541fe140 --- /dev/null +++ b/alembic/env.py @@ -0,0 +1,85 @@ +import asyncio + +from logging.config import fileConfig + +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import async_engine_from_config + +from a2a.server.models import Base +from alembic import context + + +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. +config = context.config + +# Interpret the config file for Python logging. +# This line sets up loggers basically. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# add your model's MetaData object here for 'autogenerate' support +target_metadata = Base.metadata + +# other values from the config, defined by the needs of env.py, +# can be acquired: +# my_important_option = config.get_main_option("my_important_option") +# ... etc. + + +def run_migrations_offline() -> None: + """Run migrations in 'offline' mode. + + This configures the context with just a URL + and not an Engine, though an Engine is acceptable + here as well. By skipping the Engine creation + we don't even need a DBAPI to be available. + + Calls to context.execute() here emit the given string to the + script output. + + """ + url = config.get_main_option('sqlalchemy.url') + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={'paramstyle': 'named'}, + ) + + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): + context.configure(connection=connection, target_metadata=target_metadata) + + with context.begin_transaction(): + context.run_migrations() + + +async def run_async_migrations(): + """In this scenario we need to create an Engine + and associate a connection with the context. + """ + connectable = async_engine_from_config( + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool, + ) + + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + + await connectable.dispose() + + +def run_migrations_online(): + """Run migrations in 'online' mode.""" + asyncio.run(run_async_migrations()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/alembic/script.py.mako b/alembic/script.py.mako new file mode 100644 index 000000000..11016301e --- /dev/null +++ b/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, Sequence[str], None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + """Upgrade schema.""" + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + """Downgrade schema.""" + ${downgrades if downgrades else "pass"} diff --git a/alembic/versions/6419d2d130f6_add_owner_to_task.py b/alembic/versions/6419d2d130f6_add_owner_to_task.py new file mode 100644 index 000000000..3b96a5c9e --- /dev/null +++ b/alembic/versions/6419d2d130f6_add_owner_to_task.py @@ -0,0 +1,38 @@ +"""add_owner_to_task + +Revision ID: 6419d2d130f6 +Revises: +Create Date: 2026-02-17 09:23:06.758085 + +""" + +from collections.abc import Sequence + +import sqlalchemy as sa + +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = '6419d2d130f6' +down_revision: str | Sequence[str] | None = None +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Upgrade schema.""" + op.add_column( + 'tasks', + sa.Column( + 'owner', + sa.String(255), + nullable=False, + server_default='unknown', # Set your desired default value here + ), + ) + + +def downgrade() -> None: + """Downgrade schema.""" + op.drop_column('tasks', 'owner') diff --git a/pyproject.toml b/pyproject.toml index 1a8f0af68..0d580f745 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -323,3 +323,89 @@ docstring-code-format = true docstring-code-line-length = "dynamic" quote-style = "single" indent-style = "space" + + +[tool.alembic] + +# path to migration scripts. +# this is typically a path given in POSIX (e.g. forward slashes) +# format, relative to the token %(here)s which refers to the location of this +# ini file +script_location = "%(here)s/alembic" + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +# file_template = "%%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s" +# Or organize into date-based subdirectories (requires recursive_version_locations = true) +# file_template = "%%(year)d/%%(month).2d/%%(day).2d_%%(hour).2d%%(minute).2d_%%(second).2d_%%(rev)s_%%(slug)s" + +# additional paths to be prepended to sys.path. defaults to the current working directory. +prepend_sys_path = [ + "." +] + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the tzdata library which can be installed by adding +# `alembic[tz]` to the pip requirements. +# string value is passed to ZoneInfo() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +# sourceless = false + +# version location specification; This defaults +# to /versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# version_locations = [ +# "%(here)s/alembic/versions", +# "%(here)s/foo/bar" +# ] + + +# set to 'true' to search source files recursively +# in each "version_locations" directory +# new in Alembic version 1.10 +# recursive_version_locations = false + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = "utf-8" + +# This section defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples +# [[tool.alembic.post_write_hooks]] +# format using "black" - use the console_scripts runner, +# against the "black" entrypoint +# name = "black" +# type = "console_scripts" +# entrypoint = "black" +# options = "-l 79 REVISION_SCRIPT_FILENAME" +# +# [[tool.alembic.post_write_hooks]] +# lint with attempts to fix using "ruff" - use the module runner, against the "ruff" module +# name = "ruff" +# type = "module" +# module = "ruff" +# options = "check --fix REVISION_SCRIPT_FILENAME" +# +# [[tool.alembic.post_write_hooks]] +# Alternatively, use the exec runner to execute a binary found on your PATH +# name = "ruff" +# type = "exec" +# executable = "ruff" +# options = "check --fix REVISION_SCRIPT_FILENAME" + diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index 4b0f7504c..636efedcb 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -1,3 +1,5 @@ +import datetime + from typing import TYPE_CHECKING, Any, Generic, TypeVar @@ -16,7 +18,7 @@ def override(func): # noqa: ANN001, ANN201 try: - from sqlalchemy import JSON, Dialect, LargeBinary, String + from sqlalchemy import JSON, Dialect, Index, LargeBinary, String from sqlalchemy.orm import ( DeclarativeBase, Mapped, @@ -127,6 +129,8 @@ class TaskMixin: kind: Mapped[str] = mapped_column( String(16), nullable=False, default='task' ) + owner: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + last_updated: Mapped[datetime] = mapped_column(String(22), nullable=True) # Properly typed Pydantic fields with automatic serialization status: Mapped[TaskStatus] = mapped_column(PydanticType(TaskStatus)) @@ -152,6 +156,17 @@ def __repr__(self) -> str: f'context_id="{self.context_id}", status="{self.status}")>' ) + @declared_attr + @classmethod + def __table_args__(cls) -> tuple[Any, ...]: + """Define a unique index (owner, last_updated) for each table that uses the mixin.""" + tablename = getattr(cls, '__tablename__', 'tasks') + return ( + Index( + f'idx_{tablename}_owner_last_updated', 'owner', 'last_updated' + ), + ) + def create_task_model( table_name: str = 'tasks', base: type[DeclarativeBase] = Base @@ -212,6 +227,7 @@ class PushNotificationConfigMixin: task_id: Mapped[str] = mapped_column(String(36), primary_key=True) config_id: Mapped[str] = mapped_column(String(255), primary_key=True) config_data: Mapped[bytes] = mapped_column(LargeBinary, nullable=False) + owner: Mapped[str] = mapped_column(String(255), nullable=False, index=True) @override def __repr__(self) -> str: diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py new file mode 100644 index 000000000..7c2756075 --- /dev/null +++ b/src/a2a/server/owner_resolver.py @@ -0,0 +1,18 @@ +from collections.abc import Callable + +from a2a.server.context import ServerCallContext + + +# Definition +OwnerResolver = Callable[[ServerCallContext], str] + + +# Example Default Implementation +def resolve_user_scope(context: ServerCallContext) -> str: + """Resolves the owner scope based on the user in the context.""" + if not context: + return 'unknown' + if not context.user: + raise ValueError('User not found in context.') + # Example: Basic user name. Adapt as needed for your user model. + return context.user.user_name diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index e125f22a1..b1a30157f 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -8,11 +8,7 @@ try: - from sqlalchemy import ( - Table, - delete, - select, - ) + from sqlalchemy import Table, and_, delete, select from sqlalchemy.ext.asyncio import ( AsyncEngine, AsyncSession, @@ -29,11 +25,13 @@ "or 'pip install a2a-sdk[sql]'" ) from e +from a2a.server.context import ServerCallContext from a2a.server.models import ( Base, PushNotificationConfigModel, create_push_notification_config_model, ) +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -59,6 +57,7 @@ class DatabasePushNotificationConfigStore(PushNotificationConfigStore): _initialized: bool config_model: type[PushNotificationConfigModel] _fernet: 'Fernet | None' + owner_resolver: OwnerResolver def __init__( self, @@ -66,6 +65,7 @@ def __init__( create_table: bool = True, table_name: str = 'push_notification_configs', encryption_key: str | bytes | None = None, + owner_resolver: OwnerResolver = resolve_user_scope, ) -> None: """Initializes the DatabasePushNotificationConfigStore. @@ -76,6 +76,7 @@ def __init__( encryption_key: A key for encrypting sensitive configuration data. If provided, `config_data` will be encrypted in the database. The key must be a URL-safe base64-encoded 32-byte key. + owner_resolver: Function to resolve the owner from the context. """ logger.debug( 'Initializing DatabasePushNotificationConfigStore with existing engine, table: %s', @@ -87,6 +88,7 @@ def __init__( ) self.create_table = create_table self._initialized = False + self.owner_resolver = owner_resolver self.config_model = ( PushNotificationConfigModel if table_name == 'push_notification_configs' @@ -139,7 +141,7 @@ async def _ensure_initialized(self) -> None: await self.initialize() def _to_orm( - self, task_id: str, config: PushNotificationConfig + self, task_id: str, config: PushNotificationConfig, owner: str ) -> PushNotificationConfigModel: """Maps a Pydantic PushNotificationConfig to a SQLAlchemy model instance. @@ -155,6 +157,7 @@ def _to_orm( return self.config_model( task_id=task_id, config_id=config.id, + owner=owner, config_data=data_to_store, ) @@ -223,30 +226,43 @@ def _from_orm( ) from e async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext | None = None, ) -> None: """Sets or updates the push notification configuration for a task.""" await self._ensure_initialized() + owner = self.owner_resolver(context) config_to_save = notification_config.model_copy() if config_to_save.id is None: config_to_save.id = task_id - db_config = self._to_orm(task_id, config_to_save) + db_config = self._to_orm(task_id, config_to_save, owner) async with self.async_session_maker.begin() as session: await session.merge(db_config) logger.debug( - 'Push notification config for task %s with config id %s saved/updated.', + 'Push notification config for task %s with config id %s for owner %s saved/updated.', task_id, config_to_save.id, + owner, ) - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: - """Retrieves all push notification configurations for a task.""" + async def get_info( + self, + task_id: str, + context: ServerCallContext | None = None, + ) -> list[PushNotificationConfig]: + """Retrieves all push notification configurations for a task, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker() as session: stmt = select(self.config_model).where( - self.config_model.task_id == task_id + and_( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) ) result = await session.execute(stmt) models = result.scalars().all() @@ -257,24 +273,32 @@ async def get_info(self, task_id: str) -> list[PushNotificationConfig]: configs.append(self._from_orm(model)) except ValueError: # noqa: PERF203 logger.exception( - 'Could not deserialize push notification config for task %s, config %s', + 'Could not deserialize push notification config for task %s, config %s, owner %s', model.task_id, model.config_id, + owner, ) return configs async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + config_id: str | None = None, + context: ServerCallContext | None = None, ) -> None: """Deletes push notification configurations for a task. If config_id is provided, only that specific configuration is deleted. - If config_id is None, all configurations for the task are deleted. + If config_id is None, all configurations for the task for the owner are deleted. """ await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker.begin() as session: stmt = delete(self.config_model).where( - self.config_model.task_id == task_id + and_( + self.config_model.task_id == task_id, + self.config_model.owner == owner, + ) ) if config_id is not None: stmt = stmt.where(self.config_model.config_id == config_id) @@ -283,13 +307,15 @@ async def delete_info( if result.rowcount > 0: logger.info( - 'Deleted %s push notification config(s) for task %s.', + 'Deleted %s push notification config(s) for task %s, owner %s.', result.rowcount, task_id, + owner, ) else: logger.warning( - 'Attempted to delete push notification config for task %s with config_id: %s that does not exist.', + 'Attempted to delete push notification config for task %s, owner %s with config_id: %s that does not exist.', task_id, + owner, config_id, ) diff --git a/src/a2a/server/tasks/database_task_store.py b/src/a2a/server/tasks/database_task_store.py index 1605c601a..503be64d2 100644 --- a/src/a2a/server/tasks/database_task_store.py +++ b/src/a2a/server/tasks/database_task_store.py @@ -30,6 +30,7 @@ from a2a.server.context import ServerCallContext from a2a.server.models import Base, TaskModel, create_task_model +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.task_store import TaskStore, TasksPage from a2a.types import ListTasksParams, Task from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE @@ -50,12 +51,14 @@ class DatabaseTaskStore(TaskStore): create_table: bool _initialized: bool task_model: type[TaskModel] + owner_resolver: OwnerResolver def __init__( self, engine: AsyncEngine, create_table: bool = True, table_name: str = 'tasks', + owner_resolver: OwnerResolver = resolve_user_scope, ) -> None: """Initializes the DatabaseTaskStore. @@ -63,6 +66,7 @@ def __init__( engine: An existing SQLAlchemy AsyncEngine to be used by Task Store create_table: If true, create tasks table on initialization. table_name: Name of the database table. Defaults to 'tasks'. + owner_resolver: Function to resolve the owner from the context. """ logger.debug( 'Initializing DatabaseTaskStore with existing engine, table: %s', @@ -74,6 +78,7 @@ def __init__( ) self.create_table = create_table self._initialized = False + self.owner_resolver = owner_resolver self.task_model = ( TaskModel @@ -104,12 +109,14 @@ async def _ensure_initialized(self) -> None: if not self._initialized: await self.initialize() - def _to_orm(self, task: Task) -> TaskModel: + def _to_orm(self, task: Task, owner: str) -> TaskModel: """Maps a Pydantic Task to a SQLAlchemy TaskModel instance.""" return self.task_model( id=task.id, context_id=task.context_id, kind=task.kind, + owner=owner, + last_updated=task.status.timestamp, status=task.status, artifacts=task.artifacts, history=task.history, @@ -123,6 +130,7 @@ def _from_orm(self, task_model: TaskModel) -> Task: 'id': task_model.id, 'context_id': task_model.context_id, 'kind': task_model.kind, + 'owner': task_model.owner, 'status': task_model.status, 'artifacts': task_model.artifacts, 'history': task_model.history, @@ -134,38 +142,60 @@ def _from_orm(self, task_model: TaskModel) -> Task: async def save( self, task: Task, context: ServerCallContext | None = None ) -> None: - """Saves or updates a task in the database.""" + """Saves or updates a task in the database for the resolved owner.""" await self._ensure_initialized() - db_task = self._to_orm(task) + owner = self.owner_resolver(context) + db_task = self._to_orm(task, owner) async with self.async_session_maker.begin() as session: await session.merge(db_task) - logger.debug('Task %s saved/updated successfully.', task.id) + logger.debug( + 'Task %s for owner %s saved/updated successfully.', + task.id, + owner, + ) async def get( self, task_id: str, context: ServerCallContext | None = None ) -> Task | None: - """Retrieves a task from the database by ID.""" + """Retrieves a task from the database by ID, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker() as session: - stmt = select(self.task_model).where(self.task_model.id == task_id) + stmt = select(self.task_model).where( + and_( + self.task_model.id == task_id, + self.task_model.owner == owner, + ) + ) result = await session.execute(stmt) task_model = result.scalar_one_or_none() if task_model: task = self._from_orm(task_model) - logger.debug('Task %s retrieved successfully.', task_id) + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) return task - logger.debug('Task %s not found in store.', task_id) + logger.debug( + 'Task %s not found in store for owner %s.', task_id, owner + ) return None async def list( self, params: ListTasksParams, context: ServerCallContext | None = None ) -> TasksPage: - """Retrieves all tasks from the database.""" + """Retrieves tasks from the database based on provided parameters, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) + logger.debug('Listing tasks for owner %s with params %s', owner, params) + async with self.async_session_maker() as session: - timestamp_col = self.task_model.status['timestamp'].as_string() - base_stmt = select(self.task_model) + timestamp_col = self.task_model.last_updated + base_stmt = select(self.task_model).where( + self.task_model.owner == owner + ) # Add filters if params.context_id: @@ -202,30 +232,36 @@ async def list( start_task = ( await session.execute( select(self.task_model).where( - self.task_model.id == start_task_id + and_( + self.task_model.id == start_task_id, + self.task_model.owner == owner, + ) ) ) ).scalar_one_or_none() if not start_task: raise ValueError(f'Invalid page token: {params.page_token}') - if start_task.status.timestamp: - stmt = stmt.where( - or_( - and_( - timestamp_col == start_task.status.timestamp, - self.task_model.id <= start_task.id, - ), - timestamp_col < start_task.status.timestamp, - timestamp_col.is_(None), + + start_task_timestamp = start_task.status.timestamp + where_clauses = [] + if start_task_timestamp: + where_clauses.append( + and_( + timestamp_col == start_task_timestamp, + self.task_model.id <= start_task_id, ) ) + where_clauses.append(timestamp_col < start_task_timestamp) + where_clauses.append(timestamp_col.is_(None)) else: - stmt = stmt.where( + where_clauses.append( and_( timestamp_col.is_(None), - self.task_model.id <= start_task.id, + self.task_model.id <= start_task_id, ) ) + stmt = stmt.where(or_(*where_clauses)) + page_size = params.page_size or DEFAULT_LIST_TASKS_PAGE_SIZE stmt = stmt.limit(page_size + 1) # Add 1 for next page token @@ -248,17 +284,27 @@ async def list( async def delete( self, task_id: str, context: ServerCallContext | None = None ) -> None: - """Deletes a task from the database by ID.""" + """Deletes a task from the database by ID, for the given owner.""" await self._ensure_initialized() + owner = self.owner_resolver(context) async with self.async_session_maker.begin() as session: - stmt = delete(self.task_model).where(self.task_model.id == task_id) + stmt = delete(self.task_model).where( + and_( + self.task_model.id == task_id, + self.task_model.owner == owner, + ) + ) result = await session.execute(stmt) # Commit is automatic when using session.begin() if result.rowcount > 0: - logger.info('Task %s deleted successfully.', task_id) + logger.info( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) else: logger.warning( - 'Attempted to delete nonexistent task with id: %s', task_id + 'Attempted to delete nonexistent task with id: %s and owner %s', + task_id, + owner, ) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 31d42a310..246282650 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -1,9 +1,11 @@ import asyncio import logging +from collections import defaultdict from datetime import datetime, timezone from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.task_store import TaskStore, TasksPage from a2a.types import ListTasksParams, Task from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE @@ -16,45 +18,70 @@ class InMemoryTaskStore(TaskStore): """In-memory implementation of TaskStore. - Stores task objects in a dictionary in memory. Task data is lost when the - server process stops. + Stores task objects in a nested dictionary in memory, keyed by owner then task_id. + Task data is lost when the server process stops. """ - def __init__(self) -> None: + def __init__( + self, + owner_resolver: OwnerResolver = resolve_user_scope, + ) -> None: """Initializes the InMemoryTaskStore.""" logger.debug('Initializing InMemoryTaskStore') - self.tasks: dict[str, Task] = {} + self.tasks: dict[str, dict[str, Task]] = defaultdict(dict) self.lock = asyncio.Lock() + self.owner_resolver = owner_resolver async def save( self, task: Task, context: ServerCallContext | None = None ) -> None: - """Saves or updates a task in the in-memory store.""" + """Saves or updates a task in the in-memory store for the resolved owner.""" + owner = self.owner_resolver(context) + async with self.lock: - self.tasks[task.id] = task - logger.debug('Task %s saved successfully.', task.id) + self.tasks[owner][task.id] = task + logger.debug( + 'Task %s for owner %s saved successfully.', task.id, owner + ) async def get( self, task_id: str, context: ServerCallContext | None = None ) -> Task | None: - """Retrieves a task from the in-memory store by ID.""" + """Retrieves a task from the in-memory store by ID, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - logger.debug('Attempting to get task with id: %s', task_id) - task = self.tasks.get(task_id) - if task: - logger.debug('Task %s retrieved successfully.', task_id) - else: - logger.debug('Task %s not found in store.', task_id) - return task + logger.debug( + 'Attempting to get task with id: %s for owner: %s', + task_id, + owner, + ) + owner_tasks = self.tasks.get(owner) + if owner_tasks: + task = owner_tasks.get(task_id) + if task: + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) + return task + logger.debug( + 'Task %s not found in store for owner %s.', task_id, owner + ) + return None async def list( self, params: ListTasksParams, context: ServerCallContext | None = None, ) -> TasksPage: - """Retrieves a list of tasks from the store.""" + """Retrieves a list of tasks from the store, for the given owner.""" + owner = self.owner_resolver(context) + logger.debug('Listing tasks for owner %s with params %s', owner, params) + async with self.lock: - tasks = list(self.tasks.values()) + owner_tasks = self.tasks.get(owner, {}) + tasks = list(owner_tasks.values()) # Filter tasks if params.context_id: @@ -118,13 +145,25 @@ async def list( async def delete( self, task_id: str, context: ServerCallContext | None = None ) -> None: - """Deletes a task from the in-memory store by ID.""" + """Deletes a task from the in-memory store by ID, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - logger.debug('Attempting to delete task with id: %s', task_id) - if task_id in self.tasks: - del self.tasks[task_id] - logger.debug('Task %s deleted successfully.', task_id) + logger.debug( + 'Attempting to delete task with id: %s for owner %s', + task_id, + owner, + ) + if owner in self.tasks and task_id in self.tasks[owner]: + del self.tasks[owner][task_id] + logger.debug( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) + if not self.tasks[owner]: + del self.tasks[owner] + logger.debug('Removed empty owner %s from store.', owner) else: logger.warning( - 'Attempted to delete nonexistent task with id: %s', task_id + 'Attempted to delete nonexistent task with id: %s for owner %s', + task_id, + owner, ) diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index efe46b40a..388d86c1e 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from a2a.server.context import ServerCallContext from a2a.types import PushNotificationConfig @@ -8,16 +9,26 @@ class PushNotificationConfigStore(ABC): @abstractmethod async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext | None = None, ) -> None: """Sets or updates the push notification configuration for a task.""" @abstractmethod - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: + async def get_info( + self, + task_id: str, + context: ServerCallContext | None = None, + ) -> list[PushNotificationConfig]: """Retrieves the push notification configuration for a task.""" @abstractmethod async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + config_id: str | None = None, + context: ServerCallContext | None = None, ) -> None: """Deletes the push notification configuration for a task.""" diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index daeba947f..16c85d400 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -915,9 +915,8 @@ async def test_on_message_send_non_blocking(): ), ) - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) assert result is not None assert isinstance(result, Task) @@ -927,7 +926,7 @@ async def test_on_message_send_non_blocking(): task: Task | None = None for _ in range(5): await asyncio.sleep(0.1) - task = await task_store.get(result.id) + task = await task_store.get(result.id, context) assert task is not None if task.status.state == TaskState.completed: break @@ -964,9 +963,8 @@ async def test_on_message_send_limit_history(): ), ) - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + context = create_server_call_context() + result = await request_handler.on_message_send(params, context) # verify that history_length is honored assert result is not None @@ -975,7 +973,7 @@ async def test_on_message_send_limit_history(): assert result.status.state == TaskState.completed # verify that history is still persisted to the store - task = await task_store.get(result.id) + task = await task_store.get(result.id, context) assert task is not None assert task.history is not None and len(task.history) > 1 diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 0c3bd4683..26d968912 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -3,6 +3,8 @@ from collections.abc import AsyncGenerator import pytest +from a2a.server.context import ServerCallContext +from a2a.auth.user import User # Skip entire test module if SQLAlchemy is not installed @@ -94,6 +96,21 @@ ) +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + @pytest_asyncio.fixture(params=DB_CONFIGS) async def db_store_parameterized( request, @@ -547,6 +564,7 @@ async def test_parsing_error_after_successful_decryption( task_id=task_id, config_id=config_id, config_data=encrypted_data, + owner='test-owner', ) session.add(db_model) await session.commit() @@ -563,3 +581,74 @@ async def test_parsing_error_after_successful_decryption( with pytest.raises(ValueError): db_store_parameterized._from_orm(db_model_retrieved) # type: ignore + + +@pytest.mark.asyncio +async def test_owner_resource_scoping( + db_store_parameterized: DatabasePushNotificationConfigStore, +) -> None: + """Test that operations are scoped to the correct owner.""" + config_store = db_store_parameterized + + context_user1 = ServerCallContext(user=TestUser(user_name='user1')) + context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + + # Create configs for different owners + task1_u1_config1 = PushNotificationConfig( + id='t1-u1-c1', url='http://u1.com/1' + ) + task1_u1_config2 = PushNotificationConfig( + id='t1-u1-c2', url='http://u1.com/2' + ) + task1_u2_config1 = PushNotificationConfig( + id='t1-u2-c1', url='http://u2.com/1' + ) + task2_u1_config1 = PushNotificationConfig( + id='t2-u1-c1', url='http://u1.com/3' + ) + + await config_store.set_info('task1', task1_u1_config1, context_user1) + await config_store.set_info('task1', task1_u1_config2, context_user1) + await config_store.set_info('task1', task1_u2_config1, context_user2) + await config_store.set_info('task2', task2_u1_config1, context_user1) + + # Test GET_INFO + # User 1 should get only their configs for task1 + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 2 + assert {c.id for c in u1_task1_configs} == {'t1-u1-c1', 't1-u1-c2'} + + # User 2 should get only their configs for task1 + u2_task1_configs = await config_store.get_info('task1', context_user2) + assert len(u2_task1_configs) == 1 + assert u2_task1_configs[0].id == 't1-u2-c1' + + # User 2 should get no configs for task2 + u2_task2_configs = await config_store.get_info('task2', context_user2) + assert len(u2_task2_configs) == 0 + + # User 1 should get their config for task2 + u1_task2_configs = await config_store.get_info('task2', context_user1) + assert len(u1_task2_configs) == 1 + assert u1_task2_configs[0].id == 't2-u1-c1' + + # Test DELETE_INFO + # User 2 deleting User 1's config should not work + await config_store.delete_info('task1', 't1-u1-c1', context_user2) + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 2 + + # User 1 deleting their own config + await config_store.delete_info('task1', 't1-u1-c1', context_user1) + u1_task1_configs = await config_store.get_info('task1', context_user1) + assert len(u1_task1_configs) == 1 + assert u1_task1_configs[0].id == 't1-u1-c2' + + # User 1 deleting all configs for task2 + await config_store.delete_info('task2', context=context_user1) + u1_task2_configs = await config_store.get_info('task2', context_user1) + assert len(u1_task2_configs) == 0 + + # Cleanup remaining + await config_store.delete_info('task1', context=context_user1) + await config_store.delete_info('task1', context=context_user2) diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index 495d2e4fd..5c35b391a 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -28,6 +28,23 @@ TaskStatus, TextPart, ) +from a2a.auth.user import User +from a2a.server.context import ServerCallContext + + +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name # DSNs for different databases @@ -608,4 +625,53 @@ async def test_metadata_field_mapping( await db_store_parameterized.delete('task-metadata-test-4') +@pytest.mark.asyncio +async def test_owner_resource_scoping( + db_store_parameterized: DatabaseTaskStore, +) -> None: + """Test that operations are scoped to the correct owner.""" + task_store = db_store_parameterized + + context_user1 = ServerCallContext(user=TestUser(user_name='user1')) + context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + + # Create tasks for different owners + task1_user1 = MINIMAL_TASK_OBJ.model_copy(update={'id': 'u1-task1'}) + task2_user1 = MINIMAL_TASK_OBJ.model_copy(update={'id': 'u1-task2'}) + task1_user2 = MINIMAL_TASK_OBJ.model_copy(update={'id': 'u2-task1'}) + + await task_store.save(task1_user1, context_user1) + await task_store.save(task2_user1, context_user1) + await task_store.save(task1_user2, context_user2) + + # Test GET + assert await task_store.get('u1-task1', context_user1) is not None + assert await task_store.get('u1-task1', context_user2) is None + assert await task_store.get('u2-task1', context_user1) is None + assert await task_store.get('u2-task1', context_user2) is not None + + # Test LIST + params = ListTasksParams() + page_user1 = await task_store.list(params, context_user1) + assert len(page_user1.tasks) == 2 + assert {t.id for t in page_user1.tasks} == {'u1-task1', 'u1-task2'} + assert page_user1.total_size == 2 + + page_user2 = await task_store.list(params, context_user2) + assert len(page_user2.tasks) == 1 + assert {t.id for t in page_user2.tasks} == {'u2-task1'} + assert page_user2.total_size == 1 + + # Test DELETE + await task_store.delete('u1-task1', context_user2) # Should not delete + assert await task_store.get('u1-task1', context_user1) is not None + + await task_store.delete('u1-task1', context_user1) # Should delete + assert await task_store.get('u1-task1', context_user1) is None + + # Cleanup remaining tasks + await task_store.delete('u1-task2', context_user1) + await task_store.delete('u2-task1', context_user2) + + # Ensure aiosqlite, asyncpg, and aiomysql are installed in the test environment (added to pyproject.toml). diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index ee91b9261..2fd77e0b0 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -1,9 +1,26 @@ from typing import Any +from a2a.server.context import ServerCallContext import pytest from a2a.server.tasks import InMemoryTaskStore from a2a.types import ListTasksParams, Task, TaskState, TaskStatus +from a2a.auth.user import User + + +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name MINIMAL_TASK: dict[str, Any] = { @@ -259,3 +276,51 @@ async def test_in_memory_task_store_delete_nonexistent() -> None: """Test deleting a nonexistent task.""" store = InMemoryTaskStore() await store.delete('nonexistent') + + +@pytest.mark.asyncio +async def test_owner_resource_scoping() -> None: + """Test that operations are scoped to the correct owner.""" + store = InMemoryTaskStore() + task = Task(**MINIMAL_TASK) + + context_user1 = ServerCallContext(user=TestUser(user_name='user1')) + context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + + # Create tasks for different owners + task1_user1 = task.model_copy(update={'id': 'u1-task1'}) + task2_user1 = task.model_copy(update={'id': 'u1-task2'}) + task1_user2 = task.model_copy(update={'id': 'u2-task1'}) + + await store.save(task1_user1, context_user1) + await store.save(task2_user1, context_user1) + await store.save(task1_user2, context_user2) + + # Test GET + assert await store.get('u1-task1', context_user1) is not None + assert await store.get('u1-task1', context_user2) is None + assert await store.get('u2-task1', context_user1) is None + assert await store.get('u2-task1', context_user2) is not None + + # Test LIST + params = ListTasksParams() + page_user1 = await store.list(params, context_user1) + assert len(page_user1.tasks) == 2 + assert {t.id for t in page_user1.tasks} == {'u1-task1', 'u1-task2'} + assert page_user1.total_size == 2 + + page_user2 = await store.list(params, context_user2) + assert len(page_user2.tasks) == 1 + assert {t.id for t in page_user2.tasks} == {'u2-task1'} + assert page_user2.total_size == 1 + + # Test DELETE + await store.delete('u1-task1', context_user2) # Should not delete + assert await store.get('u1-task1', context_user1) is not None + + await store.delete('u1-task1', context_user1) # Should delete + assert await store.get('u1-task1', context_user1) is None + + # Cleanup remaining tasks + await store.delete('u1-task2', context_user1) + await store.delete('u2-task1', context_user2) diff --git a/tests/server/test_owner_resolver.py b/tests/server/test_owner_resolver.py new file mode 100644 index 000000000..8a0686865 --- /dev/null +++ b/tests/server/test_owner_resolver.py @@ -0,0 +1,31 @@ +from a2a.auth.user import User + +from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import resolve_user_scope + + +class TestUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def test_resolve_user_scope_valid_user(): + """Test resolve_user_scope with a valid user in the context.""" + user = TestUser(user_name='testuser') + context = ServerCallContext(user=user) + assert resolve_user_scope(context) == 'testuser' + + +def test_resolve_user_scope_no_context(): + """Test resolve_user_scope when the context is None.""" + assert resolve_user_scope(None) == 'unknown' From 6093f7f197e2a1ab0fce09c32befe6c9dd24583c Mon Sep 17 00:00:00 2001 From: sokoliva Date: Wed, 18 Feb 2026 16:26:29 +0000 Subject: [PATCH 02/15] fix: add poolclass to allow.txt --- .github/actions/spelling/allow.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/actions/spelling/allow.txt b/.github/actions/spelling/allow.txt index 8d0b13c8c..cda0a4b3d 100644 --- a/.github/actions/spelling/allow.txt +++ b/.github/actions/spelling/allow.txt @@ -69,6 +69,7 @@ oauthoidc oidc opensource otherurl +poolclass postgres POSTGRES postgresql From 6600b4719076e3a96b43a113c29f874006adc80b Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 09:58:44 +0000 Subject: [PATCH 03/15] fix: test_inmemory_task_store.py merge caused error --- tests/server/tasks/test_inmemory_task_store.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index ed30f2356..97befb755 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -269,15 +269,23 @@ async def test_in_memory_task_store_delete_nonexistent() -> None: async def test_owner_resource_scoping() -> None: """Test that operations are scoped to the correct owner.""" store = InMemoryTaskStore() - task = Task(**MINIMAL_TASK) + task = create_minimal_task() context_user1 = ServerCallContext(user=TestUser(user_name='user1')) context_user2 = ServerCallContext(user=TestUser(user_name='user2')) # Create tasks for different owners - task1_user1 = task.model_copy(update={'id': 'u1-task1'}) - task2_user1 = task.model_copy(update={'id': 'u1-task2'}) - task1_user2 = task.model_copy(update={'id': 'u2-task1'}) + task1_user1 = Task() + task1_user1.CopyFrom(task) + task1_user1.id = 'u1-task1' + + task2_user1 = Task() + task2_user1.CopyFrom(task) + task2_user1.id = 'u1-task2' + + task1_user2 = Task() + task1_user2.CopyFrom(task) + task1_user2.id = 'u2-task1' await store.save(task1_user1, context_user1) await store.save(task2_user1, context_user1) From ea89bbbe6296ee180fa2bfa43262d074b13a88e6 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 10:47:12 +0000 Subject: [PATCH 04/15] fix: - add alembic to dev field in pyproject.toml - fix elmbic README.md error - make ServerCallContext optional in OwnerResolver --- alembic/README | 3 +++ pyproject.toml | 1 + src/a2a/server/owner_resolver.py | 2 +- 3 files changed, 5 insertions(+), 1 deletion(-) diff --git a/alembic/README b/alembic/README index 0c6d7dba1..06ec9e9a8 100644 --- a/alembic/README +++ b/alembic/README @@ -27,6 +27,9 @@ uv run alembic history --verbose uv run alembic upgrade head # Downgrade by one version +uv run alembic downgrade -1 + +# Revert all migrations uv run alembic downgrade base ``` diff --git a/pyproject.toml b/pyproject.toml index 6be567814..7e3b6a2f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,6 +106,7 @@ style = "pep440" [dependency-groups] dev = [ + "alembic>=1.14.0", "mypy>=1.15.0", "PyJWT>=2.0.0", "pytest>=8.3.5", diff --git a/src/a2a/server/owner_resolver.py b/src/a2a/server/owner_resolver.py index 6c50cd79f..4fa310b92 100644 --- a/src/a2a/server/owner_resolver.py +++ b/src/a2a/server/owner_resolver.py @@ -4,7 +4,7 @@ # Definition -OwnerResolver = Callable[[ServerCallContext], str] +OwnerResolver = Callable[[ServerCallContext | None], str] # Example Default Implementation From 9301b8c77b4fea37b506deb487164007d9029501 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 10:55:14 +0000 Subject: [PATCH 05/15] fix: update uv.lock --- uv.lock | 37 +++++++++++++++++++++++++++++++++---- 1 file changed, 33 insertions(+), 4 deletions(-) diff --git a/uv.lock b/uv.lock index 2cecfc177..748ef3ee6 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.10" resolution-markers = [ "python_full_version >= '3.14'", @@ -70,6 +70,7 @@ telemetry = [ [package.dev-dependencies] dev = [ { name = "a2a-sdk", extra = ["all"] }, + { name = "alembic" }, { name = "autoflake" }, { name = "mypy" }, { name = "no-implicit-optional" }, @@ -135,6 +136,7 @@ provides-extras = ["all", "encryption", "grpc", "http-server", "mysql", "postgre [package.metadata.requires-dev] dev = [ { name = "a2a-sdk", extras = ["all"], editable = "." }, + { name = "alembic", specifier = ">=1.14.0" }, { name = "autoflake" }, { name = "mypy", specifier = ">=1.15.0" }, { name = "no-implicit-optional" }, @@ -177,6 +179,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/00/b7/e3bf5133d697a08128598c8d0abc5e16377b51465a33756de24fa7dee953/aiosqlite-0.22.1-py3-none-any.whl", hash = "sha256:21c002eb13823fad740196c5a2e9d8e62f6243bd9e7e4a1f87fb5e44ecb4fceb", size = 17405, upload-time = "2025-12-23T19:25:42.139Z" }, ] +[[package]] +name = "alembic" +version = "1.18.4" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "mako" }, + { name = "sqlalchemy" }, + { name = "tomli", marker = "python_full_version < '3.11'" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/94/13/8b084e0f2efb0275a1d534838844926f798bd766566b1375174e2448cd31/alembic-1.18.4.tar.gz", hash = "sha256:cb6e1fd84b6174ab8dbb2329f86d631ba9559dd78df550b57804d607672cedbc", size = 2056725, upload-time = "2026-02-10T16:00:47.195Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d2/29/6533c317b74f707ea28f8d633734dbda2119bbadfc61b2f3640ba835d0f7/alembic-1.18.4-py3-none-any.whl", hash = "sha256:a5ed4adcf6d8a4cb575f3d759f071b03cd6e5c7618eb796cb52497be25bfe19a", size = 263893, upload-time = "2026-02-10T16:00:49.997Z" }, +] + [[package]] name = "annotated-doc" version = "0.0.4" @@ -1277,6 +1294,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/d1/433b3c06e78f23486fe4fdd19bc134657eb30997d2054b0dbf52bbf3382e/librt-0.8.0-cp314-cp314t-win_arm64.whl", hash = "sha256:92249938ab744a5890580d3cb2b22042f0dce71cdaa7c1369823df62bedf7cbc", size = 48753, upload-time = "2026-02-12T14:53:38.539Z" }, ] +[[package]] +name = "mako" +version = "1.3.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "markupsafe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/9e/38/bd5b78a920a64d708fe6bc8e0a2c075e1389d53bef8413725c63ba041535/mako-1.3.10.tar.gz", hash = "sha256:99579a6f39583fa7e5630a28c3c1f440e4e97a414b80372649c0ce338da2ea28", size = 392474, upload-time = "2025-04-10T12:44:31.16Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/87/fb/99f81ac72ae23375f22b7afdb7642aba97c00a713c217124420147681a2f/mako-1.3.10-py3-none-any.whl", hash = "sha256:baef24a52fc4fc514a0887ac600f9f1cff3d82c61d4d700a1fa84d597b88db59", size = 78509, upload-time = "2025-04-10T12:50:53.297Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -2323,7 +2352,7 @@ wheels = [ [[package]] name = "virtualenv" -version = "20.37.0" +version = "20.38.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "distlib" }, @@ -2331,9 +2360,9 @@ dependencies = [ { name = "platformdirs" }, { name = "typing-extensions", marker = "python_full_version < '3.11'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c1/ef/d9d4ce633df789bf3430bd81fb0d8b9d9465dfc1d1f0deb3fb62cd80f5c2/virtualenv-20.37.0.tar.gz", hash = "sha256:6f7e2064ed470aa7418874e70b6369d53b66bcd9e9fd5389763e96b6c94ccb7c", size = 5864710, upload-time = "2026-02-16T16:17:59.42Z" } +sdist = { url = "https://files.pythonhosted.org/packages/d2/03/a94d404ca09a89a7301a7008467aed525d4cdeb9186d262154dd23208709/virtualenv-20.38.0.tar.gz", hash = "sha256:94f39b1abaea5185bf7ea5a46702b56f1d0c9aa2f41a6c2b8b0af4ddc74c10a7", size = 5864558, upload-time = "2026-02-19T07:48:02.385Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/42/4b/6cf85b485be7ec29db837ec2a1d8cd68bc1147b1abf23d8636c5bd65b3cc/virtualenv-20.37.0-py3-none-any.whl", hash = "sha256:5d3951c32d57232ae3569d4de4cc256c439e045135ebf43518131175d9be435d", size = 5837480, upload-time = "2026-02-16T16:17:57.341Z" }, + { url = "https://files.pythonhosted.org/packages/42/d7/394801755d4c8684b655d35c665aea7836ec68320304f62ab3c94395b442/virtualenv-20.38.0-py3-none-any.whl", hash = "sha256:d6e78e5889de3a4742df2d3d44e779366325a90cf356f15621fddace82431794", size = 5837778, upload-time = "2026-02-19T07:47:59.778Z" }, ] [[package]] From 62dce316c78fba89d6a5bc920372799528b1c60f Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 14:36:51 +0000 Subject: [PATCH 06/15] fix: - remove redundant 'index=True' in owner field declaration - add owner resource scoping to `InMemoryPushNotificationConfigStore` and a related unit test --- src/a2a/server/models.py | 6 +- ...inmemory_push_notification_config_store.py | 117 +++++++++++--- .../server/tasks/test_database_task_store.py | 2 +- .../tasks/test_inmemory_push_notifications.py | 151 ++++++++++++++---- .../server/tasks/test_inmemory_task_store.py | 6 +- 5 files changed, 222 insertions(+), 60 deletions(-) diff --git a/src/a2a/server/models.py b/src/a2a/server/models.py index b1e013e6b..a7e80d81c 100644 --- a/src/a2a/server/models.py +++ b/src/a2a/server/models.py @@ -148,7 +148,7 @@ class TaskMixin: kind: Mapped[str] = mapped_column( String(16), nullable=False, default='task' ) - owner: Mapped[str] = mapped_column(String(255), nullable=False, index=True) + owner: Mapped[str] = mapped_column(String(255), nullable=False) last_updated: Mapped[str] = mapped_column(String(22), nullable=True) # Properly typed Pydantic fields with automatic serialization @@ -175,10 +175,10 @@ def __repr__(self) -> str: f'context_id="{self.context_id}", status="{self.status}")>' ) - @declared_attr + @declared_attr.directive @classmethod def __table_args__(cls) -> tuple[Any, ...]: - """Define a unique index (owner, last_updated) for each table that uses the mixin.""" + """Define a composite index (owner, last_updated) for each table that uses the mixin.""" tablename = getattr(cls, '__tablename__', 'tasks') return ( Index( diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 707156593..54d6e1894 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -1,6 +1,10 @@ import asyncio import logging +from collections import defaultdict + +from a2a.server.context import ServerCallContext +from a2a.server.owner_resolver import OwnerResolver, resolve_user_scope from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -13,56 +17,117 @@ class InMemoryPushNotificationConfigStore(PushNotificationConfigStore): """In-memory implementation of PushNotificationConfigStore interface. - Stores push notification configurations in memory + Stores push notification configurations in a nested dictionary in memory, + keyed by owner then task_id. """ - def __init__(self) -> None: + def __init__( + self, + owner_resolver: OwnerResolver = resolve_user_scope, + ) -> None: """Initializes the InMemoryPushNotificationConfigStore.""" self.lock = asyncio.Lock() self._push_notification_infos: dict[ - str, list[PushNotificationConfig] - ] = {} + str, dict[str, list[PushNotificationConfig]] + ] = defaultdict(dict) + self.owner_resolver = owner_resolver async def set_info( - self, task_id: str, notification_config: PushNotificationConfig + self, + task_id: str, + notification_config: PushNotificationConfig, + context: ServerCallContext | None = None, ) -> None: """Sets or updates the push notification configuration for a task in memory.""" + owner = self.owner_resolver(context) async with self.lock: - if task_id not in self._push_notification_infos: - self._push_notification_infos[task_id] = [] + owner_infos = self._push_notification_infos[owner] + if task_id not in owner_infos: + owner_infos[task_id] = [] if not notification_config.id: notification_config.id = task_id - for config in self._push_notification_infos[task_id]: + # Remove existing config with the same ID + for config in owner_infos[task_id]: if config.id == notification_config.id: - self._push_notification_infos[task_id].remove(config) + owner_infos[task_id].remove(config) break - self._push_notification_infos[task_id].append(notification_config) - - async def get_info(self, task_id: str) -> list[PushNotificationConfig]: - """Retrieves the push notification configuration for a task from memory.""" + owner_infos[task_id].append(notification_config) + logger.debug( + 'Push notification config for task %s with config id %s for owner %s saved/updated.', + task_id, + notification_config.id, + owner, + ) + + async def get_info( + self, + task_id: str, + context: ServerCallContext | None = None, + ) -> list[PushNotificationConfig]: + """Retrieves all push notification configurations for a task from memory, for the given owner.""" + owner = self.owner_resolver(context) async with self.lock: - return self._push_notification_infos.get(task_id) or [] + owner_infos = self._push_notification_infos.get(owner) + if owner_infos: + return list(owner_infos.get(task_id, [])) + return [] async def delete_info( - self, task_id: str, config_id: str | None = None + self, + task_id: str, + config_id: str | None = None, + context: ServerCallContext | None = None, ) -> None: - """Deletes the push notification configuration for a task from memory.""" - async with self.lock: - if config_id is None: - config_id = task_id + """Deletes push notification configurations for a task from memory. - if task_id in self._push_notification_infos: - configurations = self._push_notification_infos[task_id] - if not configurations: - return + If config_id is provided, only that specific configuration is deleted. + If config_id is None, all configurations for the task for the owner are deleted. + """ + owner = self.owner_resolver(context) + async with self.lock: + owner_infos = self._push_notification_infos.get(owner) + if not owner_infos or task_id not in owner_infos: + logger.warning( + 'Attempted to delete push notification config for task %s, owner %s that does not exist.', + task_id, + owner, + ) + return + if config_id is None: + del owner_infos[task_id] + logger.info( + 'Deleted all push notification configs for task %s, owner %s.', + task_id, + owner, + ) + else: + configurations = owner_infos[task_id] + found = False for config in configurations: if config.id == config_id: configurations.remove(config) + found = True break - - if len(configurations) == 0: - del self._push_notification_infos[task_id] + if found: + logger.info( + 'Deleted push notification config %s for task %s, owner %s.', + config_id, + task_id, + owner, + ) + if len(configurations) == 0: + del owner_infos[task_id] + else: + logger.warning( + 'Attempted to delete push notification config %s for task %s, owner %s that does not exist.', + config_id, + task_id, + owner, + ) + + if not owner_infos: + del self._push_notification_infos[owner] diff --git a/tests/server/tasks/test_database_task_store.py b/tests/server/tasks/test_database_task_store.py index e1396d082..e6b67701c 100644 --- a/tests/server/tasks/test_database_task_store.py +++ b/tests/server/tasks/test_database_task_store.py @@ -33,6 +33,7 @@ ) from a2a.auth.user import User from a2a.server.context import ServerCallContext +from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE class TestUser(User): @@ -48,7 +49,6 @@ def is_authenticated(self) -> bool: @property def user_name(self) -> str: return self._user_name -from a2a.utils.constants import DEFAULT_LIST_TASKS_PAGE_SIZE # DSNs for different databases diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index bbb01de2c..f1de00782 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -5,6 +5,8 @@ import httpx from google.protobuf.json_format import MessageToDict +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) @@ -43,6 +45,21 @@ def create_sample_push_config( return PushNotificationConfig(id=config_id, url=url, token=token) +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) @@ -60,10 +77,8 @@ async def test_set_info_adds_new_config(self) -> None: await self.config_store.set_info(task_id, config) - self.assertIn(task_id, self.config_store._push_notification_infos) - self.assertEqual( - self.config_store._push_notification_infos[task_id], [config] - ) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(retrieved, [config]) async def test_set_info_appends_to_existing_config(self) -> None: task_id = 'task_update' @@ -77,15 +92,10 @@ async def test_set_info_appends_to_existing_config(self) -> None: ) await self.config_store.set_info(task_id, updated_config) - self.assertIn(task_id, self.config_store._push_notification_infos) - self.assertEqual( - self.config_store._push_notification_infos[task_id][0], - initial_config, - ) - self.assertEqual( - self.config_store._push_notification_infos[task_id][1], - updated_config, - ) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 2) + self.assertEqual(retrieved[0], initial_config) + self.assertEqual(retrieved[1], updated_config) async def test_set_info_without_config_id(self) -> None: task_id = 'task1' @@ -94,21 +104,17 @@ async def test_set_info_without_config_id(self) -> None: ) await self.config_store.set_info(task_id, initial_config) - assert ( - self.config_store._push_notification_infos[task_id][0].id == task_id - ) + retrieved = await self.config_store.get_info(task_id) + assert retrieved[0].id == task_id updated_config = PushNotificationConfig( url='http://initial.url/callback_new' ) await self.config_store.set_info(task_id, updated_config) - self.assertIn(task_id, self.config_store._push_notification_infos) - assert len(self.config_store._push_notification_infos[task_id]) == 1 - self.assertEqual( - self.config_store._push_notification_infos[task_id][0].url, - updated_config.url, - ) + retrieved = await self.config_store.get_info(task_id) + assert len(retrieved) == 1 + self.assertEqual(retrieved[0].url, updated_config.url) async def test_get_info_existing_config(self) -> None: task_id = 'task_get_exist' @@ -128,9 +134,12 @@ async def test_delete_info_existing_config(self) -> None: config = create_sample_push_config(url='http://delete.this/callback') await self.config_store.set_info(task_id, config) - self.assertIn(task_id, self.config_store._push_notification_infos) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 1) + await self.config_store.delete_info(task_id, config_id=config.id) - self.assertNotIn(task_id, self.config_store._push_notification_infos) + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 0) async def test_delete_info_non_existent_config(self) -> None: task_id = 'task_delete_non_exist' @@ -141,9 +150,8 @@ async def test_delete_info_non_existent_config(self) -> None: self.fail( f'delete_info raised {e} unexpectedly for nonexistent task_id' ) - self.assertNotIn( - task_id, self.config_store._push_notification_infos - ) # Should still not be there + retrieved = await self.config_store.get_info(task_id) + self.assertEqual(len(retrieved), 0) async def test_send_notification_success(self) -> None: task_id = 'task_send_success' @@ -295,6 +303,95 @@ async def test_send_notification_with_auth( ) # auth is not passed by current implementation mock_response.raise_for_status.assert_called_once() + async def test_owner_resource_scoping(self) -> None: + """Test that operations are scoped to the correct owner.""" + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) + + # Create configs for different owners + task1_u1_config1 = PushNotificationConfig( + id='t1-u1-c1', url='http://u1.com/1' + ) + task1_u1_config2 = PushNotificationConfig( + id='t1-u1-c2', url='http://u1.com/2' + ) + task1_u2_config1 = PushNotificationConfig( + id='t1-u2-c1', url='http://u2.com/1' + ) + task2_u1_config1 = PushNotificationConfig( + id='t2-u1-c1', url='http://u1.com/3' + ) + + await self.config_store.set_info( + 'task1', task1_u1_config1, context_user1 + ) + await self.config_store.set_info( + 'task1', task1_u1_config2, context_user1 + ) + await self.config_store.set_info( + 'task1', task1_u2_config1, context_user2 + ) + await self.config_store.set_info( + 'task2', task2_u1_config1, context_user1 + ) + + # Test GET_INFO + # User 1 should get only their configs for task1 + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 2) + self.assertEqual( + {c.id for c in u1_task1_configs}, {'t1-u1-c1', 't1-u1-c2'} + ) + + # User 2 should get only their configs for task1 + u2_task1_configs = await self.config_store.get_info( + 'task1', context_user2 + ) + self.assertEqual(len(u2_task1_configs), 1) + self.assertEqual(u2_task1_configs[0].id, 't1-u2-c1') + + # User 2 should get no configs for task2 + u2_task2_configs = await self.config_store.get_info( + 'task2', context_user2 + ) + self.assertEqual(len(u2_task2_configs), 0) + + # User 1 should get their config for task2 + u1_task2_configs = await self.config_store.get_info( + 'task2', context_user1 + ) + self.assertEqual(len(u1_task2_configs), 1) + self.assertEqual(u1_task2_configs[0].id, 't2-u1-c1') + + # Test DELETE_INFO + # User 2 deleting User 1's config should not work + await self.config_store.delete_info('task1', 't1-u1-c1', context_user2) + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 2) + + # User 1 deleting their own config + await self.config_store.delete_info('task1', 't1-u1-c1', context_user1) + u1_task1_configs = await self.config_store.get_info( + 'task1', context_user1 + ) + self.assertEqual(len(u1_task1_configs), 1) + self.assertEqual(u1_task1_configs[0].id, 't1-u1-c2') + + # User 1 deleting all configs for task2 + await self.config_store.delete_info('task2', context=context_user1) + u1_task2_configs = await self.config_store.get_info( + 'task2', context_user1 + ) + self.assertEqual(len(u1_task2_configs), 0) + + # Cleanup remaining + await self.config_store.delete_info('task1', context=context_user1) + await self.config_store.delete_info('task1', context=context_user2) + if __name__ == '__main__': unittest.main() diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index 8f6849e7a..f6093b64e 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -9,7 +9,7 @@ from a2a.auth.user import User -class TestUser(User): +class SampleUser(User): """A test implementation of the User interface.""" def __init__(self, user_name: str): @@ -273,8 +273,8 @@ async def test_owner_resource_scoping() -> None: store = InMemoryTaskStore() task = create_minimal_task() - context_user1 = ServerCallContext(user=TestUser(user_name='user1')) - context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) # Create tasks for different owners task1_user1 = Task() From 99cf89ff51704214b8e6928f44af2d2d777e3485 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 15:34:46 +0000 Subject: [PATCH 07/15] fix: fix linter issues --- alembic/__init__.py | 1 + alembic/env.py | 17 +++++++++++++---- .../versions/6419d2d130f6_add_owner_to_task.py | 2 +- 3 files changed, 15 insertions(+), 5 deletions(-) create mode 100644 alembic/__init__.py diff --git a/alembic/__init__.py b/alembic/__init__.py new file mode 100644 index 000000000..7b55fb93e --- /dev/null +++ b/alembic/__init__.py @@ -0,0 +1 @@ +"Alembic database migration package." diff --git a/alembic/env.py b/alembic/env.py index d541fe140..dcc644655 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -23,7 +23,7 @@ # other values from the config, defined by the needs of env.py, # can be acquired: -# my_important_option = config.get_main_option("my_important_option") +# my_important_option = config.get_main_option("my_important_option") # noqa: ERA001 # ... etc. @@ -51,14 +51,23 @@ def run_migrations_offline() -> None: context.run_migrations() -def do_run_migrations(connection): +def do_run_migrations(connection) -> None: + """Run migrations in 'online' mode. + + This function is called within a synchronous context (via run_sync) + to configure the migration context with the provided connection + and target metadata, then execute the migrations within a transaction. + + Args: + connection: The SQLAlchemy connection to use for the migrations. + """ context.configure(connection=connection, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() -async def run_async_migrations(): +async def run_async_migrations() -> None: """In this scenario we need to create an Engine and associate a connection with the context. """ @@ -74,7 +83,7 @@ async def run_async_migrations(): await connectable.dispose() -def run_migrations_online(): +def run_migrations_online() -> None: """Run migrations in 'online' mode.""" asyncio.run(run_async_migrations()) diff --git a/alembic/versions/6419d2d130f6_add_owner_to_task.py b/alembic/versions/6419d2d130f6_add_owner_to_task.py index 3b96a5c9e..6e2ede603 100644 --- a/alembic/versions/6419d2d130f6_add_owner_to_task.py +++ b/alembic/versions/6419d2d130f6_add_owner_to_task.py @@ -1,4 +1,4 @@ -"""add_owner_to_task +"""add_owner_to_task. Revision ID: 6419d2d130f6 Revises: From feb5033bfc7af52bf21b7d64053af43b5b2f11cc Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 15:44:09 +0000 Subject: [PATCH 08/15] Fix: fix some more linter errors --- alembic/env.py | 8 +++++--- alembic/versions/__init__.py | 1 + 2 files changed, 6 insertions(+), 3 deletions(-) create mode 100644 alembic/versions/__init__.py diff --git a/alembic/env.py b/alembic/env.py index dcc644655..f516c886c 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -2,7 +2,7 @@ from logging.config import fileConfig -from sqlalchemy import pool +from sqlalchemy import pool, Connection from sqlalchemy.ext.asyncio import async_engine_from_config from a2a.server.models import Base @@ -51,7 +51,7 @@ def run_migrations_offline() -> None: context.run_migrations() -def do_run_migrations(connection) -> None: +def do_run_migrations(connection: Connection) -> None: """Run migrations in 'online' mode. This function is called within a synchronous context (via run_sync) @@ -68,7 +68,9 @@ def do_run_migrations(connection) -> None: async def run_async_migrations() -> None: - """In this scenario we need to create an Engine + """Run migrations using an Engine. + + In this scenario we need to create an Engine and associate a connection with the context. """ connectable = async_engine_from_config( diff --git a/alembic/versions/__init__.py b/alembic/versions/__init__.py new file mode 100644 index 000000000..23a018c29 --- /dev/null +++ b/alembic/versions/__init__.py @@ -0,0 +1 @@ +"""Alembic versioned migrations for the A2A project.""" From 212ad37c73fbcb27dd0b4c1623d9a20b1e482215 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Thu, 19 Feb 2026 15:48:58 +0000 Subject: [PATCH 09/15] fix: more linter errors fixed --- alembic/env.py | 2 +- alembic/versions/__init__.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/alembic/env.py b/alembic/env.py index f516c886c..07864de4d 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -2,7 +2,7 @@ from logging.config import fileConfig -from sqlalchemy import pool, Connection +from sqlalchemy import Connection, pool from sqlalchemy.ext.asyncio import async_engine_from_config from a2a.server.models import Base diff --git a/alembic/versions/__init__.py b/alembic/versions/__init__.py index 23a018c29..574828c67 100644 --- a/alembic/versions/__init__.py +++ b/alembic/versions/__init__.py @@ -1 +1 @@ -"""Alembic versioned migrations for the A2A project.""" +"""Alembic migrations scripts for the A2A project.""" From f7b5c1cc1e3787c079ef7a5c6ba3824879cbc874 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 16:53:06 +0000 Subject: [PATCH 10/15] fix: make parameter `ServerCallContext` non-optional in `PushNotificationConfigStore` methods. --- .../default_request_handler.py | 11 +- .../tasks/base_push_notification_sender.py | 8 +- ...database_push_notification_config_store.py | 6 +- ...inmemory_push_notification_config_store.py | 6 +- .../tasks/push_notification_config_store.py | 6 +- .../test_default_request_handler.py | 54 ++++---- .../request_handlers/test_jsonrpc_handler.py | 4 +- ...database_push_notification_config_store.py | 127 ++++++++++++------ .../tasks/test_inmemory_push_notifications.py | 123 +++++++++++------ .../tasks/test_push_notification_sender.py | 65 ++++++--- 10 files changed, 267 insertions(+), 143 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 63d0fdc74..9860d96e2 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -228,7 +228,7 @@ async def _run_event_stream( async def _setup_message_execution( self, params: SendMessageRequest, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: """Common setup logic for both streaming and non-streaming message handling. @@ -284,7 +284,7 @@ async def _setup_message_execution( and params.configuration.push_notification_config ): await self._push_config_store.set_info( - task_id, params.configuration.push_notification_config + task_id, params.configuration.push_notification_config, context ) queue = await self._queue_manager.create_or_tap(task_id) @@ -498,6 +498,7 @@ async def on_create_task_push_notification_config( await self._push_config_store.set_info( task_id, params.config, + context, ) return TaskPushNotificationConfig( @@ -524,7 +525,7 @@ async def on_get_task_push_notification_config( raise ServerError(error=TaskNotFoundError()) push_notification_configs: list[PushNotificationConfig] = ( - await self._push_config_store.get_info(task_id) or [] + await self._push_config_store.get_info(task_id, context) or [] ) for config in push_notification_configs: @@ -596,7 +597,7 @@ async def on_list_task_push_notification_configs( raise ServerError(error=TaskNotFoundError()) push_notification_config_list = await self._push_config_store.get_info( - task_id + task_id, context ) return ListTaskPushNotificationConfigsResponse( @@ -627,4 +628,4 @@ async def on_delete_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.delete_info(task_id, config_id) + await self._push_config_store.delete_info(task_id, context, config_id) diff --git a/src/a2a/server/tasks/base_push_notification_sender.py b/src/a2a/server/tasks/base_push_notification_sender.py index 4e4444923..84f544f5e 100644 --- a/src/a2a/server/tasks/base_push_notification_sender.py +++ b/src/a2a/server/tasks/base_push_notification_sender.py @@ -5,6 +5,7 @@ from google.protobuf.json_format import MessageToDict +from a2a.server.context import ServerCallContext from a2a.server.tasks.push_notification_config_store import ( PushNotificationConfigStore, ) @@ -22,19 +23,24 @@ def __init__( self, httpx_client: httpx.AsyncClient, config_store: PushNotificationConfigStore, + context: ServerCallContext, ) -> None: """Initializes the BasePushNotificationSender. Args: httpx_client: An async HTTP client instance to send notifications. config_store: A PushNotificationConfigStore instance to retrieve configurations. + context: The `ServerCallContext` that this push notification is produced under. """ self._client = httpx_client self._config_store = config_store + self._call_context: ServerCallContext = context async def send_notification(self, task: Task) -> None: """Sends a push notification for a task if configuration exists.""" - push_configs = await self._config_store.get_info(task.id) + push_configs = await self._config_store.get_info( + task.id, self._call_context + ) if not push_configs: return diff --git a/src/a2a/server/tasks/database_push_notification_config_store.py b/src/a2a/server/tasks/database_push_notification_config_store.py index 32dd47fd8..be8f16121 100644 --- a/src/a2a/server/tasks/database_push_notification_config_store.py +++ b/src/a2a/server/tasks/database_push_notification_config_store.py @@ -241,7 +241,7 @@ async def set_info( self, task_id: str, notification_config: PushNotificationConfig, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task.""" await self._ensure_initialized() @@ -266,7 +266,7 @@ async def set_info( async def get_info( self, task_id: str, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> list[PushNotificationConfig]: """Retrieves all push notification configurations for a task, for the given owner.""" await self._ensure_initialized() @@ -297,8 +297,8 @@ async def get_info( async def delete_info( self, task_id: str, + context: ServerCallContext, config_id: str | None = None, - context: ServerCallContext | None = None, ) -> None: """Deletes push notification configurations for a task. diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 54d6e1894..4de8b82fa 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -36,7 +36,7 @@ async def set_info( self, task_id: str, notification_config: PushNotificationConfig, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task in memory.""" owner = self.owner_resolver(context) @@ -65,7 +65,7 @@ async def set_info( async def get_info( self, task_id: str, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> list[PushNotificationConfig]: """Retrieves all push notification configurations for a task from memory, for the given owner.""" owner = self.owner_resolver(context) @@ -78,8 +78,8 @@ async def get_info( async def delete_info( self, task_id: str, + context: ServerCallContext, config_id: str | None = None, - context: ServerCallContext | None = None, ) -> None: """Deletes push notification configurations for a task from memory. diff --git a/src/a2a/server/tasks/push_notification_config_store.py b/src/a2a/server/tasks/push_notification_config_store.py index e47060d7d..f1db64664 100644 --- a/src/a2a/server/tasks/push_notification_config_store.py +++ b/src/a2a/server/tasks/push_notification_config_store.py @@ -12,7 +12,7 @@ async def set_info( self, task_id: str, notification_config: PushNotificationConfig, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> None: """Sets or updates the push notification configuration for a task.""" @@ -20,7 +20,7 @@ async def set_info( async def get_info( self, task_id: str, - context: ServerCallContext | None = None, + context: ServerCallContext, ) -> list[PushNotificationConfig]: """Retrieves the push notification configuration for a task.""" @@ -28,7 +28,7 @@ async def get_info( async def delete_info( self, task_id: str, + context: ServerCallContext, config_id: str | None = None, - context: ServerCallContext | None = None, ) -> None: """Deletes the push notification configuration for a task.""" diff --git a/tests/server/request_handlers/test_default_request_handler.py b/tests/server/request_handlers/test_default_request_handler.py index 691731c94..410c2d21f 100644 --- a/tests/server/request_handlers/test_default_request_handler.py +++ b/tests/server/request_handlers/test_default_request_handler.py @@ -546,6 +546,7 @@ async def mock_current_result(): lambda self: mock_current_result() ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -560,12 +561,10 @@ async def mock_current_result(): return_value=sample_initial_task, ), ): # Ensure task object is returned - await request_handler.on_message_send( - params, create_server_call_context() - ) + await request_handler.on_message_send(params, context) mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -665,6 +664,7 @@ async def mock_consume_and_break_on_interrupt( mock_consume_and_break_on_interrupt ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -680,9 +680,7 @@ async def mock_consume_and_break_on_interrupt( ), ): # Execute the non-blocking request - result = await request_handler.on_message_send( - params, create_server_call_context() - ) + result = await request_handler.on_message_send(params, context) # Verify the result is the initial task (non-blocking behavior) assert result == initial_task @@ -700,7 +698,7 @@ async def mock_consume_and_break_on_interrupt( # Verify that the push notification config was stored mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) @@ -763,6 +761,7 @@ async def mock_current_result(): lambda self: mock_current_result() ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -773,12 +772,10 @@ async def mock_current_result(): return_value=None, ), ): - await request_handler.on_message_send( - params, create_server_call_context() - ) + await request_handler.on_message_send(params, context) mock_push_notification_store.set_info.assert_awaited_once_with( - task_id, push_config + task_id, push_config, context ) # Other assertions for full flow if needed (e.g., agent execution) mock_agent_executor.execute.assert_awaited_once() @@ -1382,6 +1379,7 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs): side_effect=[get_current_result_coro1(), get_current_result_coro2()] ) + context = create_server_call_context() with ( patch( 'a2a.server.request_handlers.default_request_handler.ResultAggregator', @@ -1397,16 +1395,16 @@ def sync_get_event_stream_gen_for_prop_test(*args, **kwargs): ), ): # Consume the stream - async for _ in request_handler.on_message_send_stream( - params, create_server_call_context() - ): + async for _ in request_handler.on_message_send_stream(params, context): pass await asyncio.wait_for(execute_called.wait(), timeout=0.1) # Assertions # 1. set_info called once at the beginning if task exists (or after task is created from message) - mock_push_config_store.set_info.assert_any_call(task_id, push_config) + mock_push_config_store.set_info.assert_any_call( + task_id, push_config, context + ) # 2. send_notification called for each task event yielded by aggregator assert mock_push_sender.send_notification.await_count == 2 @@ -2082,7 +2080,9 @@ async def test_get_task_push_notification_config_info_not_found(): exc_info.value.error, InternalError ) # Current code raises InternalError mock_task_store.get.assert_awaited_once_with('non_existent_task', context) - mock_push_store.get_info.assert_awaited_once_with('non_existent_task') + mock_push_store.get_info.assert_awaited_once_with( + 'non_existent_task', context + ) @pytest.mark.asyncio @@ -2236,7 +2236,7 @@ async def test_on_message_send_stream(): async def consume_stream(): events = [] async for event in request_handler.on_message_send_stream( - message_params + message_params, create_server_call_context() ): events.append(event) if len(events) >= 3: @@ -2340,8 +2340,9 @@ async def test_list_task_push_notification_config_info_with_config(): ) push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config1) - await push_store.set_info('task_1', push_config2) + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), @@ -2467,6 +2468,7 @@ async def test_delete_no_task_push_notification_config_info(): await push_store.set_info( 'task_2', PushNotificationConfig(id='config_1', url='http://example.com'), + create_server_call_context(), ) request_handler = DefaultRequestHandler( @@ -2509,9 +2511,10 @@ async def test_delete_task_push_notification_config_info_with_config(): ) push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config1) - await push_store.set_info('task_1', push_config2) - await push_store.set_info('task_2', push_config1) + context = create_server_call_context() + await push_store.set_info('task_1', push_config1, context) + await push_store.set_info('task_1', push_config2, context) + await push_store.set_info('task_2', push_config1, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), @@ -2550,8 +2553,9 @@ async def test_delete_task_push_notification_config_info_with_config_and_no_id() # insertion without id should replace the existing config push_store = InMemoryPushNotificationConfigStore() - await push_store.set_info('task_1', push_config) - await push_store.set_info('task_1', push_config) + context = create_server_call_context() + await push_store.set_info('task_1', push_config, context) + await push_store.set_info('task_1', push_config, context) request_handler = DefaultRequestHandler( agent_executor=MockAgentExecutor(), diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index fca1175af..90b7be1c8 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -554,7 +554,7 @@ async def test_set_push_notification_success(self) -> None: self.assertIsInstance(response, dict) self.assertTrue(is_success_response(response)) mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, push_config + mock_task.id, push_config, None ) async def test_get_push_notification_success(self) -> None: @@ -601,7 +601,7 @@ async def test_on_message_stream_new_message_send_push_notification_success( mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) push_notification_store = InMemoryPushNotificationConfigStore() push_notification_sender = BasePushNotificationSender( - mock_httpx_client, push_notification_store + mock_httpx_client, push_notification_store, ServerCallContext() ) request_handler = DefaultRequestHandler( mock_agent_executor, diff --git a/tests/server/tasks/test_database_push_notification_config_store.py b/tests/server/tasks/test_database_push_notification_config_store.py index 9336493a2..042ff8000 100644 --- a/tests/server/tasks/test_database_push_notification_config_store.py +++ b/tests/server/tasks/test_database_push_notification_config_store.py @@ -104,7 +104,7 @@ def _create_timestamp() -> Timestamp: ) -class TestUser(User): +class SampleUser(User): """A test implementation of the User interface.""" def __init__(self, user_name: str): @@ -119,6 +119,9 @@ def user_name(self) -> str: return self._user_name +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + @pytest_asyncio.fixture(params=DB_CONFIGS) async def db_store_parameterized( request, @@ -198,8 +201,10 @@ async def test_set_and_get_info_single_config( task_id = 'task-1' config = PushNotificationConfig(id='config-1', url='http://example.com') - await db_store_parameterized.set_info(task_id, config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config @@ -215,9 +220,15 @@ async def test_set_and_get_info_multiple_configs( config1 = PushNotificationConfig(id='config-1', url='http://example.com/1') config2 = PushNotificationConfig(id='config-2', url='http://example.com/2') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 2 assert config1 in retrieved_configs @@ -238,9 +249,15 @@ async def test_set_info_updates_existing_config( id=config_id, url='http://updated.url' ) - await db_store_parameterized.set_info(task_id, initial_config) - await db_store_parameterized.set_info(task_id, updated_config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0].url == 'http://updated.url' @@ -254,8 +271,10 @@ async def test_set_info_defaults_config_id_to_task_id( task_id = 'task-1' config = PushNotificationConfig(url='http://example.com') # id is None - await db_store_parameterized.set_info(task_id, config) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0].id == task_id @@ -267,7 +286,7 @@ async def test_get_info_not_found( ): """Test getting info for a task with no configs returns an empty list.""" retrieved_configs = await db_store_parameterized.get_info( - 'non-existent-task' + 'non-existent-task', MINIMAL_CALL_CONTEXT ) assert retrieved_configs == [] @@ -281,11 +300,19 @@ async def test_delete_info_specific_config( config1 = PushNotificationConfig(id='config-1', url='http://a.com') config2 = PushNotificationConfig(id='config-2', url='http://b.com') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) - await db_store_parameterized.delete_info(task_id, 'config-1') - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.delete_info( + task_id, MINIMAL_CALL_CONTEXT, 'config-1' + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config2 @@ -301,11 +328,19 @@ async def test_delete_info_all_for_task( config1 = PushNotificationConfig(id='config-1', url='http://a.com') config2 = PushNotificationConfig(id='config-2', url='http://b.com') - await db_store_parameterized.set_info(task_id, config1) - await db_store_parameterized.set_info(task_id, config2) + await db_store_parameterized.set_info( + task_id, config1, MINIMAL_CALL_CONTEXT + ) + await db_store_parameterized.set_info( + task_id, config2, MINIMAL_CALL_CONTEXT + ) - await db_store_parameterized.delete_info(task_id, None) - retrieved_configs = await db_store_parameterized.get_info(task_id) + await db_store_parameterized.delete_info( + task_id, MINIMAL_CALL_CONTEXT, None + ) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_configs == [] @@ -316,7 +351,9 @@ async def test_delete_info_not_found( ): """Test that deleting a non-existent config does not raise an error.""" # Should not raise - await db_store_parameterized.delete_info('task-1', 'non-existent-config') + await db_store_parameterized.delete_info( + 'task-1', MINIMAL_CALL_CONTEXT, 'non-existent-config' + ) @pytest.mark.asyncio @@ -330,7 +367,7 @@ async def test_data_is_encrypted_in_db( ) plain_json = MessageToJson(config) - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Directly query the database to inspect the raw data async_session = async_sessionmaker( @@ -360,7 +397,7 @@ async def test_decryption_error_with_wrong_key( task_id = 'wrong-key-task' config = PushNotificationConfig(id='config-1', url='http://secret.url') - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with a different key # Directly query the database to inspect the raw data @@ -369,7 +406,7 @@ async def test_decryption_error_with_wrong_key( db_store_parameterized.engine, encryption_key=wrong_key ) - retrieved_configs = await store2.get_info(task_id) + retrieved_configs = await store2.get_info(task_id, MINIMAL_CALL_CONTEXT) assert retrieved_configs == [] # _from_orm should raise a ValueError @@ -394,13 +431,13 @@ async def test_decryption_error_with_no_key( task_id = 'wrong-key-task' config = PushNotificationConfig(id='config-1', url='http://secret.url') - await db_store_parameterized.set_info(task_id, config) + await db_store_parameterized.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with no key set # Directly query the database to inspect the raw data store2 = DatabasePushNotificationConfigStore(db_store_parameterized.engine) - retrieved_configs = await store2.get_info(task_id) + retrieved_configs = await store2.get_info(task_id, MINIMAL_CALL_CONTEXT) assert retrieved_configs == [] # _from_orm should raise a ValueError @@ -437,8 +474,10 @@ async def test_custom_table_name( config = PushNotificationConfig(id='config-1', url='http://custom.url') # This will create the table on first use - await custom_store.set_info(task_id, config) - retrieved_configs = await custom_store.get_info(task_id) + await custom_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) + retrieved_configs = await custom_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved_configs) == 1 assert retrieved_configs[0] == config @@ -482,9 +521,9 @@ async def test_set_and_get_info_multiple_configs_no_key( config1 = PushNotificationConfig(id='config-1', url='http://example.com/1') config2 = PushNotificationConfig(id='config-2', url='http://example.com/2') - await store.set_info(task_id, config1) - await store.set_info(task_id, config2) - retrieved_configs = await store.get_info(task_id) + await store.set_info(task_id, config1, MINIMAL_CALL_CONTEXT) + await store.set_info(task_id, config2, MINIMAL_CALL_CONTEXT) + retrieved_configs = await store.get_info(task_id, MINIMAL_CALL_CONTEXT) assert len(retrieved_configs) == 2 assert config1 in retrieved_configs @@ -508,7 +547,7 @@ async def test_data_is_not_encrypted_in_db_if_no_key_is_set( config = PushNotificationConfig(id='config-1', url='http://example.com/1') plain_json = MessageToJson(config) - await store.set_info(task_id, config) + await store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Directly query the database to inspect the raw data async_session = async_sessionmaker( @@ -539,10 +578,12 @@ async def test_decryption_fallback_for_unencrypted_data( task_id = 'mixed-encryption-task' config = PushNotificationConfig(id='config-1', url='http://plain.url') - await unencrypted_store.set_info(task_id, config) + await unencrypted_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # 2. Try to read with the encryption-enabled store from the fixture - retrieved_configs = await db_store_parameterized.get_info(task_id) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) # Should fall back to parsing as plain JSON and not fail assert len(retrieved_configs) == 1 @@ -572,13 +613,15 @@ async def test_parsing_error_after_successful_decryption( task_id=task_id, config_id=config_id, config_data=encrypted_data, - owner='test-owner', + owner='user', ) session.add(db_model) await session.commit() # 3. get_info should log an error and return an empty list - retrieved_configs = await db_store_parameterized.get_info(task_id) + retrieved_configs = await db_store_parameterized.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_configs == [] # 4. _from_orm should raise a ValueError @@ -598,8 +641,8 @@ async def test_owner_resource_scoping( """Test that operations are scoped to the correct owner.""" config_store = db_store_parameterized - context_user1 = ServerCallContext(user=TestUser(user_name='user1')) - context_user2 = ServerCallContext(user=TestUser(user_name='user2')) + context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) + context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) # Create configs for different owners task1_u1_config1 = PushNotificationConfig( @@ -642,12 +685,16 @@ async def test_owner_resource_scoping( # Test DELETE_INFO # User 2 deleting User 1's config should not work - await config_store.delete_info('task1', 't1-u1-c1', context_user2) + await config_store.delete_info('task1', context_user2, 't1-u1-c1') u1_task1_configs = await config_store.get_info('task1', context_user1) assert len(u1_task1_configs) == 2 # User 1 deleting their own config - await config_store.delete_info('task1', 't1-u1-c1', context_user1) + await config_store.delete_info( + 'task1', + context_user1, + 't1-u1-c1', + ) u1_task1_configs = await config_store.get_info('task1', context_user1) assert len(u1_task1_configs) == 1 assert u1_task1_configs[0].id == 't1-u1-c2' diff --git a/tests/server/tasks/test_inmemory_push_notifications.py b/tests/server/tasks/test_inmemory_push_notifications.py index f1de00782..0024a95a6 100644 --- a/tests/server/tasks/test_inmemory_push_notifications.py +++ b/tests/server/tasks/test_inmemory_push_notifications.py @@ -26,7 +26,7 @@ # logging.disable(logging.CRITICAL) -def create_sample_task( +def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, ) -> Task: @@ -37,7 +37,7 @@ def create_sample_task( ) -def create_sample_push_config( +def _create_sample_push_config( url: str = 'http://example.com/callback', config_id: str = 'cfg1', token: str | None = None, @@ -60,12 +60,17 @@ def user_name(self) -> str: return self._user_name +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + class TestInMemoryPushNotifier(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) self.config_store = InMemoryPushNotificationConfigStore() self.notifier = BasePushNotificationSender( - httpx_client=self.mock_httpx_client, config_store=self.config_store + httpx_client=self.mock_httpx_client, + config_store=self.config_store, + context=MINIMAL_CALL_CONTEXT, ) # Corrected argument name def test_constructor_stores_client(self) -> None: @@ -73,26 +78,34 @@ def test_constructor_stores_client(self) -> None: async def test_set_info_adds_new_config(self) -> None: task_id = 'task_new' - config = create_sample_push_config(url='http://new.url/callback') + config = _create_sample_push_config(url='http://new.url/callback') - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(retrieved, [config]) async def test_set_info_appends_to_existing_config(self) -> None: task_id = 'task_update' - initial_config = create_sample_push_config( + initial_config = _create_sample_push_config( url='http://initial.url/callback', config_id='cfg_initial' ) - await self.config_store.set_info(task_id, initial_config) + await self.config_store.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) - updated_config = create_sample_push_config( + updated_config = _create_sample_push_config( url='http://updated.url/callback', config_id='cfg_updated' ) - await self.config_store.set_info(task_id, updated_config) + await self.config_store.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 2) self.assertEqual(retrieved[0], initial_config) self.assertEqual(retrieved[1], updated_config) @@ -102,62 +115,84 @@ async def test_set_info_without_config_id(self) -> None: initial_config = PushNotificationConfig( url='http://initial.url/callback' ) - await self.config_store.set_info(task_id, initial_config) + await self.config_store.set_info( + task_id, initial_config, MINIMAL_CALL_CONTEXT + ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved[0].id == task_id updated_config = PushNotificationConfig( url='http://initial.url/callback_new' ) - await self.config_store.set_info(task_id, updated_config) + await self.config_store.set_info( + task_id, updated_config, MINIMAL_CALL_CONTEXT + ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert len(retrieved) == 1 self.assertEqual(retrieved[0].url, updated_config.url) async def test_get_info_existing_config(self) -> None: task_id = 'task_get_exist' - config = create_sample_push_config(url='http://get.this/callback') - await self.config_store.set_info(task_id, config) + config = _create_sample_push_config(url='http://get.this/callback') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - retrieved_config = await self.config_store.get_info(task_id) + retrieved_config = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(retrieved_config, [config]) async def test_get_info_non_existent_config(self) -> None: task_id = 'task_get_non_exist' - retrieved_config = await self.config_store.get_info(task_id) + retrieved_config = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) assert retrieved_config == [] async def test_delete_info_existing_config(self) -> None: task_id = 'task_delete_exist' - config = create_sample_push_config(url='http://delete.this/callback') - await self.config_store.set_info(task_id, config) + config = _create_sample_push_config(url='http://delete.this/callback') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 1) - await self.config_store.delete_info(task_id, config_id=config.id) - retrieved = await self.config_store.get_info(task_id) + await self.config_store.delete_info( + task_id, config_id=config.id, context=MINIMAL_CALL_CONTEXT + ) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 0) async def test_delete_info_non_existent_config(self) -> None: task_id = 'task_delete_non_exist' # Ensure it doesn't raise an error try: - await self.config_store.delete_info(task_id) + await self.config_store.delete_info( + task_id, context=MINIMAL_CALL_CONTEXT + ) except Exception as e: self.fail( f'delete_info raised {e} unexpectedly for nonexistent task_id' ) - retrieved = await self.config_store.get_info(task_id) + retrieved = await self.config_store.get_info( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(len(retrieved), 0) async def test_send_notification_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/here') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/here') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Mock the post call to simulate success mock_response = AsyncMock(spec=httpx.Response) @@ -180,11 +215,11 @@ async def test_send_notification_success(self) -> None: async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) # Mock the post call to simulate success mock_response = AsyncMock(spec=httpx.Response) @@ -211,7 +246,7 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' - task_data = create_sample_task(task_id=task_id) + task_data = _create_sample_task(task_id=task_id) await self.notifier.send_notification(task_data) # Pass only task_data @@ -222,9 +257,9 @@ async def test_send_notification_http_status_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_http_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/http_error') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/http_error') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) mock_response = MagicMock( spec=httpx.Response @@ -252,9 +287,9 @@ async def test_send_notification_request_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_req_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/req_error') - await self.config_store.set_info(task_id, config) + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/req_error') + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) request_error = httpx.RequestError('Network issue', request=MagicMock()) self.mock_httpx_client.post.side_effect = request_error @@ -279,11 +314,11 @@ async def test_send_notification_with_auth( still works even if the config has an authentication field set. """ task_id = 'task_send_auth' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/auth') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/auth') # The current implementation doesn't use the authentication field # It only supports token-based auth via the token field - await self.config_store.set_info(task_id, config) + await self.config_store.set_info(task_id, config, MINIMAL_CALL_CONTEXT) mock_response = AsyncMock(spec=httpx.Response) mock_response.status_code = 200 @@ -367,14 +402,14 @@ async def test_owner_resource_scoping(self) -> None: # Test DELETE_INFO # User 2 deleting User 1's config should not work - await self.config_store.delete_info('task1', 't1-u1-c1', context_user2) + await self.config_store.delete_info('task1', context_user2, 't1-u1-c1') u1_task1_configs = await self.config_store.get_info( 'task1', context_user1 ) self.assertEqual(len(u1_task1_configs), 2) # User 1 deleting their own config - await self.config_store.delete_info('task1', 't1-u1-c1', context_user1) + await self.config_store.delete_info('task1', context_user1, 't1-u1-c1') u1_task1_configs = await self.config_store.get_info( 'task1', context_user1 ) diff --git a/tests/server/tasks/test_push_notification_sender.py b/tests/server/tasks/test_push_notification_sender.py index a7b5f7603..985ae6b7a 100644 --- a/tests/server/tasks/test_push_notification_sender.py +++ b/tests/server/tasks/test_push_notification_sender.py @@ -5,6 +5,8 @@ import httpx from google.protobuf.json_format import MessageToDict +from a2a.auth.user import User +from a2a.server.context import ServerCallContext from a2a.server.tasks.base_push_notification_sender import ( BasePushNotificationSender, ) @@ -17,7 +19,22 @@ ) -def create_sample_task( +class SampleUser(User): + """A test implementation of the User interface.""" + + def __init__(self, user_name: str): + self._user_name = user_name + + @property + def is_authenticated(self) -> bool: + return True + + @property + def user_name(self) -> str: + return self._user_name + + +def _create_sample_task( task_id: str = 'task123', status_state: TaskState = TaskState.TASK_STATE_COMPLETED, ) -> Task: @@ -28,7 +45,7 @@ def create_sample_task( ) -def create_sample_push_config( +def _create_sample_push_config( url: str = 'http://example.com/callback', config_id: str = 'cfg1', token: str | None = None, @@ -36,6 +53,9 @@ def create_sample_push_config( return PushNotificationConfig(id=config_id, url=url, token=token) +MINIMAL_CALL_CONTEXT = ServerCallContext(user=SampleUser(user_name='user')) + + class TestBasePushNotificationSender(unittest.IsolatedAsyncioTestCase): def setUp(self) -> None: self.mock_httpx_client = AsyncMock(spec=httpx.AsyncClient) @@ -43,6 +63,7 @@ def setUp(self) -> None: self.sender = BasePushNotificationSender( httpx_client=self.mock_httpx_client, config_store=self.mock_config_store, + context=MINIMAL_CALL_CONTEXT, ) def test_constructor_stores_client_and_config_store(self) -> None: @@ -51,8 +72,8 @@ def test_constructor_stores_client_and_config_store(self) -> None: async def test_send_notification_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/here') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/here') self.mock_config_store.get_info.return_value = [config] mock_response = AsyncMock(spec=httpx.Response) @@ -61,7 +82,9 @@ async def test_send_notification_success(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -73,8 +96,8 @@ async def test_send_notification_success(self) -> None: async def test_send_notification_with_token_success(self) -> None: task_id = 'task_send_success' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config( url='http://notify.me/here', token='unique_token' ) self.mock_config_store.get_info.return_value = [config] @@ -85,7 +108,9 @@ async def test_send_notification_with_token_success(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) # assert httpx_client post method got invoked with right parameters self.mock_httpx_client.post.assert_awaited_once_with( @@ -97,12 +122,14 @@ async def test_send_notification_with_token_success(self) -> None: async def test_send_notification_no_config(self) -> None: task_id = 'task_send_no_config' - task_data = create_sample_task(task_id=task_id) + task_data = _create_sample_task(task_id=task_id) self.mock_config_store.get_info.return_value = [] await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_not_called() @patch('a2a.server.tasks.base_push_notification_sender.logger') @@ -110,8 +137,8 @@ async def test_send_notification_http_status_error( self, mock_logger: MagicMock ) -> None: task_id = 'task_send_http_err' - task_data = create_sample_task(task_id=task_id) - config = create_sample_push_config(url='http://notify.me/http_error') + task_data = _create_sample_task(task_id=task_id) + config = _create_sample_push_config(url='http://notify.me/http_error') self.mock_config_store.get_info.return_value = [config] mock_response = MagicMock(spec=httpx.Response) @@ -124,7 +151,9 @@ async def test_send_notification_http_status_error( await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.mock_httpx_client.post.assert_awaited_once_with( config.url, json=MessageToDict(StreamResponse(task=task_data)), @@ -134,11 +163,11 @@ async def test_send_notification_http_status_error( async def test_send_notification_multiple_configs(self) -> None: task_id = 'task_multiple_configs' - task_data = create_sample_task(task_id=task_id) - config1 = create_sample_push_config( + task_data = _create_sample_task(task_id=task_id) + config1 = _create_sample_push_config( url='http://notify.me/cfg1', config_id='cfg1' ) - config2 = create_sample_push_config( + config2 = _create_sample_push_config( url='http://notify.me/cfg2', config_id='cfg2' ) self.mock_config_store.get_info.return_value = [config1, config2] @@ -149,7 +178,9 @@ async def test_send_notification_multiple_configs(self) -> None: await self.sender.send_notification(task_data) - self.mock_config_store.get_info.assert_awaited_once_with(task_id) + self.mock_config_store.get_info.assert_awaited_once_with( + task_id, MINIMAL_CALL_CONTEXT + ) self.assertEqual(self.mock_httpx_client.post.call_count, 2) # Check calls for config1 From 38d7df6706e61fdd300086bd2ee8d3c96289ade3 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 17:01:27 +0000 Subject: [PATCH 11/15] fix: add ServerCallContext to tests/e2e/push_notifications/agent_app.py --- tests/e2e/push_notifications/agent_app.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/e2e/push_notifications/agent_app.py b/tests/e2e/push_notifications/agent_app.py index ef8276c4e..dfe71566a 100644 --- a/tests/e2e/push_notifications/agent_app.py +++ b/tests/e2e/push_notifications/agent_app.py @@ -4,6 +4,7 @@ from a2a.server.agent_execution import AgentExecutor, RequestContext from a2a.server.apps import A2ARESTFastAPIApplication +from a2a.server.context import ServerCallContext from a2a.server.events import EventQueue from a2a.server.request_handlers import DefaultRequestHandler from a2a.server.tasks import ( @@ -148,6 +149,7 @@ def create_agent_app( push_sender=BasePushNotificationSender( httpx_client=notification_client, config_store=push_config_store, + context=ServerCallContext(), ), ), ) From c4b282a999317b4a5afa73d035d9692212be5245 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 17:10:05 +0000 Subject: [PATCH 12/15] fix: small fix --- .../default_request_handler.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/a2a/server/request_handlers/default_request_handler.py b/src/a2a/server/request_handlers/default_request_handler.py index 24c4b586e..104b256de 100644 --- a/src/a2a/server/request_handlers/default_request_handler.py +++ b/src/a2a/server/request_handlers/default_request_handler.py @@ -227,7 +227,7 @@ async def _run_event_stream( async def _setup_message_execution( self, params: SendMessageRequest, - context: ServerCallContext, + context: ServerCallContext | None, ) -> tuple[TaskManager, str, EventQueue, ResultAggregator, asyncio.Task]: """Common setup logic for both streaming and non-streaming message handling. @@ -283,7 +283,9 @@ async def _setup_message_execution( and params.configuration.push_notification_config ): await self._push_config_store.set_info( - task_id, params.configuration.push_notification_config, context + task_id, + params.configuration.push_notification_config, + context or ServerCallContext(), ) queue = await self._queue_manager.create_or_tap(task_id) @@ -495,7 +497,7 @@ async def on_create_task_push_notification_config( await self._push_config_store.set_info( task_id, params.config, - context, + context or ServerCallContext(), ) return TaskPushNotificationConfig( @@ -522,7 +524,10 @@ async def on_get_task_push_notification_config( raise ServerError(error=TaskNotFoundError()) push_notification_configs: list[PushNotificationConfig] = ( - await self._push_config_store.get_info(task_id, context) or [] + await self._push_config_store.get_info( + task_id, context or ServerCallContext() + ) + or [] ) for config in push_notification_configs: @@ -598,7 +603,7 @@ async def on_list_task_push_notification_configs( raise ServerError(error=TaskNotFoundError()) push_notification_config_list = await self._push_config_store.get_info( - task_id, context + task_id, context or ServerCallContext() ) return ListTaskPushNotificationConfigsResponse( @@ -629,4 +634,6 @@ async def on_delete_task_push_notification_config( if not task: raise ServerError(error=TaskNotFoundError()) - await self._push_config_store.delete_info(task_id, context, config_id) + await self._push_config_store.delete_info( + task_id, context or ServerCallContext(), config_id + ) From 0090ecc3b73c6eeda52c326af6261068ae7d0e40 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 17:22:30 +0000 Subject: [PATCH 13/15] fix: fix unit test error --- tests/server/request_handlers/test_jsonrpc_handler.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/server/request_handlers/test_jsonrpc_handler.py b/tests/server/request_handlers/test_jsonrpc_handler.py index 2ab43b44d..aa448f354 100644 --- a/tests/server/request_handlers/test_jsonrpc_handler.py +++ b/tests/server/request_handlers/test_jsonrpc_handler.py @@ -550,11 +550,12 @@ async def test_set_push_notification_success(self) -> None: task_id=mock_task.id, config=push_config, ) - response = await handler.set_push_notification_config(request) + context = ServerCallContext() + response = await handler.set_push_notification_config(request, context) self.assertIsInstance(response, dict) self.assertTrue(is_success_response(response)) mock_push_notification_store.set_info.assert_called_once_with( - mock_task.id, push_config, None + mock_task.id, push_config, context ) async def test_get_push_notification_success(self) -> None: From 0f51ef3832ac2e0394d2da482cf7feecffcd05f8 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 18:06:29 +0000 Subject: [PATCH 14/15] fi: fix --- src/a2a/server/tasks/inmemory_task_store.py | 19 +++++++++---------- .../server/tasks/test_inmemory_task_store.py | 8 ++++++++ 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 45d5c5b93..84c6556e8 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -55,16 +55,15 @@ async def get( task_id, owner, ) - owner_tasks = self.tasks.get(owner) - if owner_tasks: - task = owner_tasks.get(task_id) - if task: - logger.debug( - 'Task %s retrieved successfully for owner %s.', - task_id, - owner, - ) - return task + owner_tasks = self.tasks.get(owner, {}) + task = owner_tasks.get(task_id) + if task: + logger.debug( + 'Task %s retrieved successfully for owner %s.', + task_id, + owner, + ) + return task logger.debug( 'Task %s not found in store for owner %s.', task_id, owner ) diff --git a/tests/server/tasks/test_inmemory_task_store.py b/tests/server/tasks/test_inmemory_task_store.py index f6093b64e..6aa1bb7e5 100644 --- a/tests/server/tasks/test_inmemory_task_store.py +++ b/tests/server/tasks/test_inmemory_task_store.py @@ -275,6 +275,9 @@ async def test_owner_resource_scoping() -> None: context_user1 = ServerCallContext(user=SampleUser(user_name='user1')) context_user2 = ServerCallContext(user=SampleUser(user_name='user2')) + context_user3 = ServerCallContext( + user=SampleUser(user_name='user3') + ) # For testing non-existent user # Create tasks for different owners task1_user1 = Task() @@ -298,6 +301,7 @@ async def test_owner_resource_scoping() -> None: assert await store.get('u1-task1', context_user2) is None assert await store.get('u2-task1', context_user1) is None assert await store.get('u2-task1', context_user2) is not None + assert await store.get('u2-task1', context_user3) is None # Test LIST params = ListTasksRequest() @@ -311,6 +315,10 @@ async def test_owner_resource_scoping() -> None: assert {t.id for t in page_user2.tasks} == {'u2-task1'} assert page_user2.total_size == 1 + page_user3 = await store.list(params, context_user3) + assert len(page_user3.tasks) == 0 + assert page_user3.total_size == 0 + # Test DELETE await store.delete('u1-task1', context_user2) # Should not delete assert await store.get('u1-task1', context_user1) is not None From 00e5eacede1ebf8742795cb739e38e1fa5121c22 Mon Sep 17 00:00:00 2001 From: sokoliva Date: Fri, 20 Feb 2026 18:06:51 +0000 Subject: [PATCH 15/15] fix: fix --- ...inmemory_push_notification_config_store.py | 10 ++++----- src/a2a/server/tasks/inmemory_task_store.py | 21 +++++++++++-------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/src/a2a/server/tasks/inmemory_push_notification_config_store.py b/src/a2a/server/tasks/inmemory_push_notification_config_store.py index 4de8b82fa..eb336e329 100644 --- a/src/a2a/server/tasks/inmemory_push_notification_config_store.py +++ b/src/a2a/server/tasks/inmemory_push_notification_config_store.py @@ -70,10 +70,8 @@ async def get_info( """Retrieves all push notification configurations for a task from memory, for the given owner.""" owner = self.owner_resolver(context) async with self.lock: - owner_infos = self._push_notification_infos.get(owner) - if owner_infos: - return list(owner_infos.get(task_id, [])) - return [] + owner_infos = self._push_notification_infos.get(owner, {}) + return list(owner_infos.get(task_id, [])) async def delete_info( self, @@ -88,8 +86,8 @@ async def delete_info( """ owner = self.owner_resolver(context) async with self.lock: - owner_infos = self._push_notification_infos.get(owner) - if not owner_infos or task_id not in owner_infos: + owner_infos = self._push_notification_infos.get(owner, {}) + if task_id not in owner_infos: logger.warning( 'Attempted to delete push notification config for task %s, owner %s that does not exist.', task_id, diff --git a/src/a2a/server/tasks/inmemory_task_store.py b/src/a2a/server/tasks/inmemory_task_store.py index 84c6556e8..019fd773e 100644 --- a/src/a2a/server/tasks/inmemory_task_store.py +++ b/src/a2a/server/tasks/inmemory_task_store.py @@ -160,17 +160,20 @@ async def delete( task_id, owner, ) - if owner in self.tasks and task_id in self.tasks[owner]: - del self.tasks[owner][task_id] - logger.debug( - 'Task %s deleted successfully for owner %s.', task_id, owner - ) - if not self.tasks[owner]: - del self.tasks[owner] - logger.debug('Removed empty owner %s from store.', owner) - else: + + owner_tasks = self.tasks.get(owner, {}) + if task_id not in owner_tasks: logger.warning( 'Attempted to delete nonexistent task with id: %s for owner %s', task_id, owner, ) + return + + del owner_tasks[task_id] + logger.debug( + 'Task %s deleted successfully for owner %s.', task_id, owner + ) + if not owner_tasks: + del self.tasks[owner] + logger.debug('Removed empty owner %s from store.', owner)