diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index 85d0c9584ea9..c4d938ada8d9 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -49,8 +49,6 @@ def from_agent_framework( credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, thread_repository: Optional[AgentThreadRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, - managed_checkpoints: bool = False, - project_endpoint: Optional[str] = None, ) -> "AgentFrameworkWorkflowAdapter": """ Create an Agent Framework Workflow Adapter. @@ -68,13 +66,9 @@ def from_agent_framework( :param thread_repository: Optional thread repository for agent thread management. :type thread_repository: Optional[AgentThreadRepository] :param checkpoint_repository: Optional checkpoint repository for workflow checkpointing. + Use ``InMemoryCheckpointRepository``, ``FileCheckpointRepository``, or + ``FoundryCheckpointRepository`` for Azure AI Foundry managed storage. :type checkpoint_repository: Optional[CheckpointRepository] - :param managed_checkpoints: If True, use Azure AI Foundry managed checkpoint storage. - :type managed_checkpoints: bool - :param project_endpoint: The Azure AI Foundry project endpoint. If not provided, - will be read from AZURE_AI_PROJECT_ENDPOINT environment variable. - Example: "https://.services.ai.azure.com/api/projects/" - :type project_endpoint: Optional[str] :return: An instance of AgentFrameworkWorkflowAdapter. :rtype: AgentFrameworkWorkflowAdapter """ @@ -86,8 +80,6 @@ def from_agent_framework( credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, thread_repository: Optional[AgentThreadRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, - managed_checkpoints: bool = False, - project_endpoint: Optional[str] = None, ) -> "AgentFrameworkAgent": """ Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a @@ -101,19 +93,13 @@ def from_agent_framework( :param thread_repository: Optional thread repository for agent thread management. :type thread_repository: Optional[AgentThreadRepository] :param checkpoint_repository: Optional checkpoint repository for workflow checkpointing. + Use ``InMemoryCheckpointRepository``, ``FileCheckpointRepository``, or + ``FoundryCheckpointRepository`` for Azure AI Foundry managed storage. :type checkpoint_repository: Optional[CheckpointRepository] - :param managed_checkpoints: If True, use Azure AI Foundry managed checkpoint storage. - :type managed_checkpoints: bool - :param project_endpoint: The Azure AI Foundry project endpoint. If not provided, - will be read from AZURE_AI_PROJECT_ENDPOINT environment variable. - Example: "https://.services.ai.azure.com/api/projects/" - :type project_endpoint: Optional[str] :return: An instance of AgentFrameworkAgent. :rtype: AgentFrameworkAgent :raises TypeError: If neither or both of agent and workflow are provided, or if the provided types are incorrect. - :raises ValueError: If managed_checkpoints=True but required parameters are missing, - or if both managed_checkpoints=True and checkpoint_repository are provided. """ if isinstance(agent_or_workflow, WorkflowBuilder): @@ -122,8 +108,6 @@ def from_agent_framework( credentials=credentials, thread_repository=thread_repository, checkpoint_repository=checkpoint_repository, - managed_checkpoints=managed_checkpoints, - project_endpoint=project_endpoint, ) if isinstance(agent_or_workflow, Callable): # type: ignore return AgentFrameworkWorkflowAdapter( @@ -131,8 +115,6 @@ def from_agent_framework( credentials=credentials, thread_repository=thread_repository, checkpoint_repository=checkpoint_repository, - managed_checkpoints=managed_checkpoints, - project_endpoint=project_endpoint, ) # raise TypeError("workflow must be a WorkflowBuilder or callable returning a Workflow") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py index d5cf61ad62ae..a119d697a377 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -16,7 +16,7 @@ from azure.core.credentials_async import AsyncTokenCredential from azure.ai.agentserver.core import AgentRunContext -from azure.ai.agentserver.core.logger import get_logger, get_project_endpoint +from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import ( Response as OpenAIResponse, ResponseStreamEvent, @@ -28,7 +28,7 @@ from .models.agent_framework_output_non_streaming_converter import ( AgentFrameworkOutputNonStreamingConverter, ) -from .persistence import AgentThreadRepository, CheckpointRepository, FoundryCheckpointRepository +from .persistence import AgentThreadRepository, CheckpointRepository logger = get_logger() @@ -40,35 +40,9 @@ def __init__( credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, thread_repository: Optional[AgentThreadRepository] = None, checkpoint_repository: Optional[CheckpointRepository] = None, - managed_checkpoints: bool = False, - project_endpoint: Optional[str] = None, ) -> None: super().__init__(credentials, thread_repository) self._workflow_factory = workflow_factory - - # Validate mutual exclusion of managed_checkpoints and checkpoint_repository - if managed_checkpoints and checkpoint_repository is not None: - raise ValueError( - "Cannot use both managed_checkpoints=True and checkpoint_repository. " - "Use managed_checkpoints=True for Azure AI Foundry managed storage, " - "or provide your own checkpoint_repository, but not both." - ) - - # Handle managed checkpoints - if managed_checkpoints: - resolved_endpoint = get_project_endpoint() or project_endpoint - if not resolved_endpoint: - raise ValueError( - "project_endpoint is required when managed_checkpoints=True. " - "Set AZURE_AI_PROJECT_ENDPOINT environment variable or pass project_endpoint parameter." - ) - if not credentials: - raise ValueError("credentials are required when managed_checkpoints=True") - checkpoint_repository = FoundryCheckpointRepository( - project_endpoint=resolved_endpoint, - credential=credentials, - ) - self._checkpoint_repository = checkpoint_repository async def agent_run( # pylint: disable=too-many-statements diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py index a0cfbea51f6c..c1616c475cb4 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py @@ -1,93 +1,47 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -"""Unit tests for from_agent_framework with managed checkpoints.""" +"""Unit tests for from_agent_framework with checkpoint repository.""" -import os import pytest -from unittest.mock import Mock, AsyncMock, patch +from unittest.mock import Mock from azure.core.credentials_async import AsyncTokenCredential @pytest.mark.unit -def test_managed_checkpoints_requires_project_endpoint() -> None: - """Test that managed_checkpoints=True requires project_endpoint when env var not set.""" +def test_checkpoint_repository_is_optional() -> None: + """Test that checkpoint_repository is optional and defaults to None.""" from azure.ai.agentserver.agentframework import from_agent_framework from agent_framework import WorkflowBuilder builder = WorkflowBuilder() - mock_credential = Mock(spec=AsyncTokenCredential) - # Ensure environment variable is not set - with patch.dict(os.environ, {}, clear=True): - with pytest.raises(ValueError) as exc_info: - from_agent_framework( - builder, - credentials=mock_credential, - managed_checkpoints=True, - project_endpoint=None, - ) + # Should not raise + adapter = from_agent_framework(builder) - assert "project_endpoint" in str(exc_info.value) + assert adapter is not None @pytest.mark.unit -def test_managed_checkpoints_requires_credentials() -> None: - """Test that managed_checkpoints=True requires credentials.""" +def test_foundry_checkpoint_repository_passed_directly() -> None: + """Test that FoundryCheckpointRepository can be passed via checkpoint_repository.""" from azure.ai.agentserver.agentframework import from_agent_framework + from azure.ai.agentserver.agentframework.persistence import FoundryCheckpointRepository from agent_framework import WorkflowBuilder builder = WorkflowBuilder() + mock_credential = Mock(spec=AsyncTokenCredential) - with pytest.raises(ValueError) as exc_info: - from_agent_framework( - builder, - credentials=None, - managed_checkpoints=True, - project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", - ) - - assert "credentials" in str(exc_info.value) - - -@pytest.mark.unit -def test_managed_checkpoints_false_does_not_require_parameters() -> None: - """Test that managed_checkpoints=False does not require project_endpoint.""" - from azure.ai.agentserver.agentframework import from_agent_framework - from agent_framework import WorkflowBuilder - - builder = WorkflowBuilder() + repo = FoundryCheckpointRepository( + project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", + credential=mock_credential, + ) - # Should not raise adapter = from_agent_framework( builder, - managed_checkpoints=False, + checkpoint_repository=repo, ) assert adapter is not None - - -@pytest.mark.unit -def test_managed_checkpoints_and_checkpoint_repository_are_mutually_exclusive() -> None: - """Test that managed_checkpoints=True and checkpoint_repository cannot be used together.""" - from azure.ai.agentserver.agentframework import from_agent_framework - from azure.ai.agentserver.agentframework.persistence import InMemoryCheckpointRepository - from agent_framework import WorkflowBuilder - - builder = WorkflowBuilder() - mock_credential = Mock(spec=AsyncTokenCredential) - checkpoint_repo = InMemoryCheckpointRepository() - - with pytest.raises(ValueError) as exc_info: - from_agent_framework( - builder, - credentials=mock_credential, - managed_checkpoints=True, - checkpoint_repository=checkpoint_repo, - project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", - ) - - assert "Cannot use both" in str(exc_info.value) - assert "managed_checkpoints" in str(exc_info.value) - assert "checkpoint_repository" in str(exc_info.value) + assert adapter._checkpoint_repository is repo diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py index f63eaa05ca0c..a3eef4358564 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py @@ -3,7 +3,7 @@ # --------------------------------------------------------- __path__ = __import__("pkgutil").extend_path(__path__, __name__) -from typing import Optional, TYPE_CHECKING +from typing import Optional, Union, TYPE_CHECKING from azure.ai.agentserver.core.application import PackageMetadata, set_current_app @@ -12,17 +12,30 @@ from .langgraph import LangGraphAdapter if TYPE_CHECKING: # pragma: no cover + from langgraph.graph.state import CompiledStateGraph from .models.response_api_converter import ResponseAPIConverter from azure.core.credentials_async import AsyncTokenCredential + from azure.core.credentials import TokenCredential def from_langgraph( - agent, + agent: "CompiledStateGraph", /, - credentials: Optional["AsyncTokenCredential"] = None, - converter: Optional["ResponseAPIConverter"] = None + credentials: Optional[Union["AsyncTokenCredential", "TokenCredential"]] = None, + converter: Optional["ResponseAPIConverter"] = None, ) -> "LangGraphAdapter": - + """Create a LangGraph adapter for Azure AI Agent Server. + + :param agent: The compiled LangGraph state graph. To use persistent checkpointing, + compile the graph with a checkpointer via ``builder.compile(checkpointer=saver)``. + :type agent: CompiledStateGraph + :param credentials: Azure credentials for authentication. + :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] + :param converter: Custom response converter. + :type converter: Optional[ResponseAPIConverter] + :return: A LangGraphAdapter instance. + :rtype: LangGraphAdapter + """ return LangGraphAdapter(agent, credentials=credentials, converter=converter) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/__init__.py new file mode 100644 index 000000000000..9e91582733d3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/__init__.py @@ -0,0 +1,8 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint saver implementations for LangGraph with Azure AI Foundry.""" + +from ._foundry_checkpoint_saver import FoundryCheckpointSaver + +__all__ = ["FoundryCheckpointSaver"] diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_foundry_checkpoint_saver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_foundry_checkpoint_saver.py new file mode 100644 index 000000000000..abfeb4bf3fa8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_foundry_checkpoint_saver.py @@ -0,0 +1,606 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Foundry-backed checkpoint saver for LangGraph.""" + +import logging +from contextlib import AbstractAsyncContextManager +from types import TracebackType +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + SerializerProtocol, + get_checkpoint_id, +) + +from azure.ai.agentserver.core.checkpoints.client import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, + FoundryCheckpointClient, +) + +from ._item_id import ItemType, ParsedItemId, make_item_id, parse_item_id + +logger = logging.getLogger(__name__) + + +class FoundryCheckpointSaver( + BaseCheckpointSaver[str], AbstractAsyncContextManager["FoundryCheckpointSaver"] +): + """Checkpoint saver backed by Azure AI Foundry checkpoint storage. + + Implements LangGraph's BaseCheckpointSaver interface using the + FoundryCheckpointClient for remote storage. + + This saver only supports async operations. Sync methods will raise + NotImplementedError. + + :param project_endpoint: The Azure AI Foundry project endpoint URL. + Example: "https://.services.ai.azure.com/api/projects/" + :type project_endpoint: str + :param credential: Credential for authentication. Must be an async credential. + :type credential: Union[AsyncTokenCredential, TokenCredential] + :param serde: Optional serializer protocol. Defaults to JsonPlusSerializer. + :type serde: Optional[SerializerProtocol] + + Example:: + + from azure.ai.agentserver.langgraph.checkpointer import FoundryCheckpointSaver + from azure.identity.aio import DefaultAzureCredential + + saver = FoundryCheckpointSaver( + project_endpoint="https://myresource.services.ai.azure.com/api/projects/my-project", + credential=DefaultAzureCredential(), + ) + + # Use with LangGraph + graph = builder.compile(checkpointer=saver) + """ + + def __init__( + self, + project_endpoint: str, + credential: Union[AsyncTokenCredential, TokenCredential], + *, + serde: Optional[SerializerProtocol] = None, + ) -> None: + """Initialize the Foundry checkpoint saver. + + :param project_endpoint: The Azure AI Foundry project endpoint URL. + :type project_endpoint: str + :param credential: Credential for authentication. Must be an async credential. + :type credential: Union[AsyncTokenCredential, TokenCredential] + :param serde: Optional serializer protocol. + :type serde: Optional[SerializerProtocol] + :raises TypeError: If credential is not an AsyncTokenCredential. + """ + super().__init__(serde=serde) + if not isinstance(credential, AsyncTokenCredential): + raise TypeError( + "FoundryCheckpointSaver requires an AsyncTokenCredential. " + "Please use an async credential like DefaultAzureCredential from azure.identity.aio." + ) + self._client = FoundryCheckpointClient(project_endpoint, credential) + self._session_cache: set[str] = set() + + async def __aenter__(self) -> "FoundryCheckpointSaver": + """Enter the async context manager. + + :return: The saver instance. + :rtype: FoundryCheckpointSaver + """ + await self._client.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit the async context manager. + + :param exc_type: Exception type if an exception occurred. + :param exc_val: Exception value if an exception occurred. + :param exc_tb: Exception traceback if an exception occurred. + """ + await self._client.__aexit__(exc_type, exc_val, exc_tb) + + async def _ensure_session(self, thread_id: str) -> None: + """Ensure a session exists for the thread. + + :param thread_id: The thread identifier. + :type thread_id: str + """ + if thread_id not in self._session_cache: + session = CheckpointSession(session_id=thread_id) + await self._client.upsert_session(session) + self._session_cache.add(thread_id) + + async def _get_latest_checkpoint_id( + self, thread_id: str, checkpoint_ns: str + ) -> Optional[str]: + """Find the latest checkpoint ID for a thread and namespace. + + :param thread_id: The thread identifier. + :type thread_id: str + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :return: The latest checkpoint ID, or None if not found. + :rtype: Optional[str] + """ + item_ids = await self._client.list_item_ids(thread_id) + + # Filter to checkpoint items in this namespace + checkpoint_ids: List[str] = [] + for item_id in item_ids: + try: + parsed = parse_item_id(item_id.item_id) + if parsed.item_type == "checkpoint" and parsed.checkpoint_ns == checkpoint_ns: + checkpoint_ids.append(parsed.checkpoint_id) + except ValueError: + continue + + if not checkpoint_ids: + return None + + # Return the latest (max) checkpoint ID + return max(checkpoint_ids) + + async def _load_pending_writes( + self, thread_id: str, checkpoint_ns: str, checkpoint_id: str + ) -> List[Tuple[str, str, Any]]: + """Load pending writes for a checkpoint. + + :param thread_id: The thread identifier. + :type thread_id: str + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :return: List of pending writes as (task_id, channel, value) tuples. + :rtype: List[Tuple[str, str, Any]] + """ + item_ids = await self._client.list_item_ids(thread_id) + writes: List[Tuple[str, str, Any]] = [] + + for item_id in item_ids: + try: + parsed = parse_item_id(item_id.item_id) + if ( + parsed.item_type == "writes" + and parsed.checkpoint_ns == checkpoint_ns + and parsed.checkpoint_id == checkpoint_id + ): + item = await self._client.read_item(item_id) + if item: + task_id, channel, value, _ = self.serde.loads_typed(item.data) + writes.append((task_id, channel, value)) + except (ValueError, TypeError): + continue + + return writes + + async def _load_blobs( + self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, versions: ChannelVersions + ) -> Dict[str, Any]: + """Load channel blobs for a checkpoint. + + :param thread_id: The thread identifier. + :type thread_id: str + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :param versions: The channel versions to load. + :type versions: ChannelVersions + :return: Dictionary of channel values. + :rtype: Dict[str, Any] + """ + channel_values: Dict[str, Any] = {} + + for channel, version in versions.items(): + blob_item_id = make_item_id( + checkpoint_ns, checkpoint_id, "blob", f"{channel}:{version}" + ) + item_id = CheckpointItemId(session_id=thread_id, item_id=blob_item_id) + item = await self._client.read_item(item_id) + if item: + type_tag, data = self.serde.loads_typed(item.data) + if type_tag != "empty": + channel_values[channel] = data + + return channel_values + + # Async methods (primary implementation) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously get a checkpoint tuple by config. + + :param config: Configuration specifying which checkpoint to retrieve. + :type config: RunnableConfig + :return: The checkpoint tuple, or None if not found. + :rtype: Optional[CheckpointTuple] + """ + thread_id: str = config["configurable"]["thread_id"] + checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = get_checkpoint_id(config) + + # Ensure session exists + await self._ensure_session(thread_id) + + # If no checkpoint_id, find the latest + if not checkpoint_id: + checkpoint_id = await self._get_latest_checkpoint_id(thread_id, checkpoint_ns) + if not checkpoint_id: + return None + + # Load the checkpoint item + item_id_str = make_item_id(checkpoint_ns, checkpoint_id, "checkpoint") + item = await self._client.read_item( + CheckpointItemId(session_id=thread_id, item_id=item_id_str) + ) + if not item: + return None + + # Deserialize checkpoint data + checkpoint_data = self.serde.loads_typed(item.data) + checkpoint: Checkpoint = checkpoint_data["checkpoint"] + metadata: CheckpointMetadata = checkpoint_data["metadata"] + + # Load channel values (blobs) + channel_values = await self._load_blobs( + thread_id, checkpoint_ns, checkpoint_id, checkpoint.get("channel_versions", {}) + ) + checkpoint = {**checkpoint, "channel_values": channel_values} + + # Load pending writes + pending_writes = await self._load_pending_writes(thread_id, checkpoint_ns, checkpoint_id) + + # Build parent config if parent exists + parent_config: Optional[RunnableConfig] = None + if item.parent_id: + try: + parent_parsed = parse_item_id(item.parent_id) + parent_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": parent_parsed.checkpoint_ns, + "checkpoint_id": parent_parsed.checkpoint_id, + } + } + except ValueError: + pass + + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + }, + checkpoint=checkpoint, + metadata=metadata, + parent_config=parent_config, + pending_writes=pending_writes, + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a checkpoint. + + :param config: Configuration for the checkpoint. + :type config: RunnableConfig + :param checkpoint: The checkpoint to store. + :type checkpoint: Checkpoint + :param metadata: Additional metadata for the checkpoint. + :type metadata: CheckpointMetadata + :param new_versions: New channel versions as of this write. + :type new_versions: ChannelVersions + :return: Updated configuration with the checkpoint ID. + :rtype: RunnableConfig + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = checkpoint["id"] + + # Ensure session exists + await self._ensure_session(thread_id) + + # Determine parent + parent_checkpoint_id = config["configurable"].get("checkpoint_id") + parent_item_id: Optional[str] = None + if parent_checkpoint_id: + parent_item_id = make_item_id(checkpoint_ns, parent_checkpoint_id, "checkpoint") + + # Prepare checkpoint data (without channel_values - stored as blobs) + checkpoint_copy = checkpoint.copy() + channel_values: Dict[str, Any] = checkpoint_copy.pop("channel_values", {}) # type: ignore[misc] + + checkpoint_data = self.serde.dumps_typed({ + "checkpoint": checkpoint_copy, + "metadata": metadata, + }) + + # Create checkpoint item + item_id_str = make_item_id(checkpoint_ns, checkpoint_id, "checkpoint") + items: List[CheckpointItem] = [ + CheckpointItem( + session_id=thread_id, + item_id=item_id_str, + data=checkpoint_data, + parent_id=parent_item_id, + ) + ] + + # Create blob items for channel values with new versions + for channel, version in new_versions.items(): + if channel in channel_values: + blob_data = self.serde.dumps_typed(channel_values[channel]) + else: + blob_data = self.serde.dumps_typed(("empty", b"")) + + blob_item_id = make_item_id( + checkpoint_ns, checkpoint_id, "blob", f"{channel}:{version}" + ) + items.append( + CheckpointItem( + session_id=thread_id, + item_id=blob_item_id, + data=blob_data, + parent_id=item_id_str, + ) + ) + + await self._client.create_items(items) + + logger.debug( + "Saved checkpoint %s to Foundry session %s", + checkpoint_id, + thread_id, + ) + + return { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + } + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Asynchronously store intermediate writes for a checkpoint. + + :param config: Configuration of the related checkpoint. + :type config: RunnableConfig + :param writes: List of writes to store as (channel, value) pairs. + :type writes: Sequence[Tuple[str, Any]] + :param task_id: Identifier for the task creating the writes. + :type task_id: str + :param task_path: Path of the task creating the writes. + :type task_path: str + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = config["configurable"]["checkpoint_id"] + + checkpoint_item_id = make_item_id(checkpoint_ns, checkpoint_id, "checkpoint") + + items: List[CheckpointItem] = [] + for idx, (channel, value) in enumerate(writes): + write_data = self.serde.dumps_typed((task_id, channel, value, task_path)) + write_item_id = make_item_id( + checkpoint_ns, checkpoint_id, "writes", f"{task_id}:{idx}" + ) + items.append( + CheckpointItem( + session_id=thread_id, + item_id=write_item_id, + data=write_data, + parent_id=checkpoint_item_id, + ) + ) + + if items: + await self._client.create_items(items) + logger.debug( + "Saved %d writes for checkpoint %s", + len(items), + checkpoint_id, + ) + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """Asynchronously list checkpoints matching filter criteria. + + :param config: Base configuration for filtering checkpoints. + :type config: Optional[RunnableConfig] + :param filter: Additional filtering criteria for metadata. + :type filter: Optional[Dict[str, Any]] + :param before: List checkpoints created before this configuration. + :type before: Optional[RunnableConfig] + :param limit: Maximum number of checkpoints to return. + :type limit: Optional[int] + :return: Async iterator of matching checkpoint tuples. + :rtype: AsyncIterator[CheckpointTuple] + """ + if not config: + return + + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns") + + # Get all items for this session + item_ids = await self._client.list_item_ids(thread_id) + + # Filter to checkpoint items only + checkpoint_items: List[Tuple[ParsedItemId, CheckpointItemId]] = [] + for item_id in item_ids: + try: + parsed = parse_item_id(item_id.item_id) + if parsed.item_type == "checkpoint": + # Filter by namespace if specified + if checkpoint_ns is None or parsed.checkpoint_ns == checkpoint_ns: + checkpoint_items.append((parsed, item_id)) + except ValueError: + continue + + # Sort by checkpoint_id in reverse order (newest first) + checkpoint_items.sort(key=lambda x: x[0].checkpoint_id, reverse=True) + + # Apply before cursor + if before: + before_id = get_checkpoint_id(before) + if before_id: + checkpoint_items = [ + (p, i) for p, i in checkpoint_items if p.checkpoint_id < before_id + ] + + # Apply limit + if limit: + checkpoint_items = checkpoint_items[:limit] + + # Load and yield each checkpoint + for parsed, _ in checkpoint_items: + tuple_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": parsed.checkpoint_ns, + "checkpoint_id": parsed.checkpoint_id, + } + } + checkpoint_tuple = await self.aget_tuple(tuple_config) + if checkpoint_tuple: + # Apply metadata filter if provided + if filter: + if not all( + checkpoint_tuple.metadata.get(k) == v for k, v in filter.items() + ): + continue + yield checkpoint_tuple + + async def adelete_thread(self, thread_id: str) -> None: + """Delete all checkpoints and writes for a thread. + + :param thread_id: The thread ID whose checkpoints should be deleted. + :type thread_id: str + """ + await self._client.delete_session(thread_id) + self._session_cache.discard(thread_id) + logger.debug("Deleted session %s", thread_id) + + # Sync methods (raise NotImplementedError) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Sync version not supported - use aget_tuple instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use aget_tuple() instead." + ) + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """Sync version not supported - use alist instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use alist() instead." + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Sync version not supported - use aput instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use aput() instead." + ) + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Sync version not supported - use aput_writes instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use aput_writes() instead." + ) + + def delete_thread(self, thread_id: str) -> None: + """Sync version not supported - use adelete_thread instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use adelete_thread() instead." + ) + + def get_next_version(self, current: Optional[str], channel: None) -> str: + """Generate the next version ID for a channel. + + Uses string versions with format "{counter}.{random}". + + :param current: The current version identifier. + :type current: Optional[str] + :param channel: Deprecated argument, kept for backwards compatibility. + :return: The next version identifier. + :rtype: str + """ + import random as rand + + if current is None: + current_v = 0 + elif isinstance(current, int): + current_v = current + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + next_h = rand.random() + return f"{next_v:032}.{next_h:016}" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_item_id.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_item_id.py new file mode 100644 index 000000000000..8758181ce2f2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_item_id.py @@ -0,0 +1,96 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Item ID utilities for composite checkpoint item identifiers.""" + +from dataclasses import dataclass +from typing import Literal + +ItemType = Literal["checkpoint", "writes", "blob"] + + +@dataclass +class ParsedItemId: + """Parsed components of a checkpoint item ID. + + :ivar checkpoint_ns: The checkpoint namespace. + :ivar checkpoint_id: The checkpoint identifier. + :ivar item_type: The type of item (checkpoint, writes, or blob). + :ivar sub_key: Additional key for writes or blobs. + """ + + checkpoint_ns: str + checkpoint_id: str + item_type: ItemType + sub_key: str + + +def _encode(s: str) -> str: + """URL-safe encode a string (escape colons and percent signs). + + :param s: The string to encode. + :type s: str + :return: The encoded string. + :rtype: str + """ + return s.replace("%", "%25").replace(":", "%3A") + + +def _decode(s: str) -> str: + """Decode a URL-safe encoded string. + + :param s: The encoded string. + :type s: str + :return: The decoded string. + :rtype: str + """ + return s.replace("%3A", ":").replace("%25", "%") + + +def make_item_id( + checkpoint_ns: str, + checkpoint_id: str, + item_type: ItemType, + sub_key: str = "", +) -> str: + """Create a composite item ID. + + Format: {checkpoint_ns}:{checkpoint_id}:{type}:{sub_key} + + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :param item_type: The type of item (checkpoint, writes, or blob). + :type item_type: ItemType + :param sub_key: Additional key for writes or blobs. + :type sub_key: str + :return: The composite item ID. + :rtype: str + """ + return f"{_encode(checkpoint_ns)}:{_encode(checkpoint_id)}:{item_type}:{_encode(sub_key)}" + + +def parse_item_id(item_id: str) -> ParsedItemId: + """Parse a composite item ID back to components. + + :param item_id: The composite item ID to parse. + :type item_id: str + :return: The parsed item ID components. + :rtype: ParsedItemId + :raises ValueError: If the item ID format is invalid. + """ + parts = item_id.split(":", 3) + if len(parts) != 4: + raise ValueError(f"Invalid item_id format: {item_id}") + + item_type = parts[2] + if item_type not in ("checkpoint", "writes", "blob"): + raise ValueError(f"Invalid item_type in item_id: {item_type}") + + return ParsedItemId( + checkpoint_ns=_decode(parts[0]), + checkpoint_id=_decode(parts[1]), + item_type=item_type, # type: ignore[arg-type] + sub_key=_decode(parts[3]), + ) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/__init__.py new file mode 100644 index 000000000000..315126869940 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for the checkpointer module.""" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_foundry_checkpoint_saver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_foundry_checkpoint_saver.py new file mode 100644 index 000000000000..d25a34f6289d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_foundry_checkpoint_saver.py @@ -0,0 +1,441 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryCheckpointSaver.""" + +import pytest +from unittest.mock import Mock + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azure.ai.agentserver.langgraph.checkpointer import FoundryCheckpointSaver +from azure.ai.agentserver.langgraph.checkpointer._foundry_checkpoint_saver import ( + BaseCheckpointSaver, +) +from azure.ai.agentserver.langgraph.checkpointer._item_id import make_item_id + +from ..mocks import MockFoundryCheckpointClient + + +class TestableFoundryCheckpointSaver(FoundryCheckpointSaver): + """Testable version that accepts a mock client directly (bypasses credential check).""" + + def __init__(self, client: MockFoundryCheckpointClient) -> None: + """Initialize with a mock client.""" + # Skip FoundryCheckpointSaver.__init__ and call BaseCheckpointSaver directly + BaseCheckpointSaver.__init__(self, serde=None) + self._client = client # type: ignore[assignment] + self._session_cache: set[str] = set() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_none_for_missing_checkpoint() -> None: + """Test that aget_tuple returns None when checkpoint doesn't exist.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1"}} + result = await saver.aget_tuple(config) + + assert result is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_creates_checkpoint_item() -> None: + """Test that aput creates a checkpoint item.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + metadata = {"source": "test"} + + result = await saver.aput(config, checkpoint, metadata, {}) + + assert result["configurable"]["checkpoint_id"] == "cp-001" + assert result["configurable"]["thread_id"] == "thread-1" + + # Verify item was created + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_creates_blob_items_for_new_versions() -> None: + """Test that aput creates blob items for channel values.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {"messages": ["hello", "world"]}, + "channel_versions": {"messages": "1"}, + "versions_seen": {}, + "pending_sends": [], + } + metadata = {"source": "test"} + new_versions = {"messages": "1"} + + await saver.aput(config, checkpoint, metadata, new_versions) + + # Should have 2 items: checkpoint + 1 blob + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_returns_config_with_checkpoint_id() -> None: + """Test that aput returns config with the correct checkpoint ID.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "ns1"}} + checkpoint = { + "id": "my-checkpoint-id", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + + result = await saver.aput(config, checkpoint, {}, {}) + + assert result["configurable"]["checkpoint_id"] == "my-checkpoint-id" + assert result["configurable"]["thread_id"] == "thread-1" + assert result["configurable"]["checkpoint_ns"] == "ns1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_checkpoint_with_data() -> None: + """Test that aget_tuple returns checkpoint data correctly.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save a checkpoint first + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + metadata = {"source": "test", "step": 1} + + await saver.aput(config, checkpoint, metadata, {}) + + # Now retrieve it + get_config = {"configurable": {"thread_id": "thread-1", "checkpoint_id": "cp-001"}} + result = await saver.aget_tuple(get_config) + + assert result is not None + assert result.checkpoint["id"] == "cp-001" + assert result.metadata["source"] == "test" + assert result.metadata["step"] == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_latest_without_checkpoint_id() -> None: + """Test that aget_tuple returns the latest checkpoint when no ID specified.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save multiple checkpoints + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + + for i in range(3): + checkpoint = { + "id": f"cp-00{i}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {"step": i}, {}) + + # Retrieve without specifying checkpoint_id + get_config = {"configurable": {"thread_id": "thread-1"}} + result = await saver.aget_tuple(get_config) + + assert result is not None + # Should get the latest (max checkpoint_id) + assert result.checkpoint["id"] == "cp-002" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_writes_creates_write_items() -> None: + """Test that aput_writes creates write items.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # First create a checkpoint + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {}, {}) + + # Now add writes + writes = [("channel1", "value1"), ("channel2", "value2")] + await saver.aput_writes(config, writes, task_id="task-1") + + # Should have 3 items: checkpoint + 2 writes + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 3 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_pending_writes() -> None: + """Test that aget_tuple includes pending writes.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Create checkpoint and add writes + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {}, {}) + + writes = [("channel1", "value1")] + await saver.aput_writes(config, writes, task_id="task-1") + + # Retrieve and check pending writes + result = await saver.aget_tuple(config) + + assert result is not None + assert result.pending_writes is not None + assert len(result.pending_writes) == 1 + assert result.pending_writes[0][1] == "channel1" + assert result.pending_writes[0][2] == "value1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_alist_returns_checkpoints_in_order() -> None: + """Test that alist returns checkpoints in reverse order.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save multiple checkpoints + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + for i in range(3): + checkpoint = { + "id": f"cp-00{i}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {"step": i}, {}) + + # List checkpoints + list_config = {"configurable": {"thread_id": "thread-1"}} + results = [] + async for cp in saver.alist(list_config): + results.append(cp) + + assert len(results) == 3 + # Should be in reverse order (newest first) + assert results[0].checkpoint["id"] == "cp-002" + assert results[1].checkpoint["id"] == "cp-001" + assert results[2].checkpoint["id"] == "cp-000" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_alist_filters_by_namespace() -> None: + """Test that alist filters by checkpoint namespace.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save checkpoints in different namespaces + for ns in ["ns1", "ns2"]: + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ns}} + checkpoint = { + "id": f"cp-{ns}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + await saver.aput(config, checkpoint, {}, {}) + + # List only ns1 + list_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "ns1"}} + results = [] + async for cp in saver.alist(list_config): + results.append(cp) + + assert len(results) == 1 + assert results[0].checkpoint["id"] == "cp-ns1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_alist_applies_limit() -> None: + """Test that alist respects the limit parameter.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save multiple checkpoints + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + for i in range(5): + checkpoint = { + "id": f"cp-00{i}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {}, {}) + + # List with limit + list_config = {"configurable": {"thread_id": "thread-1"}} + results = [] + async for cp in saver.alist(list_config, limit=2): + results.append(cp) + + assert len(results) == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_adelete_thread_deletes_session() -> None: + """Test that adelete_thread removes all checkpoints for a thread.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Create a checkpoint + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + await saver.aput(config, checkpoint, {}, {}) + + # Verify it exists + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 1 + + # Delete the thread + await saver.adelete_thread("thread-1") + + # Verify it's gone + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 0 + + +@pytest.mark.unit +def test_sync_methods_raise_not_implemented() -> None: + """Test that sync methods raise NotImplementedError.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1"}} + checkpoint = {"id": "cp-001", "channel_values": {}, "channel_versions": {}} + + with pytest.raises(NotImplementedError, match="aget_tuple"): + saver.get_tuple(config) + + with pytest.raises(NotImplementedError, match="aput"): + saver.put(config, checkpoint, {}, {}) + + with pytest.raises(NotImplementedError, match="aput_writes"): + saver.put_writes(config, [], "task-1") + + with pytest.raises(NotImplementedError, match="alist"): + list(saver.list(config)) + + with pytest.raises(NotImplementedError, match="adelete_thread"): + saver.delete_thread("thread-1") + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_parent_config() -> None: + """Test that aget_tuple includes parent config when checkpoint has parent.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save parent checkpoint + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + parent_checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, parent_checkpoint, {}, {}) + + # Save child checkpoint + child_checkpoint = { + "id": "cp-002", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + await saver.aput(config, child_checkpoint, {}, {}) + + # Retrieve child + get_config = {"configurable": {"thread_id": "thread-1", "checkpoint_id": "cp-002"}} + result = await saver.aget_tuple(get_config) + + assert result is not None + assert result.parent_config is not None + assert result.parent_config["configurable"]["checkpoint_id"] == "cp-001" + + +@pytest.mark.unit +def test_constructor_requires_async_credential() -> None: + """Test that FoundryCheckpointSaver raises TypeError for sync credentials.""" + mock_credential = Mock(spec=TokenCredential) + + with pytest.raises(TypeError, match="AsyncTokenCredential"): + FoundryCheckpointSaver( + project_endpoint="https://test.services.ai.azure.com/api/projects/test", + credential=mock_credential, + ) + + +@pytest.mark.unit +def test_constructor_accepts_async_credential() -> None: + """Test that FoundryCheckpointSaver accepts AsyncTokenCredential.""" + mock_credential = Mock(spec=AsyncTokenCredential) + + saver = FoundryCheckpointSaver( + project_endpoint="https://test.services.ai.azure.com/api/projects/test", + credential=mock_credential, + ) + + assert saver is not None + assert saver._client is not None diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_item_id.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_item_id.py new file mode 100644 index 000000000000..ddb11ffa62f0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_item_id.py @@ -0,0 +1,125 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for item ID utilities.""" + +import pytest + +from azure.ai.agentserver.langgraph.checkpointer._item_id import ( + ParsedItemId, + make_item_id, + parse_item_id, +) + + +@pytest.mark.unit +def test_make_item_id_formats_correctly() -> None: + """Test that make_item_id creates correct composite IDs.""" + item_id = make_item_id("ns1", "cp-001", "checkpoint") + assert item_id == "ns1:cp-001:checkpoint:" + + +@pytest.mark.unit +def test_make_item_id_with_sub_key() -> None: + """Test that make_item_id includes sub_key correctly.""" + item_id = make_item_id("ns1", "cp-001", "writes", "task1:0") + assert item_id == "ns1:cp-001:writes:task1%3A0" + + +@pytest.mark.unit +def test_make_item_id_with_blob() -> None: + """Test blob item ID format.""" + item_id = make_item_id("", "cp-001", "blob", "messages:v2") + assert item_id == ":cp-001:blob:messages%3Av2" + + +@pytest.mark.unit +def test_parse_item_id_extracts_components() -> None: + """Test that parse_item_id extracts all components correctly.""" + item_id = "ns1:cp-001:checkpoint:" + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == "ns1" + assert parsed.checkpoint_id == "cp-001" + assert parsed.item_type == "checkpoint" + assert parsed.sub_key == "" + + +@pytest.mark.unit +def test_parse_item_id_extracts_sub_key() -> None: + """Test that parse_item_id extracts sub_key correctly.""" + item_id = "ns1:cp-001:writes:task1%3A0" + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == "ns1" + assert parsed.checkpoint_id == "cp-001" + assert parsed.item_type == "writes" + assert parsed.sub_key == "task1:0" + + +@pytest.mark.unit +def test_roundtrip_simple() -> None: + """Test roundtrip encoding/decoding of simple IDs.""" + original_ns = "namespace" + original_id = "checkpoint-123" + original_type = "checkpoint" + original_key = "" + + item_id = make_item_id(original_ns, original_id, original_type, original_key) + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == original_ns + assert parsed.checkpoint_id == original_id + assert parsed.item_type == original_type + assert parsed.sub_key == original_key + + +@pytest.mark.unit +def test_roundtrip_with_special_characters() -> None: + """Test roundtrip encoding/decoding with special characters (colons).""" + original_ns = "ns:with:colons" + original_id = "cp:123:abc" + original_type = "blob" + original_key = "channel:v1:extra" + + item_id = make_item_id(original_ns, original_id, original_type, original_key) + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == original_ns + assert parsed.checkpoint_id == original_id + assert parsed.item_type == original_type + assert parsed.sub_key == original_key + + +@pytest.mark.unit +def test_roundtrip_with_percent_signs() -> None: + """Test roundtrip encoding/decoding with percent signs.""" + original_ns = "ns%test" + original_id = "cp%123" + original_type = "checkpoint" + original_key = "key%value" + + item_id = make_item_id(original_ns, original_id, original_type, original_key) + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == original_ns + assert parsed.checkpoint_id == original_id + assert parsed.item_type == original_type + assert parsed.sub_key == original_key + + +@pytest.mark.unit +def test_parse_item_id_raises_on_invalid_format() -> None: + """Test that parse_item_id raises ValueError for invalid format.""" + with pytest.raises(ValueError, match="Invalid item_id format"): + parse_item_id("invalid:format") + + with pytest.raises(ValueError, match="Invalid item_id format"): + parse_item_id("only:two:parts") + + +@pytest.mark.unit +def test_parse_item_id_raises_on_invalid_type() -> None: + """Test that parse_item_id raises ValueError for invalid item type.""" + with pytest.raises(ValueError, match="Invalid item_type"): + parse_item_id("ns:cp:invalid:key") diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/__init__.py new file mode 100644 index 000000000000..4436d04866df --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/__init__.py @@ -0,0 +1,8 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mock implementations for testing.""" + +from .mock_checkpoint_client import MockFoundryCheckpointClient + +__all__ = ["MockFoundryCheckpointClient"] diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/mock_checkpoint_client.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/mock_checkpoint_client.py new file mode 100644 index 000000000000..ffc1e2fcc4c1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/mock_checkpoint_client.py @@ -0,0 +1,156 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mock implementation of FoundryCheckpointClient for testing.""" + +from typing import Any, Dict, List, Optional + +from azure.ai.agentserver.core.checkpoints.client import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, +) + + +class MockFoundryCheckpointClient: + """In-memory mock for FoundryCheckpointClient for unit testing. + + Stores checkpoints in memory without making any HTTP calls. + """ + + def __init__(self, endpoint: str = "https://mock.endpoint") -> None: + """Initialize the mock client. + + :param endpoint: The mock endpoint URL. + :type endpoint: str + """ + self._endpoint = endpoint + self._sessions: Dict[str, CheckpointSession] = {} + self._items: Dict[str, CheckpointItem] = {} + + def _item_key(self, item_id: CheckpointItemId) -> str: + """Generate a unique key for a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The unique key. + :rtype: str + """ + return f"{item_id.session_id}:{item_id.item_id}" + + # Session operations + + async def upsert_session(self, session: CheckpointSession) -> CheckpointSession: + """Create or update a checkpoint session. + + :param session: The checkpoint session to upsert. + :type session: CheckpointSession + :return: The upserted checkpoint session. + :rtype: CheckpointSession + """ + self._sessions[session.session_id] = session + return session + + async def read_session(self, session_id: str) -> Optional[CheckpointSession]: + """Read a checkpoint session by ID. + + :param session_id: The session identifier. + :type session_id: str + :return: The checkpoint session if found, None otherwise. + :rtype: Optional[CheckpointSession] + """ + return self._sessions.get(session_id) + + async def delete_session(self, session_id: str) -> None: + """Delete a checkpoint session. + + :param session_id: The session identifier. + :type session_id: str + """ + self._sessions.pop(session_id, None) + # Also delete all items in the session + keys_to_delete = [ + key for key, item in self._items.items() if item.session_id == session_id + ] + for key in keys_to_delete: + del self._items[key] + + # Item operations + + async def create_items(self, items: List[CheckpointItem]) -> List[CheckpointItem]: + """Create checkpoint items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The created checkpoint items. + :rtype: List[CheckpointItem] + """ + for item in items: + key = self._item_key(item.to_item_id()) + self._items[key] = item + return items + + async def read_item(self, item_id: CheckpointItemId) -> Optional[CheckpointItem]: + """Read a checkpoint item by ID. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The checkpoint item if found, None otherwise. + :rtype: Optional[CheckpointItem] + """ + key = self._item_key(item_id) + return self._items.get(key) + + async def delete_item(self, item_id: CheckpointItemId) -> bool: + """Delete a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: True if the item was deleted, False if not found. + :rtype: bool + """ + key = self._item_key(item_id) + if key in self._items: + del self._items[key] + return True + return False + + async def list_item_ids( + self, session_id: str, parent_id: Optional[str] = None + ) -> List[CheckpointItemId]: + """List checkpoint item IDs for a session. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier for filtering. + :type parent_id: Optional[str] + :return: List of checkpoint item identifiers. + :rtype: List[CheckpointItemId] + """ + result = [] + for item in self._items.values(): + if item.session_id == session_id: + if parent_id is None or item.parent_id == parent_id: + result.append(item.to_item_id()) + return result + + # Context manager methods + + async def close(self) -> None: + """Close the client (no-op for mock).""" + pass + + async def __aenter__(self) -> "MockFoundryCheckpointClient": + """Enter the async context manager. + + :return: The client instance. + :rtype: MockFoundryCheckpointClient + """ + return self + + async def __aexit__(self, *exc_details: Any) -> None: + """Exit the async context manager. + + :param exc_details: Exception details if an exception occurred. + """ + pass diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_from_langgraph_managed.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_from_langgraph_managed.py new file mode 100644 index 000000000000..5a8b5cf2a1f4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_from_langgraph_managed.py @@ -0,0 +1,72 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for from_langgraph with checkpointer via compile.""" + +import pytest +from unittest.mock import Mock + +from azure.core.credentials_async import AsyncTokenCredential + + +@pytest.mark.unit +def test_from_langgraph_basic() -> None: + """Test that from_langgraph works without a checkpointer.""" + from azure.ai.agentserver.langgraph import from_langgraph + from langgraph.graph import StateGraph + from typing_extensions import TypedDict + + class State(TypedDict): + messages: list + + builder = StateGraph(State) + builder.add_node("node1", lambda x: x) + builder.set_entry_point("node1") + builder.set_finish_point("node1") + graph = builder.compile() + + adapter = from_langgraph(graph) + + assert adapter is not None + + +@pytest.mark.unit +def test_graph_with_foundry_checkpointer_via_compile() -> None: + """Test that FoundryCheckpointSaver can be set via builder.compile().""" + from azure.ai.agentserver.langgraph import from_langgraph + from azure.ai.agentserver.langgraph.checkpointer import FoundryCheckpointSaver + from langgraph.graph import StateGraph + from typing_extensions import TypedDict + + class State(TypedDict): + messages: list + + builder = StateGraph(State) + builder.add_node("node1", lambda x: x) + builder.add_node("node2", lambda x: x) + builder.add_edge("node1", "node2") + builder.set_entry_point("node1") + builder.set_finish_point("node2") + + mock_credential = Mock(spec=AsyncTokenCredential) + saver = FoundryCheckpointSaver( + project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", + credential=mock_credential, + ) + + # User sets checkpointer via LangGraph's native compile() + graph = builder.compile( + checkpointer=saver, + interrupt_before=["node1"], + interrupt_after=["node2"], + debug=True, + ) + + adapter = from_langgraph(graph) + + # Verify checkpointer and compile parameters are preserved + assert adapter is not None + assert isinstance(adapter._graph.checkpointer, FoundryCheckpointSaver) + assert adapter._graph.interrupt_before_nodes == ["node1"] + assert adapter._graph.interrupt_after_nodes == ["node2"] + assert adapter._graph.debug is True