Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,8 @@ async def agent_run(self, context: AgentRunContext):
else:
graph = self._resolved_graph

input_data = self.state_converter.request_to_state(context)
prev_state = await self.get_prev_state(graph, context)
input_data = self.state_converter.request_to_state(context, prev_state)
logger.debug(f"Converted input data: {input_data}")
if not context.stream:
try:
Expand Down Expand Up @@ -265,6 +266,7 @@ async def agent_run_non_stream(self, input_data: dict, context: AgentRunContext,
config = self.create_runnable_config(context)
stream_mode = self.state_converter.get_stream_mode(context)
result = await graph.ainvoke(input_data, config=config, stream_mode=stream_mode)
logger.info("State after invoke: %s", result)
output = self.state_converter.state_to_response(result, context)
return output
except Exception as e:
Expand Down Expand Up @@ -311,7 +313,7 @@ async def agent_run_astream(
logger.debug("Closed tool_client after streaming completed")
except Exception as e:
logger.warning(f"Error closing tool_client in stream: {e}")

def create_runnable_config(self, context: AgentRunContext) -> RunnableConfig:
"""
Create a RunnableConfig from the converted request data.
Expand All @@ -323,13 +325,46 @@ def create_runnable_config(self, context: AgentRunContext) -> RunnableConfig:
:rtype: RunnableConfig
"""
config = RunnableConfig(
configurable={
"thread_id": context.conversation_id,
},
configurable=self.create_configurable(context),
callbacks=[self.azure_ai_tracer] if self.azure_ai_tracer else None,
)
return config

async def get_prev_state(self, graph: CompiledStateGraph, context: AgentRunContext):
"""
Get the previous state from the graph using the context.

:param graph: The compiled graph instance.
:type graph: CompiledStateGraph
:param context: The context for the agent run.
:type context: AgentRunContext

:return: The previous state of the graph.
:rtype: StateSnapshot | None
"""
if context.conversation_id and graph.checkpointer:
config = self.create_configurable(context)
prev_state = await graph.aget_state(
config=RunnableConfig(configurable=config)
)
logger.info(f"Retrieved previous state for thread {context.conversation_id}")
return prev_state
return None

def create_configurable(self, context: AgentRunContext) -> dict:
"""
Create a configurable dict from the context.

:param context: The context for the agent run.
:type context: AgentRunContext

:return: The configurable dict containing conversation_id.
:rtype: dict
"""
return {
"thread_id": context.conversation_id,
}

def format_otlp_endpoint(self, endpoint: str) -> str:
m = re.match(r"^(https?://[^/]+)", endpoint)
if m:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

import json
from typing import List, Union

from langgraph.types import (
Command,
Interrupt,
StateSnapshot,
)

from azure.ai.agentserver.core.constants import Constants
from azure.ai.agentserver.core.logger import get_logger
from azure.ai.agentserver.core.models import projects as project_models
from azure.ai.agentserver.core.models.openai import (
ResponseInputParam,
ResponseInputItemParam,
)
from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext

logger = get_logger()

INTERRUPT_NODE_NAME = "__interrupt__"
INTERRUPT_TOOL_NAME = "__hosted_agent_adapter_interrupt__"


class LanggraphHumanInTheLoopHelper:
"""Helper class for managing human-in-the-loop interactions in LangGraph."""
def __init__(self, context: AgentRunContext = None):
self.context = context

def has_interrupt(self, state: StateSnapshot) -> bool:
"""Check if the LangGraph state contains an interrupt node."""
if not state or not isinstance(state, StateSnapshot):
return False
return state.interrupts is not None and len(state.interrupts) > 0

def convert_interrupts(self, interrupts: tuple) -> list[project_models.ItemResource]:
"""Convert LangGraph interrupts to ItemResource objects."""
if not interrupts or not isinstance(interrupts, tuple):
return []
result = []
# should be only one interrupt for now
for interrupt_info in interrupts:
item = self.convert_interrupt(interrupt_info)
if item:
result.append(item)
return result

def convert_interrupt(self, interrupt_info: Interrupt) -> project_models.ItemResource:
"""Convert a single LangGraph Interrupt to an ItemResource object.

:param interrupt_info: The interrupt information from LangGraph.
:type interrupt_info: Interrupt

:return: The corresponding ItemResource object.
:rtype: project_models.ItemResource
"""
raise NotImplementedError("Subclasses must implement convert_interrupt method.")

def validate_and_convert_human_feedback(
self, state: StateSnapshot, input: Union[str, ResponseInputParam]
) -> Union[Command, None]:
"""Validate if the human feedback input corresponds to the interrupt in state.
If valid, convert the input to a LangGraph Command.

:param state: The current LangGraph state snapshot.
:type state: StateSnapshot
:param input: The human feedback input from the request.
:type input: Union[str, ResponseInputParam]

:return: Command if valid feedback is provided, else None.
:rtype: Union[Command, None]
"""
raise NotImplementedError("Subclasses must implement validate_and_convert_human_feedback method.")

def convert_input_item_to_command(self, input: ResponseInputItemParam) -> Union[Command, None]:
"""Convert ItemParams to a LangGraph Command for interrupt handling.

:param input: The item parameters containing interrupt information.
:type input: ResponseInputItemParam
:return: The LangGraph Command.
:rtype: Union[Command, None]
"""
raise NotImplementedError("Subclasses must implement convert_request_to_command method.")


class LanggraphHumanInTheLoopDefaultHelper(LanggraphHumanInTheLoopHelper):
"""
Default helper class for managing human-in-the-loop interactions in LangGraph.
Interrupts are converted to FunctionToolCallItemResource objects.
Human feedback will be sent back as FunctionCallOutputItemParam.
All values are serialized as JSON strings.
"""

def convert_interrupt(self, interrupt_info: Interrupt) -> project_models.ItemResource:
if not isinstance(interrupt_info, Interrupt):
logger.warning(f"Interrupt info is not of type Interrupt: {interrupt_info}")
return None
if isinstance(interrupt_info.value, str):
arguments = interrupt_info.value
else:
arguments = json.dumps(interrupt_info.value)
return project_models.FunctionToolCallItemResource(
call_id=interrupt_info.id,
name=INTERRUPT_TOOL_NAME,
arguments=arguments,
id=self.context.id_generator.generate_function_call_id(),
status="inprogress",
)

def validate_and_convert_human_feedback(
self, state: StateSnapshot, input: Union[str, ResponseInputParam]
) -> Union[Command, None]:
if not self.has_interrupt(state):
# No interrupt in state
logger.info("No interrupt found in state.")
return None
interrupt_obj = state.interrupts[0] # Assume single interrupt for simplicity
if not interrupt_obj or not isinstance(interrupt_obj, Interrupt):
logger.warning(f"No interrupt object found in state")
return None

logger.info(f"Retrived interrupt from state, validating and convert human feedback.")
if isinstance(input, str):
# expect a list of function call output items
logger.warning(f"Expecting function call output item, got string: {input}")
return None
if isinstance(input, list):
if len(input) != 1:
# expect exactly one function call output item
logger.warning(f"Expected exactly one interrupt input item, got {len(input)} items.")
return None
item = input[0]
# validate item type
item_type = item.get("type", None)
if item_type != project_models.ItemType.FUNCTION_CALL_OUTPUT:
logger.warning(f"Invalid interrupt input item type: {item_type}, expected FUNCTION_CALL_OUTPUT.")
return None

# validate call_id matches
if item.get("call_id") != interrupt_obj.id:
logger.warning(f"Interrupt input call_id {item.call_id} does not match interrupt id {interrupt_obj.id}.")
return None

return self.convert_input_item_to_command(item)
else:
logger.error(f"Unsupported interrupt input type: {type(input)}, {input}")
return None

def convert_input_item_to_command(self, input: ResponseInputItemParam) -> Union[Command, None]:
output_str = input.get("output")
try:
output = json.loads(output_str)
except json.JSONDecodeError as e:
logger.error(f"Invalid JSON in function call output: {output_str}")
return None
resume = output.get("resume")
update = output.get("update")
goto = output.get("goto")
if not resume and not update and not goto:
logger.warning(f"No valid command fields found in function call output: {output}")
return None
return Command(
resume=resume,
update=update,
goto=goto,
)
Original file line number Diff line number Diff line change
Expand Up @@ -4,39 +4,51 @@
# pylint: disable=logging-fstring-interpolation,broad-exception-caught,logging-not-lazy
# mypy: disable-error-code="valid-type,call-overload,attr-defined"
import copy
import json
from typing import List

from langchain_core import messages
from langchain_core.messages import AnyMessage
from langgraph.types import Interrupt

from azure.ai.agentserver.core.logger import get_logger
from azure.ai.agentserver.core.models import projects as project_models
from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext

from .langgraph_hitl_helper import (
INTERRUPT_NODE_NAME,
LanggraphHumanInTheLoopHelper,
)
from .utils import extract_function_call

logger = get_logger()


class LangGraphResponseConverter:
def __init__(self, context: AgentRunContext, output):
def __init__(self, context: AgentRunContext, output, hitl_helper: LanggraphHumanInTheLoopHelper):
self.context = context
self.output = output
self.hitl_helper = hitl_helper

def convert(self) -> list[project_models.ItemResource]:
res = []
for step in self.output:
for node_name, node_output in step.items():
message_arr = node_output.get("messages")
if not message_arr:
logger.warning(f"No messages found in node {node_name} output: {node_output}")
continue
for message in message_arr:
try:
converted = self.convert_output_message(message)
res.append(converted)
except Exception as e:
logger.error(f"Error converting message {message}: {e}")
if node_name == INTERRUPT_NODE_NAME:
interrupt_messages = self.hitl_helper.convert_interrupts(node_output)
res.extend(interrupt_messages)
else:
message_arr = node_output.get("messages")
if not message_arr or not isinstance(message_arr, list):
logger.warning(f"No messages found in node {node_name} output: {node_output}")
continue
for message in message_arr:
try:
converted = self.convert_output_message(message)
if converted:
res.append(converted)
except Exception as e:
logger.error(f"Error converting message {message}: {e}")
return res

def convert_output_message(self, output_message: AnyMessage): # pylint: disable=inconsistent-return-statements
Expand Down Expand Up @@ -87,6 +99,7 @@ def convert_output_message(self, output_message: AnyMessage): # pylint: disable
output=output_message.content,
id=self.context.id_generator.generate_function_output_id(),
)
logger.warning(f"Unsupported message type: {type(output_message)}, {output_message}")

def convert_MessageContent(
self, content, role: project_models.ResponsesMessageRole
Expand Down Expand Up @@ -134,3 +147,4 @@ def convert_MessageContentItem(
content_dict["annotations"] = [] # annotation is required for output_text

return project_models.ItemContent(content_dict)

Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,15 @@

import time
from abc import ABC, abstractmethod
from typing import Any, AsyncGenerator, AsyncIterator, Dict
from typing import Any, AsyncGenerator, AsyncIterator, Dict, Union

from langgraph.types import Command, Interrupt, StateSnapshot

from azure.ai.agentserver.core.models import Response, ResponseStreamEvent
from azure.ai.agentserver.core.models import projects as project_models
from azure.ai.agentserver.core.server.common.agent_run_context import AgentRunContext

from .langgraph_hitl_helper import LanggraphHumanInTheLoopDefaultHelper
from .langgraph_request_converter import LangGraphRequestConverter
from .langgraph_response_converter import LangGraphResponseConverter
from .langgraph_stream_response_converter import LangGraphStreamResponseConverter
Expand Down Expand Up @@ -56,17 +60,21 @@ def get_stream_mode(self, context: AgentRunContext) -> str:
"""

@abstractmethod
def request_to_state(self, context: AgentRunContext) -> Dict[str, Any]:
def request_to_state(
self, context: AgentRunContext, prev_state: StateSnapshot
) -> Union[Dict[str, Any], Command]:
"""Convert the incoming request (via context) to an initial LangGraph state.

Return a serializable dict that downstream graph execution expects.
Should not mutate the context. Raise ValueError on invalid input.

:param context: The context for the agent run.
:type context: AgentRunContext
:param prev_state: The previous LangGraph state if resuming a conversation.
:type prev_state: StateSnapshot

:return: The initial LangGraph state as a dictionary.
:rtype: Dict[str, Any]
:return: The initial LangGraph state as a dictionary or a Command object.
:rtype: Union[Dict[str, Any], Command]
"""

@abstractmethod
Expand Down Expand Up @@ -114,12 +122,22 @@ def get_stream_mode(self, context: AgentRunContext) -> str:
return "messages"
return "updates"

def request_to_state(self, context: AgentRunContext) -> Dict[str, Any]:
def request_to_state(self, context: AgentRunContext, prev_state: StateSnapshot) -> Union[Dict[str, Any], Command]:
hitl_helper = LanggraphHumanInTheLoopDefaultHelper(context)
command = hitl_helper.validate_and_convert_human_feedback(
prev_state, context.request.get("input")
)
if command is not None:
return command
converter = LangGraphRequestConverter(context.request)
return converter.convert()

def state_to_response(self, state: Any, context: AgentRunContext) -> Response:
converter = LangGraphResponseConverter(context, state)
converter = LangGraphResponseConverter(
context,
state,
hitl_helper=LanggraphHumanInTheLoopDefaultHelper(context),
)
output = converter.convert()

agent_id = context.get_agent_id_object()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
AZURE_OPENAI_API_KEY=<api-key>
AZURE_OPENAI_ENDPOINT=https://<endpoint-name>.cognitiveservices.azure.com/
OPENAI_API_VERSION=2025-03-01-preview
AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=<deployment-name>
Loading