diff --git a/pyproject.toml b/pyproject.toml index 1a8f0af68..80b50e73b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,6 +13,7 @@ dependencies = [ "pydantic>=2.11.3", "protobuf>=5.29.5", "google-api-core>=1.26.0", + "jsonschema>=4.23.0", ] classifiers = [ @@ -37,6 +38,8 @@ postgresql = ["sqlalchemy[asyncio,postgresql-asyncpg]>=2.0.0"] mysql = ["sqlalchemy[asyncio,aiomysql]>=2.0.0"] signing = ["PyJWT>=2.0.0"] sqlite = ["sqlalchemy[asyncio,aiosqlite]>=2.0.0"] +performance = ["uvloop>=0.21.0"] +tls = ["cryptography>=43.0.0"] sql = ["a2a-sdk[postgresql,mysql,sqlite]"] @@ -47,6 +50,8 @@ all = [ "a2a-sdk[grpc]", "a2a-sdk[telemetry]", "a2a-sdk[signing]", + "a2a-sdk[performance]", + "a2a-sdk[tls]", ] [project.urls] diff --git a/src/a2a/__init__.py b/src/a2a/__init__.py index 86893a97a..a12e13b05 100644 --- a/src/a2a/__init__.py +++ b/src/a2a/__init__.py @@ -1 +1,32 @@ -"""The A2A Python SDK.""" +"""The A2A Python SDK. + +This SDK provides tools for building A2A (Agent-to-Agent) protocol clients +and servers with support for: + +- TLS/SSL secure communication (see a2a.client.TLSConfig) +- JSON Schema message validation (see a2a.validation) +- High-performance async operations via uvloop (see a2a.performance) + +Example usage: + + from a2a.client import Client, ClientConfig, TLSConfig + from a2a.validation import validate_message + from a2a.performance import install_uvloop + + # Enable uvloop for better performance + install_uvloop() + + # Configure TLS + tls = TLSConfig( + enabled=True, + verify='/path/to/ca.pem', + cert=('/path/to/client.pem', '/path/to/key.pem'), + ) + + config = ClientConfig(tls_config=tls) +""" + +from a2a import client, performance, types, validation + + +__all__ = ['client', 'performance', 'types', 'validation'] diff --git a/src/a2a/client/__init__.py b/src/a2a/client/__init__.py index 4fccd0810..56bad035a 100644 --- a/src/a2a/client/__init__.py +++ b/src/a2a/client/__init__.py @@ -20,6 +20,11 @@ from a2a.client.helpers import create_text_message_object from a2a.client.legacy import A2AClient from a2a.client.middleware import ClientCallContext, ClientCallInterceptor +from a2a.client.tls import ( + TLSConfig, + create_grpc_channel_factory, + create_server_ssl_context, +) logger = logging.getLogger(__name__) @@ -62,6 +67,9 @@ def __init__(self, *args, **kwargs): 'Consumer', 'CredentialService', 'InMemoryContextCredentialStore', + 'TLSConfig', + 'create_grpc_channel_factory', + 'create_server_ssl_context', 'create_text_message_object', 'minimal_agent_card', ] diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 286641a79..fc5909c8d 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Callable, Coroutine -from typing import Any +from typing import TYPE_CHECKING, Any import httpx @@ -24,6 +24,10 @@ ) +if TYPE_CHECKING: + from a2a.client.tls import TLSConfig + + logger = logging.getLogger(__name__) @@ -70,6 +74,46 @@ class ClientConfig: extensions: list[str] = dataclasses.field(default_factory=list) """A list of extension URIs the client supports.""" + tls_config: 'TLSConfig | None' = None + """TLS/SSL configuration for secure communication. If provided, this + will be used to configure secure connections for HTTP and gRPC + transports. Ignored if httpx_client or grpc_channel_factory is + explicitly provided.""" + + validate_messages: bool = False + """Whether to validate messages against JSON Schema before sending + and after receiving. Useful for protocol compliance testing.""" + + def get_httpx_client(self) -> httpx.AsyncClient: + """Get or create an httpx client with TLS configuration. + + Returns: + Configured httpx.AsyncClient instance. + """ + if self.httpx_client is not None: + return self.httpx_client + + if self.tls_config is not None: + return self.tls_config.create_httpx_client() + + return httpx.AsyncClient() + + def get_grpc_channel_factory(self) -> Callable[[str], Channel] | None: + """Get or create a gRPC channel factory with TLS configuration. + + Returns: + A callable that creates gRPC channels, or None. + """ + if self.grpc_channel_factory is not None: + return self.grpc_channel_factory + + if self.tls_config is not None: + from a2a.client.tls import create_grpc_channel_factory + + return create_grpc_channel_factory(self.tls_config) + + return None + UpdateEvent = TaskStatusUpdateEvent | TaskArtifactUpdateEvent | None # Alias for emitted events from client diff --git a/src/a2a/client/client_factory.py b/src/a2a/client/client_factory.py index c3d5762eb..e9e39e677 100644 --- a/src/a2a/client/client_factory.py +++ b/src/a2a/client/client_factory.py @@ -71,12 +71,17 @@ def __init__( def _register_defaults( self, supported: list[str | TransportProtocol] ) -> None: - # Empty support list implies JSON-RPC only. + httpx_client = self._config.httpx_client + if httpx_client is None and self._config.tls_config is not None: + httpx_client = self._config.tls_config.create_httpx_client() + elif httpx_client is None: + httpx_client = httpx.AsyncClient() + if TransportProtocol.jsonrpc in supported or not supported: self.register( TransportProtocol.jsonrpc, lambda card, url, config, interceptors: JsonRpcTransport( - config.httpx_client or httpx.AsyncClient(), + httpx_client, card, url, interceptors, @@ -87,7 +92,7 @@ def _register_defaults( self.register( TransportProtocol.http_json, lambda card, url, config, interceptors: RestTransport( - config.httpx_client or httpx.AsyncClient(), + httpx_client, card, url, interceptors, diff --git a/src/a2a/client/tls.py b/src/a2a/client/tls.py new file mode 100644 index 000000000..28374ec52 --- /dev/null +++ b/src/a2a/client/tls.py @@ -0,0 +1,242 @@ +"""TLS/SSL configuration for secure agent-to-agent communication. + +This module provides TLS configuration classes for securing HTTP and gRPC +transport layers in A2A client-server communication. +""" + +import ssl + +from collections.abc import Callable +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import httpx + + +@dataclass +class TLSConfig: + """TLS/SSL configuration for secure A2A communication. + + This class encapsulates all TLS-related settings for both client and + server-side secure communication, including certificate verification, + client authentication (mTLS), and custom SSL context configuration. + + Attributes: + enabled: Whether TLS is enabled. Defaults to True. + verify: Whether to verify server certificates. Can be a boolean, + path to CA bundle file, or ssl.SSLContext. Defaults to True. + cert: Client certificate for mTLS. Can be a tuple of (cert_file, key_file) + or (cert_file, key_file, password). Defaults to None. + ca_cert: Path to CA certificate file for server verification. + Defaults to None (use system defaults). + min_version: Minimum TLS version. Defaults to TLSv1.2. + cipher_suites: List of allowed cipher suites. Defaults to None + (use system defaults). + verify_hostname: Whether to verify hostname in certificate. + Defaults to True. + """ + + enabled: bool = True + verify: bool | str | ssl.SSLContext | None = True + cert: tuple[str, str] | tuple[str, str, str] | None = None + ca_cert: str | Path | None = None + min_version: str = 'TLSv1_2' + cipher_suites: list[str] | None = None + verify_hostname: bool = True + + def create_ssl_context(self) -> ssl.SSLContext: + """Create an SSL context from this configuration. + + Returns: + A configured ssl.SSLContext instance. + """ + if isinstance(self.verify, ssl.SSLContext): + return self.verify + + context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) + + version_map = { + 'TLSv1_2': ssl.TLSVersion.TLSv1_2, + 'TLSv1_3': ssl.TLSVersion.TLSv1_3, + } + context.minimum_version = version_map.get( + self.min_version, ssl.TLSVersion.TLSv1_2 + ) + + if self.ca_cert: + context.load_verify_locations(cafile=str(self.ca_cert)) + elif isinstance(self.verify, str): + context.load_verify_locations(cafile=self.verify) + + if self.verify is not False: + context.verify_mode = ssl.CERT_REQUIRED + context.check_hostname = self.verify_hostname + + if self.cert: + if len(self.cert) == 2: + context.load_cert_chain(self.cert[0], self.cert[1]) + elif len(self.cert) == 3: + context.load_cert_chain( + self.cert[0], self.cert[1], password=self.cert[2].encode() + ) + + if self.cipher_suites: + context.set_ciphers(':'.join(self.cipher_suites)) + + return context + + def get_httpx_verify(self) -> bool | str | ssl.SSLContext: + """Get the verify parameter for httpx client. + + Returns: + Value suitable for httpx.AsyncClient verify parameter. + """ + if not self.enabled: + return False + if isinstance(self.verify, ssl.SSLContext): + return self.verify + if isinstance(self.verify, str): + return self.verify + if self.ca_cert: + return str(self.ca_cert) + return self.verify if self.verify is not None else True + + def create_httpx_client( + self, + base_url: str | httpx.URL | None = None, + **kwargs: Any, + ) -> httpx.AsyncClient: + """Create an httpx AsyncClient with this TLS configuration. + + Args: + base_url: Base URL for the client. + **kwargs: Additional arguments passed to httpx.AsyncClient. + + Returns: + Configured httpx.AsyncClient instance. + """ + client_kwargs: dict[str, Any] = {**kwargs} + if base_url is not None: + client_kwargs['base_url'] = base_url + + if not self.enabled: + client_kwargs['verify'] = False + return httpx.AsyncClient(**client_kwargs) + + client_kwargs['verify'] = self.get_httpx_verify() + client_kwargs['cert'] = self.cert + + return httpx.AsyncClient(**client_kwargs) + + +def create_grpc_credentials( + tls_config: TLSConfig | None = None, + root_certificates: bytes | None = None, +) -> Any: + """Create gRPC channel credentials from TLS configuration. + + Args: + tls_config: TLS configuration. If None, creates default SSL credentials. + root_certificates: Optional root certificates bytes for verification. + + Returns: + gRPC channel credentials object. + """ + try: + import grpc + except ImportError as e: + raise ImportError( + 'gRPC credentials require grpcio to be installed. ' + "Install with: 'pip install a2a-sdk[grpc]'" + ) from e + + if tls_config is None or not tls_config.enabled: + return grpc.ssl_channel_credentials(root_certificates=root_certificates) + + cert_chain: bytes | None = None + private_key: bytes | None = None + + if tls_config.cert: + with open(tls_config.cert[0], 'rb') as f: + cert_chain = f.read() + with open(tls_config.cert[1], 'rb') as f: + private_key = f.read() + + if tls_config.ca_cert: + with open(tls_config.ca_cert, 'rb') as f: + root_certificates = f.read() + + return grpc.ssl_channel_credentials( + root_certificates=root_certificates, + private_key=private_key, + certificate_chain=cert_chain, + ) + + +def create_grpc_channel_factory( + tls_config: TLSConfig | None = None, +) -> Callable[[str], Any]: + """Create a gRPC channel factory with TLS configuration. + + Args: + tls_config: TLS configuration for secure channels. + + Returns: + A callable that creates gRPC channels for given URLs. + """ + try: + import grpc + except ImportError as e: + raise ImportError( + 'gRPC channel factory requires grpcio to be installed. ' + "Install with: 'pip install a2a-sdk[grpc]'" + ) from e + + def factory(url: str) -> Any: + if tls_config is None or not tls_config.enabled: + return grpc.aio.insecure_channel(url) + + credentials = create_grpc_credentials(tls_config) + return grpc.aio.secure_channel(url, credentials) + + return factory + + +def create_server_ssl_context( + cert_file: str | Path, + key_file: str | Path, + ca_cert: str | Path | None = None, + require_client_cert: bool = False, + min_version: str = 'TLSv1_2', +) -> ssl.SSLContext: + """Create an SSL context for A2A server. + + Args: + cert_file: Path to server certificate file. + key_file: Path to server private key file. + ca_cert: Path to CA certificate for client verification. + require_client_cert: Whether to require client certificates (mTLS). + min_version: Minimum TLS version. + + Returns: + Configured ssl.SSLContext for server use. + """ + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.load_cert_chain(str(cert_file), str(key_file)) + + version_map = { + 'TLSv1_2': ssl.TLSVersion.TLSv1_2, + 'TLSv1_3': ssl.TLSVersion.TLSv1_3, + } + context.minimum_version = version_map.get( + min_version, ssl.TLSVersion.TLSv1_2 + ) + + if ca_cert: + context.load_verify_locations(cafile=str(ca_cert)) + + if require_client_cert: + context.verify_mode = ssl.CERT_REQUIRED + + return context diff --git a/src/a2a/client/transports/grpc.py b/src/a2a/client/transports/grpc.py index 085f8c970..c861f9bf9 100644 --- a/src/a2a/client/transports/grpc.py +++ b/src/a2a/client/transports/grpc.py @@ -80,9 +80,14 @@ def create( interceptors: list[ClientCallInterceptor], ) -> 'GrpcTransport': """Creates a gRPC transport for the A2A client.""" - if config.grpc_channel_factory is None: + channel_factory = config.grpc_channel_factory + if channel_factory is None and config.tls_config is not None: + from a2a.client.tls import create_grpc_channel_factory + + channel_factory = create_grpc_channel_factory(config.tls_config) + if channel_factory is None: raise ValueError('grpc_channel_factory is required when using gRPC') - return cls(config.grpc_channel_factory(url), card, config.extensions) + return cls(channel_factory(url), card, config.extensions) async def send_message( self, diff --git a/src/a2a/performance.py b/src/a2a/performance.py new file mode 100644 index 000000000..784fc4f31 --- /dev/null +++ b/src/a2a/performance.py @@ -0,0 +1,257 @@ +"""uvloop integration for high-performance async operations. + +This module provides utilities for enabling uvloop, a fast drop-in replacement +for asyncio's default event loop. uvloop can significantly improve the +performance of async I/O bound applications. + +Example usage: + + from a2a.performance import install_uvloop, run_with_uvloop + + # Option 1: Install uvloop as the default event loop + install_uvloop() + + # Option 2: Run a specific coroutine with uvloop + async def main(): + # Your async code here + pass + + run_with_uvloop(main()) +""" + +import asyncio +import logging +import sys + +from collections.abc import Coroutine +from typing import Any, TypeVar + + +_T = TypeVar('_T') + +logger = logging.getLogger(__name__) + +_UVLOOP_AVAILABLE = False + +try: + import uvloop + + _UVLOOP_AVAILABLE = True +except ImportError: + uvloop = None # type: ignore[assignment] + + +def is_uvloop_available() -> bool: + """Check if uvloop is available for use. + + Returns: + True if uvloop is installed, False otherwise. + """ + return _UVLOOP_AVAILABLE + + +def is_uvloop_installed() -> bool: + """Check if uvloop is currently installed as the event loop policy. + + Returns: + True if uvloop is the current event loop policy, False otherwise. + """ + if not _UVLOOP_AVAILABLE: + return False + + try: + policy = asyncio.get_event_loop_policy() + return isinstance(policy, uvloop.EventLoopPolicy) # type: ignore[union-attr] + except Exception: + return False + + +def install_uvloop() -> bool: + """Install uvloop as the default event loop policy. + + This should be called before any async code is executed, typically at + the start of your application. + + Returns: + True if uvloop was installed successfully, False if not available. + + Raises: + RuntimeError: If called from a running event loop. + + Example: + if __name__ == '__main__': + install_uvloop() + asyncio.run(main()) + """ + if not _UVLOOP_AVAILABLE: + logger.debug( + 'uvloop is not available. Install with: pip install a2a-sdk[performance]' + ) + return False + + try: + if sys.platform == 'win32': + logger.warning( + 'uvloop is not supported on Windows. Using default asyncio loop.' + ) + return False + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) # type: ignore[union-attr] + logger.info('uvloop event loop policy installed successfully') + return True + except Exception as e: + logger.warning('Failed to install uvloop: %s', e) + return False + + +def uninstall_uvloop() -> bool: + """Uninstall uvloop and restore the default asyncio event loop policy. + + Returns: + True if uvloop was uninstalled, False if it wasn't installed. + """ + if not _UVLOOP_AVAILABLE: + return False + + try: + asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy()) + logger.info('Restored default asyncio event loop policy') + return True + except Exception as e: + logger.warning('Failed to uninstall uvloop: %s', e) + return False + + +def run_with_uvloop(coro: Coroutine[Any, Any, _T]) -> _T: + """Run a coroutine with uvloop if available, otherwise use asyncio.run(). + + This is a convenience function that automatically handles uvloop + installation and cleanup. + + Args: + coro: The coroutine to run. + + Returns: + The result of the coroutine. + + Example: + async def main(): + client = A2AClient(...) + await client.send_message(...) + + result = run_with_uvloop(main()) + """ + if sys.platform == 'win32': + return asyncio.run(coro) + + if not _UVLOOP_AVAILABLE: + logger.debug('uvloop not available, using asyncio.run()') + return asyncio.run(coro) + + was_installed = is_uvloop_installed() + + if not was_installed: + install_uvloop() + + try: + return asyncio.run(coro) + finally: + if not was_installed: + uninstall_uvloop() + + +class UvloopRunner: + """Context manager for running code with uvloop. + + Provides a clean way to enable uvloop for a block of code and + automatically restore the previous event loop policy on exit. + + Example: + async def main(): + async with UvloopRunner(): + # uvloop is active here + client = A2AClient(...) + await client.send_message(...) + """ + + def __init__(self, *, force: bool = False): + """Initialize the uvloop runner. + + Args: + force: If True, raise an error if uvloop is not available. + """ + self._force = force + self._was_installed = False + self._previous_policy: asyncio.AbstractEventLoopPolicy | None = None + + def __enter__(self) -> 'UvloopRunner': + if not _UVLOOP_AVAILABLE: + if self._force: + raise RuntimeError( + 'uvloop is not available. Install with: ' + 'pip install a2a-sdk[performance]' + ) + return self + + self._was_installed = is_uvloop_installed() + + if not self._was_installed: + self._previous_policy = asyncio.get_event_loop_policy() + install_uvloop() + + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if not _UVLOOP_AVAILABLE: + return + + if not self._was_installed and self._previous_policy is not None: + asyncio.set_event_loop_policy(self._previous_policy) + + +def get_event_loop_optimization_info() -> dict[str, Any]: + """Get information about the current event loop setup. + + Returns: + Dictionary with event loop information. + """ + info = { + 'uvloop_available': _UVLOOP_AVAILABLE, + 'uvloop_installed': is_uvloop_installed(), + 'platform': sys.platform, + 'python_version': sys.version, + } + + try: + policy = asyncio.get_event_loop_policy() + info['policy_class'] = policy.__class__.__name__ + except Exception as e: + info['policy_error'] = str(e) + + try: + loop = asyncio.get_running_loop() + info['loop_class'] = loop.__class__.__name__ + except RuntimeError: + info['loop_class'] = None + + return info + + +def optimize_event_loop() -> bool: + """Optimize the event loop for A2A operations. + + This function installs uvloop if available and appropriate for the + current platform. It's designed to be called at application startup. + + Returns: + True if optimization was applied, False otherwise. + """ + if sys.platform == 'win32': + logger.debug('Event loop optimization not available on Windows') + return False + + if is_uvloop_installed(): + logger.debug('uvloop already installed') + return True + + return install_uvloop() diff --git a/src/a2a/validation.py b/src/a2a/validation.py new file mode 100644 index 000000000..e2eb29b0d --- /dev/null +++ b/src/a2a/validation.py @@ -0,0 +1,406 @@ +"""JSON Schema validation for A2A protocol messages. + +This module provides JSON Schema generation and validation for all A2A +message types, ensuring protocol compliance for incoming and outgoing messages. +""" + +from functools import lru_cache +from typing import Any, TypeVar + +import jsonschema + +from pydantic import BaseModel + +from a2a.types import ( + AgentCard, + CancelTaskRequest, + CancelTaskResponse, + GetTaskPushNotificationConfigParams, + GetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigResponse, + GetTaskRequest, + GetTaskResponse, + JSONRPCErrorResponse, + JSONRPCRequest, + Message, + MessageSendParams, + SendMessageRequest, + SendMessageResponse, + SendStreamingMessageRequest, + SendStreamingMessageResponse, + SetTaskPushNotificationConfigRequest, + SetTaskPushNotificationConfigResponse, + Task, + TaskArtifactUpdateEvent, + TaskIdParams, + TaskPushNotificationConfig, + TaskQueryParams, + TaskResubscriptionRequest, + TaskStatusUpdateEvent, +) + + +T = TypeVar('T', bound=BaseModel) + +_MESSAGE_TYPES: tuple[type[BaseModel], ...] = ( + SendMessageRequest, + SendStreamingMessageRequest, + GetTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + TaskResubscriptionRequest, + SendMessageResponse, + SendStreamingMessageResponse, + GetTaskResponse, + CancelTaskResponse, + SetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigResponse, + JSONRPCRequest, + JSONRPCErrorResponse, + Message, + Task, + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + AgentCard, + TaskQueryParams, + TaskIdParams, + MessageSendParams, + TaskPushNotificationConfig, + GetTaskPushNotificationConfigParams, +) + +_TYPE_SCHEMA_CACHE: dict[type[BaseModel], dict[str, Any]] = {} + + +class ValidationError(Exception): + """Raised when message validation fails against JSON Schema.""" + + def __init__( + self, + message: str, + errors: list[dict[str, Any]] | None = None, + schema: dict[str, Any] | None = None, + instance: Any = None, + ): + super().__init__(message) + self.errors = errors or [] + self.schema = schema + self.instance = instance + + +def get_schema_for_type(model_type: type[BaseModel]) -> dict[str, Any]: + """Generate JSON Schema for a Pydantic model type. + + Args: + model_type: A Pydantic BaseModel subclass. + + Returns: + A dictionary containing the JSON Schema for the model. + """ + if model_type in _TYPE_SCHEMA_CACHE: + return _TYPE_SCHEMA_CACHE[model_type] + + schema = model_type.model_json_schema( + mode='serialization', + by_alias=True, + ref_template='#/definitions/{model}', + ) + + if '$defs' in schema: + definitions = schema.pop('$defs') + if definitions: + schema['definitions'] = definitions + + _TYPE_SCHEMA_CACHE[model_type] = schema + return schema + + +@lru_cache(maxsize=1) +def get_protocol_schemas() -> dict[str, dict[str, Any]]: + """Generate JSON Schemas for all A2A protocol message types. + + Returns: + A dictionary mapping type names to their JSON Schemas. + """ + schemas: dict[str, dict[str, Any]] = {} + + for model_type in _MESSAGE_TYPES: + schema = get_schema_for_type(model_type) + schemas[model_type.__name__] = schema + + return schemas + + +def get_request_schemas() -> dict[str, dict[str, Any]]: + """Get JSON Schemas for all A2A request types. + + Returns: + Dictionary of request type names to their schemas. + """ + request_types = ( + SendMessageRequest, + SendStreamingMessageRequest, + GetTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + TaskResubscriptionRequest, + ) + return {t.__name__: get_schema_for_type(t) for t in request_types} + + +def get_response_schemas() -> dict[str, dict[str, Any]]: + """Get JSON Schemas for all A2A response types. + + Returns: + Dictionary of response type names to their schemas. + """ + response_types = ( + SendMessageResponse, + SendStreamingMessageResponse, + GetTaskResponse, + CancelTaskResponse, + SetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigResponse, + ) + return {t.__name__: get_schema_for_type(t) for t in response_types} + + +def get_event_schemas() -> dict[str, dict[str, Any]]: + """Get JSON Schemas for all A2A event types. + + Returns: + Dictionary of event type names to their schemas. + """ + event_types = ( + TaskStatusUpdateEvent, + TaskArtifactUpdateEvent, + ) + return {t.__name__: get_schema_for_type(t) for t in event_types} + + +def validate_message( + data: dict[str, Any], + model_type: type[T], + *, + strict: bool = True, +) -> T: + """Validate message data against a Pydantic model's JSON Schema. + + This performs both JSON Schema validation and Pydantic model validation. + + Args: + data: The raw message data to validate. + model_type: The expected Pydantic model type. + strict: Whether to use strict validation mode. + + Returns: + The validated and parsed model instance. + + Raises: + ValidationError: If validation fails against the schema. + """ + schema = get_schema_for_type(model_type) + + try: + jsonschema.validate( + instance=data, + schema=schema, + cls=jsonschema.Draft7Validator, + ) + except jsonschema.ValidationError as e: + raise ValidationError( + f'JSON Schema validation failed: {e.message}', + errors=[{'path': list(e.path), 'message': e.message}], + schema=schema, + instance=data, + ) from e + + try: + return model_type.model_validate(data, strict=strict) + except Exception as e: + raise ValidationError( + f'Pydantic validation failed: {e}', + schema=schema, + instance=data, + ) from e + + +def _validate_against_types( + data: dict[str, Any], + model_types: tuple[type[BaseModel], ...], + category_name: str, +) -> BaseModel: + """Validate data against multiple model types and return first match. + + Args: + data: Raw data to validate. + model_types: Tuple of model types to try. + category_name: Name of the category for error messages (e.g., 'request', 'response'). + + Returns: + The validated model instance. + + Raises: + ValidationError: If data doesn't match any of the provided types. + """ + errors: list[dict[str, Any]] = [] + + for model_type in model_types: + try: + return validate_message(data, model_type, strict=False) + except ValidationError as e: + errors.append( + { + 'type': model_type.__name__, + 'path': e.errors[0].get('path', []) if e.errors else [], + 'message': e.errors[0].get('message', str(e)) + if e.errors + else str(e), + } + ) + + error_details = '; '.join( + f'{e["type"]}: {e["message"]} (path: {".".join(map(str, e["path"])) or "root"})' + for e in errors + ) + raise ValidationError( + f'Data does not match any known A2A {category_name} type. ' + f'Attempted types: {[e["type"] for e in errors]}. Details: {error_details}', + errors=errors, + instance=data, + ) + + +def validate_request(data: dict[str, Any]) -> BaseModel: + """Validate and parse an A2A request message. + + Attempts to validate against all known request types and returns + the first successful match. + + Args: + data: Raw request data to validate. + + Returns: + The validated request model instance. + + Raises: + ValidationError: If data doesn't match any request type. + """ + request_types = ( + SendMessageRequest, + SendStreamingMessageRequest, + GetTaskRequest, + CancelTaskRequest, + SetTaskPushNotificationConfigRequest, + GetTaskPushNotificationConfigRequest, + TaskResubscriptionRequest, + ) + + return _validate_against_types(data, request_types, 'request') + + +def validate_response(data: dict[str, Any]) -> BaseModel: + """Validate and parse an A2A response message. + + Args: + data: Raw response data to validate. + + Returns: + The validated response model instance. + + Raises: + ValidationError: If validation fails. + """ + response_types = ( + SendMessageResponse, + SendStreamingMessageResponse, + GetTaskResponse, + CancelTaskResponse, + SetTaskPushNotificationConfigResponse, + GetTaskPushNotificationConfigResponse, + ) + + return _validate_against_types(data, response_types, 'response') + + +class MessageValidator: + """A reusable validator for A2A messages with caching. + + This class provides efficient validation by caching schemas and + supporting batch validation operations. + """ + + def __init__(self, *, strict: bool = True): + """Initialize the message validator. + + Args: + strict: Whether to use strict validation mode by default. + """ + self._strict = strict + self._schemas = get_protocol_schemas() + + def validate( + self, + data: dict[str, Any], + model_type: type[T], + ) -> T: + """Validate data against a specific model type. + + Args: + data: Raw data to validate. + model_type: Expected model type. + + Returns: + Validated model instance. + + Raises: + ValidationError: If validation fails. + """ + return validate_message(data, model_type, strict=self._strict) + + def validate_batch( + self, + messages: list[dict[str, Any]], + model_type: type[T], + ) -> list[T]: + """Validate multiple messages of the same type. + + Args: + messages: List of raw message data. + model_type: Expected model type for all messages. + + Returns: + List of validated model instances. + + Raises: + ValidationError: If any message fails validation. + """ + return [self.validate(msg, model_type) for msg in messages] + + def get_schema(self, type_name: str) -> dict[str, Any] | None: + """Get a cached schema by type name. + + Args: + type_name: Name of the type to get schema for. + + Returns: + JSON Schema dictionary or None if not found. + """ + return self._schemas.get(type_name) + + def list_schemas(self) -> list[str]: + """List all available schema type names. + + Returns: + List of type names with cached schemas. + """ + return list(self._schemas.keys()) + + def clear_cache(self) -> None: + """Clear all cached schemas.""" + self._schemas.clear() + _TYPE_SCHEMA_CACHE.clear() + get_protocol_schemas.cache_clear() + self._schemas = get_protocol_schemas()