diff --git a/sdk/agentserver/AGENTS.md b/sdk/agentserver/AGENTS.md new file mode 100644 index 000000000000..5a5f20dc457c --- /dev/null +++ b/sdk/agentserver/AGENTS.md @@ -0,0 +1,200 @@ +# AGENTS.md + +This file provides comprehensive guidance to Coding Agents(Codex, Claude Code, GitHub Copilot, etc.) when working with Python code in this repository. + + +## 🎯 (Read-first) Project Awareness & Context + +- **Always read `PLANNING.md`** at the start of a new conversation to understand the project's architecture, goals, style, and constraints. +- **Check `TASK.md`** before starting a new task. If the task isn't listed, add it with a brief description and today's date. + +## πŸ† Core Development Philosophy + +### KISS (Keep It Simple, Stupid) + +Simplicity should be a key goal in design. Choose straightforward solutions over complex ones whenever possible. Simple solutions are easier to understand, maintain, and debug. + +### YAGNI (You Aren't Gonna Need It) + +Avoid building functionality on speculation. Implement features only when they are needed, not when you anticipate they might be useful in the future. + +### Design Principles + +- **Dependency Inversion**: High-level modules should not depend on low-level modules. Both should depend on abstractions. +- **Open/Closed Principle**: Software entities should be open for extension but closed for modification. +- **Single Responsibility**: Each function, class, and module should have one clear purpose. +- **Fail Fast**: Check for potential errors early and raise exceptions immediately when issues occur. +- **Encapsulation**: Hide internal state and require all interaction to be performed through an object's methods. + +### Implementation Patterns + +- **Type-Safe**: Prefer explicit, type‑safe structures (TypedDicts, dataclasses, enums, unions) over Dict, Any, or other untyped containers, unless there is no viable alternative. +- **Explicit validation at boundaries**: Validate inputs and identifiers early; reject malformed descriptors and missing required fields. +- **Separation of resolution and execution**: Keep discovery/selection separate from invocation to allow late binding and interchangeable implementations. +- **Context-aware execution**: Thread request/user context through calls via scoped providers; avoid global mutable state. +- **Cache with safety**: Use bounded TTL caching with concurrency-safe in-flight de-duplication; invalidate on errors. +- **Stable naming and deterministic mapping**: Derive stable, unique names deterministically to avoid collisions. +- **Graceful defaults, loud misconfiguration**: Provide sensible defaults when optional data is missing; raise clear errors when required configuration is absent. +- **Thin integration layers**: Use adapters/middleware to translate between layers without leaking internals. + +## βœ… Work Process (required) + +- **Before coding**: confirm the task in `TASK.md` β†’ **Now**. +- **While working**: Add new sub-tasks or TODOs discovered during development to `TASK.md` under a "Discovered During Work" section. +- **After finishing**: mark the task done immediately, and note what changed (files/areas) in `TASK.md`. +- **Update CHANGELOG.md only when required** by release policy. + +### TASK.md template +```markdown +## Now (active) +- [ ] YYYY-MM-DD β€” + - Scope: + - Exit criteria: + +## Next (queued) +- [ ] YYYY-MM-DD β€” + +## Discovered During Work +- [ ] YYYY-MM-DD β€” + +## Done +- [x] YYYY-MM-DD β€” +``` + +## πŸ“Ž Style & Conventions & Standards + +### File and Function Limits + +- **Never create a file longer than 500 lines of code**. If approaching this limit, refactor by splitting into modules. +- **Functions should be under 50 lines** with a single, clear responsibility. +- **Classes should be under 100 lines** and represent a single concept or entity. +- **Organize code into clearly separated modules**, grouped by feature or responsibility. +- **Line length should be max 120 characters**, as enforced by Ruff in `pyproject.toml`. +- **Use the standard repo workflow**: from the package root, run tests and linters via `tox` (for example, `tox -e pytest` or `tox -e pylint`) rather than relying on a custom virtual environment name. +- **Keep modules focused and cohesive**; split by feature responsibility when a file grows large or mixes concerns. +- **Avoid drive-by refactors** unless required by the task. +- **Preserve public API stability** and match existing patterns in the package you touch. + +### Naming Conventions + +- **Variables and functions**: `snake_case` +- **Classes**: `PascalCase` +- **Constants**: `UPPER_SNAKE_CASE` +- **Private attributes/methods**: `_leading_underscore` +- **Type aliases**: `PascalCase` +- **Enum values**: `UPPER_SNAKE_CASE` + +### Project File Conventions + +- **Follow per-package `pyproject.toml`** and `[tool.azure-sdk-build]` settings. +- **Do not edit generated code**. +- **Do not edit files with the header** `Code generated by Microsoft (R) Python Code Generator.` + - `sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/projects/` + - `sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py` + - `sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py` + - `sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py` +- **Do not introduce secrets or credentials.** +- **Do not disable TLS/SSL verification** without explicit approval. +- **Keep logs free of sensitive data.** + +### Python Code Standards + +- **Follow existing package patterns** and Azure SDK for Python guidelines. +- **Type hints and async patterns** should match existing code. +- **Respect Ruff settings** (line length 120, isort rules) in each package `pyproject.toml`. +- **Avoid new dependencies** unless necessary and approved. + +### Docstring Standards +Use reStructuredText (reST) / Sphinx Style docstrings for all public functions, classes, and modules: + +```python +def calculate_discount( + price: Decimal, + discount_percent: float, + min_amount: Decimal = Decimal("0.01") +) -> Decimal: + """Calculate the discounted price for a product. + + :param price: The original price of the product. + :type price: Decimal + :param discount_percent: The discount percentage to apply (0-100). + :type discount_percent: float + :param min_amount: The minimum amount after discount (default is 0.01). + :type min_amount: Decimal + :return: The price after applying the discount, not less than min_amount. + :rtype: Decimal + :raises ValueError: If discount_percent is not between 0 and 100. + + Example: + Calculate a 15% discount on a $100 product: + + .. code-block:: python + + calculate_discount(Decimal("100.00"), 15.0) + """ +``` + +### Error Handling Standards + +- Prefer explicit validation at API boundaries and raise errors **as early as possible**. +- Use standard Python exceptions (`ValueError`, `TypeError`, `KeyError`, etc.) when they accurately describe the problem. +- When a domain-specific error is needed, define a clear, documented exception type and reuse it consistently. +- Do **not** silently swallow exceptions. Either handle them meaningfully (with clear recovery behavior) or let them propagate. +- Preserve the original traceback when re-raising (`raise` without arguments) so issues remain diagnosable. +- Fail fast on programmer errors (e.g., inconsistent state, impossible branches) using assertions or explicit exceptions. +- For public APIs, validate user input and return helpful, actionable messages without leaking secrets or internal implementation details. + +#### Exception Best Practices + +- Avoid `except Exception:` and **never** use bare `except:`; always catch the most specific exception type possible. +- Keep `try` blocks **small** and focused so that it is clear which statements may raise the handled exception. +- When adding context to an error, use either `raise NewError("message") from exc` or log the context and re-raise with `raise`. +- Do not use exceptions for normal control flow; reserve them for truly exceptional or error conditions. +- When a function can raise non-obvious exceptions, document them in the docstring under a `:raises:` section. +- In asynchronous code, make sure exceptions are not lost in background tasks; gather and handle them explicitly where needed. + +### Logging Standards + +- Use the standard library `logging` module for all diagnostic output; **do not** use `print` in library or service code. +- Create a module-level logger via `logger = logging.getLogger(__name__)` and use it consistently within that module. +- Choose log levels appropriately: + - `logger.debug(...)` for detailed diagnostics and tracing. + - `logger.info(...)` for high-level lifecycle events (startup, shutdown, major state changes). + - `logger.warning(...)` for recoverable issues or unexpected-but-tolerated conditions. + - `logger.error(...)` for failures where the current operation cannot succeed. + - `logger.critical(...)` for unrecoverable conditions affecting process health. +- Never log secrets, credentials, access tokens, full connection strings, or sensitive customer data. +- When logging exceptions, prefer `logger.exception("message")` inside an `except` block so the traceback is included. +- Keep log messages clear and structured (include identifiers like request IDs, resource names, or correlation IDs when available). +## ⚠️ Important Notes + +- **NEVER ASSUME OR GUESS** - When in doubt, ask for clarification +- **Always verify file paths and module names** before use +- **Keep this file (`AGENTS.md`) updated** when adding new patterns or dependencies +- **Test your code** - No feature is complete without tests +- **Document your decisions** - Future developers (including yourself) will thank you + +## πŸ“š Documentation & Explainability + +- **Keep samples runnable** and focused on SDK usage. +- **Follow third-party dependency guidance** in `CONTRIBUTING.md`. + +## πŸ› οΈ Development Environment + +### βœ… Testing & Quality Gates (tox) + +Run from a package root (e.g., `sdk/agentserver/azure-ai-agentserver-core`). + +- `tox run -e sphinx -c ../../../eng/tox/tox.ini --root .` + - Docs output: `.tox/sphinx/tmp/dist/site/index.html` +- `tox run -e pylint -c ../../../eng/tox/tox.ini --root .` + - Uses repo `pylintrc`; `next-pylint` uses `eng/pylintrc` +- `tox run -e mypy -c ../../../eng/tox/tox.ini --root .` +- `tox run -e pyright -c ../../../eng/tox/tox.ini --root .` +- `tox run -e verifytypes -c ../../../eng/tox/tox.ini --root .` +- `tox run -e whl -c ../../../eng/tox/tox.ini --root .` +- `tox run -e sdist -c ../../../eng/tox/tox.ini --root .` +- `tox run -e samples -c ../../../eng/tox/tox.ini --root .` (runs all samples) +- `tox run -e apistub -c ../../../eng/tox/tox.ini --root .` + +Check each package `pyproject.toml` under `[tool.azure-sdk-build]` to see which checks are enabled/disabled. diff --git a/sdk/agentserver/CLAUDE.md b/sdk/agentserver/CLAUDE.md new file mode 100644 index 000000000000..43c994c2d361 --- /dev/null +++ b/sdk/agentserver/CLAUDE.md @@ -0,0 +1 @@ +@AGENTS.md diff --git a/sdk/agentserver/PLANNING.md b/sdk/agentserver/PLANNING.md new file mode 100644 index 000000000000..6a65e2925a83 --- /dev/null +++ b/sdk/agentserver/PLANNING.md @@ -0,0 +1,100 @@ +# 🧭 PLANNING.md + +## 🎯 What this project is +AgentServer is a set of Python packages under `sdk/agentserver` that host agents for +Azure AI Foundry. The core package provides the runtime/server, tooling runtime, and +Responses API models, while the adapter packages wrap popular frameworks. The primary +users are SDK consumers who want to run agents locally and deploy them as Foundry-hosted +containers. Work is β€œdone” when adapters faithfully translate framework execution into +Responses API-compatible outputs and the packages pass their expected tests and samples. + +**Behavioral/policy rules live in `AGENTS.md`.** This document is architecture + repo map + doc index. + +## 🎯 Goals / Non-goals +Goals: +- Keep a stable architecture snapshot and repo map for fast onboarding. +- Document key request/response flows, including streaming. +- Clarify the development workflow and testing expectations for AgentServer packages. + +Non-goals: +- Detailed API documentation (belongs in package docs and docstrings). +- Per-initiative plans (belong in `TASK.md` or a dedicated plan file). +- Speculative refactors (align with KISS/YAGNI in `AGENTS.md`). + +## 🧩 Architecture (snapshot) +### πŸ—οΈ Project Structure +- **azure-ai-agentserver-core**: Core library + - Runtime/context + - HTTP gateway + - Foundry integrations + - Responses API protocol (current) +- **azure-ai-agentserver-agentframework**: adapters for Agent Framework agents/workflows, + thread and checkpoint persistence. +- **azure-ai-agentserver-langgraph**: adapter and converters for LangGraph agents and + Response API events. + +### Current vs target +- Current: OpenAI Responses API protocol lives in `azure-ai-agentserver-core` alongside + core runtime and HTTP gateway code; framework adapters layer on top. +- Target (planned, not fully implemented): + - Core layer: app/runtime/context, foundry integrations (tools, checkpointing), HTTP gateway + - Protocol layer: Responses API in its own package + - Framework layer: adapters (agentframework, langgraph, other frameworks) + +### Key flows +- Request path: `/runs` or `/responses` β†’ `AgentRunContext` β†’ agent execution β†’ Responses + API payload. +- Streaming path: generator/async generator β†’ SSE event stream. +- Framework adapter path: framework input β†’ converter β†’ Response API output (streaming + or non-streaming). +- Tools path: Foundry tool runtime invoked via core `tools/runtime` APIs. + +## πŸ—ΊοΈ Repo map +- `azure-ai-agentserver-core`: Core library (runtime/context, HTTP gateway, Foundry integrations, + Responses API protocol today). +- `azure-ai-agentserver-agentframework`: Agent Framework adapter. +- `azure-ai-agentserver-langgraph`: LangGraph adapter. +- Core runtime and models: `azure-ai-agentserver-core/azure/ai/agentserver/core/` +- Agent Framework adapter: `azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/` +- LangGraph adapter: `azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/` +- Samples: `azure-ai-agentserver-*/samples/` +- Tests: `azure-ai-agentserver-*/tests/` +- Package docs (Sphinx inputs): `azure-ai-agentserver-*/doc/` +- Repo-wide guidance: `CONTRIBUTING.md`, `doc/dev/tests.md`, `doc/eng_sys_checks.md` + +## πŸ“š Doc index +### **Read repo-wide guidance**: +- `CONTRIBUTING.md` +- `doc/dev/tests.md` +- `doc/eng_sys_checks.md` + +### **Read the package READMEs**: + - `sdk/agentserver/azure-ai-agentserver-core/README.md` + - `sdk/agentserver/azure-ai-agentserver-agentframework/README.md` + - `sdk/agentserver/azure-ai-agentserver-langgraph/README.md` + +### β€œIf you need X, look at Y” +- Enable/disable checks for a package β†’ that package `pyproject.toml` β†’ `[tool.azure-sdk-build]` +- How to run tests / live-recorded tests β†’ `doc/dev/tests.md` +- Engineering system checks / gates β†’ `doc/eng_sys_checks.md` +- Adapter conversion behavior β†’ the relevant adapter package + its tests + samples + +## βœ… Testing strategy +- Unit/integration tests live in each package’s `tests/` directory. +- Samples are part of validation via the `samples` tox environment. +- For live/recorded testing patterns, follow `doc/dev/tests.md`. + +## πŸš€ Rollout / migrations +- Preserve public API stability and follow Azure SDK release policy. +- Do not modify generated code (see paths in `AGENTS.md`). +- CI checks are controlled per package in `pyproject.toml` under + `[tool.azure-sdk-build]`. + +## ⚠️ Risks / edge cases +- Streaming event ordering and keep-alive behavior. +- Credential handling (async credentials and adapters). +- Response API schema compatibility across adapters. +- Tool invocation failures and error surfacing. + +## πŸ“Œ Progress +See `TASK.md` for active work items; no checklists here. diff --git a/sdk/agentserver/TASK.md b/sdk/agentserver/TASK.md new file mode 100644 index 000000000000..aecd759e81ce --- /dev/null +++ b/sdk/agentserver/TASK.md @@ -0,0 +1,27 @@ +# TASK.md + +## Now (active) + +## Done + +- [x] 2026-02-06 β€” Add README files for Foundry checkpoint samples + - Files: `azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/README.md`, + `azure-ai-agentserver-langgraph/samples/simple_agent_with_foundry_checkpointer/README.md` + - Updated setup/run/request docs, added missing LangGraph sample README, and corrected `.env` setup guidance. + +- [x] 2026-02-04 β€” Implement managed checkpoints feature + - Files: core/checkpoints/ (new), agentframework/persistence/_foundry_checkpoint_*.py (new), + agentframework/__init__.py (modified) + - Added: FoundryCheckpointClient, FoundryCheckpointStorage, FoundryCheckpointRepository + - Modified: from_agent_framework() with managed_checkpoints, foundry_endpoint, project_id params + +- [x] 2026-01-30 β€” Attach agent server metadata to OpenAIResponse.metadata + header + - Files: azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py, + azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py, + azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py, + azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py, + azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py +- [x] 2026-01-30 β€” Restore previous OTel context in streaming generator + - Files: azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py, + azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py +- [x] 2026-01-29 β€” Create `AGENTS.md`, `PLANNING.md`, `TASK.md`. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md index cfcf2445e256..d9605e795da3 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/CHANGELOG.md @@ -1,5 +1,93 @@ # Release History + +## 1.0.0b10 (2026-01-27) + +### Bugs Fixed + +- Support implicit message item type. +- Make AZURE_AI_PROJECTS_ENDPOINT optional. + +## 1.0.0b9 (2026-01-23) + +### Features Added + +- Integrated with Foundry Tools +- Add persistence for agent thread and checkpoint +- Fixed WorkflowAgent concurrency issue +- Support Human-in-the-Loop + + +## 1.0.0b8 (2026-01-21) + +### Features Added + +- Support keep alive for long-running streaming responses. + +### Bugs Fixed + +- Fixed AgentFramework breaking change and pin version to >=1.0.0b251112,<=1.0.0b260107 + + +## 1.0.0b7 (2025-12-05) + +### Features Added + +- Update response with created_by + +### Bugs Fixed + +- Fixed error response handling in stream and non-stream modes + +## 1.0.0b6 (2025-11-26) + +### Features Added + +- Support Agent-framework greater than 251112 + + +## 1.0.0b5 (2025-11-16) + +### Features Added + +- Support Tools Oauth + +### Bugs Fixed + +- Fixed streaming generation issues. + + +## 1.0.0b4 (2025-11-13) + +### Features Added + +- Adapters support tools + +### Bugs Fixed + +- Pin azure-ai-projects and azure-ai-agents version to avoid version confliction + + +## 1.0.0b3 (2025-11-11) + +### Bugs Fixed + +- Fixed Id generator format. + +- Fixed trace initialization for agent-framework. + + +## 1.0.0b2 (2025-11-10) + +### Bugs Fixed + +- Fixed Id generator format. + +- Improved stream mode error messsage. + +- Updated application insights related configuration environment variables. + + ## 1.0.0b1 (2025-11-07) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py index af980a34799f..c4d938ada8d9 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/__init__.py @@ -1,16 +1,135 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +# pylint: disable=docstring-should-be-keyword __path__ = __import__("pkgutil").extend_path(__path__, __name__) +from typing import Callable, Optional, Union, overload + +from agent_framework import AgentProtocol, BaseAgent, Workflow, WorkflowBuilder +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.credentials import TokenCredential + +from azure.ai.agentserver.core.application import PackageMetadata, set_current_app # pylint: disable=import-error,no-name-in-module + from ._version import VERSION +from ._agent_framework import AgentFrameworkAgent +from ._ai_agent_adapter import AgentFrameworkAIAgentAdapter +from ._workflow_agent_adapter import AgentFrameworkWorkflowAdapter +from ._foundry_tools import FoundryToolsChatMiddleware +from .persistence import AgentThreadRepository, CheckpointRepository, FoundryCheckpointRepository + + +@overload +def from_agent_framework( + agent: Union[BaseAgent, AgentProtocol], + /, + credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, + thread_repository: Optional[AgentThreadRepository]=None + ) -> "AgentFrameworkAIAgentAdapter": + """ + Create an Agent Framework AI Agent Adapter from an AgentProtocol or BaseAgent. + + :param agent: The agent to adapt. + :type agent: Union[BaseAgent, AgentProtocol] + :param credentials: Optional asynchronous token credential for authentication. + :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] + :param thread_repository: Optional thread repository for agent thread management. + :type thread_repository: Optional[AgentThreadRepository] + :return: An instance of AgentFrameworkAIAgentAdapter. + :rtype: AgentFrameworkAIAgentAdapter + """ + ... -def from_agent_framework(agent): - from .agent_framework import AgentFrameworkCBAgent +@overload +def from_agent_framework( + workflow: Union[WorkflowBuilder, Callable[[], Workflow]], + /, + credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, + thread_repository: Optional[AgentThreadRepository] = None, + checkpoint_repository: Optional[CheckpointRepository] = None, + ) -> "AgentFrameworkWorkflowAdapter": + """ + Create an Agent Framework Workflow Adapter. + The argument `workflow` can be either a WorkflowBuilder or a factory function + that returns a Workflow. + It will be called to create a new Workflow instance and `.as_agent()` will be + called as well for each incoming CreateResponse request. Please ensure that the + workflow definition can be converted to a WorkflowAgent. For more information, + see the agent-framework samples and documentation. - return AgentFrameworkCBAgent(agent) + :param workflow: The workflow builder or factory function to adapt. + :type workflow: Union[WorkflowBuilder, Callable[[], Workflow]] + :param credentials: Optional asynchronous token credential for authentication. + :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] + :param thread_repository: Optional thread repository for agent thread management. + :type thread_repository: Optional[AgentThreadRepository] + :param checkpoint_repository: Optional checkpoint repository for workflow checkpointing. + Use ``InMemoryCheckpointRepository``, ``FileCheckpointRepository``, or + ``FoundryCheckpointRepository`` for Azure AI Foundry managed storage. + :type checkpoint_repository: Optional[CheckpointRepository] + :return: An instance of AgentFrameworkWorkflowAdapter. + :rtype: AgentFrameworkWorkflowAdapter + """ + ... +def from_agent_framework( + agent_or_workflow: Union[BaseAgent, AgentProtocol, WorkflowBuilder, Callable[[], Workflow]], + /, + credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, + thread_repository: Optional[AgentThreadRepository] = None, + checkpoint_repository: Optional[CheckpointRepository] = None, +) -> "AgentFrameworkAgent": + """ + Create an Agent Framework Adapter from either an AgentProtocol/BaseAgent or a + WorkflowAgent. + One of agent or workflow must be provided. -__all__ = ["from_agent_framework"] + :param agent_or_workflow: The agent to adapt. + :type agent_or_workflow: Optional[Union[BaseAgent, AgentProtocol]] + :param credentials: Optional asynchronous token credential for authentication. + :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] + :param thread_repository: Optional thread repository for agent thread management. + :type thread_repository: Optional[AgentThreadRepository] + :param checkpoint_repository: Optional checkpoint repository for workflow checkpointing. + Use ``InMemoryCheckpointRepository``, ``FileCheckpointRepository``, or + ``FoundryCheckpointRepository`` for Azure AI Foundry managed storage. + :type checkpoint_repository: Optional[CheckpointRepository] + :return: An instance of AgentFrameworkAgent. + :rtype: AgentFrameworkAgent + :raises TypeError: If neither or both of agent and workflow are provided, or if + the provided types are incorrect. + """ + + if isinstance(agent_or_workflow, WorkflowBuilder): + return AgentFrameworkWorkflowAdapter( + workflow_factory=agent_or_workflow.build, + credentials=credentials, + thread_repository=thread_repository, + checkpoint_repository=checkpoint_repository, + ) + if isinstance(agent_or_workflow, Callable): # type: ignore + return AgentFrameworkWorkflowAdapter( + workflow_factory=agent_or_workflow, + credentials=credentials, + thread_repository=thread_repository, + checkpoint_repository=checkpoint_repository, + ) + # raise TypeError("workflow must be a WorkflowBuilder or callable returning a Workflow") + + if isinstance(agent_or_workflow, (AgentProtocol, BaseAgent)): + return AgentFrameworkAIAgentAdapter(agent_or_workflow, + credentials=credentials, + thread_repository=thread_repository) + raise TypeError("You must provide one of the instances of type " + "[AgentProtocol, BaseAgent, WorkflowBuilder or callable returning a Workflow]") + + +__all__ = [ + "from_agent_framework", + "FoundryToolsChatMiddleware", +] __version__ = VERSION + +set_current_app(PackageMetadata.from_dist("azure-ai-agentserver-agentframework")) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py new file mode 100644 index 000000000000..32c171ec9eff --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_agent_framework.py @@ -0,0 +1,282 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=logging-fstring-interpolation,no-name-in-module,no-member,do-not-import-asyncio +from __future__ import annotations + +import os +from typing import Any, AsyncGenerator, Optional, TYPE_CHECKING, Union, Callable + +from agent_framework import AgentProtocol, AgentThread, WorkflowAgent +from agent_framework.azure import AzureAIClient # pylint: disable=no-name-in-module +from opentelemetry import trace + +from azure.ai.agentserver.core import AgentRunContext, FoundryCBAgent +from azure.ai.agentserver.core.constants import Constants as AdapterConstants +from azure.ai.agentserver.core.logger import APPINSIGHT_CONNSTR_ENV_NAME, get_logger +from azure.ai.agentserver.core.models import ( + Response as OpenAIResponse, + ResponseStreamEvent, +) +from azure.ai.agentserver.core.models.projects import ResponseErrorEvent, ResponseFailedEvent +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError # pylint: disable=import-error + +from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter +from .models.human_in_the_loop_helper import HumanInTheLoopHelper +from .persistence import AgentThreadRepository + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential + +logger = get_logger() + + +class AgentFrameworkAgent(FoundryCBAgent): + """ + Adapter class for integrating Agent Framework agents with the FoundryCB agent interface. + + This class wraps an Agent Framework `AgentProtocol` instance and provides a unified interface + for running agents in both streaming and non-streaming modes. It handles input and output + conversion between the Agent Framework and the expected formats for FoundryCB agents. + + Parameters: + agent (AgentProtocol): An instance of an Agent Framework agent to be adapted. + + Usage: + - Instantiate with an Agent Framework agent. + - Call `agent_run` with a `CreateResponse` request body to execute the agent. + - Supports both streaming and non-streaming responses based on the `stream` flag. + """ + + def __init__(self, + credentials: "Optional[AsyncTokenCredential]" = None, + thread_repository: Optional[AgentThreadRepository] = None): + """Initialize the AgentFrameworkAgent with an AgentProtocol. + + :param credentials: Azure credentials for authentication. + :type credentials: Optional[AsyncTokenCredential] + :param thread_repository: An optional AgentThreadRepository instance for managing thread messages. + :type thread_repository: Optional[AgentThreadRepository] + """ + super().__init__(credentials=credentials) # pylint: disable=unexpected-keyword-arg + self._thread_repository = thread_repository + self._hitl_helper = HumanInTheLoopHelper() + + def init_tracing(self): + try: + otel_exporter_endpoint = os.environ.get(AdapterConstants.OTEL_EXPORTER_ENDPOINT) + otel_exporter_protocol = os.environ.get(AdapterConstants.OTEL_EXPORTER_OTLP_PROTOCOL) + app_insights_conn_str = os.environ.get(APPINSIGHT_CONNSTR_ENV_NAME) + project_endpoint = os.environ.get(AdapterConstants.AZURE_AI_PROJECT_ENDPOINT) + + exporters = [] + if otel_exporter_endpoint: + otel_exporter = self._create_otlp_exporter(otel_exporter_endpoint, protocol=otel_exporter_protocol) + if otel_exporter: + exporters.append(otel_exporter) + if app_insights_conn_str: + appinsight_exporter = self._create_application_insights_exporter(app_insights_conn_str) + if appinsight_exporter: + exporters.append(appinsight_exporter) + + if exporters and self._setup_observability(exporters): + logger.info("Observability setup completed with provided exporters.") + elif project_endpoint: + self._setup_tracing_with_azure_ai_client(project_endpoint) + except Exception as e: # pylint: disable=broad-exception-caught + logger.warning(f"Failed to initialize tracing: {e}", exc_info=True) + self.tracer = trace.get_tracer(__name__) + + def _create_application_insights_exporter(self, connection_string): + try: + from azure.monitor.opentelemetry.exporter import AzureMonitorTraceExporter + + return AzureMonitorTraceExporter.from_connection_string(connection_string) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error(f"Failed to create Application Insights exporter: {e}", exc_info=True) + return None + + def _create_otlp_exporter(self, endpoint, protocol=None): + try: + if protocol and protocol.lower() in ("http", "http/protobuf", "http/json"): + from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter + + return OTLPSpanExporter(endpoint=endpoint) + + from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter + + return OTLPSpanExporter(endpoint=endpoint) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error(f"Failed to create OTLP exporter: {e}", exc_info=True) + return None + + def _setup_observability(self, exporters) -> bool: + setup_function = self._try_import_configure_otel_providers() + if not setup_function: # fallback to early version with setup_observability + setup_function = self._try_import_setup_observability() + if setup_function: + setup_function( + enable_sensitive_data=True, + exporters=exporters, + ) + return True + return False + + def _try_import_setup_observability(self): + try: + from agent_framework.observability import setup_observability + return setup_observability + except ImportError as e: + logger.warning(f"Failed to import setup_observability: {e}") + return None + + def _try_import_configure_otel_providers(self): + try: + from agent_framework.observability import configure_otel_providers + return configure_otel_providers + except ImportError as e: + logger.warning(f"Failed to import configure_otel_providers: {e}") + return None + + def _setup_tracing_with_azure_ai_client(self, project_endpoint: str): + async def setup_async(): + async with AzureAIClient( + project_endpoint=project_endpoint, + async_credential=self.credentials, + credential=self.credentials, # Af breaking change, keep both for compatibility + ) as agent_client: + try: + await agent_client.configure_azure_monitor() + except AttributeError: + await agent_client.setup_azure_ai_observability() + + import asyncio + + loop = asyncio.get_event_loop() + if loop.is_running(): + # If loop is already running, schedule as a task + asyncio.create_task(setup_async()) + else: + # Run in new event loop + loop.run_until_complete(setup_async()) + + async def agent_run( # pylint: disable=too-many-statements + self, context: AgentRunContext + ) -> Union[ + OpenAIResponse, + AsyncGenerator[ResponseStreamEvent, Any], + ]: + raise NotImplementedError("This method is implemented in the base class.") + + async def _load_agent_thread( + self, + context: AgentRunContext, + agent: Union[AgentProtocol, WorkflowAgent], + ) -> Optional[AgentThread]: + """Load the agent thread for a given conversation ID. + + :param context: The agent run context. + :type context: AgentRunContext + :param agent: The agent instance. + :type agent: AgentProtocol | WorkflowAgent + + :return: The loaded AgentThread if available, None otherwise. + :rtype: Optional[AgentThread] + """ + if self._thread_repository: + conversation_id = context.conversation_id + agent_thread = await self._thread_repository.get(conversation_id) + if agent_thread: + logger.info(f"Loaded agent thread for conversation: {conversation_id}") + return agent_thread + return agent.get_new_thread() + return None + + async def _save_agent_thread(self, context: AgentRunContext, agent_thread: AgentThread) -> None: + """Save the agent thread for a given conversation ID. + + :param context: The agent run context. + :type context: AgentRunContext + :param agent_thread: The agent thread to save. + :type agent_thread: AgentThread + + :return: None + :rtype: None + """ + if agent_thread and self._thread_repository and (conversation_id := context.conversation_id): + await self._thread_repository.set(conversation_id, agent_thread) + logger.info(f"Saved agent thread for conversation: {conversation_id}") + + def _run_streaming_updates( + self, + context: AgentRunContext, + run_stream: Callable[[], AsyncGenerator[Any, None]], + agent_thread: Optional[AgentThread] = None, + ) -> AsyncGenerator[ResponseStreamEvent, Any]: + """ + Execute a streaming run with shared OAuth/error handling. + + :param context: The agent run context. + :type context: AgentRunContext + :param run_stream: A callable that invokes the agent in stream mode + :type run_stream: Callable[[], AsyncGenerator[Any, None]] + :param agent_thread: The agent thread to use during streaming updates. + :type agent_thread: Optional[AgentThread] + + :return: An async generator yielding streaming events. + :rtype: AsyncGenerator[ResponseStreamEvent, Any] + """ + logger.info("Running agent in streaming mode") + streaming_converter = AgentFrameworkOutputStreamingConverter( + context, + hitl_helper=self._hitl_helper, + ) + + async def stream_updates(): + try: + update_count = 0 + try: + updates = run_stream() + async for event in streaming_converter.convert(updates): + update_count += 1 + yield event + + await self._save_agent_thread(context, agent_thread) + logger.info("Streaming completed with %d updates", update_count) + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during streaming updates") + if update_count == 0: + async for event in self.respond_with_oauth_consent_astream(context, e): + yield event + else: + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=f"OAuth consent required: {e.consent_url}", + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response( # pylint: disable=protected-access + status="failed" + ), + ) + except Exception as e: # pylint: disable=broad-exception-caught + logger.error("Unhandled exception during streaming updates: %s", e, exc_info=True) + yield ResponseErrorEvent( + sequence_number=streaming_converter.next_sequence(), + code="server_error", + message=str(e), + param="agent_run", + ) + yield ResponseFailedEvent( + sequence_number=streaming_converter.next_sequence(), + response=streaming_converter._build_response( # pylint: disable=protected-access + status="failed" + ), + ) + finally: + # No request-scoped resources to clean up today, but keep hook for future use. + pass + + return stream_updates() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py new file mode 100644 index 000000000000..100f75ed3d82 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_ai_agent_adapter.py @@ -0,0 +1,88 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=no-name-in-module,import-error +from __future__ import annotations + +from typing import Any, AsyncGenerator, Optional, Union + +from agent_framework import AgentProtocol +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azure.ai.agentserver.core import AgentRunContext +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.models import ( + Response as OpenAIResponse, + ResponseStreamEvent, +) + +from .models.agent_framework_input_converters import AgentFrameworkInputConverter +from .models.agent_framework_output_non_streaming_converter import ( + AgentFrameworkOutputNonStreamingConverter, +) +from ._agent_framework import AgentFrameworkAgent +from .persistence import AgentThreadRepository + +logger = get_logger() + +class AgentFrameworkAIAgentAdapter(AgentFrameworkAgent): + def __init__(self, agent: AgentProtocol, + credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, + thread_repository: Optional[AgentThreadRepository] = None) -> None: + super().__init__(credentials, thread_repository) + self._agent = agent + + async def agent_run( # pylint: disable=too-many-statements + self, context: AgentRunContext + ) -> Union[ + OpenAIResponse, + AsyncGenerator[ResponseStreamEvent, Any], + ]: + try: + logger.info("Starting AIAgent agent_run with stream=%s", context.stream) + request_input = context.request.get("input") + + agent_thread = await self._load_agent_thread(context, self._agent) + + input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) + message = await input_converter.transform_input( + request_input, + agent_thread=agent_thread) + logger.debug("Transformed input message type: %s", type(message)) + # Use split converters + if context.stream: + return self._run_streaming_updates( + context=context, + run_stream=lambda: self._agent.run_stream( + message, + thread=agent_thread, + ), + agent_thread=agent_thread, + ) + + # Non-streaming path + logger.info("Running agent in non-streaming mode") + result = await self._agent.run( + message, + thread=agent_thread) + logger.debug("Agent run completed, result type: %s", type(result)) + await self._save_agent_thread(context, agent_thread) + + non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) + transformed_result = non_streaming_converter.transform_output_for_response(result) + logger.info("Agent run and transformation completed successfully") + return transformed_result + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during agent run") + if context.stream: + # Yield OAuth consent response events + # Capture e in the closure by passing it as a default argument + async def oauth_consent_stream(error=e): + async for event in self.respond_with_oauth_consent_astream(context, error): + yield event + return oauth_consent_stream() + return await self.respond_with_oauth_consent(context, e) + finally: + pass diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py new file mode 100644 index 000000000000..f936d32e4ec1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_foundry_tools.py @@ -0,0 +1,138 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=client-accepts-api-version-keyword,missing-client-constructor-parameter-credential,missing-client-constructor-parameter-kwargs +# pylint: disable=no-name-in-module,import-error +from __future__ import annotations + +import inspect +from typing import Any, Awaitable, Callable, Dict, List, Optional, Sequence + +from agent_framework import AIFunction, ChatContext, ChatOptions, ChatMiddleware +from pydantic import Field, create_model + +from azure.ai.agentserver.core import AgentServerContext +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.tools import FoundryToolLike, ResolvedFoundryTool, ensure_foundry_tool + +logger = get_logger() + + +def _attach_signature_from_pydantic_model(func, input_model) -> None: + params = [] + annotations: Dict[str, Any] = {} + + for name, field in input_model.model_fields.items(): + ann = field.annotation or Any + annotations[name] = ann + + default = inspect._empty if field.is_required() else field.default # pylint: disable=protected-access + params.append( + inspect.Parameter( + name=name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=default, + annotation=ann, + ) + ) + + func.__signature__ = inspect.Signature(parameters=params, return_annotation=Any) + func.__annotations__ = {**annotations, "return": Any} + +class FoundryToolClient: + + def __init__( + self, + tools: Sequence[FoundryToolLike], + ) -> None: + self._allowed_tools: List[FoundryToolLike] = [ensure_foundry_tool(tool) for tool in tools] + + async def list_tools(self) -> List[AIFunction]: + server_context = AgentServerContext.get() + foundry_tool_catalog = server_context.tools.catalog + resolved_tools = await foundry_tool_catalog.list(self._allowed_tools) + return [self._to_aifunction(tool) for tool in resolved_tools] + + def _to_aifunction(self, foundry_tool: "ResolvedFoundryTool") -> AIFunction: + """Convert an FoundryTool to an Agent Framework AI Function + + :param foundry_tool: The FoundryTool to convert. + :type foundry_tool: ~azure.ai.agentserver.core.client.tools.aio.FoundryTool + :return: An AI Function Tool. + :rtype: AIFunction + """ + # Get the input schema from the tool descriptor + input_schema = foundry_tool.input_schema or {} + + # Create a Pydantic model from the input schema + properties = input_schema.properties or {} + required_fields = set(input_schema.required or []) + + # Build field definitions for the Pydantic model + field_definitions: Dict[str, Any] = {} + for field_name, field_info in properties.items(): + if field_info.type is None: + logger.warning("Skipping field '%s' in tool '%s': unknown or empty schema type.", + field_name, foundry_tool.name) + continue + field_type = field_info.type.py_type + field_description = field_info.description or "" + is_required = field_name in required_fields + + if is_required: + field_definitions[field_name] = (field_type, Field(description=field_description)) + else: + field_definitions[field_name] = (Optional[field_type], + Field(default=None, description=field_description)) + + # Create the Pydantic model dynamically + input_model = create_model( + f"{foundry_tool.name}_input", + **field_definitions + ) + + # Create a wrapper function that calls the Azure tool + async def tool_func(**kwargs: Any) -> Any: + """Dynamically generated function to invoke the Azure AI tool. + + :return: The result from the tool invocation. + :rtype: Any + """ + server_context = AgentServerContext.get() + logger.debug("Invoking tool: %s with input: %s", foundry_tool.name, kwargs) + return await server_context.tools.invoke(foundry_tool, kwargs) + _attach_signature_from_pydantic_model(tool_func, input_model) + + # Create and return the AIFunction + return AIFunction( + name=foundry_tool.name, + description=foundry_tool.description or "No description available", + func=tool_func, + input_model=input_model + ) + + +class FoundryToolsChatMiddleware(ChatMiddleware): + """Chat middleware to inject Foundry tools into ChatOptions on each call.""" + + def __init__( + self, + tools: Sequence[FoundryToolLike]) -> None: + self._foundry_tool_client = FoundryToolClient(tools=tools) + + async def process( + self, + context: ChatContext, + next: Callable[[ChatContext], Awaitable[None]], + ) -> None: + tools = await self._foundry_tool_client.list_tools() + base_chat_options = context.chat_options + if not base_chat_options: + logger.debug("No existing ChatOptions found, creating new one with Foundry tools.") + base_chat_options = ChatOptions(tools=tools) + context.chat_options = base_chat_options + else: + logger.debug("Adding Foundry tools to existing ChatOptions.") + base_tools = base_chat_options.tools or [] + context.chat_options.tools = base_tools + tools + await next(context) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py index be71c81bd282..9ab0a006e0d0 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b1" +VERSION = "1.0.0b10" diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py new file mode 100644 index 000000000000..a119d697a377 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/_workflow_agent_adapter.py @@ -0,0 +1,162 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=no-name-in-module,import-error +from typing import ( + Any, + AsyncGenerator, + Callable, + Optional, + Union, +) + +from agent_framework import Workflow, CheckpointStorage, WorkflowAgent, WorkflowCheckpoint +from agent_framework._workflows import get_checkpoint_summary +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azure.ai.agentserver.core import AgentRunContext +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.models import ( + Response as OpenAIResponse, + ResponseStreamEvent, +) +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError + +from ._agent_framework import AgentFrameworkAgent +from .models.agent_framework_input_converters import AgentFrameworkInputConverter +from .models.agent_framework_output_non_streaming_converter import ( + AgentFrameworkOutputNonStreamingConverter, +) +from .persistence import AgentThreadRepository, CheckpointRepository + +logger = get_logger() + +class AgentFrameworkWorkflowAdapter(AgentFrameworkAgent): + """Adapter to run WorkflowBuilder agents within the Agent Framework CBAgent structure.""" + def __init__( + self, + workflow_factory: Callable[[], Workflow], + credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, + thread_repository: Optional[AgentThreadRepository] = None, + checkpoint_repository: Optional[CheckpointRepository] = None, + ) -> None: + super().__init__(credentials, thread_repository) + self._workflow_factory = workflow_factory + self._checkpoint_repository = checkpoint_repository + + async def agent_run( # pylint: disable=too-many-statements + self, context: AgentRunContext + ) -> Union[ + OpenAIResponse, + AsyncGenerator[ResponseStreamEvent, Any], + ]: + try: + agent = self._build_agent() + + logger.info("Starting WorkflowAgent agent_run with stream=%s", context.stream) + request_input = context.request.get("input") + + agent_thread = await self._load_agent_thread(context, agent) + + checkpoint_storage = None + selected_checkpoint = None + if self._checkpoint_repository and (conversation_id := context.conversation_id): + checkpoint_storage = await self._checkpoint_repository.get_or_create(conversation_id) + if checkpoint_storage: + selected_checkpoint = await self._get_latest_checkpoint(checkpoint_storage) + if selected_checkpoint: + summary = get_checkpoint_summary(selected_checkpoint) + if summary.status == "completed": + logger.warning( + "Selected checkpoint %s is completed. Will not resume from it.", + selected_checkpoint.checkpoint_id, + ) + selected_checkpoint = None # Do not resume from completed checkpoints + else: + await self._load_checkpoint(agent, selected_checkpoint, checkpoint_storage) + logger.info("Loaded checkpoint with ID: %s", selected_checkpoint.checkpoint_id) + + input_converter = AgentFrameworkInputConverter(hitl_helper=self._hitl_helper) + message = await input_converter.transform_input( + request_input, + agent_thread=agent_thread, + checkpoint=selected_checkpoint) + logger.debug("Transformed input message type: %s", type(message)) + + # Use split converters + if context.stream: + return self._run_streaming_updates( + context=context, + run_stream=lambda: agent.run_stream( + message, + thread=agent_thread, + checkpoint_storage=checkpoint_storage, + ), + agent_thread=agent_thread, + ) + + # Non-streaming path + logger.info("Running WorkflowAgent in non-streaming mode") + result = await agent.run( + message, + thread=agent_thread, + checkpoint_storage=checkpoint_storage) + logger.debug("WorkflowAgent run completed, result type: %s", type(result)) + + await self._save_agent_thread(context, agent_thread) + + non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context, hitl_helper=self._hitl_helper) + transformed_result = non_streaming_converter.transform_output_for_response(result) + logger.info("Agent run and transformation completed successfully") + return transformed_result + except OAuthConsentRequiredError as e: + logger.info("OAuth consent required during agent run") + if context.stream: + # Yield OAuth consent response events + # Capture e in the closure by passing it as a default argument + async def oauth_consent_stream(error=e): + async for event in self.respond_with_oauth_consent_astream(context, error): + yield event + return oauth_consent_stream() + return await self.respond_with_oauth_consent(context, e) + finally: + pass + + def _build_agent(self) -> WorkflowAgent: + return self._workflow_factory().as_agent() + + async def _get_latest_checkpoint(self, + checkpoint_storage: CheckpointStorage) -> Optional[Any]: + """Load the latest checkpoint from the given storage. + + :param checkpoint_storage: The checkpoint storage to load from. + :type checkpoint_storage: CheckpointStorage + + :return: The latest checkpoint if available, None otherwise. + :rtype: Optional[Any] + """ + checkpoints = await checkpoint_storage.list_checkpoints() + if checkpoints: + latest_checkpoint = max(checkpoints, key=lambda cp: cp.timestamp) + return latest_checkpoint + return None + + async def _load_checkpoint(self, + agent: WorkflowAgent, + checkpoint: WorkflowCheckpoint, + checkpoint_storage: CheckpointStorage) -> None: + """Load the checkpoint data from the given WorkflowCheckpoint. + + :param agent: The WorkflowAgent to load the checkpoint into. + :type agent: WorkflowAgent + :param checkpoint: The WorkflowCheckpoint to load data from. + :type checkpoint: WorkflowCheckpoint + :param checkpoint_storage: The storage to load the checkpoint from. + :type checkpoint_storage: CheckpointStorage + + :return: None + :rtype: None + """ + await agent.run(checkpoint_id=checkpoint.checkpoint_id, + checkpoint_storage=checkpoint_storage) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py deleted file mode 100644 index 7177b522d2a9..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/agent_framework.py +++ /dev/null @@ -1,153 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- -# pylint: disable=logging-fstring-interpolation -from __future__ import annotations - -import asyncio # pylint: disable=do-not-import-asyncio -import os -from typing import Any, AsyncGenerator, Union - -from agent_framework import AgentProtocol -from agent_framework.azure import AzureAIAgentClient # pylint: disable=no-name-in-module -from opentelemetry import trace - -from azure.ai.agentserver.core import AgentRunContext, FoundryCBAgent -from azure.ai.agentserver.core.constants import Constants as AdapterConstants -from azure.ai.agentserver.core.logger import get_logger -from azure.ai.agentserver.core.models import ( - CreateResponse, - Response as OpenAIResponse, - ResponseStreamEvent, -) -from azure.ai.projects import AIProjectClient -from azure.identity import DefaultAzureCredential - -from .models.agent_framework_input_converters import AgentFrameworkInputConverter -from .models.agent_framework_output_non_streaming_converter import ( - AgentFrameworkOutputNonStreamingConverter, -) -from .models.agent_framework_output_streaming_converter import AgentFrameworkOutputStreamingConverter -from .models.constants import Constants - -logger = get_logger() - - -class AgentFrameworkCBAgent(FoundryCBAgent): - """ - Adapter class for integrating Agent Framework agents with the FoundryCB agent interface. - - This class wraps an Agent Framework `AgentProtocol` instance and provides a unified interface - for running agents in both streaming and non-streaming modes. It handles input and output - conversion between the Agent Framework and the expected formats for FoundryCB agents. - - Parameters: - agent (AgentProtocol): An instance of an Agent Framework agent to be adapted. - - Usage: - - Instantiate with an Agent Framework agent. - - Call `agent_run` with a `CreateResponse` request body to execute the agent. - - Supports both streaming and non-streaming responses based on the `stream` flag. - """ - - def __init__(self, agent: AgentProtocol): - super().__init__() - self.agent = agent - logger.info(f"Initialized AgentFrameworkCBAgent with agent: {type(agent).__name__}") - - def _resolve_stream_timeout(self, request_body: CreateResponse) -> float: - """Resolve idle timeout for streaming updates. - - Order of precedence: - 1) request_body.stream_timeout_s (if provided) - 2) env var Constants.AGENTS_ADAPTER_STREAM_TIMEOUT_S - 3) Constants.DEFAULT_STREAM_TIMEOUT_S - - :param request_body: The CreateResponse request body. - :type request_body: CreateResponse - - :return: The resolved stream timeout in seconds. - :rtype: float - """ - override = request_body.get("stream_timeout_s", None) - if override is not None: - return float(override) - env_val = os.getenv(Constants.AGENTS_ADAPTER_STREAM_TIMEOUT_S) - return float(env_val) if env_val is not None else float(Constants.DEFAULT_STREAM_TIMEOUT_S) - - def init_tracing(self): - exporter = os.environ.get(AdapterConstants.OTEL_EXPORTER_ENDPOINT) - app_insights_conn_str = os.environ.get(AdapterConstants.APPLICATION_INSIGHTS_CONNECTION_STRING) - project_endpoint = os.environ.get(AdapterConstants.AZURE_AI_PROJECT_ENDPOINT) - - if project_endpoint: - project_client = AIProjectClient(endpoint=project_endpoint, credential=DefaultAzureCredential()) - agent_client = AzureAIAgentClient(project_client=project_client) - agent_client.setup_azure_ai_observability() - elif exporter or app_insights_conn_str: - os.environ["WORKFLOW_ENABLE_OTEL"] = "true" - from agent_framework.observability import setup_observability - - setup_observability( - enable_sensitive_data=True, - otlp_endpoint=exporter, - applicationinsights_connection_string=app_insights_conn_str, - ) - self.tracer = trace.get_tracer(__name__) - - async def agent_run( - self, context: AgentRunContext - ) -> Union[ - OpenAIResponse, - AsyncGenerator[ResponseStreamEvent, Any], - ]: - logger.info(f"Starting agent_run with stream={context.stream}") - request_input = context.request.get("input") - - input_converter = AgentFrameworkInputConverter() - message = input_converter.transform_input(request_input) - logger.debug(f"Transformed input message type: {type(message)}") - - # Use split converters - if context.stream: - logger.info("Running agent in streaming mode") - streaming_converter = AgentFrameworkOutputStreamingConverter(context) - - async def stream_updates(): - update_count = 0 - timeout_s = self._resolve_stream_timeout(context.request) - logger.info("Starting streaming with idle-timeout=%.2fs", timeout_s) - for ev in streaming_converter.initial_events(): - yield ev - - # Iterate with per-update timeout; terminate if idle too long - aiter = self.agent.run_stream(message).__aiter__() - while True: - try: - update = await asyncio.wait_for(aiter.__anext__(), timeout=timeout_s) - except StopAsyncIteration: - logger.debug("Agent streaming iterator finished (StopAsyncIteration)") - break - except asyncio.TimeoutError: - logger.warning("Streaming idle timeout reached (%.1fs); terminating stream.", timeout_s) - for ev in streaming_converter.completion_events(): - yield ev - return - update_count += 1 - transformed = streaming_converter.transform_output_for_streaming(update) - for event in transformed: - yield event - for ev in streaming_converter.completion_events(): - yield ev - logger.info("Streaming completed with %d updates", update_count) - - return stream_updates() - - # Non-streaming path - logger.info("Running agent in non-streaming mode") - non_streaming_converter = AgentFrameworkOutputNonStreamingConverter(context) - result = await self.agent.run(message) - logger.debug(f"Agent run completed, result type: {type(result)}") - transformed_result = non_streaming_converter.transform_output_for_response(result) - logger.info("Agent run and transformation completed successfully") - return transformed_result diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py index 993be43e85c8..a21e5b9c44c7 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_input_converters.py @@ -5,9 +5,15 @@ # mypy: disable-error-code="no-redef" from __future__ import annotations -from typing import Dict, List - -from agent_framework import ChatMessage, Role as ChatRole +from typing import Dict, List, Optional + +from agent_framework import ( + AgentThread, + ChatMessage, + RequestInfoEvent, + Role as ChatRole, + WorkflowCheckpoint, +) from agent_framework._types import TextContent from azure.ai.agentserver.core.logger import get_logger @@ -21,10 +27,14 @@ class AgentFrameworkInputConverter: Accepts: str | List | None Returns: None | str | ChatMessage | list[str] | list[ChatMessage] """ + def __init__(self, *, hitl_helper=None) -> None: + self._hitl_helper = hitl_helper - def transform_input( + async def transform_input( self, input: str | List[Dict] | None, + agent_thread: Optional[AgentThread] = None, + checkpoint: Optional[WorkflowCheckpoint] = None, ) -> str | ChatMessage | list[str] | list[ChatMessage] | None: logger.debug("Transforming input of type: %s", type(input)) @@ -34,6 +44,26 @@ def transform_input( if isinstance(input, str): return input + if self._hitl_helper: + # load pending requests from checkpoint and thread messages if available + thread_messages = [] + if agent_thread and agent_thread.message_store: + thread_messages = await agent_thread.message_store.list_messages() + pending_hitl_requests = self._hitl_helper.get_pending_hitl_request(thread_messages, checkpoint) + if pending_hitl_requests: + logger.info("Pending HitL requests: %s", list(pending_hitl_requests.keys())) + hitl_response = self._hitl_helper.validate_and_convert_hitl_response( + input, + pending_requests=pending_hitl_requests) + if hitl_response: + return hitl_response + + return self._transform_input_internal(input) + + def _transform_input_internal( + self, + input: str | List[Dict] | None, + ) -> str | ChatMessage | list[str] | list[ChatMessage] | None: try: if isinstance(input, list): messages: list[str | ChatMessage] = [] @@ -53,9 +83,9 @@ def transform_input( if text_parts: messages.append(" ".join(text_parts)) - # Case 2: Explicit message params (user/assistant/system) + # Case 2: message params (user/assistant/system) elif ( - item.get("type") == "message" + item.get("type") in ("", None, "message") and item.get("role") is not None and item.get("content") is not None ): @@ -118,3 +148,35 @@ def _extract_input_text(self, content_item: Dict) -> str: if isinstance(text_content, str): return text_content return None # type: ignore + + def _validate_and_convert_hitl_response( + self, + pending_request: Dict, + input: List[Dict], + ) -> Optional[List[ChatMessage]]: + if not self._hitl_helper: + logger.warning("HitL helper not provided; cannot validate HitL response.") + return None + if isinstance(input, str): + logger.warning("Expected list input for HitL response validation, got str.") + return None + if not isinstance(input, list) or len(input) != 1: + logger.warning("Expected single-item list input for HitL response validation.") + return None + + item = input[0] + if item.get("type") != "function_call_output": + logger.warning("Expected function_call_output type for HitL response validation.") + return None + call_id = item.get("call_id", None) + if not call_id or call_id not in pending_request: + logger.warning("Function call output missing valid call_id for HitL response validation.") + return None + request_info = pending_request[call_id] + if isinstance(request_info, dict): + request_info = RequestInfoEvent.from_dict(request_info) + if not isinstance(request_info, RequestInfoEvent): + logger.warning("No valid pending request info found for call_id: %s", call_id) + return None + + return self._hitl_helper.convert_response(request_info, item) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py index 805a5eeb9dec..4984b2fc0423 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_non_streaming_converter.py @@ -7,19 +7,23 @@ import json from typing import Any, List -from agent_framework import AgentRunResponse, FunctionResultContent -from agent_framework._types import FunctionCallContent, TextContent +from agent_framework import ( + AgentRunResponse, + FunctionCallContent, + FunctionResultContent, + ErrorContent, + TextContent, +) +from agent_framework._types import UserInputRequestContents from azure.ai.agentserver.core import AgentRunContext from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import Response as OpenAIResponse -from azure.ai.agentserver.core.models.projects import ( - ItemContentOutputText, - ResponsesAssistantMessageItemResource, -) +from azure.ai.agentserver.core.models.projects import ItemContentOutputText from .agent_id_generator import AgentIdGenerator from .constants import Constants +from .human_in_the_loop_helper import HumanInTheLoopHelper logger = get_logger() @@ -27,10 +31,11 @@ class AgentFrameworkOutputNonStreamingConverter: # pylint: disable=name-too-long """Non-streaming converter: AgentRunResponse -> OpenAIResponse.""" - def __init__(self, context: AgentRunContext): + def __init__(self, context: AgentRunContext, *, hitl_helper: HumanInTheLoopHelper=None): self._context = context self._response_id = None self._response_created_at = None + self._hitl_helper = hitl_helper def _ensure_response_started(self) -> None: if not self._response_id: @@ -41,11 +46,19 @@ def _ensure_response_started(self) -> None: def _build_item_content_output_text(self, text: str) -> ItemContentOutputText: return ItemContentOutputText(text=text, annotations=[]) - def _new_assistant_message_item(self, message_text: str) -> ResponsesAssistantMessageItemResource: - item_content = self._build_item_content_output_text(message_text) - return ResponsesAssistantMessageItemResource( - id=self._context.id_generator.generate_message_id(), status="completed", content=[item_content] - ) + def _build_created_by(self, author_name: str) -> dict: + self._ensure_response_started() + + agent_dict = { + "type": "agent_id", + "name": author_name or "", + "version": "", # Default to empty string + } + + return { + "agent": agent_dict, + "response_id": self._response_id, + } def transform_output_for_response(self, response: AgentRunResponse) -> OpenAIResponse: """Build an OpenAIResponse capturing all supported content types. @@ -72,9 +85,11 @@ def transform_output_for_response(self, response: AgentRunResponse) -> OpenAIRes contents = getattr(message, "contents", None) if not contents: continue + # Extract author_name from this message + msg_author_name = getattr(message, "author_name", None) or "" for j, content in enumerate(contents): logger.debug(" content index=%d in message=%d type=%s", j, i, type(content).__name__) - self._append_content_item(content, completed_items) + self._append_content_item(content, completed_items, msg_author_name) response_data = self._construct_response_data(completed_items) openai_response = OpenAIResponse(response_data) @@ -87,7 +102,7 @@ def transform_output_for_response(self, response: AgentRunResponse) -> OpenAIRes # ------------------------- helper append methods ------------------------- - def _append_content_item(self, content: Any, sink: List[dict]) -> None: + def _append_content_item(self, content: Any, sink: List[dict], author_name: str) -> None: """Dispatch a content object to the appropriate append helper. Adding this indirection keeps the main transform method compact and makes it @@ -97,20 +112,26 @@ def _append_content_item(self, content: Any, sink: List[dict]) -> None: :type content: Any :param sink: The list to append the converted content dict to. :type sink: List[dict] + :param author_name: The author name for the created_by field. + :type author_name: str :return: None :rtype: None """ if isinstance(content, TextContent): - self._append_text_content(content, sink) + self._append_text_content(content, sink, author_name) elif isinstance(content, FunctionCallContent): - self._append_function_call_content(content, sink) + self._append_function_call_content(content, sink, author_name) elif isinstance(content, FunctionResultContent): - self._append_function_result_content(content, sink) + self._append_function_result_content(content, sink, author_name) + elif isinstance(content, UserInputRequestContents): + self._append_user_input_request_contents(content, sink, author_name) + elif isinstance(content, ErrorContent): + raise ValueError(f"ErrorContent received: code={content.error_code}, message={content.message}") else: logger.debug("unsupported content type skipped: %s", type(content).__name__) - def _append_text_content(self, content: TextContent, sink: List[dict]) -> None: + def _append_text_content(self, content: TextContent, sink: List[dict], author_name: str) -> None: text_value = getattr(content, "text", None) if not text_value: return @@ -129,11 +150,12 @@ def _append_text_content(self, content: TextContent, sink: List[dict]) -> None: "logprobs": [], } ], + "created_by": self._build_created_by(author_name), } ) logger.debug(" added message item id=%s text_len=%d", item_id, len(text_value)) - def _append_function_call_content(self, content: FunctionCallContent, sink: List[dict]) -> None: + def _append_function_call_content(self, content: FunctionCallContent, sink: List[dict], author_name: str) -> None: name = getattr(content, "name", "") or "" arguments = getattr(content, "arguments", "") if not isinstance(arguments, str): @@ -151,6 +173,7 @@ def _append_function_call_content(self, content: FunctionCallContent, sink: List "call_id": call_id, "name": name, "arguments": arguments or "", + "created_by": self._build_created_by(author_name), } ) logger.debug( @@ -161,7 +184,12 @@ def _append_function_call_content(self, content: FunctionCallContent, sink: List len(arguments or ""), ) - def _append_function_result_content(self, content: FunctionResultContent, sink: List[dict]) -> None: + def _append_function_result_content( + self, + content: FunctionResultContent, + sink: List[dict], + author_name: str, + ) -> None: # Coerce the function result into a simple display string. result = [] raw = getattr(content, "result", None) @@ -169,7 +197,7 @@ def _append_function_result_content(self, content: FunctionResultContent, sink: result = [raw] elif isinstance(raw, list): for item in raw: - result.append(self._coerce_result_text(item)) # type: ignore + result.append(self._coerce_result_text(item)) # type: ignore call_id = getattr(content, "call_id", None) or "" func_out_id = self._context.id_generator.generate_function_output_id() sink.append( @@ -179,15 +207,42 @@ def _append_function_result_content(self, content: FunctionResultContent, sink: "status": "completed", "call_id": call_id, "output": json.dumps(result) if len(result) > 0 else "", + "created_by": self._build_created_by(author_name), } ) logger.debug( - "added function_call_output item id=%s call_id=%s output_len=%d", + "added function_call_output item id=%s call_id=%s " + "output_len=%d", func_out_id, call_id, len(result), ) + def _append_user_input_request_contents( + self, + content: UserInputRequestContents, + sink: List[dict], + author_name: str, + ) -> None: + item_id = self._context.id_generator.generate_function_call_id() + content = self._hitl_helper.convert_user_input_request_content(content) + sink.append( + { + "id": item_id, + "type": "function_call", + "status": "in_progress", + "call_id": content["call_id"], + "name": content["name"], + "arguments": content["arguments"], + "created_by": self._build_created_by(author_name), + } + ) + logger.debug( + " added user_input_request item id=%s call_id=%s", + item_id, + content["call_id"], + ) + # ------------- simple normalization helper ------------------------- def _coerce_result_text(self, value: Any) -> str | dict: """ diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py index d9bc3199efb5..d506476a4cc9 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/agent_framework_output_streaming_converter.py @@ -1,24 +1,27 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -# pylint: disable=attribute-defined-outside-init,protected-access -# mypy: disable-error-code="call-overload,assignment,arg-type" +# pylint: disable=attribute-defined-outside-init,protected-access,unnecessary-lambda-assignment +# mypy: disable-error-code="call-overload,assignment,arg-type,override" from __future__ import annotations import datetime import json -import uuid -from typing import Any, List, Optional, cast +from typing import Any, AsyncIterable, List, Union -from agent_framework import AgentRunResponseUpdate, FunctionApprovalRequestContent, FunctionResultContent +from agent_framework import ( + AgentRunResponseUpdate, + BaseContent, + FunctionResultContent, +) from agent_framework._types import ( ErrorContent, FunctionCallContent, TextContent, + UserInputRequestContents, ) from azure.ai.agentserver.core import AgentRunContext -from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import ( Response as OpenAIResponse, ResponseStreamEvent, @@ -27,11 +30,11 @@ FunctionToolCallItemResource, FunctionToolCallOutputItemResource, ItemContentOutputText, + ItemResource, ResponseCompletedEvent, ResponseContentPartAddedEvent, ResponseContentPartDoneEvent, ResponseCreatedEvent, - ResponseErrorEvent, ResponseFunctionCallArgumentsDeltaEvent, ResponseFunctionCallArgumentsDoneEvent, ResponseInProgressEvent, @@ -43,500 +46,409 @@ ) from .agent_id_generator import AgentIdGenerator - -logger = get_logger() +from .human_in_the_loop_helper import HumanInTheLoopHelper +from .utils.async_iter import chunk_on_change, peek class _BaseStreamingState: """Base interface for streaming state handlers.""" - def prework(self, ctx: Any) -> List[ResponseStreamEvent]: # pylint: disable=unused-argument - return [] - - def convert_content(self, ctx: Any, content) -> List[ResponseStreamEvent]: # pylint: disable=unused-argument + async def convert_contents( + self, + contents: AsyncIterable[BaseContent], + author_name: str, + ) -> AsyncIterable[ResponseStreamEvent]: + # pylint: disable=unused-argument raise NotImplementedError - def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: # pylint: disable=unused-argument - return [] - class _TextContentStreamingState(_BaseStreamingState): """State handler for text and reasoning-text content during streaming.""" - def __init__(self, context: AgentRunContext) -> None: - self.context = context - self.item_id = None - self.output_index = None - self.text_buffer = "" - self.text_part_started = False - - def prework(self, ctx: Any) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if self.item_id is not None: - return events - - # Start a new assistant message item (in_progress) - self.item_id = self.context.id_generator.generate_message_id() - self.output_index = ctx._next_output_index # pylint: disable=protected-access - ctx._next_output_index += 1 - - message_item = ResponsesAssistantMessageItemResource( - id=self.item_id, - status="in_progress", - content=[], + def __init__(self, parent: AgentFrameworkOutputStreamingConverter): + self._parent = parent + + async def convert_contents( + self, + contents: AsyncIterable[TextContent], + author_name: str, + ) -> AsyncIterable[ResponseStreamEvent]: + item_id = self._parent.context.id_generator.generate_message_id() + output_index = self._parent.next_output_index() + + yield ResponseOutputItemAddedEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=ResponsesAssistantMessageItemResource( + id=item_id, + status="in_progress", + content=[], + created_by=self._parent._build_created_by(author_name), + ), ) - events.append( - ResponseOutputItemAddedEvent( - sequence_number=ctx.next_sequence(), - output_index=self.output_index, - item=message_item, - ) + yield ResponseContentPartAddedEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, + content_index=0, + part=ItemContentOutputText(text="", annotations=[], logprobs=[]), ) - if not self.text_part_started: - empty_part = ItemContentOutputText(text="", annotations=[], logprobs=[]) - events.append( - ResponseContentPartAddedEvent( - sequence_number=ctx.next_sequence(), - item_id=self.item_id, - output_index=self.output_index, - content_index=0, - part=empty_part, - ) - ) - self.text_part_started = True - return events - - def convert_content(self, ctx: Any, content: TextContent) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if isinstance(content, TextContent): - delta = content.text or "" - else: - delta = getattr(content, "text", None) or getattr(content, "reasoning", "") or "" - - # buffer accumulated text - self.text_buffer += delta - - # emit delta event for text - assert self.item_id is not None, "Text state not initialized: missing item_id" - assert self.output_index is not None, "Text state not initialized: missing output_index" - events.append( - ResponseTextDeltaEvent( - sequence_number=ctx.next_sequence(), - item_id=self.item_id, - output_index=self.output_index, + text = "" + async for content in contents: + delta = content.text + text += delta + + yield ResponseTextDeltaEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, content_index=0, delta=delta, ) + + yield ResponseTextDoneEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, + content_index=0, + text=text, ) - return events - - def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if not self.item_id: - return events - - full_text = self.text_buffer - assert self.item_id is not None and self.output_index is not None - events.append( - ResponseTextDoneEvent( - sequence_number=ctx.next_sequence(), - item_id=self.item_id, - output_index=self.output_index, - content_index=0, - text=full_text, - ) - ) - final_part = ItemContentOutputText(text=full_text, annotations=[], logprobs=[]) - events.append( - ResponseContentPartDoneEvent( - sequence_number=ctx.next_sequence(), - item_id=self.item_id, - output_index=self.output_index, - content_index=0, - part=final_part, - ) - ) - completed_item = ResponsesAssistantMessageItemResource( - id=self.item_id, status="completed", content=[final_part] + + content_part = ItemContentOutputText(text=text, annotations=[], logprobs=[]) + yield ResponseContentPartDoneEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, + content_index=0, + part=content_part, ) - events.append( - ResponseOutputItemDoneEvent( - sequence_number=ctx.next_sequence(), - output_index=self.output_index, - item=completed_item, - ) + + item = ResponsesAssistantMessageItemResource( + id=item_id, + status="completed", + content=[content_part], + created_by=self._parent._build_created_by(author_name), ) - ctx._last_completed_text = full_text # pylint: disable=protected-access - # store for final response - ctx._completed_output_items.append( - { - "id": self.item_id, - "type": "message", - "status": "completed", - "content": [ - { - "type": "output_text", - "text": full_text, - "annotations": [], - "logprobs": [], - } - ], - "role": "assistant", - } + yield ResponseOutputItemDoneEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=item, ) - # reset state - self.item_id = None - self.output_index = None - self.text_buffer = "" - self.text_part_started = False - return events + + self._parent.add_completed_output_item(item) # pylint: disable=protected-access class _FunctionCallStreamingState(_BaseStreamingState): """State handler for function_call content during streaming.""" - def __init__(self, context: AgentRunContext) -> None: - self.context = context - self.item_id = None - self.output_index = None - self.call_id = None - self.name = None - self.args_buffer = "" - self.requires_approval = False - self.approval_request_id: str | None = None - - def prework(self, ctx: Any) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if self.item_id is not None: - return events - # initialize function-call item - self.item_id = self.context.id_generator.generate_function_call_id() - self.output_index = ctx._next_output_index - ctx._next_output_index += 1 - - self.call_id = self.call_id or str(uuid.uuid4()) - function_item = FunctionToolCallItemResource( - id=self.item_id, - status="in_progress", - call_id=self.call_id, - name=self.name or "", - arguments="", - ) - events.append( - ResponseOutputItemAddedEvent( - sequence_number=ctx.next_sequence(), - output_index=self.output_index, - item=function_item, + def __init__(self, + parent: AgentFrameworkOutputStreamingConverter, + hitl_helper: HumanInTheLoopHelper): + self._parent = parent + self._hitl_helper = hitl_helper + + async def convert_contents( + self, contents: AsyncIterable[Union[FunctionCallContent, UserInputRequestContents]], author_name: str + ) -> AsyncIterable[ResponseStreamEvent]: + content_by_call_id = {} + ids_by_call_id = {} + hitl_contents = [] + + async for content in contents: + if isinstance(content, FunctionCallContent): + if content.call_id not in content_by_call_id: + item_id = self._parent.context.id_generator.generate_function_call_id() + output_index = self._parent.next_output_index() + + content_by_call_id[content.call_id] = content + ids_by_call_id[content.call_id] = (item_id, output_index) + + yield ResponseOutputItemAddedEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=FunctionToolCallItemResource( + id=item_id, + status="in_progress", + call_id=content.call_id, + name=content.name, + arguments="", + created_by=self._parent._build_created_by(author_name), + ), + ) + else: + content_by_call_id[content.call_id] = content_by_call_id[content.call_id] + content + item_id, output_index = ids_by_call_id[content.call_id] + + args_delta = content.arguments if isinstance(content.arguments, str) else "" + yield ResponseFunctionCallArgumentsDeltaEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, + delta=args_delta, + ) + + elif isinstance(content, UserInputRequestContents): + converted_hitl = self._hitl_helper.convert_user_input_request_content(content) + if converted_hitl: + hitl_contents.append(converted_hitl) + + for call_id, content in content_by_call_id.items(): + item_id, output_index = ids_by_call_id[call_id] + args = self._serialize_arguments(content.arguments) + yield ResponseFunctionCallArgumentsDoneEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, + arguments=args, ) - ) - return events - - def convert_content(self, ctx: Any, content: FunctionCallContent) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - # record identifiers (once available) - self.name = getattr(content, "name", None) or self.name or "" - self.call_id = getattr(content, "call_id", None) or self.call_id or str(uuid.uuid4()) - - args_delta = content.arguments if isinstance(content.arguments, str) else json.dumps(content.arguments) - args_delta = args_delta or "" - self.args_buffer += args_delta - assert self.item_id is not None and self.output_index is not None - for ch in args_delta: - events.append( - ResponseFunctionCallArgumentsDeltaEvent( - sequence_number=ctx.next_sequence(), - item_id=self.item_id, - output_index=self.output_index, - delta=ch, - ) + + item = FunctionToolCallItemResource( + id=item_id, + status="completed", + call_id=call_id, + name=content.name, + arguments=args, + created_by=self._parent._build_created_by(author_name), + ) + yield ResponseOutputItemDoneEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=item, ) - # finalize if arguments are detected to be complete - is_done = bool( - getattr(content, "is_final", False) - or getattr(content, "final", False) - or getattr(content, "done", False) - or getattr(content, "arguments_final", False) - or getattr(content, "arguments_done", False) - or getattr(content, "finish", False) - ) - if not is_done and self.args_buffer: - try: - json.loads(self.args_buffer) - is_done = True - except Exception: # pylint: disable=broad-exception-caught - pass - - if is_done: - events.append( - ResponseFunctionCallArgumentsDoneEvent( - sequence_number=ctx.next_sequence(), - item_id=self.item_id, - output_index=self.output_index, - arguments=self.args_buffer, - ) + self._parent.add_completed_output_item(item) # pylint: disable=protected-access + + # process HITL contents after function calls + for content in hitl_contents: + item_id = self._parent.context.id_generator.generate_function_call_id() + output_index = self._parent.next_output_index() + + yield ResponseOutputItemAddedEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=FunctionToolCallItemResource( + id=item_id, + status="in_progress", + call_id=content["call_id"], + name=content["name"], + arguments="", + created_by=self._parent._build_created_by(author_name), + ), ) - events.extend(self.afterwork(ctx)) - return events - - def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if not self.item_id: - return events - assert self.call_id is not None - done_item = FunctionToolCallItemResource( - id=self.item_id, - status="completed", - call_id=self.call_id, - name=self.name or "", - arguments=self.args_buffer, - ) - assert self.output_index is not None - events.append( - ResponseOutputItemDoneEvent( - sequence_number=ctx.next_sequence(), - output_index=self.output_index, - item=done_item, + yield ResponseFunctionCallArgumentsDeltaEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, + delta=content["arguments"], ) - ) - # store for final response - ctx._completed_output_items.append( - { - "id": self.item_id, - "type": "function_call", - "call_id": self.call_id, - "name": self.name or "", - "arguments": self.args_buffer, - "status": "requires_approval" if self.requires_approval else "completed", - "requires_approval": self.requires_approval, - "approval_request_id": self.approval_request_id, - } - ) - # reset - self.item_id = None - self.output_index = None - self.args_buffer = "" - self.call_id = None - self.name = None - self.requires_approval = False - self.approval_request_id = None - return events + + yield ResponseFunctionCallArgumentsDoneEvent( + sequence_number=self._parent.next_sequence(), + item_id=item_id, + output_index=output_index, + arguments=content["arguments"], + ) + item = FunctionToolCallItemResource( + id=item_id, + status="in_progress", + call_id=content["call_id"], + name=content["name"], + arguments=content["arguments"], + created_by=self._parent._build_created_by(author_name), + ) + yield ResponseOutputItemDoneEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=item, + ) + self._parent.add_completed_output_item(item) + + def _serialize_arguments(self, arguments: Any) -> str: + if isinstance(arguments, str): + return arguments + try: + return json.dumps(arguments) + except Exception: # pylint: disable=broad-exception-caught + return str(arguments) class _FunctionCallOutputStreamingState(_BaseStreamingState): """Handles function_call_output items streaming (non-chunked simple output).""" - def __init__( - self, - context: AgentRunContext, - call_id: Optional[str] = None, - output: Optional[list[str]] = None, - ) -> None: - # Avoid mutable default argument (Ruff B006) - self.context = context - self.item_id = None - self.output_index = None - self.call_id = call_id - self.output = output if output is not None else [] - - def prework(self, ctx: Any) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if self.item_id is not None: - return events - self.item_id = self.context.id_generator.generate_function_output_id() - self.output_index = ctx._next_output_index - ctx._next_output_index += 1 - - self.call_id = self.call_id or str(uuid.uuid4()) - item = FunctionToolCallOutputItemResource( - id=self.item_id, - status="in_progress", - call_id=self.call_id, - output="", - ) - events.append( - ResponseOutputItemAddedEvent( - sequence_number=ctx.next_sequence(), - output_index=self.output_index, - item=item, + def __init__(self, parent: AgentFrameworkOutputStreamingConverter): + self._parent = parent + + async def convert_contents( + self, contents: AsyncIterable[FunctionResultContent], author_name: str + ) -> AsyncIterable[ResponseStreamEvent]: + async for content in contents: + item_id = self._parent.context.id_generator.generate_function_output_id() + output_index = self._parent.next_output_index() + + output = (f"{type(content.exception)}({str(content.exception)})" + if content.exception + else self._to_output(content.result)) + + item = FunctionToolCallOutputItemResource( + id=item_id, + status="completed", + call_id=content.call_id, + output=output, + created_by=self._parent._build_created_by(author_name), ) - ) - return events - - def convert_content(self, ctx: Any, content: Any) -> List[ResponseStreamEvent]: # no delta events for now - events: List[ResponseStreamEvent] = [] - # treat entire output as final - result = [] - raw = getattr(content, "result", None) - if isinstance(raw, str): - result = [raw or self.output] - elif isinstance(raw, list): - for item in raw: - result.append(self._coerce_result_text(item)) - self.output = json.dumps(result) if len(result) > 0 else "" - - events.extend(self.afterwork(ctx)) - return events - - def _coerce_result_text(self, value: Any) -> str | dict: - """ - Return a string if value is already str or a TextContent-like object; else str(value). - - :param value: The value to coerce. - :type value: Any - - :return: The coerced string or dict. - :rtype: str | dict - """ - if value is None: - return "" - if isinstance(value, str): - return value - # Direct TextContent instance - if isinstance(value, TextContent): - content_payload = {"type": "text", "text": getattr(value, "text", "")} - return content_payload - return "" + yield ResponseOutputItemAddedEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=item, + ) - def afterwork(self, ctx: Any) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if not self.item_id: - return events - # Ensure types conform: call_id must be str (guarantee non-None) and output is a single string - str_call_id = self.call_id or "" - single_output: str = cast(str, self.output[0]) if self.output else "" - done_item = FunctionToolCallOutputItemResource( - id=self.item_id, - status="completed", - call_id=str_call_id, - output=single_output, - ) - assert self.output_index is not None - events.append( - ResponseOutputItemDoneEvent( - sequence_number=ctx.next_sequence(), - output_index=self.output_index, - item=done_item, + yield ResponseOutputItemDoneEvent( + sequence_number=self._parent.next_sequence(), + output_index=output_index, + item=item, ) - ) - ctx._completed_output_items.append( - { - "id": self.item_id, - "type": "function_call_output", - "status": "completed", - "call_id": self.call_id, - "output": self.output, - } - ) - self.item_id = None - self.output_index = None - return events + + self._parent.add_completed_output_item(item) # pylint: disable=protected-access + + @classmethod + def _to_output(cls, result: Any) -> str: + if isinstance(result, str): + return result + if isinstance(result, list): + text = [] + for item in result: + if isinstance(item, BaseContent): + text.append(item.to_dict()) + else: + text.append(str(item)) + return json.dumps(text) + return "" class AgentFrameworkOutputStreamingConverter: """Streaming converter using content-type-specific state handlers.""" - def __init__(self, context: AgentRunContext) -> None: + def __init__(self, context: AgentRunContext, *, hitl_helper: HumanInTheLoopHelper=None) -> None: self._context = context # sequence numbers must start at 0 for first emitted event - self._sequence = 0 - self._response_id = None + self._sequence = -1 + self._next_output_index = -1 + self._response_id = self._context.response_id self._response_created_at = None - self._next_output_index = 0 - self._last_completed_text = "" - self._active_state: Optional[_BaseStreamingState] = None - self._active_kind = None # "text" | "function_call" | "error" - # accumulate completed output items for final response - self._completed_output_items: List[dict] = [] - - def _ensure_response_started(self) -> None: - if not self._response_id: - self._response_id = self._context.response_id - if not self._response_created_at: - self._response_created_at = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) + self._completed_output_items: List[ItemResource] = [] + self._hitl_helper = hitl_helper def next_sequence(self) -> int: self._sequence += 1 return self._sequence - def _switch_state(self, kind: str) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if self._active_state and self._active_kind != kind: - events.extend(self._active_state.afterwork(self)) - self._active_state = None - self._active_kind = None - - if self._active_state is None: - if kind == "text": - self._active_state = _TextContentStreamingState(self._context) - elif kind == "function_call": - self._active_state = _FunctionCallStreamingState(self._context) - elif kind == "function_call_output": - self._active_state = _FunctionCallOutputStreamingState(self._context) - else: - self._active_state = None - self._active_kind = kind - if self._active_state: - events.extend(self._active_state.prework(self)) - return events - - def transform_output_for_streaming(self, update: AgentRunResponseUpdate) -> List[ResponseStreamEvent]: - logger.debug( - "Transforming streaming update with %d contents", - len(update.contents) if getattr(update, "contents", None) else 0, - ) + def next_output_index(self) -> int: + self._next_output_index += 1 + return self._next_output_index + + def add_completed_output_item(self, item: ItemResource) -> None: + self._completed_output_items.append(item) + + @property + def context(self) -> AgentRunContext: + return self._context + + async def convert(self, updates: AsyncIterable[AgentRunResponseUpdate]) -> AsyncIterable[ResponseStreamEvent]: self._ensure_response_started() - events: List[ResponseStreamEvent] = [] - - if getattr(update, "contents", None): - for i, content in enumerate(update.contents): - logger.debug("Processing content %d: %s", i, type(content)) - if isinstance(content, TextContent): - events.extend(self._switch_state("text")) - if isinstance(self._active_state, _TextContentStreamingState): - events.extend(self._active_state.convert_content(self, content)) - elif isinstance(content, FunctionCallContent): - events.extend(self._switch_state("function_call")) - if isinstance(self._active_state, _FunctionCallStreamingState): - events.extend(self._active_state.convert_content(self, content)) - elif isinstance(content, FunctionResultContent): - events.extend(self._switch_state("function_call_output")) - if isinstance(self._active_state, _FunctionCallOutputStreamingState): - call_id = getattr(content, "call_id", None) - if call_id: - self._active_state.call_id = call_id - events.extend(self._active_state.convert_content(self, content)) - elif isinstance(content, FunctionApprovalRequestContent): - events.extend(self._switch_state("function_call")) - if isinstance(self._active_state, _FunctionCallStreamingState): - self._active_state.requires_approval = True - self._active_state.approval_request_id = getattr(content, "id", None) - events.extend(self._active_state.convert_content(self, content.function_call)) - elif isinstance(content, ErrorContent): - # errors are stateless; flush current state and emit error - events.extend(self._switch_state("error")) - events.append( - ResponseErrorEvent( - sequence_number=self.next_sequence(), - code=getattr(content, "error_code", None) or "server_error", - message=getattr(content, "message", None) or "An error occurred", - param="", - ) - ) - return events - def finalize_last_content(self) -> List[ResponseStreamEvent]: - events: List[ResponseStreamEvent] = [] - if self._active_state: - events.extend(self._active_state.afterwork(self)) - self._active_state = None - self._active_kind = None - return events + created_response = self._build_response(status="in_progress") + yield ResponseCreatedEvent( + sequence_number=self.next_sequence(), + response=created_response, + ) + + yield ResponseInProgressEvent( + sequence_number=self.next_sequence(), + response=created_response, + ) + + is_changed = ( + lambda a, b: a is not None \ + and b is not None \ + and a.message_id != b.message_id + ) + + async for group in chunk_on_change(updates, is_changed): + has_value, first_tuple, contents_with_author = await peek(self._read_updates(group)) + if not has_value or first_tuple is None: + continue + + first, author_name = first_tuple # Extract content and author_name from tuple + + state = None + if isinstance(first, TextContent): + state = _TextContentStreamingState(self) + elif isinstance(first, (FunctionCallContent, UserInputRequestContents)): + state = _FunctionCallStreamingState(self, self._hitl_helper) + elif isinstance(first, FunctionResultContent): + state = _FunctionCallOutputStreamingState(self) + elif isinstance(first, ErrorContent): + error_msg = ( + f"ErrorContent received: code={first.error_code}, " + f"message={first.message}" + ) + raise ValueError(error_msg) + if not state: + continue + + # Extract just the content from (content, author_name) tuples using async generator + async def extract_contents(): + async for content, _ in contents_with_author: # pylint: disable=cell-var-from-loop + yield content + + async for content in state.convert_contents(extract_contents(), author_name): + yield content - def build_response(self, status: str) -> OpenAIResponse: + yield ResponseCompletedEvent( + sequence_number=self.next_sequence(), + response=self._build_response(status="completed"), + ) + + def _build_created_by(self, author_name: str) -> dict: + agent_dict = { + "type": "agent_id", + "name": author_name or "", + "version": "", + } + + return { + "agent": agent_dict, + "response_id": self._response_id, + } + + async def _read_updates( + self, + updates: AsyncIterable[AgentRunResponseUpdate], + ) -> AsyncIterable[tuple[BaseContent, str]]: + async for update in updates: + if not update.contents: + continue + + # Extract author_name from each update + author_name = getattr(update, "author_name", "") or "" + + accepted_types = (TextContent, + FunctionCallContent, + UserInputRequestContents, + FunctionResultContent, + ErrorContent) + for content in update.contents: + if isinstance(content, accepted_types): + yield (content, author_name) + + def _ensure_response_started(self) -> None: + if not self._response_created_at: + self._response_created_at = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) + + def _build_response(self, status: str) -> OpenAIResponse: self._ensure_response_started() agent_id = AgentIdGenerator.generate(self._context) response_data = { @@ -545,51 +457,10 @@ def build_response(self, status: str) -> OpenAIResponse: "id": self._response_id, "status": status, "created_at": self._response_created_at, + "conversation": self._context.get_conversation_object(), } - if status == "completed" and self._completed_output_items: + + # set output even if _completed_output_items is empty, never leave the output as null + if status == "completed": response_data["output"] = self._completed_output_items return OpenAIResponse(response_data) - - # High-level helpers to emit lifecycle events for streaming - def initial_events(self) -> List[ResponseStreamEvent]: - """ - Emit ResponseCreatedEvent and an initial ResponseInProgressEvent. - - :return: List of initial response stream events. - :rtype: List[ResponseStreamEvent] - """ - self._ensure_response_started() - events: List[ResponseStreamEvent] = [] - created_response = self.build_response(status="in_progress") - events.append( - ResponseCreatedEvent( - sequence_number=self.next_sequence(), - response=created_response, - ) - ) - events.append( - ResponseInProgressEvent( - sequence_number=self.next_sequence(), - response=self.build_response(status="in_progress"), - ) - ) - return events - - def completion_events(self) -> List[ResponseStreamEvent]: - """ - Finalize any active content and emit a single ResponseCompletedEvent. - - :return: List of completion response stream events. - :rtype: List[ResponseStreamEvent] - """ - self._ensure_response_started() - events: List[ResponseStreamEvent] = [] - events.extend(self.finalize_last_content()) - completed_response = self.build_response(status="completed") - events.append( - ResponseCompletedEvent( - sequence_number=self.next_sequence(), - response=completed_response, - ) - ) - return events diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py new file mode 100644 index 000000000000..82c959612faa --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/human_in_the_loop_helper.py @@ -0,0 +1,121 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Any, List, Dict, Optional, Union +import json + +from agent_framework import ( + ChatMessage, + FunctionResultContent, + FunctionApprovalResponseContent, + RequestInfoEvent, + WorkflowCheckpoint, +) +from agent_framework._types import UserInputRequestContents + +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.server.common.constants import HUMAN_IN_THE_LOOP_FUNCTION_NAME + +logger = get_logger() + +class HumanInTheLoopHelper: + + def get_pending_hitl_request(self, + thread_messages: List[ChatMessage] = None, + checkpoint: Optional[WorkflowCheckpoint] = None, + ) -> dict[str, Union[RequestInfoEvent, Any]]: + res = {} + # if has checkpoint (WorkflowAgent), find pending request info from checkpoint + if checkpoint and checkpoint.pending_request_info_events: + for call_id, request in checkpoint.pending_request_info_events.items(): + # find if the request is already responded in the thread messages + if isinstance(request, dict): + request_obj = RequestInfoEvent.from_dict(request) + res[call_id] = request_obj + return res + + if not thread_messages: + return res + + # if no checkpoint (Agent), find user input request and pair the feedbacks + for message in thread_messages: + for content in message.contents: + if isinstance(content, UserInputRequestContents): + # is a human input request + function_call = content.function_call + call_id = getattr(function_call, "call_id", "") + if call_id: + res[call_id] = RequestInfoEvent( + source_executor_id="agent", + request_id=call_id, + response_type=None, + request_data=function_call, + ) + elif isinstance(content, FunctionResultContent): + if content.call_id and content.call_id in res: + # remove requests that already got feedback + res.pop(content.call_id) + elif isinstance(content, FunctionApprovalResponseContent): + function_call = content.function_call + call_id = getattr(function_call, "call_id", "") + if call_id and call_id in res: + res.pop(call_id) + return res + + def convert_user_input_request_content(self, content: UserInputRequestContents) -> dict: + function_call = content.function_call + call_id = getattr(function_call, "call_id", "") + arguments = self.convert_request_arguments(getattr(function_call, "arguments", "")) + return { + "call_id": call_id, + "name": HUMAN_IN_THE_LOOP_FUNCTION_NAME, + "arguments": arguments or "", + } + + def convert_request_arguments(self, arguments: Any) -> str: + # convert data to payload if possible + if isinstance(arguments, dict): + data = arguments.get("data") + if data and hasattr(data, "convert_to_payload"): + return data.convert_to_payload() + + if not isinstance(arguments, str): + if hasattr(arguments, "to_dict"): # agentframework models have to_dict method + arguments = arguments.to_dict() + try: + arguments = json.dumps(arguments) + except Exception: # pragma: no cover - fallback # pylint: disable=broad-exception-caught + arguments = str(arguments) + return arguments + + def validate_and_convert_hitl_response(self, + input: Union[str, List[Dict], None], + pending_requests: Dict[str, RequestInfoEvent], + ) -> Optional[List[ChatMessage]]: + + if input is None or isinstance(input, str): + logger.warning("Expected list input for HitL response validation, got str.") + return None + + res = [] + for item in input: + if item.get("type") != "function_call_output": + logger.warning("Expected function_call_output type for HitL response validation.") + return None + call_id = item.get("call_id", None) + if call_id and call_id in pending_requests: + res.append(self.convert_response(pending_requests[call_id], item)) + return res + + def convert_response(self, hitl_request: RequestInfoEvent, input: Dict) -> ChatMessage: + response_type = hitl_request.response_type + response_result = input.get("output", "") + logger.info(f"response_type {type(response_type)}: %s", response_type) + if response_type and hasattr(response_type, "convert_from_payload"): + response_result = response_type.convert_from_payload(input.get("output", "")) + logger.info(f"response_result {type(response_result)}: %s", response_result) + response_content = FunctionResultContent( + call_id=hitl_request.request_id, + result=response_result, + ) + return ChatMessage(role="tool", contents=[response_content]) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/async_iter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/async_iter.py new file mode 100644 index 000000000000..fdf3b2fbb2a3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/models/utils/async_iter.py @@ -0,0 +1,145 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +from collections.abc import AsyncIterable, AsyncIterator, Callable +from typing import TypeVar, Optional, Tuple + +TSource = TypeVar("TSource") +TKey = TypeVar("TKey") +T = TypeVar("T") + + +async def chunk_on_change( + source: AsyncIterable[TSource], + is_changed: Optional[Callable[[Optional[TSource], Optional[TSource]], bool]] = None, +) -> AsyncIterator[AsyncIterable[TSource]]: + """ + Chunks an async iterable into groups based on when consecutive elements change. + + :param source: Async iterable of items. + :type source: AsyncIterable[TSource] + :param is_changed: Function(prev, current) -> bool indicating if value changed. + If None, uses != by default. + :type is_changed: Optional[Callable[[Optional[TSource], Optional[TSource]], bool]] + :return: An async iterator of async iterables (chunks). + :rtype: AsyncIterator[AsyncIterable[TSource]] + """ + + if is_changed is None: + # Default equality: use the value itself as key, == as equality + async for group in chunk_by_key(source, lambda x: x): + yield group + else: + # Equivalent to C#: EqualityComparer.Create((x, y) => !isChanged(x, y)) + def key_equal(a: TSource, b: TSource) -> bool: + return not is_changed(a, b) + + async for group in chunk_by_key(source, lambda x: x, key_equal=key_equal): + yield group + + +async def chunk_by_key( + source: AsyncIterable[TSource], + key_selector: Callable[[TSource], TKey], + key_equal: Optional[Callable[[TKey, TKey], bool]] = None, +) -> AsyncIterator[AsyncIterable[TSource]]: + """ + Chunks the async iterable into groups based on a key selector. + + :param source: Async iterable of items. + :type source: AsyncIterable[TSource] + :param key_selector: Function mapping item -> key. + :type key_selector: Callable[[TSource], TKey] + :param key_equal: Optional equality function for keys. Defaults to '=='. + :type key_equal: Optional[Callable[[TKey, TKey], bool]] + :return: An async iterator of async iterables (chunks). + :rtype: AsyncIterator[AsyncIterable[TSource]] + """ + + if key_equal is None: + def key_equal(a: TKey, b: TKey) -> bool: # type: ignore[no-redef] + return a == b + + it = source.__aiter__() + + # Prime the iterator + try: + pending = await it.__anext__() + except StopAsyncIteration: + return + + pending_key = key_selector(pending) + has_pending = True + + while has_pending: + current_key = pending_key + + async def inner() -> AsyncIterator[TSource]: + nonlocal pending, pending_key, has_pending + + # First element of the group + yield pending + + # Consume until key changes or source ends + while True: + try: + item = await it.__anext__() + except StopAsyncIteration: + # Source ended; tell outer loop to stop after this group + has_pending = False + return + + k = key_selector(item) + if not key_equal(k, current_key): + # Hand first item of next group back to outer loop + pending = item + pending_key = k + return + + yield item + + # Yield an async iterable representing the current chunk + yield inner() + + +async def peek( + source: AsyncIterable[T], +) -> Tuple[bool, Optional[T], AsyncIterable[T]]: + """ + Peeks at the first element of an async iterable without consuming it. + + :param source: Async iterable. + :type source: AsyncIterable[T] + :return: (has_value, first, full_sequence_including_first) + :rtype: Tuple[bool, Optional[T], AsyncIterable[T]] + """ + + it = source.__aiter__() + + try: + first = await it.__anext__() + except StopAsyncIteration: + return False, None, _empty_async() + + async def sequence() -> AsyncIterator[T]: + try: + # Yield the peeked element first + yield first + # Then the rest of the original iterator + async for item in it: + yield item + finally: + # Try to close underlying async generator if it supports it + aclose = getattr(it, "aclose", None) + if aclose is not None: + await aclose() + + return True, first, sequence() + + +async def _empty_async() -> AsyncIterator[T]: + if False: # pylint: disable=using-constant-test + # This is just to make this an async generator for typing + yield None # type: ignore[misc] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py new file mode 100644 index 000000000000..d59d1154650c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/__init__.py @@ -0,0 +1,25 @@ +from .agent_thread_repository import ( + AgentThreadRepository, + InMemoryAgentThreadRepository, + SerializedAgentThreadRepository, + JsonLocalFileAgentThreadRepository, +) +from .checkpoint_repository import ( + CheckpointRepository, + InMemoryCheckpointRepository, + FileCheckpointRepository, +) +from ._foundry_checkpoint_storage import FoundryCheckpointStorage +from ._foundry_checkpoint_repository import FoundryCheckpointRepository + +__all__ = [ + "AgentThreadRepository", + "InMemoryAgentThreadRepository", + "SerializedAgentThreadRepository", + "JsonLocalFileAgentThreadRepository", + "CheckpointRepository", + "InMemoryCheckpointRepository", + "FileCheckpointRepository", + "FoundryCheckpointStorage", + "FoundryCheckpointRepository", +] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_repository.py new file mode 100644 index 000000000000..96b80615445e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_repository.py @@ -0,0 +1,99 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Foundry-backed checkpoint repository implementation.""" + +import logging +from typing import Dict, Optional, Union + +from agent_framework import CheckpointStorage +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azure.ai.agentserver.core.checkpoints.client import ( + CheckpointSession, + FoundryCheckpointClient, +) + +from .checkpoint_repository import CheckpointRepository +from ._foundry_checkpoint_storage import FoundryCheckpointStorage + +logger = logging.getLogger(__name__) + + +class FoundryCheckpointRepository(CheckpointRepository): + """Repository that creates FoundryCheckpointStorage instances per conversation. + + Manages checkpoint sessions on the Foundry backend, creating sessions on demand + and caching storage instances for reuse within the same process. + + :param project_endpoint: The Azure AI Foundry project endpoint URL. + Example: "https://.services.ai.azure.com/api/projects/" + :type project_endpoint: str + :param credential: Credential for authentication. + :type credential: Union[AsyncTokenCredential, TokenCredential] + """ + + def __init__( + self, + project_endpoint: str, + credential: Union[AsyncTokenCredential, TokenCredential], + ) -> None: + """Initialize the Foundry checkpoint repository. + + :param project_endpoint: The Azure AI Foundry project endpoint URL. + :type project_endpoint: str + :param credential: Credential for authentication. + :type credential: Union[AsyncTokenCredential, TokenCredential] + """ + # Convert sync credential to async if needed + if not isinstance(credential, AsyncTokenCredential): + # For now, we require async credentials + raise TypeError( + "FoundryCheckpointRepository requires an AsyncTokenCredential. " + "Please use an async credential like DefaultAzureCredential." + ) + + self._client = FoundryCheckpointClient(project_endpoint, credential) + self._inventory: Dict[str, CheckpointStorage] = {} + + async def get_or_create( + self, conversation_id: Optional[str] + ) -> Optional[CheckpointStorage]: + """Get or create a checkpoint storage for the given conversation. + + :param conversation_id: The conversation ID (maps to session_id on backend). + :type conversation_id: Optional[str] + :return: CheckpointStorage instance or None if no conversation_id. + :rtype: Optional[CheckpointStorage] + """ + if not conversation_id: + return None + + if conversation_id not in self._inventory: + # Ensure session exists on backend + await self._ensure_session(conversation_id) + self._inventory[conversation_id] = FoundryCheckpointStorage( + client=self._client, + session_id=conversation_id, + ) + logger.debug( + "Created FoundryCheckpointStorage for conversation %s", + conversation_id, + ) + + return self._inventory[conversation_id] + + async def _ensure_session(self, session_id: str) -> None: + """Ensure a session exists on the backend, creating if needed. + + :param session_id: The session identifier. + :type session_id: str + """ + session = CheckpointSession(session_id=session_id) + await self._client.upsert_session(session) + logger.debug("Ensured session %s exists on Foundry", session_id) + + async def close(self) -> None: + """Close the underlying HTTP client.""" + await self._client.close() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_storage.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_storage.py new file mode 100644 index 000000000000..65e5b2f7b3c4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/_foundry_checkpoint_storage.py @@ -0,0 +1,182 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Foundry-backed checkpoint storage implementation.""" + +import json +import logging +from typing import List, Optional + +from agent_framework import WorkflowCheckpoint + +from azure.ai.agentserver.core.checkpoints.client import ( + CheckpointItem, + CheckpointItemId, + FoundryCheckpointClient, +) + +logger = logging.getLogger(__name__) + + +class FoundryCheckpointStorage: + """CheckpointStorage implementation backed by Azure AI Foundry. + + Implements the agent_framework.CheckpointStorage protocol by delegating + to the FoundryCheckpointClient HTTP client for remote storage. + + :param client: The Foundry checkpoint client. + :type client: FoundryCheckpointClient + :param session_id: The session identifier (maps to conversation_id). + :type session_id: str + """ + + def __init__(self, client: FoundryCheckpointClient, session_id: str) -> None: + """Initialize the Foundry checkpoint storage. + + :param client: The Foundry checkpoint client. + :type client: FoundryCheckpointClient + :param session_id: The session identifier. + :type session_id: str + """ + self._client = client + self._session_id = session_id + + async def save_checkpoint(self, checkpoint: WorkflowCheckpoint) -> str: + """Save a checkpoint and return its ID. + + :param checkpoint: The workflow checkpoint to save. + :type checkpoint: WorkflowCheckpoint + :return: The checkpoint ID. + :rtype: str + """ + serialized = self._serialize_checkpoint(checkpoint) + item = CheckpointItem( + session_id=self._session_id, + item_id=checkpoint.checkpoint_id, + data=serialized, + parent_id=None, + ) + result = await self._client.create_items([item]) + if result: + logger.debug( + "Saved checkpoint %s to Foundry session %s", + checkpoint.checkpoint_id, + self._session_id, + ) + return result[0].item_id + return checkpoint.checkpoint_id + + async def load_checkpoint(self, checkpoint_id: str) -> Optional[WorkflowCheckpoint]: + """Load a checkpoint by ID. + + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :return: The workflow checkpoint if found, None otherwise. + :rtype: Optional[WorkflowCheckpoint] + """ + item_id = CheckpointItemId( + session_id=self._session_id, + item_id=checkpoint_id, + ) + item = await self._client.read_item(item_id) + if item is None: + return None + checkpoint = self._deserialize_checkpoint(item.data) + logger.debug( + "Loaded checkpoint %s from Foundry session %s", + checkpoint_id, + self._session_id, + ) + return checkpoint + + async def list_checkpoint_ids( + self, workflow_id: Optional[str] = None + ) -> List[str]: + """List checkpoint IDs. + + If workflow_id is provided, filter by that workflow. + + :param workflow_id: Optional workflow identifier for filtering. + :type workflow_id: Optional[str] + :return: List of checkpoint identifiers. + :rtype: List[str] + """ + item_ids = await self._client.list_item_ids(self._session_id) + ids = [item_id.item_id for item_id in item_ids] + + # Filter by workflow_id if provided + if workflow_id is not None: + filtered_ids = [] + for checkpoint_id in ids: + checkpoint = await self.load_checkpoint(checkpoint_id) + if checkpoint and checkpoint.workflow_id == workflow_id: + filtered_ids.append(checkpoint_id) + return filtered_ids + + return ids + + async def list_checkpoints( + self, workflow_id: Optional[str] = None + ) -> List[WorkflowCheckpoint]: + """List checkpoint objects. + + If workflow_id is provided, filter by that workflow. + + :param workflow_id: Optional workflow identifier for filtering. + :type workflow_id: Optional[str] + :return: List of workflow checkpoints. + :rtype: List[WorkflowCheckpoint] + """ + ids = await self.list_checkpoint_ids(workflow_id=None) + checkpoints: List[WorkflowCheckpoint] = [] + + for checkpoint_id in ids: + checkpoint = await self.load_checkpoint(checkpoint_id) + if checkpoint is not None: + if workflow_id is None or checkpoint.workflow_id == workflow_id: + checkpoints.append(checkpoint) + + return checkpoints + + async def delete_checkpoint(self, checkpoint_id: str) -> bool: + """Delete a checkpoint by ID. + + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :return: True if the checkpoint was deleted, False otherwise. + :rtype: bool + """ + item_id = CheckpointItemId( + session_id=self._session_id, + item_id=checkpoint_id, + ) + deleted = await self._client.delete_item(item_id) + if deleted: + logger.debug( + "Deleted checkpoint %s from Foundry session %s", + checkpoint_id, + self._session_id, + ) + return deleted + + def _serialize_checkpoint(self, checkpoint: WorkflowCheckpoint) -> bytes: + """Serialize a WorkflowCheckpoint to bytes. + + :param checkpoint: The workflow checkpoint. + :type checkpoint: WorkflowCheckpoint + :return: Serialized checkpoint data. + :rtype: bytes + """ + checkpoint_dict = checkpoint.to_dict() + return json.dumps(checkpoint_dict, ensure_ascii=False).encode("utf-8") + + def _deserialize_checkpoint(self, data: bytes) -> WorkflowCheckpoint: + """Deserialize bytes to WorkflowCheckpoint. + + :param data: Serialized checkpoint data. + :type data: bytes + :return: The workflow checkpoint. + :rtype: WorkflowCheckpoint + """ + checkpoint_dict = json.loads(data.decode("utf-8")) + return WorkflowCheckpoint.from_dict(checkpoint_dict) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py new file mode 100644 index 000000000000..2a43dbb5aee8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/agent_thread_repository.py @@ -0,0 +1,185 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +import json +import os +from typing import Any, Optional, Union + +from agent_framework import AgentThread, AgentProtocol, WorkflowAgent + + +class AgentThreadRepository(ABC): + """ + AgentThread repository to manage saved thread messages of agent threads and workflows. + + :meta private: + """ + + @abstractmethod + async def get( + self, + conversation_id: Optional[str], + agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, + ) -> Optional[AgentThread]: + """Retrieve the saved thread for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param agent: The agent instance. If provided, it can be used to deserialize the thread. + :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] + + :return: The saved AgentThread if available, None otherwise. + :rtype: Optional[AgentThread] + """ + + @abstractmethod + async def set(self, conversation_id: Optional[str], thread: AgentThread) -> None: + """Save the thread for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param thread: The thread to save. + :type thread: AgentThread + """ + + +class InMemoryAgentThreadRepository(AgentThreadRepository): + """In-memory implementation of AgentThreadRepository.""" + def __init__(self) -> None: + self._inventory: dict[str, AgentThread] = {} + + async def get( + self, + conversation_id: Optional[str], + agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, + ) -> Optional[AgentThread]: + """Retrieve the saved thread for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param agent: The agent instance. It will be used for in-memory repository for interface consistency. + :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] + :return: The saved AgentThread if available, None otherwise. + :rtype: Optional[AgentThread] + """ + if not conversation_id: + return None + if conversation_id in self._inventory: + return self._inventory[conversation_id] + return None + + async def set(self, conversation_id: Optional[str], thread: AgentThread) -> None: + """Save the thread for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param thread: The thread to save. + :type thread: AgentThread + """ + if not conversation_id or not thread: + return + self._inventory[conversation_id] = thread + + +class SerializedAgentThreadRepository(AgentThreadRepository): + """Implementation of AgentThreadRepository with AgentThread serialization.""" + def __init__(self, agent: AgentProtocol) -> None: + """ + Initialize the repository with the given agent. + + :param agent: The agent instance. + :type agent: AgentProtocol + """ + self._agent = agent + + async def get( + self, + conversation_id: Optional[str], + agent: Optional[Union[AgentProtocol, WorkflowAgent]] = None, + ) -> Optional[AgentThread]: + """Retrieve the saved thread for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param agent: The agent instance. If provided, it can be used to deserialize the thread. + Otherwise, the repository's agent will be used. + :type agent: Optional[Union[AgentProtocol, WorkflowAgent]] + + :return: The saved AgentThread if available, None otherwise. + :rtype: Optional[AgentThread] + """ + if not conversation_id: + return None + serialized_thread = await self.read_from_storage(conversation_id) + if serialized_thread: + agent_to_use = agent or self._agent + thread = await agent_to_use.deserialize_thread(serialized_thread) + return thread + return None + + async def set(self, conversation_id: Optional[str], thread: AgentThread) -> None: + """Save the thread for a given conversation ID. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param thread: The thread to save. + :type thread: AgentThread + """ + if not conversation_id: + return + serialized_thread = await thread.serialize() + await self.write_to_storage(conversation_id, serialized_thread) + + async def read_from_storage(self, conversation_id: Optional[str]) -> Optional[Any]: + """Read the serialized thread from storage. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + + :return: The serialized thread if available, None otherwise. + :rtype: Optional[Any] + """ + raise NotImplementedError("read_from_storage is not implemented.") + + async def write_to_storage(self, conversation_id: Optional[str], serialized_thread: Any) -> None: + """Write the serialized thread to storage. + + :param conversation_id: The conversation ID. + :type conversation_id: Optional[str] + :param serialized_thread: The serialized thread to save. + :type serialized_thread: Any + :return: None + :rtype: None + """ + raise NotImplementedError("write_to_storage is not implemented.") + + +class JsonLocalFileAgentThreadRepository(SerializedAgentThreadRepository): + """Json based implementation of AgentThreadRepository using local file storage.""" + def __init__(self, agent: AgentProtocol, storage_path: str) -> None: + super().__init__(agent) + self._storage_path = storage_path + os.makedirs(self._storage_path, exist_ok=True) + + async def read_from_storage(self, conversation_id: Optional[str]) -> Optional[Any]: + if not conversation_id: + return None + file_path = self._get_file_path(conversation_id) + if os.path.exists(file_path): + with open(file_path, "r", encoding="utf-8") as f: + serialized_thread = f.read() + if serialized_thread: + return json.loads(serialized_thread) + return None + + async def write_to_storage(self, conversation_id: Optional[str], serialized_thread: Any) -> None: + if not conversation_id: + return + serialized_str = json.dumps(serialized_thread) + file_path = self._get_file_path(conversation_id) + with open(file_path, "w", encoding="utf-8") as f: + f.write(serialized_str) + + def _get_file_path(self, conversation_id: str) -> str: + return os.path.join(self._storage_path, f"{conversation_id}.json") diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py new file mode 100644 index 000000000000..9848d01f6b10 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/azure/ai/agentserver/agentframework/persistence/checkpoint_repository.py @@ -0,0 +1,73 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +import os +from typing import Optional + +from agent_framework import ( + CheckpointStorage, + InMemoryCheckpointStorage, + FileCheckpointStorage, +) + +class CheckpointRepository(ABC): + """ + Repository interface for storing and retrieving checkpoints. + + :meta private: + """ + @abstractmethod + async def get_or_create(self, conversation_id: Optional[str]) -> Optional[CheckpointStorage]: + """Retrieve or create a checkpoint storage by conversation ID. + + :param conversation_id: The unique identifier for the checkpoint. + :type conversation_id: Optional[str] + :return: The CheckpointStorage if found or created, None otherwise. + :rtype: Optional[CheckpointStorage] + """ + + +class InMemoryCheckpointRepository(CheckpointRepository): + """In-memory implementation of CheckpointRepository.""" + def __init__(self) -> None: + self._inventory: dict[str, CheckpointStorage] = {} + + async def get_or_create(self, conversation_id: Optional[str]) -> Optional[CheckpointStorage]: + """Retrieve or create a checkpoint storage by conversation ID. + + :param conversation_id: The unique identifier for the checkpoint. + :type conversation_id: Optional[str] + :return: The CheckpointStorage if found or created, None otherwise. + :rtype: Optional[CheckpointStorage] + """ + if not conversation_id: + return None + if conversation_id not in self._inventory: + self._inventory[conversation_id] = InMemoryCheckpointStorage() + return self._inventory[conversation_id] + + +class FileCheckpointRepository(CheckpointRepository): + """File-based implementation of CheckpointRepository.""" + def __init__(self, storage_path: str) -> None: + self._storage_path = storage_path + self._inventory: dict[str, CheckpointStorage] = {} + os.makedirs(self._storage_path, exist_ok=True) + + async def get_or_create(self, conversation_id: Optional[str]) -> Optional[CheckpointStorage]: + """Retrieve or create a checkpoint storage by conversation ID. + + :param conversation_id: The unique identifier for the checkpoint. + :type conversation_id: Optional[str] + :return: The CheckpointStorage if found or created, None otherwise. + :rtype: Optional[CheckpointStorage] + """ + if not conversation_id: + return None + if conversation_id not in self._inventory: + self._inventory[conversation_id] = FileCheckpointStorage(self._get_dir_path(conversation_id)) + return self._inventory[conversation_id] + + def _get_dir_path(self, conversation_id: str) -> str: + return os.path.join(self._storage_path, conversation_id) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/cspell.json b/sdk/agentserver/azure-ai-agentserver-agentframework/cspell.json index 116acbc87af3..af30eb4f2846 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/cspell.json +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/cspell.json @@ -5,7 +5,11 @@ "mslearn", "envtemplate", "pysort", - "redef" + "redef", + "aifunction", + "ainvoke", + "hitl", + "HITL" ], "ignorePaths": [ "*.csv", diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/dev_requirements.txt b/sdk/agentserver/azure-ai-agentserver-agentframework/dev_requirements.txt index 6c036d7fb4e0..8bf2c1695f7f 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/dev_requirements.txt +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/dev_requirements.txt @@ -1,4 +1,5 @@ -e ../../../eng/tools/azure-sdk-tools ../azure-ai-agentserver-core python-dotenv -pywin32; sys_platform == 'win32' \ No newline at end of file +pywin32; sys_platform == 'win32' +pytest-asyncio \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/mypy.ini b/sdk/agentserver/azure-ai-agentserver-agentframework/mypy.ini new file mode 100644 index 000000000000..e0bb0b83e2ce --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/mypy.ini @@ -0,0 +1,5 @@ +[mypy] +explicit_package_bases = True + +[mypy-samples.*] +ignore_errors = true \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml index 814d1d6d1a1e..0ede428d03a6 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/pyproject.toml @@ -9,6 +9,7 @@ authors = [ ] license = "MIT" classifiers = [ + "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3", @@ -20,16 +21,19 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-ai-agentserver-core", - "agent-framework-azure-ai==1.0.0b251007", - "agent-framework-core==1.0.0b251007", - "opentelemetry-exporter-otlp-proto-grpc>=1.36.0", + "azure-ai-agentserver-core==1.0.0b10", + "agent-framework-azure-ai>=1.0.0b251112,<=1.0.0b260107", + "agent-framework-core>=1.0.0b251112,<=1.0.0b260107", + "opentelemetry-exporter-otlp-proto-grpc>=1.36.0,<=1.39.0", ] [build-system] requires = ["setuptools>=69", "wheel"] build-backend = "setuptools.build_meta" +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python" + [tool.setuptools.packages.find] exclude = [ "tests*", @@ -62,5 +66,5 @@ breaking = false # incompatible python version pyright = false verifytypes = false # incompatible python version for -core verify_keywords = false -mindependency = false # depends on -core package -whl_no_aio = false \ No newline at end of file +whl_no_aio = false +mypy = false diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py index 15afa52f42b8..2ea0f19dd32a 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/basic_simple/minimal_example.py @@ -6,10 +6,10 @@ from agent_framework.azure import AzureOpenAIChatClient from azure.identity import DefaultAzureCredential from dotenv import load_dotenv +load_dotenv() from azure.ai.agentserver.agentframework import from_agent_framework -load_dotenv() def get_weather( diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/README.md new file mode 100644 index 000000000000..d9fe177e850f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/README.md @@ -0,0 +1,81 @@ +# Chat Client With Foundry Tools + +This sample demonstrates how to attach `FoundryToolsChatMiddleware` to an Agent Framework chat client so that: + +- Foundry tools configured in your Azure AI Project are converted into Agent Framework `AIFunction` tools. +- The tools are injected automatically for each agent run. + +## What this sample does + +The script creates an Agent Framework agent using: + +- `AzureOpenAIChatClient` for model inference +- `FoundryToolsChatMiddleware` to resolve and inject Foundry tools +- `from_agent_framework(agent).run()` to start an AgentServer-compatible HTTP server + +## Prerequisites + +- Python 3.10+ +- An Azure AI Project endpoint +- A tool connection configured in that project (e.g. an MCP connection) +- Azure credentials available to `DefaultAzureCredential` + +## Setup + +1. Install dependencies: + +```bash +pip install -r requirements.txt +``` + +2. Update `.env` in this folder with your values. At minimum you need: + +```dotenv +AZURE_OPENAI_ENDPOINT=https://.openai.azure.com/ +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= +OPENAI_API_VERSION= + +AZURE_AI_PROJECT_ENDPOINT=https://.services.ai.azure.com/api/projects/ +AZURE_AI_PROJECT_TOOL_CONNECTION_ID= +``` + +Notes: + +- This sample uses `DefaultAzureCredential()`. Make sure you are signed in (e.g. `az login`) or otherwise configured. + +## Run + +```bash +python chat_client_with_foundry_tool.py +``` + +This starts a local Uvicorn server (it will keep running and wait for requests). If it looks "stuck" at startup, it may just be waiting for requests. + +## Key code + +The core pattern used by this sample: + +```python +agent = AzureOpenAIChatClient( + credential=DefaultAzureCredential(), + middleware=FoundryToolsChatMiddleware( + tools=[{"type": "web_search_preview"}, {"type": "mcp", "project_connection_id": tool_connection_id}], + ), +).create_agent( + name="FoundryToolAgent", + instructions="You are a helpful assistant with access to various tools.", +) + +from_agent_framework(agent).run() +``` + +## Troubleshooting + +- **No tools found**: verify `AZURE_AI_PROJECT_TOOL_CONNECTION_ID` points at an existing tool connection in your project. +- **Auth failures**: confirm `DefaultAzureCredential` can acquire a token (try `az login`). +- **Import errors / weird agent_framework circular import**: ensure you are running the sample from this folder (not from inside the package module directory) so the external `agent_framework` dependency is imported correctly. + +## Learn more + +- Azure AI Agent Service: https://learn.microsoft.com/azure/ai-services/agents/ +- Agent Framework: https://github.com/microsoft/agent-framework diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py new file mode 100644 index 000000000000..d8c75259d29b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/chat_client_with_foundry_tool.py @@ -0,0 +1,34 @@ +# Copyright (c) Microsoft. All rights reserved. +"""Example showing how to use an agent factory function with ToolClient. + +This sample demonstrates how to pass a factory function to from_agent_framework +that receives a ToolClient and returns an AgentProtocol. This pattern allows +the agent to be created dynamically with access to tools from Azure AI Tool +Client at runtime. +""" + +import os +from dotenv import load_dotenv +from agent_framework.azure import AzureOpenAIChatClient + +from azure.ai.agentserver.agentframework import from_agent_framework, FoundryToolsChatMiddleware +from azure.identity import DefaultAzureCredential + +load_dotenv() + +def main(): + tool_connection_id = os.getenv("AZURE_AI_PROJECT_TOOL_CONNECTION_ID") + + agent = AzureOpenAIChatClient( + credential=DefaultAzureCredential(), + middleware=FoundryToolsChatMiddleware( + tools=[{"type": "web_search_preview"}, {"type": "mcp", "project_connection_id": tool_connection_id}] + )).create_agent( + name="FoundryToolAgent", + instructions="You are a helpful assistant with access to various tools.", + ) + + from_agent_framework(agent).run() + +if __name__ == "__main__": + main() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/requirements.txt b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/requirements.txt new file mode 100644 index 000000000000..79caf276114f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/chat_client_with_foundry_tool/requirements.txt @@ -0,0 +1,4 @@ +azure-ai-agentserver-agentframework +azure-identity +python-dotenv +agent-framework diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.envtemplate b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.envtemplate new file mode 100644 index 000000000000..bd646f163bb7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.envtemplate @@ -0,0 +1,3 @@ +AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/ +OPENAI_API_VERSION=2025-03-01-preview +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.gitignore b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.gitignore new file mode 100644 index 000000000000..c21dfd88c196 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/.gitignore @@ -0,0 +1 @@ +thread_storage \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md new file mode 100644 index 000000000000..165dac4eb483 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/README.md @@ -0,0 +1,96 @@ +# Human-in-the-Loop Agent Framework Sample + +This sample demonstrates how to host a Microsoft Agent Framework agent inside Azure AI Agent Server and escalate function-call responses to a human reviewer whenever approval is required. It is adapted from the [agent-framework sample](https://github.com/microsoft/agent-framework/blob/main/python/samples/getting_started/tools/ai_function_with_approval_and_threads.py). + +## Prerequisites + +- Python 3.10+ and `pip` +- Azure CLI logged in with `az login` (used by `AzureCliCredential`) +- An Azure OpenAI chat deployment + +### Environment configuration + +Copy `.envtemplate` to `.env` and fill in your Azure OpenAI details: + +``` +AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/ +OPENAI_API_VERSION=2025-03-01-preview +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= +``` + +`main.py` automatically loads the `.env` file before spinning up the server. + +## Thread persistence + +The sample uses `JsonLocalFileAgentThreadRepository` for `AgentThread` persistence, creating a JSON file per conversation ID under the sample directory. An in-memory alternative, `InMemoryAgentThreadRepository`, lives in the `azure.ai.agentserver.agentframework.persistence` module. + +To store thread messages elsewhere, inherit from `SerializedAgentThreadRepository` and override the following methods: + +- `read_from_storage(self, conversation_id: str) -> Optional[Any]` +- `write_to_storage(self, conversation_id: str, serialized_thread: Any)` + +These hooks let you plug in any backing store (blob storage, databases, etc.) without changing the rest of the sample. + +## Run the hosted agent + +From this directory start the adapter host (defaults to `http://0.0.0.0:8088`): + +```powershell +python main.py +``` + +## Send a user request + +Send a `POST` request to `http://0.0.0.0:8088/responses`: + +```json +{ + "agent": {"name": "local_agent", "type": "agent_reference"}, + "stream": false, + "input": "Add a dentist appointment on March 15th", +} +``` + +A response that requires a human decision looks like this (formatted for clarity): + +```json +{ + "conversation": {"id": ""}, + "output": [ + {...}, + { + "type": "function_call", + "id": "func_xxx", + "name": "__hosted_agent_adapter_hitl__", + "call_id": "", + "arguments": "{\"event_name\":\"Dentist Appointment\",\"date\":\"2024-03-15\"}" + } + ] +} +``` + +Capture these values from the response; you will need them to provide feedback: + +- `conversation.id` +- The `call_id` associated with `__hosted_agent_adapter_hitl__` + +## Provide human feedback + +Send a `CreateResponse` request with a `function_call_output` message that contains your decision (`approve`, `reject`, or additional guidance). Replace the placeholders before running the command: + +```json +{ + "agent": {"name": "local_agent", "type": "agent_reference"}, + "stream": false, + "conversation": {"id": ""}, + "input": [ + { + "call_id": "", + "output": "approve", + "type": "function_call_output", + } + ] +} +``` + +When the reviewer response is accepted, the worker emits the approved assistant response and the HTTP call returns the final output. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py new file mode 100644 index 000000000000..56dc5fca8860 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/main.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +from typing import Annotated, Any, Collection +from dotenv import load_dotenv + +load_dotenv() + +from agent_framework import ChatAgent, ChatMessage, ChatMessageStoreProtocol, ai_function +from agent_framework._threads import ChatMessageStoreState +from agent_framework.azure import AzureOpenAIChatClient + +from azure.ai.agentserver.agentframework import from_agent_framework +from azure.ai.agentserver.agentframework.persistence.agent_thread_repository import JsonLocalFileAgentThreadRepository + +""" +Tool Approvals with Threads + +This sample demonstrates using tool approvals with threads. +With threads, you don't need to manually pass previous messages - +the thread stores and retrieves them automatically. +""" + +class CustomChatMessageStore(ChatMessageStoreProtocol): + """Implementation of custom chat message store. + In real applications, this can be an implementation of relational database or vector store.""" + + def __init__(self, messages: Collection[ChatMessage] | None = None) -> None: + self._messages: list[ChatMessage] = [] + if messages: + self._messages.extend(messages) + + async def add_messages(self, messages: Collection[ChatMessage]) -> None: + self._messages.extend(messages) + + async def list_messages(self) -> list[ChatMessage]: + return self._messages + + @classmethod + async def deserialize(cls, serialized_store_state: Any, **kwargs: Any) -> "CustomChatMessageStore": + """Create a new instance from serialized state.""" + store = cls() + await store.update_from_state(serialized_store_state, **kwargs) + return store + + async def update_from_state(self, serialized_store_state: Any, **kwargs: Any) -> None: + """Update this instance from serialized state.""" + if serialized_store_state: + state = ChatMessageStoreState.from_dict(serialized_store_state, **kwargs) + if state.messages: + self._messages.extend(state.messages) + + async def serialize(self, **kwargs: Any) -> Any: + """Serialize this store's state.""" + state = ChatMessageStoreState(messages=self._messages) + return state.to_dict(**kwargs) + + +@ai_function(approval_mode="always_require") +def add_to_calendar( + event_name: Annotated[str, "Name of the event"], date: Annotated[str, "Date of the event"] +) -> str: + """Add an event to the calendar (requires approval).""" + print(f">>> EXECUTING: add_to_calendar(event_name='{event_name}', date='{date}')") + return f"Added '{event_name}' to calendar on {date}" + + +def build_agent(): + return ChatAgent( + chat_client=AzureOpenAIChatClient(), + name="CalendarAgent", + instructions="You are a helpful calendar assistant.", + tools=[add_to_calendar], + chat_message_store_factory=CustomChatMessageStore, + ) + + +async def main() -> None: + agent = build_agent() + thread_repository = JsonLocalFileAgentThreadRepository(agent=agent, storage_path="./thread_storage") + await from_agent_framework(agent, thread_repository=thread_repository).run_async() + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/requirements.txt b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/requirements.txt new file mode 100644 index 000000000000..c044abf99eb1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_ai_function/requirements.txt @@ -0,0 +1,5 @@ +python-dotenv>=1.0.0 +azure-identity +agent-framework-azure-ai +azure-ai-agentserver-core +azure-ai-agentserver-agentframework diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/.envtemplate b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/.envtemplate new file mode 100644 index 000000000000..bd646f163bb7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/.envtemplate @@ -0,0 +1,3 @@ +AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/ +OPENAI_API_VERSION=2025-03-01-preview +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/.gitignore b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/.gitignore new file mode 100644 index 000000000000..9a5b0b4f8d68 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/.gitignore @@ -0,0 +1 @@ +checkpoints \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/README.md new file mode 100644 index 000000000000..d7a6dd978fec --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/README.md @@ -0,0 +1,142 @@ +# Human-in-the-Loop Agent Framework Sample + +This sample shows how to host a Microsoft Agent Framework workflow inside Azure AI Agent Server while escalating responses to a real human when the reviewer executor decides that manual approval is required. The sample is created by [agent-framework](https://github.com/microsoft/agent-framework/blob/main/python/samples/getting_started/workflows/agents/workflow_as_agent_human_in_the_loop.py). + +## Prerequisites + +- Python 3.10+ and `pip` +- Azure CLI logged in with `az login` (used by `AzureCliCredential`) +- An Azure OpenAI chat deployment + +### Environment configuration + +Copy `.envtemplate` to `.env` and fill in your Azure OpenAI details: + +``` +AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/ +OPENAI_API_VERSION=2025-03-01-preview +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= +``` + +`main.py` automatically loads the `.env` file before spinning up the server. + +## Checkpoint persistence + +The adapter uses the `CheckpointRepository` interface to retrieve a `CheckpointStorage` per conversation. The storage keeps serialized workflow state so long-running conversations can resume after the host restarts. + +- Use `InMemoryCheckpointRepository()` for quick demos; checkpoints vanish once the Python process exits. +- `FileCheckpointRepository("")` is implemented to persist checkpoints in files. +- To back checkpoints with a different store (Redis, Blob, etc.), subclass `CheckpointRepository`, implement `get_or_create`, and pass your repository instance to `from_agent_framework(..., checkpoint_repository=)`. + +## Run the hosted workflow agent + +For Human-in-the-loop scenario, the `HumanReviewRequest` and `ReviewResponse` are provided by user. User should provide functions for these classes that allow adapter convert the data to request payloads. + +```py +@dataclass +class HumanReviewRequest: + """A request message type for escalation to a human reviewer.""" + + agent_request: ReviewRequest | None = None + + def convert_to_payload(self) -> str: # called by adapter + """Convert the HumanReviewRequest to a payload string.""" + request = self.agent_request + payload: dict[str, Any] = {"agent_request": None} + + if request: + payload["agent_request"] = { + "request_id": request.request_id, + "user_messages": [msg.to_dict() for msg in request.user_messages], + "agent_messages": [msg.to_dict() for msg in request.agent_messages], + } + + return json.dumps(payload) +``` + +```py +@dataclass +class ReviewResponse: + """Structured response from Reviewer back to Worker.""" + + request_id: str + feedback: str + approved: bool + + @staticmethod + def convert_from_payload(payload: str) -> "ReviewResponse": + """Convert a JSON payload string to a ReviewResponse instance.""" + data = json.loads(payload) + return ReviewResponse( + request_id=data["request_id"], + feedback=data["feedback"], + approved=data["approved"], + ) +``` + +From this directory start the adapter host (defaults to `http://0.0.0.0:8088`): + +```powershell +python main.py +``` + +The worker executor produces answers, the reviewer executor always escalates to a person, and the adapter exposes the whole workflow through the `/responses` endpoint. + + + +## Send a user request + +Send a `POST` request to `http://0.0.0.0:8088/responses` + +```json +{ + "agent": {"name": "local_agent", "type": "agent_reference"}, + "stream": false, + "input": "Write code for parallel reading 1 million Files on disk and write to a sorted output file.", +} +``` + +A response with human-review request looks like this (formatted for clarity): + +```json +{ + "conversation": {"id": ""}, + "output": [ + {...}, + { + "type": "function_call", + "id": "func_xxx", + "name": "__hosted_agent_adapter_hitl__", + "call_id": "", + "arguments": "{\"agent_request\":{\"request_id\":\"\",...}}" + } + ] +} +``` + +Capture three values from the response: + +- `conversation.id` +- The `call_id` of the `__hosted_agent_adapter_hitl__` function call +- The `request_id` inside the serialized `agent_request` + +## Provide human feedback + +Respond by sending a `CreateResponse` request with `function_call_output` message that carries your review decision. Replace the placeholders before running the command: + +```json +{ + "agent": {"name": "local_agent", "type": "agent_reference"}, + "stream": false, + "conversation": {"id": ""}, + "input": [ + { + "call_id": "", + "output": "{\"request_id\":\"\",\"approved\":true,\"feedback\":\"approve\"}", + "type": "function_call_output", + } + ] +} +``` + +Once the reviewer accepts the human feedback, the worker emits the approved assistant response and the HTTP call returns the final output. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py new file mode 100644 index 000000000000..cc89c941e65e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/main.py @@ -0,0 +1,119 @@ +# Copyright (c) Microsoft. All rights reserved. + +import asyncio +import json +from dataclasses import dataclass +from typing import Any + +from agent_framework.azure import AzureOpenAIChatClient +from azure.identity import AzureCliCredential +from dotenv import load_dotenv + +from agent_framework import ( # noqa: E402 + Executor, + WorkflowBuilder, + WorkflowContext, + handler, + response_handler, +) +from workflow_as_agent_reflection_pattern import ( # noqa: E402 + ReviewRequest, + ReviewResponse, + Worker, +) + +from azure.ai.agentserver.agentframework import from_agent_framework +from azure.ai.agentserver.agentframework.persistence import FileCheckpointRepository + +load_dotenv() + +@dataclass +class HumanReviewRequest: + """A request message type for escalation to a human reviewer.""" + + agent_request: ReviewRequest | None = None + + def convert_to_payload(self) -> str: + """Convert the HumanReviewRequest to a payload string.""" + request = self.agent_request + payload: dict[str, Any] = {"agent_request": None} + + if request: + payload["agent_request"] = { + "request_id": request.request_id, + "user_messages": [msg.to_dict() for msg in request.user_messages], + "agent_messages": [msg.to_dict() for msg in request.agent_messages], + } + + return json.dumps(payload, indent=2) + + +class ReviewerWithHumanInTheLoop(Executor): + """Executor that always escalates reviews to a human manager.""" + + def __init__(self, worker_id: str, reviewer_id: str | None = None) -> None: + unique_id = reviewer_id or f"{worker_id}-reviewer" + super().__init__(id=unique_id) + self._worker_id = worker_id + + @handler + async def review(self, request: ReviewRequest, ctx: WorkflowContext) -> None: + # In this simplified example, we always escalate to a human manager. + # See workflow_as_agent_reflection.py for an implementation + # using an automated agent to make the review decision. + print(f"Reviewer: Evaluating response for request {request.request_id[:8]}...") + print("Reviewer: Escalating to human manager...") + + # Forward the request to a human manager by sending a HumanReviewRequest. + await ctx.request_info( + request_data=HumanReviewRequest(agent_request=request), + response_type=ReviewResponse, + ) + + @response_handler + async def accept_human_review( + self, + original_request: HumanReviewRequest, + response: ReviewResponse, + ctx: WorkflowContext[ReviewResponse], + ) -> None: + # Accept the human review response and forward it back to the Worker. + print(f"Reviewer: Accepting human review for request {response.request_id[:8]}...") + print(f"Reviewer: Human feedback: {response.feedback}") + print(f"Reviewer: Human approved: {response.approved}") + print("Reviewer: Forwarding human review back to worker...") + await ctx.send_message(response, target_id=self._worker_id) + +def create_builder(): + # Build a workflow with bidirectional communication between Worker and Reviewer, + # and escalation paths for human review. + builder = ( + WorkflowBuilder() + .register_executor( + lambda: Worker( + id="sub-worker", + chat_client=AzureOpenAIChatClient(credential=AzureCliCredential()), + ), + name="worker", + ) + .register_executor( + lambda: ReviewerWithHumanInTheLoop(worker_id="sub-worker"), + name="reviewer", + ) + .add_edge("worker", "reviewer") # Worker sends requests to Reviewer + .add_edge("reviewer", "worker") # Reviewer sends feedback to Worker + .set_start_executor("worker") + ) + return builder + + +async def run_agent() -> None: + """Run the workflow inside the agent server adapter.""" + builder = create_builder() + await from_agent_framework( + builder, # pass workflow builder to adapter + checkpoint_repository=FileCheckpointRepository(storage_path="./checkpoints"), # for checkpoint storage + ).run_async() + +if __name__ == "__main__": + asyncio.run(run_agent()) \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/requirements.txt b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/requirements.txt new file mode 100644 index 000000000000..c044abf99eb1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/requirements.txt @@ -0,0 +1,5 @@ +python-dotenv>=1.0.0 +azure-identity +agent-framework-azure-ai +azure-ai-agentserver-core +azure-ai-agentserver-agentframework diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/workflow_as_agent_reflection_pattern.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/workflow_as_agent_reflection_pattern.py new file mode 100644 index 000000000000..ef2a286ba174 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/human_in_the_loop_workflow_agent/workflow_as_agent_reflection_pattern.py @@ -0,0 +1,139 @@ +# Copyright (c) Microsoft. All rights reserved. + +from dataclasses import dataclass +import json +from uuid import uuid4 + +from agent_framework import ( + AgentRunResponseUpdate, + AgentRunUpdateEvent, + ChatClientProtocol, + ChatMessage, + Contents, + Executor, + Role, + WorkflowContext, + handler, +) + +@dataclass +class ReviewRequest: + """Structured request passed from Worker to Reviewer for evaluation.""" + + request_id: str + user_messages: list[ChatMessage] + agent_messages: list[ChatMessage] + + +@dataclass +class ReviewResponse: + """Structured response from Reviewer back to Worker.""" + + request_id: str + feedback: str + approved: bool + + @staticmethod + def convert_from_payload(payload: str) -> "ReviewResponse": + """Convert a JSON payload string to a ReviewResponse instance.""" + data = json.loads(payload) + return ReviewResponse( + request_id=data["request_id"], + feedback=data["feedback"], + approved=data["approved"], + ) + + +PendingReviewState = tuple[ReviewRequest, list[ChatMessage]] + + +class Worker(Executor): + """Executor that generates responses and incorporates feedback when necessary.""" + + def __init__(self, id: str, chat_client: ChatClientProtocol) -> None: + super().__init__(id=id) + self._chat_client = chat_client + self._pending_requests: dict[str, PendingReviewState] = {} + + @handler + async def handle_user_messages(self, user_messages: list[ChatMessage], ctx: WorkflowContext[ReviewRequest]) -> None: + print("Worker: Received user messages, generating response...") + + # Initialize chat with system prompt. + messages = [ChatMessage(role=Role.SYSTEM, text="You are a helpful assistant.")] + messages.extend(user_messages) + + print("Worker: Calling LLM to generate response...") + response = await self._chat_client.get_response(messages=messages) + print(f"Worker: Response generated: {response.messages[-1].text}") + + # Add agent messages to context. + messages.extend(response.messages) + + # Create review request and send to Reviewer. + request = ReviewRequest(request_id=str(uuid4()), user_messages=user_messages, agent_messages=response.messages) + print(f"Worker: Sending response for review (ID: {request.request_id[:8]})") + await ctx.send_message(request) + + # Track request for possible retry. + self._pending_requests[request.request_id] = (request, messages) + + @handler + async def handle_review_response(self, review: ReviewResponse, ctx: WorkflowContext[ReviewRequest]) -> None: + print(f"Worker: Received review for request {review.request_id[:8]} - Approved: {review.approved}") + + if review.request_id not in self._pending_requests: + raise ValueError(f"Unknown request ID in review: {review.request_id}") + + request, messages = self._pending_requests.pop(review.request_id) + + if review.approved: + print("Worker: Response approved. Emitting to external consumer...") + contents: list[Contents] = [] + for message in request.agent_messages: + contents.extend(message.contents) + + # Emit approved result to external consumer via AgentRunUpdateEvent. + await ctx.add_event( + AgentRunUpdateEvent(self.id, data=AgentRunResponseUpdate(contents=contents, role=Role.ASSISTANT)) + ) + return + + print(f"Worker: Response not approved. Feedback: {review.feedback}") + print("Worker: Regenerating response with feedback...") + + # Incorporate review feedback. + messages.append(ChatMessage(role=Role.SYSTEM, text=review.feedback)) + messages.append( + ChatMessage(role=Role.SYSTEM, text="Please incorporate the feedback and regenerate the response.") + ) + messages.extend(request.user_messages) + + # Retry with updated prompt. + response = await self._chat_client.get_response(messages=messages) + print(f"Worker: New response generated: {response.messages[-1].text}") + + messages.extend(response.messages) + + # Send updated request for re-review. + new_request = ReviewRequest( + request_id=review.request_id, user_messages=request.user_messages, agent_messages=response.messages + ) + await ctx.send_message(new_request) + + # Track new request for further evaluation. + self._pending_requests[new_request.request_id] = (new_request, messages) + + async def on_checkpoint_save(self) -> dict: + """ + Persist pending requests during checkpointing. + In memory implementation for demonstration purposes. + """ + return {"pending_requests": self._pending_requests} + + async def on_checkpoint_restore(self, data: dict) -> None: + """ + Load pending requests from checkpoint data. + In memory implementation for demonstration purposes. + """ + self._pending_requests = data.get("pending_requests", {}) \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md index 59bb6b9f19ec..cc82d1f19171 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/README.md @@ -1,287 +1,63 @@ -## Workflow Agent Reflection Sample (Python) +## Workflow Agent Simple Sample (Python) -This sample demonstrates how to wrap an Agent Framework workflow (with iterative review + improvement) as an agent using the Container Agents Adapter. It implements a "reflection" pattern consisting of two executors: +This sample hosts a two-step Agent Framework workflowβ€”`Writer` followed by `Reviewer`β€”through the Azure AI Agent Server Adapter. The writer creates content, the reviewer provides the final response, and the adapter exposes the workflow through the same HTTP surface as any hosted agent. -- Worker: Produces an initial answer (and revised answers after feedback) -- Reviewer: Evaluates the answer against quality criteria and either approves or returns constructive feedback +### What `workflow_agent_simple.py` does +- Builds a workflow with `WorkflowBuilder` +- Passes the builder to `from_agent_framework(...).run_async()` so the adapter spins up an HTTP server (defaults to `0.0.0.0:8088`). +- A example of passing with a factory of `Workflow` is shown in comments. -The workflow cycles until the Reviewer approves the response. Only approved content is emitted externally (streamed the same way as a normal agent response). This pattern is useful for quality‑controlled assistance, gated tool use, evaluative chains, or iterative refinement. - -### Key Concepts Shown -- `WorkflowBuilder` + `.as_agent()` to expose a workflow as a standard agent -- Bidirectional edges enabling cyclical review (Worker ↔ Reviewer) -- Structured output parsing (Pydantic model) for review feedback -- Emitting `AgentRunUpdateEvent` to stream only approved messages -- Managing pending requests and re‑submission with incorporated feedback - -File: `workflow_agent_simple.py` +Please note that the `WorkflowBuilder` or `Workflow` factory will be called for each incoming request. The `Workflow` will be converted to `WorkflowAgent` by `.as_agent()`. The workflow definition need ot be valid for `WorkflowAgent`. --- ## Prerequisites - -> **Azure sign-in:** Run `az login` before starting the sample so `DefaultAzureCredential` can acquire a CLI token. - -Dependencies used by `workflow_agent_simple.py`: -- agent-framework-azure-ai (published package with workflow abstractions) -- agents_adapter -- azure-identity (for `DefaultAzureCredential`) -- python-dotenv (loads `.env` for local credentials) -- pydantic (pulled transitively; listed for clarity) - -Install from PyPI (from the repo root: `container_agents/`): -```bash -pip install agent-framework-azure-ai azure-identity python-dotenv - -pip install -e src/adapter/python -``` +- Python 3.10+ +- Azure CLI authenticated with `az login` (required for `AzureCliCredential`). +- An Azure AI project that already hosts a chat model deployment supported by the Agent Framework Azure client. --- -## Additional Requirements - -1. Azure AI project with a model deployment (supports Microsoft hosted, Azure OpenAI, or custom models exposed via Azure AI Foundry). - ---- - -## Configuration - -Copy `.envtemplate` to `.env` and fill in real values: -``` -AZURE_AI_PROJECT_ENDPOINT= -AZURE_AI_MODEL_DEPLOYMENT_NAME= -AGENT_PROJECT_NAME= -``` -`AGENT_PROJECT_NAME` lets you override the default Azure AI agent project for this workflow; omit it to fall back to the SDK default. +## Setup +1. Copy `.envtemplate` to `.env` and fill in your project details: + ``` + AZURE_AI_PROJECT_ENDPOINT= + AZURE_AI_MODEL_DEPLOYMENT_NAME= + ``` +2. Install the sample dependencies: + ```bash + pip install -r requirements.txt + ``` --- ## Run the Workflow Agent - From this folder: ```bash python workflow_agent_simple.py ``` -The server (via the adapter) will start on `0.0.0.0:8088` by default. - ---- - -## Send a Non‑Streaming Request - -```bash -curl -sS \ - -H "Content-Type: application/json" \ - -X POST http://localhost:8088/runs \ - -d '{"input":"Explain the concept of reflection in this workflow sample.","stream":false}' -``` - -Sample output (non‑streaming): - -``` -Processing 1 million files in parallel and writing their contents into a sorted output file can be a computationally and resource-intensive task. To handle it effectively, you can use Python with libraries like `concurrent.futures` for parallelism and `heapq` for the sorting and merging. - -Below is an example implementation: - -import os -from concurrent.futures import ThreadPoolExecutor -import heapq - -def read_file(file_path): - """Read the content of a single file and return it as a list of lines.""" - with open(file_path, 'r') as file: - return file.readlines() - -def parallel_read_files(file_paths, max_workers=8): - """ - Read files in parallel and return all the lines in memory. - :param file_paths: List of file paths to read. - :param max_workers: Number of worker threads to use for parallelism. - """ - all_lines = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit tasks to read each file in parallel - results = executor.map(read_file, file_paths) - # Collect the results - for lines in results: - all_lines.extend(lines) - return all_lines - -def write_sorted_output(lines, output_file_path): - """ - Write sorted lines to the output file. - :param lines: List of strings to be sorted and written. - :param output_file_path: File path to write the sorted result. - """ - sorted_lines = sorted(lines) - with open(output_file_path, 'w') as output_file: - output_file.writelines(sorted_lines) - -def main(directory_path, output_file_path): - """ - Main function to read files in parallel and write sorted output. - :param directory_path: Path to the directory containing input files. - :param output_file_path: File path to write the sorted output. - """ - # Get a list of all the file paths in the given directory - file_paths = [os.path.join(directory_path, f) for f in os.listdir(directory_path) if os.path.isfile(os.path.join(directory_path, f))] - - print(f"Found {len(file_paths)} files. Reading files in parallel...") - - # Read all lines from the files in parallel - all_lines = parallel_read_files(file_paths) - - print(f"Total lines read: {len(all_lines)}. Sorting and writing to output file...") - - # Write the sorted lines to the output file - write_sorted_output(all_lines, output_file_path) - - print(f"Sorted output written to: {output_file_path}") - -if __name__ == "__main__": - # Replace these paths with the appropriate input directory and output file path - input_directory = "path/to/input/directory" # Directory containing 1 million files - output_file = "path/to/output/sorted_output.txt" # Output file path - - main(input_directory, output_file) - -### Key Features and Steps: - -1. **Parallel Reading with `ThreadPoolExecutor`**: - - Files are read in parallel using threads to improve I/O performance since reading many files is mostly I/O-bound. - -2. **Sorting and Writing**: - - Once all lines are aggregated into memory, they are sorted using Python's `sorted()` function and written to the output file in one go. - -3. **Handles Large Number of Files**: - - The program uses threads to manage the potentially massive number of files in parallel, saving time instead of processing them serially. - -### Considerations: -- **Memory Usage**: This script reads all file contents into memory. If the total size of the files is too large, you may encounter memory issues. In such cases, consider processing the files in smaller chunks. -- **Sorting**: For extremely large data, consider using an external/merge sort technique to handle sorting in smaller chunks. -- **I/O Performance**: Ensure that your I/O subsystem and disk can handle the load. - -Let me know if you'd like an optimized version to handle larger datasets with limited memory! - -Usage (if provided): None -``` - ---- - -## Send a Streaming Request (Server-Sent Events) - -```bash -curl -N \ - -H "Content-Type: application/json" \ - -X POST http://localhost:8088/runs \ - -d '{"input":"How does the reviewer decide to approve?","stream":true}' -``` - -Sample output (streaming): - -``` -Here is a Python script that demonstrates parallel reading of 1 million files using `concurrent.futures` for parallelism and `heapq` to write the outputs to a sorted file. This approach ensures efficiency when dealing with such a large number of files. - - -import os -import heapq -from concurrent.futures import ThreadPoolExecutor - -def read_file(file_path): - """ - Read the content of a single file and return it as a list of lines. - """ - with open(file_path, 'r') as file: - return file.readlines() - -def parallel_read_files(file_paths, max_workers=4): - """ - Read multiple files in parallel. - """ - all_lines = [] - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit reading tasks to the thread pool - futures = [executor.submit(read_file, file_path) for file_path in file_paths] - - # Gather results as they are completed - for future in futures: - all_lines.extend(future.result()) - - return all_lines - -def write_sorted_output(lines, output_file): - """ - Write sorted lines to an output file. - """ - sorted_lines = sorted(lines) - with open(output_file, 'w') as file: - file.writelines(sorted_lines) - -if __name__ == "__main__": - # Set the directory containing your input files - input_directory = 'path_to_your_folder_with_files' - - # Get the list of all input files - file_paths = [os.path.join(input_directory, f) for f in os.listdir(input_directory) if os.path.isfile(os.path.join(input_directory, f))] - - # Specify the number of threads for parallel processing - max_threads = 8 # Adjust according to your system's capabilities - - # Step 1: Read all files in parallel - print("Reading files in parallel...") - all_lines = parallel_read_files(file_paths, max_workers=max_threads) - - # Step 2: Write the sorted data to the output file - output_file = 'sorted_output.txt' - print(f"Writing sorted output to {output_file}...") - write_sorted_output(all_lines, output_file) - - print("Operation complete.") - -[comment]: # ( cspell:ignore pysort ) - -### Key Points: -1. **Parallel Read**: The reading of files is handled using `concurrent.futures.ThreadPoolExecutor`, allowing multiple files to be processed simultaneously. - -2. **Sorted Output**: After collecting all lines from the files, the `sorted()` function is used to sort the content in memory. This ensures that the final output file will have all data in sorted order. - -3. **Adjustable Parallelism**: The `max_threads` parameter can be modified to control the number of threads used for file reading. The value should match your system's capabilities for optimal performance. - -4. **Large Data Handling**: If the data from 1 million files is too large to fit into memory, consider using an external merge sort algorithm or a library like `pysort` for efficient external sorting. - -Let me know if you'd like improvements or adjustments for more specific scenarios! -Final usage (if provided): None -``` - -> Only the final approved assistant content is emitted as normal output deltas; intermediate review feedback stays internal. - ---- -## How the Reflection Loop Works -1. User query enters the workflow (Worker start executor) -2. Worker produces an answer with model call -3. Reviewer evaluates using a structured schema (`feedback`, `approved`) -4. If not approved: Worker augments context with feedback + regeneration instruction, then re‑answers -5. Loop continues until `approved=True` -6. Approved content is emitted as `AgentRunResponseUpdate` (streamed externally) +The adapter starts the server on `http://0.0.0.0:8088` by default. --- -## Troubleshooting -| Issue | Resolution | -|-------|------------| -| `DefaultAzureCredential` errors | Run `az login` or configure a service principal. | -| Empty / no streaming | Confirm `stream` flag in request JSON and that the event loop is healthy. | -| Model 404 / deployment error | Verify `AZURE_AI_MODEL_DEPLOYMENT_NAME` exists in the Azure AI project configured by `AZURE_AI_PROJECT_ENDPOINT`. | -| `.env` not loading | Ensure `.env` sits beside the script (or set `dotenv_path`) and that `python-dotenv` is installed. | +## Send Requests +- **Non-streaming:** + ```bash + curl -sS \ + -H "Content-Type: application/json" \ + -X POST http://localhost:8088/runs \ + -d '{"input":"Create a slogan for a new electric SUV that is affordable and fun to drive","stream":false}' + ``` --- ## Related Resources - Agent Framework repo: https://github.com/microsoft/agent-framework -- Basic simple sample README (same folder structure) for installation reference +- Adapter package docs: `azure.ai.agentserver.agentframework` in this SDK --- ## License & Support -This sample follows the repository's LICENSE. For questions about unreleased Agent Framework features, contact the Agent Framework team via its GitHub repository. +This sample follows the repository LICENSE. For questions about the Agent Framework itself, open an issue in the Agent Framework GitHub repository. diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py index ce3cca956273..5de214c9ff09 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_agent_simple/workflow_agent_simple.py @@ -1,291 +1,51 @@ -# Copyright (c) Microsoft. All rights reserved. - import asyncio -from dataclasses import dataclass -from uuid import uuid4 - -from agent_framework import ( - AgentRunResponseUpdate, - AgentRunUpdateEvent, - BaseChatClient, - ChatMessage, - Contents, - Executor, - Role as ChatRole, - WorkflowBuilder, - WorkflowContext, - handler, -) -from agent_framework_azure_ai import AzureAIAgentClient -from azure.identity.aio import DefaultAzureCredential from dotenv import load_dotenv -from pydantic import BaseModel - -from azure.ai.agentserver.agentframework import from_agent_framework - -""" -The following sample demonstrates how to wrap a workflow as an agent using WorkflowAgent. - -This sample shows how to: -1. Create a workflow with a reflection pattern (Worker + Reviewer executors) -2. Wrap the workflow as an agent using the .as_agent() method -3. Stream responses from the workflow agent like a regular agent -4. Implement a review-retry mechanism where responses are iteratively improved - -The example implements a quality-controlled AI assistant where: -- Worker executor generates responses to user queries -- Reviewer executor evaluates the responses and provides feedback -- If not approved, the Worker incorporates feedback and regenerates the response -- The cycle continues until the response is approved -- Only approved responses are emitted to the external consumer - -Key concepts demonstrated: -- WorkflowAgent: Wraps a workflow to make it behave as an agent -- Bidirectional workflow with cycles (Worker ↔ Reviewer) -- AgentRunUpdateEvent: How workflows communicate with external consumers -- Structured output parsing for review feedback -- State management with pending requests tracking -""" - - -@dataclass -class ReviewRequest: - request_id: str - user_messages: list[ChatMessage] - agent_messages: list[ChatMessage] - - -@dataclass -class ReviewResponse: - request_id: str - feedback: str - approved: bool - load_dotenv() +from agent_framework import ChatAgent, Workflow, WorkflowBuilder +from agent_framework.azure import AzureAIAgentClient +from azure.identity.aio import AzureCliCredential -class Reviewer(Executor): - """An executor that reviews messages and provides feedback.""" - - def __init__(self, chat_client: BaseChatClient) -> None: - super().__init__(id="reviewer") - self._chat_client = chat_client - - @handler - async def review( - self, request: ReviewRequest, ctx: WorkflowContext[ReviewResponse] - ) -> None: - print( - f"πŸ” Reviewer: Evaluating response for request {request.request_id[:8]}..." - ) - - # Use the chat client to review the message and use structured output. - # NOTE: this can be modified to use an evaluation framework. - - class _Response(BaseModel): - feedback: str - approved: bool - - # Define the system prompt. - messages = [ - ChatMessage( - role=ChatRole.SYSTEM, - text="You are a reviewer for an AI agent, please provide feedback on the " - "following exchange between a user and the AI agent, " - "and indicate if the agent's responses are approved or not.\n" - "Use the following criteria for your evaluation:\n" - "- Relevance: Does the response address the user's query?\n" - "- Accuracy: Is the information provided correct?\n" - "- Clarity: Is the response easy to understand?\n" - "- Completeness: Does the response cover all aspects of the query?\n" - "Be critical in your evaluation and provide constructive feedback.\n" - "Do not approve until all criteria are met.", - ) - ] - - # Add user and agent messages to the chat history. - messages.extend(request.user_messages) - - # Add agent messages to the chat history. - messages.extend(request.agent_messages) - - # Add add one more instruction for the assistant to follow. - messages.append( - ChatMessage( - role=ChatRole.USER, - text="Please provide a review of the agent's responses to the user.", - ) - ) - - print("πŸ” Reviewer: Sending review request to LLM...") - # Get the response from the chat client. - response = await self._chat_client.get_response( - messages=messages, response_format=_Response - ) - - # Parse the response. - parsed = _Response.model_validate_json(response.messages[-1].text) - - print(f"πŸ” Reviewer: Review complete - Approved: {parsed.approved}") - print(f"πŸ” Reviewer: Feedback: {parsed.feedback}") - - # Send the review response. - await ctx.send_message( - ReviewResponse( - request_id=request.request_id, - feedback=parsed.feedback, - approved=parsed.approved, - ) - ) - - -class Worker(Executor): - """An executor that performs tasks for the user.""" - - def __init__(self, chat_client: BaseChatClient) -> None: - super().__init__(id="worker") - self._chat_client = chat_client - self._pending_requests: dict[str, tuple[ReviewRequest, list[ChatMessage]]] = {} - - @handler - async def handle_user_messages( - self, user_messages: list[ChatMessage], ctx: WorkflowContext[ReviewRequest] - ) -> None: - print("πŸ”§ Worker: Received user messages, generating response...") - - # Handle user messages and prepare a review request for the reviewer. - # Define the system prompt. - messages = [ - ChatMessage(role=ChatRole.SYSTEM, text="You are a helpful assistant.") - ] - - # Add user messages. - messages.extend(user_messages) - - print("πŸ”§ Worker: Calling LLM to generate response...") - # Get the response from the chat client. - response = await self._chat_client.get_response(messages=messages) - print(f"πŸ”§ Worker: Response generated: {response.messages[-1].text}") - - # Add agent messages. - messages.extend(response.messages) - - # Create the review request. - request = ReviewRequest( - request_id=str(uuid4()), - user_messages=user_messages, - agent_messages=response.messages, - ) - - print( - f"πŸ”§ Worker: Generated response, sending to reviewer (ID: {request.request_id[:8]})" - ) - # Send the review request. - await ctx.send_message(request) - - # Add to pending requests. - self._pending_requests[request.request_id] = (request, messages) - - @handler - async def handle_review_response( - self, review: ReviewResponse, ctx: WorkflowContext[ReviewRequest] - ) -> None: - print( - f"πŸ”§ Worker: Received review for request {review.request_id[:8]} - Approved: {review.approved}" - ) - - # Handle the review response. Depending on the approval status, - # either emit the approved response as AgentRunUpdateEvent, or - # retry given the feedback. - if review.request_id not in self._pending_requests: - raise ValueError( - f"Received review response for unknown request ID: {review.request_id}" - ) - # Remove the request from pending requests. - request, messages = self._pending_requests.pop(review.request_id) - - if review.approved: - print("βœ… Worker: Response approved! Emitting to external consumer...") - # If approved, emit the agent run response update to the workflow's - # external consumer. - contents: list[Contents] = [] - for message in request.agent_messages: - contents.extend(message.contents) - # Emitting an AgentRunUpdateEvent in a workflow wrapped by a WorkflowAgent - # will send the AgentRunResponseUpdate to the WorkflowAgent's - # event stream. - await ctx.add_event( - AgentRunUpdateEvent( - self.id, - data=AgentRunResponseUpdate( - contents=contents, role=ChatRole.ASSISTANT - ), - ) - ) - return - - print(f"❌ Worker: Response not approved. Feedback: {review.feedback}") - print("πŸ”§ Worker: Incorporating feedback and regenerating response...") - - # Construct new messages with feedback. - messages.append(ChatMessage(role=ChatRole.SYSTEM, text=review.feedback)) - - # Add additional instruction to address the feedback. - messages.append( - ChatMessage( - role=ChatRole.SYSTEM, - text="Please incorporate the feedback above, and provide a response to user's next message.", - ) - ) - messages.extend(request.user_messages) - - # Get the new response from the chat client. - response = await self._chat_client.get_response(messages=messages) - print( - f"πŸ”§ Worker: New response generated after feedback: {response.messages[-1].text}" - ) - - # Process the response. - messages.extend(response.messages) - - print( - f"πŸ”§ Worker: Generated improved response, sending for re-review (ID: {review.request_id[:8]})" - ) - # Send an updated review request. - new_request = ReviewRequest( - request_id=review.request_id, - user_messages=request.user_messages, - agent_messages=response.messages, - ) - await ctx.send_message(new_request) +from azure.ai.agentserver.agentframework import from_agent_framework - # Add to pending requests. - self._pending_requests[new_request.request_id] = (new_request, messages) +def create_writer_agent(client: AzureAIAgentClient) -> ChatAgent: + return client.create_agent( + name="Writer", + instructions=( + "You are an excellent content writer. You create new content and edit contents based on the feedback." + ), + ) -def build_agent(chat_client: BaseChatClient): - reviewer = Reviewer(chat_client=chat_client) - worker = Worker(chat_client=chat_client) - return ( - WorkflowBuilder() - .add_edge( - worker, reviewer - ) # <--- This edge allows the worker to send requests to the reviewer - .add_edge( - reviewer, worker - ) # <--- This edge allows the reviewer to send feedback back to the worker - .set_start_executor(worker) - .build() - .as_agent() # Convert the workflow to an agent. +def create_reviewer_agent(client: AzureAIAgentClient) -> ChatAgent: + return client.create_agent( + name="Reviewer", + instructions=( + "You are an excellent content reviewer. " + "Provide actionable feedback to the writer about the provided content. " + "Provide the feedback in the most concise manner possible." + ), ) async def main() -> None: - async with DefaultAzureCredential() as credential: - async with AzureAIAgentClient(async_credential=credential) as chat_client: - agent = build_agent(chat_client) - await from_agent_framework(agent).run_async() + async with AzureCliCredential() as cred, AzureAIAgentClient(credential=cred) as client: + builder = ( + WorkflowBuilder() + .register_agent(lambda: create_writer_agent(client), name="writer") + .register_agent(lambda: create_reviewer_agent(client), name="reviewer", output_response=True) + .set_start_executor("writer") + .add_edge("writer", "reviewer") + ) + + # Pass the WorkflowBuilder to the adapter and run it + # await from_agent_framework(workflow=builder).run_async() + + # Or create a factory function for the workflow pass the workflow factory to the adapter + def workflow_factory() -> Workflow: + return builder.build() + await from_agent_framework(workflow_factory).run_async() if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/README.md b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/README.md new file mode 100644 index 000000000000..3034a4e7de6b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/README.md @@ -0,0 +1,89 @@ +# Workflow Agent with Foundry Managed Checkpoints + +This sample hosts a two-step Agent Framework workflowβ€”`Writer` followed by `Reviewer`β€”and uses +`FoundryCheckpointRepository` to persist workflow checkpoints in Azure AI Foundry managed storage. + +With Foundry managed checkpoints, workflow state is stored remotely so long-running conversations can +resume even after the host process restarts, without managing your own storage backend. + +### What `main.py` does + +- Builds a workflow with `WorkflowBuilder` (writer + reviewer) +- Creates a `FoundryCheckpointRepository` pointed at your Azure AI Foundry project +- Passes both to `from_agent_framework(..., checkpoint_repository=...)` so the adapter spins up an + HTTP server (defaults to `0.0.0.0:8088`) + +--- + +## Prerequisites + +- Python 3.10+ +- Azure CLI authenticated with `az login` (required for `AzureCliCredential`) +- An Azure AI Foundry project with a chat model deployment + +--- + +## Setup + +1. Create a `.env` file in this folder: + ``` + AZURE_AI_PROJECT_ENDPOINT=https://.services.ai.azure.com/api/projects/ + AZURE_AI_MODEL_DEPLOYMENT_NAME= + ``` + +2. Install dependencies: + ```bash + pip install azure-ai-agentserver-agentframework agent-framework-azure-ai azure-identity python-dotenv + ``` + +--- + +## Run the Workflow Agent + +From this folder: + +```bash +python main.py +``` + +The adapter starts the server on `http://0.0.0.0:8088` by default. + +--- + +## Send Requests + +**Non-streaming:** + +```bash +curl -sS \ + -H "Content-Type: application/json" \ + -X POST http://localhost:8088/responses \ + -d '{ + "agent": {"name": "local_agent", "type": "agent_reference"}, + "stream": false, + "input": "Write a short blog post about cloud-native AI applications", + "conversation": {"id": "test-conversation-1"} + }' +``` + +The `conversation.id` ties requests to the same checkpoint session. Subsequent requests with the same +ID will resume the workflow from its last checkpoint. + +--- + +## Checkpoint Repository Options + +The `checkpoint_repository` parameter in `from_agent_framework` accepts any `CheckpointRepository` implementation: + +| Repository | Use case | +|---|---| +| `InMemoryCheckpointRepository()` | Quick demos; checkpoints vanish when the process exits | +| `FileCheckpointRepository("")` | Local file-based persistence | +| `FoundryCheckpointRepository(project_endpoint, credential)` | Azure AI Foundry managed remote storage (this sample) | + +--- + +## Related Resources + +- Agent Framework repo: https://github.com/microsoft/agent-framework +- Adapter package docs: `azure.ai.agentserver.agentframework` in this SDK diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/main.py b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/main.py new file mode 100644 index 000000000000..586c31b8d4b7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/samples/workflow_with_foundry_checkpoints/main.py @@ -0,0 +1,84 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Workflow Agent with Foundry Managed Checkpoints + +This sample demonstrates how to use FoundryCheckpointRepository with +a WorkflowBuilder agent to persist workflow checkpoints in Azure AI Foundry. + +Foundry managed checkpoints enable workflow state to be persisted across +requests, allowing workflows to be paused, resumed, and replayed. + +Prerequisites: + - Set AZURE_AI_PROJECT_ENDPOINT to your Azure AI Foundry project endpoint + e.g. "https://.services.ai.azure.com/api/projects/" + - Azure credentials configured (e.g. az login) +""" + +import asyncio +import os + +from dotenv import load_dotenv + +from agent_framework import ChatAgent, WorkflowBuilder +from agent_framework.azure import AzureAIAgentClient +from azure.identity.aio import AzureCliCredential + +from azure.ai.agentserver.agentframework import from_agent_framework +from azure.ai.agentserver.agentframework.persistence import FoundryCheckpointRepository + +load_dotenv() + + +def create_writer_agent(client: AzureAIAgentClient) -> ChatAgent: + """Create a writer agent that generates content.""" + return client.create_agent( + name="Writer", + instructions=( + "You are an excellent content writer. " + "You create new content and edit contents based on the feedback." + ), + ) + + +def create_reviewer_agent(client: AzureAIAgentClient) -> ChatAgent: + """Create a reviewer agent that provides feedback.""" + return client.create_agent( + name="Reviewer", + instructions=( + "You are an excellent content reviewer. " + "Provide actionable feedback to the writer about the provided content. " + "Provide the feedback in the most concise manner possible." + ), + ) + + +async def main() -> None: + """Run the workflow agent with Foundry managed checkpoints.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT", "") + + async with AzureCliCredential() as cred, AzureAIAgentClient(credential=cred) as client: + builder = ( + WorkflowBuilder() + .register_agent(lambda: create_writer_agent(client), name="writer") + .register_agent(lambda: create_reviewer_agent(client), name="reviewer", output_response=True) + .set_start_executor("writer") + .add_edge("writer", "reviewer") + ) + + # Use FoundryCheckpointRepository for Azure AI Foundry managed storage. + # This persists workflow checkpoints remotely, enabling pause/resume + # across requests and server restarts. + checkpoint_repository = FoundryCheckpointRepository( + project_endpoint=project_endpoint, + credential=cred, + ) + + await from_agent_framework( + builder, + checkpoint_repository=checkpoint_repository, + ).run_async() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py deleted file mode 100644 index 4a5d26360bce..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests package diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/conftest.py deleted file mode 100644 index a56a7164c0a3..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/conftest.py +++ /dev/null @@ -1,15 +0,0 @@ -""" -Pytest configuration and shared fixtures for unit tests. -""" - -import sys -from pathlib import Path - -# Ensure package sources are importable during tests -tests_root = Path(__file__).resolve() -src_root = tests_root.parents[4] -packages_root = tests_root.parents[2] / "packages" - -for path in (packages_root, src_root): - if str(path) not in sys.path: - sys.path.insert(0, str(path)) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/conftest.py new file mode 100644 index 000000000000..cd00c924c030 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/conftest.py @@ -0,0 +1,31 @@ +""" +Pytest configuration and shared fixtures for unit tests. +""" + +# Workaround: importing agent_framework (via mcp) can fail with +# KeyError: 'pydantic.root_model' unless this module is imported first. +import pydantic.root_model # noqa: F401 + +import site +import sys +from pathlib import Path + + +# Ensure we don't import user-site packages that can conflict with the active +# environment (e.g., a user-installed cryptography wheel causing PyO3 errors). +try: + user_site = site.getusersitepackages() + if user_site: + sys.path[:] = [p for p in sys.path if p != user_site] +except Exception: + # Best-effort: if site isn't fully configured, proceed without filtering. + pass + +# Ensure package sources are importable during tests +tests_root = Path(__file__).resolve() +src_root = tests_root.parents[4] +packages_root = tests_root.parents[2] / "packages" + +for path in (packages_root, src_root): + if str(path) not in sys.path: + sys.path.insert(0, str(path)) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/mocks/__init__.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/mocks/__init__.py new file mode 100644 index 000000000000..4436d04866df --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/mocks/__init__.py @@ -0,0 +1,8 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mock implementations for testing.""" + +from .mock_checkpoint_client import MockFoundryCheckpointClient + +__all__ = ["MockFoundryCheckpointClient"] diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/mocks/mock_checkpoint_client.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/mocks/mock_checkpoint_client.py new file mode 100644 index 000000000000..ffc1e2fcc4c1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/mocks/mock_checkpoint_client.py @@ -0,0 +1,156 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mock implementation of FoundryCheckpointClient for testing.""" + +from typing import Any, Dict, List, Optional + +from azure.ai.agentserver.core.checkpoints.client import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, +) + + +class MockFoundryCheckpointClient: + """In-memory mock for FoundryCheckpointClient for unit testing. + + Stores checkpoints in memory without making any HTTP calls. + """ + + def __init__(self, endpoint: str = "https://mock.endpoint") -> None: + """Initialize the mock client. + + :param endpoint: The mock endpoint URL. + :type endpoint: str + """ + self._endpoint = endpoint + self._sessions: Dict[str, CheckpointSession] = {} + self._items: Dict[str, CheckpointItem] = {} + + def _item_key(self, item_id: CheckpointItemId) -> str: + """Generate a unique key for a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The unique key. + :rtype: str + """ + return f"{item_id.session_id}:{item_id.item_id}" + + # Session operations + + async def upsert_session(self, session: CheckpointSession) -> CheckpointSession: + """Create or update a checkpoint session. + + :param session: The checkpoint session to upsert. + :type session: CheckpointSession + :return: The upserted checkpoint session. + :rtype: CheckpointSession + """ + self._sessions[session.session_id] = session + return session + + async def read_session(self, session_id: str) -> Optional[CheckpointSession]: + """Read a checkpoint session by ID. + + :param session_id: The session identifier. + :type session_id: str + :return: The checkpoint session if found, None otherwise. + :rtype: Optional[CheckpointSession] + """ + return self._sessions.get(session_id) + + async def delete_session(self, session_id: str) -> None: + """Delete a checkpoint session. + + :param session_id: The session identifier. + :type session_id: str + """ + self._sessions.pop(session_id, None) + # Also delete all items in the session + keys_to_delete = [ + key for key, item in self._items.items() if item.session_id == session_id + ] + for key in keys_to_delete: + del self._items[key] + + # Item operations + + async def create_items(self, items: List[CheckpointItem]) -> List[CheckpointItem]: + """Create checkpoint items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The created checkpoint items. + :rtype: List[CheckpointItem] + """ + for item in items: + key = self._item_key(item.to_item_id()) + self._items[key] = item + return items + + async def read_item(self, item_id: CheckpointItemId) -> Optional[CheckpointItem]: + """Read a checkpoint item by ID. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The checkpoint item if found, None otherwise. + :rtype: Optional[CheckpointItem] + """ + key = self._item_key(item_id) + return self._items.get(key) + + async def delete_item(self, item_id: CheckpointItemId) -> bool: + """Delete a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: True if the item was deleted, False if not found. + :rtype: bool + """ + key = self._item_key(item_id) + if key in self._items: + del self._items[key] + return True + return False + + async def list_item_ids( + self, session_id: str, parent_id: Optional[str] = None + ) -> List[CheckpointItemId]: + """List checkpoint item IDs for a session. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier for filtering. + :type parent_id: Optional[str] + :return: List of checkpoint item identifiers. + :rtype: List[CheckpointItemId] + """ + result = [] + for item in self._items.values(): + if item.session_id == session_id: + if parent_id is None or item.parent_id == parent_id: + result.append(item.to_item_id()) + return result + + # Context manager methods + + async def close(self) -> None: + """Close the client (no-op for mock).""" + pass + + async def __aenter__(self) -> "MockFoundryCheckpointClient": + """Enter the async context manager. + + :return: The client instance. + :rtype: MockFoundryCheckpointClient + """ + return self + + async def __aexit__(self, *exc_details: Any) -> None: + """Exit the async context manager. + + :param exc_details: Exception details if an exception occurred. + """ + pass diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py index 3dab36131f8d..d52d0e481bd2 100644 --- a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_agent_framework_input_converter.py @@ -16,26 +16,30 @@ def converter() -> AgentFrameworkInputConverter: @pytest.mark.unit -def test_transform_none_returns_none(converter: AgentFrameworkInputConverter) -> None: - assert converter.transform_input(None) is None +@pytest.mark.asyncio +async def test_transform_none_returns_none(converter: AgentFrameworkInputConverter) -> None: + assert await converter.transform_input(None) is None @pytest.mark.unit -def test_transform_string_returns_same(converter: AgentFrameworkInputConverter) -> None: - assert converter.transform_input("hello") == "hello" +@pytest.mark.asyncio +async def test_transform_string_returns_same(converter: AgentFrameworkInputConverter) -> None: + assert await converter.transform_input("hello") == "hello" @pytest.mark.unit -def test_transform_implicit_user_message_with_string(converter: AgentFrameworkInputConverter) -> None: +@pytest.mark.asyncio +async def test_transform_implicit_user_message_with_string(converter: AgentFrameworkInputConverter) -> None: payload = [{"content": "How are you?"}] - result = converter.transform_input(payload) + result = await converter.transform_input(payload) assert result == "How are you?" @pytest.mark.unit -def test_transform_implicit_user_message_with_input_text_list(converter: AgentFrameworkInputConverter) -> None: +@pytest.mark.asyncio +async def test_transform_implicit_user_message_with_input_text_list(converter: AgentFrameworkInputConverter) -> None: payload = [ { "content": [ @@ -45,13 +49,14 @@ def test_transform_implicit_user_message_with_input_text_list(converter: AgentFr } ] - result = converter.transform_input(payload) + result = await converter.transform_input(payload) assert result == "Hello world" @pytest.mark.unit -def test_transform_explicit_message_returns_chat_message(converter: AgentFrameworkInputConverter) -> None: +@pytest.mark.asyncio +async def test_transform_explicit_message_returns_chat_message(converter: AgentFrameworkInputConverter) -> None: payload = [ { "type": "message", @@ -62,7 +67,7 @@ def test_transform_explicit_message_returns_chat_message(converter: AgentFramewo } ] - result = converter.transform_input(payload) + result = await converter.transform_input(payload) assert isinstance(result, ChatMessage) assert result.role == ChatRole.ASSISTANT @@ -70,7 +75,8 @@ def test_transform_explicit_message_returns_chat_message(converter: AgentFramewo @pytest.mark.unit -def test_transform_multiple_explicit_messages_returns_list(converter: AgentFrameworkInputConverter) -> None: +@pytest.mark.asyncio +async def test_transform_multiple_explicit_messages_returns_list(converter: AgentFrameworkInputConverter) -> None: payload = [ { "type": "message", @@ -86,7 +92,7 @@ def test_transform_multiple_explicit_messages_returns_list(converter: AgentFrame }, ] - result = converter.transform_input(payload) + result = await converter.transform_input(payload) assert isinstance(result, list) assert len(result) == 2 @@ -98,7 +104,8 @@ def test_transform_multiple_explicit_messages_returns_list(converter: AgentFrame @pytest.mark.unit -def test_transform_mixed_messages_coerces_to_strings(converter: AgentFrameworkInputConverter) -> None: +@pytest.mark.asyncio +async def test_transform_mixed_messages_coerces_to_strings(converter: AgentFrameworkInputConverter) -> None: payload = [ {"content": "First"}, { @@ -110,21 +117,23 @@ def test_transform_mixed_messages_coerces_to_strings(converter: AgentFrameworkIn }, ] - result = converter.transform_input(payload) + result = await converter.transform_input(payload) assert result == ["First", "Second"] @pytest.mark.unit -def test_transform_invalid_input_type_raises(converter: AgentFrameworkInputConverter) -> None: +@pytest.mark.asyncio +async def test_transform_invalid_input_type_raises(converter: AgentFrameworkInputConverter) -> None: with pytest.raises(Exception) as exc_info: - converter.transform_input({"content": "invalid"}) + await converter.transform_input({"content": "invalid"}) assert "Unsupported input type" in str(exc_info.value) @pytest.mark.unit -def test_transform_skips_non_text_entries(converter: AgentFrameworkInputConverter) -> None: +@pytest.mark.asyncio +async def test_transform_skips_non_text_entries(converter: AgentFrameworkInputConverter) -> None: payload = [ { "content": [ @@ -134,6 +143,6 @@ def test_transform_skips_non_text_entries(converter: AgentFrameworkInputConverte } ] - result = converter.transform_input(payload) + result = await converter.transform_input(payload) assert result is None diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py new file mode 100644 index 000000000000..85fdc34f6498 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_conversation_id_optional.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from unittest.mock import Mock + +import pytest +from agent_framework import AgentThread, InMemoryCheckpointStorage + +from azure.ai.agentserver.agentframework.persistence.agent_thread_repository import ( + InMemoryAgentThreadRepository, +) +from azure.ai.agentserver.agentframework.persistence.checkpoint_repository import ( + InMemoryCheckpointRepository, +) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_inmemory_thread_repository_ignores_missing_conversation_id() -> None: + repo = InMemoryAgentThreadRepository() + thread = Mock(spec=AgentThread) + + await repo.set(None, thread) + assert await repo.get(None) is None + + await repo.set("conv-1", thread) + assert await repo.get("conv-1") is thread + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_inmemory_checkpoint_repository_returns_none_without_conversation_id() -> None: + repo = InMemoryCheckpointRepository() + + assert await repo.get_or_create(None) is None + + storage = await repo.get_or_create("conv-1") + assert isinstance(storage, InMemoryCheckpointStorage) diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_repository.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_repository.py new file mode 100644 index 000000000000..ddbbe616b5e5 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_repository.py @@ -0,0 +1,111 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryCheckpointRepository.""" + +import pytest + +from azure.ai.agentserver.agentframework.persistence import ( + FoundryCheckpointRepository, + FoundryCheckpointStorage, +) + +from .mocks import MockFoundryCheckpointClient + + +class TestableFoundryCheckpointRepository(FoundryCheckpointRepository): + """Testable version that accepts a mock client.""" + + def __init__(self, client: MockFoundryCheckpointClient) -> None: + """Initialize with a mock client (bypass credential requirements).""" + self._client = client # type: ignore[assignment] + self._inventory = {} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_get_or_create_returns_none_without_conversation_id() -> None: + """Test that get_or_create returns None when conversation_id is None.""" + client = MockFoundryCheckpointClient() + repo = TestableFoundryCheckpointRepository(client=client) + + result = await repo.get_or_create(None) + + assert result is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_get_or_create_returns_none_for_empty_string() -> None: + """Test that get_or_create returns None when conversation_id is empty.""" + client = MockFoundryCheckpointClient() + repo = TestableFoundryCheckpointRepository(client=client) + + result = await repo.get_or_create("") + + assert result is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_get_or_create_creates_storage_on_first_access() -> None: + """Test that get_or_create creates storage on first access.""" + client = MockFoundryCheckpointClient() + repo = TestableFoundryCheckpointRepository(client=client) + + storage = await repo.get_or_create("conv-123") + + assert storage is not None + assert isinstance(storage, FoundryCheckpointStorage) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_get_or_create_creates_session_on_first_access() -> None: + """Test that get_or_create creates session on the backend.""" + client = MockFoundryCheckpointClient() + repo = TestableFoundryCheckpointRepository(client=client) + + await repo.get_or_create("conv-123") + + # Verify session was created + session = await client.read_session("conv-123") + assert session is not None + assert session.session_id == "conv-123" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_get_or_create_caches_storage_instances() -> None: + """Test that get_or_create returns cached storage on subsequent calls.""" + client = MockFoundryCheckpointClient() + repo = TestableFoundryCheckpointRepository(client=client) + + storage1 = await repo.get_or_create("conv-123") + storage2 = await repo.get_or_create("conv-123") + + assert storage1 is storage2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_get_or_create_creates_separate_storage_per_conversation() -> None: + """Test that get_or_create creates separate storage per conversation.""" + client = MockFoundryCheckpointClient() + repo = TestableFoundryCheckpointRepository(client=client) + + storage1 = await repo.get_or_create("conv-1") + storage2 = await repo.get_or_create("conv-2") + + assert storage1 is not storage2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_close_closes_client() -> None: + """Test that close closes the underlying client.""" + client = MockFoundryCheckpointClient() + repo = TestableFoundryCheckpointRepository(client=client) + + # Should not raise + await repo.close() diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_storage.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_storage.py new file mode 100644 index 000000000000..fb1b8d9c0e30 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_checkpoint_storage.py @@ -0,0 +1,165 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryCheckpointStorage.""" + +import pytest +from agent_framework import WorkflowCheckpoint + +from azure.ai.agentserver.agentframework.persistence import FoundryCheckpointStorage + +from .mocks import MockFoundryCheckpointClient + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_save_checkpoint_returns_checkpoint_id() -> None: + """Test that save_checkpoint returns the checkpoint ID.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + checkpoint = WorkflowCheckpoint( + checkpoint_id="cp-123", + workflow_id="wf-1", + ) + + result = await storage.save_checkpoint(checkpoint) + + assert result == "cp-123" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_load_checkpoint_returns_checkpoint() -> None: + """Test that load_checkpoint returns the checkpoint.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + checkpoint = WorkflowCheckpoint( + checkpoint_id="cp-123", + workflow_id="wf-1", + iteration_count=5, + ) + + await storage.save_checkpoint(checkpoint) + loaded = await storage.load_checkpoint("cp-123") + + assert loaded is not None + assert loaded.checkpoint_id == "cp-123" + assert loaded.workflow_id == "wf-1" + assert loaded.iteration_count == 5 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_load_checkpoint_returns_none_for_missing() -> None: + """Test that load_checkpoint returns None for missing checkpoint.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + loaded = await storage.load_checkpoint("nonexistent") + + assert loaded is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_list_checkpoint_ids_returns_all_ids() -> None: + """Test that list_checkpoint_ids returns all checkpoint IDs.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + cp1 = WorkflowCheckpoint(checkpoint_id="cp-1", workflow_id="wf-1") + cp2 = WorkflowCheckpoint(checkpoint_id="cp-2", workflow_id="wf-1") + cp3 = WorkflowCheckpoint(checkpoint_id="cp-3", workflow_id="wf-2") + + await storage.save_checkpoint(cp1) + await storage.save_checkpoint(cp2) + await storage.save_checkpoint(cp3) + + ids = await storage.list_checkpoint_ids() + + assert set(ids) == {"cp-1", "cp-2", "cp-3"} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_list_checkpoint_ids_filters_by_workflow() -> None: + """Test that list_checkpoint_ids filters by workflow_id.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + cp1 = WorkflowCheckpoint(checkpoint_id="cp-1", workflow_id="wf-1") + cp2 = WorkflowCheckpoint(checkpoint_id="cp-2", workflow_id="wf-1") + cp3 = WorkflowCheckpoint(checkpoint_id="cp-3", workflow_id="wf-2") + + await storage.save_checkpoint(cp1) + await storage.save_checkpoint(cp2) + await storage.save_checkpoint(cp3) + + ids = await storage.list_checkpoint_ids(workflow_id="wf-1") + + assert set(ids) == {"cp-1", "cp-2"} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_list_checkpoints_returns_all_checkpoints() -> None: + """Test that list_checkpoints returns all checkpoints.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + cp1 = WorkflowCheckpoint(checkpoint_id="cp-1", workflow_id="wf-1") + cp2 = WorkflowCheckpoint(checkpoint_id="cp-2", workflow_id="wf-2") + + await storage.save_checkpoint(cp1) + await storage.save_checkpoint(cp2) + + checkpoints = await storage.list_checkpoints() + + assert len(checkpoints) == 2 + ids = {cp.checkpoint_id for cp in checkpoints} + assert ids == {"cp-1", "cp-2"} + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_delete_checkpoint_returns_true_for_existing() -> None: + """Test that delete_checkpoint returns True for existing checkpoint.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + checkpoint = WorkflowCheckpoint(checkpoint_id="cp-123", workflow_id="wf-1") + await storage.save_checkpoint(checkpoint) + + deleted = await storage.delete_checkpoint("cp-123") + + assert deleted is True + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_delete_checkpoint_returns_false_for_missing() -> None: + """Test that delete_checkpoint returns False for missing checkpoint.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + deleted = await storage.delete_checkpoint("nonexistent") + + assert deleted is False + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_delete_checkpoint_removes_from_storage() -> None: + """Test that delete_checkpoint actually removes the checkpoint.""" + client = MockFoundryCheckpointClient() + storage = FoundryCheckpointStorage(client=client, session_id="session-1") + + checkpoint = WorkflowCheckpoint(checkpoint_id="cp-123", workflow_id="wf-1") + await storage.save_checkpoint(checkpoint) + + await storage.delete_checkpoint("cp-123") + loaded = await storage.load_checkpoint("cp-123") + + assert loaded is None diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_tools.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_tools.py new file mode 100644 index 000000000000..ba30e06cfe78 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_foundry_tools.py @@ -0,0 +1,188 @@ +import importlib +import inspect +from types import SimpleNamespace +from unittest.mock import AsyncMock + +from typing import Any + +import pytest +from agent_framework import AIFunction, ChatOptions +from pydantic import Field, create_model + +# Import schema models directly from client._models to avoid heavy azure.identity import +# chain triggered by azure.ai.agentserver.core.__init__.py +from azure.ai.agentserver.core.tools.client._models import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) + +# Load _foundry_tools module directly without triggering parent __init__ which has heavy deps +import importlib.util +import sys +from pathlib import Path + +foundry_tools_path = Path(__file__).parent.parent.parent / "azure" / "ai" / "agentserver" / "agentframework" / "_foundry_tools.py" +spec = importlib.util.spec_from_file_location("_foundry_tools", foundry_tools_path) +foundry_tools_module = importlib.util.module_from_spec(spec) +sys.modules["_foundry_tools"] = foundry_tools_module +spec.loader.exec_module(foundry_tools_module) + +FoundryToolClient = foundry_tools_module.FoundryToolClient +FoundryToolsChatMiddleware = foundry_tools_module.FoundryToolsChatMiddleware +_attach_signature_from_pydantic_model = foundry_tools_module._attach_signature_from_pydantic_model + + +@pytest.mark.unit +def test_attach_signature_from_pydantic_model_required_and_optional() -> None: + InputModel = create_model( + "InputModel", + required_int=(int, Field(description="required")), + optional_str=(str | None, Field(default=None, description="optional")), + ) + + async def tool_func(**kwargs): + return kwargs + + _attach_signature_from_pydantic_model(tool_func, InputModel) + + sig = inspect.signature(tool_func) + assert list(sig.parameters.keys()) == ["required_int", "optional_str"] + assert all(p.kind is inspect.Parameter.KEYWORD_ONLY for p in sig.parameters.values()) + assert sig.parameters["required_int"].default is inspect._empty + assert sig.parameters["optional_str"].default is None + + # Ensure annotations are also attached. + assert tool_func.__annotations__["required_int"] is int + assert tool_func.__annotations__["return"] is Any + + +def _make_resolved_tool(*, name: str = "my_tool", description: str = "desc") -> ResolvedFoundryTool: + schema = SchemaDefinition( + properties={ + "a": SchemaProperty(type=SchemaType.STRING, description="field a"), + "b": SchemaProperty(type=SchemaType.INTEGER, description="field b"), + }, + required={"a"}, + ) + definition = FoundryHostedMcpTool(name=name) + details = FoundryToolDetails(name=name, description=description, input_schema=schema) + return ResolvedFoundryTool(definition=definition, details=details) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_to_aifunction_builds_pydantic_model_and_invokes(monkeypatch: pytest.MonkeyPatch) -> None: + resolved_tool = _make_resolved_tool(name="echo", description="Echo tool") + + invoke = AsyncMock(return_value={"ok": True}) + server_context = SimpleNamespace( + tools=SimpleNamespace( + invoke=invoke, + catalog=SimpleNamespace(list=AsyncMock()), + ) + ) + monkeypatch.setattr( + foundry_tools_module.AgentServerContext, + "get", + classmethod(lambda cls: server_context), + ) + + client = FoundryToolClient(tools=[]) + ai_func = client._to_aifunction(resolved_tool) + assert isinstance(ai_func, AIFunction) + assert ai_func.name == "echo" + assert ai_func.description == "Echo tool" + + # Signature should be attached from schema (a required, b optional) + sig = inspect.signature(ai_func.func) + assert list(sig.parameters.keys()) == ["a", "b"] + assert sig.parameters["a"].default is inspect._empty + assert sig.parameters["b"].default is None + + result = await ai_func.func(a="hi", b=123) + assert result == {"ok": True} + invoke.assert_awaited_once_with(resolved_tool, {"a": "hi", "b": 123}) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_list_tools_uses_catalog_and_converts(monkeypatch: pytest.MonkeyPatch) -> None: + allowed = [FoundryHostedMcpTool(name="allowed_tool")] + resolved = [_make_resolved_tool(name="allowed_tool", description="Allowed")] + + catalog_list = AsyncMock(return_value=resolved) + server_context = SimpleNamespace( + tools=SimpleNamespace( + catalog=SimpleNamespace(list=catalog_list), + invoke=AsyncMock(), + ) + ) + monkeypatch.setattr( + foundry_tools_module.AgentServerContext, + "get", + classmethod(lambda cls: server_context), + ) + + client = FoundryToolClient(tools=allowed) + functions = await client.list_tools() + + catalog_list.assert_awaited_once() + args, _kwargs = catalog_list.await_args + assert args[0] == list(allowed) + + assert len(functions) == 1 + assert isinstance(functions[0], AIFunction) + assert functions[0].name == "allowed_tool" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_middleware_process_creates_chat_options_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: + middleware = FoundryToolsChatMiddleware(tools=[]) + + async def dummy_tool(**kwargs): + return kwargs + + DummyInput = create_model("DummyInput") + injected = [AIFunction(name="t", description="d", func=dummy_tool, input_model=DummyInput)] + monkeypatch.setattr(middleware._foundry_tool_client, "list_tools", AsyncMock(return_value=injected)) + + context = SimpleNamespace(chat_options=None) + next_fn = AsyncMock() + + await middleware.process(context, next_fn) + + assert isinstance(context.chat_options, ChatOptions) + assert context.chat_options.tools == injected + next_fn.assert_awaited_once_with(context) + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_middleware_process_appends_to_existing_chat_options(monkeypatch: pytest.MonkeyPatch) -> None: + middleware = FoundryToolsChatMiddleware(tools=[]) + + async def dummy_tool(**kwargs): + return kwargs + + DummyInput = create_model("DummyInput") + injected = [AIFunction(name="t2", description="d2", func=dummy_tool, input_model=DummyInput)] + monkeypatch.setattr(middleware._foundry_tool_client, "list_tools", AsyncMock(return_value=injected)) + + # Existing ChatOptions with no tools should become injected + context = SimpleNamespace(chat_options=ChatOptions()) + next_fn = AsyncMock() + await middleware.process(context, next_fn) + assert context.chat_options.tools == injected + + # Existing ChatOptions with tools should be appended + existing = [AIFunction(name="t1", description="d1", func=dummy_tool, input_model=DummyInput)] + context = SimpleNamespace(chat_options=ChatOptions(tools=existing)) + next_fn = AsyncMock() + await middleware.process(context, next_fn) + assert context.chat_options.tools == existing + injected + assert next_fn.await_count == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py new file mode 100644 index 000000000000..c1616c475cb4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-agentframework/tests/unit_tests/test_from_agent_framework_managed.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for from_agent_framework with checkpoint repository.""" + +import pytest +from unittest.mock import Mock + +from azure.core.credentials_async import AsyncTokenCredential + + +@pytest.mark.unit +def test_checkpoint_repository_is_optional() -> None: + """Test that checkpoint_repository is optional and defaults to None.""" + from azure.ai.agentserver.agentframework import from_agent_framework + from agent_framework import WorkflowBuilder + + builder = WorkflowBuilder() + + # Should not raise + adapter = from_agent_framework(builder) + + assert adapter is not None + + +@pytest.mark.unit +def test_foundry_checkpoint_repository_passed_directly() -> None: + """Test that FoundryCheckpointRepository can be passed via checkpoint_repository.""" + from azure.ai.agentserver.agentframework import from_agent_framework + from azure.ai.agentserver.agentframework.persistence import FoundryCheckpointRepository + from agent_framework import WorkflowBuilder + + builder = WorkflowBuilder() + mock_credential = Mock(spec=AsyncTokenCredential) + + repo = FoundryCheckpointRepository( + project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", + credential=mock_credential, + ) + + adapter = from_agent_framework( + builder, + checkpoint_repository=repo, + ) + + assert adapter is not None + assert adapter._checkpoint_repository is repo diff --git a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md index cfcf2445e256..8cd2e9b9a861 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-core/CHANGELOG.md @@ -1,5 +1,85 @@ # Release History + +## 1.0.0b10 (2026-01-27) + +### Bugs Fixed + +- Make AZURE_AI_PROJECTS_ENDPOINT optional. + +## 1.0.0b9 (2026-01-23) + +### Features Added + +- Integrated with Foundry Tools + + +## 1.0.0b8 (2026-01-21) + +### Features Added + +- Support keep alive for long-running streaming responses. + + +## 1.0.0b7 (2025-12-05) + +### Features Added + +- Update response with created_by + +### Bugs Fixed + +- Fixed error response handling in stream and non-stream modes + +## 1.0.0b6 (2025-11-26) + +### Features Added + +- Support Agent-framework greater than 251112 + + +## 1.0.0b5 (2025-11-16) + +### Features Added + +- Support Tools Oauth + +### Bugs Fixed + +- Fixed streaming generation issues. + + +## 1.0.0b4 (2025-11-13) + +### Features Added + +- Adapters support tools + +### Bugs Fixed + +- Pin azure-ai-projects and azure-ai-agents version to avoid version confliction + + +## 1.0.0b3 (2025-11-11) + +### Bugs Fixed + +- Fixed Id generator format. + +- Fixed trace initialization for agent-framework. + + +## 1.0.0b2 (2025-11-10) + +### Bugs Fixed + +- Fixed Id generator format. + +- Improved stream mode error messsage. + +- Updated application insights related configuration environment variables. + + ## 1.0.0b1 (2025-11-07) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py index 895074d32ae3..88a13741bbac 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/__init__.py @@ -7,8 +7,9 @@ from .logger import configure as config_logging from .server.base import FoundryCBAgent from .server.common.agent_run_context import AgentRunContext +from .server._context import AgentServerContext config_logging() -__all__ = ["FoundryCBAgent", "AgentRunContext"] +__all__ = ["FoundryCBAgent", "AgentRunContext", "AgentServerContext"] __version__ = VERSION diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py index be71c81bd282..9ab0a006e0d0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b1" +VERSION = "1.0.0b10" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py new file mode 100644 index 000000000000..052f9894497b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/__init__.py @@ -0,0 +1,21 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) + +__all__ = [ + "AgentServerMetadata", + "PackageMetadata", + "RuntimeMetadata", + "get_current_app", + "set_current_app" +] + +from ._metadata import ( + AgentServerMetadata, + PackageMetadata, + RuntimeMetadata, + get_current_app, + set_current_app, +) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_builder.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_builder.py new file mode 100644 index 000000000000..c09c253ab09f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_builder.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +class AgentServerBuilder: + pass diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py new file mode 100644 index 000000000000..1f8a01d57639 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_configuration.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from dataclasses import dataclass, field + +from azure.core.credentials_async import AsyncTokenCredential + + +@dataclass(frozen=True) +class HttpServerConfiguration: + """Resolved configuration for the HTTP server. + + :ivar str host: The host address the server listens on. Defaults to '0.0.0.0'. + :ivar int port: The port number the server listens on. Defaults to 8088. + """ + + host: str = "0.0.0.0" + port: int = 8088 + + +class ToolsConfiguration: + """Resolved configuration for the Tools subsystem. + + :ivar int catalog_cache_ttl: The time-to-live (TTL) for the tool catalog cache in seconds. + Defaults to 600 seconds (10 minutes). + :ivar int catalog_cache_max_size: The maximum size of the tool catalog cache. + Defaults to 1024 entries. + """ + + catalog_cache_ttl: int = 600 + catalog_cache_max_size: int = 1024 + + +@dataclass(frozen=True) +class AgentServerConfiguration: + """Resolved configuration for the Agent Server application.""" + + project_endpoint: str + credential: AsyncTokenCredential + agent_name: str = "$default" + http: HttpServerConfiguration = field(default_factory=HttpServerConfiguration) + tools: ToolsConfiguration = field(default_factory=ToolsConfiguration) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py new file mode 100644 index 000000000000..1ddc3d2d2e9e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_metadata.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import os +import platform +from dataclasses import dataclass, field +from importlib.metadata import Distribution, PackageNotFoundError +from typing import Mapping + + +@dataclass(frozen=True, kw_only=True) +class PackageMetadata: + name: str + version: str + + @staticmethod + def from_dist(dist_name: str) -> "PackageMetadata": + try: + ver = Distribution.from_name(dist_name).version + except PackageNotFoundError: + ver = "" + + return PackageMetadata( + name=dist_name, + version=ver, + ) + + +@dataclass(frozen=True, kw_only=True) +class RuntimeMetadata: + python_version: str = field(default_factory=platform.python_version) + platform: str = field(default_factory=platform.platform) + host_name: str = "" + replica_name: str = "" + + @staticmethod + def from_aca_app_env() -> "RuntimeMetadata | None": + host_name = os.environ.get("CONTAINER_APP_REVISION_FQDN") + replica_name = os.environ.get("CONTAINER_APP_REPLICA_NAME") + + if not host_name and not replica_name: + return None + + return RuntimeMetadata( + host_name=host_name or "", + replica_name=replica_name or "", + ) + + @staticmethod + def resolve(host_name: str | None = None, replica_name: str | None = None) -> "RuntimeMetadata": + runtime = RuntimeMetadata.from_aca_app_env() + + override = RuntimeMetadata(host_name=host_name or "", replica_name=replica_name or "") + return runtime.merged_with(override) if runtime else override + + def merged_with(self, override: "RuntimeMetadata | None") -> "RuntimeMetadata": + if override is None: + return self + + return RuntimeMetadata( + python_version=override.python_version or self.python_version, + platform=override.platform or self.platform, + host_name=override.host_name or self.host_name, + replica_name=override.replica_name or self.replica_name, + ) + + +@dataclass(frozen=True) +class AgentServerMetadata: + package: PackageMetadata + runtime: RuntimeMetadata + + def as_user_agent(self, component: str | None = None) -> str: + component_value = f" {component}" if component else "" + return ( + f"{self.package.name}/{self.package.version} " + f"Python {self.runtime.python_version}{component_value} " + f"({self.runtime.platform})" + ) + + +_default = AgentServerMetadata( + package=PackageMetadata.from_dist("azure-ai-agentserver-core"), + runtime=RuntimeMetadata.resolve(), +) +_app: AgentServerMetadata = _default + + +def set_current_app(app: PackageMetadata, runtime: RuntimeMetadata | None = None) -> None: + global _app # pylint: disable=W0603 + resolved_runtime = RuntimeMetadata.resolve() + merged_runtime = resolved_runtime.merged_with(runtime) + _app = AgentServerMetadata(package=app, runtime=merged_runtime) + + +def get_current_app() -> AgentServerMetadata: + global _app # pylint: disable=W0602 + return _app diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py new file mode 100644 index 000000000000..d70270261e7b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/application/_options.py @@ -0,0 +1,45 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Literal, TypedDict, Union + +from typing_extensions import NotRequired + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + + +class AgentServerOptions(TypedDict): + """Configuration options for the Agent Server. + + Attributes: + project_endpoint (str, optional): The endpoint URL for the project. Defaults to current project. + credential (Union[AsyncTokenCredential, TokenCredential], optional): The credential used for authentication. + Defaults to current project's managed identity. + """ + project_endpoint: NotRequired[str] + credential: NotRequired[Union[AsyncTokenCredential, TokenCredential]] + http: NotRequired["HttpServerOptions"] + tools: NotRequired["ToolsOptions"] + + +class HttpServerOptions(TypedDict): + """Configuration options for the HTTP server. + + Attributes: + host (str, optional): The host address the server listens on. + """ + host: NotRequired[Literal["127.0.0.1", "localhost", "0.0.0.0"]] + + +class ToolsOptions(TypedDict): + """Configuration options for the Tools subsystem. + + Attributes: + catalog_cache_ttl (int, optional): The time-to-live (TTL) for the tool catalog cache in seconds. + Defaults to 600 seconds (10 minutes). + catalog_cache_max_size (int, optional): The maximum size of the tool catalog cache. + Defaults to 1024 entries. + """ + catalog_cache_ttl: NotRequired[int] + catalog_cache_max_size: NotRequired[int] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/__init__.py new file mode 100644 index 000000000000..f9d6ed3d8aa8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/__init__.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint storage module for Azure AI Agent Server.""" + +from .client import FoundryCheckpointClient +from .client._models import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, +) + +__all__ = [ + "CheckpointItem", + "CheckpointItemId", + "CheckpointSession", + "FoundryCheckpointClient", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/__init__.py new file mode 100644 index 000000000000..34f30f16c5d9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/__init__.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint client module for Azure AI Agent Server.""" + +from ._client import FoundryCheckpointClient +from ._models import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, +) + +__all__ = [ + "CheckpointItem", + "CheckpointItemId", + "CheckpointSession", + "FoundryCheckpointClient", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_client.py new file mode 100644 index 000000000000..8c16c1151b9e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_client.py @@ -0,0 +1,153 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Asynchronous client for Azure AI Foundry checkpoint storage API.""" + +from typing import Any, AsyncContextManager, List, Optional + +from azure.core import AsyncPipelineClient +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._configuration import FoundryCheckpointClientConfiguration +from ._models import CheckpointItem, CheckpointItemId, CheckpointSession +from .operations import CheckpointItemOperations, CheckpointSessionOperations + + +class FoundryCheckpointClient(AsyncContextManager["FoundryCheckpointClient"]): + """Asynchronous client for Azure AI Foundry checkpoint storage API. + + This client provides access to checkpoint storage for workflow state persistence, + enabling checkpoint save, load, list, and delete operations. + + :param endpoint: The fully qualified project endpoint for the Azure AI Foundry service. + Example: "https://.services.ai.azure.com/api/projects/" + :type endpoint: str + :param credential: Credential for authenticating requests to the service. + Use credentials from azure-identity like DefaultAzureCredential. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + """ + + def __init__( + self, + endpoint: str, + credential: "AsyncTokenCredential", + ) -> None: + """Initialize the asynchronous Azure AI Checkpoint Client. + + :param endpoint: The project endpoint URL (includes project context). + :type endpoint: str + :param credential: Credentials for authenticating requests. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + """ + config = FoundryCheckpointClientConfiguration(credential) + self._client: AsyncPipelineClient = AsyncPipelineClient( + base_url=endpoint, config=config + ) + self._sessions = CheckpointSessionOperations(self._client) + self._items = CheckpointItemOperations(self._client) + + # Session operations + + @distributed_trace_async + async def upsert_session(self, session: CheckpointSession) -> CheckpointSession: + """Create or update a checkpoint session. + + :param session: The checkpoint session to upsert. + :type session: CheckpointSession + :return: The upserted checkpoint session. + :rtype: CheckpointSession + """ + return await self._sessions.upsert(session) + + @distributed_trace_async + async def read_session(self, session_id: str) -> Optional[CheckpointSession]: + """Read a checkpoint session by ID. + + :param session_id: The session identifier. + :type session_id: str + :return: The checkpoint session if found, None otherwise. + :rtype: Optional[CheckpointSession] + """ + return await self._sessions.read(session_id) + + @distributed_trace_async + async def delete_session(self, session_id: str) -> None: + """Delete a checkpoint session. + + :param session_id: The session identifier. + :type session_id: str + """ + await self._sessions.delete(session_id) + + # Item operations + + @distributed_trace_async + async def create_items(self, items: List[CheckpointItem]) -> List[CheckpointItem]: + """Create checkpoint items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The created checkpoint items. + :rtype: List[CheckpointItem] + """ + return await self._items.create_batch(items) + + @distributed_trace_async + async def read_item(self, item_id: CheckpointItemId) -> Optional[CheckpointItem]: + """Read a checkpoint item by ID. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The checkpoint item if found, None otherwise. + :rtype: Optional[CheckpointItem] + """ + return await self._items.read(item_id) + + @distributed_trace_async + async def delete_item(self, item_id: CheckpointItemId) -> bool: + """Delete a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: True if the item was deleted, False if not found. + :rtype: bool + """ + return await self._items.delete(item_id) + + @distributed_trace_async + async def list_item_ids( + self, session_id: str, parent_id: Optional[str] = None + ) -> List[CheckpointItemId]: + """List checkpoint item IDs for a session. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier for filtering. + :type parent_id: Optional[str] + :return: List of checkpoint item identifiers. + :rtype: List[CheckpointItemId] + """ + return await self._items.list_ids(session_id, parent_id) + + # Context manager methods + + async def close(self) -> None: + """Close the underlying HTTP pipeline.""" + await self._client.close() + + async def __aenter__(self) -> "FoundryCheckpointClient": + """Enter the async context manager. + + :return: The client instance. + :rtype: FoundryCheckpointClient + """ + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + """Exit the async context manager. + + :param exc_details: Exception details if an exception occurred. + """ + await self._client.__aexit__(*exc_details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_configuration.py new file mode 100644 index 000000000000..cd9ed9ee7ff7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_configuration.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Configuration for Azure AI Checkpoint Client.""" + +from azure.core.configuration import Configuration +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.pipeline import policies + +from ...application._metadata import get_current_app + + +class FoundryCheckpointClientConfiguration(Configuration): + """Configuration for Azure AI Checkpoint Client. + + Manages authentication, endpoint configuration, and policy settings for the + Azure AI Checkpoint Client. This class is used internally by the client and should + not typically be instantiated directly. + + :param credential: Azure TokenCredential for authentication. + :type credential: ~azure.core.credentials_async.AsyncTokenCredential + """ + + def __init__(self, credential: "AsyncTokenCredential") -> None: + super().__init__() + + self.retry_policy = policies.AsyncRetryPolicy() + self.logging_policy = policies.NetworkTraceLoggingPolicy() + self.request_id_policy = policies.RequestIdPolicy() + self.http_logging_policy = policies.HttpLoggingPolicy() + self.user_agent_policy = policies.UserAgentPolicy( + base_user_agent=get_current_app().as_user_agent("FoundryCheckpointClient") + ) + self.authentication_policy = policies.AsyncBearerTokenCredentialPolicy( + credential, "https://ai.azure.com/.default" + ) + self.redirect_policy = policies.AsyncRedirectPolicy() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_models.py new file mode 100644 index 000000000000..b6e61aa2b6f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/_models.py @@ -0,0 +1,201 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Data models for checkpoint storage API.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +from pydantic import BaseModel, Field + + +@dataclass +class CheckpointSession: + """Represents a checkpoint session. + + A session maps to a conversation and groups related checkpoints together. + + :ivar session_id: The session identifier (maps to conversation_id). + :ivar metadata: Optional metadata for the session. + """ + + session_id: str + metadata: Optional[Dict[str, Any]] = None + + +@dataclass +class CheckpointItemId: + """Identifier for a checkpoint item. + + :ivar session_id: The session identifier this checkpoint belongs to. + :ivar item_id: The unique checkpoint item identifier. + :ivar parent_id: Optional parent checkpoint identifier for hierarchical checkpoints. + """ + + session_id: str + item_id: str + parent_id: Optional[str] = None + + +@dataclass +class CheckpointItem: + """Represents a single checkpoint item. + + Contains the serialized checkpoint data along with identifiers. + + :ivar session_id: The session identifier this checkpoint belongs to. + :ivar item_id: The unique checkpoint item identifier. + :ivar data: Serialized checkpoint data as bytes. + :ivar parent_id: Optional parent checkpoint identifier. + """ + + session_id: str + item_id: str + data: bytes + parent_id: Optional[str] = None + + def to_item_id(self) -> CheckpointItemId: + """Convert to a CheckpointItemId. + + :return: The checkpoint item identifier. + :rtype: CheckpointItemId + """ + return CheckpointItemId( + session_id=self.session_id, + item_id=self.item_id, + parent_id=self.parent_id, + ) + + +# Pydantic models for API request/response serialization + + +class CheckpointSessionRequest(BaseModel): + """Request model for creating/updating a checkpoint session.""" + + session_id: str = Field(alias="sessionId") + metadata: Optional[Dict[str, Any]] = None + + model_config = {"populate_by_name": True} + + @classmethod + def from_session(cls, session: CheckpointSession) -> "CheckpointSessionRequest": + """Create a request from a CheckpointSession. + + :param session: The checkpoint session. + :type session: CheckpointSession + :return: The request model. + :rtype: CheckpointSessionRequest + """ + return cls( + session_id=session.session_id, + metadata=session.metadata, + ) + + +class CheckpointSessionResponse(BaseModel): + """Response model for checkpoint session operations.""" + + session_id: str = Field(alias="sessionId") + metadata: Optional[Dict[str, Any]] = None + etag: Optional[str] = None + + model_config = {"populate_by_name": True} + + def to_session(self) -> CheckpointSession: + """Convert to a CheckpointSession. + + :return: The checkpoint session. + :rtype: CheckpointSession + """ + return CheckpointSession( + session_id=self.session_id, + metadata=self.metadata, + ) + + +class CheckpointItemIdResponse(BaseModel): + """Response model for checkpoint item identifiers.""" + + session_id: str = Field(alias="sessionId") + item_id: str = Field(alias="itemId") + parent_id: Optional[str] = Field(default=None, alias="parentId") + + model_config = {"populate_by_name": True} + + def to_item_id(self) -> CheckpointItemId: + """Convert to a CheckpointItemId. + + :return: The checkpoint item identifier. + :rtype: CheckpointItemId + """ + return CheckpointItemId( + session_id=self.session_id, + item_id=self.item_id, + parent_id=self.parent_id, + ) + + +class CheckpointItemRequest(BaseModel): + """Request model for creating checkpoint items.""" + + session_id: str = Field(alias="sessionId") + item_id: str = Field(alias="itemId") + data: str # Base64-encoded bytes + parent_id: Optional[str] = Field(default=None, alias="parentId") + + model_config = {"populate_by_name": True} + + @classmethod + def from_item(cls, item: CheckpointItem) -> "CheckpointItemRequest": + """Create a request from a CheckpointItem. + + :param item: The checkpoint item. + :type item: CheckpointItem + :return: The request model. + :rtype: CheckpointItemRequest + """ + import base64 + + return cls( + session_id=item.session_id, + item_id=item.item_id, + data=base64.b64encode(item.data).decode("utf-8"), + parent_id=item.parent_id, + ) + + +class CheckpointItemResponse(BaseModel): + """Response model for checkpoint item operations.""" + + session_id: str = Field(alias="sessionId") + item_id: str = Field(alias="itemId") + data: str # Base64-encoded bytes + parent_id: Optional[str] = Field(default=None, alias="parentId") + etag: Optional[str] = None + + model_config = {"populate_by_name": True} + + def to_item(self) -> CheckpointItem: + """Convert to a CheckpointItem. + + :return: The checkpoint item. + :rtype: CheckpointItem + """ + import base64 + + return CheckpointItem( + session_id=self.session_id, + item_id=self.item_id, + data=base64.b64decode(self.data), + parent_id=self.parent_id, + ) + + +class ListCheckpointItemIdsResponse(BaseModel): + """Response model for listing checkpoint item identifiers.""" + + value: List[CheckpointItemIdResponse] = Field(default_factory=list) + next_link: Optional[str] = Field(default=None, alias="nextLink") + + model_config = {"populate_by_name": True} diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/__init__.py new file mode 100644 index 000000000000..42ba9bacd20b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/__init__.py @@ -0,0 +1,12 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint operations module for Azure AI Agent Server.""" + +from ._items import CheckpointItemOperations +from ._sessions import CheckpointSessionOperations + +__all__ = [ + "CheckpointItemOperations", + "CheckpointSessionOperations", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_items.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_items.py new file mode 100644 index 000000000000..8c3fa4331446 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_items.py @@ -0,0 +1,199 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Operations for checkpoint items.""" + +from typing import Any, ClassVar, Dict, List, Optional + +from azure.core import AsyncPipelineClient +from azure.core.exceptions import ResourceNotFoundError +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ....tools.client.operations._base import BaseOperations +from .._models import ( + CheckpointItem, + CheckpointItemId, + CheckpointItemRequest, + CheckpointItemResponse, + ListCheckpointItemIdsResponse, +) + + +class CheckpointItemOperations(BaseOperations): + """Operations for managing checkpoint items.""" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = {"api-version": _API_VERSION} + + def _items_path(self, item_id: Optional[str] = None) -> str: + """Get the API path for item operations. + + :param item_id: Optional item identifier. + :type item_id: Optional[str] + :return: The API path. + :rtype: str + """ + base = "/checkpoints/items" + return f"{base}/{item_id}" if item_id else base + + def _build_create_batch_request(self, items: List[CheckpointItem]) -> HttpRequest: + """Build the HTTP request for creating items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The HTTP request. + :rtype: HttpRequest + """ + request_models = [CheckpointItemRequest.from_item(item) for item in items] + return self._client.post( + self._items_path(), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=[model.model_dump(by_alias=True) for model in request_models], + ) + + def _build_read_request(self, item_id: CheckpointItemId) -> HttpRequest: + """Build the HTTP request for reading an item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The HTTP request. + :rtype: HttpRequest + """ + params = dict(self._QUERY_PARAMS) + params["sessionId"] = item_id.session_id + if item_id.parent_id: + params["parentId"] = item_id.parent_id + return self._client.get( + self._items_path(item_id.item_id), + params=params, + headers=self._HEADERS, + ) + + def _build_delete_request(self, item_id: CheckpointItemId) -> HttpRequest: + """Build the HTTP request for deleting an item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The HTTP request. + :rtype: HttpRequest + """ + params = dict(self._QUERY_PARAMS) + params["sessionId"] = item_id.session_id + if item_id.parent_id: + params["parentId"] = item_id.parent_id + return self._client.delete( + self._items_path(item_id.item_id), + params=params, + headers=self._HEADERS, + ) + + def _build_list_ids_request( + self, session_id: str, parent_id: Optional[str] = None + ) -> HttpRequest: + """Build the HTTP request for listing item IDs. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier. + :type parent_id: Optional[str] + :return: The HTTP request. + :rtype: HttpRequest + """ + params = dict(self._QUERY_PARAMS) + params["sessionId"] = session_id + if parent_id: + params["parentId"] = parent_id + return self._client.get( + self._items_path(), + params=params, + headers=self._HEADERS, + ) + + @distributed_trace_async + async def create_batch(self, items: List[CheckpointItem]) -> List[CheckpointItem]: + """Create checkpoint items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The created checkpoint items. + :rtype: List[CheckpointItem] + """ + if not items: + return [] + + request = self._build_create_batch_request(items) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + if isinstance(json_response, list): + return [ + CheckpointItemResponse.model_validate(item).to_item() + for item in json_response + ] + # Single item response + return [CheckpointItemResponse.model_validate(json_response).to_item()] + + @distributed_trace_async + async def read(self, item_id: CheckpointItemId) -> Optional[CheckpointItem]: + """Read a checkpoint item by ID. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The checkpoint item if found, None otherwise. + :rtype: Optional[CheckpointItem] + """ + request = self._build_read_request(item_id) + try: + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + item_response = CheckpointItemResponse.model_validate(json_response) + return item_response.to_item() + except ResourceNotFoundError: + return None + + @distributed_trace_async + async def delete(self, item_id: CheckpointItemId) -> bool: + """Delete a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: True if the item was deleted, False if not found. + :rtype: bool + """ + request = self._build_delete_request(item_id) + try: + response = await self._send_request(request) + async with response: + pass # No response body expected + return True + except ResourceNotFoundError: + return False + + @distributed_trace_async + async def list_ids( + self, session_id: str, parent_id: Optional[str] = None + ) -> List[CheckpointItemId]: + """List checkpoint item IDs for a session. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier for filtering. + :type parent_id: Optional[str] + :return: List of checkpoint item identifiers. + :rtype: List[CheckpointItemId] + """ + request = self._build_list_ids_request(session_id, parent_id) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + list_response = ListCheckpointItemIdsResponse.model_validate(json_response) + return [item.to_item_id() for item in list_response.value] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_sessions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_sessions.py new file mode 100644 index 000000000000..b3587f8baf72 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/checkpoints/client/operations/_sessions.py @@ -0,0 +1,133 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Operations for checkpoint sessions.""" + +from typing import Any, ClassVar, Dict, Optional + +from azure.core import AsyncPipelineClient +from azure.core.exceptions import ResourceNotFoundError +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ....tools.client.operations._base import BaseOperations +from .._models import ( + CheckpointSession, + CheckpointSessionRequest, + CheckpointSessionResponse, +) + + +class CheckpointSessionOperations(BaseOperations): + """Operations for managing checkpoint sessions.""" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = {"api-version": _API_VERSION} + + def _session_path(self, session_id: Optional[str] = None) -> str: + """Get the API path for session operations. + + :param session_id: Optional session identifier. + :type session_id: Optional[str] + :return: The API path. + :rtype: str + """ + base = "/checkpoints/sessions" + return f"{base}/{session_id}" if session_id else base + + def _build_upsert_request(self, session: CheckpointSession) -> HttpRequest: + """Build the HTTP request for upserting a session. + + :param session: The checkpoint session. + :type session: CheckpointSession + :return: The HTTP request. + :rtype: HttpRequest + """ + request_model = CheckpointSessionRequest.from_session(session) + return self._client.put( + self._session_path(session.session_id), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=request_model.model_dump(by_alias=True), + ) + + def _build_read_request(self, session_id: str) -> HttpRequest: + """Build the HTTP request for reading a session. + + :param session_id: The session identifier. + :type session_id: str + :return: The HTTP request. + :rtype: HttpRequest + """ + return self._client.get( + self._session_path(session_id), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + ) + + def _build_delete_request(self, session_id: str) -> HttpRequest: + """Build the HTTP request for deleting a session. + + :param session_id: The session identifier. + :type session_id: str + :return: The HTTP request. + :rtype: HttpRequest + """ + return self._client.delete( + self._session_path(session_id), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + ) + + @distributed_trace_async + async def upsert(self, session: CheckpointSession) -> CheckpointSession: + """Create or update a checkpoint session. + + :param session: The checkpoint session to upsert. + :type session: CheckpointSession + :return: The upserted checkpoint session. + :rtype: CheckpointSession + """ + request = self._build_upsert_request(session) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + session_response = CheckpointSessionResponse.model_validate(json_response) + return session_response.to_session() + + @distributed_trace_async + async def read(self, session_id: str) -> Optional[CheckpointSession]: + """Read a checkpoint session by ID. + + :param session_id: The session identifier. + :type session_id: str + :return: The checkpoint session if found, None otherwise. + :rtype: Optional[CheckpointSession] + """ + request = self._build_read_request(session_id) + try: + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + session_response = CheckpointSessionResponse.model_validate(json_response) + return session_response.to_session() + except ResourceNotFoundError: + return None + + @distributed_trace_async + async def delete(self, session_id: str) -> None: + """Delete a checkpoint session. + + :param session_id: The session identifier. + :type session_id: str + """ + request = self._build_delete_request(session_id) + response = await self._send_request(request) + async with response: + pass # No response body expected diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py index a13f23aa261e..ae6c04235ff1 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/constants.py @@ -3,12 +3,14 @@ # --------------------------------------------------------- class Constants: # well-known environment variables - APPLICATION_INSIGHTS_CONNECTION_STRING = "_AGENT_RUNTIME_APP_INSIGHTS_CONNECTION_STRING" AZURE_AI_PROJECT_ENDPOINT = "AZURE_AI_PROJECT_ENDPOINT" AGENT_ID = "AGENT_ID" AGENT_NAME = "AGENT_NAME" AGENT_PROJECT_RESOURCE_ID = "AGENT_PROJECT_NAME" OTEL_EXPORTER_ENDPOINT = "OTEL_EXPORTER_ENDPOINT" + OTEL_EXPORTER_OTLP_PROTOCOL = "OTEL_EXPORTER_OTLP_PROTOCOL" AGENT_LOG_LEVEL = "AGENT_LOG_LEVEL" AGENT_DEBUG_ERRORS = "AGENT_DEBUG_ERRORS" - ENABLE_APPLICATION_INSIGHTS_LOGGER = "ENABLE_APPLICATION_INSIGHTS_LOGGER" + ENABLE_APPLICATION_INSIGHTS_LOGGER = "AGENT_APP_INSIGHTS_ENABLED" + AZURE_AI_WORKSPACE_ENDPOINT = "AZURE_AI_WORKSPACE_ENDPOINT" + AZURE_AI_TOOLS_ENDPOINT = "AZURE_AI_TOOLS_ENDPOINT" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py index f062398c0d3b..b9b0153f281a 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/logger.py @@ -28,6 +28,8 @@ request_context = contextvars.ContextVar("request_context", default=None) +APPINSIGHT_CONNSTR_ENV_NAME = "APPLICATIONINSIGHTS_CONNECTION_STRING" + def get_dimensions(): env_values = {name: value for name, value in vars(Constants).items() if not name.startswith("_")} @@ -40,42 +42,51 @@ def get_dimensions(): return res -def get_project_endpoint(): +def get_project_endpoint(logger=None): + project_endpoint = os.environ.get(Constants.AZURE_AI_PROJECT_ENDPOINT) + if project_endpoint: + if logger: + logger.info(f"Using project endpoint from {Constants.AZURE_AI_PROJECT_ENDPOINT}: {project_endpoint}") + return project_endpoint project_resource_id = os.environ.get(Constants.AGENT_PROJECT_RESOURCE_ID) if project_resource_id: last_part = project_resource_id.split("/")[-1] parts = last_part.split("@") if len(parts) < 2: - print(f"invalid project resource id: {project_resource_id}") + if logger: + logger.warning(f"Invalid project resource id format: {project_resource_id}") return None account = parts[0] project = parts[1] - return f"https://{account}.services.ai.azure.com/api/projects/{project}" - print("environment variable AGENT_PROJECT_RESOURCE_ID not set.") + endpoint = f"https://{account}.services.ai.azure.com/api/projects/{project}" + if logger: + logger.info(f"Using project endpoint derived from {Constants.AGENT_PROJECT_RESOURCE_ID}: {endpoint}") + return endpoint return None -def get_application_insights_connstr(): +def get_application_insights_connstr(logger=None): try: - conn_str = os.environ.get(Constants.APPLICATION_INSIGHTS_CONNECTION_STRING) + conn_str = os.environ.get(APPINSIGHT_CONNSTR_ENV_NAME) if not conn_str: - print("environment variable APPLICATION_INSIGHTS_CONNECTION_STRING not set.") - project_endpoint = get_project_endpoint() + project_endpoint = get_project_endpoint(logger=logger) if project_endpoint: # try to get the project connected application insights from azure.ai.projects import AIProjectClient from azure.identity import DefaultAzureCredential - project_client = AIProjectClient(credential=DefaultAzureCredential(), endpoint=project_endpoint) conn_str = project_client.telemetry.get_application_insights_connection_string() - if not conn_str: - print(f"no connected application insights found for project:{project_endpoint}") - else: - os.environ[Constants.APPLICATION_INSIGHTS_CONNECTION_STRING] = conn_str + if not conn_str and logger: + logger.info(f"No Application Insights connection found for project: {project_endpoint}") + elif conn_str: + os.environ[APPINSIGHT_CONNSTR_ENV_NAME] = conn_str + elif logger: + logger.info("Application Insights not configured, telemetry export disabled.") return conn_str except Exception as e: - print(f"failed to get application insights with error: {e}") + if logger: + logger.warning(f"Failed to get Application Insights connection string, telemetry export disabled: {e}") return None @@ -102,8 +113,9 @@ def configure(log_config: dict = default_log_config): """ try: config.dictConfig(log_config) + app_logger = logging.getLogger("azure.ai.agentserver") - application_insights_connection_string = get_application_insights_connstr() + application_insights_connection_string = get_application_insights_connstr(logger=app_logger) enable_application_insights_logger = ( os.environ.get(Constants.ENABLE_APPLICATION_INSIGHTS_LOGGER, "true").lower() == "true" ) @@ -132,7 +144,6 @@ def configure(log_config: dict = default_log_config): handler.addFilter(custom_filter) # Only add to azure.ai.agentserver namespace to avoid infrastructure logs - app_logger = logging.getLogger("azure.ai.agentserver") app_logger.setLevel(get_log_level()) app_logger.addHandler(handler) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py index a38f55408c7f..820d54c6cea0 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/models/_create_response.py @@ -10,3 +10,4 @@ class CreateResponse(response_create_params.ResponseCreateParamsBase, total=False): # type: ignore agent: Optional[_azure_ai_projects_models.AgentReference] stream: Optional[bool] + tools: Optional[list[_azure_ai_projects_models.Tool]] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_context.py new file mode 100644 index 000000000000..f86d1ae0d4ac --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_context.py @@ -0,0 +1,32 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import AsyncContextManager, ClassVar, Optional + +from azure.ai.agentserver.core.tools import FoundryToolRuntime + + +class AgentServerContext(AsyncContextManager["AgentServerContext"]): + _INSTANCE: ClassVar[Optional["AgentServerContext"]] = None + + def __init__(self, tool_runtime: FoundryToolRuntime): + self._tool_runtime = tool_runtime + + self.__class__._INSTANCE = self + + @classmethod + def get(cls) -> "AgentServerContext": + if cls._INSTANCE is None: + raise ValueError("AgentServerContext has not been initialized.") + return cls._INSTANCE + + @property + def tools(self) -> FoundryToolRuntime: + return self._tool_runtime + + async def __aenter__(self) -> "AgentServerContext": + await self._tool_runtime.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_value, traceback) -> None: + await self._tool_runtime.__aexit__(exc_type, exc_value, traceback) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py new file mode 100644 index 000000000000..fb73c2a3fe7d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/_response_metadata.py @@ -0,0 +1,42 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import json +from dataclasses import asdict +from typing import Dict + +from ..application._metadata import get_current_app +from ..models import Response as OpenAIResponse, ResponseStreamEvent +from ..models.projects import ( + ResponseCompletedEvent, + ResponseCreatedEvent, + ResponseInProgressEvent, +) + +HEADER_NAME = "x-aml-foundry-agents-metadata" +METADATA_KEY = "foundry_agents_metadata" + + +def _metadata_json() -> str: + payload = asdict(get_current_app()) + return json.dumps(payload) + + +def build_foundry_agents_metadata_headers() -> Dict[str, str]: + """Return header dict containing the foundry metadata header.""" + return {HEADER_NAME: _metadata_json()} + + +def attach_foundry_metadata_to_response(response: OpenAIResponse) -> None: + """Attach metadata into response.metadata[METADATA_KEY].""" + meta = response.metadata or {} + meta[METADATA_KEY] = _metadata_json() + response.metadata = meta + + +def try_attach_foundry_metadata_to_event(event: ResponseStreamEvent) -> None: + """Attach metadata to supported stream events; skip others.""" + if isinstance(event, (ResponseCreatedEvent, ResponseInProgressEvent, ResponseCompletedEvent)): + attach_foundry_metadata_to_response(event.response) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py index 8915aadb172b..74688d7772f1 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/base.py @@ -2,17 +2,20 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- # pylint: disable=broad-exception-caught,unused-argument,logging-fstring-interpolation,too-many-statements,too-many-return-statements +# mypy: ignore-errors +import asyncio # pylint: disable=C4763 import inspect import json import os -import traceback +import time from abc import abstractmethod -from typing import Any, AsyncGenerator, Generator, Union +from typing import Any, AsyncGenerator, Generator, Optional, Union import uvicorn from opentelemetry import context as otel_context, trace from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator from starlette.applications import Starlette +from starlette.concurrency import iterate_in_threadpool from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.cors import CORSMiddleware from starlette.requests import Request @@ -20,21 +23,31 @@ from starlette.routing import Route from starlette.types import ASGIApp -from ..constants import Constants -from ..logger import get_logger, request_context -from ..models import ( - Response as OpenAIResponse, - ResponseStreamEvent, +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from azure.identity.aio import DefaultAzureCredential as AsyncDefaultTokenCredential + +from ._context import AgentServerContext +from ._response_metadata import ( + attach_foundry_metadata_to_response, + build_foundry_agents_metadata_headers, + try_attach_foundry_metadata_to_event, ) from .common.agent_run_context import AgentRunContext +from ..constants import Constants +from ..logger import APPINSIGHT_CONNSTR_ENV_NAME, get_logger, get_project_endpoint, request_context +from ..models import Response as OpenAIResponse, ResponseStreamEvent, projects as project_models +from ..tools import UserInfoContextMiddleware, create_tool_runtime +from ..utils._credential import AsyncTokenCredentialAdapter logger = get_logger() DEBUG_ERRORS = os.environ.get(Constants.AGENT_DEBUG_ERRORS, "false").lower() == "true" - +KEEP_ALIVE_INTERVAL = 15.0 # seconds class AgentRunContextMiddleware(BaseHTTPMiddleware): - def __init__(self, app: ASGIApp): + def __init__(self, app: ASGIApp, agent: Optional['FoundryCBAgent'] = None): super().__init__(app) + self.agent = agent async def dispatch(self, request: Request, call_next): if request.url.path in ("/runs", "/responses"): @@ -82,7 +95,13 @@ def set_run_context_to_context_var(self, run_context): class FoundryCBAgent: - def __init__(self): + def __init__(self, + credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] = None, + project_endpoint: Optional[str] = None) -> None: + self.credentials = AsyncTokenCredentialAdapter(credentials) if credentials else AsyncDefaultTokenCredential() + project_endpoint = get_project_endpoint(logger=logger) or project_endpoint + AgentServerContext(create_tool_runtime(project_endpoint, self.credentials)) + async def runs_endpoint(request): # Set up tracing context and span context = request.state.agent_run_context @@ -93,89 +112,64 @@ async def runs_endpoint(request): kind=trace.SpanKind.SERVER, ): try: - logger.info("Start processing CreateResponse request:") + logger.info("Start processing CreateResponse request.") context_carrier = {} TraceContextTextMapPropagator().inject(context_carrier) + ex = None resp = await self.agent_run(context) - - if inspect.isgenerator(resp): - # Prefetch first event to allow 500 status if generation fails immediately - try: - first_event = next(resp) - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Generator initialization failed: %s\n%s", e, traceback.format_exc()) - return JSONResponse({"error": err_msg}, status_code=500) - - def gen(): - ctx = TraceContextTextMapPropagator().extract(carrier=context_carrier) - token = otel_context.attach(ctx) - error_sent = False - try: - # yield prefetched first event - yield _event_to_sse_chunk(first_event) - for event in resp: - yield _event_to_sse_chunk(event) - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Error in non-async generator: %s\n%s", e, traceback.format_exc()) - payload = {"error": err_msg} - yield f"event: error\ndata: {json.dumps(payload)}\n\n" - yield "data: [DONE]\n\n" - error_sent = True - finally: - logger.info("End of processing CreateResponse request:") - otel_context.detach(token) - if not error_sent: - yield "data: [DONE]\n\n" - - return StreamingResponse(gen(), media_type="text/event-stream") - if inspect.isasyncgen(resp): - # Prefetch first async event to allow early 500 - try: - first_event = await resp.__anext__() - except StopAsyncIteration: - # No items produced; treat as empty successful stream - def empty_gen(): - yield "data: [DONE]\n\n" - - return StreamingResponse(empty_gen(), media_type="text/event-stream") - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Async generator initialization failed: %s\n%s", e, traceback.format_exc()) - return JSONResponse({"error": err_msg}, status_code=500) - - async def gen_async(): - ctx = TraceContextTextMapPropagator().extract(carrier=context_carrier) - token = otel_context.attach(ctx) - error_sent = False - try: - # yield prefetched first event - yield _event_to_sse_chunk(first_event) - async for event in resp: - yield _event_to_sse_chunk(event) - except Exception as e: # noqa: BLE001 - err_msg = str(e) if DEBUG_ERRORS else "Internal error" - logger.error("Error in async generator: %s\n%s", e, traceback.format_exc()) - payload = {"error": err_msg} - yield f"event: error\ndata: {json.dumps(payload)}\n\n" - yield "data: [DONE]\n\n" - error_sent = True - finally: - logger.info("End of processing CreateResponse request.") - otel_context.detach(token) - if not error_sent: - yield "data: [DONE]\n\n" - - return StreamingResponse(gen_async(), media_type="text/event-stream") - logger.info("End of processing CreateResponse request.") - return JSONResponse(resp.as_dict()) except Exception as e: # TODO: extract status code from exception - logger.error(f"Error processing CreateResponse request: {traceback.format_exc()}") - return JSONResponse({"error": str(e)}, status_code=500) + logger.error(f"Error processing CreateResponse request: {e}", exc_info=True) + ex = e + + if not context.stream: + logger.info("End of processing CreateResponse request.") + result = resp if not ex else project_models.ResponseError( + code=project_models.ResponseErrorCode.SERVER_ERROR, + message=_format_error(ex)) + if not ex: + attach_foundry_metadata_to_response(result) + return JSONResponse(result.as_dict(), headers=self.create_response_headers()) + + async def gen_async(ex): + ctx = TraceContextTextMapPropagator().extract(carrier=context_carrier) + prev_ctx = otel_context.get_current() + otel_context.attach(ctx) + seq = 0 + try: + if ex: + return + it = iterate_in_threadpool(resp) if inspect.isgenerator(resp) else resp + # Wrap iterator with keep-alive mechanism + async for event in _iter_with_keep_alive(it): + if event is None: + # Keep-alive signal + yield _keep_alive_comment() + else: + try_attach_foundry_metadata_to_event(event) + seq += 1 + yield _event_to_sse_chunk(event) + logger.info("End of processing CreateResponse request.") + except Exception as e: # noqa: BLE001 + logger.error("Error in async generator: %s", e, exc_info=True) + ex = e + finally: + if ex: + err = project_models.ResponseErrorEvent( + sequence_number=seq + 1, + code=project_models.ResponseErrorCode.SERVER_ERROR, + message=_format_error(ex), + param="") + yield _event_to_sse_chunk(err) + otel_context.attach(prev_ctx) + + return StreamingResponse( + gen_async(ex), + media_type="text/event-stream", + headers=self.create_response_headers(), + ) async def liveness_endpoint(request): result = await self.agent_liveness(request) @@ -193,6 +187,7 @@ async def readiness_endpoint(request): ] self.app = Starlette(routes=routes) + UserInfoContextMiddleware.install(self.app) self.app.add_middleware( CORSMiddleware, allow_origins=["*"], @@ -200,12 +195,17 @@ async def readiness_endpoint(request): allow_methods=["*"], allow_headers=["*"], ) - self.app.add_middleware(AgentRunContextMiddleware) + self.app.add_middleware(AgentRunContextMiddleware, agent=self) @self.app.on_event("startup") - async def attach_appinsights_logger(): + async def on_startup(): import logging + # Log server started successfully + port = getattr(self, '_port', 'unknown') + logger.info(f"FoundryCBAgent server started successfully on port {port}") + + # Attach App Insights handler to uvicorn loggers for handler in logger.handlers: if handler.name == "appinsights_handler": for logger_name in ["uvicorn", "uvicorn.error", "uvicorn.access"]: @@ -222,6 +222,118 @@ async def agent_run( ) -> Union[OpenAIResponse, Generator[ResponseStreamEvent, Any, Any], AsyncGenerator[ResponseStreamEvent, Any]]: raise NotImplementedError + async def respond_with_oauth_consent(self, context, error) -> project_models.Response: + """Generate a response indicating that OAuth consent is required. + + :param context: The agent run context. + :type context: AgentRunContext + :param error: The OAuthConsentRequiredError instance. + :type error: OAuthConsentRequiredError + :return: A Response indicating the need for OAuth consent. + :rtype: project_models.Response + """ + output = [ + project_models.OAuthConsentRequestItemResource( + id=context.id_generator.generate_oauthreq_id(), + consent_link=error.consent_url, + server_label="server_label" + ) + ] + agent_id = context.get_agent_id_object() + conversation = context.get_conversation_object() + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "created_at": int(time.time()), + "output": output, + }) + return response + + async def respond_with_oauth_consent_astream(self, context, error) -> AsyncGenerator[ResponseStreamEvent, None]: + """Generate a response stream indicating that OAuth consent is required. + + :param context: The agent run context. + :type context: AgentRunContext + :param error: The OAuthConsentRequiredError instance. + :type error: OAuthConsentRequiredError + :return: An async generator yielding ResponseStreamEvent instances. + :rtype: AsyncGenerator[ResponseStreamEvent, None] + """ + sequence_number = 0 + agent_id = context.get_agent_id_object() + conversation = context.get_conversation_object() + + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "status": "in_progress", + "created_at": int(time.time()), + }) + yield project_models.ResponseCreatedEvent(sequence_number=sequence_number, response=response) + sequence_number += 1 + + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "status": "in_progress", + "created_at": int(time.time()), + }) + yield project_models.ResponseInProgressEvent(sequence_number=sequence_number, response=response) + + sequence_number += 1 + output_index = 0 + oauth_id = context.id_generator.generate_oauthreq_id() + item = project_models.OAuthConsentRequestItemResource({ + "id": oauth_id, + "type": "oauth_consent_request", + "consent_link": error.consent_url, + "server_label": "server_label", + }) + yield project_models.ResponseOutputItemAddedEvent(sequence_number=sequence_number, + output_index=output_index, item=item) + sequence_number += 1 + yield project_models.ResponseStreamEvent({ + "sequence_number": sequence_number, + "output_index": output_index, + "id": oauth_id, + "type": "response.oauth_consent_requested", + "consent_link": error.consent_url, + "server_label": "server_label", + }) + + sequence_number += 1 + yield project_models.ResponseOutputItemDoneEvent(sequence_number=sequence_number, + output_index=output_index, item=item) + sequence_number += 1 + output = [ + project_models.OAuthConsentRequestItemResource( + id= oauth_id, + consent_link=error.consent_url, + server_label="server_label" + ) + ] + + response = project_models.Response({ + "object": "response", + "id": context.response_id, + "agent": agent_id, + "conversation": conversation, + "metadata": context.request.get("metadata"), + "created_at": int(time.time()), + "status": "completed", + "output": output, + }) + yield project_models.ResponseCompletedEvent(sequence_number=sequence_number, response=response) + async def agent_liveness(self, request) -> Union[Response, dict]: return Response(status_code=200) @@ -241,6 +353,7 @@ async def run_async( self.init_tracing() config = uvicorn.Config(self.app, host="0.0.0.0", port=port, loop="asyncio") server = uvicorn.Server(config) + self._port = port logger.info(f"Starting FoundryCBAgent server async on port {port}") await server.serve() @@ -256,12 +369,13 @@ def run(self, port: int = int(os.environ.get("DEFAULT_AD_PORT", 8088))) -> None: :type port: int """ self.init_tracing() + self._port = port logger.info(f"Starting FoundryCBAgent server on port {port}") uvicorn.run(self.app, host="0.0.0.0", port=port) def init_tracing(self): exporter = os.environ.get(Constants.OTEL_EXPORTER_ENDPOINT) - app_insights_conn_str = os.environ.get(Constants.APPLICATION_INSIGHTS_CONNECTION_STRING) + app_insights_conn_str = os.environ.get(APPINSIGHT_CONNSTR_ENV_NAME) if exporter or app_insights_conn_str: from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider @@ -303,6 +417,11 @@ def setup_otlp_exporter(self, endpoint, provider): provider.add_span_processor(processor) logger.info(f"Tracing setup with OTLP exporter: {endpoint}") + def create_response_headers(self) -> dict[str, str]: + headers = {} + headers.update(build_foundry_agents_metadata_headers()) + return headers + def _event_to_sse_chunk(event: ResponseStreamEvent) -> str: event_data = json.dumps(event.as_dict()) @@ -311,5 +430,71 @@ def _event_to_sse_chunk(event: ResponseStreamEvent) -> str: return f"data: {event_data}\n\n" +def _keep_alive_comment() -> str: + """Generate a keep-alive SSE comment to maintain connection. + + :return: The keep-alive comment string. + :rtype: str + """ + return ": keep-alive\n\n" + + +async def _iter_with_keep_alive( + it: AsyncGenerator[ResponseStreamEvent, None] +) -> AsyncGenerator[Optional[ResponseStreamEvent], None]: + """Wrap an async iterator with keep-alive mechanism. + + If no event is received within KEEP_ALIVE_INTERVAL seconds, + yields None as a signal to send a keep-alive comment. + The original iterator is protected with asyncio.shield to ensure + it continues running even when timeout occurs. + + :param it: The async generator to wrap. + :type it: AsyncGenerator[ResponseStreamEvent, None] + :return: An async generator that yields events or None for keep-alive. + :rtype: AsyncGenerator[Optional[ResponseStreamEvent], None] + """ + it_anext = it.__anext__ + pending_task: Optional[asyncio.Task] = None + + while True: + try: + # If there's a pending task from previous timeout, wait for it first + if pending_task is not None: + event = await pending_task + pending_task = None + yield event + continue + + # Create a task for the next event + next_event_task = asyncio.create_task(it_anext()) + + try: + # Shield the task and wait with timeout + event = await asyncio.wait_for( + asyncio.shield(next_event_task), + timeout=KEEP_ALIVE_INTERVAL + ) + yield event + except asyncio.TimeoutError: + # Timeout occurred, but task continues due to shield + # Save task to check in next iteration + pending_task = next_event_task + yield None + + except StopAsyncIteration: + # Iterator exhausted + break + + +def _format_error(exc: Exception) -> str: + message = str(exc) + if message: + return message + if DEBUG_ERRORS: + return repr(exc) + return f"{type(exc)}: Internal error" + + def _to_response(result: Union[Response, dict]) -> Response: return result if isinstance(result, Response) else JSONResponse(result) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py index 6fae56f0027d..87c32926bde4 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/agent_run_context.py @@ -1,17 +1,22 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- +from typing import Optional + +from .id_generator.foundry_id_generator import FoundryIdGenerator +from .id_generator.id_generator import IdGenerator from ...logger import get_logger from ...models import CreateResponse from ...models.projects import AgentId, AgentReference, ResponseConversation1 -from .id_generator.foundry_id_generator import FoundryIdGenerator -from .id_generator.id_generator import IdGenerator logger = get_logger() class AgentRunContext: - def __init__(self, payload: dict): + """ + :meta private: + """ + def __init__(self, payload: dict) -> None: self._raw_payload = payload self._request = _deserialize_create_response(payload) self._id_generator = FoundryIdGenerator.from_request(payload) @@ -36,7 +41,7 @@ def response_id(self) -> str: return self._response_id @property - def conversation_id(self) -> str: + def conversation_id(self) -> Optional[str]: return self._conversation_id @property @@ -46,7 +51,7 @@ def stream(self) -> bool: def get_agent_id_object(self) -> AgentId: agent = self.request.get("agent") if not agent: - return None # type: ignore + return None # type: ignore return AgentId( { "type": agent.type, @@ -57,7 +62,7 @@ def get_agent_id_object(self) -> AgentId: def get_conversation_object(self) -> ResponseConversation1: if not self._conversation_id: - return None # type: ignore + return None # type: ignore return ResponseConversation1(id=self._conversation_id) @@ -67,10 +72,14 @@ def _deserialize_create_response(payload: dict) -> CreateResponse: raw_agent_reference = payload.get("agent") if raw_agent_reference: _deserialized["agent"] = _deserialize_agent_reference(raw_agent_reference) + + tools = payload.get("tools") + if tools: + _deserialized["tools"] = [tool for tool in tools] # pylint: disable=unnecessary-comprehension return _deserialized def _deserialize_agent_reference(payload: dict) -> AgentReference: if not payload: - return None # type: ignore + return None # type: ignore return AgentReference(**payload) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/constants.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/constants.py new file mode 100644 index 000000000000..6d4fb628a7f2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/constants.py @@ -0,0 +1,6 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# Reserved function name for HITL. +HUMAN_IN_THE_LOOP_FUNCTION_NAME = "__hosted_agent_adapter_hitl__" diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py index 910a7c481daa..01ac72289e4e 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/foundry_id_generator.py @@ -27,8 +27,12 @@ class FoundryIdGenerator(IdGenerator): def __init__(self, response_id: Optional[str], conversation_id: Optional[str]): self.response_id = response_id or self._new_id("resp") - self.conversation_id = conversation_id or self._new_id("conv") - self._partition_id = self._extract_partition_id(self.conversation_id) + self.conversation_id = conversation_id + partition_source = self.conversation_id or self.response_id + try: + self._partition_id = self._extract_partition_id(partition_source) + except ValueError: + self._partition_id = self._secure_entropy(18) @classmethod def from_request(cls, payload: dict) -> "FoundryIdGenerator": @@ -37,7 +41,7 @@ def from_request(cls, payload: dict) -> "FoundryIdGenerator": if isinstance(conv_id_raw, str): conv_id = conv_id_raw elif isinstance(conv_id_raw, dict): - conv_id = conv_id_raw.get("id", None) + conv_id = conv_id_raw.get("id", None) # type: ignore[assignment] else: conv_id = None return cls(response_id, conv_id) @@ -88,7 +92,7 @@ def _new_id( infix = infix or "" prefix_part = f"{prefix}{delimiter}" if prefix else "" - return f"{prefix_part}{entropy}{infix}{pkey}" + return f"{prefix_part}{infix}{pkey}{entropy}" @staticmethod def _secure_entropy(string_length: int) -> str: @@ -133,4 +137,4 @@ def _extract_partition_id( if len(segment) < string_length + partition_key_length: raise ValueError(f"Id '{id_str}' does not contain a valid id.") - return segment[-partition_key_length:] + return segment[:partition_key_length] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/id_generator.py index 48f0d9add17d..5b602a7fc686 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/id_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/server/common/id_generator/id_generator.py @@ -17,3 +17,6 @@ def generate_function_output_id(self) -> str: def generate_message_id(self) -> str: return self.generate("msg") + + def generate_oauthreq_id(self) -> str: + return self.generate("oauthreq") diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py new file mode 100644 index 000000000000..fa58e50368bf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/__init__.py @@ -0,0 +1,82 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) + +from .client._client import FoundryToolClient +from ._exceptions import ( + ToolInvocationError, + OAuthConsentRequiredError, + UnableToResolveToolInvocationError, + InvalidToolFacadeError, +) +from .client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryTool, + FoundryToolDetails, + FoundryToolProtocol, + FoundryToolSource, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, + UserInfo, +) +from .runtime._catalog import ( + FoundryToolCatalog, + CachedFoundryToolCatalog, + DefaultFoundryToolCatalog, +) +from .runtime._facade import FoundryToolFacade, FoundryToolLike, ensure_foundry_tool +from .runtime._invoker import FoundryToolInvoker, DefaultFoundryToolInvoker +from .runtime._resolver import FoundryToolInvocationResolver, DefaultFoundryToolInvocationResolver +from .runtime._runtime import create_tool_runtime, FoundryToolRuntime, DefaultFoundryToolRuntime +from .runtime._starlette import UserInfoContextMiddleware +from .runtime._user import UserProvider, ContextVarUserProvider + +__all__ = [ + # Client + "FoundryToolClient", + # Exceptions + "ToolInvocationError", + "OAuthConsentRequiredError", + "UnableToResolveToolInvocationError", + "InvalidToolFacadeError", + # Models + "FoundryConnectedTool", + "FoundryHostedMcpTool", + "FoundryTool", + "FoundryToolDetails", + "FoundryToolProtocol", + "FoundryToolSource", + "ResolvedFoundryTool", + "SchemaDefinition", + "SchemaProperty", + "SchemaType", + "UserInfo", + # Catalog + "FoundryToolCatalog", + "CachedFoundryToolCatalog", + "DefaultFoundryToolCatalog", + # Facade + "FoundryToolFacade", + "FoundryToolLike", + "ensure_foundry_tool", + # Invoker + "FoundryToolInvoker", + "DefaultFoundryToolInvoker", + # Resolver + "FoundryToolInvocationResolver", + "DefaultFoundryToolInvocationResolver", + # Runtime + "create_tool_runtime", + "FoundryToolRuntime", + "DefaultFoundryToolRuntime", + # Starlette + "UserInfoContextMiddleware", + # User + "UserProvider", + "ContextVarUserProvider", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py new file mode 100644 index 000000000000..a5fe7726e9f1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/_exceptions.py @@ -0,0 +1,74 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from .client._models import FoundryTool, ResolvedFoundryTool + + +class ToolInvocationError(RuntimeError): + """Raised when a tool invocation fails. + + :ivar ResolvedFoundryTool tool: The tool that failed during invocation. + + :param str message: Human-readable message describing the error. + :param ResolvedFoundryTool tool: The tool that failed during invocation. + + This exception is raised when an error occurs during the invocation of a tool, + providing details about the failure. + """ + + def __init__(self, message: str, tool: ResolvedFoundryTool): + super().__init__(message) + self.tool = tool + + +class OAuthConsentRequiredError(RuntimeError): + """Raised when the service requires end-user OAuth consent. + + This exception is raised when a tool or service operation requires explicit + OAuth consent from the end user before the operation can proceed. + + :ivar str message: Human-readable guidance returned by the service. + :ivar str consent_url: Link that the end user must visit to provide consent. + :ivar str project_connection_id: The project connection ID related to the consent request. + + :param str message: Human-readable guidance returned by the service. + :param str consent_url: Link that the end user must visit to provide the required consent. + :param str project_connection_id: The project connection ID related to the consent request. + """ + + def __init__(self, message: str, consent_url: str, project_connection_id: str): + super().__init__(message) + self.message = message + self.consent_url = consent_url + self.project_connection_id = project_connection_id + + +class UnableToResolveToolInvocationError(RuntimeError): + """Raised when a tool cannot be resolved. + + :ivar str message: Human-readable message describing the error. + :ivar FoundryTool tool: The tool that could not be resolved. + + :param str message: Human-readable message describing the error. + :param FoundryTool tool: The tool that could not be resolved. + + This exception is raised when a tool cannot be found or resolved + from the available tool sources. + """ + + def __init__(self, message: str, tool: FoundryTool): + super().__init__(message) + self.tool = tool + + +class InvalidToolFacadeError(RuntimeError): + """Raised when a tool facade is invalid. + + This exception is raised when a tool facade does not conform + to the expected structure or contains invalid data. + """ diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py new file mode 100644 index 000000000000..12b647d7adc7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_client.py @@ -0,0 +1,215 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import asyncio # pylint: disable=C4763 +import itertools +from collections import defaultdict +from typing import ( + Any, + AsyncContextManager, + Awaitable, Collection, + DefaultDict, + Dict, + Iterable, + List, + Mapping, + Optional, + Tuple, + cast, +) + +from azure.core import AsyncPipelineClient +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._configuration import FoundryToolClientConfiguration +from ._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryTool, + FoundryToolDetails, + FoundryToolSource, + ResolvedFoundryTool, + UserInfo, +) +from .operations._foundry_connected_tools import FoundryConnectedToolsOperations +from .operations._foundry_hosted_mcp_tools import FoundryMcpToolsOperations +from .._exceptions import ToolInvocationError + + +class FoundryToolClient(AsyncContextManager["FoundryToolClient"]): # pylint: disable=C4748 + """Asynchronous client for aggregating tools from Azure AI MCP and Tools APIs. + + This client provides access to tools from both MCP (Model Context Protocol) servers + and Azure AI Tools API endpoints, enabling unified tool discovery and invocation. + + :param endpoint: + The fully qualified endpoint for the Azure AI Agents service. + Example: "https://.api.azureml.ms" + :type endpoint: str + :param credential: + Credential for authenticating requests to the service. + Use credentials from azure-identity like DefaultAzureCredential. + :type credential: ~azure.core.credentials.TokenCredential + :param api_version: The API version to use for this operation. + :type api_version: str or None + """ + + def __init__( # pylint: disable=C4718 + self, + endpoint: str, + credential: "AsyncTokenCredential", + ) -> None: + """Initialize the asynchronous Azure AI Tool Client. + + :param endpoint: The service endpoint URL. + :type endpoint: str + :param credential: Credentials for authenticating requests. + :type credential: ~azure.core.credentials.TokenCredential + :param api_version: The API version to use for this operation. + :type api_version: str or None + """ + # noinspection PyTypeChecker + config = FoundryToolClientConfiguration(credential) + self._client: AsyncPipelineClient = AsyncPipelineClient(base_url=endpoint, config=config) + + self._hosted_mcp_tools = FoundryMcpToolsOperations(self._client) + self._connected_tools = FoundryConnectedToolsOperations(self._client) + + @distributed_trace_async + async def list_tools( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> List[ResolvedFoundryTool]: + """List all available tools from configured sources. + + Retrieves tools from both MCP servers and Azure AI Tools API endpoints, + returning them as ResolvedFoundryTool instances ready for invocation. + + :param tools: Collection of FoundryTool instances to resolve. + :type tools: Collection[~FoundryTool] + :param user: Information about the user requesting the tools. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent requesting the tools. + :type agent_name: str + + :return: List of resolved Foundry tools. + :rtype: List[ResolvedFoundryTool] + :raises ~azure.ai.agentserver.core.tools._exceptions.OAuthConsentRequiredError: + Raised when the service requires user OAuth consent. + :raises ~azure.core.exceptions.HttpResponseError: + Raised for HTTP communication failures. + + """ + _ = kwargs # Reserved for future use + resolved_tools: List[ResolvedFoundryTool] = [] + results = await self._list_tools_details_internal(tools, agent_name, user) + for definition, details in results: + resolved_tools.append(ResolvedFoundryTool(definition=definition, details=details)) + return resolved_tools + + @distributed_trace_async + async def list_tools_details( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> Mapping[str, List[FoundryToolDetails]]: + """List all available tools from configured sources. + + Retrieves tools from both MCP servers and Azure AI Tools API endpoints, + returning them as ResolvedFoundryTool instances ready for invocation. + + :param tools: Collection of FoundryTool instances to resolve. + :type tools: Collection[~FoundryTool] + :param user: Information about the user requesting the tools. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent requesting the tools. + :type agent_name: str + + :return: Mapping of tool IDs to lists of FoundryToolDetails. + :rtype: Mapping[str, List[FoundryToolDetails]] + :raises ~azure.ai.agentserver.core.tools._exceptions.OAuthConsentRequiredError: + Raised when the service requires user OAuth consent. + :raises ~azure.core.exceptions.HttpResponseError: + Raised for HTTP communication failures. + + """ + _ = kwargs # Reserved for future use + resolved_tools: Dict[str, List[FoundryToolDetails]] = defaultdict(list) + results = await self._list_tools_details_internal(tools, agent_name, user) + for definition, details in results: + resolved_tools[definition.id].append(details) + return resolved_tools + + async def _list_tools_details_internal( + self, + tools: Collection[FoundryTool], + agent_name: str, + user: Optional[UserInfo] = None, + ) -> Iterable[Tuple[FoundryTool, FoundryToolDetails]]: + tools_by_source: DefaultDict[FoundryToolSource, List[FoundryTool]] = defaultdict(list) + for t in tools: + tools_by_source[t.source].append(t) + + listing_tools: List[Awaitable[Iterable[Tuple[FoundryTool, FoundryToolDetails]]]] = [] + if FoundryToolSource.HOSTED_MCP in tools_by_source: + hosted_mcp_tools = cast(List[FoundryHostedMcpTool], tools_by_source[FoundryToolSource.HOSTED_MCP]) + listing_tools.append(self._hosted_mcp_tools.list_tools(hosted_mcp_tools)) + if FoundryToolSource.CONNECTED in tools_by_source: + connected_tools = cast(List[FoundryConnectedTool], tools_by_source[FoundryToolSource.CONNECTED]) + listing_tools.append(self._connected_tools.list_tools(connected_tools, user, agent_name)) + iters = await asyncio.gather(*listing_tools) + return itertools.chain.from_iterable(iters) + + @distributed_trace_async + async def invoke_tool( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + agent_name: str, + user: Optional[UserInfo] = None, + **kwargs: Any + ) -> Any: + """Invoke a tool by instance, name, or descriptor. + + :param tool: Tool to invoke, specified as an AzureAITool instance, + tool name string, or FoundryTool. + :type tool: ResolvedFoundryTool + :param arguments: Arguments to pass to the tool. + :type arguments: Dict[str, Any] + :param user: Information about the user invoking the tool. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent invoking the tool. + :type agent_name: str + :return: The result of invoking the tool. + :rtype: Any + :raises ~OAuthConsentRequiredError: + Raised when the service requires user OAuth consent. + :raises ~azure.core.exceptions.HttpResponseError: + Raised for HTTP communication failures. + :raises ~ToolInvocationError: + Raised when the tool invocation fails or source is not supported. + + """ + _ = kwargs # Reserved for future use + if tool.source is FoundryToolSource.HOSTED_MCP: + return await self._hosted_mcp_tools.invoke_tool(tool, arguments) + if tool.source is FoundryToolSource.CONNECTED: + return await self._connected_tools.invoke_tool(tool, arguments, user, agent_name) + raise ToolInvocationError(f"Unsupported tool source: {tool.source}", tool=tool) + + async def close(self) -> None: + """Close the underlying HTTP pipeline.""" + await self._client.close() + + async def __aenter__(self) -> "FoundryToolClient": + await self._client.__aenter__() + return self + + async def __aexit__(self, *exc_details: Any) -> None: + await self._client.__aexit__(*exc_details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py new file mode 100644 index 000000000000..e09c80ed83f8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_configuration.py @@ -0,0 +1,35 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azure.core.configuration import Configuration +from azure.core.credentials_async import AsyncTokenCredential +from azure.core.pipeline import policies + +from ...application._metadata import get_current_app + + +class FoundryToolClientConfiguration(Configuration): # pylint: disable=too-many-instance-attributes + """Configuration for Azure AI Tool Client. + + Manages authentication, endpoint configuration, and policy settings for the + Azure AI Tool Client. This class is used internally by the client and should + not typically be instantiated directly. + + :param credential: + Azure TokenCredential for authentication. + :type credential: ~azure.core.credentials.TokenCredential + """ + + def __init__(self, credential: "AsyncTokenCredential"): + super().__init__() + + self.retry_policy = policies.AsyncRetryPolicy() + self.logging_policy = policies.NetworkTraceLoggingPolicy() + self.request_id_policy = policies.RequestIdPolicy() + self.http_logging_policy = policies.HttpLoggingPolicy() + self.user_agent_policy = policies.UserAgentPolicy( + base_user_agent=get_current_app().as_user_agent("FoundryToolClient")) + self.authentication_policy = policies.AsyncBearerTokenCredentialPolicy( + credential, "https://ai.azure.com/.default" + ) + self.redirect_policy = policies.AsyncRedirectPolicy() diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py new file mode 100644 index 000000000000..6302c6de5ade --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/_models.py @@ -0,0 +1,608 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from enum import Enum +from typing import ( + Annotated, + Any, + ClassVar, + Dict, + Iterable, + List, + Literal, + Mapping, + Optional, + Set, + Type, + Union, +) + +from pydantic import ( + AliasChoices, + AliasPath, + BaseModel, + Discriminator, + Field, + ModelWrapValidatorHandler, + Tag, + TypeAdapter, + model_validator, +) + +from azure.core import CaseInsensitiveEnumMeta + +from .._exceptions import OAuthConsentRequiredError + + +class FoundryToolSource(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Identifies the origin of a tool. + + Specifies whether a tool comes from an MCP (Model Context Protocol) server + or from the Azure AI Tools API (remote tools). + """ + + HOSTED_MCP = "hosted_mcp" + CONNECTED = "connected" + + +class FoundryToolProtocol(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """Identifies the protocol used by a connected tool.""" + + MCP = "mcp" + A2A = "a2a" + + +@dataclass(frozen=True, eq=False) +class FoundryTool(ABC): + """Definition of a foundry tool including its parameters.""" + source: FoundryToolSource = field(init=False) + + @property + @abstractmethod + def id(self) -> str: + """Unique identifier for the tool. + + :rtype: str + """ + raise NotImplementedError + + def __str__(self): + return self.id + + +@dataclass(frozen=True, eq=False) +class FoundryHostedMcpTool(FoundryTool): + """Foundry MCP tool definition. + + :ivar str name: Name of MCP tool. + :ivar Mapping[str, Any] configuration: Tools configuration. + """ + source: Literal[FoundryToolSource.HOSTED_MCP] = field(init=False, default=FoundryToolSource.HOSTED_MCP) + name: str + configuration: Optional[Mapping[str, Any]] = None + + @property + def id(self) -> str: + """Unique identifier for the tool. + + :rtype: str + """ + return f"{self.source}:{self.name}" + + +@dataclass(frozen=True, eq=False) +class FoundryConnectedTool(FoundryTool): + """Foundry connected tool definition. + + :ivar str project_connection_id: connection name of foundry tool. + """ + source: Literal[FoundryToolSource.CONNECTED] = field(init=False, default=FoundryToolSource.CONNECTED) + protocol: str + project_connection_id: str + + @property + def id(self) -> str: + return f"{self.source}:{self.protocol}:{self.project_connection_id}" + + +@dataclass(frozen=True) +class FoundryToolDetails: + """Details about a Foundry tool. + + :ivar str name: Name of the tool. + :ivar str description: Description of the tool. + :ivar SchemaDefinition input_schema: Input schema for the tool parameters. + :ivar Optional[SchemaDefinition] metadata: Optional metadata schema for the tool. + """ + name: str + description: str + input_schema: "SchemaDefinition" + metadata: Optional["SchemaDefinition"] = None + + +@dataclass(frozen=True) +class ResolvedFoundryTool: + """Resolved Foundry tool with definition and details. + + :ivar ToolDefinition definition: + Optional tool definition object, or None. + :ivar FoundryToolDetails details: + Details about the tool, including name, description, and input schema. + """ + + definition: FoundryTool + details: FoundryToolDetails + + @property + def id(self) -> str: + return f"{self.definition.id}:{self.details.name}" + + @property + def source(self) -> FoundryToolSource: + """Origin of the tool. + + :rtype: FoundryToolSource + """ + return self.definition.source + + @property + def name(self) -> str: + """Name of the tool. + + :rtype: str + """ + return self.details.name + + @property + def description(self) -> str: + """Description of the tool. + + :rtype: str + """ + return self.details.description + + @property + def input_schema(self) -> "SchemaDefinition": + """Input schema of the tool. + + :rtype: SchemaDefinition + """ + return self.details.input_schema + + @property + def metadata(self) -> Optional["SchemaDefinition"]: + """Metadata schema of the tool, if any.""" + return self.details.metadata + + +@dataclass(frozen=True) +class UserInfo: + """Represents user information. + + :ivar str object_id: User's object identifier. + :ivar str tenant_id: Tenant identifier. + """ + + object_id: str + tenant_id: str + + +class SchemaType(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """ + Enumeration of possible schema types. + + :ivar py_type: The corresponding Python runtime type for this schema type + (e.g., ``SchemaType.STRING.py_type is str``). + """ + + py_type: Type[Any] + """The corresponding Python runtime type for this schema type.""" + + STRING = ("string", str) + """Schema type for string values (maps to ``str``).""" + + NUMBER = ("number", float) + """Schema type for numeric values with decimals (maps to ``float``).""" + + INTEGER = ("integer", int) + """Schema type for integer values (maps to ``int``).""" + + BOOLEAN = ("boolean", bool) + """Schema type for boolean values (maps to ``bool``).""" + + ARRAY = ("array", list) + """Schema type for array values (maps to ``list``).""" + + OBJECT = ("object", dict) + """Schema type for object/dictionary values (maps to ``dict``).""" + + def __new__(cls, value: str, py_type: Type[Any]): + """ + Create an enum member whose value is the schema type string, while also + attaching the mapped Python type. + + :param value: The serialized schema type string (e.g. ``"string"``). + :type value: str + :param py_type: The mapped Python runtime type (e.g. ``str``). + :type py_type: Type[Any] + :return: The created enum member. + :rtype: SchemaType + """ + obj = str.__new__(cls, value) + obj._value_ = value + obj.py_type = py_type + return obj + + @classmethod + def from_python_type(cls, t: Type[Any]) -> "SchemaType": + """ + Get the matching :class:`SchemaType` for a given Python runtime type. + + :param t: A Python runtime type (e.g. ``str``, ``int``, ``float``). + :type t: Type[Any] + :returns: The corresponding :class:`SchemaType`. + :rtype: SchemaType + :raises ValueError: If ``t`` is not supported by this enumeration. + """ + for member in cls: + if member.py_type is t: + return member + raise ValueError(f"Unsupported python type: {t!r}") + + +class SchemaProperty(BaseModel): + """ + A JSON Schema-like description of a single property (field) or nested schema node. + + This model is intended to be recursively nestable via :attr:`items` (for arrays) + and :attr:`properties` (for objects). + + :ivar type: The schema node type (e.g., ``string``, ``object``, ``array``). + :ivar description: Optional human-readable description of the property. + :ivar items: The item schema for an ``array`` type. Typically set when + :attr:`type` is :data:`~SchemaType.ARRAY`. + :ivar properties: Nested properties for an ``object`` type. Typically set when + :attr:`type` is :data:`~SchemaType.OBJECT`. Keys are property names, values + are their respective schemas. + :ivar default: Optional default value for the property. + :ivar required: For an ``object`` schema node, the set of required property + names within :attr:`properties`. (This mirrors JSON Schema’s ``required`` + keyword; it is *not* β€œthis property is required in a parent object”.) + """ + + type: Optional[SchemaType] = None + """The schema node type (e.g., ``string``, ``object``, ``array``). May be ``None`` + if the upstream tool manifest supplies an empty or unrecognised type string.""" + description: Optional[str] = None + + @model_validator(mode="before") + @classmethod + def _coerce_empty_type(cls, data: Any) -> Any: + """Coerce an empty ``type`` string to ``None`` so that properties with + invalid or missing type information are still deserialised instead of + raising a validation error.""" + if isinstance(data, dict) and data.get("type") == "": + data = {**data, "type": None} + return data + items: Optional["SchemaProperty"] = None + properties: Optional[Mapping[str, "SchemaProperty"]] = None + default: Any = None + required: Optional[Set[str]] = None + + def has_default(self) -> bool: + """ + Check if the property has a default value defined. + + :return: True if a default value is set, False otherwise. + :rtype: bool + """ + return "default" in self.model_fields_set + + +class SchemaDefinition(BaseModel): + """ + A top-level JSON Schema-like definition for an object. + + :ivar type: The schema type of the root. Typically :data:`~SchemaType.OBJECT`. + :ivar properties: Mapping of top-level property names to their schemas. + :ivar required: Set of required top-level property names within + :attr:`properties`. + """ + + type: SchemaType = SchemaType.OBJECT + properties: Mapping[str, SchemaProperty] = field(default_factory=dict) # pylint: disable=E3701 + required: Optional[Set[str]] = None + + def extract_from(self, + datasource: Mapping[str, Any], + property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: + return self._extract(datasource, self.properties, self.required, property_alias) + + @classmethod + def _extract(cls, + datasource: Mapping[str, Any], + properties: Mapping[str, SchemaProperty], + required: Optional[Set[str]] = None, + property_alias: Optional[Dict[str, List[str]]] = None) -> Dict[str, Any]: + result: Dict[str, Any] = {} + + for property_name, schema in properties.items(): + # Determine the keys to look for in the datasource + keys_to_check = [property_name] + if property_alias and property_name in property_alias: + keys_to_check.extend(property_alias[property_name]) + + # Find the first matching key in the datasource + value_found = False + for key in keys_to_check: + if key in datasource: + value = datasource[key] + value_found = True + break + + if not value_found and schema.has_default(): + value = schema.default + value_found = True + + if not value_found: + # If the property is required but not found, raise an error + if required and property_name in required: + raise KeyError(f"Required property '{property_name}' not found in datasource.") + # If not found and not required, skip to next property + continue + + # Process the value based on its schema type + if schema.type == SchemaType.OBJECT and schema.properties: + if isinstance(value, Mapping): + nested_value = cls._extract( + value, + schema.properties, + schema.required, + property_alias + ) + result[property_name] = nested_value + elif schema.type == SchemaType.ARRAY and schema.items: + if isinstance(value, Iterable): + nested_list = [] + for item in value: + if schema.items.type == SchemaType.OBJECT and schema.items.properties: + nested_item = SchemaDefinition._extract( + item, + schema.items.properties, + schema.items.required, + property_alias + ) + nested_list.append(nested_item) + else: + nested_list.append(item) + result[property_name] = nested_list + else: + result[property_name] = value + + return result + + +class RawFoundryHostedMcpTool(BaseModel): + """Pydantic model for a single MCP tool. + + :ivar str name: Unique name identifier of the tool. + :ivar Optional[str] title: Display title of the tool, defaults to name if not provided. + :ivar str description: Human-readable description of the tool. + :ivar SchemaDefinition input_schema: JSON schema for tool input parameters. + :ivar Optional[SchemaDefinition] meta: Optional metadata for the tool. + """ + + name: str + title: Optional[str] = None + description: str = "" + input_schema: SchemaDefinition = Field( + default_factory=SchemaDefinition, + validation_alias="inputSchema" + ) + meta: Optional[SchemaDefinition] = Field(default=None, validation_alias="_meta") + + def model_post_init(self, __context: Any) -> None: + if self.title is None: + self.title = self.name + + +class RawFoundryHostedMcpTools(BaseModel): + """Pydantic model for the result containing list of tools. + + :ivar List[RawFoundryHostedMcpTool] tools: List of MCP tool definitions. + """ + + tools: List[RawFoundryHostedMcpTool] = Field(default_factory=list) + + +class ListFoundryHostedMcpToolsResponse(BaseModel): + """Pydantic model for the complete MCP tools/list JSON-RPC response. + + :ivar str jsonrpc: JSON-RPC version, defaults to "2.0". + :ivar int id: Request identifier, defaults to 0. + :ivar RawFoundryHostedMcpTools result: Result containing the list of tools. + """ + + jsonrpc: str = "2.0" + id: int = 0 + result: RawFoundryHostedMcpTools = Field( + default_factory=RawFoundryHostedMcpTools + ) + + +class BaseConnectedToolsErrorResult(BaseModel, ABC): + """Base model for connected tools error responses.""" + + @abstractmethod + def as_exception(self) -> Exception: + """Convert the error result to an appropriate exception. + + :return: An exception representing the error. + :rtype: Exception + """ + raise NotImplementedError + + +class OAuthConsentRequiredErrorResult(BaseConnectedToolsErrorResult): + """Model for OAuth consent required error responses. + + :ivar Literal["OAuthConsentRequired"] type: Error type identifier. + :ivar Optional[str] consent_url: URL for user consent, if available. + :ivar Optional[str] message: Human-readable error message. + :ivar Optional[str] project_connection_id: Project connection ID related to the error. + """ + + type: Literal["OAuthConsentRequired"] + consent_url: str = Field( + validation_alias=AliasChoices( + AliasPath("toolResult", "consentUrl"), + AliasPath("toolResult", "message"), + ), + ) + message: str = Field( + validation_alias=AliasPath("toolResult", "message"), + ) + project_connection_id: str = Field( + validation_alias=AliasPath("toolResult", "projectConnectionId"), + ) + + def as_exception(self) -> Exception: + return OAuthConsentRequiredError(self.message, self.consent_url, self.project_connection_id) + + +class RawFoundryConnectedTool(BaseModel): + """Pydantic model for a single connected tool. + + :ivar str name: Name of the tool. + :ivar str description: Description of the tool. + :ivar Optional[SchemaDefinition] input_schema: Input schema for the tool parameters. + """ + name: str + description: str + input_schema: SchemaDefinition = Field( + default_factory=SchemaDefinition, + validation_alias="parameters", + ) + + +class RawFoundryConnectedRemoteServer(BaseModel): + """Pydantic model for a connected remote server. + + :ivar str protocol: Protocol used by the remote server. + :ivar str project_connection_id: Project connection ID of the remote server. + :ivar List[RawFoundryConnectedTool] tools: List of connected tools from this server. + """ + protocol: str = Field( + validation_alias=AliasPath("remoteServer", "protocol"), + ) + project_connection_id: str = Field( + validation_alias=AliasPath("remoteServer", "projectConnectionId"), + ) + tools: List[RawFoundryConnectedTool] = Field( + default_factory=list, + validation_alias="manifest", + ) + + +class ListConnectedToolsResult(BaseModel): + """Pydantic model for the result of listing connected tools. + + :ivar List[ConnectedRemoteServer] servers: List of connected remote servers. + """ + servers: List[RawFoundryConnectedRemoteServer] = Field( + default_factory=list, + validation_alias="tools", + ) + + +class ListFoundryConnectedToolsResponse(BaseModel): + """Pydantic model for the response of listing the connected tools. + + :ivar Optional[ConnectedToolsResult] result: Result containing connected tool servers. + :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. + """ + + result: Optional[ListConnectedToolsResult] = None + error: Optional[BaseConnectedToolsErrorResult] = None + + # noinspection DuplicatedCode + _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( + Annotated[ + Union[ + Annotated[ + Annotated[ + Union[OAuthConsentRequiredErrorResult], + Field(discriminator="type") + ], + Tag("ErrorType") + ], + Annotated[ListConnectedToolsResult, Tag("ResultType")], + ], + Discriminator( + lambda payload: "ErrorType" if isinstance(payload, dict) and "type" in payload else "ResultType" + ), + ]) + + @model_validator(mode="wrap") + @classmethod + def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "ListFoundryConnectedToolsResponse": + parsed = cls._TYPE_ADAPTER.validate_python(data) + normalized = {} + if isinstance(parsed, ListConnectedToolsResult): + normalized["result"] = parsed + elif isinstance(parsed, BaseConnectedToolsErrorResult): + normalized["error"] = parsed # type: ignore[assignment] + return handler(normalized) + + +class InvokeConnectedToolsResult(BaseModel): + """Pydantic model for the result of invoking a connected tool. + + :ivar Any value: The result value from the tool invocation. + """ + value: Any = Field(validation_alias="toolResult") + + +class InvokeFoundryConnectedToolsResponse(BaseModel): + """Pydantic model for the response of invoking a connected tool. + + :ivar Optional[InvokeConnectedToolsResult] result: Result of the tool invocation. + :ivar Optional[BaseConnectedToolsErrorResult] error: Error result, if any. + """ + result: Optional[InvokeConnectedToolsResult] = None + error: Optional[BaseConnectedToolsErrorResult] = None + + # noinspection DuplicatedCode + _TYPE_ADAPTER: ClassVar[TypeAdapter] = TypeAdapter( + Annotated[ + Union[ + Annotated[ + Annotated[ + Union[OAuthConsentRequiredErrorResult], + Field(discriminator="type") + ], + Tag("ErrorType") + ], + Annotated[InvokeConnectedToolsResult, Tag("ResultType")], + ], + Discriminator( + lambda payload: "ErrorType" if isinstance(payload, dict) and + # handle other error types in the future + payload.get("type") == "OAuthConsentRequired" + else "ResultType" + ), + ]) + + @model_validator(mode="wrap") + @classmethod + def _validator(cls, data: Any, handler: ModelWrapValidatorHandler) -> "InvokeFoundryConnectedToolsResponse": + parsed: Union[InvokeConnectedToolsResult, BaseConnectedToolsErrorResult] = (cls._TYPE_ADAPTER + .validate_python(data)) + normalized: Dict[str, Any] = {} + if isinstance(parsed, InvokeConnectedToolsResult): + normalized["result"] = parsed + elif isinstance(parsed, BaseConnectedToolsErrorResult): + normalized["error"] = parsed + return handler(normalized) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py new file mode 100644 index 000000000000..a3c552fe2575 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_base.py @@ -0,0 +1,73 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +from abc import ABC +import json +from typing import Any, ClassVar, MutableMapping, Type + +from azure.core import AsyncPipelineClient +from azure.core.exceptions import ClientAuthenticationError, HttpResponseError, ResourceExistsError, \ + ResourceNotFoundError, ResourceNotModifiedError, map_error +from azure.core.pipeline.transport import AsyncHttpResponse, HttpRequest + +ErrorMapping = MutableMapping[int, Type[HttpResponseError]] + + +class BaseOperations(ABC): + DEFAULT_ERROR_MAP: ClassVar[ErrorMapping] = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + + def __init__(self, client: AsyncPipelineClient, error_map: ErrorMapping | None = None) -> None: + self._client = client + self._error_map = self._prepare_error_map(error_map) + + @classmethod + def _prepare_error_map(cls, custom_error_map: ErrorMapping | None = None) -> MutableMapping: + """Prepare error map by merging default and custom error mappings. + + :param custom_error_map: Custom error mappings to merge + :return: Merged error map + """ + error_map = cls.DEFAULT_ERROR_MAP + if custom_error_map: + error_map = dict(cls.DEFAULT_ERROR_MAP) + error_map.update(custom_error_map) + return error_map + + async def _send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> AsyncHttpResponse: + """Send an HTTP request. + + :param request: HTTP request + :param stream: Stream to be used for HTTP requests + :param kwargs: Keyword arguments + + :return: Response object + """ + response: AsyncHttpResponse = await self._client.send_request(request, stream=stream, **kwargs) + self._handle_response_error(response) + return response + + def _handle_response_error(self, response: AsyncHttpResponse) -> None: + """Handle HTTP response errors. + + :param response: HTTP response to check + :raises HttpResponseError: If response status is not 200 + """ + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=self._error_map) + raise HttpResponseError(response=response) + + def _extract_response_json(self, response: AsyncHttpResponse) -> Any: + try: + payload_text = response.text() + payload_json = json.loads(payload_text) if payload_text else {} + except AttributeError: + payload_bytes = response.body() + payload_json = json.loads(payload_bytes.decode("utf-8")) if payload_bytes else {} + return payload_json \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_connected_tools.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_connected_tools.py new file mode 100644 index 000000000000..83138a17ad9a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_connected_tools.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC +from typing import Any, AsyncIterable, ClassVar, Dict, Iterable, List, Mapping, Optional, Tuple, cast + +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._base import BaseOperations +from .._models import FoundryConnectedTool, FoundryToolDetails, FoundryToolSource, InvokeFoundryConnectedToolsResponse, \ + ListFoundryConnectedToolsResponse, ResolvedFoundryTool, UserInfo +from ..._exceptions import ToolInvocationError + + +class BaseFoundryConnectedToolsOperations(BaseOperations, ABC): + """Base operations for Foundry connected tools.""" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = { + "api-version": _API_VERSION + } + + @staticmethod + def _list_tools_path(agent_name: str) -> str: + return f"/agents/{agent_name}/tools/resolve" + + @staticmethod + def _invoke_tool_path(agent_name: str) -> str: + return f"/agents/{agent_name}/tools/invoke" + + def _build_list_tools_request( + self, + tools: List[FoundryConnectedTool], + user: Optional[UserInfo], + agent_name: str,) -> HttpRequest: + payload: Dict[str, Any] = { + "remoteServers": [ + { + "projectConnectionId": tool.project_connection_id, + "protocol": tool.protocol, + } for tool in tools + ], + } + if user: + payload["user"] = { + "objectId": user.object_id, + "tenantId": user.tenant_id, + } + return self._client.post( + self._list_tools_path(agent_name), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=payload) + + @classmethod + def _convert_listed_tools( + cls, + resp: ListFoundryConnectedToolsResponse, + input_tools: List[FoundryConnectedTool]) -> Iterable[Tuple[FoundryConnectedTool, FoundryToolDetails]]: + if resp.error: + raise resp.error.as_exception() + if not resp.result: + return + + tool_map = {(tool.project_connection_id, tool.protocol): tool for tool in input_tools} + for server in resp.result.servers: + input_tool = tool_map.get((server.project_connection_id, server.protocol)) + if not input_tool: + continue + + for tool in server.tools: + details = FoundryToolDetails( + name=tool.name, + description=tool.description, + input_schema=tool.input_schema, + ) + yield input_tool, details + + def _build_invoke_tool_request( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + user: Optional[UserInfo], + agent_name: str) -> HttpRequest: + if tool.definition.source != FoundryToolSource.CONNECTED: + raise ToolInvocationError(f"Tool {tool.name} is not a Foundry connected tool.", tool=tool) + + tool_def = cast(FoundryConnectedTool, tool.definition) + payload: Dict[str, Any] = { + "toolName": tool.name, + "arguments": arguments, + "remoteServer": { + "projectConnectionId": tool_def.project_connection_id, + "protocol": tool_def.protocol, + }, + } + if user: + payload["user"] = { + "objectId": user.object_id, + "tenantId": user.tenant_id, + } + return self._client.post( + self._invoke_tool_path(agent_name), + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=payload) + + @classmethod + def _convert_invoke_result(cls, resp: InvokeFoundryConnectedToolsResponse) -> Any: + if resp.error: + raise resp.error.as_exception() + if not resp.result: + return None + return resp.result.value + + +class FoundryConnectedToolsOperations(BaseFoundryConnectedToolsOperations): + """Operations for managing Foundry connected tools.""" + + @distributed_trace_async + async def list_tools(self, + tools: List[FoundryConnectedTool], + user: Optional[UserInfo], + agent_name: str) -> Iterable[Tuple[FoundryConnectedTool, FoundryToolDetails]]: + """List connected tools. + + :param tools: List of connected tool definitions. + :type tools: List[FoundryConnectedTool] + :param user: User information for the request. Value can be None if running in local. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent. + :type agent_name: str + :return: An async iterable of tuples containing the tool definition and its details. + :rtype: AsyncIterable[Tuple[FoundryConnectedTool, FoundryToolDetails]] + """ + if not tools: + return [] + + request = self._build_list_tools_request(tools, user, agent_name) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + tools_response = ListFoundryConnectedToolsResponse.model_validate(json_response) + return self._convert_listed_tools(tools_response, tools) + + + @distributed_trace_async + async def invoke_tool( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + user: Optional[UserInfo], + agent_name: str) -> Any: + """Invoke a connected tool. + + :param tool: Tool descriptor to invoke. + :type tool: ResolvedFoundryTool + :param arguments: Input arguments for the tool. + :type arguments: Mapping[str, Any] + :param user: User information for the request. Value can be None if running in local. + :type user: Optional[UserInfo] + :param agent_name: Name of the agent. + :type agent_name: str + :return: Result of the tool invocation. + :rtype: Any + """ + request = self._build_invoke_tool_request(tool, arguments, user, agent_name) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + invoke_response = InvokeFoundryConnectedToolsResponse.model_validate(json_response) + return self._convert_invoke_result(invoke_response) + \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py new file mode 100644 index 000000000000..08587e274096 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/client/operations/_foundry_hosted_mcp_tools.py @@ -0,0 +1,168 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC +from typing import Any, AsyncIterable, ClassVar, Dict, Iterable, List, Tuple, cast + +from azure.core.pipeline.transport import HttpRequest +from azure.core.tracing.decorator_async import distributed_trace_async + +from ._base import BaseOperations +from .._models import FoundryHostedMcpTool, FoundryToolDetails, FoundryToolSource, ListFoundryHostedMcpToolsResponse, \ + ResolvedFoundryTool +from ..._exceptions import ToolInvocationError + + +class BaseFoundryHostedMcpToolsOperations(BaseOperations, ABC): + """Base operations for Foundry-hosted MCP tools.""" + + _PATH: ClassVar[str] = "/mcp_tools" + + _API_VERSION: ClassVar[str] = "2025-11-15-preview" + + _HEADERS: ClassVar[Dict[str, str]] = { + "Content-Type": "application/json", + "Accept": "application/json,text/event-stream", + "Connection": "keep-alive", + "Cache-Control": "no-cache", + } + + _QUERY_PARAMS: ClassVar[Dict[str, Any]] = { + "api-version": _API_VERSION + } + + _LIST_TOOLS_REQUEST_BODY: ClassVar[Dict[str, Any]] = { + "jsonrpc": "2.0", + "id": 1, + "method": "tools/list", + "params": {} + } + + _INVOKE_TOOL_REQUEST_BODY_TEMPLATE: ClassVar[Dict[str, Any]] = { + "jsonrpc": "2.0", + "id": 2, + "method": "tools/call", + } + + # Tool-specific property key overrides + # Format: {"tool_name": {"tool_def_key": "meta_schema_key"}} + _TOOL_PROPERTY_ALIAS: ClassVar[Dict[str, Dict[str, List[str]]]] = { + "_default": { + "imagegen_model_deployment_name": ["model_deployment_name"], + "model_deployment_name": ["model"], + "deployment_name": ["model"], + }, + "image_generation": { + "imagegen_model_deployment_name": ["model"] + }, + # Add more tool-specific mappings as needed + } + + def _build_list_tools_request(self) -> HttpRequest: + """Build request for listing MCP tools. + + :return: Request for listing MCP tools. + """ + return self._client.post(self._PATH, + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=self._LIST_TOOLS_REQUEST_BODY) + + @staticmethod + def _convert_listed_tools( + response: ListFoundryHostedMcpToolsResponse, + allowed_tools: List[FoundryHostedMcpTool]) -> Iterable[Tuple[FoundryHostedMcpTool, FoundryToolDetails]]: + + allowlist = {tool.name: tool for tool in allowed_tools} + for tool in response.result.tools: + definition = allowlist.get(tool.name) + if not definition: + continue + details = FoundryToolDetails( + name=tool.name, + description=tool.description, + metadata=tool.meta, + input_schema=tool.input_schema) + yield definition, details + + def _build_invoke_tool_request(self, tool: ResolvedFoundryTool, arguments: Dict[str, Any]) -> HttpRequest: + if tool.definition.source != FoundryToolSource.HOSTED_MCP: + raise ToolInvocationError(f"Tool {tool.name} is not a Foundry-hosted MCP tool.", tool=tool) + definition = cast(FoundryHostedMcpTool, tool.definition) + + payload = dict(self._INVOKE_TOOL_REQUEST_BODY_TEMPLATE) + payload["params"] = { + "name": tool.name, + "arguments": arguments + } + if tool.metadata and definition.configuration: + payload["_meta"] = tool.metadata.extract_from(definition.configuration, + self._resolve_property_alias(tool.name)) + + return self._client.post(self._PATH, + params=self._QUERY_PARAMS, + headers=self._HEADERS, + content=payload) + + @classmethod + def _resolve_property_alias(cls, tool_name: str) -> Dict[str, List[str]]: + """Get property key overrides for a specific tool. + + :param tool_name: Name of the tool. + :type tool_name: str + :return: Property key overrides. + :rtype: Dict[str, List[str]] + """ + overrides = dict(cls._TOOL_PROPERTY_ALIAS.get("_default", {})) + tool_specific = cls._TOOL_PROPERTY_ALIAS.get(tool_name, {}) + overrides.update(tool_specific) + return overrides + + +class FoundryMcpToolsOperations(BaseFoundryHostedMcpToolsOperations): + """Operations for Foundry-hosted MCP tools.""" + + @distributed_trace_async + async def list_tools( + self, + allowed_tools: List[FoundryHostedMcpTool] + ) -> Iterable[Tuple[FoundryHostedMcpTool, FoundryToolDetails]]: + """List MCP tools. + + :param allowed_tools: List of allowed MCP tools to filter. + :type allowed_tools: List[FoundryHostedMcpTool] + :return: An async iterable of tuples containing tool definitions and their details. + :rtype: AsyncIterable[Tuple[FoundryHostedMcpTool, FoundryToolDetails]] + """ + if not allowed_tools: + return [] + + request = self._build_list_tools_request() + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + tools_response = ListFoundryHostedMcpToolsResponse.model_validate(json_response) + + return self._convert_listed_tools(tools_response, allowed_tools) + + @distributed_trace_async + async def invoke_tool( + self, + tool: ResolvedFoundryTool, + arguments: Dict[str, Any], + ) -> Any: + """Invoke an MCP tool. + + :param tool: Tool descriptor for the tool to invoke. + :type tool: ResolvedFoundryTool + :param arguments: Input arguments for the tool. + :type arguments: Dict[str, Any] + :return: Result of the tool invocation. + :rtype: Any + """ + request = self._build_invoke_tool_request(tool, arguments) + response = await self._send_request(request) + async with response: + json_response = self._extract_response_json(response) + invoke_response = json_response + return invoke_response diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py new file mode 100644 index 000000000000..2d50089fef8f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_catalog.py @@ -0,0 +1,143 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import asyncio # pylint: disable=C4763 +from abc import ABC, abstractmethod +from typing import Any, Awaitable, Collection, List, Mapping, MutableMapping, Optional, Union + +from cachetools import TTLCache # type: ignore[import-untyped] + +from ._facade import FoundryToolLike, ensure_foundry_tool +from ._user import UserProvider +from ..client._client import FoundryToolClient +from ..client._models import FoundryTool, FoundryToolDetails, FoundryToolSource, ResolvedFoundryTool, UserInfo + + +class FoundryToolCatalog(ABC): + """Base class for Foundry tool catalogs.""" + def __init__(self, user_provider: UserProvider): + self._user_provider = user_provider + + async def get(self, tool: FoundryToolLike) -> Optional[ResolvedFoundryTool]: + """Gets a Foundry tool by its definition. + + :param tool: The Foundry tool to resolve. + :type tool: FoundryToolLike + :return: The resolved Foundry tool. + :rtype: Optional[ResolvedFoundryTool] + """ + tools = await self.list([tool]) + return tools[0] if tools else None + + @abstractmethod + async def list(self, tools: List[FoundryToolLike]) -> List[ResolvedFoundryTool]: + """Lists all available Foundry tools. + + :param tools: The list of Foundry tools to resolve. + :type tools: List[FoundryToolLike] + :return: A list of resolved Foundry tools. + :rtype: List[ResolvedFoundryTool] + """ + raise NotImplementedError + + +_CachedValueType = Union[Awaitable[List[FoundryToolDetails]], List[FoundryToolDetails]] + + +class CachedFoundryToolCatalog(FoundryToolCatalog, ABC): + """Cached implementation of FoundryToolCatalog with concurrency-safe caching.""" + + def __init__(self, user_provider: UserProvider): + super().__init__(user_provider) + self._cache: MutableMapping[Any, _CachedValueType] = self._create_cache() + + def _create_cache(self) -> MutableMapping[Any, _CachedValueType]: + return TTLCache(maxsize=1024, ttl=600) + + def _get_key(self, user: Optional[UserInfo], tool: FoundryTool) -> Any: + if tool.source is FoundryToolSource.HOSTED_MCP: + return tool.id + return user, tool.id + + async def list(self, tools: List[FoundryToolLike]) -> List[ResolvedFoundryTool]: + user = await self._user_provider.get_user() + foundry_tools = {} + tools_to_fetch = {} + fetching_tasks = [] + for t in tools: + tool = ensure_foundry_tool(t) + key = self._get_key(user, tool) + foundry_tools[key] = tool + if key not in self._cache: + tools_to_fetch[key] = tool + elif (task := self._cache[key]) and isinstance(task, Awaitable): + fetching_tasks.append(task) + + # for tools that are not being listed, create a batch task, convert to per-tool resolving tasks, and cache them + if tools_to_fetch: + # Awaitable[Mapping[str, List[FoundryToolDetails]]] + fetched_tools = asyncio.create_task(self._fetch_tools(tools_to_fetch.values(), user)) + + for k, t in tools_to_fetch.items(): + # safe to write cache since it's the only runner in this event loop + task = asyncio.create_task(self._per_tool_fetching_task(k, t, fetched_tools)) + self._cache[k] = task + fetching_tasks.append(task) + + try: + # now we have every tool associated with a task + if fetching_tasks: + await asyncio.gather(*fetching_tasks) + except: + # exception can only be caused by fetching tasks, remove them from cache + for k, _ in tools_to_fetch.items(): + if k in self._cache: + del self._cache[k] + raise + + resolved_tools = [] + for key, tool in foundry_tools.items(): + # this acts like a lock - every task of the same tool waits for the same underlying fetch + task_or_value = self._cache[key] + details_list = (await task_or_value) if isinstance(task_or_value, Awaitable) else task_or_value + for details in details_list: + resolved_tools.append( + ResolvedFoundryTool( + definition=tool, + details=details + ) + ) + + return resolved_tools + + async def _per_tool_fetching_task( + self, + cache_key: Any, + tool: FoundryTool, + fetching: Awaitable[Mapping[str, List[FoundryToolDetails]]] + ) -> List[FoundryToolDetails]: + details = await fetching + details_list = details.get(tool.id, []) + # replace the task in cache with the actual value to optimize memory usage + self._cache[cache_key] = details_list + return details_list + + @abstractmethod + async def _fetch_tools(self, + tools: Collection[FoundryTool], + user: Optional[UserInfo]) -> Mapping[str, List[FoundryToolDetails]]: + raise NotImplementedError + + +class DefaultFoundryToolCatalog(CachedFoundryToolCatalog): + """Default implementation of FoundryToolCatalog.""" + + def __init__(self, client: FoundryToolClient, user_provider: UserProvider, agent_name: str): + super().__init__(user_provider) + self._client = client + self._agent_name = agent_name + + async def _fetch_tools(self, + tools: Collection[FoundryTool], + user: Optional[UserInfo]) -> Mapping[str, List[FoundryToolDetails]]: + return await self._client.list_tools_details(tools, self._agent_name, user) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py new file mode 100644 index 000000000000..71c4601cb525 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_facade.py @@ -0,0 +1,95 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import re +from typing import Any, Dict, Union + +from .. import FoundryConnectedTool, FoundryHostedMcpTool +from .._exceptions import InvalidToolFacadeError +from ..client._models import FoundryTool, FoundryToolProtocol + +# FoundryToolFacade: a β€œtool descriptor” bag. +# +# Reserved keys: +# Required: +# - "type": str Discriminator, e.g. "mcp" | "a2a" | "code_interpreter" | ... +# Optional: +# - "project_connection_id": str Project connection id of Foundry connected tools, +# required when "type" is "mcp" or "a2a". +# +# Custom keys: +# - Allowed, but MUST NOT shadow reserved keys. +FoundryToolFacade = Dict[str, Any] + +FoundryToolLike = Union[FoundryToolFacade, FoundryTool] + + +def ensure_foundry_tool(tool: FoundryToolLike) -> FoundryTool: + """Ensure the input is a FoundryTool instance. + + :param tool: The tool descriptor, either as a FoundryToolFacade or FoundryTool. + :type tool: FoundryToolLike + :return: The corresponding FoundryTool instance. + :rtype: FoundryTool + """ + if isinstance(tool, FoundryTool): + return tool + + tool = tool.copy() + tool_type = tool.pop("type", None) + if not isinstance(tool_type, str) or not tool_type: + raise InvalidToolFacadeError("FoundryToolFacade must have a valid 'type' field of type str.") + + try: + protocol = FoundryToolProtocol(tool_type) + project_connection_id = tool.pop("project_connection_id", None) + if not isinstance(project_connection_id, str) or not project_connection_id: + raise InvalidToolFacadeError(f"project_connection_id is required for tool protocol {protocol}.") + + # Parse the connection identifier to extract the connection name + connection_name = _parse_connection_id(project_connection_id) + return FoundryConnectedTool(protocol=protocol, project_connection_id=connection_name) + except ValueError: + return FoundryHostedMcpTool(name=tool_type, configuration=tool) + + +# Pattern for Azure resource ID format: +# /subscriptions//resourceGroups//providers/Microsoft.CognitiveServices/accounts +# //projects//connections/ +_RESOURCE_ID_PATTERN = re.compile( + r"^/subscriptions/[^/]+/resourceGroups/[^/]+/providers/Microsoft\.CognitiveServices/" + r"accounts/[^/]+/projects/[^/]+/connections/(?P[^/]+)$", + re.IGNORECASE, +) + + +def _parse_connection_id(connection_id: str) -> str: + """Parse the connection identifier and extract the connection name. + + Supports two formats: + 1. Simple name: "my-connection-name" + 2. Resource ID: "/subscriptions//resourceGroups//providers + /Microsoft.CognitiveServices/accounts//projects//connections/" + + :param connection_id: The connection identifier, either a simple name or a full resource ID. + :type connection_id: str + :return: The connection name extracted from the identifier. + :rtype: str + :raises InvalidToolFacadeError: If the connection_id format is invalid. + """ + if not connection_id: + raise InvalidToolFacadeError("Connection identifier cannot be empty.") + + # Check if it's a resource ID format (starts with /) + if connection_id.startswith("/"): + match = _RESOURCE_ID_PATTERN.match(connection_id) + if not match: + raise InvalidToolFacadeError( + f"Invalid resource ID format for connection: '{connection_id}'. " + "Expected format: /subscriptions//resourceGroups//providers/" + "Microsoft.CognitiveServices/accounts//projects//connections/" + ) + return match.group("name") + + # Otherwise, treat it as a simple connection name + return connection_id diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_invoker.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_invoker.py new file mode 100644 index 000000000000..d24c79dd4d12 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_invoker.py @@ -0,0 +1,69 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +from typing import Any, Dict + +from ._user import UserProvider +from ..client._client import FoundryToolClient +from ..client._models import ResolvedFoundryTool + + +class FoundryToolInvoker(ABC): + """Abstract base class for Foundry tool invokers.""" + + @property + @abstractmethod + def resolved_tool(self) -> ResolvedFoundryTool: + """Get the resolved tool definition. + + :return: The tool definition. + :rtype: ResolvedFoundryTool + """ + raise NotImplementedError + + @abstractmethod + async def invoke(self, arguments: Dict[str, Any]) -> Any: + """Invoke the tool with the given arguments. + + :param arguments: The arguments to pass to the tool. + :type arguments: Dict[str, Any] + :return: The result of the tool invocation + :rtype: Any + """ + raise NotImplementedError + + +class DefaultFoundryToolInvoker(FoundryToolInvoker): + """Default implementation of FoundryToolInvoker.""" + + def __init__(self, + resolved_tool: ResolvedFoundryTool, + client: FoundryToolClient, + user_provider: UserProvider, + agent_name: str): + self._resolved_tool = resolved_tool + self._client = client + self._user_provider = user_provider + self._agent_name = agent_name + + @property + def resolved_tool(self) -> ResolvedFoundryTool: + """Get the resolved tool definition. + + :return: The tool definition. + :rtype: ResolvedFoundryTool + """ + return self._resolved_tool + + async def invoke(self, arguments: Dict[str, Any]) -> Any: + """Invoke the tool with the given arguments. + + :param arguments: The arguments to pass to the tool + :type arguments: Dict[str, Any] + :return: The result of the tool invocation + :rtype: Any + """ + user = await self._user_provider.get_user() + result = await self._client.invoke_tool(self._resolved_tool, arguments, self._agent_name, user) + return result diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py new file mode 100644 index 000000000000..9596124d9b55 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_resolver.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from abc import ABC, abstractmethod +from typing import Union + +from ._catalog import FoundryToolCatalog +from ._facade import FoundryToolLike, ensure_foundry_tool +from ._invoker import DefaultFoundryToolInvoker, FoundryToolInvoker +from ._user import UserProvider +from .. import FoundryToolClient +from .._exceptions import UnableToResolveToolInvocationError +from ..client._models import ResolvedFoundryTool + + +class FoundryToolInvocationResolver(ABC): + """Resolver for Foundry tool invocations.""" + + @abstractmethod + async def resolve(self, tool: Union[FoundryToolLike, ResolvedFoundryTool]) -> FoundryToolInvoker: + """Resolves a Foundry tool invocation. + + :param tool: The Foundry tool to resolve. + :type tool: Union[FoundryToolLike, ResolvedFoundryTool] + :return: The resolved Foundry tool invoker. + :rtype: FoundryToolInvoker + """ + raise NotImplementedError + + +class DefaultFoundryToolInvocationResolver(FoundryToolInvocationResolver): + """Default implementation of FoundryToolInvocationResolver.""" + + def __init__(self, + catalog: FoundryToolCatalog, + client: FoundryToolClient, + user_provider: UserProvider, + agent_name: str): + self._catalog = catalog + self._client = client + self._user_provider = user_provider + self._agent_name = agent_name + + async def resolve(self, tool: Union[FoundryToolLike, ResolvedFoundryTool]) -> FoundryToolInvoker: + """Resolves a Foundry tool invocation. + + :param tool: The Foundry tool to resolve. + :type tool: Union[FoundryToolLike, ResolvedFoundryTool] + :return: The resolved Foundry tool invoker. + :rtype: FoundryToolInvoker + """ + if isinstance(tool, ResolvedFoundryTool): + resolved_tool = tool + else: + foundry_tool = ensure_foundry_tool(tool) + resolved_tool = await self._catalog.get(foundry_tool) # type: ignore[assignment] + if not resolved_tool: + raise UnableToResolveToolInvocationError(f"Unable to resolve tool {foundry_tool} from catalog", + foundry_tool) + return DefaultFoundryToolInvoker(resolved_tool, self._client, self._user_provider, self._agent_name) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py new file mode 100644 index 000000000000..6335d6ee7c58 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_runtime.py @@ -0,0 +1,147 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import os +from abc import ABC, abstractmethod +from typing import Any, AsyncContextManager, ClassVar, Dict, Optional, Union + +from azure.core.credentials_async import AsyncTokenCredential + +from ._catalog import DefaultFoundryToolCatalog, FoundryToolCatalog +from ._facade import FoundryToolLike +from ._resolver import DefaultFoundryToolInvocationResolver, FoundryToolInvocationResolver +from ._user import ContextVarUserProvider, UserProvider +from ..client._models import ResolvedFoundryTool +from ..client._client import FoundryToolClient +from ...constants import Constants + + +def create_tool_runtime(project_endpoint: str | None, + credential: AsyncTokenCredential | None) -> "FoundryToolRuntime": + """Create a Foundry tool runtime. + Returns a DefaultFoundryToolRuntime if both project_endpoint and credential are provided, + otherwise returns a ThrowingFoundryToolRuntime which raises errors on usage. + + :param project_endpoint: The project endpoint. + :type project_endpoint: str | None + :param credential: The credential. + :type credential: AsyncTokenCredential | None + :return: The Foundry tool runtime. + :rtype: FoundryToolRuntime + """ + if project_endpoint and credential: + return DefaultFoundryToolRuntime(project_endpoint=project_endpoint, credential=credential) + return ThrowingFoundryToolRuntime() + +class FoundryToolRuntime(AsyncContextManager["FoundryToolRuntime"], ABC): + """Base class for Foundry tool runtimes.""" + + @property + @abstractmethod + def catalog(self) -> FoundryToolCatalog: + """The tool catalog. + + :return: The tool catalog. + :rtype: FoundryToolCatalog + """ + raise NotImplementedError + + @property + @abstractmethod + def invocation(self) -> FoundryToolInvocationResolver: + """The tool invocation resolver. + + :return: The tool invocation resolver. + :rtype: FoundryToolInvocationResolver + """ + raise NotImplementedError + + async def invoke(self, tool: Union[FoundryToolLike, ResolvedFoundryTool], arguments: Dict[str, Any]) -> Any: + """Invoke a tool with the given arguments. + + :param tool: The tool to invoke. + :type tool: Union[FoundryToolLike, ResolvedFoundryTool] + :param arguments: The arguments to pass to the tool. + :type arguments: Dict[str, Any] + :return: The result of the tool invocation. + :rtype: Any + """ + invoker = await self.invocation.resolve(tool) + return await invoker.invoke(arguments) + + +class DefaultFoundryToolRuntime(FoundryToolRuntime): + """Default implementation of FoundryToolRuntime.""" + + def __init__(self, + project_endpoint: str, + credential: "AsyncTokenCredential", + user_provider: Optional[UserProvider] = None): + # Do we need introduce DI here? + self._user_provider = user_provider or ContextVarUserProvider() + self._agent_name = os.getenv(Constants.AGENT_NAME, "$default") + self._client = FoundryToolClient(endpoint=project_endpoint, credential=credential) + self._catalog = DefaultFoundryToolCatalog(client=self._client, + user_provider=self._user_provider, + agent_name=self._agent_name) + self._invocation = DefaultFoundryToolInvocationResolver(catalog=self._catalog, + client=self._client, + user_provider=self._user_provider, + agent_name=self._agent_name) + + @property + def catalog(self) -> FoundryToolCatalog: + """The tool catalog. + + :rtype: FoundryToolCatalog + """ + return self._catalog + + @property + def invocation(self) -> FoundryToolInvocationResolver: + """The tool invocation resolver. + + :rtype: FoundryToolInvocationResolver + """ + return self._invocation + + async def __aenter__(self) -> "DefaultFoundryToolRuntime": + await self._client.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self._client.__aexit__(exc_type, exc_value, traceback) + + +class ThrowingFoundryToolRuntime(FoundryToolRuntime): + """A FoundryToolRuntime that raises errors on usage.""" + _ERROR_MESSAGE: ClassVar[str] = ("FoundryToolRuntime is not configured. " + "Please provide a valid project endpoint and credential.") + + @property + def catalog(self) -> FoundryToolCatalog: + """The tool catalog. + + :returns: The tool catalog. + :rtype: FoundryToolCatalog + :raises RuntimeError: Always raised to indicate the runtime is not configured. + """ + raise RuntimeError(self._ERROR_MESSAGE) + + @property + def invocation(self) -> FoundryToolInvocationResolver: + """The tool invocation resolver. + + :returns: The tool invocation resolver. + :rtype: FoundryToolInvocationResolver + :raises RuntimeError: Always raised to indicate the runtime is not configured. + """ + raise RuntimeError(self._ERROR_MESSAGE) + + async def __aenter__(self) -> "ThrowingFoundryToolRuntime": + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + pass diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py new file mode 100644 index 000000000000..80b25d78b20e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_starlette.py @@ -0,0 +1,66 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from contextvars import ContextVar +from typing import Awaitable, Callable, Optional + +from starlette.applications import Starlette +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.types import ASGIApp + +from ._user import ContextVarUserProvider, resolve_user_from_headers +from ..client._models import UserInfo + +_UserContextType = ContextVar[Optional[UserInfo]] +_ResolverType = Callable[[Request], Awaitable[Optional[UserInfo]]] + +class UserInfoContextMiddleware(BaseHTTPMiddleware): + """Middleware to set user information in a context variable for each request.""" + + def __init__(self, app: ASGIApp, user_info_var: _UserContextType, user_resolver: _ResolverType): + super().__init__(app) + self._user_info_var = user_info_var + self._user_resolver = user_resolver + + @classmethod + def install(cls, + app: Starlette, + user_context: Optional[_UserContextType] = None, + user_resolver: Optional[_ResolverType] = None): + """Install the middleware into a Starlette application. + + :param app: The Starlette application to install the middleware into. + :type app: Starlette + :param user_context: Optional context variable to use for storing user info. + If not provided, a default context variable will be used. + :type user_context: Optional[ContextVar[Optional[UserInfo]]] + :param user_resolver: Optional function to resolve user info from the request. + If not provided, a default resolver will be used. + :type user_resolver: Optional[Callable[[Request], Awaitable[Optional[UserInfo]]]] + + """ + app.add_middleware(UserInfoContextMiddleware, # type: ignore[arg-type] + user_info_var=user_context or ContextVarUserProvider.default_user_info_context, + user_resolver=user_resolver or cls._default_user_resolver) + + @staticmethod + async def _default_user_resolver(request: Request) -> Optional[UserInfo]: + return resolve_user_from_headers(request.headers) + + async def dispatch(self, request: Request, call_next): + """Process the incoming request, setting the user info in the context variable. + + :param request: The incoming Starlette request. + :type request: Request + :param call_next: The next middleware or endpoint to call. + :type call_next: Callable[[Request], Awaitable[Response]] + :return: The response from the next middleware or endpoint. + :rtype: Response + """ + user = await self._user_resolver(request) + token = self._user_info_var.set(user) + try: + return await call_next(request) + finally: + self._user_info_var.reset(token) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py new file mode 100644 index 000000000000..f72b30c0d3d3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/runtime/_user.py @@ -0,0 +1,60 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from contextvars import ContextVar +from abc import ABC, abstractmethod +from typing import ClassVar, Mapping, Optional + +from ..client._models import UserInfo + + +class UserProvider(ABC): + """Base class for user providers.""" + + @abstractmethod + async def get_user(self) -> Optional[UserInfo]: + """Get the user information. + + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ + raise NotImplementedError + + +class ContextVarUserProvider(UserProvider): + """User provider that retrieves user information from a ContextVar.""" + default_user_info_context: ClassVar[ContextVar[UserInfo]] = ContextVar("user_info_context") + + def __init__(self, context: Optional[ContextVar[UserInfo]] = None): + self.context = context or self.default_user_info_context + + async def get_user(self) -> Optional[UserInfo]: + """Get the user information from the context variable. + + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ + return self.context.get(None) + + +def resolve_user_from_headers(headers: Mapping[str, str], + object_id_header: str = "x-aml-oid", + tenant_id_header: str = "x-aml-tid") -> Optional[UserInfo]: + """Resolve user information from HTTP headers. + + :param headers: The HTTP headers. + :type headers: Mapping[str, str] + :param object_id_header: The header name for the object ID. + :type object_id_header: str + :param tenant_id_header: The header name for the tenant ID. + :type tenant_id_header: str + :return: The user information or None if not found. + :rtype: Optional[UserInfo] + """ + object_id = headers.get(object_id_header, "") + tenant_id = headers.get(tenant_id_header, "") + + if not object_id or not tenant_id: + return None + + return UserInfo(object_id=object_id, tenant_id=tenant_id) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py new file mode 100644 index 000000000000..037fb1dc04de --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/__init__.py @@ -0,0 +1,11 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) + +from ._name_resolver import ToolNameResolver + +__all__ = [ + "ToolNameResolver", +] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py new file mode 100644 index 000000000000..9f1b7874f52c --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/tools/utils/_name_resolver.py @@ -0,0 +1,37 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from ..client._models import ResolvedFoundryTool + + +class ToolNameResolver: + """Utility class for resolving tool names to be registered to model.""" + + def __init__(self): + self._count_by_name = {} + self._stable_names = {} + + def resolve(self, tool: ResolvedFoundryTool) -> str: + """Resolve a stable name for the given tool. + If the tool name has not been used before, use it as is. + If it has been used, append an underscore and a count to make it unique. + + :param tool: The tool to resolve the name for. + :type tool: ResolvedFoundryTool + :return: The resolved stable name for the tool. + :rtype: str + """ + final_name = self._stable_names.get(tool.id) + if final_name is not None: + return final_name + + dup_count = self._count_by_name.setdefault(tool.details.name, 0) + + if dup_count == 0: + final_name = tool.details.name + else: + final_name = f"{tool.details.name}_{dup_count}" + + self._stable_names[tool.id] = final_name + self._count_by_name[tool.details.name] = dup_count + 1 + return self._stable_names[tool.id] diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py new file mode 100644 index 000000000000..398a8c46fd5d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/azure/ai/agentserver/core/utils/_credential.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import asyncio # pylint: disable=C4763 +import inspect +from types import TracebackType +from typing import Any, Type, cast + +from azure.core.credentials import AccessToken, TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + + +async def _to_thread(func, *args, **kwargs): # pylint: disable=C4743 + """Compatibility wrapper for asyncio.to_thread (Python 3.8+). + + :param func: The function to run in a thread. + :type func: Callable + :param args: Positional arguments to pass to the function. + :type args: Any + :param kwargs: Keyword arguments to pass to the function. + :type kwargs: Any + :return: The result of the function call. + :rtype: Any + """ + if hasattr(asyncio, "to_thread"): + return await asyncio.to_thread(func, *args, **kwargs) # py>=3.9 + loop = asyncio.get_running_loop() + return await loop.run_in_executor(None, lambda: func(*args, **kwargs)) + + +class AsyncTokenCredentialAdapter(AsyncTokenCredential): + """ + AsyncTokenCredential adapter for either: + - azure.core.credentials.TokenCredential (sync) + - azure.core.credentials_async.AsyncTokenCredential (async) + """ + + def __init__(self, credential: TokenCredential | AsyncTokenCredential) -> None: + if not hasattr(credential, "get_token"): + raise TypeError("credential must have a get_token method") + self._credential = credential + self._is_async = isinstance(credential, AsyncTokenCredential) or inspect.iscoroutinefunction( + getattr(credential, "get_token", None) + ) + + async def get_token( + self, + *scopes: str, + claims: str | None = None, + tenant_id: str | None = None, + enable_cae: bool = False, + **kwargs: Any, + ) -> AccessToken: + if self._is_async: + cred = cast(AsyncTokenCredential, self._credential) + return await cred.get_token(*scopes, + claims=claims, + tenant_id=tenant_id, + enable_cae=enable_cae, + **kwargs) + return await _to_thread(self._credential.get_token, + *scopes, + claims=claims, + tenant_id=tenant_id, + enable_cae=enable_cae, + **kwargs) + + async def close(self) -> None: + """ + Best-effort resource cleanup: + - if underlying has async close(): await it + - else if underlying has sync close(): run it in a thread + """ + close_fn = getattr(self._credential, "close", None) + if close_fn is None: + return + + if inspect.iscoroutinefunction(close_fn): + await close_fn() + else: + await _to_thread(close_fn) + + async def __aenter__(self) -> "AsyncTokenCredentialAdapter": + enter = getattr(self._credential, "__aenter__", None) + if enter is not None and inspect.iscoroutinefunction(enter): + await enter() + return self + + async def __aexit__( + self, + exc_type: Type[BaseException] | None = None, + exc_value: BaseException | None = None, + traceback: TracebackType | None = None, + ) -> None: + aexit = getattr(self._credential, "__aexit__", None) + if aexit is not None and inspect.iscoroutinefunction(aexit): + return await aexit(exc_type, exc_value, traceback) + await self.close() diff --git a/sdk/agentserver/azure-ai-agentserver-core/cspell.json b/sdk/agentserver/azure-ai-agentserver-core/cspell.json index 126cadc0625c..d5003af37fe1 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/cspell.json +++ b/sdk/agentserver/azure-ai-agentserver-core/cspell.json @@ -16,7 +16,11 @@ "GETFL", "DETFL", "SETFL", - "Planifica" + "Planifica", + "ainvoke", + "oauthreq", + "hitl", + "HITL" ], "ignorePaths": [ "*.csv", diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst new file mode 100644 index 000000000000..415b7d3b2538 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.application.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.application package +============================================= + +.. automodule:: azure.ai.agentserver.core.application + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.openai.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.openai.rst new file mode 100644 index 000000000000..dd1cce6eecca --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.openai.rst @@ -0,0 +1,8 @@ +azure.ai.agentserver.core.models.openai package +=============================================== + +.. automodule:: azure.ai.agentserver.core.models.openai + :inherited-members: + :members: + :undoc-members: + :ignore-module-all: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.projects.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.projects.rst new file mode 100644 index 000000000000..38e0be4f331b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.projects.rst @@ -0,0 +1,8 @@ +azure.ai.agentserver.core.models.projects package +================================================= + +.. automodule:: azure.ai.agentserver.core.models.projects + :inherited-members: + :members: + :undoc-members: + :ignore-module-all: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst new file mode 100644 index 000000000000..008b280c64de --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.models.rst @@ -0,0 +1,17 @@ +azure.ai.agentserver.core.models package +======================================== + +.. automodule:: azure.ai.agentserver.core.models + :inherited-members: + :members: + :undoc-members: + :ignore-module-all: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.core.models.openai + azure.ai.agentserver.core.models.projects diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst index da01b083b0b3..b8f1dadf3a73 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.rst @@ -12,7 +12,11 @@ Subpackages .. toctree:: :maxdepth: 4 + azure.ai.agentserver.core.application + azure.ai.agentserver.core.models azure.ai.agentserver.core.server + azure.ai.agentserver.core.tools + azure.ai.agentserver.core.utils Submodules ---------- diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst index 26c4aaf4d15a..8fb5b52e4465 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.server.common.rst @@ -24,3 +24,11 @@ azure.ai.agentserver.core.server.common.agent\_run\_context module :inherited-members: :members: :undoc-members: + +azure.ai.agentserver.core.server.common.constants module +-------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.core.server.common.constants + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst new file mode 100644 index 000000000000..8182914f69f9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.client.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.client package +============================================== + +.. automodule:: azure.ai.agentserver.core.tools.client + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst new file mode 100644 index 000000000000..c112ec2beabd --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.rst @@ -0,0 +1,18 @@ +azure.ai.agentserver.core.tools package +======================================= + +.. automodule:: azure.ai.agentserver.core.tools + :inherited-members: + :members: + :undoc-members: + :exclude-members: BaseModel,model_json_schema + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.core.tools.client + azure.ai.agentserver.core.tools.runtime + azure.ai.agentserver.core.tools.utils diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst new file mode 100644 index 000000000000..c502d56b42f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.runtime.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.runtime package +=============================================== + +.. automodule:: azure.ai.agentserver.core.tools.runtime + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst new file mode 100644 index 000000000000..94d3f310e112 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.tools.utils.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.tools.utils package +============================================= + +.. automodule:: azure.ai.agentserver.core.tools.utils + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst new file mode 100644 index 000000000000..5250167cf7e6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/doc/azure.ai.agentserver.core.utils.rst @@ -0,0 +1,7 @@ +azure.ai.agentserver.core.utils package +======================================= + +.. automodule:: azure.ai.agentserver.core.utils + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml index f574360722bb..55307e69540d 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-core/pyproject.toml @@ -8,6 +8,7 @@ authors = [ ] license = "MIT" classifiers = [ + "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3", @@ -19,22 +20,27 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-monitor-opentelemetry>=1.5.0", - "azure-ai-projects", - "azure-ai-agents>=1.2.0b5", + "azure-monitor-opentelemetry>=1.5.0,<1.8.5", + "azure-ai-projects>=1.1.0b4", + "azure-ai-agents==1.2.0b5", "azure-core>=1.35.0", - "azure-identity", + "azure-identity>=1.25.1", "openai>=1.80.0", "opentelemetry-api>=1.35", "opentelemetry-exporter-otlp-proto-http", "starlette>=0.45.0", "uvicorn>=0.31.0", + "aiohttp>=3.13.0", # used by azure-identity aio + "cachetools>=6.0.0" ] [build-system] requires = ["setuptools>=69", "wheel"] build-backend = "setuptools.build_meta" +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python" + [tool.setuptools.packages.find] exclude = [ "tests*", @@ -68,4 +74,6 @@ combine-as-imports = true [tool.azure-sdk-build] breaking = false # incompatible python version pyright = false -verifytypes = false \ No newline at end of file +verifytypes = false +latestdependency = false +dependencies = false \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py b/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py index 3d4187a188f2..f6d2c08bb0b9 100644 --- a/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py +++ b/sdk/agentserver/azure-ai-agentserver-core/samples/simple_mock_agent/custom_mock_agent_test.py @@ -97,7 +97,7 @@ async def agent_run(context: AgentRunContext): return response -my_agent = FoundryCBAgent() +my_agent = FoundryCBAgent(project_endpoint="mock-endpoint") my_agent.agent_run = agent_run if __name__ == "__main__": diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/common/test_foundry_id_generator.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/common/test_foundry_id_generator.py new file mode 100644 index 000000000000..a46f45f7c739 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/common/test_foundry_id_generator.py @@ -0,0 +1,27 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from azure.ai.agentserver.core.server.common.id_generator.foundry_id_generator import FoundryIdGenerator + + +def test_conversation_id_none_uses_response_partition(): + response_id = FoundryIdGenerator._new_id("resp") + generator = FoundryIdGenerator(response_id=response_id, conversation_id=None) + + assert generator.conversation_id is None + + expected_partition = FoundryIdGenerator._extract_partition_id(response_id) + generated_id = generator.generate("msg") + assert FoundryIdGenerator._extract_partition_id(generated_id) == expected_partition + + +def test_conversation_id_present_uses_conversation_partition(): + response_id = FoundryIdGenerator._new_id("resp") + conversation_id = FoundryIdGenerator._new_id("conv") + generator = FoundryIdGenerator(response_id=response_id, conversation_id=conversation_id) + + assert generator.conversation_id == conversation_id + + expected_partition = FoundryIdGenerator._extract_partition_id(conversation_id) + generated_id = generator.generate("msg") + assert FoundryIdGenerator._extract_partition_id(generated_id) == expected_partition diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py new file mode 100644 index 000000000000..f5ee395ff48f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_otel_context.py @@ -0,0 +1,14 @@ +import pytest +from opentelemetry import context as otel_context +from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator + + +@pytest.mark.asyncio +async def test_streaming_context_restore_uses_previous_context() -> None: + prev_ctx = otel_context.get_current() + ctx = TraceContextTextMapPropagator().extract(carrier={}) + + otel_context.attach(ctx) + otel_context.attach(prev_ctx) + + assert otel_context.get_current() is prev_ctx diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py new file mode 100644 index 000000000000..c2e3bea53287 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/server/test_response_metadata.py @@ -0,0 +1,140 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import json + +from azure.ai.agentserver.core.application import ( + PackageMetadata, + RuntimeMetadata, + get_current_app, + set_current_app, +) +from azure.ai.agentserver.core.models import Response as OpenAIResponse +from azure.ai.agentserver.core.models.projects import ResponseCreatedEvent, ResponseErrorEvent +from azure.ai.agentserver.core.server._response_metadata import ( + METADATA_KEY, + attach_foundry_metadata_to_response, + build_foundry_agents_metadata_headers, + try_attach_foundry_metadata_to_event, +) + + +def _set_test_app(): + previous = get_current_app() + set_current_app( + PackageMetadata( + name="test-package", + version="1.2.3", + ), + RuntimeMetadata( + python_version="3.11.0", + platform="test-platform", + host_name="test-host", + replica_name="test-replica", + ), + ) + return previous + + +def _expected_payload() -> dict[str, dict[str, str]]: + return { + "package": { + "name": "test-package", + "version": "1.2.3", + }, + "runtime": { + "python_version": "3.11.0", + "platform": "test-platform", + "host_name": "test-host", + "replica_name": "test-replica", + }, + } + + +def test_build_foundry_agents_metadata_headers_returns_json(): + previous = _set_test_app() + try: + headers = build_foundry_agents_metadata_headers() + payload = json.loads(headers["x-aml-foundry-agents-metadata"]) + assert payload == _expected_payload() + finally: + set_current_app(previous.package, previous.runtime) + + +def test_attach_foundry_metadata_to_response_sets_metadata_key(): + previous = _set_test_app() + try: + response = OpenAIResponse({"object": "response", "id": "resp", "metadata": {}}) + attach_foundry_metadata_to_response(response) + assert METADATA_KEY in response.metadata + assert json.loads(response.metadata[METADATA_KEY]) == _expected_payload() + finally: + set_current_app(previous.package, previous.runtime) + + +def test_try_attach_foundry_metadata_to_event_attaches_for_supported_events(): + previous = _set_test_app() + try: + response = OpenAIResponse({"object": "response", "id": "resp", "metadata": {}}) + event = ResponseCreatedEvent({"sequence_number": 0, "response": response}) + try_attach_foundry_metadata_to_event(event) + assert METADATA_KEY in response.metadata + + unsupported = ResponseErrorEvent( + {"sequence_number": 1, "code": "server_error", "message": "boom", "param": ""} + ) + try_attach_foundry_metadata_to_event(unsupported) + finally: + set_current_app(previous.package, previous.runtime) + + +def test_runtime_metadata_merge_overrides_non_empty_fields(): + base = RuntimeMetadata( + python_version="3.10.0", + platform="base-platform", + host_name="base-host", + replica_name="base-replica", + ) + override = RuntimeMetadata( + python_version="", + platform="override-platform", + host_name="", + replica_name="override-replica", + ) + + merged = base.merged_with(override) + + assert merged.python_version == "3.10.0" + assert merged.platform == "override-platform" + assert merged.host_name == "base-host" + assert merged.replica_name == "override-replica" + + +def test_runtime_metadata_resolve_falls_back_when_env_missing(monkeypatch): + monkeypatch.delenv("CONTAINER_APP_REVISION_FQDN", raising=False) + monkeypatch.delenv("CONTAINER_APP_REPLICA_NAME", raising=False) + runtime = RuntimeMetadata.resolve() + + assert runtime.host_name == "" + assert runtime.replica_name == "" + assert runtime.python_version + assert runtime.platform + + +def test_runtime_metadata_resolve_aca_env(monkeypatch): + monkeypatch.setenv("CONTAINER_APP_REVISION_FQDN", "aca-host") + monkeypatch.setenv("CONTAINER_APP_REPLICA_NAME", "aca-replica") + runtime = RuntimeMetadata.resolve() + + assert runtime.host_name == "aca-host" + assert runtime.replica_name == "aca-replica" + + +def test_runtime_metadata_resolve_explicit_overrides(monkeypatch): + monkeypatch.setenv("CONTAINER_APP_REVISION_FQDN", "aca-host") + monkeypatch.setenv("CONTAINER_APP_REPLICA_NAME", "aca-replica") + + runtime = RuntimeMetadata.resolve(host_name="override-host", replica_name="override-replica") + + assert runtime.host_name == "override-host" + assert runtime.replica_name == "override-replica" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/__init__.py new file mode 100644 index 000000000000..d02a9af6c5f6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_connected_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_connected_tools.py new file mode 100644 index 000000000000..e7273f37a7e7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_connected_tools.py @@ -0,0 +1,479 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryConnectedToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryToolDetails, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_connected_tools import ( + FoundryConnectedToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import OAuthConsentRequiredError, ToolInvocationError + +from ...conftest import create_mock_http_response + + +class TestFoundryConnectedToolsOperationsListTools: + """Tests for FoundryConnectedToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when tools list is empty.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + result = await ops.list_tools([], None, "test-agent") + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_tools_from_server( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns tools from server response.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "remote_tool", + "description": "A remote connected tool", + "parameters": { + "type": "object", + "properties": { + "input": {"type": "string"} + } + } + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_connected_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == "remote_tool" + assert details.description == "A remote connected tool" + + @pytest.mark.asyncio + async def test_list_tools_without_user_info(self, sample_connected_tool): + """Test list_tools works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_no_user", + "description": "Tool without user", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], None, "test-agent")) + + assert len(result) == 1 + assert result[0][1].name == "tool_no_user" + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_connections(self, sample_user_info): + """Test list_tools with multiple connected tool definitions.""" + mock_client = AsyncMock() + + tool1 = FoundryConnectedTool(protocol="mcp", project_connection_id="conn-1") + tool2 = FoundryConnectedTool(protocol="a2a", project_connection_id="conn-2") + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "conn-1" + }, + "manifest": [ + { + "name": "tool_from_conn1", + "description": "From connection 1", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "a2a", + "projectConnectionId": "conn-2" + }, + "manifest": [ + { + "name": "tool_from_conn2", + "description": "From connection 2", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2], sample_user_info, "test-agent")) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_from_conn1", "tool_from_conn2"} + + @pytest.mark.asyncio + async def test_list_tools_filters_by_connection_id(self, sample_user_info): + """Test list_tools only returns tools from requested connections.""" + mock_client = AsyncMock() + + requested_tool = FoundryConnectedTool(protocol="mcp", project_connection_id="requested-conn") + + # Server returns tools from multiple connections, but we only requested one + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "requested-conn" + }, + "manifest": [ + { + "name": "requested_tool", + "description": "Requested", + "parameters": {"type": "object", "properties": {}} + } + ] + }, + { + "remoteServer": { + "protocol": "mcp", + "projectConnectionId": "unrequested-conn" + }, + "manifest": [ + { + "name": "unrequested_tool", + "description": "Not requested", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([requested_tool], sample_user_info, "test-agent")) + + # Should only return tools from requested connection + assert len(result) == 1 + assert result[0][1].name == "requested_tool" + + @pytest.mark.asyncio + async def test_list_tools_multiple_tools_per_connection( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools returns multiple tools from same connection.""" + mock_client = AsyncMock() + + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "tool_one", + "description": "First tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "parameters": {"type": "object", "properties": {}} + }, + { + "name": "tool_three", + "description": "Third tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert len(result) == 3 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two", "tool_three"} + + @pytest.mark.asyncio + async def test_list_tools_raises_oauth_consent_error( + self, + sample_connected_tool, + sample_user_info + ): + """Test list_tools raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/consent", + "message": "User consent is required to access this resource", + "projectConnectionId": sample_connected_tool.project_connection_id + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + list(await ops.list_tools([sample_connected_tool], sample_user_info, "test-agent")) + + assert exc_info.value.consent_url == "https://login.microsoftonline.com/consent" + assert "consent" in exc_info.value.message.lower() + + +class TestFoundryConnectedToolsOperationsInvokeTool: + """Tests for FoundryConnectedToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_result_value( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns the result value from server.""" + mock_client = AsyncMock() + + expected_result = {"data": "some output", "status": "success"} + response_data = {"toolResult": expected_result} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_tool_without_user_info(self, sample_resolved_connected_tool): + """Test invoke_tool works without user info (local execution).""" + mock_client = AsyncMock() + + response_data = {"toolResult": "local result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + None, # No user info + "test-agent" + ) + + assert result == "local result" + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "processed"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + complex_args = { + "query": "search term", + "filters": { + "date_range": {"start": "2025-01-01", "end": "2025-12-31"}, + "categories": ["A", "B", "C"] + }, + "limit": 50 + } + + result = await ops.invoke_tool( + sample_resolved_connected_tool, + complex_args, + sample_user_info, + "test-agent" + ) + + assert result == "processed" + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_returns_none_for_empty_result( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool returns None when server returns no result.""" + mock_client = AsyncMock() + + # Server returns empty response (no toolResult) + response_data = { + "toolResult": None + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + result = await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert result is None + + @pytest.mark.asyncio + async def test_invoke_tool_with_mcp_tool_raises_error( + self, + sample_resolved_mcp_tool, + sample_user_info + ): + """Test invoke_tool raises ToolInvocationError for non-connected tool.""" + mock_client = AsyncMock() + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool( + sample_resolved_mcp_tool, + {}, + sample_user_info, + "test-agent" + ) + + assert "not a Foundry connected tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_raises_oauth_consent_error( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool raises OAuthConsentRequiredError when consent needed.""" + mock_client = AsyncMock() + + response_data = { + "type": "OAuthConsentRequired", + "toolResult": { + "consentUrl": "https://login.microsoftonline.com/oauth/consent", + "message": "Please provide consent to continue", + "projectConnectionId": "test-connection-id" + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + with pytest.raises(OAuthConsentRequiredError) as exc_info: + await ops.invoke_tool( + sample_resolved_connected_tool, + {"input": "test"}, + sample_user_info, + "test-agent" + ) + + assert "https://login.microsoftonline.com/oauth/consent" in exc_info.value.consent_url + + @pytest.mark.asyncio + async def test_invoke_tool_with_different_agent_names( + self, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool uses correct agent name in request.""" + mock_client = AsyncMock() + + response_data = {"toolResult": "result"} + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryConnectedToolsOperations(mock_client) + + # Invoke with different agent names + for agent_name in ["agent-1", "my-custom-agent", "production-agent"]: + await ops.invoke_tool( + sample_resolved_connected_tool, + {}, + sample_user_info, + agent_name + ) + + # Verify the correct path was used + call_args = mock_client.post.call_args + assert agent_name in call_args[0][0] + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_hosted_mcp_tools.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_hosted_mcp_tools.py new file mode 100644 index 000000000000..473b27cc8768 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/operations/test_foundry_hosted_mcp_tools.py @@ -0,0 +1,309 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryMcpToolsOperations - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.tools.client.operations._foundry_hosted_mcp_tools import ( + FoundryMcpToolsOperations, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from ...conftest import create_mock_http_response + + +class TestFoundryMcpToolsOperationsListTools: + """Tests for FoundryMcpToolsOperations.list_tools public method.""" + + @pytest.mark.asyncio + async def test_list_tools_with_empty_list_returns_empty(self): + """Test list_tools returns empty when allowed_tools is empty.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + result = await ops.list_tools([]) + + assert result == [] + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_list_tools_returns_matching_tools(self, sample_hosted_mcp_tool): + """Test list_tools returns tools that match the allowed list.""" + mock_client = AsyncMock() + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool", + "inputSchema": { + "type": "object", + "properties": { + "query": {"type": "string", "description": "Search query"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + definition, details = result[0] + assert definition == sample_hosted_mcp_tool + assert isinstance(details, FoundryToolDetails) + assert details.name == sample_hosted_mcp_tool.name + assert details.description == "Test MCP tool" + + @pytest.mark.asyncio + async def test_list_tools_filters_out_non_allowed_tools(self, sample_hosted_mcp_tool): + """Test list_tools only returns tools in the allowed list.""" + mock_client = AsyncMock() + + # Server returns multiple tools but only one is allowed + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "other_tool_not_in_list", + "description": "Not allowed tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "another_unlisted_tool", + "description": "Also not allowed", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([sample_hosted_mcp_tool])) + + assert len(result) == 1 + assert result[0][1].name == sample_hosted_mcp_tool.name + + @pytest.mark.asyncio + async def test_list_tools_with_multiple_allowed_tools(self): + """Test list_tools with multiple tools in allowed list.""" + mock_client = AsyncMock() + + tool1 = FoundryHostedMcpTool(name="tool_one") + tool2 = FoundryHostedMcpTool(name="tool_two") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_one", + "description": "First tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "tool_two", + "description": "Second tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool1, tool2])) + + assert len(result) == 2 + names = {r[1].name for r in result} + assert names == {"tool_one", "tool_two"} + + @pytest.mark.asyncio + async def test_list_tools_preserves_tool_metadata(self): + """Test list_tools preserves metadata from server response.""" + mock_client = AsyncMock() + + tool = FoundryHostedMcpTool(name="tool_with_meta") + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": "tool_with_meta", + "description": "Tool with metadata", + "inputSchema": { + "type": "object", + "properties": { + "param1": {"type": "string"} + }, + "required": ["param1"] + }, + "_meta": { + "type": "object", + "properties": { + "model": {"type": "string"} + } + } + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = list(await ops.list_tools([tool])) + + assert len(result) == 1 + details = result[0][1] + assert details.metadata is not None + + +class TestFoundryMcpToolsOperationsInvokeTool: + """Tests for FoundryMcpToolsOperations.invoke_tool public method.""" + + @pytest.mark.asyncio + async def test_invoke_tool_returns_server_response(self, sample_resolved_mcp_tool): + """Test invoke_tool returns the response from server.""" + mock_client = AsyncMock() + + expected_response = { + "jsonrpc": "2.0", + "id": 2, + "result": { + "content": [{"type": "text", "text": "Hello World"}] + } + } + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {"query": "test"}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_empty_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool works with empty arguments.""" + mock_client = AsyncMock() + + expected_response = {"result": "success"} + mock_response = create_mock_http_response(200, expected_response) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(sample_resolved_mcp_tool, {}) + + assert result == expected_response + + @pytest.mark.asyncio + async def test_invoke_tool_with_complex_arguments(self, sample_resolved_mcp_tool): + """Test invoke_tool handles complex nested arguments.""" + mock_client = AsyncMock() + + mock_response = create_mock_http_response(200, {"result": "ok"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + complex_args = { + "text": "sample text", + "options": { + "temperature": 0.7, + "max_tokens": 100 + }, + "tags": ["tag1", "tag2"] + } + + result = await ops.invoke_tool(sample_resolved_mcp_tool, complex_args) + + assert result == {"result": "ok"} + mock_client.send_request.assert_called_once() + + @pytest.mark.asyncio + async def test_invoke_tool_with_connected_tool_raises_error( + self, + sample_resolved_connected_tool + ): + """Test invoke_tool raises ToolInvocationError for non-MCP tool.""" + mock_client = AsyncMock() + ops = FoundryMcpToolsOperations(mock_client) + + with pytest.raises(ToolInvocationError) as exc_info: + await ops.invoke_tool(sample_resolved_connected_tool, {}) + + assert "not a Foundry-hosted MCP tool" in str(exc_info.value) + # Should not make any HTTP request + mock_client.send_request.assert_not_called() + + @pytest.mark.asyncio + async def test_invoke_tool_with_configuration_and_metadata(self): + """Test invoke_tool handles tool with configuration and metadata.""" + mock_client = AsyncMock() + + # Create tool with configuration + tool_def = FoundryHostedMcpTool( + name="image_generation", + configuration={"model_deployment_name": "dall-e-3"} + ) + + # Create tool details with metadata schema + meta_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "model": SchemaProperty(type=SchemaType.STRING) + } + ) + details = FoundryToolDetails( + name="image_generation", + description="Generate images", + input_schema=SchemaDefinition(type=SchemaType.OBJECT, properties={}), + metadata=meta_schema + ) + resolved_tool = ResolvedFoundryTool(definition=tool_def, details=details) + + mock_response = create_mock_http_response(200, {"result": "image_url"}) + mock_client.send_request.return_value = mock_response + mock_client.post.return_value = MagicMock() + + ops = FoundryMcpToolsOperations(mock_client) + result = await ops.invoke_tool(resolved_tool, {"prompt": "a cat"}) + + assert result == {"result": "image_url"} + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_client.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_client.py new file mode 100644 index 000000000000..de60f545e089 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_client.py @@ -0,0 +1,485 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClient - testing only public methods.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.client._client import FoundryToolClient +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + FoundryToolSource, + ResolvedFoundryTool, +) +from azure.ai.agentserver.core.tools._exceptions import ToolInvocationError + +from ..conftest import create_mock_http_response + + +class TestFoundryToolClientInit: + """Tests for FoundryToolClient.__init__ public method.""" + + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + def test_init_with_valid_endpoint_and_credential(self, mock_pipeline_client_class, mock_credential): + """Test client can be initialized with valid endpoint and credential.""" + endpoint = "https://fake-project-endpoint.site" + + client = FoundryToolClient(endpoint, mock_credential) + + # Verify client was created with correct base_url + call_kwargs = mock_pipeline_client_class.call_args + assert call_kwargs[1]["base_url"] == endpoint + assert client is not None + + +class TestFoundryToolClientListTools: + """Tests for FoundryToolClient.list_tools public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_empty_collection_returns_empty_list( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test list_tools returns empty list when given empty collection.""" + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + + result = await client.list_tools([], agent_name="test-agent") + + assert result == [] + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_mcp_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools with a single MCP tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for MCP tools listing + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test MCP tool description", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == sample_hosted_mcp_tool.name + assert result[0].source == FoundryToolSource.HOSTED_MCP + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_single_connected_tool_returns_resolved_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with a single connected tool returns resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Mock HTTP response for connected tools listing + response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_test_tool", + "description": "Test connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools( + [sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].name == "connected_test_tool" + assert result[0].source == FoundryToolSource.CONNECTED + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_with_mixed_tool_types_returns_all_resolved( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_user_info + ): + """Test list_tools with both MCP and connected tools returns all resolved tools.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # We need to return different responses based on the request + mcp_response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "MCP tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + connected_response_data = { + "tools": [ + { + "remoteServer": { + "protocol": sample_connected_tool.protocol, + "projectConnectionId": sample_connected_tool.project_connection_id + }, + "manifest": [ + { + "name": "connected_tool", + "description": "Connected tool", + "parameters": {"type": "object", "properties": {}} + } + ] + } + ] + } + + # Mock to return different responses for different requests + mock_client_instance.send_request.side_effect = [ + create_mock_http_response(200, mcp_response_data), + create_mock_http_response(200, connected_response_data) + ] + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools( + [sample_hosted_mcp_tool, sample_connected_tool], + agent_name="test-agent", + user=sample_user_info + ) + + assert len(result) == 2 + sources = {tool.source for tool in result} + assert FoundryToolSource.HOSTED_MCP in sources + assert FoundryToolSource.CONNECTED in sources + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_filters_unlisted_mcp_tools( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools only returns tools that are in the allowed list.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns more tools than requested + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Requested tool", + "inputSchema": {"type": "object", "properties": {}} + }, + { + "name": "unrequested_tool", + "description": "This tool was not requested", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools([sample_hosted_mcp_tool], agent_name="test-agent") + + # Should only return the requested tool + assert len(result) == 1 + assert result[0].name == sample_hosted_mcp_tool.name + + +class TestFoundryToolClientListToolsDetails: + """Tests for FoundryToolClient.list_tools_details public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_returns_mapping_structure( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details returns correct mapping structure.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Test tool", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + assert isinstance(result, dict) + assert sample_hosted_mcp_tool.id in result + assert len(result[sample_hosted_mcp_tool.id]) == 1 + assert isinstance(result[sample_hosted_mcp_tool.id][0], FoundryToolDetails) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_list_tools_details_groups_multiple_tools_by_definition( + self, + mock_pipeline_client_class, + mock_credential, + sample_hosted_mcp_tool + ): + """Test list_tools_details groups multiple tools from same source by definition ID.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Server returns multiple tools for the same MCP source + response_data = { + "jsonrpc": "2.0", + "id": 1, + "result": { + "tools": [ + { + "name": sample_hosted_mcp_tool.name, + "description": "Tool variant 1", + "inputSchema": {"type": "object", "properties": {}} + } + ] + } + } + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.list_tools_details([sample_hosted_mcp_tool], agent_name="test-agent") + + # All tools should be grouped under the same definition ID + assert sample_hosted_mcp_tool.id in result + + +class TestFoundryToolClientInvokeTool: + """Tests for FoundryToolClient.invoke_tool public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_mcp_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool with MCP tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_result = {"result": {"content": [{"text": "Hello World"}]}} + mock_response = create_mock_http_response(200, expected_result) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert result == expected_result + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_connected_tool_returns_result( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_connected_tool, + sample_user_info + ): + """Test invoke_tool with connected tool returns the invocation result.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + expected_value = {"output": "Connected tool result"} + response_data = {"toolResult": expected_value} + mock_response = create_mock_http_response(200, response_data) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + result = await client.invoke_tool( + sample_resolved_connected_tool, + arguments={"input": "test"}, + agent_name="test-agent", + user=sample_user_info + ) + + assert result == expected_value + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_complex_arguments( + self, + mock_pipeline_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke_tool correctly passes complex arguments.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + mock_response = create_mock_http_response(200, {"result": "success"}) + mock_client_instance.send_request.return_value = mock_response + mock_client_instance.post.return_value = MagicMock() + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + complex_args = { + "string_param": "value", + "number_param": 42, + "bool_param": True, + "list_param": [1, 2, 3], + "nested_param": {"key": "value"} + } + + result = await client.invoke_tool( + sample_resolved_mcp_tool, + arguments=complex_args, + agent_name="test-agent" + ) + + # Verify request was made + mock_client_instance.send_request.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_invoke_tool_with_unsupported_source_raises_error( + self, + mock_pipeline_client_class, + mock_credential, + sample_tool_details + ): + """Test invoke_tool raises ToolInvocationError for unsupported tool source.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + # Create a mock tool with unsupported source + mock_definition = MagicMock() + mock_definition.source = "unsupported_source" + mock_tool = MagicMock(spec=ResolvedFoundryTool) + mock_tool.definition = mock_definition + mock_tool.source = "unsupported_source" + mock_tool.details = sample_tool_details + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + + with pytest.raises(ToolInvocationError) as exc_info: + await client.invoke_tool( + mock_tool, + arguments={"input": "test"}, + agent_name="test-agent" + ) + + assert "Unsupported tool source" in str(exc_info.value) + + +class TestFoundryToolClientClose: + """Tests for FoundryToolClient.close public method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_close_closes_underlying_client( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test close() properly closes the underlying HTTP client.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + client = FoundryToolClient("https://fake-project-endpoint.site", mock_credential) + await client.close() + + mock_client_instance.close.assert_called_once() + + +class TestFoundryToolClientContextManager: + """Tests for FoundryToolClient async context manager protocol.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.client._client.AsyncPipelineClient") + async def test_async_context_manager_enters_and_exits( + self, + mock_pipeline_client_class, + mock_credential + ): + """Test client can be used as async context manager.""" + mock_client_instance = AsyncMock() + mock_pipeline_client_class.return_value = mock_client_instance + + async with FoundryToolClient("https://fake-project-endpoint.site", mock_credential) as client: + assert client is not None + mock_client_instance.__aenter__.assert_called_once() + + mock_client_instance.__aexit__.assert_called_once() + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_configuration.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_configuration.py new file mode 100644 index 000000000000..2f3c2710a3fc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/client/test_configuration.py @@ -0,0 +1,25 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolClientConfiguration.""" + +from azure.core.pipeline import policies + +from azure.ai.agentserver.core.tools.client._configuration import FoundryToolClientConfiguration + + +class TestFoundryToolClientConfiguration: + """Tests for FoundryToolClientConfiguration class.""" + + def test_init_creates_all_required_policies(self, mock_credential): + """Test that initialization creates all required pipeline policies.""" + config = FoundryToolClientConfiguration(mock_credential) + + assert isinstance(config.retry_policy, policies.AsyncRetryPolicy) + assert isinstance(config.logging_policy, policies.NetworkTraceLoggingPolicy) + assert isinstance(config.request_id_policy, policies.RequestIdPolicy) + assert isinstance(config.http_logging_policy, policies.HttpLoggingPolicy) + assert isinstance(config.user_agent_policy, policies.UserAgentPolicy) + assert isinstance(config.authentication_policy, policies.AsyncBearerTokenCredentialPolicy) + assert isinstance(config.redirect_policy, policies.AsyncRedirectPolicy) + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/conftest.py new file mode 100644 index 000000000000..8849ce8aafbf --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/conftest.py @@ -0,0 +1,127 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for tools unit tests.""" +import json +from typing import Any, Dict, Optional +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, + UserInfo, +) + + +@pytest.fixture +def mock_credential(): + """Create a mock async token credential.""" + credential = AsyncMock() + credential.get_token = AsyncMock(return_value=MagicMock(token="test-token")) + return credential + + +@pytest.fixture +def sample_user_info(): + """Create a sample UserInfo instance.""" + return UserInfo(object_id="test-object-id", tenant_id="test-tenant-id") + + +@pytest.fixture +def sample_hosted_mcp_tool(): + """Create a sample FoundryHostedMcpTool.""" + return FoundryHostedMcpTool( + name="test_mcp_tool", + configuration={"model_deployment_name": "gpt-4"} + ) + + +@pytest.fixture +def sample_connected_tool(): + """Create a sample FoundryConnectedTool.""" + return FoundryConnectedTool( + protocol="mcp", + project_connection_id="test-connection-id" + ) + + +@pytest.fixture +def sample_schema_definition(): + """Create a sample SchemaDefinition.""" + return SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input parameter") + }, + required={"input"} + ) + + +@pytest.fixture +def sample_tool_details(sample_schema_definition): + """Create a sample FoundryToolDetails.""" + return FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=sample_schema_definition + ) + + +@pytest.fixture +def sample_resolved_mcp_tool(sample_hosted_mcp_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for MCP.""" + return ResolvedFoundryTool( + definition=sample_hosted_mcp_tool, + details=sample_tool_details + ) + + +@pytest.fixture +def sample_resolved_connected_tool(sample_connected_tool, sample_tool_details): + """Create a sample ResolvedFoundryTool for connected tools.""" + return ResolvedFoundryTool( + definition=sample_connected_tool, + details=sample_tool_details + ) + + +def create_mock_http_response( + status_code: int = 200, + json_data: Optional[Dict[str, Any]] = None +) -> AsyncMock: + """Create a mock HTTP response that simulates real Azure SDK response behavior. + + This mock matches the behavior expected by BaseOperations._extract_response_json, + where response.text() and response.body() are synchronous methods that return + the actual string/bytes values directly. + + :param status_code: HTTP status code. + :param json_data: JSON data to return. + :return: Mock response object. + """ + response = AsyncMock() + response.status_code = status_code + + if json_data is not None: + json_str = json.dumps(json_data) + json_bytes = json_str.encode("utf-8") + # text() and body() are synchronous methods in AsyncHttpResponse + # They must be MagicMock (not AsyncMock) to return values directly when called + response.text = MagicMock(return_value=json_str) + response.body = MagicMock(return_value=json_bytes) + else: + response.text = MagicMock(return_value="") + response.body = MagicMock(return_value=b"") + + # Support async context manager + response.__aenter__ = AsyncMock(return_value=response) + response.__aexit__ = AsyncMock(return_value=None) + + return response diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/__init__.py new file mode 100644 index 000000000000..964fac9d8a55 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Runtime unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/conftest.py new file mode 100644 index 000000000000..52a371bdc958 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/conftest.py @@ -0,0 +1,39 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for runtime unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from unittest.mock import AsyncMock + +import pytest + + +@pytest.fixture +def mock_foundry_tool_client(): + """Create a mock FoundryToolClient.""" + client = AsyncMock() + client.list_tools = AsyncMock(return_value=[]) + client.list_tools_details = AsyncMock(return_value={}) + client.invoke_tool = AsyncMock(return_value={"result": "success"}) + client.__aenter__ = AsyncMock(return_value=client) + client.__aexit__ = AsyncMock(return_value=None) + return client + + +@pytest.fixture +def mock_user_provider(sample_user_info): + """Create a mock UserProvider.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=sample_user_info) + return provider + + +@pytest.fixture +def mock_user_provider_none(): + """Create a mock UserProvider that returns None.""" + provider = AsyncMock() + provider.get_user = AsyncMock(return_value=None) + return provider + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_catalog.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_catalog.py new file mode 100644 index 000000000000..a7013a7537e1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_catalog.py @@ -0,0 +1,350 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _catalog.py - testing public methods of DefaultFoundryToolCatalog.""" +import asyncio +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools import ensure_foundry_tool +from azure.ai.agentserver.core.tools.runtime._catalog import ( + DefaultFoundryToolCatalog, +) +from azure.ai.agentserver.core.tools.client._models import ( + FoundryToolDetails, + ResolvedFoundryTool, + UserInfo, +) + + +class TestFoundryToolCatalogGet: + """Tests for FoundryToolCatalog.get method.""" + + @pytest.mark.asyncio + async def test_get_returns_resolved_tool_when_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details, + sample_user_info + ): + """Test get returns a resolved tool when the tool is found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is not None + assert isinstance(result, ResolvedFoundryTool) + assert result.details == sample_tool_details + + @pytest.mark.asyncio + async def test_get_returns_none_when_not_found( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test get returns None when the tool is not found.""" + mock_foundry_tool_client.list_tools_details = AsyncMock(return_value={}) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.get(sample_hosted_mcp_tool) + + assert result is None + + +class TestDefaultFoundryToolCatalogList: + """Tests for DefaultFoundryToolCatalog.list method.""" + + @pytest.mark.asyncio + async def test_list_returns_empty_list_when_no_tools( + self, + mock_foundry_tool_client, + mock_user_provider + ): + """Test list returns empty list when no tools are provided.""" + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([]) + + assert result == [] + + @pytest.mark.asyncio + async def test_list_returns_resolved_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test list returns resolved tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 1 + assert isinstance(result[0], ResolvedFoundryTool) + assert result[0].definition == sample_hosted_mcp_tool + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_multiple_tools_with_multiple_details( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_connected_tool, + sample_schema_definition + ): + """Test list returns all resolved tools when tools have multiple details.""" + details1 = FoundryToolDetails( + name="tool1", + description="First tool", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="tool2", + description="Second tool", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={ + sample_hosted_mcp_tool.id: [details1], + sample_connected_tool.id: [details2] + } + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool, sample_connected_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"tool1", "tool2"} + + @pytest.mark.asyncio + async def test_list_caches_results_for_hosted_mcp_tools( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that list caches results for hosted MCP tools.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # First call + result1 = await catalog.list([sample_hosted_mcp_tool]) + # Second call should use cache + result2 = await catalog.list([sample_hosted_mcp_tool]) + + # Client should only be called once + assert mock_foundry_tool_client.list_tools_details.call_count == 1 + assert len(result1) == len(result2) == 1 + + @pytest.mark.asyncio + async def test_list_with_facade_dict( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_tool_details + ): + """Test list works with facade dictionaries.""" + facade = {"type": "custom_tool", "config": "value"} + expected_id = ensure_foundry_tool(facade).id + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={expected_id: [sample_tool_details]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([facade]) + + assert len(result) == 1 + assert result[0].details == sample_tool_details + + @pytest.mark.asyncio + async def test_list_returns_multiple_details_per_tool( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_schema_definition + ): + """Test list returns multiple resolved tools when a tool has multiple details.""" + details1 = FoundryToolDetails( + name="function1", + description="First function", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function2", + description="Second function", + input_schema=sample_schema_definition + ) + + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_hosted_mcp_tool.id: [details1, details2]} + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await catalog.list([sample_hosted_mcp_tool]) + + assert len(result) == 2 + names = {r.details.name for r in result} + assert names == {"function1", "function2"} + + @pytest.mark.asyncio + async def test_list_handles_exception_from_client( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test list propagates exception from client and clears cache.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + side_effect=RuntimeError("Network error") + ) + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Network error"): + await catalog.list([sample_hosted_mcp_tool]) + + @pytest.mark.asyncio + async def test_list_connected_tool_cache_key_includes_user( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_tool_details, + sample_user_info + ): + """Test that connected tool cache key includes user info.""" + mock_foundry_tool_client.list_tools_details = AsyncMock( + return_value={sample_connected_tool.id: [sample_tool_details]} + ) + + # Create a new user provider returning a different user + other_user = UserInfo(object_id="other-oid", tenant_id="other-tid") + mock_user_provider2 = AsyncMock() + mock_user_provider2.get_user = AsyncMock(return_value=other_user) + + catalog1 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + catalog2 = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider2, + agent_name="test-agent" + ) + + # Both catalogs should be able to list tools + result1 = await catalog1.list([sample_connected_tool]) + result2 = await catalog2.list([sample_connected_tool]) + + assert len(result1) == 1 + assert len(result2) == 1 + + +class TestCachedFoundryToolCatalogConcurrency: + """Tests for CachedFoundryToolCatalog concurrency handling.""" + + @pytest.mark.asyncio + async def test_concurrent_requests_share_single_fetch( + self, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_tool_details + ): + """Test that concurrent requests for the same tool share a single fetch.""" + call_count = 0 + fetch_event = asyncio.Event() + + async def slow_fetch(*args, **kwargs): + nonlocal call_count + call_count += 1 + await fetch_event.wait() + return {sample_hosted_mcp_tool.id: [sample_tool_details]} + + mock_foundry_tool_client.list_tools_details = slow_fetch + + catalog = DefaultFoundryToolCatalog( + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + # Start two concurrent requests + task1 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + task2 = asyncio.create_task(catalog.list([sample_hosted_mcp_tool])) + + # Allow tasks to start + await asyncio.sleep(0.01) + + # Release the fetch + fetch_event.set() + + results = await asyncio.gather(task1, task2) + + # Both should get results, but fetch should only be called once + assert len(results[0]) == 1 + assert len(results[1]) == 1 + assert call_count == 1 diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_facade.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_facade.py new file mode 100644 index 000000000000..c5377dc339a4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_facade.py @@ -0,0 +1,180 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _facade.py - testing public function ensure_foundry_tool.""" +import pytest + +from azure.ai.agentserver.core.tools.runtime._facade import ensure_foundry_tool +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolProtocol, + FoundryToolSource, +) +from azure.ai.agentserver.core.tools._exceptions import InvalidToolFacadeError + + +class TestEnsureFoundryTool: + """Tests for ensure_foundry_tool public function.""" + + def test_returns_same_instance_when_given_foundry_tool(self, sample_hosted_mcp_tool): + """Test that passing a FoundryTool returns the same instance.""" + result = ensure_foundry_tool(sample_hosted_mcp_tool) + + assert result is sample_hosted_mcp_tool + + def test_returns_same_instance_for_connected_tool(self, sample_connected_tool): + """Test that passing a FoundryConnectedTool returns the same instance.""" + result = ensure_foundry_tool(sample_connected_tool) + + assert result is sample_connected_tool + + def test_converts_facade_with_mcp_protocol_to_connected_tool(self): + """Test that a facade with 'mcp' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "mcp", + "project_connection_id": "my-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.MCP + assert result.project_connection_id == "my-connection" + assert result.source == FoundryToolSource.CONNECTED + + def test_converts_facade_with_a2a_protocol_to_connected_tool(self): + """Test that a facade with 'a2a' protocol is converted to FoundryConnectedTool.""" + facade = { + "type": "a2a", + "project_connection_id": "my-a2a-connection" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.protocol == FoundryToolProtocol.A2A + assert result.project_connection_id == "my-a2a-connection" + + def test_converts_facade_with_unknown_type_to_hosted_mcp_tool(self): + """Test that a facade with unknown type is converted to FoundryHostedMcpTool.""" + facade = { + "type": "my_custom_tool", + "some_config": "value123", + "another_config": True + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "my_custom_tool" + assert result.configuration == {"some_config": "value123", "another_config": True} + assert result.source == FoundryToolSource.HOSTED_MCP + + def test_raises_error_when_type_is_missing(self): + """Test that InvalidToolFacadeError is raised when 'type' is missing.""" + facade = {"project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_empty_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is empty string.""" + facade = {"type": "", "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_type_is_not_string(self): + """Test that InvalidToolFacadeError is raised when 'type' is not a string.""" + facade = {"type": 123, "project_connection_id": "my-connection"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "type" in str(exc_info.value).lower() + + def test_raises_error_when_mcp_protocol_missing_connection_id(self): + """Test that InvalidToolFacadeError is raised when mcp protocol is missing project_connection_id.""" + facade = {"type": "mcp"} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_raises_error_when_a2a_protocol_has_empty_connection_id(self): + """Test that InvalidToolFacadeError is raised when a2a protocol has empty project_connection_id.""" + facade = {"type": "a2a", "project_connection_id": ""} + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "project_connection_id" in str(exc_info.value) + + def test_parses_resource_id_format_connection_id(self): + """Test that resource ID format project_connection_id is parsed correctly.""" + resource_id = ( + "/subscriptions/sub-123/resourceGroups/rg-test/providers/" + "Microsoft.CognitiveServices/accounts/acc-test/projects/proj-test/connections/my-conn-name" + ) + facade = { + "type": "mcp", + "project_connection_id": resource_id + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "my-conn-name" + + def test_raises_error_for_invalid_resource_id_format(self): + """Test that InvalidToolFacadeError is raised for invalid resource ID format.""" + invalid_resource_id = "/subscriptions/sub-123/invalid/path" + facade = { + "type": "mcp", + "project_connection_id": invalid_resource_id + } + + with pytest.raises(InvalidToolFacadeError) as exc_info: + ensure_foundry_tool(facade) + + assert "Invalid resource ID format" in str(exc_info.value) + + def test_uses_simple_connection_name_as_is(self): + """Test that simple connection name is used as-is without parsing.""" + facade = { + "type": "mcp", + "project_connection_id": "simple-connection-name" + } + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryConnectedTool) + assert result.project_connection_id == "simple-connection-name" + + def test_original_facade_not_modified(self): + """Test that the original facade dictionary is not modified.""" + facade = { + "type": "my_tool", + "config_key": "config_value" + } + original_facade = facade.copy() + + ensure_foundry_tool(facade) + + assert facade == original_facade + + def test_hosted_mcp_tool_with_no_extra_configuration(self): + """Test that hosted MCP tool works with no extra configuration.""" + facade = {"type": "simple_tool"} + + result = ensure_foundry_tool(facade) + + assert isinstance(result, FoundryHostedMcpTool) + assert result.name == "simple_tool" + assert result.configuration == {} diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_invoker.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_invoker.py new file mode 100644 index 000000000000..b2a222c09d6e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_invoker.py @@ -0,0 +1,198 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _invoker.py - testing public methods of DefaultFoundryToolInvoker.""" +import pytest +from unittest.mock import AsyncMock + +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker + + +class TestDefaultFoundryToolInvokerResolvedTool: + """Tests for DefaultFoundryToolInvoker.resolved_tool property.""" + + def test_resolved_tool_returns_tool_passed_at_init( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns the tool passed during initialization.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_mcp_tool + + def test_resolved_tool_returns_connected_tool( + self, + sample_resolved_connected_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test resolved_tool property returns connected tool.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_connected_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + assert invoker.resolved_tool is sample_resolved_connected_tool + + +class TestDefaultFoundryToolInvokerInvoke: + """Tests for DefaultFoundryToolInvoker.invoke method.""" + + @pytest.mark.asyncio + async def test_invoke_calls_client_with_correct_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke calls client.invoke_tool with correct arguments.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + arguments = {"input": "test value", "count": 5} + + await invoker.invoke(arguments) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + arguments, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_returns_result_from_client( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke returns the result from client.invoke_tool.""" + expected_result = {"output": "test result", "status": "completed"} + mock_foundry_tool_client.invoke_tool = AsyncMock(return_value=expected_result) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + result = await invoker.invoke({"input": "test"}) + + assert result == expected_result + + @pytest.mark.asyncio + async def test_invoke_with_empty_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke works with empty arguments dictionary.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke({}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {}, + "test-agent", + sample_user_info + ) + + @pytest.mark.asyncio + async def test_invoke_with_none_user( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider_none + ): + """Test invoke works when user provider returns None.""" + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider_none, + agent_name="test-agent" + ) + + await invoker.invoke({"input": "test"}) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + {"input": "test"}, + "test-agent", + None + ) + + @pytest.mark.asyncio + async def test_invoke_propagates_client_exception( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider + ): + """Test invoke propagates exceptions from client.invoke_tool.""" + mock_foundry_tool_client.invoke_tool = AsyncMock( + side_effect=RuntimeError("Client error") + ) + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(RuntimeError, match="Client error"): + await invoker.invoke({"input": "test"}) + + @pytest.mark.asyncio + async def test_invoke_with_complex_nested_arguments( + self, + sample_resolved_mcp_tool, + mock_foundry_tool_client, + mock_user_provider, + sample_user_info + ): + """Test invoke with complex nested argument structure.""" + complex_args = { + "nested": {"key1": "value1", "key2": 123}, + "list": [1, 2, 3], + "mixed": [{"a": 1}, {"b": 2}] + } + + invoker = DefaultFoundryToolInvoker( + resolved_tool=sample_resolved_mcp_tool, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + await invoker.invoke(complex_args) + + mock_foundry_tool_client.invoke_tool.assert_called_once_with( + sample_resolved_mcp_tool, + complex_args, + "test-agent", + sample_user_info + ) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_resolver.py new file mode 100644 index 000000000000..7bdaa8f957a9 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_resolver.py @@ -0,0 +1,202 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _resolver.py - testing public methods of DefaultFoundryToolInvocationResolver.""" +import pytest +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._invoker import DefaultFoundryToolInvoker +from azure.ai.agentserver.core.tools._exceptions import UnableToResolveToolInvocationError +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, +) + + +class TestDefaultFoundryToolInvocationResolverResolve: + """Tests for DefaultFoundryToolInvocationResolver.resolve method.""" + + @pytest.fixture + def mock_catalog(self, sample_resolved_mcp_tool): + """Create a mock FoundryToolCatalog.""" + catalog = AsyncMock() + catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + catalog.list = AsyncMock(return_value=[sample_resolved_mcp_tool]) + return catalog + + @pytest.mark.asyncio + async def test_resolve_with_resolved_tool_returns_invoker_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve returns invoker directly when given ResolvedFoundryTool.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + assert invoker.resolved_tool is sample_resolved_mcp_tool + # Catalog should not be called when ResolvedFoundryTool is passed + mock_catalog.get.assert_not_called() + + @pytest.mark.asyncio + async def test_resolve_with_foundry_tool_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool, + sample_resolved_mcp_tool + ): + """Test resolve uses catalog to resolve FoundryTool.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_hosted_mcp_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_hosted_mcp_tool) + + @pytest.mark.asyncio + async def test_resolve_with_facade_dict_uses_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_connected_tool + ): + """Test resolve converts facade dict and uses catalog.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + facade = { + "type": "mcp", + "project_connection_id": "test-connection" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once() + # Verify the facade was converted to FoundryConnectedTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryConnectedTool) + + @pytest.mark.asyncio + async def test_resolve_raises_error_when_tool_not_found_in_catalog( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_hosted_mcp_tool + ): + """Test resolve raises UnableToResolveToolInvocationError when catalog returns None.""" + mock_catalog.get = AsyncMock(return_value=None) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + with pytest.raises(UnableToResolveToolInvocationError) as exc_info: + await resolver.resolve(sample_hosted_mcp_tool) + + assert exc_info.value.tool is sample_hosted_mcp_tool + assert "Unable to resolve tool" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_resolve_with_hosted_mcp_facade( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve with hosted MCP facade (unknown type becomes FoundryHostedMcpTool).""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_mcp_tool) + facade = { + "type": "custom_mcp_tool", + "config_key": "config_value" + } + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(facade) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + # Verify the facade was converted to FoundryHostedMcpTool + call_arg = mock_catalog.get.call_args[0][0] + assert isinstance(call_arg, FoundryHostedMcpTool) + assert call_arg.name == "custom_mcp_tool" + + @pytest.mark.asyncio + async def test_resolve_returns_invoker_with_correct_agent_name( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_resolved_mcp_tool + ): + """Test resolve creates invoker with the correct agent name.""" + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="custom-agent-name" + ) + + invoker = await resolver.resolve(sample_resolved_mcp_tool) + + # Verify invoker was created with correct agent name by checking internal state + assert invoker._agent_name == "custom-agent-name" + + @pytest.mark.asyncio + async def test_resolve_with_connected_tool_directly( + self, + mock_catalog, + mock_foundry_tool_client, + mock_user_provider, + sample_connected_tool, + sample_resolved_connected_tool + ): + """Test resolve with FoundryConnectedTool directly.""" + mock_catalog.get = AsyncMock(return_value=sample_resolved_connected_tool) + + resolver = DefaultFoundryToolInvocationResolver( + catalog=mock_catalog, + client=mock_foundry_tool_client, + user_provider=mock_user_provider, + agent_name="test-agent" + ) + + invoker = await resolver.resolve(sample_connected_tool) + + assert isinstance(invoker, DefaultFoundryToolInvoker) + mock_catalog.get.assert_called_once_with(sample_connected_tool) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_runtime.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_runtime.py new file mode 100644 index 000000000000..a99935662941 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_runtime.py @@ -0,0 +1,401 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _runtime.py - testing public methods of DefaultFoundryToolRuntime.""" +import os +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from azure.ai.agentserver.core.tools.runtime._runtime import ( + create_tool_runtime, + DefaultFoundryToolRuntime, + ThrowingFoundryToolRuntime, +) +from azure.ai.agentserver.core.tools.runtime._catalog import DefaultFoundryToolCatalog +from azure.ai.agentserver.core.tools.runtime._resolver import DefaultFoundryToolInvocationResolver +from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + +class TestDefaultFoundryToolRuntimeInit: + """Tests for DefaultFoundryToolRuntime initialization.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_creates_client_with_endpoint_and_credential( + self, + mock_client_class, + mock_credential + ): + """Test initialization creates client with correct endpoint and credential.""" + endpoint = "https://test-project.azure.com" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint=endpoint, + credential=mock_credential + ) + + mock_client_class.assert_called_once_with( + endpoint=endpoint, + credential=mock_credential + ) + assert runtime is not None + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_user_provider_when_none_provided( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses ContextVarUserProvider when user_provider is None.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime._user_provider, ContextVarUserProvider) + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_custom_user_provider( + self, + mock_client_class, + mock_credential, + mock_user_provider + ): + """Test initialization uses custom user provider when provided.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential, + user_provider=mock_user_provider + ) + + assert runtime._user_provider is mock_user_provider + + @patch.dict(os.environ, {"AGENT_NAME": "custom-agent"}) + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_reads_agent_name_from_environment( + self, + mock_client_class, + mock_credential + ): + """Test initialization reads agent name from environment variable.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "custom-agent" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_init_uses_default_agent_name_when_env_not_set( + self, + mock_client_class, + mock_credential + ): + """Test initialization uses default agent name when env var is not set.""" + mock_client_class.return_value = MagicMock() + + # Ensure AGENT_NAME is not set + env_copy = os.environ.copy() + if "AGENT_NAME" in env_copy: + del env_copy["AGENT_NAME"] + + with patch.dict(os.environ, env_copy, clear=True): + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert runtime._agent_name == "$default" + + +class TestDefaultFoundryToolRuntimeCatalog: + """Tests for DefaultFoundryToolRuntime.catalog property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_catalog_returns_default_catalog( + self, + mock_client_class, + mock_credential + ): + """Test catalog property returns DefaultFoundryToolCatalog.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.catalog, DefaultFoundryToolCatalog) + + +class TestDefaultFoundryToolRuntimeInvocation: + """Tests for DefaultFoundryToolRuntime.invocation property.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_invocation_returns_default_resolver( + self, + mock_client_class, + mock_credential + ): + """Test invocation property returns DefaultFoundryToolInvocationResolver.""" + mock_client_class.return_value = MagicMock() + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + assert isinstance(runtime.invocation, DefaultFoundryToolInvocationResolver) + + +class TestDefaultFoundryToolRuntimeInvoke: + """Tests for DefaultFoundryToolRuntime.invoke method.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_resolves_and_invokes_tool( + self, + mock_client_class, + mock_credential, + sample_resolved_mcp_tool + ): + """Test invoke resolves the tool and calls the invoker.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"result": "success"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(sample_resolved_mcp_tool, {"input": "test"}) + + assert result == {"result": "success"} + runtime._invocation.resolve.assert_called_once_with(sample_resolved_mcp_tool) + mock_invoker.invoke.assert_called_once_with({"input": "test"}) + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_invoke_with_facade_dict( + self, + mock_client_class, + mock_credential + ): + """Test invoke works with facade dictionary.""" + mock_client_instance = MagicMock() + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + facade = {"type": "custom_tool", "config": "value"} + + # Mock the invocation resolver + mock_invoker = AsyncMock() + mock_invoker.invoke = AsyncMock(return_value={"output": "done"}) + runtime._invocation.resolve = AsyncMock(return_value=mock_invoker) + + result = await runtime.invoke(facade, {"param": "value"}) + + assert result == {"output": "done"} + runtime._invocation.resolve.assert_called_once_with(facade) + + +class TestDefaultFoundryToolRuntimeContextManager: + """Tests for DefaultFoundryToolRuntime async context manager.""" + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aenter_returns_runtime_and_enters_client( + self, + mock_client_class, + mock_credential + ): + """Test __aenter__ enters client and returns runtime.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime as r: + assert r is runtime + mock_client_instance.__aenter__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_exits_client( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ exits client properly.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + async with runtime: + pass + + mock_client_instance.__aexit__.assert_called_once() + + @pytest.mark.asyncio + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + async def test_aexit_called_on_exception( + self, + mock_client_class, + mock_credential + ): + """Test __aexit__ is called even when exception occurs.""" + mock_client_instance = AsyncMock() + mock_client_instance.__aenter__ = AsyncMock(return_value=mock_client_instance) + mock_client_instance.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client_instance + + runtime = DefaultFoundryToolRuntime( + project_endpoint="https://test.azure.com", + credential=mock_credential + ) + + with pytest.raises(ValueError): + async with runtime: + raise ValueError("Test error") + + mock_client_instance.__aexit__.assert_called_once() + + +class TestCreateToolRuntime: + """Tests for create_tool_runtime factory function.""" + + @patch("azure.ai.agentserver.core.tools.runtime._runtime.FoundryToolClient") + def test_create_tool_runtime_returns_default_runtime_with_valid_params( + self, + mock_client_class, + mock_credential + ): + """Test create_tool_runtime returns DefaultFoundryToolRuntime when both params are provided.""" + mock_client_class.return_value = MagicMock() + endpoint = "https://test-project.azure.com" + + runtime = create_tool_runtime(project_endpoint=endpoint, credential=mock_credential) + + assert isinstance(runtime, DefaultFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_endpoint_is_none( + self, + mock_credential + ): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when endpoint is None.""" + runtime = create_tool_runtime(project_endpoint=None, credential=mock_credential) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_credential_is_none(self): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when credential is None.""" + runtime = create_tool_runtime(project_endpoint="https://test.azure.com", credential=None) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_both_are_none(self): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when both params are None.""" + runtime = create_tool_runtime(project_endpoint=None, credential=None) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + def test_create_tool_runtime_returns_throwing_runtime_when_endpoint_is_empty_string( + self, + mock_credential + ): + """Test create_tool_runtime returns ThrowingFoundryToolRuntime when endpoint is empty string.""" + runtime = create_tool_runtime(project_endpoint="", credential=mock_credential) + + assert isinstance(runtime, ThrowingFoundryToolRuntime) + + +class TestThrowingFoundryToolRuntime: + """Tests for ThrowingFoundryToolRuntime.""" + + def test_catalog_raises_runtime_error(self): + """Test catalog property raises RuntimeError.""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(RuntimeError) as exc_info: + _ = runtime.catalog + + assert "FoundryToolRuntime is not configured" in str(exc_info.value) + assert "project endpoint and credential" in str(exc_info.value) + + def test_invocation_raises_runtime_error(self): + """Test invocation property raises RuntimeError.""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(RuntimeError) as exc_info: + _ = runtime.invocation + + assert "FoundryToolRuntime is not configured" in str(exc_info.value) + assert "project endpoint and credential" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_invoke_raises_runtime_error(self): + """Test invoke method raises RuntimeError (via invocation property).""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(RuntimeError) as exc_info: + await runtime.invoke({"type": "test"}, {"arg": "value"}) + + assert "FoundryToolRuntime is not configured" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_aenter_returns_self(self): + """Test __aenter__ returns the runtime instance.""" + runtime = ThrowingFoundryToolRuntime() + + async with runtime as r: + assert r is runtime + + @pytest.mark.asyncio + async def test_aexit_completes_successfully(self): + """Test __aexit__ completes without error.""" + runtime = ThrowingFoundryToolRuntime() + + # Should not raise any exception + async with runtime: + pass + + @pytest.mark.asyncio + async def test_context_manager_does_not_suppress_exceptions(self): + """Test context manager does not suppress exceptions.""" + runtime = ThrowingFoundryToolRuntime() + + with pytest.raises(ValueError): + async with runtime: + raise ValueError("Test error") + + def test_error_message_is_class_variable(self): + """Test _ERROR_MESSAGE is defined as a class variable.""" + assert hasattr(ThrowingFoundryToolRuntime, "_ERROR_MESSAGE") + assert isinstance(ThrowingFoundryToolRuntime._ERROR_MESSAGE, str) + diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_starlette.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_starlette.py new file mode 100644 index 000000000000..d1d72004d011 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_starlette.py @@ -0,0 +1,261 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _starlette.py - testing public methods of UserInfoContextMiddleware.""" +import pytest +from contextvars import ContextVar +from unittest.mock import AsyncMock, MagicMock + +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestUserInfoContextMiddlewareInstall: + """Tests for UserInfoContextMiddleware.install class method.""" + + def test_install_adds_middleware_to_starlette_app(self): + """Test install adds middleware to Starlette application.""" + # Import here to avoid requiring starlette when not needed + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + mock_app.add_middleware.assert_called_once() + call_args = mock_app.add_middleware.call_args + assert call_args[0][0] == UserInfoContextMiddleware + + def test_install_uses_default_context_when_none_provided(self): + """Test install uses default user context when none is provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + from azure.ai.agentserver.core.tools.runtime._user import ContextVarUserProvider + + mock_app = MagicMock() + + UserInfoContextMiddleware.install(mock_app) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is ContextVarUserProvider.default_user_info_context + + def test_install_uses_custom_context(self): + """Test install uses custom user context when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + custom_context = ContextVar("custom_context") + + UserInfoContextMiddleware.install(mock_app, user_context=custom_context) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_info_var"] is custom_context + + def test_install_uses_custom_resolver(self): + """Test install uses custom user resolver when provided.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_app = MagicMock() + + async def custom_resolver(request): + return UserInfo(object_id="custom-oid", tenant_id="custom-tid") + + UserInfoContextMiddleware.install(mock_app, user_resolver=custom_resolver) + + call_kwargs = mock_app.add_middleware.call_args[1] + assert call_kwargs["user_resolver"] is custom_resolver + + +class TestUserInfoContextMiddlewareDispatch: + """Tests for UserInfoContextMiddleware.dispatch method.""" + + @pytest.mark.asyncio + async def test_dispatch_sets_user_in_context(self): + """Test dispatch sets user info in context variable.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + user_info = UserInfo(object_id="test-oid", tenant_id="test-tid") + + async def mock_resolver(request): + return user_info + + # Create a simple mock app + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = None + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get(None) + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is user_info + + @pytest.mark.asyncio + async def test_dispatch_resets_context_after_request(self): + """Test dispatch resets context variable after request completes.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + # During request, should have new_user + assert user_context.get(None) is new_user + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + # After request, context should be reset to original value + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_resets_context_on_exception(self): + """Test dispatch resets context even when call_next raises exception.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + original_user = UserInfo(object_id="original-oid", tenant_id="original-tid") + user_context.set(original_user) + + new_user = UserInfo(object_id="new-oid", tenant_id="new-tid") + + async def mock_resolver(request): + return new_user + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + + async def call_next(request): + raise RuntimeError("Request failed") + + with pytest.raises(RuntimeError, match="Request failed"): + await middleware.dispatch(mock_request, call_next) + + # Context should still be reset to original + assert user_context.get(None) is original_user + + @pytest.mark.asyncio + async def test_dispatch_handles_none_user(self): + """Test dispatch handles None user from resolver.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + + async def mock_resolver(request): + return None + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + captured_user = "not_set" + + async def call_next(request): + nonlocal captured_user + captured_user = user_context.get("default") + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_user is None + + @pytest.mark.asyncio + async def test_dispatch_calls_resolver_with_request(self): + """Test dispatch calls user resolver with the request object.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + user_context = ContextVar("test_context") + captured_request = None + + async def mock_resolver(request): + nonlocal captured_request + captured_request = request + return UserInfo(object_id="oid", tenant_id="tid") + + mock_app = AsyncMock() + + middleware = UserInfoContextMiddleware( + app=mock_app, + user_info_var=user_context, + user_resolver=mock_resolver + ) + + mock_request = MagicMock() + mock_request.url = "https://test.com/api" + + async def call_next(request): + return MagicMock() + + await middleware.dispatch(mock_request, call_next) + + assert captured_request is mock_request + + +class TestUserInfoContextMiddlewareDefaultResolver: + """Tests for UserInfoContextMiddleware default resolver.""" + + @pytest.mark.asyncio + async def test_default_resolver_extracts_user_from_headers(self): + """Test default resolver extracts user info from request headers.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = { + "x-aml-oid": "header-object-id", + "x-aml-tid": "header-tenant-id" + } + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is not None + assert result.object_id == "header-object-id" + assert result.tenant_id == "header-tenant-id" + + @pytest.mark.asyncio + async def test_default_resolver_returns_none_when_headers_missing(self): + """Test default resolver returns None when required headers are missing.""" + from azure.ai.agentserver.core.tools.runtime._starlette import UserInfoContextMiddleware + + mock_request = MagicMock() + mock_request.headers = {} + + result = await UserInfoContextMiddleware._default_user_resolver(mock_request) + + assert result is None diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_user.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_user.py new file mode 100644 index 000000000000..a909d9e5948a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/runtime/test_user.py @@ -0,0 +1,210 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _user.py - testing public methods of ContextVarUserProvider and resolve_user_from_headers.""" +import pytest +from contextvars import ContextVar + +from azure.ai.agentserver.core.tools.runtime._user import ( + ContextVarUserProvider, + resolve_user_from_headers, +) +from azure.ai.agentserver.core.tools.client._models import UserInfo + + +class TestContextVarUserProvider: + """Tests for ContextVarUserProvider public methods.""" + + @pytest.mark.asyncio + async def test_get_user_returns_none_when_context_not_set(self): + """Test get_user returns None when context variable is not set.""" + custom_context = ContextVar("test_user_context") + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is None + + @pytest.mark.asyncio + async def test_get_user_returns_user_when_context_is_set(self, sample_user_info): + """Test get_user returns UserInfo when context variable is set.""" + custom_context = ContextVar("test_user_context") + custom_context.set(sample_user_info) + provider = ContextVarUserProvider(context=custom_context) + + result = await provider.get_user() + + assert result is sample_user_info + assert result.object_id == "test-object-id" + assert result.tenant_id == "test-tenant-id" + + @pytest.mark.asyncio + async def test_uses_default_context_when_none_provided(self, sample_user_info): + """Test that default context is used when no context is provided.""" + # Set value in default context + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider = ContextVarUserProvider() + + result = await provider.get_user() + + assert result is sample_user_info + + @pytest.mark.asyncio + async def test_different_providers_share_same_default_context(self, sample_user_info): + """Test that different providers using default context share the same value.""" + ContextVarUserProvider.default_user_info_context.set(sample_user_info) + provider1 = ContextVarUserProvider() + provider2 = ContextVarUserProvider() + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is result2 is sample_user_info + + @pytest.mark.asyncio + async def test_custom_context_isolation(self, sample_user_info): + """Test that custom contexts are isolated from each other.""" + context1 = ContextVar("context1") + context2 = ContextVar("context2") + user2 = UserInfo(object_id="other-oid", tenant_id="other-tid") + + context1.set(sample_user_info) + context2.set(user2) + + provider1 = ContextVarUserProvider(context=context1) + provider2 = ContextVarUserProvider(context=context2) + + result1 = await provider1.get_user() + result2 = await provider2.get_user() + + assert result1 is sample_user_info + assert result2 is user2 + assert result1 is not result2 + + +class TestResolveUserFromHeaders: + """Tests for resolve_user_from_headers public function.""" + + def test_returns_user_info_when_both_headers_present(self): + """Test returns UserInfo when both object_id and tenant_id headers are present.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is not None + assert isinstance(result, UserInfo) + assert result.object_id == "user-object-id" + assert result.tenant_id == "user-tenant-id" + + def test_returns_none_when_object_id_missing(self): + """Test returns None when object_id header is missing.""" + headers = {"x-aml-tid": "user-tenant-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_missing(self): + """Test returns None when tenant_id header is missing.""" + headers = {"x-aml-oid": "user-object-id"} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_both_headers_missing(self): + """Test returns None when both headers are missing.""" + headers = {} + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_object_id_is_empty(self): + """Test returns None when object_id is empty string.""" + headers = { + "x-aml-oid": "", + "x-aml-tid": "user-tenant-id" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_returns_none_when_tenant_id_is_empty(self): + """Test returns None when tenant_id is empty string.""" + headers = { + "x-aml-oid": "user-object-id", + "x-aml-tid": "" + } + + result = resolve_user_from_headers(headers) + + assert result is None + + def test_custom_header_names(self): + """Test using custom header names for object_id and tenant_id.""" + headers = { + "custom-oid-header": "custom-object-id", + "custom-tid-header": "custom-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid-header", + tenant_id_header="custom-tid-header" + ) + + assert result is not None + assert result.object_id == "custom-object-id" + assert result.tenant_id == "custom-tenant-id" + + def test_default_headers_not_matched_with_custom_headers(self): + """Test that default headers are not matched when custom headers are specified.""" + headers = { + "x-aml-oid": "default-object-id", + "x-aml-tid": "default-tenant-id" + } + + result = resolve_user_from_headers( + headers, + object_id_header="custom-oid", + tenant_id_header="custom-tid" + ) + + assert result is None + + def test_case_sensitive_header_matching(self): + """Test that header matching is case-sensitive.""" + headers = { + "X-AML-OID": "user-object-id", + "X-AML-TID": "user-tenant-id" + } + + # Default headers are lowercase, so these should not match + result = resolve_user_from_headers(headers) + + assert result is None + + def test_with_mapping_like_object(self): + """Test with a mapping-like object that supports .get().""" + class HeadersMapping: + def __init__(self, data): + self._data = data + + def get(self, key, default=""): + return self._data.get(key, default) + + headers = HeadersMapping({ + "x-aml-oid": "mapping-object-id", + "x-aml-tid": "mapping-tenant-id" + }) + + result = resolve_user_from_headers(headers) + + assert result is not None + assert result.object_id == "mapping-object-id" + assert result.tenant_id == "mapping-tenant-id" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/__init__.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/__init__.py new file mode 100644 index 000000000000..2d7503de198d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Utils unit tests package.""" diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/conftest.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/conftest.py new file mode 100644 index 000000000000..abd2f5145c29 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/conftest.py @@ -0,0 +1,56 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for utils unit tests. + +Common fixtures are inherited from the parent conftest.py automatically by pytest. +""" +from typing import Optional + +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaType, +) + + +def create_resolved_tool_with_name( + name: str, + tool_type: str = "mcp", + connection_id: Optional[str] = None +) -> ResolvedFoundryTool: + """Helper to create a ResolvedFoundryTool with a specific name. + + :param name: The name for the tool details. + :param tool_type: Either "mcp" or "connected". + :param connection_id: Connection ID for connected tools. If provided with tool_type="mcp", + will automatically use "connected" type to ensure unique tool IDs. + :return: A ResolvedFoundryTool instance. + """ + schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={}, + required=set() + ) + details = FoundryToolDetails( + name=name, + description=f"Tool named {name}", + input_schema=schema + ) + + # If connection_id is provided, use connected tool to ensure unique IDs + if connection_id is not None or tool_type == "connected": + definition = FoundryConnectedTool( + protocol="mcp", + project_connection_id=connection_id or f"conn-{name}" + ) + else: + definition = FoundryHostedMcpTool( + name=f"mcp-{name}", + configuration={} + ) + + return ResolvedFoundryTool(definition=definition, details=details) diff --git a/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/test_name_resolver.py b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/test_name_resolver.py new file mode 100644 index 000000000000..14340799253b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-core/tests/unit_tests/tools/utils/test_name_resolver.py @@ -0,0 +1,260 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for _name_resolver.py - testing public methods of ToolNameResolver.""" +from azure.ai.agentserver.core.tools.utils import ToolNameResolver +from azure.ai.agentserver.core.tools.client._models import ( + FoundryConnectedTool, + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, +) + +from .conftest import create_resolved_tool_with_name + + +class TestToolNameResolverResolve: + """Tests for ToolNameResolver.resolve method.""" + + def test_resolve_returns_tool_name_for_first_occurrence( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the original tool name for first occurrence.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_mcp_tool) + + assert result == sample_resolved_mcp_tool.details.name + + def test_resolve_returns_same_name_for_same_tool( + self, + sample_resolved_mcp_tool + ): + """Test resolve returns the same name when called multiple times for same tool.""" + resolver = ToolNameResolver() + + result1 = resolver.resolve(sample_resolved_mcp_tool) + result2 = resolver.resolve(sample_resolved_mcp_tool) + result3 = resolver.resolve(sample_resolved_mcp_tool) + + assert result1 == result2 == result3 + assert result1 == sample_resolved_mcp_tool.details.name + + def test_resolve_appends_count_for_duplicate_names(self): + """Test resolve appends count for tools with duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + tool3 = create_resolved_tool_with_name("my_tool", connection_id="conn-3") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "my_tool" + assert result2 == "my_tool_1" + assert result3 == "my_tool_2" + + def test_resolve_handles_multiple_unique_names(self): + """Test resolve handles multiple tools with unique names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_alpha") + tool2 = create_resolved_tool_with_name("tool_beta") + tool3 = create_resolved_tool_with_name("tool_gamma") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + result3 = resolver.resolve(tool3) + + assert result1 == "tool_alpha" + assert result2 == "tool_beta" + assert result3 == "tool_gamma" + + def test_resolve_mixed_unique_and_duplicate_names(self): + """Test resolve handles a mix of unique and duplicate names.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("shared_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("unique_name") + tool3 = create_resolved_tool_with_name("shared_name", connection_id="conn-2") + tool4 = create_resolved_tool_with_name("another_unique") + tool5 = create_resolved_tool_with_name("shared_name", connection_id="conn-3") + + assert resolver.resolve(tool1) == "shared_name" + assert resolver.resolve(tool2) == "unique_name" + assert resolver.resolve(tool3) == "shared_name_1" + assert resolver.resolve(tool4) == "another_unique" + assert resolver.resolve(tool5) == "shared_name_2" + + def test_resolve_returns_cached_name_after_duplicate_added(self): + """Test that resolving a tool again returns cached name even after duplicates are added.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my_tool", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my_tool", connection_id="conn-2") + + # First resolution + first_result = resolver.resolve(tool1) + assert first_result == "my_tool" + + # Add duplicate + dup_result = resolver.resolve(tool2) + assert dup_result == "my_tool_1" + + # Resolve original again - should return cached value + second_result = resolver.resolve(tool1) + assert second_result == "my_tool" + + def test_resolve_with_connected_tool( + self, + sample_resolved_connected_tool + ): + """Test resolve works with connected tools.""" + resolver = ToolNameResolver() + + result = resolver.resolve(sample_resolved_connected_tool) + + assert result == sample_resolved_connected_tool.details.name + + def test_resolve_different_tools_same_details_name(self, sample_schema_definition): + """Test resolve handles different tool definitions with same details name.""" + resolver = ToolNameResolver() + + details = FoundryToolDetails( + name="shared_function", + description="A shared function", + input_schema=sample_schema_definition + ) + + mcp_def = FoundryHostedMcpTool(name="mcp_server", configuration={}) + connected_def = FoundryConnectedTool(protocol="mcp", project_connection_id="my-conn") + + tool1 = ResolvedFoundryTool(definition=mcp_def, details=details) + tool2 = ResolvedFoundryTool(definition=connected_def, details=details) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "shared_function" + assert result2 == "shared_function_1" + + def test_resolve_empty_name(self): + """Test resolve handles tools with empty name.""" + resolver = ToolNameResolver() + + tool = create_resolved_tool_with_name("") + + result = resolver.resolve(tool) + + assert result == "" + + def test_resolve_special_characters_in_name(self): + """Test resolve handles tools with special characters in name.""" + resolver = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("my-tool_v1.0", connection_id="conn-2") + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + assert result1 == "my-tool_v1.0" + assert result2 == "my-tool_v1.0_1" + + def test_independent_resolver_instances(self): + """Test that different resolver instances maintain independent state.""" + resolver1 = ToolNameResolver() + resolver2 = ToolNameResolver() + + tool1 = create_resolved_tool_with_name("tool_name", connection_id="conn-1") + tool2 = create_resolved_tool_with_name("tool_name", connection_id="conn-2") + + # Both resolvers resolve tool1 first + assert resolver1.resolve(tool1) == "tool_name" + assert resolver2.resolve(tool1) == "tool_name" + + # resolver1 resolves tool2 as duplicate + assert resolver1.resolve(tool2) == "tool_name_1" + + # resolver2 has not seen tool2 yet in its context + # but tool2 has same name, so it should be duplicate + assert resolver2.resolve(tool2) == "tool_name_1" + + def test_resolve_many_duplicates(self): + """Test resolve handles many tools with the same name.""" + resolver = ToolNameResolver() + + tools = [ + create_resolved_tool_with_name("common_name", connection_id=f"conn-{i}") + for i in range(10) + ] + + results = [resolver.resolve(tool) for tool in tools] + + expected = ["common_name"] + [f"common_name_{i}" for i in range(1, 10)] + assert results == expected + + def test_resolve_uses_tool_id_for_caching(self, sample_schema_definition): + """Test that resolve uses tool.id for caching, not just name.""" + resolver = ToolNameResolver() + + # Create two tools with same definition but different details names + definition = FoundryHostedMcpTool(name="same_definition", configuration={}) + + details1 = FoundryToolDetails( + name="function_a", + description="Function A", + input_schema=sample_schema_definition + ) + details2 = FoundryToolDetails( + name="function_b", + description="Function B", + input_schema=sample_schema_definition + ) + + tool1 = ResolvedFoundryTool(definition=definition, details=details1) + tool2 = ResolvedFoundryTool(definition=definition, details=details2) + + result1 = resolver.resolve(tool1) + result2 = resolver.resolve(tool2) + + # Both should get their respective names since they have different tool.id + assert result1 == "function_a" + assert result2 == "function_b" + + def test_resolve_idempotent_for_same_tool_id(self, sample_schema_definition): + """Test that resolve is idempotent for the same tool id.""" + resolver = ToolNameResolver() + + definition = FoundryHostedMcpTool(name="my_mcp", configuration={}) + details = FoundryToolDetails( + name="my_function", + description="My function", + input_schema=sample_schema_definition + ) + tool = ResolvedFoundryTool(definition=definition, details=details) + + # Call resolve many times + results = [resolver.resolve(tool) for _ in range(5)] + + # All should return the same name + assert all(r == "my_function" for r in results) + + def test_resolve_interleaved_tool_resolutions(self): + """Test resolve with interleaved resolutions of different tools.""" + resolver = ToolNameResolver() + + toolA_1 = create_resolved_tool_with_name("A", connection_id="A-1") + toolA_2 = create_resolved_tool_with_name("A", connection_id="A-2") + toolB_1 = create_resolved_tool_with_name("B", connection_id="B-1") + toolA_3 = create_resolved_tool_with_name("A", connection_id="A-3") + toolB_2 = create_resolved_tool_with_name("B", connection_id="B-2") + + assert resolver.resolve(toolA_1) == "A" + assert resolver.resolve(toolB_1) == "B" + assert resolver.resolve(toolA_2) == "A_1" + assert resolver.resolve(toolA_3) == "A_2" + assert resolver.resolve(toolB_2) == "B_1" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md index cfcf2445e256..1d94fe7ba05e 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/CHANGELOG.md @@ -1,5 +1,89 @@ # Release History + +## 1.0.0b10 (2026-01-27) + +### Bugs Fixed + +- Make AZURE_AI_PROJECTS_ENDPOINT optional. + +## 1.0.0b9 (2026-01-23) + +### Features Added + +- Integrated with Foundry Tools + +- Support Human-in-the-Loop + +- Added Response API converters for request conversion and orchestration. + + +## 1.0.0b8 (2026-01-21) + +### Features Added + +- Support keep alive for long-running streaming responses. + + +## 1.0.0b7 (2025-12-05) + +### Features Added + +- Update response with created_by + +### Bugs Fixed + +- Fixed error response handling in stream and non-stream modes + +## 1.0.0b6 (2025-11-26) + +### Feature Added + +- Support Agent-framework greater than 251112 + + +## 1.0.0b5 (2025-11-16) + +### Feature Added + +- Support Tools Oauth + +### Bugs Fixed + +- Fixed streaming generation issues. + + +## 1.0.0b4 (2025-11-13) + +### Feature Added + +- Adapters support tools + +### Bugs Fixed + +- Pin azure-ai-projects and azure-ai-agents version to avoid version confliction + + +## 1.0.0b3 (2025-11-11) + +### Bugs Fixed + +- Fixed Id generator format. + +- Fixed trace initialization for agent-framework. + + +## 1.0.0b2 (2025-11-10) + +### Bugs Fixed + +- Fixed Id generator format. + +- Improved stream mode error messsage. + +- Updated application insights related configuration environment variables. + + ## 1.0.0b1 (2025-11-07) ### Features Added diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/README.md b/sdk/agentserver/azure-ai-agentserver-langgraph/README.md index 1c1eaab6837e..970c14df9a85 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/README.md +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/README.md @@ -28,12 +28,6 @@ if __name__ == "__main__": ``` -**Note** -If your langgraph agent was not using langgraph's builtin [MessageState](https://langchain-ai.github.io/langgraph/concepts/low_level/?h=messagesstate#messagesstate), you should implement your own `LanggraphStateConverter` and provide to `from_langgraph`. - -Reference this [example](https://github.com/Azure/azure-sdk-for-python/blob/main/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/main.py) for more details. - - ## Troubleshooting First run your agent with azure-ai-agentserver-langgraph locally. diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py index ed2e0d4d493a..a3eef4358564 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/__init__.py @@ -3,19 +3,43 @@ # --------------------------------------------------------- __path__ = __import__("pkgutil").extend_path(__path__, __name__) -from typing import TYPE_CHECKING, Optional +from typing import Optional, Union, TYPE_CHECKING +from azure.ai.agentserver.core.application import PackageMetadata, set_current_app + +from ._context import LanggraphRunContext from ._version import VERSION +from .langgraph import LangGraphAdapter if TYPE_CHECKING: # pragma: no cover - from . import models - - -def from_langgraph(agent, state_converter: Optional["models.LanggraphStateConverter"] = None): - from .langgraph import LangGraphAdapter - - return LangGraphAdapter(agent, state_converter=state_converter) - - -__all__ = ["from_langgraph"] + from langgraph.graph.state import CompiledStateGraph + from .models.response_api_converter import ResponseAPIConverter + from azure.core.credentials_async import AsyncTokenCredential + from azure.core.credentials import TokenCredential + + +def from_langgraph( + agent: "CompiledStateGraph", + /, + credentials: Optional[Union["AsyncTokenCredential", "TokenCredential"]] = None, + converter: Optional["ResponseAPIConverter"] = None, +) -> "LangGraphAdapter": + """Create a LangGraph adapter for Azure AI Agent Server. + + :param agent: The compiled LangGraph state graph. To use persistent checkpointing, + compile the graph with a checkpointer via ``builder.compile(checkpointer=saver)``. + :type agent: CompiledStateGraph + :param credentials: Azure credentials for authentication. + :type credentials: Optional[Union[AsyncTokenCredential, TokenCredential]] + :param converter: Custom response converter. + :type converter: Optional[ResponseAPIConverter] + :return: A LangGraphAdapter instance. + :rtype: LangGraphAdapter + """ + return LangGraphAdapter(agent, credentials=credentials, converter=converter) + + +__all__ = ["from_langgraph", "LanggraphRunContext"] __version__ = VERSION + +set_current_app(PackageMetadata.from_dist("azure-ai-agentserver-langgraph")) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py new file mode 100644 index 000000000000..d037088b18a5 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_context.py @@ -0,0 +1,67 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import sys +from dataclasses import dataclass +from typing import Optional, Union + +from langchain_core.runnables import RunnableConfig +from langgraph.prebuilt import ToolRuntime +from langgraph.runtime import Runtime, get_runtime + +from azure.ai.agentserver.core import AgentRunContext +from .tools._context import FoundryToolContext + + +@dataclass +class LanggraphRunContext: + agent_run: AgentRunContext + + tools: FoundryToolContext + + def attach_to_config(self, config: RunnableConfig): + config["configurable"]["__foundry_hosted_agent_langgraph_run_context__"] = self + + @classmethod + def resolve(cls, + config: Optional[RunnableConfig] = None, + runtime: Optional[Union[Runtime, ToolRuntime]] = None) -> Optional["LanggraphRunContext"]: + """Resolve the LanggraphRunContext from either a RunnableConfig or a Runtime. + + :param config: Optional RunnableConfig to extract the context from. + :type config: Optional[RunnableConfig] + :param runtime: Optional Runtime or ToolRuntime to extract the context from. + :type runtime: Optional[Union[Runtime, ToolRuntime]] + + :return: An instance of LanggraphRunContext if found, otherwise None. + :rtype: Optional[LanggraphRunContext] + """ + context: Optional["LanggraphRunContext"] = None + if config: + context = cls.from_config(config) + if not context and (r := cls._resolve_runtime(runtime)): + context = cls.from_runtime(r) + return context + + @staticmethod + def _resolve_runtime( + runtime: Optional[Union[Runtime, ToolRuntime]] = None) -> Optional[Union[Runtime, ToolRuntime]]: + if runtime: + return runtime + if sys.version_info >= (3, 11): + return get_runtime(LanggraphRunContext) + return None + + @staticmethod + def from_config(config: RunnableConfig) -> Optional["LanggraphRunContext"]: + context = config["configurable"].get("__foundry_hosted_agent_langgraph_run_context__") + if isinstance(context, LanggraphRunContext): + return context + return None + + @staticmethod + def from_runtime(runtime: Union[Runtime, ToolRuntime]) -> Optional["LanggraphRunContext"]: + context = runtime.context + if isinstance(context, LanggraphRunContext): + return context + return None diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_exceptions.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_exceptions.py new file mode 100644 index 000000000000..27dc2e5f1e9b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_exceptions.py @@ -0,0 +1,13 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Optional + + +class LangGraphMissingConversationIdError(ValueError): + def __init__(self, message: Optional[str] = None) -> None: + super().__init__( + message + or "conversation.id is required when a LangGraph checkpointer is enabled. " + "Provide conversation.id or disable the checkpointer." + ) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py index be71c81bd282..9ab0a006e0d0 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/_version.py @@ -6,4 +6,4 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -VERSION = "1.0.0b1" +VERSION = "1.0.0b10" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/__init__.py new file mode 100644 index 000000000000..9e91582733d3 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/__init__.py @@ -0,0 +1,8 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Checkpoint saver implementations for LangGraph with Azure AI Foundry.""" + +from ._foundry_checkpoint_saver import FoundryCheckpointSaver + +__all__ = ["FoundryCheckpointSaver"] diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_foundry_checkpoint_saver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_foundry_checkpoint_saver.py new file mode 100644 index 000000000000..abfeb4bf3fa8 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_foundry_checkpoint_saver.py @@ -0,0 +1,606 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Foundry-backed checkpoint saver for LangGraph.""" + +import logging +from contextlib import AbstractAsyncContextManager +from types import TracebackType +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential +from langchain_core.runnables import RunnableConfig +from langgraph.checkpoint.base import ( + BaseCheckpointSaver, + ChannelVersions, + Checkpoint, + CheckpointMetadata, + CheckpointTuple, + SerializerProtocol, + get_checkpoint_id, +) + +from azure.ai.agentserver.core.checkpoints.client import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, + FoundryCheckpointClient, +) + +from ._item_id import ItemType, ParsedItemId, make_item_id, parse_item_id + +logger = logging.getLogger(__name__) + + +class FoundryCheckpointSaver( + BaseCheckpointSaver[str], AbstractAsyncContextManager["FoundryCheckpointSaver"] +): + """Checkpoint saver backed by Azure AI Foundry checkpoint storage. + + Implements LangGraph's BaseCheckpointSaver interface using the + FoundryCheckpointClient for remote storage. + + This saver only supports async operations. Sync methods will raise + NotImplementedError. + + :param project_endpoint: The Azure AI Foundry project endpoint URL. + Example: "https://.services.ai.azure.com/api/projects/" + :type project_endpoint: str + :param credential: Credential for authentication. Must be an async credential. + :type credential: Union[AsyncTokenCredential, TokenCredential] + :param serde: Optional serializer protocol. Defaults to JsonPlusSerializer. + :type serde: Optional[SerializerProtocol] + + Example:: + + from azure.ai.agentserver.langgraph.checkpointer import FoundryCheckpointSaver + from azure.identity.aio import DefaultAzureCredential + + saver = FoundryCheckpointSaver( + project_endpoint="https://myresource.services.ai.azure.com/api/projects/my-project", + credential=DefaultAzureCredential(), + ) + + # Use with LangGraph + graph = builder.compile(checkpointer=saver) + """ + + def __init__( + self, + project_endpoint: str, + credential: Union[AsyncTokenCredential, TokenCredential], + *, + serde: Optional[SerializerProtocol] = None, + ) -> None: + """Initialize the Foundry checkpoint saver. + + :param project_endpoint: The Azure AI Foundry project endpoint URL. + :type project_endpoint: str + :param credential: Credential for authentication. Must be an async credential. + :type credential: Union[AsyncTokenCredential, TokenCredential] + :param serde: Optional serializer protocol. + :type serde: Optional[SerializerProtocol] + :raises TypeError: If credential is not an AsyncTokenCredential. + """ + super().__init__(serde=serde) + if not isinstance(credential, AsyncTokenCredential): + raise TypeError( + "FoundryCheckpointSaver requires an AsyncTokenCredential. " + "Please use an async credential like DefaultAzureCredential from azure.identity.aio." + ) + self._client = FoundryCheckpointClient(project_endpoint, credential) + self._session_cache: set[str] = set() + + async def __aenter__(self) -> "FoundryCheckpointSaver": + """Enter the async context manager. + + :return: The saver instance. + :rtype: FoundryCheckpointSaver + """ + await self._client.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[type[BaseException]], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + """Exit the async context manager. + + :param exc_type: Exception type if an exception occurred. + :param exc_val: Exception value if an exception occurred. + :param exc_tb: Exception traceback if an exception occurred. + """ + await self._client.__aexit__(exc_type, exc_val, exc_tb) + + async def _ensure_session(self, thread_id: str) -> None: + """Ensure a session exists for the thread. + + :param thread_id: The thread identifier. + :type thread_id: str + """ + if thread_id not in self._session_cache: + session = CheckpointSession(session_id=thread_id) + await self._client.upsert_session(session) + self._session_cache.add(thread_id) + + async def _get_latest_checkpoint_id( + self, thread_id: str, checkpoint_ns: str + ) -> Optional[str]: + """Find the latest checkpoint ID for a thread and namespace. + + :param thread_id: The thread identifier. + :type thread_id: str + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :return: The latest checkpoint ID, or None if not found. + :rtype: Optional[str] + """ + item_ids = await self._client.list_item_ids(thread_id) + + # Filter to checkpoint items in this namespace + checkpoint_ids: List[str] = [] + for item_id in item_ids: + try: + parsed = parse_item_id(item_id.item_id) + if parsed.item_type == "checkpoint" and parsed.checkpoint_ns == checkpoint_ns: + checkpoint_ids.append(parsed.checkpoint_id) + except ValueError: + continue + + if not checkpoint_ids: + return None + + # Return the latest (max) checkpoint ID + return max(checkpoint_ids) + + async def _load_pending_writes( + self, thread_id: str, checkpoint_ns: str, checkpoint_id: str + ) -> List[Tuple[str, str, Any]]: + """Load pending writes for a checkpoint. + + :param thread_id: The thread identifier. + :type thread_id: str + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :return: List of pending writes as (task_id, channel, value) tuples. + :rtype: List[Tuple[str, str, Any]] + """ + item_ids = await self._client.list_item_ids(thread_id) + writes: List[Tuple[str, str, Any]] = [] + + for item_id in item_ids: + try: + parsed = parse_item_id(item_id.item_id) + if ( + parsed.item_type == "writes" + and parsed.checkpoint_ns == checkpoint_ns + and parsed.checkpoint_id == checkpoint_id + ): + item = await self._client.read_item(item_id) + if item: + task_id, channel, value, _ = self.serde.loads_typed(item.data) + writes.append((task_id, channel, value)) + except (ValueError, TypeError): + continue + + return writes + + async def _load_blobs( + self, thread_id: str, checkpoint_ns: str, checkpoint_id: str, versions: ChannelVersions + ) -> Dict[str, Any]: + """Load channel blobs for a checkpoint. + + :param thread_id: The thread identifier. + :type thread_id: str + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :param versions: The channel versions to load. + :type versions: ChannelVersions + :return: Dictionary of channel values. + :rtype: Dict[str, Any] + """ + channel_values: Dict[str, Any] = {} + + for channel, version in versions.items(): + blob_item_id = make_item_id( + checkpoint_ns, checkpoint_id, "blob", f"{channel}:{version}" + ) + item_id = CheckpointItemId(session_id=thread_id, item_id=blob_item_id) + item = await self._client.read_item(item_id) + if item: + type_tag, data = self.serde.loads_typed(item.data) + if type_tag != "empty": + channel_values[channel] = data + + return channel_values + + # Async methods (primary implementation) + + async def aget_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Asynchronously get a checkpoint tuple by config. + + :param config: Configuration specifying which checkpoint to retrieve. + :type config: RunnableConfig + :return: The checkpoint tuple, or None if not found. + :rtype: Optional[CheckpointTuple] + """ + thread_id: str = config["configurable"]["thread_id"] + checkpoint_ns: str = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = get_checkpoint_id(config) + + # Ensure session exists + await self._ensure_session(thread_id) + + # If no checkpoint_id, find the latest + if not checkpoint_id: + checkpoint_id = await self._get_latest_checkpoint_id(thread_id, checkpoint_ns) + if not checkpoint_id: + return None + + # Load the checkpoint item + item_id_str = make_item_id(checkpoint_ns, checkpoint_id, "checkpoint") + item = await self._client.read_item( + CheckpointItemId(session_id=thread_id, item_id=item_id_str) + ) + if not item: + return None + + # Deserialize checkpoint data + checkpoint_data = self.serde.loads_typed(item.data) + checkpoint: Checkpoint = checkpoint_data["checkpoint"] + metadata: CheckpointMetadata = checkpoint_data["metadata"] + + # Load channel values (blobs) + channel_values = await self._load_blobs( + thread_id, checkpoint_ns, checkpoint_id, checkpoint.get("channel_versions", {}) + ) + checkpoint = {**checkpoint, "channel_values": channel_values} + + # Load pending writes + pending_writes = await self._load_pending_writes(thread_id, checkpoint_ns, checkpoint_id) + + # Build parent config if parent exists + parent_config: Optional[RunnableConfig] = None + if item.parent_id: + try: + parent_parsed = parse_item_id(item.parent_id) + parent_config = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": parent_parsed.checkpoint_ns, + "checkpoint_id": parent_parsed.checkpoint_id, + } + } + except ValueError: + pass + + return CheckpointTuple( + config={ + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + }, + checkpoint=checkpoint, + metadata=metadata, + parent_config=parent_config, + pending_writes=pending_writes, + ) + + async def aput( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Asynchronously store a checkpoint. + + :param config: Configuration for the checkpoint. + :type config: RunnableConfig + :param checkpoint: The checkpoint to store. + :type checkpoint: Checkpoint + :param metadata: Additional metadata for the checkpoint. + :type metadata: CheckpointMetadata + :param new_versions: New channel versions as of this write. + :type new_versions: ChannelVersions + :return: Updated configuration with the checkpoint ID. + :rtype: RunnableConfig + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = checkpoint["id"] + + # Ensure session exists + await self._ensure_session(thread_id) + + # Determine parent + parent_checkpoint_id = config["configurable"].get("checkpoint_id") + parent_item_id: Optional[str] = None + if parent_checkpoint_id: + parent_item_id = make_item_id(checkpoint_ns, parent_checkpoint_id, "checkpoint") + + # Prepare checkpoint data (without channel_values - stored as blobs) + checkpoint_copy = checkpoint.copy() + channel_values: Dict[str, Any] = checkpoint_copy.pop("channel_values", {}) # type: ignore[misc] + + checkpoint_data = self.serde.dumps_typed({ + "checkpoint": checkpoint_copy, + "metadata": metadata, + }) + + # Create checkpoint item + item_id_str = make_item_id(checkpoint_ns, checkpoint_id, "checkpoint") + items: List[CheckpointItem] = [ + CheckpointItem( + session_id=thread_id, + item_id=item_id_str, + data=checkpoint_data, + parent_id=parent_item_id, + ) + ] + + # Create blob items for channel values with new versions + for channel, version in new_versions.items(): + if channel in channel_values: + blob_data = self.serde.dumps_typed(channel_values[channel]) + else: + blob_data = self.serde.dumps_typed(("empty", b"")) + + blob_item_id = make_item_id( + checkpoint_ns, checkpoint_id, "blob", f"{channel}:{version}" + ) + items.append( + CheckpointItem( + session_id=thread_id, + item_id=blob_item_id, + data=blob_data, + parent_id=item_id_str, + ) + ) + + await self._client.create_items(items) + + logger.debug( + "Saved checkpoint %s to Foundry session %s", + checkpoint_id, + thread_id, + ) + + return { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": checkpoint_ns, + "checkpoint_id": checkpoint_id, + } + } + + async def aput_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Asynchronously store intermediate writes for a checkpoint. + + :param config: Configuration of the related checkpoint. + :type config: RunnableConfig + :param writes: List of writes to store as (channel, value) pairs. + :type writes: Sequence[Tuple[str, Any]] + :param task_id: Identifier for the task creating the writes. + :type task_id: str + :param task_path: Path of the task creating the writes. + :type task_path: str + """ + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns", "") + checkpoint_id = config["configurable"]["checkpoint_id"] + + checkpoint_item_id = make_item_id(checkpoint_ns, checkpoint_id, "checkpoint") + + items: List[CheckpointItem] = [] + for idx, (channel, value) in enumerate(writes): + write_data = self.serde.dumps_typed((task_id, channel, value, task_path)) + write_item_id = make_item_id( + checkpoint_ns, checkpoint_id, "writes", f"{task_id}:{idx}" + ) + items.append( + CheckpointItem( + session_id=thread_id, + item_id=write_item_id, + data=write_data, + parent_id=checkpoint_item_id, + ) + ) + + if items: + await self._client.create_items(items) + logger.debug( + "Saved %d writes for checkpoint %s", + len(items), + checkpoint_id, + ) + + async def alist( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> AsyncIterator[CheckpointTuple]: + """Asynchronously list checkpoints matching filter criteria. + + :param config: Base configuration for filtering checkpoints. + :type config: Optional[RunnableConfig] + :param filter: Additional filtering criteria for metadata. + :type filter: Optional[Dict[str, Any]] + :param before: List checkpoints created before this configuration. + :type before: Optional[RunnableConfig] + :param limit: Maximum number of checkpoints to return. + :type limit: Optional[int] + :return: Async iterator of matching checkpoint tuples. + :rtype: AsyncIterator[CheckpointTuple] + """ + if not config: + return + + thread_id = config["configurable"]["thread_id"] + checkpoint_ns = config["configurable"].get("checkpoint_ns") + + # Get all items for this session + item_ids = await self._client.list_item_ids(thread_id) + + # Filter to checkpoint items only + checkpoint_items: List[Tuple[ParsedItemId, CheckpointItemId]] = [] + for item_id in item_ids: + try: + parsed = parse_item_id(item_id.item_id) + if parsed.item_type == "checkpoint": + # Filter by namespace if specified + if checkpoint_ns is None or parsed.checkpoint_ns == checkpoint_ns: + checkpoint_items.append((parsed, item_id)) + except ValueError: + continue + + # Sort by checkpoint_id in reverse order (newest first) + checkpoint_items.sort(key=lambda x: x[0].checkpoint_id, reverse=True) + + # Apply before cursor + if before: + before_id = get_checkpoint_id(before) + if before_id: + checkpoint_items = [ + (p, i) for p, i in checkpoint_items if p.checkpoint_id < before_id + ] + + # Apply limit + if limit: + checkpoint_items = checkpoint_items[:limit] + + # Load and yield each checkpoint + for parsed, _ in checkpoint_items: + tuple_config: RunnableConfig = { + "configurable": { + "thread_id": thread_id, + "checkpoint_ns": parsed.checkpoint_ns, + "checkpoint_id": parsed.checkpoint_id, + } + } + checkpoint_tuple = await self.aget_tuple(tuple_config) + if checkpoint_tuple: + # Apply metadata filter if provided + if filter: + if not all( + checkpoint_tuple.metadata.get(k) == v for k, v in filter.items() + ): + continue + yield checkpoint_tuple + + async def adelete_thread(self, thread_id: str) -> None: + """Delete all checkpoints and writes for a thread. + + :param thread_id: The thread ID whose checkpoints should be deleted. + :type thread_id: str + """ + await self._client.delete_session(thread_id) + self._session_cache.discard(thread_id) + logger.debug("Deleted session %s", thread_id) + + # Sync methods (raise NotImplementedError) + + def get_tuple(self, config: RunnableConfig) -> Optional[CheckpointTuple]: + """Sync version not supported - use aget_tuple instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use aget_tuple() instead." + ) + + def list( + self, + config: Optional[RunnableConfig], + *, + filter: Optional[Dict[str, Any]] = None, + before: Optional[RunnableConfig] = None, + limit: Optional[int] = None, + ) -> Iterator[CheckpointTuple]: + """Sync version not supported - use alist instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use alist() instead." + ) + + def put( + self, + config: RunnableConfig, + checkpoint: Checkpoint, + metadata: CheckpointMetadata, + new_versions: ChannelVersions, + ) -> RunnableConfig: + """Sync version not supported - use aput instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use aput() instead." + ) + + def put_writes( + self, + config: RunnableConfig, + writes: Sequence[Tuple[str, Any]], + task_id: str, + task_path: str = "", + ) -> None: + """Sync version not supported - use aput_writes instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use aput_writes() instead." + ) + + def delete_thread(self, thread_id: str) -> None: + """Sync version not supported - use adelete_thread instead. + + :raises NotImplementedError: Always raised. + """ + raise NotImplementedError( + "FoundryCheckpointSaver requires async usage. Use adelete_thread() instead." + ) + + def get_next_version(self, current: Optional[str], channel: None) -> str: + """Generate the next version ID for a channel. + + Uses string versions with format "{counter}.{random}". + + :param current: The current version identifier. + :type current: Optional[str] + :param channel: Deprecated argument, kept for backwards compatibility. + :return: The next version identifier. + :rtype: str + """ + import random as rand + + if current is None: + current_v = 0 + elif isinstance(current, int): + current_v = current + else: + current_v = int(current.split(".")[0]) + next_v = current_v + 1 + next_h = rand.random() + return f"{next_v:032}.{next_h:016}" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_item_id.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_item_id.py new file mode 100644 index 000000000000..8758181ce2f2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/checkpointer/_item_id.py @@ -0,0 +1,96 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Item ID utilities for composite checkpoint item identifiers.""" + +from dataclasses import dataclass +from typing import Literal + +ItemType = Literal["checkpoint", "writes", "blob"] + + +@dataclass +class ParsedItemId: + """Parsed components of a checkpoint item ID. + + :ivar checkpoint_ns: The checkpoint namespace. + :ivar checkpoint_id: The checkpoint identifier. + :ivar item_type: The type of item (checkpoint, writes, or blob). + :ivar sub_key: Additional key for writes or blobs. + """ + + checkpoint_ns: str + checkpoint_id: str + item_type: ItemType + sub_key: str + + +def _encode(s: str) -> str: + """URL-safe encode a string (escape colons and percent signs). + + :param s: The string to encode. + :type s: str + :return: The encoded string. + :rtype: str + """ + return s.replace("%", "%25").replace(":", "%3A") + + +def _decode(s: str) -> str: + """Decode a URL-safe encoded string. + + :param s: The encoded string. + :type s: str + :return: The decoded string. + :rtype: str + """ + return s.replace("%3A", ":").replace("%25", "%") + + +def make_item_id( + checkpoint_ns: str, + checkpoint_id: str, + item_type: ItemType, + sub_key: str = "", +) -> str: + """Create a composite item ID. + + Format: {checkpoint_ns}:{checkpoint_id}:{type}:{sub_key} + + :param checkpoint_ns: The checkpoint namespace. + :type checkpoint_ns: str + :param checkpoint_id: The checkpoint identifier. + :type checkpoint_id: str + :param item_type: The type of item (checkpoint, writes, or blob). + :type item_type: ItemType + :param sub_key: Additional key for writes or blobs. + :type sub_key: str + :return: The composite item ID. + :rtype: str + """ + return f"{_encode(checkpoint_ns)}:{_encode(checkpoint_id)}:{item_type}:{_encode(sub_key)}" + + +def parse_item_id(item_id: str) -> ParsedItemId: + """Parse a composite item ID back to components. + + :param item_id: The composite item ID to parse. + :type item_id: str + :return: The parsed item ID components. + :rtype: ParsedItemId + :raises ValueError: If the item ID format is invalid. + """ + parts = item_id.split(":", 3) + if len(parts) != 4: + raise ValueError(f"Invalid item_id format: {item_id}") + + item_type = parts[2] + if item_type not in ("checkpoint", "writes", "blob"): + raise ValueError(f"Invalid item_type in item_id: {item_type}") + + return ParsedItemId( + checkpoint_ns=_decode(parts[0]), + checkpoint_id=_decode(parts[1]), + item_type=item_type, # type: ignore[arg-type] + sub_key=_decode(parts[3]), + ) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py index 0d2b60bac248..95e220ee978b 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/langgraph.py @@ -1,25 +1,29 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -# pylint: disable=logging-fstring-interpolation,broad-exception-caught +# pylint: disable=logging-fstring-interpolation,broad-exception-caught,no-member # mypy: disable-error-code="assignment,arg-type" import os import re -from typing import Optional +from typing import Optional, TYPE_CHECKING -from langchain_core.runnables import RunnableConfig from langgraph.graph.state import CompiledStateGraph from azure.ai.agentserver.core.constants import Constants from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.server.base import FoundryCBAgent -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext - -from .models import ( - LanggraphMessageStateConverter, - LanggraphStateConverter, -) +from azure.ai.agentserver.core import AgentRunContext +from azure.ai.agentserver.core.tools import OAuthConsentRequiredError # pylint:disable=import-error,no-name-in-module +from ._context import LanggraphRunContext +from ._exceptions import LangGraphMissingConversationIdError +from .models.response_api_converter import GraphInputArguments, ResponseAPIConverter +from .models.response_api_default_converter import ResponseAPIDefaultConverter from .models.utils import is_state_schema_valid +from .tools._context import FoundryToolContext +from .tools._resolver import FoundryLangChainToolResolver + +if TYPE_CHECKING: + from azure.core.credentials_async import AsyncTokenCredential logger = get_logger() @@ -29,33 +33,61 @@ class LangGraphAdapter(FoundryCBAgent): Adapter for LangGraph Agent. """ - def __init__(self, graph: CompiledStateGraph, state_converter: Optional[LanggraphStateConverter] = None): + def __init__( + self, + graph: CompiledStateGraph, + credentials: "Optional[AsyncTokenCredential]" = None, + converter: "Optional[ResponseAPIConverter]" = None, + ) -> None: """ - Initialize the LangGraphAdapter with a CompiledStateGraph. - - :param graph: The LangGraph StateGraph to adapt. - :type graph: CompiledStateGraph - :param state_converter: custom state converter. Required if graph state is not MessagesState. - :type state_converter: Optional[LanggraphStateConverter] + Initialize the LangGraphAdapter with a CompiledStateGraph or a function that returns one. + + :param graph: The LangGraph StateGraph to adapt, or a callable that takes ToolClient + and returns CompiledStateGraph (sync or async). + :type graph: Union[CompiledStateGraph, GraphFactory] + :param credentials: Azure credentials for authentication. + :type credentials: Optional[AsyncTokenCredential] + :param converter: custom response converter. + :type converter: Optional[ResponseAPIConverter] """ - super().__init__() - self.graph = graph + super().__init__(credentials=credentials) # pylint: disable=unexpected-keyword-arg + self._graph = graph + self._tool_resolver = FoundryLangChainToolResolver() self.azure_ai_tracer = None - if not state_converter: - if is_state_schema_valid(self.graph.builder.state_schema): - self.state_converter = LanggraphMessageStateConverter() + + if not converter: + if is_state_schema_valid(self._graph.builder.state_schema): + self.converter = ResponseAPIDefaultConverter(graph=self._graph) else: - raise ValueError("state_converter is required for non-MessagesState graph.") + raise ValueError("converter is required for non-MessagesState graph.") else: - self.state_converter = state_converter + self.converter = converter async def agent_run(self, context: AgentRunContext): - input_data = self.state_converter.request_to_state(context) - logger.debug(f"Converted input data: {input_data}") - if not context.stream: - response = await self.agent_run_non_stream(input_data, context) - return response - return self.agent_run_astream(input_data, context) + # Resolve graph - always resolve if it's a factory function to get fresh graph each time + # For factories, get a new graph instance per request to avoid concurrency issues + try: + self._check_missing_thread_id(context) + lg_run_context = await self.setup_lg_run_context(context) + input_arguments = await self.converter.convert_request(lg_run_context) + self.ensure_runnable_config(input_arguments, lg_run_context) + + if not context.stream: + response = await self.agent_run_non_stream(input_arguments) + return response + + return self.agent_run_astream(input_arguments) + except OAuthConsentRequiredError as e: + if not context.stream: + response = await self.respond_with_oauth_consent(context, e) + return response + return self.respond_with_oauth_consent_astream(context, e) + + async def setup_lg_run_context(self, agent_run_context: AgentRunContext) -> LanggraphRunContext: + resolved = await self._tool_resolver.resolve_from_registry() + return LanggraphRunContext( + agent_run_context, + FoundryToolContext(resolved)) def init_tracing_internal(self, exporter_endpoint=None, app_insights_conn_str=None): # set env vars for langsmith @@ -85,69 +117,73 @@ def get_trace_attributes(self): attrs["service.namespace"] = "azure.ai.agentserver.langgraph" return attrs - async def agent_run_non_stream(self, input_data: dict, context: AgentRunContext): + def _check_missing_thread_id(self, context: AgentRunContext) -> None: + checkpointer = getattr(self._graph, "checkpointer", None) + if checkpointer and checkpointer is not False and not context.conversation_id: + raise LangGraphMissingConversationIdError() + + async def agent_run_non_stream(self, input_arguments: GraphInputArguments): """ Run the agent with non-streaming response. - :param input_data: The input data to run the agent with. - :type input_data: dict - :param context: The context for the agent run. - :type context: AgentRunContext + :param input_arguments: The input data to run the agent with. + :type input_arguments: GraphInputArguments :return: The response of the agent run. :rtype: dict """ - try: - config = self.create_runnable_config(context) - stream_mode = self.state_converter.get_stream_mode(context) - result = await self.graph.ainvoke(input_data, config=config, stream_mode=stream_mode) - output = self.state_converter.state_to_response(result, context) + result = await self._graph.ainvoke(**input_arguments) + output = await self.converter.convert_response_non_stream(result, input_arguments["context"]) return output except Exception as e: - logger.error(f"Error during agent run: {e}") + logger.error(f"Error during agent run: {e}", exc_info=True) raise e - async def agent_run_astream(self, input_data: dict, context: AgentRunContext): + async def agent_run_astream(self, + input_arguments: GraphInputArguments): """ Run the agent with streaming response. - :param input_data: The input data to run the agent with. - :type input_data: dict - :param context: The context for the agent run. - :type context: AgentRunContext + :param input_arguments: The input data to run the agent with. + :type input_arguments: GraphInputArguments :return: An async generator yielding the response stream events. :rtype: AsyncGenerator[dict] """ try: - logger.info(f"Starting streaming agent run {context.response_id}") - config = self.create_runnable_config(context) - stream_mode = self.state_converter.get_stream_mode(context) - stream = self.graph.astream(input=input_data, config=config, stream_mode=stream_mode) - async for result in self.state_converter.state_to_response_stream(stream, context): - yield result + logger.info(f"Starting streaming agent run {input_arguments['context'].agent_run.response_id}") + stream = self._graph.astream(**input_arguments) + async for output_event in self.converter.convert_response_stream( + stream, + input_arguments["context"]): + yield output_event except Exception as e: - logger.error(f"Error during streaming agent run: {e}") + logger.error(f"Error during streaming agent run: {e}", exc_info=True) raise e - def create_runnable_config(self, context: AgentRunContext) -> RunnableConfig: + def ensure_runnable_config(self, input_arguments: GraphInputArguments, context: LanggraphRunContext): """ - Create a RunnableConfig from the converted request data. - - :param context: The context for the agent run. - :type context: AgentRunContext + Ensure the RunnableConfig is set in the input arguments. - :return: The RunnableConfig for the agent run. - :rtype: RunnableConfig + :param input_arguments: The input arguments for the agent run. + :type input_arguments: GraphInputArguments + :param context: The Langgraph run context. + :type context: LanggraphRunContext """ - config = RunnableConfig( - configurable={ - "thread_id": context.conversation_id, - }, - callbacks=[self.azure_ai_tracer] if self.azure_ai_tracer else None, - ) - return config + config = input_arguments.get("config", {}) + configurable = config.get("configurable", {}) + thread_id = input_arguments["context"].agent_run.conversation_id + if thread_id: + configurable["thread_id"] = thread_id + config["configurable"] = configurable + context.attach_to_config(config) + + callbacks = config.get("callbacks", []) # mypy: ignore-errors + if self.azure_ai_tracer and self.azure_ai_tracer not in callbacks: + callbacks.append(self.azure_ai_tracer) + config["callbacks"] = callbacks + input_arguments["config"] = config def format_otlp_endpoint(self, endpoint: str) -> str: m = re.match(r"^(https?://[^/]+)", endpoint) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/__init__.py index eb6285a6279b..d540fd20468c 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/__init__.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/__init__.py @@ -1,15 +1,3 @@ # --------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- -from .langgraph_request_converter import LangGraphRequestConverter -from .langgraph_response_converter import LangGraphResponseConverter -from .langgraph_state_converter import LanggraphMessageStateConverter, LanggraphStateConverter -from .langgraph_stream_response_converter import LangGraphStreamResponseConverter - -__all__ = [ - "LangGraphRequestConverter", - "LangGraphResponseConverter", - "LangGraphStreamResponseConverter", - "LanggraphStateConverter", - "LanggraphMessageStateConverter", -] diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/human_in_the_loop_helper.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/human_in_the_loop_helper.py new file mode 100644 index 000000000000..9f3c693800a1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/human_in_the_loop_helper.py @@ -0,0 +1,144 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional, Union + +from langgraph.types import ( + Command, + Interrupt, + StateSnapshot, +) + +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.models import projects as project_models +from azure.ai.agentserver.core.models.openai import (ResponseInputItemParam, ResponseInputParam) +from .._context import LanggraphRunContext + +INTERRUPT_NODE_NAME = "__interrupt__" +logger = get_logger() + + +class HumanInTheLoopHelper: + """Helper class for managing human-in-the-loop interactions in LangGraph.""" + def __init__(self, context: LanggraphRunContext): + self.context = context + + def has_interrupt(self, state: Optional[StateSnapshot]) -> bool: + """Check if the LangGraph state contains an interrupt node. + + :param state: The LangGraph state snapshot. + :type state: Optional[StateSnapshot] + :return: True if the state contains an interrupt, False otherwise. + :rtype: bool + """ + if not state or not isinstance(state, StateSnapshot): + return False + return state.interrupts is not None and len(state.interrupts) > 0 + + def convert_interrupts(self, interrupts: tuple) -> list[project_models.ItemResource]: + """Convert LangGraph interrupts to ItemResource objects. + + :param interrupts: A tuple of interrupt objects. + :type interrupts: tuple + :return: A list of ItemResource objects. + :rtype: list[project_models.ItemResource] + """ + if not interrupts or not isinstance(interrupts, tuple): + return [] + result = [] + # should be only one interrupt for now + for interrupt_info in interrupts: + if not isinstance(interrupt_info, Interrupt): + # skip invalid interrupt + continue + item = self.convert_interrupt(interrupt_info) + if item: + result.append(item) + return result + + def convert_interrupt(self, interrupt_info: Interrupt) -> Optional[project_models.ItemResource]: + """Convert a single LangGraph Interrupt to an ItemResource object. + + :param interrupt_info: The interrupt information from LangGraph. + :type interrupt_info: Interrupt + + :return: The corresponding ItemResource object. + :rtype: project_models.ItemResource + """ + raise NotImplementedError("Subclasses must implement convert_interrupt method.") + + def validate_and_convert_human_feedback( + self, state: Optional[StateSnapshot], input_data: Union[str, ResponseInputParam] + ) -> Optional[Command]: + """Validate if the human feedback input corresponds to the interrupt in state. + If valid, convert the input to a LangGraph Command. + + :param state: The current LangGraph state snapshot. + :type state: Optional[StateSnapshot] + :param input_data: The human feedback input from the request. + :type input_data: Union[str, ResponseInputParam] + + :return: Command if valid feedback is provided, else None. + :rtype: Union[Command, None] + """ + # Validate interrupt exists in state + if not self.has_interrupt(state): + logger.info("No interrupt found in state.") + return None + + interrupt_obj = state.interrupts[0] # type: ignore[union-attr] # Assume single interrupt for simplicity + if not interrupt_obj or not isinstance(interrupt_obj, Interrupt): + logger.warning("No interrupt object found in state") + return None + + logger.info("Retrieved interrupt from state, validating and converting human feedback.") + + # Validate input format and extract item + item = self._validate_input_format(input_data, interrupt_obj) + if item is None: + return None + + return self.convert_input_item_to_command(item) + + def _validate_input_format( + self, input_data: Union[str, ResponseInputParam], interrupt_obj: Interrupt + ) -> Optional[ResponseInputItemParam]: + if isinstance(input_data, str): + logger.warning("Expecting function call output item, got string: %s", input_data) + return None + + if not isinstance(input_data, list): + logger.error("Unsupported interrupt input type: %s, %s", type(input_data), input_data) + return None + + if len(input_data) != 1: + logger.warning("Expected exactly one interrupt input item, got %d items.", len(input_data)) + return None + + item = input_data[0] + item_type = item.get("type", None) + if item_type != project_models.ItemType.FUNCTION_CALL_OUTPUT: + logger.warning( + "Invalid interrupt input item type: %s, expected FUNCTION_CALL_OUTPUT.", item_type + ) + return None + + if item.get("call_id") != interrupt_obj.id: + logger.warning( + "Interrupt input call_id %s does not match interrupt id %s.", + item.get("call_id"), interrupt_obj.id + ) + return None + + return item + + def convert_input_item_to_command(self, input_item: ResponseInputItemParam) -> Union[Command, None]: + """Convert ItemParams to a LangGraph Command for interrupt handling. + + :param input_item: The item parameters containing interrupt information. + :type input_item: ResponseInputItemParam + :return: The LangGraph Command. + :rtype: Union[Command, None] + """ + raise NotImplementedError("Subclasses must implement convert_request_to_command method.") diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/human_in_the_loop_json_helper.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/human_in_the_loop_json_helper.py new file mode 100644 index 000000000000..e1396ba90577 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/human_in_the_loop_json_helper.py @@ -0,0 +1,86 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import json +from typing import Optional, Union + +from langgraph.types import ( + Command, + Interrupt, +) + +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.models import projects as project_models +from azure.ai.agentserver.core.models.openai import ( + ResponseInputItemParam, +) +from azure.ai.agentserver.core.server.common.constants import HUMAN_IN_THE_LOOP_FUNCTION_NAME + +from .human_in_the_loop_helper import HumanInTheLoopHelper + +logger = get_logger() + + +class HumanInTheLoopJsonHelper(HumanInTheLoopHelper): + """ + Helper class for managing human-in-the-loop interactions in LangGraph. + Interrupts are converted to FunctionToolCallItemResource objects. + Human feedback will be sent back as FunctionCallOutputItemParam. + All values are serialized as JSON strings. + """ + + def convert_interrupt(self, interrupt_info: Interrupt) -> Optional[project_models.ItemResource]: + if not isinstance(interrupt_info, Interrupt): + logger.warning("Interrupt is not of type Interrupt: %s", interrupt_info) + return None + name, call_id, arguments = self.interrupt_to_function_call(interrupt_info) + return project_models.FunctionToolCallItemResource( + call_id=call_id, + name=name, + arguments=arguments, + id=self.context.agent_run.id_generator.generate_function_call_id(), + status="in_progress", + ) + + def interrupt_to_function_call(self, interrupt: Interrupt) : + """ + Convert an Interrupt to a function call tuple. + + :param interrupt: The Interrupt object to convert. + :type interrupt: Interrupt + + :return: A tuple of (name, call_id, argument). + :rtype: tuple[str | None, str | None, str | None] + """ + if isinstance(interrupt.value, str): + arguments = interrupt.value + else: + try: + arguments = json.dumps(interrupt.value) + except Exception as e: # pragma: no cover - fallback # pylint: disable=broad-exception-caught + logger.error("Failed to serialize interrupt value to JSON: %s, error: %s", interrupt.value, e) + arguments = str(interrupt.value) + return HUMAN_IN_THE_LOOP_FUNCTION_NAME, interrupt.id, arguments + + def convert_input_item_to_command(self, input_item: ResponseInputItemParam) -> Union[Command, None]: + output_str = input_item.get("output") + if not isinstance(output_str, str): + logger.error("Invalid output type in function call output: %s", input_item) + return None + try: + output = json.loads(output_str) + except json.JSONDecodeError: + logger.error("Invalid JSON in function call output: %s", input_item) + return None + resume = output.get("resume", None) + update = output.get("update", None) + goto = output.get("goto", None) + if resume is None and update is None and goto is None: + logger.warning("No valid Command fields found in function call output: %s", input_item) + return None + return Command( + resume=resume, + update=update, + goto=goto, + ) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_state_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_state_converter.py deleted file mode 100644 index a1bc2181f919..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_state_converter.py +++ /dev/null @@ -1,143 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- -# mypy: disable-error-code="call-overload,override" -"""Base interface for converting between LangGraph internal state and OpenAI-style responses. - -A LanggraphStateConverter implementation bridges: - 1. Incoming CreateResponse (wrapped in AgentRunContext) -> initial graph state - 2. Internal graph state -> final non-streaming Response - 3. Streaming graph state events -> ResponseStreamEvent sequence - 4. Declares which stream mode (if any) is supported for a given run context - -Concrete implementations should: - * Decide and document the shape of the state dict they return in request_to_state - * Handle aggregation, error mapping, and metadata propagation in state_to_response - * Incrementally translate async stream_state items in state_to_response_stream - -Do NOT perform network I/O directly inside these methods (other than awaiting the -provided async iterator). Keep them pure transformation layers so they are testable. -""" - -from __future__ import annotations - -import time -from abc import ABC, abstractmethod -from typing import Any, AsyncGenerator, AsyncIterator, Dict - -from azure.ai.agentserver.core.models import Response, ResponseStreamEvent -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext - -from .langgraph_request_converter import LangGraphRequestConverter -from .langgraph_response_converter import LangGraphResponseConverter -from .langgraph_stream_response_converter import LangGraphStreamResponseConverter - - -class LanggraphStateConverter(ABC): - """ - Abstract base class for LangGraph state <-> response conversion. - - :meta private: - """ - - @abstractmethod - def get_stream_mode(self, context: AgentRunContext) -> str: - """Return a string indicating streaming mode for this run. - - Examples: "values", "updates", "messages", "custom", "debug". - Implementations may inspect context.request.stream or other flags. - Must be fast and side-effect free. - - :param context: The context for the agent run. - :type context: AgentRunContext - - :return: The streaming mode as a string. - :rtype: str - """ - - @abstractmethod - def request_to_state(self, context: AgentRunContext) -> Dict[str, Any]: - """Convert the incoming request (via context) to an initial LangGraph state. - - Return a serializable dict that downstream graph execution expects. - Should not mutate the context. Raise ValueError on invalid input. - - :param context: The context for the agent run. - :type context: AgentRunContext - - :return: The initial LangGraph state as a dictionary. - :rtype: Dict[str, Any] - """ - - @abstractmethod - def state_to_response(self, state: Any, context: AgentRunContext) -> Response: - """Convert a completed LangGraph state into a final non-streaming Response object. - - Implementations must construct and return an models.Response. - The returned object should include output items, usage (if available), - and reference the agent / conversation from context. - - :param state: The completed LangGraph state. - :type state: Any - :param context: The context for the agent run. - :type context: AgentRunContext - - :return: The final non-streaming Response object. - :rtype: Response - """ - - @abstractmethod - async def state_to_response_stream( - self, stream_state: AsyncIterator[Dict[str, Any] | Any], context: AgentRunContext - ) -> AsyncGenerator[ResponseStreamEvent, None]: - """Convert an async iterator of partial state updates into stream events. - - Yield ResponseStreamEvent objects in the correct order. Implementations - are responsible for emitting lifecycle events (created, in_progress, deltas, - completed, errors) consistent with the OpenAI Responses streaming contract. - - :param stream_state: An async iterator of partial LangGraph state updates. - :type stream_state: AsyncIterator[Dict[str, Any] | Any] - :param context: The context for the agent run. - :type context: AgentRunContext - - :return: An async generator yielding ResponseStreamEvent objects. - :rtype: AsyncGenerator[ResponseStreamEvent, None] - """ - - -class LanggraphMessageStateConverter(LanggraphStateConverter): - """Converter implementation for langgraph built-in MessageState.""" - - def get_stream_mode(self, context: AgentRunContext) -> str: - if context.request.get("stream"): - return "messages" - return "updates" - - def request_to_state(self, context: AgentRunContext) -> Dict[str, Any]: - converter = LangGraphRequestConverter(context.request) - return converter.convert() - - def state_to_response(self, state: Any, context: AgentRunContext) -> Response: - converter = LangGraphResponseConverter(context, state) - output = converter.convert() - - agent_id = context.get_agent_id_object() - conversation = context.get_conversation_object() - response = Response( - object="response", - id=context.response_id, - agent=agent_id, - conversation=conversation, - metadata=context.request.get("metadata"), - created_at=int(time.time()), - output=output, - ) - return response - - async def state_to_response_stream( - self, stream_state: AsyncIterator[Dict[str, Any] | Any], context: AgentRunContext - ) -> AsyncGenerator[ResponseStreamEvent, None]: - response_converter = LangGraphStreamResponseConverter(stream_state, context) - async for result in response_converter.convert(): - yield result diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_stream_response_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_stream_response_converter.py deleted file mode 100644 index cba1db014ed8..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_stream_response_converter.py +++ /dev/null @@ -1,74 +0,0 @@ -# --------------------------------------------------------- -# Copyright (c) Microsoft Corporation. All rights reserved. -# --------------------------------------------------------- -# pylint: disable=logging-fstring-interpolation -# mypy: disable-error-code="assignment,valid-type" -from typing import List - -from langchain_core.messages import AnyMessage - -from azure.ai.agentserver.core.logger import get_logger -from azure.ai.agentserver.core.models import ResponseStreamEvent -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext - -from .response_event_generators import ( - ResponseEventGenerator, - ResponseStreamEventGenerator, - StreamEventState, -) - -logger = get_logger() - - -class LangGraphStreamResponseConverter: - def __init__(self, stream, context: AgentRunContext): - self.stream = stream - self.context = context - - self.stream_state = StreamEventState() - self.current_generator: ResponseEventGenerator = None - - async def convert(self): - async for message, _ in self.stream: - try: - if self.current_generator is None: - self.current_generator = ResponseStreamEventGenerator(logger, None) - - converted = self.try_process_message(message, self.context) - for event in converted: - yield event # yield each event separately - except Exception as e: - logger.error(f"Error converting message {message}: {e}") - raise ValueError(f"Error converting message {message}") from e - - logger.info("Stream ended, finalizing response.") - # finalize the stream - converted = self.try_process_message(None, self.context) - for event in converted: - yield event # yield each event separately - - def try_process_message(self, event: AnyMessage, context: AgentRunContext) -> List[ResponseStreamEvent]: - if event and not self.current_generator: - self.current_generator = ResponseStreamEventGenerator(logger, None) - - is_processed = False - next_processor = self.current_generator - returned_events = [] - while not is_processed: - is_processed, next_processor, processed_events = self.current_generator.try_process_message( - event, context, self.stream_state - ) - returned_events.extend(processed_events) - if not is_processed and next_processor == self.current_generator: - logger.warning( - f"Message can not be processed by current generator {type(self.current_generator).__name__}:" - + f" {type(event)}: {event}" - ) - break - if next_processor != self.current_generator: - logger.info( - f"Switching processor from {type(self.current_generator).__name__} " - + f"to {type(next_processor).__name__}" - ) - self.current_generator = next_processor - return returned_events diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py new file mode 100644 index 000000000000..32cbf93a4bfb --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_converter.py @@ -0,0 +1,100 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# mypy: disable-error-code="call-overload,override" +"""Base interface for converting between LangGraph internal state and OpenAI-style responses. + +A ResponseAPIConverter implementation bridges: + 1. Incoming CreateResponse (wrapped in AgentRunContext) -> GraphInputArguments to invoke graph + 2. Graph output -> final non-streaming Response + 3. Streaming graph output events -> ResponseStreamEvent sequence + +Concrete implementations should: + * Decide and document the shape of input arguments they return in convert_request + * Handle aggregation, error mapping, and metadata propagation in convert_response_non_stream + * Incrementally translate async stream_state items in convert_response_stream + +Do NOT perform network I/O directly inside these methods (other than awaiting the +provided async iterator). Keep them pure transformation layers so they are testable. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, AsyncIterable, AsyncIterator, Dict, TypedDict, Union + +from langchain_core.runnables import RunnableConfig +from langgraph.types import Command, StreamMode + +from azure.ai.agentserver.core.models import Response, ResponseStreamEvent +from .._context import LanggraphRunContext + + +class GraphInputArguments(TypedDict): + """TypedDict for LangGraph input arguments.""" + input: Union[Dict[str, Any], Command, None] + config: RunnableConfig + context: LanggraphRunContext + stream_mode: StreamMode + + +class ResponseAPIConverter(ABC): + """ + Abstract base class for LangGraph input/output <-> response conversion. + + Orchestrates the conversions + + :meta private: + """ + @abstractmethod + async def convert_request(self, context: LanggraphRunContext) -> GraphInputArguments: + """Convert the incoming request to a serializable dict for LangGraph. + + This is a convenience wrapper around request_to_state that only returns + dict states, raising ValueError if a Command is returned instead. + + :param context: The context for the agent run. + :type context: LanggraphRunContext + + :return: The initial LangGraph arguments + :rtype: GraphInputArguments + """ + + @abstractmethod + async def convert_response_non_stream( + self, output: Union[dict[str, Any], Any], context: LanggraphRunContext + ) -> Response: + """Convert the completed LangGraph state into a final non-streaming Response object. + + This is a convenience wrapper around state_to_response that retrieves + the current state snapshot asynchronously. + + :param output: The LangGraph output to convert. + :type output: Union[dict[str, Any], Any], + :param context: The context for the agent run. + :type context: LanggraphRunContext + + :return: The final non-streaming Response object. + :rtype: Response + """ + + @abstractmethod + async def convert_response_stream( + self, + output: AsyncIterator[Union[Dict[str, Any], Any]], + context: LanggraphRunContext, + ) -> AsyncIterable[ResponseStreamEvent]: + """Convert an async iterator of LangGraph stream events into stream events. + + This is a convenience wrapper around state_to_response_stream that retrieves + the current stream of state updates asynchronously. + + :param output: An async iterator yielding LangGraph stream events + :type output: AsyncIterator[Dict[str, Any] | Any] + :param context: The context for the agent run. + :type context: LanggraphRunContext + + :return: An async iterable yielding ResponseStreamEvent objects. + :rtype: AsyncIterable[ResponseStreamEvent] + """ + raise NotImplementedError diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py new file mode 100644 index 000000000000..912e3f9632c5 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_default_converter.py @@ -0,0 +1,159 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +import time +from collections.abc import Callable +from typing import Any, AsyncIterable, AsyncIterator, Dict, Optional, Union + +from langchain_core.runnables import RunnableConfig +from langgraph.graph.state import CompiledStateGraph +from langgraph.types import Command, StateSnapshot, StreamMode + +from azure.ai.agentserver.core.models import Response, ResponseStreamEvent +from .human_in_the_loop_helper import HumanInTheLoopHelper +from .human_in_the_loop_json_helper import HumanInTheLoopJsonHelper +from .response_api_converter import GraphInputArguments, ResponseAPIConverter +from .response_api_non_stream_response_converter import (ResponseAPIMessagesNonStreamResponseConverter, + ResponseAPINonStreamResponseConverter) +from .response_api_request_converter import ResponseAPIMessageRequestConverter, ResponseAPIRequestConverter +from .response_api_stream_response_converter import ResponseAPIMessagesStreamResponseConverter +from .._context import LanggraphRunContext + + +class ResponseAPIDefaultConverter(ResponseAPIConverter): + """ + Default implementation of ResponseAPIConverter for LangGraph input/output <-> Response API. + Orchestrates the conversions using default converters and HITL helper. + """ + def __init__(self, + graph: CompiledStateGraph, + create_request_converter: Callable[[LanggraphRunContext], ResponseAPIRequestConverter] | None = None, + create_stream_response_converter: Callable[ + [LanggraphRunContext], + ResponseAPIMessagesStreamResponseConverter + ] | None = None, + create_non_stream_response_converter: Callable[ + [LanggraphRunContext], + ResponseAPINonStreamResponseConverter + ] | None = None, + create_human_in_the_loop_helper: Callable[[LanggraphRunContext], HumanInTheLoopHelper] | None = None): + self._graph = graph + self._custom_request_converter_factory = create_request_converter + self._custom_stream_response_converter_factory = create_stream_response_converter + self._custom_non_stream_response_converter_factory = create_non_stream_response_converter + self._custom_human_in_the_loop_helper_factory = create_human_in_the_loop_helper + + async def convert_request(self, context: LanggraphRunContext) -> GraphInputArguments: + prev_state = await self._aget_state(context) + input_data = self._convert_request_input(context, prev_state) + stream_mode = self.get_stream_mode(context) + return GraphInputArguments( + input=input_data, + stream_mode=stream_mode, + config={}, + context=context, + ) + + async def convert_response_non_stream( + self, output: Union[dict[str, Any], Any], context: LanggraphRunContext) -> Response: + agent_run_context = context.agent_run + converter = self._create_non_stream_response_converter(context) + converted_output = converter.convert(output) + + agent_id = agent_run_context.get_agent_id_object() + conversation = agent_run_context.get_conversation_object() + response = Response( # type: ignore[call-overload] + object="response", + id=agent_run_context.response_id, + agent=agent_id, + conversation=conversation, + metadata=agent_run_context.request.get("metadata"), + created_at=int(time.time()), + output=converted_output, + ) + return response + + async def convert_response_stream( # type: ignore[override] + self, + output: AsyncIterator[Union[Dict[str, Any], Any]], + context: LanggraphRunContext, + ) -> AsyncIterable[ResponseStreamEvent]: + converter = self._create_stream_response_converter(context) + async for event in output: + converted_output = converter.convert(event) + for e in converted_output: + yield e + + state = await self._aget_state(context) + finalized_output = converter.finalize(state) # finalize the response with graph state after stream + for event in finalized_output: + yield event + + def get_stream_mode(self, context: LanggraphRunContext) -> StreamMode: + if context.agent_run.stream: + return "messages" + return "updates" + + def _create_request_converter(self, context: LanggraphRunContext) -> ResponseAPIRequestConverter: + if self._custom_request_converter_factory: + return self._custom_request_converter_factory(context) + data = context.agent_run.request + return ResponseAPIMessageRequestConverter(data) + + def _create_stream_response_converter( + self, context: LanggraphRunContext + ) -> ResponseAPIMessagesStreamResponseConverter: + if self._custom_stream_response_converter_factory: + return self._custom_stream_response_converter_factory(context) + hitl_helper = self._create_human_in_the_loop_helper(context) + return ResponseAPIMessagesStreamResponseConverter(context, hitl_helper=hitl_helper) + + def _create_non_stream_response_converter( + self, context: LanggraphRunContext + ) -> ResponseAPINonStreamResponseConverter: + if self._custom_non_stream_response_converter_factory: + return self._custom_non_stream_response_converter_factory(context) + hitl_helper = self._create_human_in_the_loop_helper(context) + return ResponseAPIMessagesNonStreamResponseConverter(context, hitl_helper) + + def _create_human_in_the_loop_helper(self, context: LanggraphRunContext) -> HumanInTheLoopHelper: + if self._custom_human_in_the_loop_helper_factory: + return self._custom_human_in_the_loop_helper_factory(context) + return HumanInTheLoopJsonHelper(context) + + def _convert_request_input( + self, context: LanggraphRunContext, prev_state: Optional[StateSnapshot] + ) -> Union[Dict[str, Any], Command]: + """ + Convert the CreateResponse input to LangGraph input format, handling HITL if needed. + + :param context: The context for the agent run. + :type context: LanggraphRunContext + :param prev_state: The previous LangGraph state snapshot. + :type prev_state: Optional[StateSnapshot] + + :return: The converted LangGraph input data or Command for HITL. + :rtype: Union[Dict[str, Any], Command] + """ + hitl_helper = self._create_human_in_the_loop_helper(context) + if hitl_helper: + command = hitl_helper.validate_and_convert_human_feedback( + prev_state, context.agent_run.request.get("input") + ) + if command is not None: + return command + converter = self._create_request_converter(context) + return converter.convert() + + async def _aget_state(self, context: LanggraphRunContext) -> Optional[StateSnapshot]: + if not (thread_id := context.agent_run.conversation_id): + return None + config = RunnableConfig( + configurable={"thread_id": thread_id}, + ) + if self._graph.checkpointer: + state = await self._graph.aget_state(config=config) + return state + return None diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_response_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py similarity index 62% rename from sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_response_converter.py rename to sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py index 62560279cdc6..7ec8bdf14f1a 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_response_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_non_stream_response_converter.py @@ -4,41 +4,83 @@ # pylint: disable=logging-fstring-interpolation,broad-exception-caught,logging-not-lazy # mypy: disable-error-code="valid-type,call-overload,attr-defined" import copy -from typing import List +from abc import ABC, abstractmethod +from typing import Any, Collection, Iterable, List, Union from langchain_core import messages from langchain_core.messages import AnyMessage from azure.ai.agentserver.core.logger import get_logger from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext - +from .human_in_the_loop_helper import ( + HumanInTheLoopHelper, + INTERRUPT_NODE_NAME, +) from .utils import extract_function_call +from .._context import LanggraphRunContext logger = get_logger() -class LangGraphResponseConverter: - def __init__(self, context: AgentRunContext, output): +class ResponseAPINonStreamResponseConverter(ABC): + """ + Abstract base class for converting Langgraph output to items in Response format. + One converter instance handles one response. + """ + @abstractmethod + def convert(self, output: dict[str, Any]) -> list[project_models.ItemResource]: + """ + Convert the Langgraph output to a list of ItemResource objects. + + :param output: The Langgraph output to be converted. + :type output: dict[str, Any] + + :return: A list of ItemResource objects representing the converted output. + :rtype: list[project_models.ItemResource] + """ + raise NotImplementedError + + +class ResponseAPIMessagesNonStreamResponseConverter(ResponseAPINonStreamResponseConverter): # pylint: disable=C4751 + """ + Convert Langgraph MessageState output to ItemResource objects. + """ + def __init__(self, + context: LanggraphRunContext, + hitl_helper: HumanInTheLoopHelper): self.context = context - self.output = output + self.hitl_helper = hitl_helper - def convert(self) -> list[project_models.ItemResource]: - res = [] - for step in self.output: + def convert(self, output: Union[dict[str, Any], Any]) -> list[project_models.ItemResource]: + res: list[project_models.ItemResource] = [] + if not isinstance(output, list): + logger.error(f"Expected output to be a list, got {type(output)}: {output}") + raise ValueError(f"Invalid output format. Expected a list, got {type(output)}.") + for step in output: for node_name, node_output in step.items(): - message_arr = node_output.get("messages") - if not message_arr: - logger.warning(f"No messages found in node {node_name} output: {node_output}") - continue - for message in message_arr: - try: - converted = self.convert_output_message(message) - res.append(converted) - except Exception as e: - logger.error(f"Error converting message {message}: {e}") + node_results = self._convert_node_output(node_name, node_output) + res.extend(node_results) return res + def _convert_node_output( + self, node_name: str, node_output: Any + ) -> Iterable[project_models.ItemResource]: + if node_name == INTERRUPT_NODE_NAME: + yield from self.hitl_helper.convert_interrupts(node_output) + else: + message_arr = node_output.get("messages") + if not message_arr or not isinstance(message_arr, Collection): + logger.warning(f"No messages found in node {node_name} output: {node_output}") + return + + for message in message_arr: + try: + converted = self.convert_output_message(message) + if converted: + yield converted + except Exception as e: + logger.error(f"Error converting message {message}: {e}") + def convert_output_message(self, output_message: AnyMessage): # pylint: disable=inconsistent-return-statements # Implement the conversion logic for inner inputs if isinstance(output_message, messages.HumanMessage): @@ -46,7 +88,7 @@ def convert_output_message(self, output_message: AnyMessage): # pylint: disable content=self.convert_MessageContent( output_message.content, role=project_models.ResponsesMessageRole.USER ), - id=self.context.id_generator.generate_message_id(), + id=self.context.agent_run.id_generator.generate_message_id(), status="completed", # temporary status, can be adjusted based on actual logic ) if isinstance(output_message, messages.SystemMessage): @@ -54,7 +96,7 @@ def convert_output_message(self, output_message: AnyMessage): # pylint: disable content=self.convert_MessageContent( output_message.content, role=project_models.ResponsesMessageRole.SYSTEM ), - id=self.context.id_generator.generate_message_id(), + id=self.context.agent_run.id_generator.generate_message_id(), status="completed", ) if isinstance(output_message, messages.AIMessage): @@ -71,22 +113,23 @@ def convert_output_message(self, output_message: AnyMessage): # pylint: disable call_id=call_id, name=name, arguments=argument, - id=self.context.id_generator.generate_function_call_id(), + id=self.context.agent_run.id_generator.generate_function_call_id(), status="completed", ) return project_models.ResponsesAssistantMessageItemResource( content=self.convert_MessageContent( output_message.content, role=project_models.ResponsesMessageRole.ASSISTANT ), - id=self.context.id_generator.generate_message_id(), + id=self.context.agent_run.id_generator.generate_message_id(), status="completed", ) if isinstance(output_message, messages.ToolMessage): return project_models.FunctionToolCallOutputItemResource( call_id=output_message.tool_call_id, output=output_message.content, - id=self.context.id_generator.generate_function_output_id(), + id=self.context.agent_run.id_generator.generate_function_output_id(), ) + logger.warning(f"Unsupported message type: {type(output_message)}, {output_message}") def convert_MessageContent( self, content, role: project_models.ResponsesMessageRole diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_request_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_request_converter.py similarity index 92% rename from sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_request_converter.py rename to sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_request_converter.py index d29a346b192b..f62d4c2b62f9 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/langgraph_request_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_request_converter.py @@ -3,6 +3,7 @@ # --------------------------------------------------------- # pylint: disable=logging-fstring-interpolation # mypy: ignore-errors +from abc import ABC, abstractmethod import json from typing import Dict, List @@ -37,7 +38,22 @@ } -class LangGraphRequestConverter: +class ResponseAPIRequestConverter(ABC): + """ + Convert CreateResponse to LangGraph request format. + """ + @abstractmethod + def convert(self) -> dict: + """ + Convert the CreateResponse to a LangGraph request format. + + :return: The converted LangGraph request. + :rtype: dict + """ + raise NotImplementedError + + +class ResponseAPIMessageRequestConverter(ResponseAPIRequestConverter): def __init__(self, data: CreateResponse): self.data: CreateResponse = data diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_stream_response_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_stream_response_converter.py new file mode 100644 index 000000000000..02f79c589a96 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_api_stream_response_converter.py @@ -0,0 +1,114 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +# pylint: disable=logging-fstring-interpolation,C4751 +# mypy: disable-error-code="assignment,valid-type" +from abc import ABC, abstractmethod +from typing import Any, List, Union + +from langchain_core.messages import AnyMessage + +from azure.ai.agentserver.core.logger import get_logger +from azure.ai.agentserver.core.models import ResponseStreamEvent +from .human_in_the_loop_helper import HumanInTheLoopHelper +from .response_event_generators import ( + ResponseEventGenerator, + ResponseStreamEventGenerator, + StreamEventState, +) +from .._context import LanggraphRunContext + +logger = get_logger() + + +class ResponseAPIStreamResponseConverter(ABC): + """ + Abstract base class for converting Langgraph streamed output to ResponseStreamEvent objects. + One converter instance handles one response stream. + """ + @abstractmethod + def convert(self, event: Union[AnyMessage, dict, Any, None]): + """ + Convert the Langgraph streamed output to ResponseStreamEvent objects. + + :param event: The event to convert. + :type event: Union[AnyMessage, dict, Any, None] + :return: An asynchronous generator yielding ResponseStreamEvent objects. + :rtype: AsyncGenerator[ResponseStreamEvent, None] + """ + raise NotImplementedError + + @abstractmethod + def finalize(self, graph_state=None): + """ + Finalize the conversion process after the stream ends. + + :param graph_state: The final graph state. + :type graph_state: Any + :return: An asynchronous generator yielding final ResponseStreamEvent objects. + :rtype: AsyncGenerator[ResponseStreamEvent, None] + """ + + +class ResponseAPIMessagesStreamResponseConverter(ResponseAPIStreamResponseConverter): + def __init__(self, context: LanggraphRunContext, *, hitl_helper: HumanInTheLoopHelper): + # self.stream = stream + self.context = context + self.hitl_helper = hitl_helper + + self.stream_state = StreamEventState() + self.current_generator: ResponseEventGenerator = None + + def convert(self, event: Union[AnyMessage, dict, Any, None]): + try: + if self.current_generator is None: + self.current_generator = ResponseStreamEventGenerator(logger, None, hitl_helper=self.hitl_helper) + if event is None or not hasattr(event, '__getitem__'): + raise ValueError(f"Event is not indexable: {event}") + message = event[0] # expect a tuple + converted = self.try_process_message(message, self.context) + return converted + except Exception as e: + logger.error(f"Error converting message {event}: {e}") + raise ValueError(f"Error converting message {event}") from e + + def finalize(self, graph_state=None): + logger.info("Stream ended, finalizing response.") + res = [] + # check and convert interrupts + if self.hitl_helper.has_interrupt(graph_state): + interrupt = graph_state.interrupts[0] # should have only one interrupt + converted = self.try_process_message(interrupt, self.context) + res.extend(converted) + # finalize the stream + converted = self.try_process_message(None, self.context) + res.extend(converted) + return res + + def try_process_message( + self, event: Union[AnyMessage, Any, None], context: LanggraphRunContext + ) -> List[ResponseStreamEvent]: + if event and not self.current_generator: + self.current_generator = ResponseStreamEventGenerator(logger, None, hitl_helper=self.hitl_helper) + + is_processed = False + next_processor = self.current_generator + returned_events = [] + while not is_processed: + is_processed, next_processor, processed_events = self.current_generator.try_process_message( + event, context, self.stream_state + ) + returned_events.extend(processed_events) + if not is_processed and next_processor == self.current_generator: + logger.warning( + f"Message can not be processed by current generator {type(self.current_generator).__name__}:" + + f" {type(event)}: {event}" + ) + break + if next_processor != self.current_generator: + logger.info( + f"Switching processor from {type(self.current_generator).__name__} " + + f"to {type(next_processor).__name__}" + ) + self.current_generator = next_processor + return returned_events diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/item_resource_helpers.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/item_resource_helpers.py index a1c97423d5ae..8502ec13069b 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/item_resource_helpers.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/item_resource_helpers.py @@ -2,13 +2,18 @@ # Copyright (c) Microsoft Corporation. All rights reserved. # --------------------------------------------------------- # mypy: disable-error-code="assignment" +from typing import Optional + +from langgraph.types import Interrupt + from azure.ai.agentserver.core.models import projects as project_models +from ..human_in_the_loop_helper import HumanInTheLoopHelper from ..utils import extract_function_call class ItemResourceHelper: - def __init__(self, item_type: str, item_id: str = None): + def __init__(self, item_type: str, item_id: Optional[str] = None): self.item_type = item_type self.item_id = item_id @@ -58,6 +63,31 @@ def get_aggregated_content(self): return self.create_item_resource(is_done=True) +class FunctionCallInterruptItemResourceHelper(ItemResourceHelper): + def __init__(self, + item_id: Optional[str] = None, + hitl_helper: Optional[HumanInTheLoopHelper] = None, + interrupt: Optional[Interrupt] = None): + super().__init__(project_models.ItemType.FUNCTION_CALL, item_id) + self.hitl_helper = hitl_helper + self.interrupt = interrupt + + def create_item_resource(self, is_done: bool): + if self.hitl_helper is None or self.interrupt is None: + return None + item_resource = self.hitl_helper.convert_interrupt(self.interrupt) + if item_resource is not None and not is_done: + if hasattr(item_resource, 'arguments'): + item_resource.arguments = "" # type: ignore[union-attr] + return item_resource + + def add_aggregate_content(self, item): + pass + + def get_aggregated_content(self): + return self.create_item_resource(is_done=True) + + class FunctionCallOutputItemResourceHelper(ItemResourceHelper): def __init__(self, item_id: str = None, call_id: str = None): super().__init__(project_models.ItemType.FUNCTION_CALL_OUTPUT, item_id) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_content_part_event_generator.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_content_part_event_generator.py index fe141887a2b2..4823de4411ae 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_content_part_event_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_content_part_event_generator.py @@ -63,7 +63,7 @@ def try_process_message( return is_processed, next_processor, events - def on_start( # mypy: ignore[override] + def on_start( # mypy: ignore[override] self, event, run_details, stream_state: StreamEventState ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: if self.started: @@ -81,8 +81,9 @@ def on_start( # mypy: ignore[override] return True, [start_event] - def on_end(self, message, context, stream_state: StreamEventState - ) -> List[project_models.ResponseStreamEvent]: # mypy: ignore[override] + def on_end( + self, message, context, stream_state: StreamEventState + ) -> List[project_models.ResponseStreamEvent]: # mypy: ignore[override] aggregated_content = self.item_content_helper.create_item_content() done_event = project_models.ResponseContentPartDoneEvent( item_id=self.item_id, diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_event_generator.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_event_generator.py index ee19ca74f4bb..cd161b99d152 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_event_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_event_generator.py @@ -8,7 +8,7 @@ from langchain_core.messages import AnyMessage from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext +from ..._context import LanggraphRunContext class StreamEventState: @@ -33,18 +33,18 @@ def __init__(self, logger, parent): self.parent = parent # parent generator def try_process_message( - self, - message: AnyMessage, # mypy: ignore[valid-type] - context: AgentRunContext, - stream_state: StreamEventState - ): # mypy: ignore[empty-body] + self, + message: AnyMessage, # mypy: ignore[valid-type] + context: LanggraphRunContext, + stream_state: StreamEventState, + ): # mypy: ignore[empty-body] """ Try to process the incoming message. :param message: The incoming message to process. :type message: AnyMessage :param context: The agent run context. - :type context: AgentRunContext + :type context: LanggraphRunContext :param stream_state: The current stream event state. :type stream_state: StreamEventState @@ -63,8 +63,8 @@ def on_start(self) -> tuple[bool, List[project_models.ResponseStreamEvent]]: return False, [] def on_end( - self, message: AnyMessage, context: AgentRunContext, stream_state: StreamEventState - ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: + self, message: AnyMessage, context: LanggraphRunContext, stream_state: StreamEventState + ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: """ Generate the ending events for this layer. TODO: handle different end conditions, e.g. normal end, error end, etc. @@ -72,7 +72,7 @@ def on_end( :param message: The incoming message to process. :type message: AnyMessage :param context: The agent run context. - :type context: AgentRunContext + :type context: LanggraphRunContext :param stream_state: The current stream event state. :type stream_state: StreamEventState diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_function_call_argument_event_generator.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_function_call_argument_event_generator.py index dbaed3ac9258..56c3bde68632 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_function_call_argument_event_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_function_call_argument_event_generator.py @@ -3,28 +3,39 @@ # --------------------------------------------------------- # pylint: disable=unused-argument,name-too-long # mypy: ignore-errors -from typing import List +from typing import List, Union from langchain_core import messages as langgraph_messages from langchain_core.messages import AnyMessage +from langgraph.types import Interrupt from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext - -from ..utils import extract_function_call from . import ResponseEventGenerator, StreamEventState +from ..human_in_the_loop_helper import HumanInTheLoopHelper +from ..utils import extract_function_call +from ..._context import LanggraphRunContext class ResponseFunctionCallArgumentEventGenerator(ResponseEventGenerator): - def __init__(self, logger, parent: ResponseEventGenerator, item_id, message_id, output_index: int): + def __init__( + self, + logger, + parent: ResponseEventGenerator, + item_id, + message_id, + output_index: int, + *, + hitl_helper: HumanInTheLoopHelper = None, + ): super().__init__(logger, parent) self.item_id = item_id self.output_index = output_index self.aggregated_content = "" self.message_id = message_id + self.hitl_helper = hitl_helper def try_process_message( - self, message, context: AgentRunContext, stream_state: StreamEventState + self, message, context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, ResponseEventGenerator, List[project_models.ResponseStreamEvent]]: is_processed = False events = [] @@ -55,21 +66,31 @@ def on_start( return True, [] def process( - self, message: AnyMessage, run_details, stream_state: StreamEventState + self, message: Union[langgraph_messages.AnyMessage, Interrupt], run_details, stream_state: StreamEventState ) -> tuple[bool, ResponseEventGenerator, List[project_models.ResponseStreamEvent]]: - tool_call = self.get_tool_call_info(message) - if tool_call: - _, _, argument = extract_function_call(tool_call) - if argument: - argument_delta_event = project_models.ResponseFunctionCallArgumentsDeltaEvent( - item_id=self.item_id, - output_index=self.output_index, - delta=argument, - sequence_number=stream_state.sequence_number, - ) - stream_state.sequence_number += 1 - self.aggregated_content += argument - return True, self, [argument_delta_event] + if self.should_end(message): + return False, self, [] + + argument = None + if isinstance(message, Interrupt): + if self.hitl_helper: + _, _, argument = self.hitl_helper.interrupt_to_function_call(message) + else: + argument = None + else: + tool_call = self.get_tool_call_info(message) + if tool_call: + _, _, argument = extract_function_call(tool_call) + if argument: + argument_delta_event = project_models.ResponseFunctionCallArgumentsDeltaEvent( + item_id=self.item_id, + output_index=self.output_index, + delta=argument, + sequence_number=stream_state.sequence_number, + ) + stream_state.sequence_number += 1 + self.aggregated_content += argument + return True, self, [argument_delta_event] return False, self, [] def has_finish_reason(self, message: AnyMessage) -> bool: @@ -94,7 +115,7 @@ def should_end(self, event: AnyMessage) -> bool: return False def on_end( - self, message: AnyMessage, context: AgentRunContext, stream_state: StreamEventState + self, message: AnyMessage, context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: done_event = project_models.ResponseFunctionCallArgumentsDoneEvent( item_id=self.item_id, @@ -106,7 +127,7 @@ def on_end( self.parent.aggregate_content(self.aggregated_content) # pass aggregated content to parent return [done_event] - def get_tool_call_info(self, message: langgraph_messages.AnyMessage): + def get_tool_call_info(self, message: Union[langgraph_messages.AnyMessage, Interrupt]): if isinstance(message, langgraph_messages.AIMessageChunk): if message.tool_call_chunks: if len(message.tool_call_chunks) > 1: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_item_event_generator.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_item_event_generator.py index a2606d1541c1..14eee3c571b2 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_item_event_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_item_event_generator.py @@ -3,35 +3,39 @@ # --------------------------------------------------------- # pylint: disable=unused-argument # mypy: ignore-errors -from typing import List +from typing import List, Union from langchain_core import messages as langgraph_messages from langchain_core.messages import AnyMessage +from langgraph.types import Interrupt from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext from azure.ai.agentserver.core.server.common.id_generator.id_generator import IdGenerator - from . import ResponseEventGenerator, StreamEventState, item_resource_helpers from .response_content_part_event_generator import ResponseContentPartEventGenerator from .response_function_call_argument_event_generator import ResponseFunctionCallArgumentEventGenerator +from ..human_in_the_loop_helper import HumanInTheLoopHelper +from ..._context import LanggraphRunContext class ResponseOutputItemEventGenerator(ResponseEventGenerator): - def __init__(self, logger, parent: ResponseEventGenerator, output_index: int, message_id: str = None): + def __init__(self, logger, parent: ResponseEventGenerator, + output_index: int, message_id: str = None, + *, hitl_helper: HumanInTheLoopHelper = None): super().__init__(logger, parent) self.output_index = output_index self.message_id = message_id self.item_resource_helper = None + self.hitl_helper = hitl_helper def try_process_message( - self, message: AnyMessage, context: AgentRunContext, stream_state: StreamEventState + self, message: Union[AnyMessage, Interrupt, None], context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, ResponseEventGenerator, List[project_models.ResponseStreamEvent]]: is_processed = False next_processor = self events = [] if self.item_resource_helper is None: - if not self.try_create_item_resource_helper(message, context.id_generator): + if not self.try_create_item_resource_helper(message, context.agent_run.id_generator): # cannot create item resource, skip this message self.logger.warning(f"Cannot create item resource helper for message: {message}, skipping.") return True, self, [] @@ -65,7 +69,7 @@ def try_process_message( return is_processed, next_processor, events def on_start( - self, event: AnyMessage, context: AgentRunContext, stream_state: StreamEventState + self, event: Union[AnyMessage, Interrupt], context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: if self.started: return True, [] @@ -83,7 +87,7 @@ def on_start( self.started = True return True, [item_added_event] - def should_end(self, event: AnyMessage) -> bool: + def should_end(self, event: Union[AnyMessage, Interrupt]) -> bool: if event is None: self.logger.info("Received None event, ending processor.") return True @@ -92,7 +96,7 @@ def should_end(self, event: AnyMessage) -> bool: return False def on_end( - self, message: AnyMessage, context: AgentRunContext, stream_state: StreamEventState + self, message: Union[AnyMessage, Interrupt], context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: if not self.started: # should not happen return [] @@ -112,7 +116,7 @@ def aggregate_content(self, content): # aggregate content from child processor self.item_resource_helper.add_aggregate_content(content) - def try_create_item_resource_helper(self, event: AnyMessage, id_generator: IdGenerator): # pylint: disable=too-many-return-statements + def try_create_item_resource_helper(self, event: Union[AnyMessage, Interrupt], id_generator: IdGenerator): # pylint: disable=too-many-return-statements if isinstance(event, langgraph_messages.AIMessageChunk) and event.tool_call_chunks: self.item_resource_helper = item_resource_helpers.FunctionCallItemResourceHelper( item_id=id_generator.generate_function_call_id(), tool_call=event.tool_call_chunks[0] @@ -143,9 +147,16 @@ def try_create_item_resource_helper(self, event: AnyMessage, id_generator: IdGen item_id=id_generator.generate_function_output_id(), call_id=event.tool_call_id ) return True + if isinstance(event, Interrupt): + self.item_resource_helper = item_resource_helpers.FunctionCallInterruptItemResourceHelper( + item_id=id_generator.generate_function_output_id(), + hitl_helper=self.hitl_helper, + interrupt=event, + ) + return True return False - def create_child_processor(self, message: AnyMessage): + def create_child_processor(self, message: Union[AnyMessage, Interrupt]): if self.item_resource_helper is None: return None if self.item_resource_helper.item_type == project_models.ItemType.FUNCTION_CALL: @@ -155,6 +166,7 @@ def create_child_processor(self, message: AnyMessage): item_id=self.item_resource_helper.item_id, message_id=message.id, output_index=self.output_index, + hitl_helper=self.hitl_helper, ) if self.item_resource_helper.item_type == project_models.ItemType.MESSAGE: return ResponseContentPartEventGenerator( diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_text_event_generator.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_text_event_generator.py index b6be81ec7cb2..8d0e62650a2d 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_text_event_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_output_text_event_generator.py @@ -6,12 +6,11 @@ from typing import List from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext - from .response_event_generator import ( ResponseEventGenerator, StreamEventState, ) +from ..._context import LanggraphRunContext class ResponseOutputTextEventGenerator(ResponseEventGenerator): @@ -74,7 +73,7 @@ def process( self.aggregated_content += item stream_state.sequence_number += 1 res.append(chunk_event) - return True, self, res # mypy: ignore[return-value] + return True, self, res # mypy: ignore[return-value] return False, self, [] def has_finish_reason(self, message) -> bool: @@ -92,8 +91,8 @@ def should_end(self, message) -> bool: return True return False - def on_end( # mypy: ignore[override] - self, message, context: AgentRunContext, stream_state: StreamEventState + def on_end( # mypy: ignore[override] + self, message, context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: if not self.started: return False, [] diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_stream_event_generator.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_stream_event_generator.py index a6ad1cba7396..72737e0774bb 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_stream_event_generator.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/models/response_event_generators/response_stream_event_generator.py @@ -9,13 +9,12 @@ from langchain_core import messages as langgraph_messages from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext - from .response_event_generator import ( ResponseEventGenerator, StreamEventState, ) from .response_output_item_event_generator import ResponseOutputItemEventGenerator +from ..._context import LanggraphRunContext class ResponseStreamEventGenerator(ResponseEventGenerator): @@ -24,23 +23,24 @@ class ResponseStreamEventGenerator(ResponseEventGenerator): Response stream event generator. """ - def __init__(self, logger, parent): + def __init__(self, logger, parent, *, hitl_helper=None): super().__init__(logger, parent) + self.hitl_helper = hitl_helper self.aggregated_contents: List[project_models.ItemResource] = [] def on_start( - self, context: AgentRunContext, stream_state: StreamEventState + self, context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, List[project_models.ResponseStreamEvent]]: if self.started: return True, [] - agent_id = context.get_agent_id_object() - conversation = context.get_conversation_object() + agent_id = context.agent_run.get_agent_id_object() + conversation = context.agent_run.get_conversation_object() # response create event response_dict = { "object": "response", "agent_id": agent_id, "conversation": conversation, - "id": context.response_id, + "id": context.agent_run.response_id, "status": "in_progress", "created_at": int(time.time()), } @@ -55,7 +55,7 @@ def on_start( "object": "response", "agent_id": agent_id, "conversation": conversation, - "id": context.response_id, + "id": context.agent_run.response_id, "status": "in_progress", "created_at": int(time.time()), } @@ -74,7 +74,7 @@ def should_complete(self, event: langgraph_messages.AnyMessage) -> bool: return False def try_process_message( - self, message: langgraph_messages.AnyMessage, context: AgentRunContext, stream_state: StreamEventState + self, message: langgraph_messages.AnyMessage, context: LanggraphRunContext, stream_state: StreamEventState ) -> tuple[bool, ResponseEventGenerator, List[project_models.ResponseStreamEvent]]: is_processed = False next_processor = self @@ -87,7 +87,7 @@ def try_process_message( if message: # create a child processor next_processor = ResponseOutputItemEventGenerator( - self.logger, self, len(self.aggregated_contents), message.id + self.logger, self, len(self.aggregated_contents), message.id, hitl_helper=self.hitl_helper ) return is_processed, next_processor, events @@ -106,14 +106,15 @@ def should_end(self, event: langgraph_messages.AnyMessage) -> bool: return True return False - def on_end(self, message: langgraph_messages.AnyMessage, context: AgentRunContext, stream_state: StreamEventState): - agent_id = context.get_agent_id_object() - conversation = context.get_conversation_object() + def on_end(self, message: langgraph_messages.AnyMessage, context: LanggraphRunContext, + stream_state: StreamEventState): + agent_id = context.agent_run.get_agent_id_object() + conversation = context.agent_run.get_conversation_object() response_dict = { "object": "response", "agent_id": agent_id, "conversation": conversation, - "id": context.response_id, + "id": context.agent_run.response_id, "status": "completed", "created_at": int(time.time()), "output": self.aggregated_contents, diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/__init__.py new file mode 100644 index 000000000000..bd74de0f4e38 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/__init__.py @@ -0,0 +1,18 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) + +from ._builder import use_foundry_tools +from ._chat_model import FoundryToolLateBindingChatModel +from ._middleware import FoundryToolBindingMiddleware +from ._tool_node import FoundryToolCallWrapper, FoundryToolNodeWrappers + +__all__ = [ + "use_foundry_tools", + "FoundryToolLateBindingChatModel", + "FoundryToolBindingMiddleware", + "FoundryToolCallWrapper", + "FoundryToolNodeWrappers", +] diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py new file mode 100644 index 000000000000..0ea9a2da80f2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_builder.py @@ -0,0 +1,63 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import List, Optional, Union, overload + +from langchain_core.language_models import BaseChatModel + +from azure.ai.agentserver.core.tools import FoundryToolLike, ensure_foundry_tool +from ._chat_model import FoundryToolLateBindingChatModel +from ._middleware import FoundryToolBindingMiddleware +from ._resolver import get_registry + + +@overload +def use_foundry_tools(tools: List[FoundryToolLike], /) -> FoundryToolBindingMiddleware: # pylint: disable=C4743 + """Use foundry tools as middleware. + + :param tools: A list of foundry tools to bind. + :type tools: List[FoundryToolLike] + :return: A FoundryToolBindingMiddleware that binds the given tools. + :rtype: FoundryToolBindingMiddleware + """ + ... + + +@overload +def use_foundry_tools(model: BaseChatModel, tools: List[FoundryToolLike], /) -> FoundryToolLateBindingChatModel: # pylint: disable=C4743 + """Use foundry tools with a chat model. + + :param model: The chat model to bind the tools to. + :type model: BaseChatModel + :param tools: A list of foundry tools to bind. + :type tools: List[FoundryToolLike] + :return: A FoundryToolLateBindingChatModel that binds the given tools to the model. + :rtype: FoundryToolLateBindingChatModel + """ + ... + + +def use_foundry_tools( # pylint: disable=C4743 + model_or_tools: Union[BaseChatModel, List[FoundryToolLike]], + tools: Optional[List[FoundryToolLike]] = None, + /, +) -> Union[FoundryToolBindingMiddleware, FoundryToolLateBindingChatModel]: + """Use foundry tools with a chat model or as middleware. + + :param model_or_tools: The chat model to bind the tools to, or a list of foundry tools to bind as middleware. + :type model_or_tools: Union[BaseChatModel, List[FoundryToolLike]] + :param tools: A list of foundry tools to bind (required if model_or_tools is a chat model). + :type tools: Optional[List[FoundryToolLike]] + :return: A FoundryToolLateBindingChatModel or FoundryToolBindingMiddleware that binds the given tools. + :rtype: Union[FoundryToolBindingMiddleware, FoundryToolLateBindingChatModel] + """ + if isinstance(model_or_tools, BaseChatModel): + if tools is None: + raise ValueError("Tools must be provided when a model is given.") + foundry_tools = [ensure_foundry_tool(tool) for tool in tools] + get_registry().extend(foundry_tools) + return FoundryToolLateBindingChatModel(model_or_tools, runtime=None, foundry_tools=foundry_tools) + + foundry_tools = [ensure_foundry_tool(tool) for tool in model_or_tools] + get_registry().extend(foundry_tools) + return FoundryToolBindingMiddleware(foundry_tools) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py new file mode 100644 index 000000000000..4ca422b88c41 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_chat_model.py @@ -0,0 +1,131 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel, LanguageModelInput +from langchain_core.messages import AIMessage, BaseMessage +from langchain_core.outputs import ChatResult +from langchain_core.runnables import Runnable, RunnableConfig +from langchain_core.tools import BaseTool +from langgraph.prebuilt import ToolNode +from langgraph.runtime import Runtime + +from azure.ai.agentserver.core.tools import FoundryToolLike +from ._tool_node import FoundryToolCallWrapper, FoundryToolNodeWrappers + + +class FoundryToolLateBindingChatModel(BaseChatModel): + """A ChatModel that supports late binding of Foundry tools during invocation. + + This ChatModel allows you to specify Foundry tools that will be resolved and bound + at the time of invocation, rather than at the time of model creation. + + :param delegate: The underlying chat model to delegate calls to. + :type delegate: BaseChatModel + :param foundry_tools: A list of Foundry tools to be resolved and bound during invocation. + :type foundry_tools: List[FoundryToolLike] + """ + + def __init__(self, delegate: BaseChatModel, runtime: Optional[Runtime], foundry_tools: List[FoundryToolLike]): + super().__init__() + self._delegate = delegate + self._runtime = runtime + self._foundry_tools_to_bind = foundry_tools + self._bound_tools: List[Dict[str, Any] | type | Callable | BaseTool] = [] + self._bound_kwargs: dict[str, Any] = {} + + @property + def tool_node(self) -> ToolNode: + """Get a ToolNode that uses this chat model's Foundry tool call wrappers. + + :return: A ToolNode with Foundry tool call wrappers. + :rtype: ToolNode + """ + return ToolNode([], **self.tool_node_wrapper) + + @property + def tool_node_wrapper(self) -> FoundryToolNodeWrappers: + """Get the Foundry tool call wrappers for this chat model. + + Example:: + >>> from langgraph.prebuilt import ToolNode + >>> foundry_tool_bound_chat_model = FoundryToolLateBindingChatModel(...) + >>> ToolNode([...], **foundry_tool_bound_chat_model.as_wrappers()) + + :return: The Foundry tool call wrappers. + :rtype: FoundryToolNodeWrappers + + """ + return FoundryToolCallWrapper(self._foundry_tools_to_bind).as_wrappers() + + def bind_tools(self, # pylint: disable=C4758 + tools: Sequence[ + Dict[str, Any] | type | Callable | BaseTool # noqa: UP006 + ], + *, + tool_choice: str | None = None, + **kwargs: Any) -> Runnable[LanguageModelInput, AIMessage]: + """Record tools to be bound later during invocation. + + :param tools: A sequence of tools to bind. + :type tools: Sequence[Dict[str, Any] | type | Callable | BaseTool] + :keyword tool_choice: Optional tool choice strategy. + :type tool_choice: str | None + :keyword kwargs: Additional keyword arguments for tool binding. + :type kwargs: Any + :return: A Runnable with the tools bound for later invocation. + :rtype: Runnable[LanguageModelInput, AIMessage] + """ + + self._bound_tools.extend(tools) + if tool_choice is not None: + self._bound_kwargs["tool_choice"] = tool_choice + self._bound_kwargs.update(kwargs) + + return self + + def _bound_delegate_for_call(self, config: Optional[RunnableConfig]) -> Runnable[LanguageModelInput, AIMessage]: + from .._context import LanggraphRunContext + + foundry_tools: Iterable[BaseTool] = [] + if context := LanggraphRunContext.resolve(config, self._runtime): + foundry_tools = context.tools.resolved_tools.get(self._foundry_tools_to_bind) + elif self._foundry_tools_to_bind: + raise RuntimeError("Unable to resolve foundry tools from context, " + "if you are running in python < 3.11, " + "make sure you are passing RunnableConfig when calling model.") + + all_tools = self._bound_tools.copy() + all_tools.extend(foundry_tools) + + if not all_tools: + return self._delegate + + bound_kwargs = self._bound_kwargs or {} + return self._delegate.bind_tools(all_tools, **bound_kwargs) + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: + return self._bound_delegate_for_call(config).invoke(input, config=config, **kwargs) + + async def ainvoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> Any: + return await self._bound_delegate_for_call(config).ainvoke(input, config=config, **kwargs) + + def stream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + yield from self._bound_delegate_for_call(config).stream(input, config=config, **kwargs) + + async def astream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + async for x in self._bound_delegate_for_call(config).astream(input, config=config, **kwargs): + yield x + + @property + def _llm_type(self) -> str: + return f"foundry_tool_binding_model({getattr(self._delegate, '_llm_type', type(self._delegate).__name__)})" + + def _generate(self, messages: list[BaseMessage], stop: list[str] | None = None, + run_manager: CallbackManagerForLLMRun | None = None, **kwargs: Any) -> ChatResult: + # should never be called as invoke/ainvoke/stream/astream are redirected to delegate + raise NotImplementedError() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_context.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_context.py new file mode 100644 index 000000000000..53789efec1a4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_context.py @@ -0,0 +1,16 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from dataclasses import dataclass, field + +from ._resolver import ResolvedTools + + +@dataclass +class FoundryToolContext: + """Context for tool resolution. + + :param resolved_tools: The resolved tools of all registered foundry tools. + :type resolved_tools: ResolvedTools + """ + resolved_tools: ResolvedTools = field(default_factory=lambda: ResolvedTools([])) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_middleware.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_middleware.py new file mode 100644 index 000000000000..c226e51e72ac --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_middleware.py @@ -0,0 +1,122 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from __future__ import annotations + +from typing import Awaitable, Callable, ClassVar, List + +from langchain.agents.middleware import AgentMiddleware, ModelRequest, ModelResponse +from langchain.agents.middleware.types import ModelCallResult +from langchain_core.messages import ToolMessage +from langchain_core.tools import BaseTool, Tool +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command + +from azure.ai.agentserver.core.tools import FoundryToolLike + +from ._chat_model import FoundryToolLateBindingChatModel +from ._tool_node import FoundryToolCallWrapper + + +class FoundryToolBindingMiddleware(AgentMiddleware): + """Middleware that binds foundry tools to tool calls in the agent. + + :param foundry_tools: A list of foundry tools to bind. + :type foundry_tools: List[FoundryToolLike] + """ + _DummyToolName: ClassVar[str] = "__dummy_tool_by_foundry_middleware__" + + def __init__(self, foundry_tools: List[FoundryToolLike]): + super().__init__() + + # to ensure `create_agent()` will create a tool node when there are foundry tools to bind + # this tool will never be bound to model and called + self.tools = [self._dummy_tool()] if foundry_tools else [] + + self._foundry_tools_to_bind = foundry_tools + self._tool_call_wrapper = FoundryToolCallWrapper(self._foundry_tools_to_bind) + + @classmethod + def _dummy_tool(cls) -> BaseTool: + return Tool(name=cls._DummyToolName, + func=lambda x: None, + description="__dummy_tool_by_foundry_middleware__") + + def wrap_model_call(self, request: ModelRequest, + handler: Callable[[ModelRequest], ModelResponse]) -> ModelCallResult: + """Wrap the model call to use a FoundryToolBindingChatModel. + + :param request: The model request. + :type request: ModelRequest + :param handler: The model call handler. + :type handler: Callable[[ModelRequest], ModelResponse] + :return: The model call result. + :rtype: ModelCallResult + """ + return handler(self._wrap_model(request)) + + async def awrap_model_call(self, request: ModelRequest, + handler: Callable[[ModelRequest], Awaitable[ModelResponse]]) -> ModelCallResult: + """Asynchronously wrap the model call to use a FoundryToolBindingChatModel. + + :param request: The model request. + :type request: ModelRequest + :param handler: The model call handler. + :type handler: Callable[[ModelRequest], Awaitable[ModelResponse]] + :return: The model call result. + :rtype: ModelCallResult + """ + return await handler(self._wrap_model(request)) + + def _wrap_model(self, request: ModelRequest) -> ModelRequest: + """Wrap the model in the request with a FoundryToolBindingChatModel. + + :param request: The model request. + :type request: ModelRequest + :return: The modified model request. + :rtype: ModelRequest + """ + if not self._foundry_tools_to_bind: + return request + wrapper = FoundryToolLateBindingChatModel(request.model, request.runtime, self._foundry_tools_to_bind) + return request.override(model=wrapper, tools=self._remove_dummy_tool(request)) + + def _remove_dummy_tool(self, request: ModelRequest) -> list: + """Remove the dummy tool from the request's tools if present. + + :param request: The model request. + :type request: ModelRequest + :return: The list of tools without the dummy tool. + :rtype: list + """ + if not request.tools: + return [] + return [tool for tool in request.tools if not isinstance(tool, Tool) or tool.name != self._DummyToolName] + + def wrap_tool_call(self, request: ToolCallRequest, + handler: Callable[[ToolCallRequest], ToolMessage | Command]) -> ToolMessage | Command: + """Wrap the tool call to use FoundryToolCallWrapper. + + :param request: The tool call request. + :type request: ToolCallRequest + :param handler: The tool call handler. + :type handler: Callable[[ToolCallRequest], ToolMessage | Command] + :return: The tool call result. + :rtype: ToolMessage | Command + """ + return self._tool_call_wrapper.call_tool(request, handler) + + async def awrap_tool_call( + self, + request: ToolCallRequest, + handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]]) -> ToolMessage | Command: + """Asynchronously wrap the tool call to use FoundryToolCallWrapper. + + :param request: The tool call request. + :type request: ToolCallRequest + :param handler: The tool call handler. + :type handler: Callable[[ToolCallRequest], Awaitable[ToolMessage | Command]] + :return: The tool call result. + :rtype: ToolMessage | Command + """ + return await self._tool_call_wrapper.call_tool_async(request, handler) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_resolver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_resolver.py new file mode 100644 index 000000000000..5c77b1339132 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_resolver.py @@ -0,0 +1,154 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union, overload + +import logging + +from langchain_core.tools import BaseTool, StructuredTool +from pydantic import BaseModel, Field, create_model + +from azure.ai.agentserver.core import AgentServerContext +from azure.ai.agentserver.core.tools import FoundryToolLike, ResolvedFoundryTool, SchemaDefinition, ensure_foundry_tool +from azure.ai.agentserver.core.tools.utils import ToolNameResolver + + +class ResolvedTools(Iterable[BaseTool]): + """A resolved view of foundry tools into LangChain tools. + + :param tools: An iterable of tuples of resolved foundry tools and their corresponding LangChain tools. + :type tools: Iterable[Tuple[ResolvedFoundryTool, BaseTool]] + """ + def __init__(self, tools: Iterable[Tuple[ResolvedFoundryTool, BaseTool]]): + self._by_source_id: Dict[str, List[BaseTool]] = defaultdict(list) + for rt, t in tools: + self._by_source_id[rt.definition.id].append(t) + + @overload + def get(self, tool: FoundryToolLike, /) -> Iterable[BaseTool]: # pylint: disable=C4743 + """Get the LangChain tools for the given foundry tool. + + :param tool: The foundry tool to get the LangChain tools for. + :type tool: FoundryToolLike + :return: The list of LangChain tools for the given foundry tool. + :rtype: Iterable[BaseTool] + """ + ... + + @overload + def get(self, tools: Iterable[FoundryToolLike], /) -> Iterable[BaseTool]: # pylint: disable=C4743 + """Get the LangChain tools for the given foundry tools. + + :param tools: The foundry tools to get the LangChain tools for. + :type tools: Iterable[FoundryToolLike] + :return: The list of LangChain tools for the given foundry tools. + :rtype: Iterable[BaseTool] + """ + ... + + @overload + def get(self) -> Iterable[BaseTool]: + """Get all LangChain tools. + + :return: The list of all LangChain tools. + :rtype: Iterable[BaseTool] + """ + ... + + def get(self, tool: Union[FoundryToolLike, Iterable[FoundryToolLike], None] = None) -> Iterable[BaseTool]: + """Get the LangChain tools for the given foundry tool(s), or all tools if none is given. + + :param tool: The foundry tool or tools to get the LangChain tools for, or None to get all tools. + :type tool: Union[FoundryToolLike, Iterable[FoundryToolLike], None] + :return: The list of LangChain tools for the given foundry tool(s), or all tools if none is given. + :rtype: Iterable[BaseTool] + """ + if tool is None: + yield from self + return + + tool_list = [tool] if not isinstance(tool, Iterable) else tool # type: ignore[assignment] + for t in tool_list: + ft = ensure_foundry_tool(t) # type: ignore[arg-type] + yield from self._by_source_id.get(ft.id, []) + + def __iter__(self): + for tool_list in self._by_source_id.values(): + yield from tool_list + + +class FoundryLangChainToolResolver: + """Resolves foundry tools into LangChain tools. + + :param name_resolver: The tool name resolver. + :type name_resolver: Optional[ToolNameResolver] + """ + def __init__(self, name_resolver: Optional[ToolNameResolver] = None): + self._name_resolver = name_resolver or ToolNameResolver() + + async def resolve_from_registry(self) -> ResolvedTools: + """Resolve the foundry tools from the global registry into LangChain tools. + + :return: The resolved LangChain tools. + :rtype: Iterable[Tuple[ResolvedFoundryTool, BaseTool]] + """ + return await self.resolve(get_registry()) + + async def resolve(self, foundry_tools: List[FoundryToolLike]) -> ResolvedTools: + """Resolve the given foundry tools into LangChain tools. + + :param foundry_tools: The foundry tools to resolve. + :type foundry_tools: List[FoundryToolLike] + :return: The resolved LangChain tools. + :rtype: Iterable[Tuple[ResolvedFoundryTool, BaseTool]] + """ + context = AgentServerContext.get() + resolved_foundry_tools = await context.tools.catalog.list(foundry_tools) + return ResolvedTools(tools=((tool, self._create_structured_tool(tool)) for tool in resolved_foundry_tools)) + + def _create_structured_tool(self, resolved_tool: ResolvedFoundryTool) -> StructuredTool: + name = self._name_resolver.resolve(resolved_tool) + args_schema = self._create_pydantic_model(name, resolved_tool.input_schema) + + async def _tool_func(**kwargs: Any) -> str: + result = await AgentServerContext.get().tools.invoke(resolved_tool, kwargs) + if isinstance(result, dict): + import json + return json.dumps(result) + return str(result) + + return StructuredTool.from_function( + name=name, + description=resolved_tool.details.description, + coroutine=_tool_func, + args_schema=args_schema + ) + + @classmethod + def _create_pydantic_model(cls, tool_name: str, input_schema: SchemaDefinition) -> type[BaseModel]: + field_definitions: Dict[str, Any] = {} + required_fields = input_schema.required or set() + for prop_name, prop in input_schema.properties.items(): + if prop.type is None: + logging.getLogger(__name__).warning( + "Skipping field '%s' in tool '%s': unknown or empty schema type.", prop_name, tool_name) + continue + py_type = prop.type.py_type + default = ... if prop_name in required_fields else None + field_definitions[prop_name] = (py_type, Field(default, description=prop.description)) + + model_name = f"{tool_name.replace('-', '_').replace(' ', '_').title()}-Input" + return create_model(model_name, **field_definitions) # type: ignore[call-overload] + + +_tool_registry: List[FoundryToolLike] = [] + + +def get_registry() -> List[FoundryToolLike]: + """Get the global foundry tool registry. + + :return: The list of registered foundry tools. + :rtype: List[FoundryToolLike] + """ + return _tool_registry diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py new file mode 100644 index 000000000000..1bfef8c39f81 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/azure/ai/agentserver/langgraph/tools/_tool_node.py @@ -0,0 +1,94 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from typing import Awaitable, Callable, List, TypedDict, Union + +from langchain_core.messages import ToolMessage +from langgraph.prebuilt.tool_node import AsyncToolCallWrapper, ToolCallRequest, ToolCallWrapper +from langgraph.types import Command + +from azure.ai.agentserver.core.tools import FoundryToolLike + +ToolInvocationResult = Union[ToolMessage, Command] +ToolInvocation = Callable[[ToolCallRequest], ToolInvocationResult] +AsyncToolInvocation = Callable[[ToolCallRequest], Awaitable[ToolInvocationResult]] + + +class FoundryToolNodeWrappers(TypedDict): + """A TypedDict for Foundry tool node wrappers. + + Example:: + >>> from langgraph.prebuilt import ToolNode + >>> call_wrapper = FoundryToolCallWrapper(...) + >>> ToolNode([...], **call_wrapper.as_wrappers()) + + :param wrap_tool_call: The synchronous tool call wrapper. + :type wrap_tool_call: ToolCallWrapper + :param awrap_tool_call: The asynchronous tool call wrapper. + :type awrap_tool_call: AsyncToolCallWrapper + """ + + wrap_tool_call: ToolCallWrapper # type: ignore[valid-type] + + awrap_tool_call: AsyncToolCallWrapper # type: ignore[valid-type] + + +class FoundryToolCallWrapper: + """A ToolCallWrapper that tries to resolve invokable foundry tools from context if tool is not resolved yet.""" + def __init__(self, foundry_tools: List[FoundryToolLike]): + self._allowed_foundry_tools = foundry_tools + + def as_wrappers(self) -> FoundryToolNodeWrappers: + """Get the wrappers as a TypedDict. + + :return: The wrappers as a TypedDict. + :rtype: FoundryToolNodeWrappers + """ + return FoundryToolNodeWrappers( + wrap_tool_call=self.call_tool, + awrap_tool_call=self.call_tool_async, + ) + + def call_tool(self, request: ToolCallRequest, invocation: ToolInvocation) -> ToolInvocationResult: + """Call the tool, resolving foundry tools from context if necessary. + + :param request: The tool call request. + :type request: ToolCallRequest + :param invocation: The tool invocation function. + :type invocation: ToolInvocation + :return: The result of the tool invocation. + :rtype: ToolInvocationResult + """ + return invocation(self._maybe_calling_foundry_tool(request)) + + async def call_tool_async(self, request: ToolCallRequest, invocation: AsyncToolInvocation) -> ToolInvocationResult: + """Call the tool, resolving foundry tools from context if necessary. + + :param request: The tool call request. + :type request: ToolCallRequest + :param invocation: The tool invocation function. + :type invocation: AsyncToolInvocation + :return: The result of the tool invocation. + :rtype: ToolInvocationResult + """ + return await invocation(self._maybe_calling_foundry_tool(request)) + + def _maybe_calling_foundry_tool(self, request: ToolCallRequest) -> ToolCallRequest: + from .._context import LanggraphRunContext + + if (request.tool + or not self._allowed_foundry_tools + or not (context := LanggraphRunContext.resolve(runtime=request.runtime))): + # tool is already resolved + return request + + tool_name = request.tool_call["name"] + for t in context.tools.resolved_tools.get(self._allowed_foundry_tools): + if t.name == tool_name: + return ToolCallRequest( + tool_call=request.tool_call, + tool=t, + state=request.state, + runtime=request.runtime, + ) + return request diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/cspell.json b/sdk/agentserver/azure-ai-agentserver-langgraph/cspell.json index 470408fb66cc..3201d7d662a3 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/cspell.json +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/cspell.json @@ -5,7 +5,10 @@ "mslearn", "envtemplate", "ainvoke", - "asetup" + "asetup", + "hitl", + "HITL", + "awrap" ], "ignorePaths": [ "*.csv", diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.response_event_generators.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.response_event_generators.rst new file mode 100644 index 000000000000..af7cc69bd859 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.response_event_generators.rst @@ -0,0 +1,74 @@ +azure.ai.agentserver.langgraph.models.response\_event\_generators package +========================================================================= + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators + :inherited-members: + :members: + :undoc-members: + +Submodules +---------- + +azure.ai.agentserver.langgraph.models.response\_event\_generators.item\_content\_helpers module +----------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.item_content_helpers + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.item\_resource\_helpers module +------------------------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.item_resource_helpers + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_content\_part\_event\_generator module +------------------------------------------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_content_part_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_event\_generator module +--------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_function\_call\_argument\_event\_generator module +----------------------------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_function_call_argument_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_output\_item\_event\_generator module +----------------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_output_item_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_output\_text\_event\_generator module +----------------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_output_text_event_generator + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_event\_generators.response\_stream\_event\_generator module +----------------------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_event_generators.response_stream_event_generator + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.rst new file mode 100644 index 000000000000..aba857c3b64a --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.models.rst @@ -0,0 +1,82 @@ +azure.ai.agentserver.langgraph.models package +============================================= + +.. automodule:: azure.ai.agentserver.langgraph.models + :inherited-members: + :members: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.langgraph.models.response_event_generators + +Submodules +---------- + +azure.ai.agentserver.langgraph.models.human\_in\_the\_loop\_helper module +------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.human_in_the_loop_helper + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.human\_in\_the\_loop\_json\_helper module +------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.human_in_the_loop_json_helper + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_converter module +--------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_default\_converter module +------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_default_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_non\_stream\_response\_converter module +-------------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_non_stream_response_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_request\_converter module +------------------------------------------------------------------------------ + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_request_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.response\_api\_stream\_response\_converter module +--------------------------------------------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.response_api_stream_response_converter + :inherited-members: + :members: + :undoc-members: + +azure.ai.agentserver.langgraph.models.utils module +-------------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.models.utils + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.rst new file mode 100644 index 000000000000..deefeb67fa96 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.rst @@ -0,0 +1,27 @@ +azure.ai.agentserver.langgraph package +====================================== + +.. automodule:: azure.ai.agentserver.langgraph + :inherited-members: + :members: + :undoc-members: + +Subpackages +----------- + +.. toctree:: + :maxdepth: 4 + + azure.ai.agentserver.langgraph.models + azure.ai.agentserver.langgraph.tools + +Submodules +---------- + +azure.ai.agentserver.langgraph.langgraph module +----------------------------------------------- + +.. automodule:: azure.ai.agentserver.langgraph.langgraph + :inherited-members: + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.tools.rst b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.tools.rst new file mode 100644 index 000000000000..17f7ef6d2ab7 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/doc/azure.ai.agentserver.langgraph.tools.rst @@ -0,0 +1,6 @@ +azure.ai.agentserver.langgraph.tools package +============================================ + +.. automodule:: azure.ai.agentserver.langgraph.tools + :members: + :undoc-members: diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml index 5552ff8233d2..f20fcfee1401 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/pyproject.toml @@ -8,6 +8,7 @@ authors = [ ] license = "MIT" classifiers = [ + "Development Status :: 4 - Beta", "Programming Language :: Python", "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3", @@ -19,18 +20,19 @@ classifiers = [ keywords = ["azure", "azure sdk"] dependencies = [ - "azure-ai-agentserver-core", - "langchain>0.3.5", - "langchain-openai>0.3.10", - "langchain-azure-ai[opentelemetry]>=0.1.4", - "langgraph>0.5.0", - "opentelemetry-exporter-otlp-proto-http", + "azure-ai-agentserver-core==1.0.0b10", + "langchain>=1.0.3", + "langchain-openai>=1.0.3", + "langchain-azure-ai[opentelemetry]~=1.0.0", ] [build-system] requires = ["setuptools>=69", "wheel"] build-backend = "setuptools.build_meta" +[project.urls] +repository = "https://github.com/Azure/azure-sdk-for-python" + [tool.setuptools.packages.find] exclude = [ "tests*", @@ -63,5 +65,6 @@ breaking = false # incompatible python version pyright = false verifytypes = false # incompatible python version for -core verify_keywords = false -mindependency = false # depends on -core package -whl_no_aio = false \ No newline at end of file +whl_no_aio = false +mypy = false +apistub = false \ No newline at end of file diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/README.md b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/README.md index 1455a366a0ad..1c78acf0105d 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/README.md +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/README.md @@ -1,10 +1,11 @@ -# Custom LangGraph State Converter (Mini RAG) Sample +# Custom LangGraph Converter (Mini RAG) Sample -This sample demonstrates how to host a LangGraph agent **with a custom internal state** using the `azure.ai.agentserver` SDK by supplying a custom `LanggraphStateConverter` (`RAGStateConverter`). It shows the minimal pattern required to adapt OpenAI Responses-style requests to a LangGraph state and back to an OpenAI-compatible response. +This sample demonstrates how to host a LangGraph agent **with a custom internal state** using the `azure.ai.agentserver` SDK by supplying a custom `ResponseAPIConverter` (`RAGStateConverter`). It shows the minimal pattern required to adapt OpenAI Responses-style requests to a LangGraph state and back to an OpenAI-compatible response. ## What It Shows -- Defining a custom state (`RAGState`) separate from the wire contract. -- Implementing `RAGStateConverter.request_to_state` and `state_to_response` to bridge request ↔ graph ↔ response. +- Defining a custom graph state (`RAGState`) separate from the wire contract. +- Implementing `ResponseAPIRequestConverter` to convert request payload to your graph input +- An overwrite of `ResponseAPIDefaultConverter.convert_response_non_stream` to bridge graph ↔ response. - A simple multi-step graph: intent analysis β†’ optional retrieval β†’ answer generation. - Lightweight retrieval (keyword scoring over an in‑memory knowledge base) with citation annotations added to the assistant message. - Graceful local fallback answer when Azure OpenAI credentials are absent. @@ -13,11 +14,12 @@ This sample demonstrates how to host a LangGraph agent **with a custom internal ## Flow Overview ``` CreateResponse request - -> RAGStateConverter.request_to_state - -> LangGraph executes nodes (analyze β†’ retrieve? β†’ answer) - -> Final state - -> RAGStateConverter.state_to_response - -> OpenAI-style response object + -> RAGStateConverter + -> RAGRequestConverter.convert + -> LangGraph executes nodes (analyze β†’ retrieve? β†’ answer) + -> Final state + -> RAGStateConverter.convert_response_non_stream + -> OpenAI-style response object ``` ## Running @@ -34,7 +36,7 @@ Optional environment variables for live model call: |------|--------| | Real retrieval | Replace `retrieve_docs` with embedding + vector / search backend. | | Richer answers | Introduce prompt templates or additional graph nodes. | -| Multi‑turn memory | Persist prior messages; include truncated history in `request_to_state`. | +| Multi‑turn memory | Persist prior messages; include truncated history in `RAGRequestConverter.convert`. | | Tool / function calls | Add nodes producing tool outputs and incorporate into final response. | | Better citations | Store offsets / URLs and expand annotation objects. | | Streaming support | (See below) | @@ -42,10 +44,10 @@ Optional environment variables for live model call: ### Adding Streaming 1. Allow `stream=True` in requests and propagate a flag into state. 2. Implement `get_stream_mode` (return appropriate mode, e.g. `events`). -3. Implement `state_to_response_stream` to yield `ResponseStreamEvent` objects (lifecycle + deltas) and finalize with a completed event. +3. Implement `convert_response_stream` to yield `ResponseStreamEvent` objects (lifecycle + deltas) and finalize with a completed event. 4. Optionally collect incremental model tokens during `generate_answer`. ## Key Takeaway -A custom `LanggraphStateConverter` is the seam where you map external request contracts to an internal graph-friendly state shape and then format the final (or streamed) result back to the OpenAI Responses schema. Start simple (non‑streaming), then layer retrieval sophistication, memory, tools, and streaming as needed. +A custom `ResponseAPIDefaultConverter` is the seam where you map external request contracts to an internal graph-friendly state shape and then format the final (or streamed) result back to the OpenAI Responses schema. Start simple (non‑streaming), then layer retrieval sophistication, memory, tools, and streaming as needed. Streaming is not supported in this sample out-of-the-box. diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/main.py b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/main.py index 27f5bf0d5ee2..ec45dceccfc8 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/main.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/custom_state/main.py @@ -1,21 +1,19 @@ from __future__ import annotations -import os import json +import os import time from dataclasses import dataclass -from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, TypedDict +from typing import Any, AsyncIterable, AsyncIterator, Dict, List, TypedDict from dotenv import load_dotenv -from langgraph.graph import StateGraph, START, END +from langgraph.graph import END, START, StateGraph from openai import OpenAI, OpenAIError -from azure.ai.agentserver.core import AgentRunContext from azure.ai.agentserver.core.models import Response, ResponseStreamEvent -from azure.ai.agentserver.langgraph import from_langgraph -from azure.ai.agentserver.langgraph.models import ( - LanggraphStateConverter, -) +from azure.ai.agentserver.langgraph import LanggraphRunContext, from_langgraph +from azure.ai.agentserver.langgraph.models.response_api_default_converter import ResponseAPIDefaultConverter +from azure.ai.agentserver.langgraph.models.response_api_request_converter import ResponseAPIRequestConverter load_dotenv() @@ -99,16 +97,14 @@ def retrieve_docs(question: str, k: int = 2) -> List[Dict[str, Any]]: # --------------------------------------------------------------------------- # Custom Converter # --------------------------------------------------------------------------- -class RAGStateConverter(LanggraphStateConverter): - """Converter implementing mini RAG logic (non‑streaming only).""" +class RAGRequestConverter(ResponseAPIRequestConverter): + """Converter implementing mini RAG logic.""" - def get_stream_mode(self, context: AgentRunContext) -> str: # noqa: D401 - if context.request.get("stream", False): # type: ignore[attr-defined] - raise NotImplementedError("Streaming not supported in this sample.") - return "values" + def __init__(self, context: LanggraphRunContext): + self.context = context - def request_to_state(self, context: AgentRunContext) -> Dict[str, Any]: # noqa: D401 - req = context.request + def convert(self) -> dict: + req = self.context.agent_run.request user_input = req.get("input") if isinstance(user_input, list): for item in user_input: @@ -136,10 +132,20 @@ def request_to_state(self, context: AgentRunContext) -> Dict[str, Any]: # noqa: } print("initial state:", res) return res + - def state_to_response( - self, state: Dict[str, Any], context: AgentRunContext - ) -> Response: # noqa: D401 +class RAGStateConverter(ResponseAPIDefaultConverter): + """Converter implementing mini RAG logic (non‑streaming only).""" + def __init__(self, graph: StateGraph): + super().__init__(graph=graph, + create_request_converter=lambda context: RAGRequestConverter(context)) + + def get_stream_mode(self, context: LanggraphRunContext) -> str: # noqa: D401 + if context.agent_run.request.get("stream", False): # type: ignore[attr-defined] + raise NotImplementedError("Streaming not supported in this sample.") + return "values" + + async def convert_response_non_stream(self, state: Any, context: LanggraphRunContext) -> Response: final_answer = state.get("final_answer") or "(no answer generated)" print(f"convert state to response, state: {state}") citations = state.get("retrieved", []) @@ -163,20 +169,20 @@ def state_to_response( } base = { "object": "response", - "id": context.response_id, - "agent": context.get_agent_id_object(), - "conversation": context.get_conversation_object(), + "id": context.agent_run.response_id, + "agent": context.agent_run.get_agent_id_object(), + "conversation": context.agent_run.get_conversation_object(), "status": "completed", "created_at": int(time.time()), "output": [output_item], } return Response(**base) - async def state_to_response_stream( # noqa: D401 + async def convert_response_stream( # noqa: D401 self, stream_state: AsyncIterator[Dict[str, Any] | Any], - context: AgentRunContext, - ) -> AsyncGenerator[ResponseStreamEvent, None]: + context: LanggraphRunContext, + ) -> AsyncIterable[ResponseStreamEvent]: raise NotImplementedError("Streaming not supported in this sample.") @@ -289,5 +295,6 @@ def _build_graph(): # --------------------------------------------------------------------------- if __name__ == "__main__": graph = _build_graph() - converter = RAGStateConverter() - from_langgraph(graph, converter).run() + + converter = RAGStateConverter(graph=graph) + from_langgraph(graph, converter=converter).run() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/.env-template b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/.env-template new file mode 100644 index 000000000000..92b9c812a686 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/.env-template @@ -0,0 +1,4 @@ +AZURE_OPENAI_API_KEY= +AZURE_OPENAI_ENDPOINT=https://.cognitiveservices.azure.com/ +OPENAI_API_VERSION=2025-03-01-preview +AZURE_OPENAI_CHAT_DEPLOYMENT_NAME= diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/README.md b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/README.md new file mode 100644 index 000000000000..e5c5aa3734a2 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/README.md @@ -0,0 +1,133 @@ +# LangGraph Human-in-the-Loop Sample + +This sample demonstrates how to create an intelligent agent with human-in-the-loop capabilities using LangGraph and Azure AI Agent Server. The agent can interrupt its execution to ask for human input when needed, making it ideal for scenarios requiring human judgment or additional information. + +## Overview + +The sample consists of several key components: + +- **LangGraph Agent**: An AI agent that can intelligently decide when to ask humans for input during task execution +- **Human Interrupt Mechanism**: Uses LangGraph's `interrupt()` function to pause execution and wait for human feedback +- **Azure AI Agent Server Adapter**: Hosts the LangGraph agent as an HTTP service + +## Files Description + +- `main.py` - The main LangGraph agent implementation with human-in-the-loop capabilities +- `requirements.txt` - Python dependencies for the sample + + + +## Setup + +1. **Environment Configuration** + + Create a `.env` file in this directory with your Azure OpenAI configuration: + ```env + AZURE_OPENAI_API_KEY=your_api_key_here + AZURE_OPENAI_ENDPOINT=your_endpoint_here + AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=gpt-4o + ``` + + Alternatively, if you're using Azure Identity (without API key), ensure your Azure credentials are configured. + +2. **Install Dependencies** + + Install the required Python packages: + ```bash + pip install python-dotenv + pip install azure-ai-agentserver[langgraph] + ``` + +## Usage + +### Running the Agent Server + +1. Start the agent server: + ```bash + python main.py + ``` + The server will start on `http://localhost:8088` + +### Making Requests + +#### Initial Request (Triggering Human Input) + +Send a request that will cause the agent to ask for human input: + +```bash +curl -X POST http://localhost:8088/responses \ + -H "Content-Type: application/json" \ + -d '{ + "agent": { + "name": "local_agent", + "type": "agent_reference" + }, + "stream": false, + "input": "Ask the user where they are, then look up the weather there." + }' +``` + +**Response Structure:** + +The agent will respond with an interrupt request: + +```json +{ + "conversation": { + "id": "conv_abc123..." + }, + "output": [ + { + "type": "function_call", + "name": "__hosted_agent_adapter_interrupt__", + "call_id": "call_xyz789...", + "arguments": "{\"question\": \"Where are you located?\"}" + } + ] +} +``` + +#### Providing Human Feedback + +Resume the conversation by providing the human's response: + +```bash +curl -X POST http://localhost:8088/responses \ + -H "Content-Type: application/json" \ + -d '{ + "agent": { + "name": "local_agent", + "type": "agent_reference" + }, + "stream": false, + "input": [ + { + "type": "function_call_output", + "call_id": "call_xyz789...", + "output": "{\"resume\": \"San Francisco\"}" + } + ], + "conversation": { + "id": "conv_abc123..." + } + }' +``` + +**Final Response:** + +The agent will continue execution and provide the final result: + +```json +{ + "conversation": { + "id": "conv_abc123..." + }, + "output": [ + { + "type": "message", + "role": "assistant", + "content": "I looked up the weather in San Francisco. Result: It's sunny in San Francisco." + } + ] +} +``` diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/main.py b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/main.py new file mode 100644 index 000000000000..aed1ff51049e --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/main.py @@ -0,0 +1,204 @@ +""" +Human-in-the-Loop Agent Example + +This sample demonstrates how to create a LangGraph agent that can interrupt +execution to ask for human input when needed. The agent uses Azure OpenAI +and includes a custom tool for asking human questions. +""" + +import os + +from dotenv import load_dotenv +from pydantic import BaseModel + +from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from langchain.chat_models import init_chat_model +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langgraph.checkpoint.memory import InMemorySaver +from langgraph.graph import END, START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode +from langgraph.types import interrupt + +from azure.ai.agentserver.langgraph import from_langgraph + +# Load environment variables +load_dotenv() + + +# ============================================================================= +# Configuration +# ============================================================================= + +DEPLOYMENT_NAME = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4o") +API_KEY = os.getenv("AZURE_OPENAI_API_KEY", "") + + +# ============================================================================= +# Model Initialization +# ============================================================================= + +def initialize_llm(): + """Initialize the language model with Azure OpenAI credentials.""" + if API_KEY: + return init_chat_model(f"azure_openai:{DEPLOYMENT_NAME}") + else: + credential = DefaultAzureCredential() + token_provider = get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" + ) + return init_chat_model( + f"azure_openai:{DEPLOYMENT_NAME}", + azure_ad_token_provider=token_provider, + ) + + +llm = initialize_llm() + +# ============================================================================= +# Tools and Models +# ============================================================================= + +@tool +def search(query: str) -> str: + """ + Call to search the web for information. + + Args: + query: The search query string + + Returns: + Search results as a string + """ + # This is a placeholder for the actual implementation + return f"I looked up: {query}. Result: It's sunny in San Francisco." + + +class AskHuman(BaseModel): + """Schema for asking the human a question.""" + question: str + + +# Initialize tools and bind to model +tools = [search] +tool_node = ToolNode(tools) +model = llm.bind_tools(tools + [AskHuman]) + + +# ============================================================================= +# Graph Nodes +# ============================================================================= + +def call_model(state: MessagesState) -> dict: + """ + Call the language model with the current conversation state. + + Args: + state: The current messages state + + Returns: + Dictionary with the model's response message + """ + messages = state["messages"] + response = model.invoke(messages) + return {"messages": [response]} + + +def ask_human(state: MessagesState) -> dict: + """ + Interrupt execution to ask the human for input. + + Args: + state: The current messages state + + Returns: + Dictionary with the human's response as a tool message + """ + last_message = state["messages"][-1] + tool_call_id = last_message.tool_calls[0]["id"] + ask = AskHuman.model_validate(last_message.tool_calls[0]["args"]) + + # Interrupt and wait for human input + location = interrupt(ask.question) + + tool_message = ToolMessage(tool_call_id=tool_call_id, content=location) + return {"messages": [tool_message]} + + +# ============================================================================= +# Graph Logic +# ============================================================================= + +def should_continue(state: MessagesState) -> str: + """ + Determine the next step in the graph based on the last message. + + Args: + state: The current messages state + + Returns: + The name of the next node to execute, or END to finish + """ + messages = state["messages"] + last_message = messages[-1] + + # If there's no function call, we're done + if not last_message.tool_calls: + return END + + # If asking for human input, route to ask_human node + if last_message.tool_calls[0]["name"] == "AskHuman": + return "ask_human" + + # Otherwise, execute the tool call + return "action" + + +# ============================================================================= +# Graph Construction +# ============================================================================= + +def build_graph() -> StateGraph: + """ + Build and compile the LangGraph workflow. + + Returns: + Compiled StateGraph with checkpointing enabled + """ + workflow = StateGraph(MessagesState) + + # Add nodes + workflow.add_node("agent", call_model) + workflow.add_node("action", tool_node) + workflow.add_node("ask_human", ask_human) + + # Set entry point + workflow.add_edge(START, "agent") + + # Add conditional routing from agent + workflow.add_conditional_edges( + "agent", + should_continue, + path_map=["ask_human", "action", END], + ) + + # Add edges back to agent + workflow.add_edge("action", "agent") + workflow.add_edge("ask_human", "agent") + + # Compile with memory checkpointer + memory = InMemorySaver() + return workflow.compile(checkpointer=memory) + + +app = build_graph() + + +# ============================================================================= +# Main Entry Point +# ============================================================================= + +if __name__ == "__main__": + adapter = from_langgraph(app) + adapter.run() + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/requirements.txt b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/requirements.txt new file mode 100644 index 000000000000..8c3bb2198ef1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/human_in_the_loop/requirements.txt @@ -0,0 +1,3 @@ +python-dotenv>=1.0.0 +azure-ai-agentserver-core +azure-ai-agentserver-langgraph diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/simple_agent_with_foundry_checkpointer/README.md b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/simple_agent_with_foundry_checkpointer/README.md new file mode 100644 index 000000000000..b56576c45fc0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/simple_agent_with_foundry_checkpointer/README.md @@ -0,0 +1,80 @@ +# Simple LangGraph Agent with Foundry Managed Checkpointer + +This sample hosts a LangGraph ReAct-style agent and uses `FoundryCheckpointSaver` to persist +checkpoints in Azure AI Foundry managed storage. + +With Foundry managed checkpoints, graph state is stored remotely so conversations can resume across +requests and server restarts without self-managed storage. + +### What `main.py` does + +- Creates an `AzureChatOpenAI` model and two tools (`get_word_length`, `calculator`) +- Builds a LangGraph agent with `create_react_agent(..., checkpointer=saver)` +- Creates `FoundryCheckpointSaver(project_endpoint, credential)` and runs the server via + `from_langgraph(...).run_async()` + +--- + +## Prerequisites + +- Python 3.10+ +- Azure CLI authenticated with `az login` (required for `AzureCliCredential`) +- An Azure AI Foundry project endpoint +- An Azure OpenAI chat deployment (for example `gpt-4o`) + +--- + +## Setup + +1. Create a `.env` file in this folder: + ```env + AZURE_AI_PROJECT_ENDPOINT=https://.services.ai.azure.com/api/projects/ + AZURE_OPENAI_ENDPOINT=https://.openai.azure.com/ + AZURE_OPENAI_API_KEY= + OPENAI_API_VERSION=2025-03-01-preview + AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=gpt-4o + ``` + +2. Install dependencies: + ```bash + pip install azure-ai-agentserver-langgraph python-dotenv azure-identity langgraph + ``` + +--- + +## Run the Agent + +From this folder: + +```bash +python main.py +``` + +The adapter starts the server on `http://0.0.0.0:8088` by default. + +--- + +## Send Requests + +Non-streaming example: + +```bash +curl -sS \ + -H "Content-Type: application/json" \ + -X POST http://localhost:8088/responses \ + -d '{ + "agent": {"name": "local_agent", "type": "agent_reference"}, + "stream": false, + "input": "What is (15 * 4) + 6?", + "conversation": {"id": "test-conversation-1"} + }' +``` + +Use the same `conversation.id` on follow-up requests to continue the checkpointed conversation state. + +--- + +## Related Resources + +- LangGraph docs: https://langchain-ai.github.io/langgraph/ +- Adapter package docs: `azure.ai.agentserver.langgraph` in this SDK diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/simple_agent_with_foundry_checkpointer/main.py b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/simple_agent_with_foundry_checkpointer/main.py new file mode 100644 index 000000000000..10bbb3d22712 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/simple_agent_with_foundry_checkpointer/main.py @@ -0,0 +1,82 @@ +# Copyright (c) Microsoft. All rights reserved. + +""" +Simple Agent with Foundry Managed Checkpointer + +This sample demonstrates how to use FoundryCheckpointSaver with a LangGraph +agent to persist checkpoints in Azure AI Foundry. + +Foundry managed checkpoints enable graph state to be persisted across +requests, allowing conversations to be paused, resumed, and replayed. + +Prerequisites: + - Set AZURE_AI_PROJECT_ENDPOINT to your Azure AI Foundry project endpoint + e.g. "https://.services.ai.azure.com/api/projects/" + - Set AZURE_OPENAI_CHAT_DEPLOYMENT_NAME (defaults to "gpt-4o") + - Azure credentials configured (e.g. az login) +""" + +import asyncio +import os + +from dotenv import load_dotenv +from langchain_core.tools import tool +from langchain_openai import AzureChatOpenAI +from azure.identity.aio import AzureCliCredential + +from azure.ai.agentserver.langgraph import from_langgraph +from azure.ai.agentserver.langgraph.checkpointer import FoundryCheckpointSaver + +load_dotenv() + +deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4o") +model = AzureChatOpenAI(model=deployment_name) + + +@tool +def get_word_length(word: str) -> int: + """Returns the length of a word.""" + return len(word) + + +@tool +def calculator(expression: str) -> str: + """Evaluates a mathematical expression.""" + try: + result = eval(expression) # noqa: S307 + return str(result) + except Exception as e: + return f"Error: {str(e)}" + + +tools = [get_word_length, calculator] + + +def create_agent(checkpointer): + """Create a react agent with the given checkpointer.""" + from langgraph.prebuilt import create_react_agent + + return create_react_agent(model, tools, checkpointer=checkpointer) + + +async def main() -> None: + """Run the agent with Foundry managed checkpoints.""" + project_endpoint = os.getenv("AZURE_AI_PROJECT_ENDPOINT", "") + + async with AzureCliCredential() as cred: + # Use FoundryCheckpointSaver for Azure AI Foundry managed storage. + # This persists graph checkpoints remotely, enabling pause/resume + # across requests and server restarts. + saver = FoundryCheckpointSaver( + project_endpoint=project_endpoint, + credential=cred, + ) + + # Pass the checkpointer via LangGraph's native compile/create API + executor = create_agent(checkpointer=saver) + + await from_langgraph(executor).run_async() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py new file mode 100644 index 000000000000..c4992ba71f46 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/graph_agent_tool.py @@ -0,0 +1,104 @@ +import os + +from dotenv import load_dotenv +from langchain.chat_models import init_chat_model +from langchain_core.messages import SystemMessage, ToolMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import tool +from langgraph.graph import ( + END, + START, + MessagesState, + StateGraph, +) +from typing_extensions import Literal +from azure.identity import DefaultAzureCredential, get_bearer_token_provider + +from azure.ai.agentserver.langgraph import from_langgraph +from azure.ai.agentserver.langgraph.tools import use_foundry_tools + +load_dotenv() + +deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4o") +credential = DefaultAzureCredential() +token_provider = get_bearer_token_provider( + credential, "https://cognitiveservices.azure.com/.default" +) +llm = init_chat_model( + f"azure_openai:{deployment_name}", + azure_ad_token_provider=token_provider, +) +llm_with_foundry_tools = use_foundry_tools(llm, [ + { + # use the python tool to calculate what is 4 * 3.82. and then find its square root and then find the square root of that result + "type": "code_interpreter" + }, + { + # Give me the Azure CLI commands to create an Azure Container App with a managed identity. search Microsoft Learn + "type": "mcp", + "project_connection_id": "MicrosoftLearn" + }, + # { + # "type": "mcp", + # "project_connection_id": "FoundryMCPServerpreview" + # } +]) + + +# Nodes +async def llm_call(state: MessagesState, config: RunnableConfig): + """LLM decides whether to call a tool or not""" + + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [ + SystemMessage( + content="You are a helpful assistant tasked with performing arithmetic on a set of inputs." + ) + ] + + state["messages"], + config=config, + ) + ] + } + + +# Conditional edge function to route to the tool node or end based upon whether the LLM made a tool call +def should_continue(state: MessagesState) -> Literal["environment", END]: + """Decide if we should continue the loop or stop based upon whether the LLM made a tool call""" + + messages = state["messages"] + last_message = messages[-1] + # If the LLM makes a tool call, then perform an action + if last_message.tool_calls: + return "Action" + # Otherwise, we stop (reply to the user) + return END + + +# Build workflow +agent_builder = StateGraph(MessagesState) + +# Add nodes +agent_builder.add_node("llm_call", llm_call) +agent_builder.add_node("environment", llm_with_foundry_tools.tool_node) + +# Add edges to connect nodes +agent_builder.add_edge(START, "llm_call") +agent_builder.add_conditional_edges( + "llm_call", + should_continue, + { + "Action": "environment", + END: END, + }, +) +agent_builder.add_edge("environment", "llm_call") + +# Compile the agent +agent = agent_builder.compile() + +if __name__ == "__main__": + adapter = from_langgraph(agent) + adapter.run() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/react_agent_tool.py b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/react_agent_tool.py new file mode 100644 index 000000000000..111a2307c6af --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/samples/tool_client_example/react_agent_tool.py @@ -0,0 +1,47 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +import os + +from azure.identity import DefaultAzureCredential, get_bearer_token_provider +from dotenv import load_dotenv +from langchain.agents import create_agent +from langchain_core.tools import tool +from langchain_openai import AzureChatOpenAI +from langgraph.checkpoint.memory import MemorySaver + +from azure.ai.agentserver.langgraph import from_langgraph +from azure.ai.agentserver.langgraph.tools import use_foundry_tools + +load_dotenv() + +token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" +) + +memory = MemorySaver() +deployment_name = os.getenv("AZURE_OPENAI_CHAT_DEPLOYMENT_NAME", "gpt-4o") +model = AzureChatOpenAI(model=deployment_name, azure_ad_token_provider=token_provider) + +foundry_tools = use_foundry_tools([ + { + # use the python tool to calculate what is 4 * 3.82. and then find its square root and then find the square root of that result + "type": "code_interpreter" + }, + { + # Give me the Azure CLI commands to create an Azure Container App with a managed identity. search Microsoft Learn + "type": "mcp", + "project_connection_id": "MicrosoftLearn" + }, + # { + # "type": "mcp", + # "project_connection_id": "FoundryMCPServerpreview" + # } +]) + + +agent_executor = create_agent(model, checkpointer=memory, middleware=[foundry_tools]) + +if __name__ == "__main__": + # host the langgraph agent + from_langgraph(agent_executor).run() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py deleted file mode 100644 index 4a5d26360bce..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -# Unit tests package diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/conftest.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/conftest.py deleted file mode 100644 index 7f055e40010c..000000000000 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/conftest.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -Pytest configuration and shared fixtures for unit tests. -""" - -import sys -from pathlib import Path - -# Add the src directory to the Python path so we can import modules under test -src_path = Path(__file__).parent.parent.parent / "src" -sys.path.insert(0, str(src_path)) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/__init__.py new file mode 100644 index 000000000000..315126869940 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/__init__.py @@ -0,0 +1,4 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for the checkpointer module.""" diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_foundry_checkpoint_saver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_foundry_checkpoint_saver.py new file mode 100644 index 000000000000..d25a34f6289d --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_foundry_checkpoint_saver.py @@ -0,0 +1,441 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryCheckpointSaver.""" + +import pytest +from unittest.mock import Mock + +from azure.core.credentials import TokenCredential +from azure.core.credentials_async import AsyncTokenCredential + +from azure.ai.agentserver.langgraph.checkpointer import FoundryCheckpointSaver +from azure.ai.agentserver.langgraph.checkpointer._foundry_checkpoint_saver import ( + BaseCheckpointSaver, +) +from azure.ai.agentserver.langgraph.checkpointer._item_id import make_item_id + +from ..mocks import MockFoundryCheckpointClient + + +class TestableFoundryCheckpointSaver(FoundryCheckpointSaver): + """Testable version that accepts a mock client directly (bypasses credential check).""" + + def __init__(self, client: MockFoundryCheckpointClient) -> None: + """Initialize with a mock client.""" + # Skip FoundryCheckpointSaver.__init__ and call BaseCheckpointSaver directly + BaseCheckpointSaver.__init__(self, serde=None) + self._client = client # type: ignore[assignment] + self._session_cache: set[str] = set() + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_none_for_missing_checkpoint() -> None: + """Test that aget_tuple returns None when checkpoint doesn't exist.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1"}} + result = await saver.aget_tuple(config) + + assert result is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_creates_checkpoint_item() -> None: + """Test that aput creates a checkpoint item.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + metadata = {"source": "test"} + + result = await saver.aput(config, checkpoint, metadata, {}) + + assert result["configurable"]["checkpoint_id"] == "cp-001" + assert result["configurable"]["thread_id"] == "thread-1" + + # Verify item was created + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_creates_blob_items_for_new_versions() -> None: + """Test that aput creates blob items for channel values.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {"messages": ["hello", "world"]}, + "channel_versions": {"messages": "1"}, + "versions_seen": {}, + "pending_sends": [], + } + metadata = {"source": "test"} + new_versions = {"messages": "1"} + + await saver.aput(config, checkpoint, metadata, new_versions) + + # Should have 2 items: checkpoint + 1 blob + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_returns_config_with_checkpoint_id() -> None: + """Test that aput returns config with the correct checkpoint ID.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "ns1"}} + checkpoint = { + "id": "my-checkpoint-id", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + + result = await saver.aput(config, checkpoint, {}, {}) + + assert result["configurable"]["checkpoint_id"] == "my-checkpoint-id" + assert result["configurable"]["thread_id"] == "thread-1" + assert result["configurable"]["checkpoint_ns"] == "ns1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_checkpoint_with_data() -> None: + """Test that aget_tuple returns checkpoint data correctly.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save a checkpoint first + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + metadata = {"source": "test", "step": 1} + + await saver.aput(config, checkpoint, metadata, {}) + + # Now retrieve it + get_config = {"configurable": {"thread_id": "thread-1", "checkpoint_id": "cp-001"}} + result = await saver.aget_tuple(get_config) + + assert result is not None + assert result.checkpoint["id"] == "cp-001" + assert result.metadata["source"] == "test" + assert result.metadata["step"] == 1 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_latest_without_checkpoint_id() -> None: + """Test that aget_tuple returns the latest checkpoint when no ID specified.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save multiple checkpoints + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + + for i in range(3): + checkpoint = { + "id": f"cp-00{i}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {"step": i}, {}) + + # Retrieve without specifying checkpoint_id + get_config = {"configurable": {"thread_id": "thread-1"}} + result = await saver.aget_tuple(get_config) + + assert result is not None + # Should get the latest (max checkpoint_id) + assert result.checkpoint["id"] == "cp-002" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aput_writes_creates_write_items() -> None: + """Test that aput_writes creates write items.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # First create a checkpoint + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {}, {}) + + # Now add writes + writes = [("channel1", "value1"), ("channel2", "value2")] + await saver.aput_writes(config, writes, task_id="task-1") + + # Should have 3 items: checkpoint + 2 writes + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 3 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_pending_writes() -> None: + """Test that aget_tuple includes pending writes.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Create checkpoint and add writes + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {}, {}) + + writes = [("channel1", "value1")] + await saver.aput_writes(config, writes, task_id="task-1") + + # Retrieve and check pending writes + result = await saver.aget_tuple(config) + + assert result is not None + assert result.pending_writes is not None + assert len(result.pending_writes) == 1 + assert result.pending_writes[0][1] == "channel1" + assert result.pending_writes[0][2] == "value1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_alist_returns_checkpoints_in_order() -> None: + """Test that alist returns checkpoints in reverse order.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save multiple checkpoints + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + for i in range(3): + checkpoint = { + "id": f"cp-00{i}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {"step": i}, {}) + + # List checkpoints + list_config = {"configurable": {"thread_id": "thread-1"}} + results = [] + async for cp in saver.alist(list_config): + results.append(cp) + + assert len(results) == 3 + # Should be in reverse order (newest first) + assert results[0].checkpoint["id"] == "cp-002" + assert results[1].checkpoint["id"] == "cp-001" + assert results[2].checkpoint["id"] == "cp-000" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_alist_filters_by_namespace() -> None: + """Test that alist filters by checkpoint namespace.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save checkpoints in different namespaces + for ns in ["ns1", "ns2"]: + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ns}} + checkpoint = { + "id": f"cp-{ns}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + await saver.aput(config, checkpoint, {}, {}) + + # List only ns1 + list_config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": "ns1"}} + results = [] + async for cp in saver.alist(list_config): + results.append(cp) + + assert len(results) == 1 + assert results[0].checkpoint["id"] == "cp-ns1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_alist_applies_limit() -> None: + """Test that alist respects the limit parameter.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save multiple checkpoints + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + for i in range(5): + checkpoint = { + "id": f"cp-00{i}", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, checkpoint, {}, {}) + + # List with limit + list_config = {"configurable": {"thread_id": "thread-1"}} + results = [] + async for cp in saver.alist(list_config, limit=2): + results.append(cp) + + assert len(results) == 2 + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_adelete_thread_deletes_session() -> None: + """Test that adelete_thread removes all checkpoints for a thread.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Create a checkpoint + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + await saver.aput(config, checkpoint, {}, {}) + + # Verify it exists + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 1 + + # Delete the thread + await saver.adelete_thread("thread-1") + + # Verify it's gone + item_ids = await client.list_item_ids("thread-1") + assert len(item_ids) == 0 + + +@pytest.mark.unit +def test_sync_methods_raise_not_implemented() -> None: + """Test that sync methods raise NotImplementedError.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + config = {"configurable": {"thread_id": "thread-1"}} + checkpoint = {"id": "cp-001", "channel_values": {}, "channel_versions": {}} + + with pytest.raises(NotImplementedError, match="aget_tuple"): + saver.get_tuple(config) + + with pytest.raises(NotImplementedError, match="aput"): + saver.put(config, checkpoint, {}, {}) + + with pytest.raises(NotImplementedError, match="aput_writes"): + saver.put_writes(config, [], "task-1") + + with pytest.raises(NotImplementedError, match="alist"): + list(saver.list(config)) + + with pytest.raises(NotImplementedError, match="adelete_thread"): + saver.delete_thread("thread-1") + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_tuple_returns_parent_config() -> None: + """Test that aget_tuple includes parent config when checkpoint has parent.""" + client = MockFoundryCheckpointClient() + saver = TestableFoundryCheckpointSaver(client=client) + + # Save parent checkpoint + config = {"configurable": {"thread_id": "thread-1", "checkpoint_ns": ""}} + parent_checkpoint = { + "id": "cp-001", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + config = await saver.aput(config, parent_checkpoint, {}, {}) + + # Save child checkpoint + child_checkpoint = { + "id": "cp-002", + "channel_values": {}, + "channel_versions": {}, + "versions_seen": {}, + "pending_sends": [], + } + await saver.aput(config, child_checkpoint, {}, {}) + + # Retrieve child + get_config = {"configurable": {"thread_id": "thread-1", "checkpoint_id": "cp-002"}} + result = await saver.aget_tuple(get_config) + + assert result is not None + assert result.parent_config is not None + assert result.parent_config["configurable"]["checkpoint_id"] == "cp-001" + + +@pytest.mark.unit +def test_constructor_requires_async_credential() -> None: + """Test that FoundryCheckpointSaver raises TypeError for sync credentials.""" + mock_credential = Mock(spec=TokenCredential) + + with pytest.raises(TypeError, match="AsyncTokenCredential"): + FoundryCheckpointSaver( + project_endpoint="https://test.services.ai.azure.com/api/projects/test", + credential=mock_credential, + ) + + +@pytest.mark.unit +def test_constructor_accepts_async_credential() -> None: + """Test that FoundryCheckpointSaver accepts AsyncTokenCredential.""" + mock_credential = Mock(spec=AsyncTokenCredential) + + saver = FoundryCheckpointSaver( + project_endpoint="https://test.services.ai.azure.com/api/projects/test", + credential=mock_credential, + ) + + assert saver is not None + assert saver._client is not None diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_item_id.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_item_id.py new file mode 100644 index 000000000000..ddb11ffa62f0 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/checkpointer/test_item_id.py @@ -0,0 +1,125 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for item ID utilities.""" + +import pytest + +from azure.ai.agentserver.langgraph.checkpointer._item_id import ( + ParsedItemId, + make_item_id, + parse_item_id, +) + + +@pytest.mark.unit +def test_make_item_id_formats_correctly() -> None: + """Test that make_item_id creates correct composite IDs.""" + item_id = make_item_id("ns1", "cp-001", "checkpoint") + assert item_id == "ns1:cp-001:checkpoint:" + + +@pytest.mark.unit +def test_make_item_id_with_sub_key() -> None: + """Test that make_item_id includes sub_key correctly.""" + item_id = make_item_id("ns1", "cp-001", "writes", "task1:0") + assert item_id == "ns1:cp-001:writes:task1%3A0" + + +@pytest.mark.unit +def test_make_item_id_with_blob() -> None: + """Test blob item ID format.""" + item_id = make_item_id("", "cp-001", "blob", "messages:v2") + assert item_id == ":cp-001:blob:messages%3Av2" + + +@pytest.mark.unit +def test_parse_item_id_extracts_components() -> None: + """Test that parse_item_id extracts all components correctly.""" + item_id = "ns1:cp-001:checkpoint:" + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == "ns1" + assert parsed.checkpoint_id == "cp-001" + assert parsed.item_type == "checkpoint" + assert parsed.sub_key == "" + + +@pytest.mark.unit +def test_parse_item_id_extracts_sub_key() -> None: + """Test that parse_item_id extracts sub_key correctly.""" + item_id = "ns1:cp-001:writes:task1%3A0" + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == "ns1" + assert parsed.checkpoint_id == "cp-001" + assert parsed.item_type == "writes" + assert parsed.sub_key == "task1:0" + + +@pytest.mark.unit +def test_roundtrip_simple() -> None: + """Test roundtrip encoding/decoding of simple IDs.""" + original_ns = "namespace" + original_id = "checkpoint-123" + original_type = "checkpoint" + original_key = "" + + item_id = make_item_id(original_ns, original_id, original_type, original_key) + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == original_ns + assert parsed.checkpoint_id == original_id + assert parsed.item_type == original_type + assert parsed.sub_key == original_key + + +@pytest.mark.unit +def test_roundtrip_with_special_characters() -> None: + """Test roundtrip encoding/decoding with special characters (colons).""" + original_ns = "ns:with:colons" + original_id = "cp:123:abc" + original_type = "blob" + original_key = "channel:v1:extra" + + item_id = make_item_id(original_ns, original_id, original_type, original_key) + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == original_ns + assert parsed.checkpoint_id == original_id + assert parsed.item_type == original_type + assert parsed.sub_key == original_key + + +@pytest.mark.unit +def test_roundtrip_with_percent_signs() -> None: + """Test roundtrip encoding/decoding with percent signs.""" + original_ns = "ns%test" + original_id = "cp%123" + original_type = "checkpoint" + original_key = "key%value" + + item_id = make_item_id(original_ns, original_id, original_type, original_key) + parsed = parse_item_id(item_id) + + assert parsed.checkpoint_ns == original_ns + assert parsed.checkpoint_id == original_id + assert parsed.item_type == original_type + assert parsed.sub_key == original_key + + +@pytest.mark.unit +def test_parse_item_id_raises_on_invalid_format() -> None: + """Test that parse_item_id raises ValueError for invalid format.""" + with pytest.raises(ValueError, match="Invalid item_id format"): + parse_item_id("invalid:format") + + with pytest.raises(ValueError, match="Invalid item_id format"): + parse_item_id("only:two:parts") + + +@pytest.mark.unit +def test_parse_item_id_raises_on_invalid_type() -> None: + """Test that parse_item_id raises ValueError for invalid item type.""" + with pytest.raises(ValueError, match="Invalid item_type"): + parse_item_id("ns:cp:invalid:key") diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/__init__.py new file mode 100644 index 000000000000..4436d04866df --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/__init__.py @@ -0,0 +1,8 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mock implementations for testing.""" + +from .mock_checkpoint_client import MockFoundryCheckpointClient + +__all__ = ["MockFoundryCheckpointClient"] diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/mock_checkpoint_client.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/mock_checkpoint_client.py new file mode 100644 index 000000000000..ffc1e2fcc4c1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/mocks/mock_checkpoint_client.py @@ -0,0 +1,156 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Mock implementation of FoundryCheckpointClient for testing.""" + +from typing import Any, Dict, List, Optional + +from azure.ai.agentserver.core.checkpoints.client import ( + CheckpointItem, + CheckpointItemId, + CheckpointSession, +) + + +class MockFoundryCheckpointClient: + """In-memory mock for FoundryCheckpointClient for unit testing. + + Stores checkpoints in memory without making any HTTP calls. + """ + + def __init__(self, endpoint: str = "https://mock.endpoint") -> None: + """Initialize the mock client. + + :param endpoint: The mock endpoint URL. + :type endpoint: str + """ + self._endpoint = endpoint + self._sessions: Dict[str, CheckpointSession] = {} + self._items: Dict[str, CheckpointItem] = {} + + def _item_key(self, item_id: CheckpointItemId) -> str: + """Generate a unique key for a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The unique key. + :rtype: str + """ + return f"{item_id.session_id}:{item_id.item_id}" + + # Session operations + + async def upsert_session(self, session: CheckpointSession) -> CheckpointSession: + """Create or update a checkpoint session. + + :param session: The checkpoint session to upsert. + :type session: CheckpointSession + :return: The upserted checkpoint session. + :rtype: CheckpointSession + """ + self._sessions[session.session_id] = session + return session + + async def read_session(self, session_id: str) -> Optional[CheckpointSession]: + """Read a checkpoint session by ID. + + :param session_id: The session identifier. + :type session_id: str + :return: The checkpoint session if found, None otherwise. + :rtype: Optional[CheckpointSession] + """ + return self._sessions.get(session_id) + + async def delete_session(self, session_id: str) -> None: + """Delete a checkpoint session. + + :param session_id: The session identifier. + :type session_id: str + """ + self._sessions.pop(session_id, None) + # Also delete all items in the session + keys_to_delete = [ + key for key, item in self._items.items() if item.session_id == session_id + ] + for key in keys_to_delete: + del self._items[key] + + # Item operations + + async def create_items(self, items: List[CheckpointItem]) -> List[CheckpointItem]: + """Create checkpoint items in batch. + + :param items: The checkpoint items to create. + :type items: List[CheckpointItem] + :return: The created checkpoint items. + :rtype: List[CheckpointItem] + """ + for item in items: + key = self._item_key(item.to_item_id()) + self._items[key] = item + return items + + async def read_item(self, item_id: CheckpointItemId) -> Optional[CheckpointItem]: + """Read a checkpoint item by ID. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: The checkpoint item if found, None otherwise. + :rtype: Optional[CheckpointItem] + """ + key = self._item_key(item_id) + return self._items.get(key) + + async def delete_item(self, item_id: CheckpointItemId) -> bool: + """Delete a checkpoint item. + + :param item_id: The checkpoint item identifier. + :type item_id: CheckpointItemId + :return: True if the item was deleted, False if not found. + :rtype: bool + """ + key = self._item_key(item_id) + if key in self._items: + del self._items[key] + return True + return False + + async def list_item_ids( + self, session_id: str, parent_id: Optional[str] = None + ) -> List[CheckpointItemId]: + """List checkpoint item IDs for a session. + + :param session_id: The session identifier. + :type session_id: str + :param parent_id: Optional parent item identifier for filtering. + :type parent_id: Optional[str] + :return: List of checkpoint item identifiers. + :rtype: List[CheckpointItemId] + """ + result = [] + for item in self._items.values(): + if item.session_id == session_id: + if parent_id is None or item.parent_id == parent_id: + result.append(item.to_item_id()) + return result + + # Context manager methods + + async def close(self) -> None: + """Close the client (no-op for mock).""" + pass + + async def __aenter__(self) -> "MockFoundryCheckpointClient": + """Enter the async context manager. + + :return: The client instance. + :rtype: MockFoundryCheckpointClient + """ + return self + + async def __aexit__(self, *exc_details: Any) -> None: + """Exit the async context manager. + + :param exc_details: Exception details if an exception occurred. + """ + pass diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_conversation_id_optional.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_conversation_id_optional.py new file mode 100644 index 000000000000..a337f494e595 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_conversation_id_optional.py @@ -0,0 +1,68 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +from types import SimpleNamespace + +import pytest + +from azure.ai.agentserver.langgraph._exceptions import LangGraphMissingConversationIdError +from azure.ai.agentserver.langgraph.langgraph import LangGraphAdapter +from azure.ai.agentserver.langgraph.models.response_api_default_converter import ResponseAPIDefaultConverter + + +class _DummyConverter: + async def convert_request(self, context): # pragma: no cover - guard should short-circuit first + raise AssertionError("convert_request should not be called for this test") + + async def convert_response_non_stream(self, output, context): # pragma: no cover - guard should short-circuit + raise AssertionError("convert_response_non_stream should not be called for this test") + + async def convert_response_stream(self, output, context): # pragma: no cover - guard should short-circuit + raise AssertionError("convert_response_stream should not be called for this test") + + +class _DummyGraph: + def __init__(self) -> None: + self.checkpointer = object() + self.last_config = None + + async def aget_state(self, config): + self.last_config = config + return "state" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_state_skips_without_conversation_id() -> None: + graph = _DummyGraph() + converter = ResponseAPIDefaultConverter(graph) # type: ignore[arg-type] + context = SimpleNamespace(agent_run=SimpleNamespace(conversation_id=None)) + + state = await converter._aget_state(context) # type: ignore[arg-type] + + assert state is None + assert graph.last_config is None + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_aget_state_uses_conversation_id() -> None: + graph = _DummyGraph() + converter = ResponseAPIDefaultConverter(graph) # type: ignore[arg-type] + context = SimpleNamespace(agent_run=SimpleNamespace(conversation_id="conv-1")) + + state = await converter._aget_state(context) # type: ignore[arg-type] + + assert state == "state" + assert graph.last_config["configurable"]["thread_id"] == "conv-1" + + +@pytest.mark.unit +@pytest.mark.asyncio +async def test_agent_run_requires_conversation_id_with_checkpointer_raises() -> None: + graph = _DummyGraph() + adapter = LangGraphAdapter(graph, converter=_DummyConverter()) # type: ignore[arg-type] + context = SimpleNamespace(conversation_id=None, stream=False) + + with pytest.raises(LangGraphMissingConversationIdError): + await adapter.agent_run(context) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_from_langgraph_managed.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_from_langgraph_managed.py new file mode 100644 index 000000000000..5a8b5cf2a1f4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_from_langgraph_managed.py @@ -0,0 +1,72 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for from_langgraph with checkpointer via compile.""" + +import pytest +from unittest.mock import Mock + +from azure.core.credentials_async import AsyncTokenCredential + + +@pytest.mark.unit +def test_from_langgraph_basic() -> None: + """Test that from_langgraph works without a checkpointer.""" + from azure.ai.agentserver.langgraph import from_langgraph + from langgraph.graph import StateGraph + from typing_extensions import TypedDict + + class State(TypedDict): + messages: list + + builder = StateGraph(State) + builder.add_node("node1", lambda x: x) + builder.set_entry_point("node1") + builder.set_finish_point("node1") + graph = builder.compile() + + adapter = from_langgraph(graph) + + assert adapter is not None + + +@pytest.mark.unit +def test_graph_with_foundry_checkpointer_via_compile() -> None: + """Test that FoundryCheckpointSaver can be set via builder.compile().""" + from azure.ai.agentserver.langgraph import from_langgraph + from azure.ai.agentserver.langgraph.checkpointer import FoundryCheckpointSaver + from langgraph.graph import StateGraph + from typing_extensions import TypedDict + + class State(TypedDict): + messages: list + + builder = StateGraph(State) + builder.add_node("node1", lambda x: x) + builder.add_node("node2", lambda x: x) + builder.add_edge("node1", "node2") + builder.set_entry_point("node1") + builder.set_finish_point("node2") + + mock_credential = Mock(spec=AsyncTokenCredential) + saver = FoundryCheckpointSaver( + project_endpoint="https://test.services.ai.azure.com/api/projects/test-project", + credential=mock_credential, + ) + + # User sets checkpointer via LangGraph's native compile() + graph = builder.compile( + checkpointer=saver, + interrupt_before=["node1"], + interrupt_after=["node2"], + debug=True, + ) + + adapter = from_langgraph(graph) + + # Verify checkpointer and compile parameters are preserved + assert adapter is not None + assert isinstance(adapter._graph.checkpointer, FoundryCheckpointSaver) + assert adapter._graph.interrupt_before_nodes == ["node1"] + assert adapter._graph.interrupt_after_nodes == ["node2"] + assert adapter._graph.debug is True diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py index 84a8c8784d8b..b1894f7350d5 100644 --- a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/test_langgraph_request_converter.py @@ -3,7 +3,7 @@ from azure.ai.agentserver.core import models from azure.ai.agentserver.core.models import projects as project_models -from azure.ai.agentserver.langgraph import models as langgraph_models +from azure.ai.agentserver.langgraph.models.response_api_request_converter import ResponseAPIMessageRequestConverter @pytest.mark.unit @@ -16,7 +16,7 @@ def test_convert_implicit_user_message(): input=[implicit_user_message], ) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res @@ -34,7 +34,7 @@ def test_convert_implicit_user_message_with_contents(): ] create_response = models.CreateResponse(input=[{"content": input_data}]) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res @@ -61,7 +61,7 @@ def test_convert_item_param_message(): create_response = models.CreateResponse( input=input_data, ) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res @@ -103,7 +103,7 @@ def test_convert_item_param_function_call_and_function_call_output(): create_response = models.CreateResponse( input=input_data, ) - converter = langgraph_models.LangGraphRequestConverter(create_response) + converter = ResponseAPIMessageRequestConverter(create_response) res = converter.convert() assert "messages" in res assert len(res["messages"]) == len(input_data) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/__init__.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/__init__.py new file mode 100644 index 000000000000..28077537d94b --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/__init__.py @@ -0,0 +1,5 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +__path__ = __import__('pkgutil').extend_path(__path__, __name__) diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/conftest.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/conftest.py new file mode 100644 index 000000000000..7efc298559c1 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/conftest.py @@ -0,0 +1,271 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Shared fixtures for langgraph tools unit tests.""" +from typing import Any, Dict, List, Optional + +import pytest +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models import BaseChatModel +from langchain_core.messages import AIMessage +from langchain_core.outputs import ChatGeneration, ChatResult +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryConnectedTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + + +class FakeChatModel(BaseChatModel): + """A fake chat model for testing purposes that returns pre-configured responses.""" + + responses: List[AIMessage] = [] + tool_calls_list: List[List[Dict[str, Any]]] = [] + _call_count: int = 0 + _bound_tools: List[Any] = [] + _bound_kwargs: Dict[str, Any] = {} + + def __init__( + self, + responses: Optional[List[AIMessage]] = None, + tool_calls: Optional[List[List[Dict[str, Any]]]] = None, + **kwargs: Any, + ): + """Initialize the fake chat model. + + :param responses: List of AIMessage responses to return in sequence. + :param tool_calls: List of tool_calls lists corresponding to each response. + """ + super().__init__(**kwargs) + self.responses = responses or [] + self.tool_calls_list = tool_calls or [] + self._call_count = 0 + self._bound_tools = [] + self._bound_kwargs = {} + + @property + def _llm_type(self) -> str: + return "fake_chat_model" + + def _generate( + self, + messages: List[Any], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Generate a response.""" + response = self._get_next_response() + return ChatResult(generations=[ChatGeneration(message=response)]) + + def bind_tools( + self, + tools: List[Any], + **kwargs: Any, + ) -> "FakeChatModel": + """Bind tools to this model.""" + self._bound_tools = list(tools) + self._bound_kwargs.update(kwargs) + return self + + def invoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> AIMessage: + """Synchronously invoke the model.""" + return self._get_next_response() + + async def ainvoke(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any) -> AIMessage: + """Asynchronously invoke the model.""" + return self._get_next_response() + + def stream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + """Stream the response.""" + yield self._get_next_response() + + async def astream(self, input: Any, config: Optional[RunnableConfig] = None, **kwargs: Any): + """Async stream the response.""" + yield self._get_next_response() + + def _get_next_response(self) -> AIMessage: + """Get the next response in sequence.""" + if self._call_count < len(self.responses): + response = self.responses[self._call_count] + else: + # Default response if no more configured + response = AIMessage(content="Default response") + + # Apply tool calls if configured + if self._call_count < len(self.tool_calls_list): + response = AIMessage( + content=response.content, + tool_calls=self.tool_calls_list[self._call_count], + ) + + self._call_count += 1 + return response + + +@pytest.fixture +def sample_schema_definition() -> SchemaDefinition: + """Create a sample SchemaDefinition.""" + return SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required={"query"}, + ) + + +@pytest.fixture +def sample_code_interpreter_tool() -> FoundryHostedMcpTool: + """Create a sample code interpreter tool definition.""" + return FoundryHostedMcpTool( + name="code_interpreter", + configuration={}, + ) + + +@pytest.fixture +def sample_mcp_connected_tool() -> FoundryConnectedTool: + """Create a sample MCP connected tool definition.""" + return FoundryConnectedTool( + protocol="mcp", + project_connection_id="MicrosoftLearn", + ) + + +@pytest.fixture +def sample_tool_details(sample_schema_definition: SchemaDefinition) -> FoundryToolDetails: + """Create a sample FoundryToolDetails.""" + return FoundryToolDetails( + name="search", + description="Search for documents", + input_schema=sample_schema_definition, + ) + + +@pytest.fixture +def sample_resolved_tool( + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_tool_details: FoundryToolDetails, +) -> ResolvedFoundryTool: + """Create a sample resolved foundry tool.""" + return ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=sample_tool_details, + ) + + +@pytest.fixture +def mock_langchain_tool() -> BaseTool: + """Create a mock LangChain BaseTool.""" + @tool + def mock_tool(query: str) -> str: + """Mock tool for testing. + + :param query: The search query. + :return: Mock result. + """ + return f"Mock result for: {query}" + + return mock_tool + + +@pytest.fixture +def mock_async_langchain_tool() -> BaseTool: + """Create a mock async LangChain BaseTool.""" + @tool + async def mock_async_tool(query: str) -> str: + """Mock async tool for testing. + + :param query: The search query. + :return: Mock result. + """ + return f"Async mock result for: {query}" + + return mock_async_tool + + +@pytest.fixture +def sample_resolved_tools( + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, +) -> ResolvedTools: + """Create a sample ResolvedTools instance.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="mock_tool", + description="Mock tool for testing", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + return ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + +@pytest.fixture +def mock_agent_run_context() -> AgentRunContext: + """Create a mock AgentRunContext.""" + payload = { + "input": [{"role": "user", "content": "Hello"}], + "stream": False, + } + return AgentRunContext(payload=payload) + + +@pytest.fixture +def mock_foundry_tool_context(sample_resolved_tools: ResolvedTools) -> FoundryToolContext: + """Create a mock FoundryToolContext.""" + return FoundryToolContext(resolved_tools=sample_resolved_tools) + + +@pytest.fixture +def mock_langgraph_run_context( + mock_agent_run_context: AgentRunContext, + mock_foundry_tool_context: FoundryToolContext, +) -> LanggraphRunContext: + """Create a mock LanggraphRunContext.""" + return LanggraphRunContext( + agent_run=mock_agent_run_context, + tools=mock_foundry_tool_context, + ) + + +@pytest.fixture +def fake_chat_model_simple() -> FakeChatModel: + """Create a simple fake chat model.""" + return FakeChatModel( + responses=[AIMessage(content="Hello! How can I help you?")], + ) + + +@pytest.fixture +def fake_chat_model_with_tool_call() -> FakeChatModel: + """Create a fake chat model that makes a tool call.""" + return FakeChatModel( + responses=[ + AIMessage(content=""), # First response: tool call + AIMessage(content="The answer is 42."), # Second response: final answer + ], + tool_calls=[ + [{"id": "call_1", "name": "mock_tool", "args": {"query": "test query"}}], + [], # No tool calls in final response + ], + ) + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_agent_integration.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_agent_integration.py new file mode 100644 index 000000000000..eea917e54fd4 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_agent_integration.py @@ -0,0 +1,416 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Integration-style unit tests for langgraph agents with foundry tools. + +These tests demonstrate the usage patterns similar to the tool_client_example samples, +but use mocked models and tools to avoid calling real services. +""" +import pytest + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool +from langgraph.graph import END, START, MessagesState, StateGraph +from langgraph.prebuilt import ToolNode +from typing_extensions import Literal + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools import use_foundry_tools +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools, get_registry + +from .conftest import FakeChatModel + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the global registry before and after each test.""" + registry = get_registry() + registry.clear() + yield + registry.clear() + + +@pytest.mark.unit +class TestGraphAgentWithFoundryTools: + """Tests demonstrating graph agent usage patterns similar to graph_agent_tool.py sample.""" + + def _create_mock_langgraph_context( + self, + foundry_tool: FoundryHostedMcpTool, + langchain_tool: BaseTool, + ) -> LanggraphRunContext: + """Create a mock LanggraphRunContext with resolved tools.""" + # Create resolved foundry tool + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name=langchain_tool.name, + description=langchain_tool.description or "Mock tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + + # Create resolved tools + resolved_tools = ResolvedTools(tools=[(resolved_foundry_tool, langchain_tool)]) + + # Create context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext(resolved_tools=resolved_tools) + + return LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + @pytest.mark.asyncio + async def test_graph_agent_with_foundry_tools_no_tool_call(self): + """Test a graph agent that uses foundry tools but doesn't make a tool call.""" + # Create a mock tool + @tool + def calculate(expression: str) -> str: + """Calculate a mathematical expression. + + :param expression: The expression to calculate. + :return: The result. + """ + return "42" + + # Create foundry tool definition + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + foundry_tools = [{"type": "code_interpreter"}] + + # Create mock model that returns simple response (no tool call) + mock_model = FakeChatModel( + responses=[AIMessage(content="The answer is 42.")], + ) + + # Create the foundry tool binding chat model + llm_with_foundry_tools = FoundryToolLateBindingChatModel( + delegate=mock_model, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Create context and attach + context = self._create_mock_langgraph_context(foundry_tool, calculate) + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Define the LLM call node + async def llm_call(state: MessagesState, config: RunnableConfig): + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [SystemMessage(content="You are a helpful assistant.")] + + state["messages"], + config=config, + ) + ] + } + + # Define routing function + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, 'tool_calls') and last_message.tool_calls: + return "tools" + return END + + # Build the graph + builder = StateGraph(MessagesState) + builder.add_node("llm_call", llm_call) + builder.add_node("tools", llm_with_foundry_tools.tool_node) + builder.add_edge(START, "llm_call") + builder.add_conditional_edges("llm_call", should_continue, {"tools": "tools", END: END}) + builder.add_edge("tools", "llm_call") + + graph = builder.compile() + + # Run the graph + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 6 * 7?")]}, + config=config, + ) + + # Verify result + assert len(result["messages"]) == 2 # HumanMessage + AIMessage + assert result["messages"][-1].content == "The answer is 42." + + @pytest.mark.asyncio + async def test_graph_agent_with_tool_call(self): + """Test a graph agent that makes a tool call.""" + # Create a mock tool + @tool + def calculate(expression: str) -> str: + """Calculate a mathematical expression. + + :param expression: The expression to calculate. + :return: The result. + """ + return "42" + + # Create foundry tool definition + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + foundry_tools = [{"type": "code_interpreter"}] + + # Create mock model that makes a tool call, then returns final answer + mock_model = FakeChatModel( + responses=[ + AIMessage( + content="", + tool_calls=[{"id": "call_1", "name": "calculate", "args": {"expression": "6 * 7"}}], + ), + AIMessage(content="The answer is 42."), + ], + ) + + # Create the foundry tool binding chat model + llm_with_foundry_tools = FoundryToolLateBindingChatModel( + delegate=mock_model, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Create context with the calculate tool + context = self._create_mock_langgraph_context(foundry_tool, calculate) + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Define the LLM call node + async def llm_call(state: MessagesState, config: RunnableConfig): + return { + "messages": [ + await llm_with_foundry_tools.ainvoke( + [SystemMessage(content="You are a helpful assistant.")] + + state["messages"], + config=config, + ) + ] + } + + # Define routing function + def should_continue(state: MessagesState) -> Literal["tools", "__end__"]: + messages = state["messages"] + last_message = messages[-1] + if hasattr(last_message, 'tool_calls') and last_message.tool_calls: + return "tools" + return END + + # Build the graph with a regular ToolNode (using the local tool directly for testing) + builder = StateGraph(MessagesState) + builder.add_node("llm_call", llm_call) + builder.add_node("tools", ToolNode([calculate])) + builder.add_edge(START, "llm_call") + builder.add_conditional_edges("llm_call", should_continue, {"tools": "tools", END: END}) + builder.add_edge("tools", "llm_call") + + graph = builder.compile() + + # Run the graph + result = await graph.ainvoke( + {"messages": [HumanMessage(content="What is 6 * 7?")]}, + config=config, + ) + + # Verify result - should have: HumanMessage, AIMessage (with tool call), ToolMessage, AIMessage (final) + assert len(result["messages"]) == 4 + assert result["messages"][-1].content == "The answer is 42." + + # Verify tool was called + tool_message = result["messages"][2] + assert isinstance(tool_message, ToolMessage) + assert tool_message.content == "42" + + +@pytest.mark.unit +class TestReactAgentWithFoundryTools: + """Tests demonstrating react agent usage patterns similar to react_agent_tool.py sample.""" + + @pytest.mark.asyncio + async def test_middleware_integration_with_foundry_tools(self): + """Test that FoundryToolBindingMiddleware correctly integrates with agents.""" + # Define foundry tools configuration + foundry_tools_config = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "MicrosoftLearn"}, + ] + + # Create middleware using use_foundry_tools + middleware = use_foundry_tools(foundry_tools_config) + + # Verify middleware is created correctly + assert isinstance(middleware, FoundryToolBindingMiddleware) + + # Verify dummy tool is created for the agent + assert len(middleware.tools) == 1 + assert middleware.tools[0].name == "__dummy_tool_by_foundry_middleware__" + + # Verify foundry tools are recorded + assert len(middleware._foundry_tools_to_bind) == 2 + + def test_use_foundry_tools_with_model(self): + """Test use_foundry_tools when used with a model directly.""" + foundry_tools = [{"type": "code_interpreter"}] + mock_model = FakeChatModel() + + result = use_foundry_tools(mock_model, foundry_tools) # type: ignore + + assert isinstance(result, FoundryToolLateBindingChatModel) + assert len(result._foundry_tools_to_bind) == 1 + assert isinstance(result._foundry_tools_to_bind[0], FoundryHostedMcpTool) + assert result._foundry_tools_to_bind[0].name == "code_interpreter" + + +@pytest.mark.unit +class TestLanggraphRunContextIntegration: + """Tests for LanggraphRunContext integration with langgraph.""" + + def test_context_attachment_to_config(self): + """Test that context is correctly attached to RunnableConfig.""" + # Create a mock context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + # Create config and attach context + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Verify context is attached + assert "__foundry_hosted_agent_langgraph_run_context__" in config["configurable"] + assert config["configurable"]["__foundry_hosted_agent_langgraph_run_context__"] is context + + def test_context_resolution_from_config(self): + """Test that context can be resolved from RunnableConfig.""" + # Create and attach context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + # Resolve context + resolved = LanggraphRunContext.resolve(config=config) + + assert resolved is context + + def test_context_resolution_returns_none_when_not_attached(self): + """Test that context resolution returns None when not attached. + + For Python < 3.11, LanggraphRunContext.resolve will return None. + For Python >= 3.11, it will try get_runtime() from langgraph which depends on + context propagation. Since we don't run inside langgraph in unit tests, + no one propagates the context for us, so RuntimeError is raised. + """ + import sys + config: RunnableConfig = {"configurable": {}} + + if sys.version_info >= (3, 11): + with pytest.raises(RuntimeError, match="outside of a runnable context"): + LanggraphRunContext.resolve(config=config) + else: + resolved = LanggraphRunContext.resolve(config=config) + assert resolved is None + + def test_from_config_returns_context(self): + """Test LanggraphRunContext.from_config method.""" + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext() + + context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + config: RunnableConfig = {"configurable": {}} + context.attach_to_config(config) + + result = LanggraphRunContext.from_config(config) + + assert result is context + + def test_from_config_returns_none_for_non_context_value(self): + """Test that from_config returns None when value is not LanggraphRunContext.""" + config: RunnableConfig = { + "configurable": { + "__foundry_hosted_agent_langgraph_run_context__": "not a context" + } + } + + result = LanggraphRunContext.from_config(config) + + assert result is None + + +@pytest.mark.unit +class TestToolsResolutionInGraph: + """Tests for tool resolution within langgraph execution.""" + + @pytest.mark.asyncio + async def test_foundry_tools_resolved_from_context_in_graph_node(self): + """Test that foundry tools are correctly resolved from context during graph execution.""" + # Create mock tool + @tool + def search(query: str) -> str: + """Search for information. + + :param query: The search query. + :return: Search results. + """ + return f"Results for: {query}" + + # Create foundry tool and context + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved_tools = ResolvedTools(tools=[(resolved_foundry_tool, search)]) + + # Create context + payload = {"input": [{"role": "user", "content": "test"}], "stream": False} + agent_run_context = AgentRunContext(payload=payload) + tool_context = FoundryToolContext(resolved_tools=resolved_tools) + lg_context = LanggraphRunContext(agent_run=agent_run_context, tools=tool_context) + + # Create config and attach context + config: RunnableConfig = {"configurable": {}} + lg_context.attach_to_config(config) + + # Verify tools can be resolved + resolved = LanggraphRunContext.resolve(config=config) + assert resolved is not None + + tools = list(resolved.tools.resolved_tools.get(foundry_tool)) + assert len(tools) == 1 + assert tools[0].name == "search" + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_builder.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_builder.py new file mode 100644 index 000000000000..1a2a5af167be --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_builder.py @@ -0,0 +1,109 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for use_foundry_tools builder function.""" +import pytest +from typing import List + + +from azure.ai.agentserver.langgraph.tools._builder import use_foundry_tools +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware +from azure.ai.agentserver.langgraph.tools._resolver import get_registry + +from .conftest import FakeChatModel + + +@pytest.fixture(autouse=True) +def clear_registry(): + """Clear the global registry before and after each test.""" + registry = get_registry() + registry.clear() + yield + registry.clear() + + +@pytest.mark.unit +class TestUseFoundryTools: + """Tests for use_foundry_tools function.""" + + def test_use_foundry_tools_with_tools_only_returns_middleware(self): + """Test that passing only tools returns FoundryToolBindingMiddleware.""" + tools = [{"type": "code_interpreter"}] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + + def test_use_foundry_tools_with_model_and_tools_returns_chat_model(self): + """Test that passing model and tools returns FoundryToolLateBindingChatModel.""" + model = FakeChatModel() + tools = [{"type": "code_interpreter"}] + + result = use_foundry_tools(model, tools) # type: ignore + + assert isinstance(result, FoundryToolLateBindingChatModel) + + def test_use_foundry_tools_with_model_but_no_tools_raises_error(self): + """Test that passing model without tools raises ValueError.""" + model = FakeChatModel() + + with pytest.raises(ValueError, match="Tools must be provided"): + use_foundry_tools(model, None) # type: ignore + + def test_use_foundry_tools_registers_tools_in_global_registry(self): + """Test that tools are registered in the global registry.""" + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + + use_foundry_tools(tools) + + registry = get_registry() + assert len(registry) == 2 + + def test_use_foundry_tools_with_model_registers_tools(self): + """Test that tools are registered when using with model.""" + model = FakeChatModel() + tools = [{"type": "code_interpreter"}] + + use_foundry_tools(model, tools) # type: ignore + + registry = get_registry() + assert len(registry) == 1 + + def test_use_foundry_tools_with_empty_tools_list(self): + """Test using with empty tools list.""" + tools: List = [] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + assert len(get_registry()) == 0 + + def test_use_foundry_tools_with_mcp_tools(self): + """Test using with MCP connected tools.""" + tools = [ + { + "type": "mcp", + "project_connection_id": "MicrosoftLearn", + }, + ] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + + def test_use_foundry_tools_with_mixed_tool_types(self): + """Test using with a mix of different tool types.""" + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "MicrosoftLearn"}, + ] + + result = use_foundry_tools(tools) + + assert isinstance(result, FoundryToolBindingMiddleware) + assert len(get_registry()) == 2 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_chat_model.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_chat_model.py new file mode 100644 index 000000000000..046285f17562 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_chat_model.py @@ -0,0 +1,300 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolLateBindingChatModel.""" +import pytest +from typing import Any, List, Optional +from unittest.mock import MagicMock, patch + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import BaseTool, tool + +from azure.ai.agentserver.core.tools import ( + FoundryHostedMcpTool, + FoundryToolDetails, + ResolvedFoundryTool, + SchemaDefinition, + SchemaProperty, + SchemaType, +) +from azure.ai.agentserver.langgraph._context import LanggraphRunContext +from azure.ai.agentserver.langgraph.tools._chat_model import FoundryToolLateBindingChatModel +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + +from .conftest import FakeChatModel + + +@pytest.mark.unit +class TestFoundryToolLateBindingChatModel: + """Tests for FoundryToolLateBindingChatModel class.""" + + def test_llm_type_property(self): + """Test the _llm_type property returns correct value.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + assert "foundry_tool_binding_model" in model._llm_type + assert "fake_chat_model" in model._llm_type + + def test_bind_tools_records_tools(self): + """Test that bind_tools records tools for later use.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + @tool + def my_tool(x: str) -> str: + """My tool.""" + return x + + result = model.bind_tools([my_tool], tool_choice="auto") + + # Should return self for chaining + assert result is model + # Tools should be recorded + assert len(model._bound_tools) == 1 + assert model._bound_kwargs.get("tool_choice") == "auto" + + def test_bind_tools_multiple_times(self): + """Test binding tools multiple times accumulates them.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + @tool + def tool1(x: str) -> str: + """Tool 1.""" + return x + + @tool + def tool2(x: str) -> str: + """Tool 2.""" + return x + + model.bind_tools([tool1]) + model.bind_tools([tool2]) + + assert len(model._bound_tools) == 2 + + def test_tool_node_property(self): + """Test that tool_node property returns a ToolNode.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + tool_node = model.tool_node + + # Should return a ToolNode + assert tool_node is not None + + def test_tool_node_wrapper_property(self): + """Test that tool_node_wrapper returns correct wrappers.""" + delegate = FakeChatModel() + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + wrappers = model.tool_node_wrapper + + assert "wrap_tool_call" in wrappers + assert "awrap_tool_call" in wrappers + assert callable(wrappers["wrap_tool_call"]) + assert callable(wrappers["awrap_tool_call"]) + + def test_invoke_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + sample_code_interpreter_tool: FoundryHostedMcpTool, + ): + """Test invoking model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Hello from model!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + result = model.invoke(input_messages, config=config) + + assert result.content == "Hello from model!" + + @pytest.mark.asyncio + async def test_ainvoke_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test async invoking model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Async hello from model!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + result = await model.ainvoke(input_messages, config=config) + + assert result.content == "Async hello from model!" + + def test_invoke_without_context_and_no_foundry_tools(self): + """Test invoking model without context and no foundry tools. + + For Python < 3.11, LanggraphRunContext.resolve will return None. + For Python >= 3.11, it will try get_runtime() from langgraph which depends on + context propagation. Since we don't run inside langgraph in unit tests, + no one propagates the context for us, so RuntimeError is raised. + """ + import sys + delegate = FakeChatModel( + responses=[AIMessage(content="Hello!")], + ) + # No foundry tools + foundry_tools: List[Any] = [] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + config: RunnableConfig = {"configurable": {}} + input_messages = [HumanMessage(content="Hello")] + + if sys.version_info >= (3, 11): + with pytest.raises(RuntimeError, match="outside of a runnable context"): + model.invoke(input_messages, config=config) + else: + result = model.invoke(input_messages, config=config) + # Should work since no foundry tools need resolution + assert result.content == "Hello!" + + def test_invoke_without_context_raises_error_when_foundry_tools_present(self): + """Test that invoking without context raises error when foundry tools are set. + + For Python < 3.11, LanggraphRunContext.resolve will return None and we expect + 'Unable to resolve foundry tools from context' error. + For Python >= 3.11, it will try get_runtime() from langgraph which depends on + context propagation. Since we don't run inside langgraph in unit tests, + no one propagates the context for us, so RuntimeError is raised. + """ + import sys + delegate = FakeChatModel( + responses=[AIMessage(content="Hello!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + config: RunnableConfig = {"configurable": {}} + input_messages = [HumanMessage(content="Hello")] + + if sys.version_info >= (3, 11): + with pytest.raises(RuntimeError, match="outside of a runnable context"): + model.invoke(input_messages, config=config) + else: + with pytest.raises(RuntimeError, match="Unable to resolve foundry tools from context"): + model.invoke(input_messages, config=config) + + def test_stream_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test streaming model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Streamed response!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + results = list(model.stream(input_messages, config=config)) + + assert len(results) == 1 + assert results[0].content == "Streamed response!" + + @pytest.mark.asyncio + async def test_astream_with_context( + self, + mock_langgraph_run_context: LanggraphRunContext, + ): + """Test async streaming model with context attached.""" + delegate = FakeChatModel( + responses=[AIMessage(content="Async streamed response!")], + ) + foundry_tools = [{"type": "code_interpreter"}] + + model = FoundryToolLateBindingChatModel( + delegate=delegate, # type: ignore + runtime=None, + foundry_tools=foundry_tools, + ) + + # Attach context to config + config: RunnableConfig = {"configurable": {}} + mock_langgraph_run_context.attach_to_config(config) + + input_messages = [HumanMessage(content="Hello")] + results = [] + async for chunk in model.astream(input_messages, config=config): + results.append(chunk) + + assert len(results) == 1 + assert results[0].content == "Async streamed response!" + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_context.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_context.py new file mode 100644 index 000000000000..577d4e6e4e6f --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_context.py @@ -0,0 +1,36 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolContext.""" +import pytest + +from azure.ai.agentserver.langgraph.tools._context import FoundryToolContext +from azure.ai.agentserver.langgraph.tools._resolver import ResolvedTools + + +@pytest.mark.unit +class TestFoundryToolContext: + """Tests for FoundryToolContext class.""" + + def test_create_with_resolved_tools(self, sample_resolved_tools: ResolvedTools): + """Test creating FoundryToolContext with resolved tools.""" + context = FoundryToolContext(resolved_tools=sample_resolved_tools) + + assert context.resolved_tools is sample_resolved_tools + + def test_create_with_default_resolved_tools(self): + """Test creating FoundryToolContext with default empty resolved tools.""" + context = FoundryToolContext() + + # Default should be empty ResolvedTools + assert context.resolved_tools is not None + tools_list = list(context.resolved_tools) + assert len(tools_list) == 0 + + def test_resolved_tools_is_iterable(self, sample_resolved_tools: ResolvedTools): + """Test that resolved_tools can be iterated.""" + context = FoundryToolContext(resolved_tools=sample_resolved_tools) + + tools_list = list(context.resolved_tools) + assert len(tools_list) == 1 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_middleware.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_middleware.py new file mode 100644 index 000000000000..bf337c340ff6 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_middleware.py @@ -0,0 +1,217 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolBindingMiddleware.""" +import pytest +from typing import Any, List +from unittest.mock import AsyncMock, MagicMock + +from langchain.agents.middleware.types import ModelRequest +from langchain_core.messages import AIMessage, ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest + +from azure.ai.agentserver.langgraph.tools._middleware import FoundryToolBindingMiddleware + +from .conftest import FakeChatModel + + +@pytest.mark.unit +class TestFoundryToolBindingMiddleware: + """Tests for FoundryToolBindingMiddleware class.""" + + def test_init_with_foundry_tools_creates_dummy_tool(self): + """Test that initialization with foundry tools creates a dummy tool.""" + foundry_tools = [{"type": "code_interpreter"}] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Should have one dummy tool + assert len(middleware.tools) == 1 + assert middleware.tools[0].name == "__dummy_tool_by_foundry_middleware__" + + def test_init_without_foundry_tools_no_dummy_tool(self): + """Test that initialization without foundry tools creates no dummy tool.""" + foundry_tools: List[Any] = [] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + assert len(middleware.tools) == 0 + + def test_wrap_model_call_wraps_model_with_foundry_binding(self): + """Test that wrap_model_call wraps the model correctly.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock model and request + mock_model = FakeChatModel() + mock_runtime = MagicMock() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.runtime = mock_runtime + mock_request.tools = [] + + # Create a modified request to return + modified_request = MagicMock(spec=ModelRequest) + mock_request.override = MagicMock(return_value=modified_request) + + # Mock handler + expected_result = AIMessage(content="Result") + mock_handler = MagicMock(return_value=expected_result) + + result = middleware.wrap_model_call(mock_request, mock_handler) + + # Handler should be called with modified request + mock_handler.assert_called_once() + assert result == expected_result + + @pytest.mark.asyncio + async def test_awrap_model_call_wraps_model_async(self): + """Test that awrap_model_call wraps the model correctly in async.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock model and request + mock_model = FakeChatModel() + mock_runtime = MagicMock() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.runtime = mock_runtime + mock_request.tools = [] + + # Create a modified request to return + modified_request = MagicMock(spec=ModelRequest) + mock_request.override = MagicMock(return_value=modified_request) + + # Mock async handler + expected_result = AIMessage(content="Async Result") + mock_handler = AsyncMock(return_value=expected_result) + + result = await middleware.awrap_model_call(mock_request, mock_handler) + + # Handler should be called + mock_handler.assert_awaited_once() + assert result == expected_result + + def test_wrap_model_without_foundry_tools_returns_unchanged(self): + """Test that wrap_model returns unchanged request when no foundry tools.""" + foundry_tools: List[Any] = [] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + mock_model = FakeChatModel() + mock_request = MagicMock(spec=ModelRequest) + mock_request.model = mock_model + mock_request.tools = [] + + # Should not call override + mock_request.override = MagicMock() + + mock_handler = MagicMock(return_value=AIMessage(content="Result")) + + middleware.wrap_model_call(mock_request, mock_handler) + + # Handler should be called with original request + mock_handler.assert_called_once_with(mock_request) + + def test_remove_dummy_tool_from_request(self): + """Test that dummy tool is removed from the request tools.""" + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create request with dummy tool + dummy = middleware._dummy_tool() + + @tool + def real_tool(x: str) -> str: + """Real tool.""" + return x + + mock_request = MagicMock(spec=ModelRequest) + mock_request.tools = [dummy, real_tool] + + # Call internal method + result = middleware._remove_dummy_tool(mock_request) + + # Should only have real_tool + assert len(result) == 1 + assert result[0] is real_tool + + def test_wrap_tool_call_delegates_to_wrapper(self): + """Test that wrap_tool_call delegates to FoundryToolCallWrapper. + + For Python < 3.11, LanggraphRunContext.resolve will return None. + For Python >= 3.11, it will try get_runtime() from langgraph which depends on + context propagation. Since we don't run inside langgraph in unit tests, + no one propagates the context for us, so RuntimeError is raised. + """ + import sys + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock tool call request + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "test_tool", "id": "call_1"} + mock_request.state = {} + mock_request.runtime = None + + # Mock handler + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_handler = MagicMock(return_value=expected_result) + + if sys.version_info >= (3, 11): + with pytest.raises(RuntimeError, match="outside of a runnable context"): + middleware.wrap_tool_call(mock_request, mock_handler) + else: + result = middleware.wrap_tool_call(mock_request, mock_handler) + # Handler should be called + mock_handler.assert_called_once() + assert result == expected_result + + @pytest.mark.asyncio + async def test_awrap_tool_call_delegates_to_wrapper_async(self): + """Test that awrap_tool_call delegates to FoundryToolCallWrapper async. + + For Python < 3.11, LanggraphRunContext.resolve will return None. + For Python >= 3.11, it will try get_runtime() from langgraph which depends on + context propagation. Since we don't run inside langgraph in unit tests, + no one propagates the context for us, so RuntimeError is raised. + """ + import sys + foundry_tools = [{"type": "code_interpreter"}] + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Create mock tool call request + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "test_tool", "id": "call_1"} + mock_request.state = {} + mock_request.runtime = None + + # Mock async handler + expected_result = ToolMessage(content="Async Result", tool_call_id="call_1") + mock_handler = AsyncMock(return_value=expected_result) + + if sys.version_info >= (3, 11): + with pytest.raises(RuntimeError, match="outside of a runnable context"): + await middleware.awrap_tool_call(mock_request, mock_handler) + else: + result = await middleware.awrap_tool_call(mock_request, mock_handler) + # Handler should be awaited + mock_handler.assert_awaited_once() + assert result == expected_result + + def test_middleware_with_multiple_foundry_tools(self): + """Test middleware initialization with multiple foundry tools.""" + foundry_tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + + middleware = FoundryToolBindingMiddleware(foundry_tools) + + # Should still only have one dummy tool + assert len(middleware.tools) == 1 + # But should have all foundry tools registered + assert len(middleware._foundry_tools_to_bind) == 2 + diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_resolver.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_resolver.py new file mode 100644 index 000000000000..985ed4caec49 --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_resolver.py @@ -0,0 +1,502 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for ResolvedTools and FoundryLangChainToolResolver.""" +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from langchain_core.tools import BaseTool, StructuredTool, tool +from pydantic import BaseModel + +from azure.ai.agentserver.core.tools import (FoundryConnectedTool, FoundryHostedMcpTool, FoundryToolDetails, + ResolvedFoundryTool, SchemaDefinition, SchemaProperty, SchemaType) +from azure.ai.agentserver.langgraph.tools._resolver import ( + ResolvedTools, + FoundryLangChainToolResolver, + get_registry, +) + + +@pytest.mark.unit +class TestResolvedTools: + """Tests for ResolvedTools class.""" + + def test_create_empty_resolved_tools(self): + """Test creating an empty ResolvedTools.""" + resolved = ResolvedTools(tools=[]) + + tools_list = list(resolved) + assert len(tools_list) == 0 + + def test_create_with_single_tool( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test creating ResolvedTools with a single tool.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + tools_list = list(resolved) + assert len(tools_list) == 1 + assert tools_list[0] is mock_langchain_tool + + def test_create_with_multiple_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + ): + """Test creating ResolvedTools with multiple tools.""" + @tool + def tool1(query: str) -> str: + """Tool 1.""" + return "result1" + + @tool + def tool2(query: str) -> str: + """Tool 2.""" + return "result2" + + resolved_tool1 = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="tool1", + description="Tool 1", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + resolved_tool2 = ResolvedFoundryTool( + definition=sample_mcp_connected_tool, + details=FoundryToolDetails( + name="tool2", + description="Tool 2", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved = ResolvedTools(tools=[ + (resolved_tool1, tool1), + (resolved_tool2, tool2), + ]) + + tools_list = list(resolved) + assert len(tools_list) == 2 + + def test_get_tool_by_foundry_tool_like( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test getting tools by FoundryToolLike.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get by the original foundry tool definition + tools = list(resolved.get(sample_code_interpreter_tool)) + assert len(tools) == 1 + assert tools[0] is mock_langchain_tool + + def test_get_tools_by_list_of_foundry_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + ): + """Test getting tools by a list of FoundryToolLike.""" + @tool + def tool1(query: str) -> str: + """Tool 1.""" + return "result1" + + @tool + def tool2(query: str) -> str: + """Tool 2.""" + return "result2" + + resolved_tool1 = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="tool1", + description="Tool 1", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + resolved_tool2 = ResolvedFoundryTool( + definition=sample_mcp_connected_tool, + details=FoundryToolDetails( + name="tool2", + description="Tool 2", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={"query": SchemaProperty(type=SchemaType.STRING, description="Query")}, + required={"query"}, + ), + ), + ) + + resolved = ResolvedTools(tools=[ + (resolved_tool1, tool1), + (resolved_tool2, tool2), + ]) + + # Get by list of foundry tools + tools = list(resolved.get([sample_code_interpreter_tool, sample_mcp_connected_tool])) + assert len(tools) == 2 + + def test_get_all_tools_when_no_filter( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test getting all tools when no filter is provided.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get all tools (no filter) + tools = list(resolved.get()) + assert len(tools) == 1 + + def test_get_returns_empty_for_unknown_tool( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + sample_mcp_connected_tool: FoundryConnectedTool, + mock_langchain_tool: BaseTool, + ): + """Test that get returns empty when requesting unknown tool.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Get by a different foundry tool (not in resolved) + tools = list(resolved.get(sample_mcp_connected_tool)) + assert len(tools) == 0 + + def test_iteration_over_resolved_tools( + self, + sample_code_interpreter_tool: FoundryHostedMcpTool, + mock_langchain_tool: BaseTool, + ): + """Test iterating over ResolvedTools.""" + resolved_foundry_tool = ResolvedFoundryTool( + definition=sample_code_interpreter_tool, + details=FoundryToolDetails( + name="test_tool", + description="A test tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ), + ), + ) + resolved = ResolvedTools(tools=[(resolved_foundry_tool, mock_langchain_tool)]) + + # Iterate using for loop + count = 0 + for t in resolved: + assert t is mock_langchain_tool + count += 1 + assert count == 1 + + +@pytest.mark.unit +class TestFoundryLangChainToolResolver: + """Tests for FoundryLangChainToolResolver class.""" + + def test_init_with_default_name_resolver(self): + """Test initialization with default name resolver.""" + resolver = FoundryLangChainToolResolver() + + assert resolver._name_resolver is not None + + def test_init_with_custom_name_resolver(self): + """Test initialization with custom name resolver.""" + from azure.ai.agentserver.core.tools.utils import ToolNameResolver + + custom_resolver = ToolNameResolver() + resolver = FoundryLangChainToolResolver(name_resolver=custom_resolver) + + assert resolver._name_resolver is custom_resolver + + def test_create_pydantic_model_with_required_fields(self): + """Test creating a Pydantic model with required fields.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + "limit": SchemaProperty(type=SchemaType.INTEGER, description="Max results"), + }, + required={"query"}, + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("test_tool", input_schema) + + assert issubclass(model, BaseModel) + # Check that the model has the expected fields + assert "query" in model.model_fields + assert "limit" in model.model_fields + + def test_create_pydantic_model_with_no_required_fields(self): + """Test creating a Pydantic model with no required fields.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required=set(), + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("optional_tool", input_schema) + + assert issubclass(model, BaseModel) + assert "query" in model.model_fields + # Optional field should have None as default + assert model.model_fields["query"].default is None + + def test_create_pydantic_model_with_special_characters_in_name(self): + """Test creating a Pydantic model with special characters in tool name.""" + input_schema = SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "input": SchemaProperty(type=SchemaType.STRING, description="Input"), + }, + required={"input"}, + ) + + model = FoundryLangChainToolResolver._create_pydantic_model("my-tool name", input_schema) + + assert issubclass(model, BaseModel) + # Name should be sanitized + assert "-Input" in model.__name__ or "Input" in model.__name__ + + def test_create_structured_tool(self): + """Test creating a StructuredTool from a resolved foundry tool.""" + resolver = FoundryLangChainToolResolver() + + foundry_tool = FoundryHostedMcpTool(name="test_tool", configuration={}) + resolved_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search for documents", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Search query"), + }, + required={"query"}, + ), + ), + ) + + structured_tool = resolver._create_structured_tool(resolved_tool) + + assert isinstance(structured_tool, StructuredTool) + assert structured_tool.description == "Search for documents" + assert structured_tool.coroutine is not None # Should have async function + + @pytest.mark.asyncio + async def test_resolve_from_registry(self): + """Test resolving tools from the global registry.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + + foundry_tool = FoundryHostedMcpTool(name="test_tool", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="search", + description="Search tool", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "query": SchemaProperty(type=SchemaType.STRING, description="Query"), + }, + required={"query"}, + ), + ), + ) + mock_catalog.return_value = [resolved_foundry_tool] + + # Add tool to registry + registry = get_registry() + registry.clear() + registry.append({"type": "code_interpreter"}) + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve_from_registry() + + assert isinstance(result, ResolvedTools) + mock_catalog.assert_called_once() + + # Clean up registry + registry.clear() + + @pytest.mark.asyncio + async def test_resolve_with_foundry_tools_list(self): + """Test resolving a list of foundry tools.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + + foundry_tool = FoundryHostedMcpTool(name="code_interpreter", configuration={}) + resolved_foundry_tool = ResolvedFoundryTool( + definition=foundry_tool, + details=FoundryToolDetails( + name="execute_code", + description="Execute code", + input_schema=SchemaDefinition( + type=SchemaType.OBJECT, + properties={ + "code": SchemaProperty(type=SchemaType.STRING, description="Code to execute"), + }, + required={"code"}, + ), + ), + ) + mock_catalog.return_value = [resolved_foundry_tool] + + foundry_tools = [{"type": "code_interpreter"}] + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve(foundry_tools) + + assert isinstance(result, ResolvedTools) + tools_list = list(result) + assert len(tools_list) == 1 + assert isinstance(tools_list[0], StructuredTool) + + @pytest.mark.asyncio + async def test_resolve_empty_list(self): + """Test resolving an empty list of foundry tools.""" + resolver = FoundryLangChainToolResolver() + + # Mock the AgentServerContext + mock_context = MagicMock() + mock_catalog = AsyncMock() + mock_context.tools.catalog.list = mock_catalog + mock_catalog.return_value = [] + + with patch("azure.ai.agentserver.langgraph.tools._resolver.AgentServerContext.get", return_value=mock_context): + result = await resolver.resolve([]) + + assert isinstance(result, ResolvedTools) + tools_list = list(result) + assert len(tools_list) == 0 + + +@pytest.mark.unit +class TestGetRegistry: + """Tests for the get_registry function.""" + + def test_get_registry_returns_list(self): + """Test that get_registry returns a list.""" + registry = get_registry() + + assert isinstance(registry, list) + + def test_registry_is_singleton(self): + """Test that get_registry returns the same list instance.""" + registry1 = get_registry() + registry2 = get_registry() + + assert registry1 is registry2 + + def test_registry_can_be_modified(self): + """Test that the registry can be modified.""" + registry = get_registry() + original_length = len(registry) + + registry.append({"type": "test_tool"}) + + assert len(registry) == original_length + 1 + + # Clean up + registry.pop() + + def test_registry_extend(self): + """Test extending the registry with multiple tools.""" + registry = get_registry() + registry.clear() + + tools = [ + {"type": "code_interpreter"}, + {"type": "mcp", "project_connection_id": "test"}, + ] + registry.extend(tools) + + assert len(registry) == 2 + + # Clean up + registry.clear() diff --git a/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_tool_node.py b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_tool_node.py new file mode 100644 index 000000000000..1c46e58785bc --- /dev/null +++ b/sdk/agentserver/azure-ai-agentserver-langgraph/tests/unit_tests/tools/test_tool_node.py @@ -0,0 +1,179 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- +"""Unit tests for FoundryToolCallWrapper and FoundryToolNodeWrappers.""" +import pytest +from typing import Any, List +from unittest.mock import AsyncMock, MagicMock + +from langchain_core.messages import ToolMessage +from langchain_core.tools import tool +from langgraph.prebuilt.tool_node import ToolCallRequest +from langgraph.types import Command + +from azure.ai.agentserver.langgraph.tools._tool_node import ( + FoundryToolCallWrapper, + FoundryToolNodeWrappers, +) + + +@pytest.mark.unit +class TestFoundryToolCallWrapper: + """Tests for FoundryToolCallWrapper class.""" + + def test_as_wrappers_returns_typed_dict(self): + """Test that as_wrappers returns a FoundryToolNodeWrappers TypedDict.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + result = wrapper.as_wrappers() + + assert isinstance(result, dict) + assert "wrap_tool_call" in result + assert "awrap_tool_call" in result + assert callable(result["wrap_tool_call"]) + assert callable(result["awrap_tool_call"]) + + def test_call_tool_with_already_resolved_tool(self): + """Test that call_tool passes through when tool is already resolved.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + # Create request with tool already set + @tool + def existing_tool(x: str) -> str: + """Existing tool.""" + return f"Result: {x}" + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = existing_tool + mock_request.tool_call = {"name": "existing_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result: test", tool_call_id="call_1") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + # Should pass through original request + mock_invocation.assert_called_once_with(mock_request) + assert result == expected_result + + def test_call_tool_with_no_foundry_tools(self): + """Test that call_tool passes through when no foundry tools configured.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + mock_invocation.assert_called_once_with(mock_request) + assert result == expected_result + + @pytest.mark.asyncio + async def test_call_tool_async_with_already_resolved_tool(self): + """Test that call_tool_async passes through when tool is already resolved.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + @tool + def existing_tool(x: str) -> str: + """Existing tool.""" + return f"Result: {x}" + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = existing_tool + mock_request.tool_call = {"name": "existing_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Async Result", tool_call_id="call_1") + mock_invocation = AsyncMock(return_value=expected_result) + + result = await wrapper.call_tool_async(mock_request, mock_invocation) + + mock_invocation.assert_awaited_once_with(mock_request) + assert result == expected_result + + @pytest.mark.asyncio + async def test_call_tool_async_with_no_foundry_tools(self): + """Test that call_tool_async passes through when no foundry tools configured.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + expected_result = ToolMessage(content="Result", tool_call_id="call_1") + mock_invocation = AsyncMock(return_value=expected_result) + + result = await wrapper.call_tool_async(mock_request, mock_invocation) + + mock_invocation.assert_awaited_once_with(mock_request) + assert result == expected_result + + def test_call_tool_returns_command_result(self): + """Test that call_tool can return Command objects.""" + foundry_tools: List[Any] = [] + wrapper = FoundryToolCallWrapper(foundry_tools) + + mock_request = MagicMock(spec=ToolCallRequest) + mock_request.tool = None + mock_request.tool_call = {"name": "some_tool", "id": "call_1"} + + # Return a Command instead of ToolMessage + expected_result = Command(goto="next_node") + mock_invocation = MagicMock(return_value=expected_result) + + result = wrapper.call_tool(mock_request, mock_invocation) + + assert result == expected_result + assert isinstance(result, Command) + + +@pytest.mark.unit +class TestFoundryToolNodeWrappers: + """Tests for FoundryToolNodeWrappers TypedDict.""" + + def test_foundry_tool_node_wrappers_structure(self): + """Test that FoundryToolNodeWrappers has the expected structure.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + wrappers: FoundryToolNodeWrappers = wrapper.as_wrappers() + + # Should have both sync and async wrappers + assert "wrap_tool_call" in wrappers + assert "awrap_tool_call" in wrappers + + # Should be the wrapper methods + assert wrappers["wrap_tool_call"] == wrapper.call_tool + assert wrappers["awrap_tool_call"] == wrapper.call_tool_async + + def test_wrappers_can_be_unpacked_to_tool_node(self): + """Test that wrappers can be unpacked as kwargs to ToolNode.""" + foundry_tools = [{"type": "code_interpreter"}] + wrapper = FoundryToolCallWrapper(foundry_tools) + + wrappers = wrapper.as_wrappers() + + # Should be usable as kwargs + assert len(wrappers) == 2 + + # This pattern is used: ToolNode([], **wrappers) + def mock_tool_node_init(tools, wrap_tool_call=None, awrap_tool_call=None): + return { + "tools": tools, + "wrap_tool_call": wrap_tool_call, + "awrap_tool_call": awrap_tool_call, + } + + result = mock_tool_node_init([], **wrappers) + + assert result["wrap_tool_call"] is not None + assert result["awrap_tool_call"] is not None + diff --git a/shared_requirements.txt b/shared_requirements.txt index b5e1b85f184a..630973de2eb2 100644 --- a/shared_requirements.txt +++ b/shared_requirements.txt @@ -96,4 +96,5 @@ opentelemetry-exporter-otlp-proto-grpc agent-framework-core langchain langchain-openai -azure-ai-language-questionanswering \ No newline at end of file +azure-ai-language-questionanswering +cachetools \ No newline at end of file