Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/actions/spelling/expect.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
excinfo
fetchrow
fetchval
GVsb
initdb
isready
notif
otherurl
POSTGRES
postgres
postgresql
1 change: 1 addition & 0 deletions .github/workflows/spelling.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ jobs:
cspell:sql/src/tsql.txt
cspell:terraform/dict/terraform.txt
cspell:typescript/dict/typescript.txt

check_extra_dictionaries: ""
only_check_changed_files: true
longest_word: "10"
14 changes: 14 additions & 0 deletions .github/workflows/unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ jobs:
runs-on: ubuntu-latest

if: github.repository == 'google/a2a-python'
services:
postgres:
image: postgres:15-alpine
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: a2a_test
ports:
- 5432:5432

strategy:
matrix:
Expand All @@ -28,6 +37,11 @@ jobs:
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Set postgres for tests
run: |
sudo apt-get update && sudo apt-get install -y postgresql-client
PGPASSWORD=postgres psql -h localhost -p 5432 -U postgres -d a2a_test -f ${{ github.workspace }}/docker/postgres/init.sql
export POSTGRES_TEST_DSN="postgresql://postgres:postgres@localhost:5432/a2a_test"

- name: Install uv
run: |
Expand Down
28 changes: 28 additions & 0 deletions docker/postgres/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
version: "3.8"

services:
postgres:
image: postgres:15-alpine
ports:
- "5432:5432"
environment:
- POSTGRES_USER=postgres
- POSTGRES_PASSWORD=postgres
- POSTGRES_DB=a2a_test
volumes:
- postgres_data:/var/lib/postgresql/data
- ./docker/postgres/init.sql:/docker-entrypoint-initdb.d/init.sql
networks:
- a2a-network
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5

volumes:
postgres_data:

networks:
a2a-network:
driver: bridge
8 changes: 8 additions & 0 deletions docker/postgres/init.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- Create a dedicated user for the application
CREATE USER a2a WITH PASSWORD 'a2a_password';

-- Create the tasks database
CREATE DATABASE a2a_tasks;

GRANT ALL PRIVILEGES ON DATABASE a2a_test TO a2a;

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ authors = [{ name = "Google LLC", email = "googleapis-packages@google.com" }]
requires-python = ">=3.13"
keywords = ["A2A", "A2A SDK", "A2A Protocol", "Agent2Agent"]
dependencies = [
"asyncpg>=0.30.0",
"httpx>=0.28.1",
"httpx-sse>=0.4.0",
"opentelemetry-api>=1.33.0",
Expand Down Expand Up @@ -70,6 +71,7 @@ members = [

[dependency-groups]
dev = [
"asyncpg-stubs>=0.30.1",
"datamodel-code-generator>=0.30.0",
"mypy>=1.15.0",
"pytest>=8.3.5",
Expand Down
2 changes: 2 additions & 0 deletions src/a2a/server/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from a2a.server.tasks.inmemory_push_notifier import InMemoryPushNotifier
from a2a.server.tasks.inmemory_task_store import InMemoryTaskStore
from a2a.server.tasks.postgresql_task_store import PostgreSQLTaskStore
from a2a.server.tasks.push_notifier import PushNotifier
from a2a.server.tasks.result_aggregator import ResultAggregator
from a2a.server.tasks.task_manager import TaskManager
Expand All @@ -12,6 +13,7 @@
__all__ = [
'InMemoryPushNotifier',
'InMemoryTaskStore',
'PostgreSQLTaskStore',
'PushNotifier',
'ResultAggregator',
'TaskManager',
Expand Down
136 changes: 136 additions & 0 deletions src/a2a/server/tasks/postgresql_task_store.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import json
import logging

import asyncpg

from a2a.server.tasks.task_store import TaskStore
from a2a.types import Task


logger = logging.getLogger(__name__)


class PostgreSQLTaskStore(TaskStore):
"""PostgreSQL implementation of TaskStore.

Stores task objects in a PostgreSQL database.
"""

def __init__(
self,
url: str,
table_name: str = 'tasks',
create_table: bool = True,
) -> None:
"""Initializes the PostgreSQLTaskStore.

Args:
url: PostgreSQL connection string in the format:
postgresql://username:password@hostname:port/database
table_name: The name of the table to store tasks in
create_table: Whether to create the table if it doesn't exist
"""
logger.debug('Initializing PostgreSQLTaskStore')
self.url = url
self.table_name = table_name
self.create_table = create_table

self.pool: asyncpg.Pool | None = None

async def initialize(self) -> None:
"""Initialize the database connection pool and create the table
if needed.
"""
if self.pool is not None:
return

logger.debug('Creating connection pool')
self.pool = await asyncpg.create_pool(self.url)

if self.create_table:
async with self.pool.acquire() as conn:
logger.debug('Creating tasks table if not exists')
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.table_name} (
id TEXT PRIMARY KEY,
data JSONB NOT NULL

)
"""
)

async def close(self) -> None:
"""Close the database connection pool."""
if self.pool is not None:
await self.pool.close()
self.pool = None

async def save(self, task: Task) -> None:
"""Saves or updates a task in the PostgreSQL store."""
await self._ensure_initialized()

assert self.pool is not None
async with self.pool.acquire() as conn, conn.transaction():
task_json = task.model_dump()

await conn.execute(
f"""
INSERT INTO {self.table_name} (id, data)
VALUES ($1, $2)
ON CONFLICT (id) DO UPDATE
SET data = $2
""",
task.id,
json.dumps(task_json),
)

logger.info('Task %s saved successfully.', task.id)

async def get(self, task_id: str) -> Task | None:
"""Retrieves a task from the PostgreSQL store by ID."""
await self._ensure_initialized()

assert self.pool is not None
async with self.pool.acquire() as conn, conn.transaction():
logger.debug('Attempting to get task with id: %s', task_id)

row = await conn.fetchrow(
f'SELECT data FROM {self.table_name} WHERE id = $1',
task_id,
)

if row:
task_json = json.loads(row['data'])
task = Task.model_validate(task_json)
logger.debug('Task %s retrieved successfully.', task_id)
return task

logger.debug('Task %s not found in store.', task_id)
return None

async def delete(self, task_id: str) -> None:
"""Deletes a task from the PostgreSQL store by ID."""
await self._ensure_initialized()

assert self.pool is not None
async with self.pool.acquire() as conn, conn.transaction():
logger.debug('Attempting to delete task with id: %s', task_id)

result = await conn.execute(
f'DELETE FROM {self.table_name} WHERE id = $1',
task_id,
)

if result.split()[-1] != '0': # Check if rows were affected
logger.info('Task %s deleted successfully.', task_id)
else:
logger.warning(
'Attempted to delete nonexistent task with id: %s',
task_id,
)

async def _ensure_initialized(self) -> None:
"""Ensure the database connection is initialized."""
if self.pool is None:
await self.initialize()
18 changes: 13 additions & 5 deletions src/a2a/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,7 @@ class TaskResubscriptionRequest(BaseModel):
"""


class TaskState(Enum):
class TaskState(str, Enum):
"""
Represents the possible states of a Task.
"""
Expand Down Expand Up @@ -1088,7 +1088,9 @@ class Artifact(BaseModel):


class GetTaskPushNotificationConfigResponse(
RootModel[JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse]
RootModel[
JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse
]
):
root: JSONRPCErrorResponse | GetTaskPushNotificationConfigSuccessResponse
"""
Expand Down Expand Up @@ -1239,7 +1241,9 @@ class SendStreamingMessageRequest(BaseModel):


class SetTaskPushNotificationConfigResponse(
RootModel[JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse]
RootModel[
JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse
]
):
root: JSONRPCErrorResponse | SetTaskPushNotificationConfigSuccessResponse
"""
Expand Down Expand Up @@ -1524,7 +1528,9 @@ class SendStreamingMessageSuccessResponse(BaseModel):
"""


class CancelTaskResponse(RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse]):
class CancelTaskResponse(
RootModel[JSONRPCErrorResponse | CancelTaskSuccessResponse]
):
root: JSONRPCErrorResponse | CancelTaskSuccessResponse
"""
JSON-RPC response for the 'tasks/cancel' method.
Expand Down Expand Up @@ -1563,7 +1569,9 @@ class JSONRPCResponse(
"""


class SendMessageResponse(RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse]):
class SendMessageResponse(
RootModel[JSONRPCErrorResponse | SendMessageSuccessResponse]
):
root: JSONRPCErrorResponse | SendMessageSuccessResponse
"""
JSON-RPC response model for the 'message/send' method.
Expand Down
Loading