diff --git a/src/a2a/client/client.py b/src/a2a/client/client.py index 286641a79..c79f76d82 100644 --- a/src/a2a/client/client.py +++ b/src/a2a/client/client.py @@ -3,10 +3,13 @@ from abc import ABC, abstractmethod from collections.abc import AsyncIterator, Callable, Coroutine +from types import TracebackType from typing import Any import httpx +from typing_extensions import Self + from a2a.client.middleware import ClientCallContext, ClientCallInterceptor from a2a.client.optionals import Channel from a2a.types import ( @@ -107,6 +110,19 @@ def __init__( self._consumers = consumers self._middleware = middleware + async def __aenter__(self) -> Self: + """Enters the async context manager.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exits the async context manager and closes the client.""" + await self.close() + @abstractmethod async def send_message( self, @@ -209,3 +225,7 @@ async def consume( return for c in self._consumers: await c(event, card) + + @abstractmethod + async def close(self) -> None: + """Closes the client and releases any underlying resources."""