From 4a31d71c37c49d8b0fa56844ddd8410e4f0d0e05 Mon Sep 17 00:00:00 2001 From: Cristian Pufu Date: Sat, 21 Feb 2026 12:23:37 +0200 Subject: [PATCH 1/2] fix: persist session after HITL resume and emit tool completed events - After HITL resume, checkpoint restore creates separate session copies per executor. Extract the most complete session (highest message count) and persist it to KV storage so the next turn has valid history. - Handle AgentExecutorResponse wrapper in _extract_tool_state_events and _extract_contents so function_result from executor_completed data is properly found. - Emit ToolCallEnd in close_message() for pending tool calls interrupted by HITL suspension (clears stale _pending_tool_calls state). - Track pending tool nodes (STARTED without COMPLETED) across stream iterations and synthesize COMPLETED events on HITL resume. Co-Authored-By: Claude Opus 4.6 --- .../uipath-agent-framework/pyproject.toml | 2 +- .../samples/hitl-workflow/main.py | 2 +- .../samples/hitl-workflow/pyproject.toml | 3 + .../src/uipath_agent_framework/chat/openai.py | 4 +- .../runtime/messages.py | 14 +- .../uipath_agent_framework/runtime/runtime.py | 131 ++++++- .../tests/test_hitl_e2e.py | 323 ++++++++++++++++++ packages/uipath-agent-framework/uv.lock | 2 +- 8 files changed, 457 insertions(+), 24 deletions(-) diff --git a/packages/uipath-agent-framework/pyproject.toml b/packages/uipath-agent-framework/pyproject.toml index e25bf84..fa0034a 100644 --- a/packages/uipath-agent-framework/pyproject.toml +++ b/packages/uipath-agent-framework/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-agent-framework" -version = "0.0.6" +version = "0.0.7" description = "Python SDK that enables developers to build and deploy Microsoft Agent Framework agents to the UiPath Cloud Platform" readme = "README.md" requires-python = ">=3.11" diff --git a/packages/uipath-agent-framework/samples/hitl-workflow/main.py b/packages/uipath-agent-framework/samples/hitl-workflow/main.py index 2dc1d91..7feb89a 100644 --- a/packages/uipath-agent-framework/samples/hitl-workflow/main.py +++ b/packages/uipath-agent-framework/samples/hitl-workflow/main.py @@ -54,7 +54,7 @@ def issue_refund(order_id: str, amount: float, reason: str) -> str: return f"Refund of ${amount:.2f} issued for order {order_id}: {reason}" -client = UiPathOpenAIChatClient(model="gpt-5-mini-2025-08-07") +client = UiPathOpenAIChatClient() triage = client.as_agent( name="triage", diff --git a/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml index 7b9e2d2..713932b 100644 --- a/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml +++ b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml @@ -19,3 +19,6 @@ dev = [ [tool.uv] prerelease = "allow" +[tool.uv.sources] +uipath-dev = { path = "../../../../../uipath-dev-python", editable = true } +uipath-agent-framework = { path = "../../", editable = true } diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py index 6a4ca6c..7861b27 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py @@ -67,7 +67,7 @@ class UiPathOpenAIChatClient(OpenAIChatClient): from uipath_agent_framework.chat import UiPathOpenAIChatClient - client = UiPathOpenAIChatClient(model="gpt-4o-mini") + client = UiPathOpenAIChatClient(model="gpt-4.1-mini-2025-04-14") agent = client.as_agent( name="assistant", instructions="You are a helpful assistant.", @@ -75,7 +75,7 @@ class UiPathOpenAIChatClient(OpenAIChatClient): ) """ - def __init__(self, model: str = "gpt-4o-mini", **kwargs: Any): + def __init__(self, model: str = "gpt-4.1-mini-2025-04-14", **kwargs: Any): uipath_url, token = get_uipath_config() gateway_url = build_gateway_url("openai", model, uipath_url) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py index e034119..37b0516 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py @@ -184,12 +184,20 @@ def map_streaming_content( def close_message(self) -> list[UiPathConversationMessageEvent]: """Close the current message if open. Safety net for end of stream.""" + events: list[UiPathConversationMessageEvent] = [] + # Emit ToolCallEnd for any tool calls that were started but never + # completed (e.g. HITL suspension interrupted before function_result). + if self._pending_tool_calls: + for tool_call_id, message_id in self._pending_tool_calls.items(): + events.append( + self._make_tool_call_end_event(message_id, tool_call_id, {}) + ) + self._pending_tool_calls.clear() if self._message_started and self._current_message_id: - events = [self._make_message_end_event(self._current_message_id)] + events.append(self._make_message_end_event(self._current_message_id)) self._message_started = False self._current_message_id = None - return events - return [] + return events @staticmethod def _extract_text_from_content(content: Content) -> str: diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py index a2d373d..a8ff68b 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py @@ -6,6 +6,7 @@ from agent_framework import ( AgentExecutor, + AgentExecutorResponse, AgentResponse, AgentResponseUpdate, AgentSession, @@ -67,6 +68,12 @@ def __init__( self._last_breakpoint_node: str | None = None self._last_checkpoint_id: str | None = None self._resumed_from_checkpoint_id: str | None = None + # Track tool nodes that emitted STARTED but not yet COMPLETED. + # Persists across _stream_workflow() calls (same runtime instance + # reused by UiPathChatRuntime's while loop), allowing us to emit + # synthetic COMPLETED events on HITL resume when the framework + # doesn't surface function_result in output/executor_completed. + self._pending_tool_nodes: set[str] = set() # ------------------------------------------------------------------ # Checkpoint helpers @@ -171,6 +178,37 @@ def _apply_session_to_executors(self, session: AgentSession) -> None: if isinstance(executor, AgentExecutor): executor._session = session + def _get_session_from_executors(self) -> AgentSession | None: + """Extract the most complete session from AgentExecutors in the workflow. + + After checkpoint restore each executor receives its own independent + session copy (unlike fresh runs where all executors share one object). + Only the executor that processed the HITL/breakpoint response will + have the updated conversation history. We return the session with the + most messages to ensure the complete history is persisted. + """ + workflow = self.agent.workflow + best_session: AgentSession | None = None + best_msg_count = -1 + for executor in workflow.executors.values(): + if isinstance(executor, AgentExecutor) and executor._session is not None: + msg_count = self._count_session_messages(executor._session) + if msg_count > best_msg_count: + best_msg_count = msg_count + best_session = executor._session + return best_session + + @staticmethod + def _count_session_messages(session: AgentSession) -> int: + """Count total messages across all provider keys in a session's state.""" + count = 0 + for value in session.state.values(): + if isinstance(value, dict) and "messages" in value: + messages = value["messages"] + if isinstance(messages, list): + count += len(messages) + return count + # ------------------------------------------------------------------ # HITL helpers (tool approval flow) # ------------------------------------------------------------------ @@ -332,6 +370,12 @@ async def execute( checkpoint_storage=self._checkpoint_storage, ) + # After resume paths the checkpoint restores the session into + # executors directly, so the local ``session`` is still None. + # Extract it so it can be persisted after completion. + if session is None: + session = self._get_session_from_executors() + # Check for HITL suspension (framework's request_info mechanism) request_info_events = result.get_request_info_events() hitl_requests = { @@ -462,6 +506,19 @@ async def _stream_workflow( phase=UiPathRuntimeStatePhase.STARTED, ) + # On HITL resume, emit COMPLETED for tool nodes that were left + # pending when the previous stream suspended. The framework + # doesn't surface function_result in output/executor_completed + # for handoff workflows, so we synthesize these events here. + if is_resuming and self._pending_tool_nodes: + for tool_node in list(self._pending_tool_nodes): + yield UiPathRuntimeStateEvent( + payload={}, + node_name=tool_node, + phase=UiPathRuntimeStatePhase.COMPLETED, + ) + self._pending_tool_nodes.clear() + # Choose workflow.run() mode based on resume type if self._resume_responses: # HITL resume: pass responses to workflow with checkpoint @@ -495,10 +552,13 @@ async def _stream_workflow( request_info_map: dict[str, Any] = {} is_suspended = False - # Track executors whose tool events were emitted via output events. - # When the workflow filters output events (e.g. GroupChat), tool events - # are instead extracted from executor_completed data as a fallback. - executors_with_tool_outputs: set[str] = set() + # Track which tool event phases were emitted per executor via output + # events. When the workflow filters output events (e.g. GroupChat), + # tool events are extracted from executor_completed data as a fallback. + # Tracking phases (not just executor_ids) lets us handle HITL resume + # where function_call (STARTED) is in output but function_result + # (COMPLETED) is only in executor_completed. + executor_tool_phases: dict[str, set[UiPathRuntimeStatePhase]] = {} # Emit an early STARTED event for the start executor so the graph # visualization shows it immediately rather than after it finishes. @@ -534,17 +594,36 @@ async def _stream_workflow( phase=UiPathRuntimeStatePhase.STARTED, ) elif event.type == "executor_completed": - # When output events were filtered by the workflow (e.g. - # GroupChat where participants are not output executors), - # extract tool state events from the completed data instead. - if ( - event.executor_id - and event.executor_id not in executors_with_tool_outputs - ): + # Extract tool state events from executor_completed data, + # skipping phases already emitted via output events. + # This handles three scenarios: + # 1. GroupChat (no output events): emit all from completed + # 2. Normal (both in output): skip all from completed + # 3. HITL resume (only STARTED in output): emit COMPLETED + if event.executor_id: + emitted_phases = executor_tool_phases.get( + event.executor_id, set() + ) for tool_event in self._extract_tool_state_events( event.data, event.executor_id ): - yield tool_event + if tool_event.phase not in emitted_phases: + # Track pending tool nodes + if ( + tool_event.phase + == UiPathRuntimeStatePhase.STARTED + ): + self._pending_tool_nodes.add( + tool_event.node_name + ) + elif ( + tool_event.phase + == UiPathRuntimeStatePhase.COMPLETED + ): + self._pending_tool_nodes.discard( + tool_event.node_name + ) + yield tool_event yield UiPathRuntimeStateEvent( payload=self._serialize_event_data( self._filter_completed_data(event.data) @@ -557,9 +636,15 @@ async def _stream_workflow( tool_events = self._extract_tool_state_events( event.data, executor_id ) - if tool_events: - executors_with_tool_outputs.add(executor_id) for tool_event in tool_events: + executor_tool_phases.setdefault( + executor_id, set() + ).add(tool_event.phase) + # Track pending tool nodes across stream iterations + if tool_event.phase == UiPathRuntimeStatePhase.STARTED: + self._pending_tool_nodes.add(tool_event.node_name) + elif tool_event.phase == UiPathRuntimeStatePhase.COMPLETED: + self._pending_tool_nodes.discard(tool_event.node_name) yield tool_event for msg_event in self._extract_workflow_messages(event.data): yield UiPathRuntimeMessageEvent(payload=msg_event) @@ -581,6 +666,10 @@ async def _stream_workflow( for msg_event in self.chat.close_message(): yield UiPathRuntimeMessageEvent(payload=msg_event) + # After resume paths the checkpoint restores the session into + # executors directly, so the local ``session`` may still be None. + if session is None: + session = self._get_session_from_executors() if session is not None: await self._save_session(session) @@ -619,6 +708,10 @@ async def _stream_workflow( for msg_event in self.chat.close_message(): yield UiPathRuntimeMessageEvent(payload=msg_event) + # After resume paths the checkpoint restores the session into + # executors directly, so the local ``session`` may still be None. + if session is None: + session = self._get_session_from_executors() if session is not None: await self._save_session(session) @@ -681,7 +774,11 @@ def _extract_tool_state_events( """ contents: list[Any] = [] - if isinstance(data, AgentResponseUpdate): + if isinstance(data, AgentExecutorResponse): + return UiPathAgentFrameworkRuntime._extract_tool_state_events( + data.agent_response, executor_id + ) + elif isinstance(data, AgentResponseUpdate): contents = list(data.contents or []) elif isinstance(data, AgentResponse): for message in data.messages or []: @@ -724,7 +821,9 @@ def _extract_tool_state_events( def _extract_contents(data: Any) -> list[Any]: """Extract Content objects from any workflow data type.""" contents: list[Any] = [] - if isinstance(data, AgentResponseUpdate): + if isinstance(data, AgentExecutorResponse): + return UiPathAgentFrameworkRuntime._extract_contents(data.agent_response) + elif isinstance(data, AgentResponseUpdate): contents = list(data.contents or []) elif isinstance(data, AgentResponse): for message in data.messages or []: diff --git a/packages/uipath-agent-framework/tests/test_hitl_e2e.py b/packages/uipath-agent-framework/tests/test_hitl_e2e.py index 775f983..cefe734 100644 --- a/packages/uipath-agent-framework/tests/test_hitl_e2e.py +++ b/packages/uipath-agent-framework/tests/test_hitl_e2e.py @@ -15,6 +15,7 @@ """ import asyncio +import json import os import tempfile from typing import Any @@ -33,6 +34,7 @@ from uipath.runtime import UiPathResumableRuntime from uipath.runtime.chat.runtime import UiPathChatRuntime from uipath.runtime.events import UiPathRuntimeEvent +from uipath.runtime.events.state import UiPathRuntimeStateEvent, UiPathRuntimeStatePhase from uipath.runtime.result import UiPathRuntimeResult, UiPathRuntimeStatus from uipath.runtime.resumable.trigger import ( UiPathResumeTrigger, @@ -572,3 +574,324 @@ async def mock_create(**kwargs: Any): finally: await storage.dispose() os.unlink(tmp_path) + + async def test_multi_turn_after_hitl_resume(self): + """Second fresh turn after HITL resume should not fail with stale session. + + Reproduces the bug where the session saved after HITL resume was stale + (missing tool results), causing OpenAI to reject the next turn with: + "An assistant message with 'tool_calls' must be followed by tool messages" + + Flow: + Turn 1: triage -> billing -> transfer_funds (HITL approve) -> complete + Turn 2: triage -> text response -> complete (must not fail) + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + count = call_count.get("triage", 0) + call_count["triage"] = count + 1 + if count == 0: + # Turn 1: route to billing + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + # Turn 2: respond directly (no tool call) + return make_mock_response( + "How else can I help?", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + chat_runtime, chat_bridge, storage = await _create_hitl_runtime_stack( + agent, "test-hitl-multi-turn", tmp_path, auto_approve=True + ) + + # Turn 1: HITL approval flow + result1 = await chat_runtime.execute({"messages": []}) + assert result1.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 1 failed: {result1.status}" + ) + + # Verify session was saved with complete history (including tool results). + # The bug was that only the HITL-suspension session was persisted + # (missing tool results), causing the next turn to fail. + session_data = await storage.get_value( + "test-hitl-multi-turn", "session", "data" + ) + assert session_data is not None, "Session was not saved after HITL resume" + + session_state = session_data.get("state", {}) + # Find all messages in any provider key + all_messages = [] + for provider_key, provider_data in session_state.items(): + if not isinstance(provider_data, dict): + continue + msgs = provider_data.get("messages", []) + if not isinstance(msgs, list): + continue + all_messages = msgs + break + + assert len(all_messages) > 0, ( + f"No messages found in session state. " + f"State keys: {list(session_state.keys())}. " + f"Full state: {json.dumps(session_data, indent=2, default=str)[:3000]}" + ) + + # Verify: every assistant message with a function_call must + # eventually be followed by a tool message with function_result + # for the same call_id. This is the exact invariant that OpenAI + # enforces and that breaks with stale sessions. + def get_function_call_ids(msg_data: dict) -> list[str]: + """Get call_ids of function_call contents in a message.""" + ids = [] + for c in msg_data.get("contents", []): + if isinstance(c, dict) and c.get("type") == "function_call": + cid = c.get("call_id") + if cid: + ids.append(cid) + return ids + + def get_function_result_ids(msg_data: dict) -> list[str]: + """Get call_ids of function_result contents in a message.""" + ids = [] + for c in msg_data.get("contents", []): + if isinstance(c, dict) and c.get("type") == "function_result": + cid = c.get("call_id") + if cid: + ids.append(cid) + return ids + + pending_call_ids: set[str] = set() + for i, msg in enumerate(all_messages): + if not isinstance(msg, dict): + continue + pending_call_ids.update(get_function_call_ids(msg)) + for rid in get_function_result_ids(msg): + pending_call_ids.discard(rid) + + assert len(pending_call_ids) == 0, ( + f"Session has orphaned function_calls (no function_result): " + f"{pending_call_ids}. This will cause OpenAI to reject the next " + f"turn. Messages:\n" + + json.dumps(all_messages, indent=2, default=str)[:5000] + ) + + # Turn 2: fresh turn after HITL — must not fail with stale session + chat_bridge.interrupts.clear() + result2 = await chat_runtime.execute({"messages": []}) + assert result2.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 2 failed: {result2.status}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_multi_turn_after_hitl_resume_streaming(self): + """Streaming: second fresh turn after HITL resume should not fail. + + Same as test_multi_turn_after_hitl_resume but using the streaming path. + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "X", "to_account": "Y", "amount": 200.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer done.", stream=is_stream) + elif "triage" in system_msg.lower(): + count = call_count.get("triage", 0) + call_count["triage"] = count + 1 + if count == 0: + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response( + "Anything else?", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + chat_runtime, chat_bridge, storage = await _create_hitl_runtime_stack( + agent, "test-hitl-stream-multi", tmp_path, auto_approve=True + ) + + # Turn 1: streaming HITL flow + events1: list[UiPathRuntimeEvent] = [] + async for event in chat_runtime.stream({"messages": []}): + events1.append(event) + results1 = [e for e in events1 if isinstance(e, UiPathRuntimeResult)] + assert len(results1) >= 1 + assert results1[-1].status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 1 failed: {results1[-1].status}" + ) + + # Turn 2: fresh streaming turn after HITL + chat_bridge.interrupts.clear() + events2: list[UiPathRuntimeEvent] = [] + async for event in chat_runtime.stream({"messages": []}): + events2.append(event) + results2 = [e for e in events2 if isinstance(e, UiPathRuntimeResult)] + assert len(results2) >= 1 + assert results2[-1].status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 2 failed: {results2[-1].status}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_tool_node_completed_after_hitl_resume(self): + """Tool node should emit both STARTED and COMPLETED state events across HITL. + + Reproduces the bug where billing_agent_tools emitted STARTED (before HITL + suspension) but never COMPLETED (after HITL resume), because the framework + doesn't surface function_result in output/executor_completed events for + handoff workflows. The fix synthesizes COMPLETED at the start of resume. + + Expected event sequence: + customer_support STARTED + triage STARTED + triage COMPLETED + billing_agent STARTED + billing_agent_tools STARTED <-- before HITL suspension + [HITL interrupt + resume] + billing_agent_tools COMPLETED <-- synthesized on resume + billing_agent STARTED <-- resume executor + billing_agent COMPLETED + customer_support COMPLETED + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + chat_runtime, chat_bridge, storage = await _create_hitl_runtime_stack( + agent, "test-hitl-tool-events", tmp_path, auto_approve=True + ) + + # Collect all events from streaming + events: list[UiPathRuntimeEvent] = [] + async for event in chat_runtime.stream({"messages": []}): + events.append(event) + + # Extract state events for analysis + state_events = [ + e for e in events if isinstance(e, UiPathRuntimeStateEvent) + ] + + # Build a summary: node_name -> list of phases + node_phases: dict[str, list[str]] = {} + for se in state_events: + if se.node_name: + node_phases.setdefault(se.node_name, []).append(se.phase.value) + + # billing_agent_tools MUST have both STARTED and COMPLETED + tools_node = "billing_agent_tools" + assert tools_node in node_phases, ( + f"{tools_node} not found in state events. " + f"Nodes seen: {list(node_phases.keys())}" + ) + tools_phases = node_phases[tools_node] + assert UiPathRuntimeStatePhase.STARTED.value in tools_phases, ( + f"{tools_node} missing STARTED. Phases: {tools_phases}" + ) + assert UiPathRuntimeStatePhase.COMPLETED.value in tools_phases, ( + f"{tools_node} missing COMPLETED after HITL resume. " + f"Phases: {tools_phases}. " + f"All state events: " + + ", ".join( + f"{e.node_name}:{e.phase.value}" + for e in state_events + if e.node_name + ) + ) + + # Verify COMPLETED comes after STARTED + started_idx = tools_phases.index(UiPathRuntimeStatePhase.STARTED.value) + completed_idx = tools_phases.index(UiPathRuntimeStatePhase.COMPLETED.value) + assert completed_idx > started_idx, ( + f"{tools_node} COMPLETED ({completed_idx}) should come after " + f"STARTED ({started_idx})" + ) + + # Final result should be successful + results = [e for e in events if isinstance(e, UiPathRuntimeResult)] + assert len(results) >= 1 + assert results[-1].status == UiPathRuntimeStatus.SUCCESSFUL + finally: + await storage.dispose() + os.unlink(tmp_path) diff --git a/packages/uipath-agent-framework/uv.lock b/packages/uipath-agent-framework/uv.lock index dac2194..4a75135 100644 --- a/packages/uipath-agent-framework/uv.lock +++ b/packages/uipath-agent-framework/uv.lock @@ -2460,7 +2460,7 @@ wheels = [ [[package]] name = "uipath-agent-framework" -version = "0.0.6" +version = "0.0.7" source = { editable = "." } dependencies = [ { name = "agent-framework-core" }, From 83c36bf54837ba913c66f03a8c565b1304a1e5dd Mon Sep 17 00:00:00 2001 From: Cristian Pufu Date: Sat, 21 Feb 2026 13:02:15 +0200 Subject: [PATCH 2/2] fix: prevent breakpoint+HITL infinite loop and duplicate events MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Load breakpoint skip counts on HITL resume so the executor passes through without re-firing the breakpoint (prevents the infinite breakpoint → HITL → breakpoint loop on same node) - Clear _last_checkpoint_id after loading breakpoint state to prevent the next fresh turn from being mistaken for a breakpoint resume - Detect stale checkpoints from previous turns by capturing a baseline before workflow.run() and comparing after breakpoint fires - Fall back to _resumed_from_checkpoint_id (instead of None) when no new checkpoint was created, preventing replay from scratch that caused duplicate handoff_to_billing_agent events (4x) Co-Authored-By: Claude Opus 4.6 --- .../samples/hitl-workflow/pyproject.toml | 4 - .../uipath_agent_framework/runtime/runtime.py | 120 ++- .../tests/test_hitl_e2e.py | 706 +++++++++++++++++- .../tests/test_streaming.py | 2 + 4 files changed, 781 insertions(+), 51 deletions(-) diff --git a/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml index 713932b..084a923 100644 --- a/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml +++ b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml @@ -18,7 +18,3 @@ dev = [ [tool.uv] prerelease = "allow" - -[tool.uv.sources] -uipath-dev = { path = "../../../../../uipath-dev-python", editable = true } -uipath-agent-framework = { path = "../../", editable = true } diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py index a8ff68b..9595c84 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py @@ -99,11 +99,14 @@ async def _save_breakpoint_state( to run before re-arming its breakpoint. The count is incremented every time the same executor hits a breakpoint again (cyclic graphs, GroupChat orchestrators). + + checkpoint_id may be None when the breakpoint fired before any + new checkpoint was created (e.g. the very first executor of a + fresh turn). In that case the resume path will replay from + original_input with skip counts instead of restoring a checkpoint. """ if not self._resumable_storage: return - if checkpoint_id is None: - checkpoint_id = await self._get_latest_checkpoint_id() state = { "skip_nodes": dict(self._breakpoint_skip_nodes), "last_breakpoint_node": self._last_breakpoint_node, @@ -308,13 +311,31 @@ async def execute( workflow = self.agent.workflow + # Capture the latest checkpoint BEFORE workflow.run() so we can + # detect whether a NEW checkpoint was created during this execution. + # Without this, breakpoints that fire before any new checkpoint + # (e.g. the first executor of turn 2) would save a stale + # checkpoint from the previous turn, causing the resume to + # restore completed state instead of replaying from input. + baseline_checkpoint_id = await self._get_latest_checkpoint_id() + if is_resuming and input is not None: # HITL resume: checkpoint restores executor state (including session) self._resume_responses = await self._convert_resume_responses(input) - # Inject breakpoints (no skip needed for HITL resume) + # Inject breakpoints with accumulated skip counts so that + # breakpoints don't re-fire on the same executor after HITL + # approval (prevents breakpoint→HITL→breakpoint loop). if options and options.breakpoints: - inject_breakpoint_middleware(self.agent, options.breakpoints) + await self._load_breakpoint_state() + inject_breakpoint_middleware( + self.agent, + options.breakpoints, + self._get_breakpoint_skip(), + ) + # _load_breakpoint_state sets _last_checkpoint_id as a + # side effect. Clear it so it doesn't contaminate later runs. + self._last_checkpoint_id = None if self._resume_responses: checkpoint_id = await self._get_latest_checkpoint_id() @@ -419,8 +440,21 @@ async def execute( ) self._last_breakpoint_node = node_id original_input = self._prepare_input(input) if not is_resuming else "" + # Only save checkpoint_id if it was created during THIS run. + # If latest == baseline, no new checkpoint was created (e.g. + # breakpoint on the first executor of a fresh turn) — save + # the checkpoint we resumed from (if any) so we don't lose + # it and replay from scratch on the next resume. + # For fresh turns _resumed_from_checkpoint_id is None, which + # correctly prevents using a stale checkpoint from the + # previous turn. + effective_checkpoint = ( + latest_checkpoint + if latest_checkpoint != baseline_checkpoint_id + else self._resumed_from_checkpoint_id + ) await self._save_breakpoint_state( - original_input, checkpoint_id=latest_checkpoint + original_input, checkpoint_id=effective_checkpoint ) return create_breakpoint_result(e) return self._create_suspended_result(e) @@ -444,9 +478,20 @@ async def stream( self._resume_responses = await self._convert_resume_responses(input) user_input = self._prepare_input(None) - # Inject breakpoints (no skip needed for HITL resume) + # Inject breakpoints with accumulated skip counts so that + # breakpoints don't re-fire on the same executor after HITL + # approval (prevents breakpoint→HITL→breakpoint loop). if options and options.breakpoints: - inject_breakpoint_middleware(self.agent, options.breakpoints) + await self._load_breakpoint_state() + inject_breakpoint_middleware( + self.agent, + options.breakpoints, + self._get_breakpoint_skip(), + ) + # _load_breakpoint_state sets _last_checkpoint_id as a + # side effect. Clear it so _stream_workflow doesn't + # mistake a subsequent fresh run for a breakpoint resume. + self._last_checkpoint_id = None elif is_resuming: # Breakpoint resume: restore original_input and session @@ -465,8 +510,9 @@ async def stream( ) else: - # Fresh run + # Fresh run — clear stale resume state from previous turns self._resume_responses = None + self._last_checkpoint_id = None user_input = self._prepare_input(input) # Load session for multi-turn conversation history @@ -519,6 +565,10 @@ async def _stream_workflow( ) self._pending_tool_nodes.clear() + # Capture the latest checkpoint BEFORE workflow.run() so we can + # detect whether a NEW checkpoint was created during this execution. + baseline_checkpoint_id = await self._get_latest_checkpoint_id() + # Choose workflow.run() mode based on resume type if self._resume_responses: # HITL resume: pass responses to workflow with checkpoint @@ -609,20 +659,21 @@ async def _stream_workflow( ): if tool_event.phase not in emitted_phases: # Track pending tool nodes - if ( - tool_event.phase - == UiPathRuntimeStatePhase.STARTED - ): - self._pending_tool_nodes.add( - tool_event.node_name - ) - elif ( - tool_event.phase - == UiPathRuntimeStatePhase.COMPLETED - ): - self._pending_tool_nodes.discard( - tool_event.node_name - ) + if tool_event.node_name: + if ( + tool_event.phase + == UiPathRuntimeStatePhase.STARTED + ): + self._pending_tool_nodes.add( + tool_event.node_name + ) + elif ( + tool_event.phase + == UiPathRuntimeStatePhase.COMPLETED + ): + self._pending_tool_nodes.discard( + tool_event.node_name + ) yield tool_event yield UiPathRuntimeStateEvent( payload=self._serialize_event_data( @@ -637,14 +688,15 @@ async def _stream_workflow( event.data, executor_id ) for tool_event in tool_events: - executor_tool_phases.setdefault( - executor_id, set() - ).add(tool_event.phase) + executor_tool_phases.setdefault(executor_id, set()).add( + tool_event.phase + ) # Track pending tool nodes across stream iterations - if tool_event.phase == UiPathRuntimeStatePhase.STARTED: - self._pending_tool_nodes.add(tool_event.node_name) - elif tool_event.phase == UiPathRuntimeStatePhase.COMPLETED: - self._pending_tool_nodes.discard(tool_event.node_name) + if tool_event.node_name: + if tool_event.phase == UiPathRuntimeStatePhase.STARTED: + self._pending_tool_nodes.add(tool_event.node_name) + elif tool_event.phase == UiPathRuntimeStatePhase.COMPLETED: + self._pending_tool_nodes.discard(tool_event.node_name) yield tool_event for msg_event in self._extract_workflow_messages(event.data): yield UiPathRuntimeMessageEvent(payload=msg_event) @@ -691,8 +743,16 @@ async def _stream_workflow( self._breakpoint_skip_nodes.get(node_id, 0) + 1 ) self._last_breakpoint_node = node_id + # Only save checkpoint_id if it was created during THIS run. + # Fall back to the checkpoint we resumed from (if any) to + # avoid replaying from scratch on the next resume. + effective_checkpoint = ( + latest_checkpoint + if latest_checkpoint != baseline_checkpoint_id + else self._resumed_from_checkpoint_id + ) await self._save_breakpoint_state( - user_input, checkpoint_id=latest_checkpoint + user_input, checkpoint_id=effective_checkpoint ) yield create_breakpoint_result(e) else: diff --git a/packages/uipath-agent-framework/tests/test_hitl_e2e.py b/packages/uipath-agent-framework/tests/test_hitl_e2e.py index cefe734..9887b8b 100644 --- a/packages/uipath-agent-framework/tests/test_hitl_e2e.py +++ b/packages/uipath-agent-framework/tests/test_hitl_e2e.py @@ -18,8 +18,8 @@ import json import os import tempfile -from typing import Any -from unittest.mock import AsyncMock, MagicMock +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, Mock import pytest from agent_framework.openai import OpenAIChatClient @@ -33,8 +33,16 @@ from uipath.platform.resume_triggers import UiPathResumeTriggerHandler from uipath.runtime import UiPathResumableRuntime from uipath.runtime.chat.runtime import UiPathChatRuntime -from uipath.runtime.events import UiPathRuntimeEvent -from uipath.runtime.events.state import UiPathRuntimeStateEvent, UiPathRuntimeStatePhase +from uipath.runtime.debug import ( + UiPathDebugProtocol, + UiPathDebugRuntime, +) +from uipath.runtime.events import ( + UiPathRuntimeEvent, + UiPathRuntimeStateEvent, + UiPathRuntimeStatePhase, +) +from uipath.runtime.events.state import UiPathRuntimeMessageEvent from uipath.runtime.result import UiPathRuntimeResult, UiPathRuntimeStatus from uipath.runtime.resumable.trigger import ( UiPathResumeTrigger, @@ -614,9 +622,7 @@ async def mock_create(**kwargs: Any): ) else: # Turn 2: respond directly (no tool call) - return make_mock_response( - "How else can I help?", stream=is_stream - ) + return make_mock_response("How else can I help?", stream=is_stream) else: return make_mock_response("OK", stream=is_stream) @@ -649,7 +655,7 @@ async def mock_create(**kwargs: Any): session_state = session_data.get("state", {}) # Find all messages in any provider key all_messages = [] - for provider_key, provider_data in session_state.items(): + for _provider_key, provider_data in session_state.items(): if not isinstance(provider_data, dict): continue msgs = provider_data.get("messages", []) @@ -668,7 +674,7 @@ async def mock_create(**kwargs: Any): # eventually be followed by a tool message with function_result # for the same call_id. This is the exact invariant that OpenAI # enforces and that breaks with stale sessions. - def get_function_call_ids(msg_data: dict) -> list[str]: + def get_function_call_ids(msg_data: dict[str, Any]) -> list[str]: """Get call_ids of function_call contents in a message.""" ids = [] for c in msg_data.get("contents", []): @@ -678,7 +684,7 @@ def get_function_call_ids(msg_data: dict) -> list[str]: ids.append(cid) return ids - def get_function_result_ids(msg_data: dict) -> list[str]: + def get_function_result_ids(msg_data: dict[str, Any]) -> list[str]: """Get call_ids of function_result contents in a message.""" ids = [] for c in msg_data.get("contents", []): @@ -689,7 +695,7 @@ def get_function_result_ids(msg_data: dict) -> list[str]: return ids pending_call_ids: set[str] = set() - for i, msg in enumerate(all_messages): + for msg in all_messages: if not isinstance(msg, dict): continue pending_call_ids.update(get_function_call_ids(msg)) @@ -744,9 +750,7 @@ async def mock_create(**kwargs: Any): "handoff_to_billing_agent", stream=is_stream ) else: - return make_mock_response( - "Anything else?", stream=is_stream - ) + return make_mock_response("Anything else?", stream=is_stream) else: return make_mock_response("OK", stream=is_stream) @@ -849,9 +853,7 @@ async def mock_create(**kwargs: Any): events.append(event) # Extract state events for analysis - state_events = [ - e for e in events if isinstance(e, UiPathRuntimeStateEvent) - ] + state_events = [e for e in events if isinstance(e, UiPathRuntimeStateEvent)] # Build a summary: node_name -> list of phases node_phases: dict[str, list[str]] = {} @@ -895,3 +897,673 @@ async def mock_create(**kwargs: Any): finally: await storage.dispose() os.unlink(tmp_path) + + +# --------------------------------------------------------------------------- +# Debug bridge mock +# --------------------------------------------------------------------------- + + +def _make_debug_bridge(**overrides: Any) -> UiPathDebugProtocol: + """Create a mock debug bridge with sensible defaults.""" + bridge: Mock = Mock(spec=UiPathDebugProtocol) + bridge.connect = AsyncMock() + bridge.disconnect = AsyncMock() + bridge.emit_execution_started = AsyncMock() + bridge.emit_execution_completed = AsyncMock() + bridge.emit_execution_error = AsyncMock() + bridge.emit_execution_suspended = AsyncMock() + bridge.emit_breakpoint_hit = AsyncMock() + bridge.emit_state_update = AsyncMock() + bridge.emit_execution_resumed = AsyncMock() + bridge.wait_for_resume = AsyncMock(return_value=None) + bridge.wait_for_terminate = AsyncMock() + bridge.get_breakpoints = Mock(return_value=[]) + for k, v in overrides.items(): + setattr(bridge, k, v) + return cast(UiPathDebugProtocol, bridge) + + +# --------------------------------------------------------------------------- +# Breakpoint + HITL combined tests +# --------------------------------------------------------------------------- + +# Safety limit: if the debug loop exceeds this many resume calls, +# the test fails — this means breakpoints are stuck in a loop. +MAX_RESUME_CALLS = 20 + + +@pytest.mark.asyncio(loop_scope="class") +class TestBreakpointAndHitlCombined: + """Tests for the interaction between breakpoints and HITL tool approval. + + Uses the full runtime stack: + UiPathDebugRuntime -> UiPathChatRuntime -> UiPathResumableRuntime + -> UiPathAgentFrameworkRuntime + + Reproduces the bug where breakpoint + @requires_approval on the same + node causes an infinite loop: breakpoint → continue → HITL → approve → + breakpoint (again!) → continue → HITL → ... + """ + + @pytest.fixture(autouse=True) + async def _settle_framework(self): + """Allow framework background tasks to complete between tests.""" + yield + await asyncio.sleep(0.2) + + async def test_breakpoint_then_hitl_does_not_loop(self): + """Breakpoint + HITL on same node must complete without looping. + + Flow: + 1. triage → billing_agent (breakpoint fires before billing_agent) + 2. User continues breakpoint + 3. billing_agent runs, calls transfer_funds (HITL fires) + 4. User approves tool via chat bridge + 5. Execution completes — NO more breakpoints + + Without the fix, step 5 would trigger another breakpoint on + billing_agent, creating an infinite breakpoint→HITL→breakpoint loop. + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-bp-hitl-loop" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + base_runtime.chat = MagicMock() + base_runtime.chat.map_messages_to_input.return_value = ( + "Transfer $100 from A to B" + ) + base_runtime.chat.map_streaming_content.return_value = [] + base_runtime.chat.close_message.return_value = [] + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + # Debug bridge: breakpoint on billing_agent + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes. " + f"Breakpoint hits: " + f"{cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count}" + ) + return None + + debug_bridge = _make_debug_bridge() + cast(Mock, debug_bridge.get_breakpoints).return_value = ["billing_agent"] + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + # Execute the full flow + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count + + # Must complete successfully (not loop forever) + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, Breakpoint hits: {bp_count}" + ) + + # Breakpoint should have been hit at least once. + # The exact count depends on how many times the HandoffBuilder + # calls execute() during checkpoint restore, but it MUST NOT + # loop forever (which is caught by MAX_RESUME_CALLS). + assert bp_count >= 1, f"Expected at least 1 breakpoint hit, got {bp_count}" + + # HITL should have been handled by chat bridge + assert len(chat_bridge.interrupts) >= 1, ( + f"Expected at least 1 HITL interrupt, got {len(chat_bridge.interrupts)}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_breakpoint_on_all_nodes_with_hitl(self): + """Breakpoints on ALL nodes + HITL on billing_agent must complete. + + Same as above but with breakpoints="*", which means triage also + gets a breakpoint. Verifies the full sequence: + 1. Breakpoint on triage → continue + 2. triage hands off to billing_agent + 3. Breakpoint on billing_agent → continue + 4. billing_agent calls transfer_funds → HITL → approve + 5. Execution completes + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "X", "to_account": "Y", "amount": 50.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer done.", stream=is_stream) + elif "triage" in system_msg.lower(): + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-bp-all-hitl" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + base_runtime.chat = MagicMock() + base_runtime.chat.map_messages_to_input.return_value = ( + "Transfer $50 from X to Y" + ) + base_runtime.chat.map_streaming_content.return_value = [] + base_runtime.chat.close_message.return_value = [] + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes. " + f"Breakpoint hits: " + f"{cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count}" + ) + return None + + debug_bridge = _make_debug_bridge() + cast(Mock, debug_bridge.get_breakpoints).return_value = "*" + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, Breakpoint hits: {bp_count}" + ) + + # At least 1 breakpoint should have been hit + assert bp_count >= 1, f"Expected at least 1 breakpoint hit, got {bp_count}" + + # HITL should have been handled (1 interrupt for transfer_funds) + assert len(chat_bridge.interrupts) == 1, ( + f"Expected 1 HITL interrupt, got {len(chat_bridge.interrupts)}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_second_turn_after_breakpoint_and_hitl(self): + """Second turn after breakpoint+HITL must actually run the workflow. + + Reproduces the bug where a breakpoint on the first executor of + turn 2 saved a stale checkpoint_id from turn 1 (because no new + checkpoint existed yet). On resume, workflow.run(checkpoint_id=stale) + restored turn 1's completed state and returned immediately with + no work done — the OTEL trace showed empty workflow.run spans. + + Flow: + Turn 1: breakpoints=* → triage → billing → HITL → approve → complete + Turn 2: breakpoints=* → triage breakpoint → continue → triage + responds "How else can I help?" → complete + + The key assertion: triage's LLM must have been called during turn 2 + (call_count["triage"] increases). Without the fix, the resume + restores turn 1's state and triage is never called. + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + count = call_count.get("triage", 0) + call_count["triage"] = count + 1 + if count == 0: + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + # Turn 2: triage responds directly + return make_mock_response("How else can I help?", stream=is_stream) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-bp-hitl-turn2" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + base_runtime.chat = MagicMock() + base_runtime.chat.map_messages_to_input.return_value = ( + "Transfer $100 from A to B" + ) + base_runtime.chat.map_streaming_content.return_value = [] + base_runtime.chat.close_message.return_value = [] + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes" + ) + return None + + debug_bridge = _make_debug_bridge() + cast(Mock, debug_bridge.get_breakpoints).return_value = "*" + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + # Turn 1: breakpoint + HITL + result1 = await debug_runtime.execute({"messages": []}) + assert result1.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 1 failed: {result1.status}" + ) + + turn1_bp_count = cast( + AsyncMock, debug_bridge.emit_breakpoint_hit + ).await_count + assert turn1_bp_count >= 1, "Turn 1 should have hit at least 1 breakpoint" + + # Record LLM call counts after turn 1 + triage_calls_after_turn1 = call_count.get("triage", 0) + assert triage_calls_after_turn1 >= 1, ( + "Triage should have been called at least once during turn 1" + ) + + # Turn 2: fresh turn — should complete normally with breakpoints + chat_bridge.interrupts.clear() + resume_count[0] = 0 + cast(AsyncMock, debug_bridge.emit_breakpoint_hit).reset_mock() + + result2 = await debug_runtime.execute({"messages": []}) + assert result2.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 2 failed: {result2.status}. " + f"Resumes: {resume_count[0]}, " + f"BPs: {cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count}" + ) + + turn2_bp_count = cast( + AsyncMock, debug_bridge.emit_breakpoint_hit + ).await_count + # Turn 2 should also hit breakpoints (fresh run with breakpoints) + assert turn2_bp_count >= 1, ( + f"Turn 2 should have hit at least 1 breakpoint, got {turn2_bp_count}" + ) + + # KEY ASSERTION: triage's LLM was actually called during turn 2. + # Without the stale-checkpoint fix, the resume restores turn 1's + # completed state and triage is never invoked again. + triage_calls_after_turn2 = call_count.get("triage", 0) + assert triage_calls_after_turn2 > triage_calls_after_turn1, ( + f"Turn 2 must invoke triage LLM. " + f"Calls after turn 1: {triage_calls_after_turn1}, " + f"after turn 2: {triage_calls_after_turn2}. " + f"This means the resume used a stale checkpoint from turn 1." + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_no_duplicate_tool_calls_on_breakpoint_resume(self): + """Breakpoint resumes must not emit duplicate tool call message events. + + Full E2E streaming test with BOTH breakpoints=* AND HITL approval. + Reproduces the bug where checkpoint restore replays output events for + already-completed executors. The chat mapper processed each replayed + Content(type='function_call') as new, emitting duplicate ToolCallStart + events — visible as repeated 'handoff_to_billing_agent' in the UI. + + Full runtime stack: + UiPathDebugRuntime -> UiPathChatRuntime -> UiPathResumableRuntime + -> UiPathAgentFrameworkRuntime + + Expected flow (breakpoints="*"): + 1. Breakpoint on triage → debug bridge continue + 2. triage runs, calls handoff_to_billing_agent → handoff to billing + 3. Breakpoint on billing_agent → debug bridge continue + 4. billing_agent runs, calls transfer_funds → HITL suspends + 5. Chat bridge auto-approves → billing completes + 6. More breakpoints may fire on fan-out agents → continues + 7. Workflow completes SUCCESSFUL + + Assertions at every step: + - Breakpoints fired (emit_breakpoint_hit called) + - HITL handled (chat_bridge.interrupts) + - handoff_to_billing_agent ToolCallStart appears exactly ONCE + - transfer_funds ToolCallStart appears exactly ONCE + - Final result is SUCCESSFUL + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + # Check triage BEFORE billing — triage's system prompt contains + # "billing_agent" in the handoff instructions, so a naive + # "billing" check would match both agents. + if "triage" in system_msg.lower(): + call_count["triage"] = call_count.get("triage", 0) + 1 + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + elif "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-no-dup-tool-calls" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes" + ) + return None + + debug_bridge = _make_debug_bridge() + # Breakpoints on ALL nodes — maximizes checkpoint replays + cast(Mock, debug_bridge.get_breakpoints).return_value = "*" + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + # ---- Stream and collect all events ---- + all_events: list[UiPathRuntimeEvent] = [] + async for event in debug_runtime.stream({"messages": []}): + all_events.append(event) + + # ---- Step 1: Verify execution completed successfully ---- + results = [e for e in all_events if isinstance(e, UiPathRuntimeResult)] + assert len(results) >= 1, "No UiPathRuntimeResult events found" + final_result = results[-1] + assert final_result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {final_result.status}. " + f"Total results: {len(results)}, " + f"Resumes: {resume_count[0]}" + ) + + # ---- Step 2: Verify breakpoints fired ---- + bp_hit_count = cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count + assert bp_hit_count >= 1, ( + f"Expected at least 1 breakpoint hit with breakpoints='*', " + f"got {bp_hit_count}" + ) + + # ---- Step 3: Verify HITL was handled ---- + assert len(chat_bridge.interrupts) >= 1, ( + f"Expected at least 1 HITL interrupt for transfer_funds, " + f"got {len(chat_bridge.interrupts)}" + ) + + # ---- Step 4: Verify LLM calls ---- + assert call_count.get("triage", 0) >= 1, ( + "Triage LLM should have been called at least once" + ) + assert call_count.get("billing", 0) >= 1, ( + "Billing LLM should have been called at least once" + ) + + # ---- Step 5: Count ToolCallStart events per tool name ---- + tool_call_starts: dict[str, int] = {} + tool_call_ends: dict[str, int] = {} + for event in all_events: + if not isinstance(event, UiPathRuntimeMessageEvent): + continue + payload = event.payload + if not hasattr(payload, "tool_call") or not payload.tool_call: + continue + tc = payload.tool_call + if hasattr(tc, "start") and tc.start: + name = tc.start.tool_name + tool_call_starts[name] = tool_call_starts.get(name, 0) + 1 + if hasattr(tc, "end") and tc.end: + tc_id = tc.tool_call_id or "unknown" + tool_call_ends[tc_id] = tool_call_ends.get(tc_id, 0) + 1 + + # ---- Step 6: Assert no duplicate tool calls ---- + assert "handoff_to_billing_agent" in tool_call_starts, ( + f"handoff_to_billing_agent not found in tool calls. " + f"Tool calls seen: {tool_call_starts}" + ) + assert tool_call_starts["handoff_to_billing_agent"] == 1, ( + f"handoff_to_billing_agent should appear exactly once, " + f"but appeared {tool_call_starts['handoff_to_billing_agent']} " + f"times. All tool calls: {tool_call_starts}. " + f"This means checkpoint replay emitted duplicate events." + ) + + assert "transfer_funds" in tool_call_starts, ( + f"transfer_funds not found in tool calls. " + f"Tool calls seen: {tool_call_starts}" + ) + assert tool_call_starts["transfer_funds"] == 1, ( + f"transfer_funds should appear exactly once, " + f"but appeared {tool_call_starts['transfer_funds']} times. " + f"All tool calls: {tool_call_starts}." + ) + + # ---- Step 7: Verify state events are well-formed ---- + state_events = [ + e for e in all_events if isinstance(e, UiPathRuntimeStateEvent) + ] + node_phases: dict[str, list[str]] = {} + for se in state_events: + if se.node_name: + node_phases.setdefault(se.node_name, []).append(se.phase.value) + # The top-level agent node must have started and completed + agent_node = "customer_support" + assert agent_node in node_phases, ( + f"Top-level agent '{agent_node}' not found in state events. " + f"Nodes: {list(node_phases.keys())}" + ) + assert "started" in node_phases[agent_node], ( + "Agent node missing STARTED phase" + ) + assert "completed" in node_phases[agent_node], ( + f"Agent node missing COMPLETED phase. Phases: {node_phases[agent_node]}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) diff --git a/packages/uipath-agent-framework/tests/test_streaming.py b/packages/uipath-agent-framework/tests/test_streaming.py index ac7e890..6c29267 100644 --- a/packages/uipath-agent-framework/tests/test_streaming.py +++ b/packages/uipath-agent-framework/tests/test_streaming.py @@ -640,6 +640,7 @@ async def test_checkpoint_storage_passed_to_workflow_run_stream(self): agent = WorkflowAgent(workflow=workflow, name="chat_wf") mock_checkpoint_storage = MagicMock() + mock_checkpoint_storage.get_latest = AsyncMock(return_value=None) captured_kwargs: list[dict[str, Any]] = [] def mock_run(**kwargs): @@ -679,6 +680,7 @@ async def test_checkpoint_storage_passed_to_workflow_run_execute(self): agent = WorkflowAgent(workflow=workflow, name="exec_wf") mock_checkpoint_storage = MagicMock() + mock_checkpoint_storage.get_latest = AsyncMock(return_value=None) captured_kwargs: list[dict[str, Any]] = [] async def mock_run(**kwargs):