diff --git a/packages/uipath-agent-framework/pyproject.toml b/packages/uipath-agent-framework/pyproject.toml index e25bf84..fa0034a 100644 --- a/packages/uipath-agent-framework/pyproject.toml +++ b/packages/uipath-agent-framework/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "uipath-agent-framework" -version = "0.0.6" +version = "0.0.7" description = "Python SDK that enables developers to build and deploy Microsoft Agent Framework agents to the UiPath Cloud Platform" readme = "README.md" requires-python = ">=3.11" diff --git a/packages/uipath-agent-framework/samples/hitl-workflow/main.py b/packages/uipath-agent-framework/samples/hitl-workflow/main.py index 2dc1d91..7feb89a 100644 --- a/packages/uipath-agent-framework/samples/hitl-workflow/main.py +++ b/packages/uipath-agent-framework/samples/hitl-workflow/main.py @@ -54,7 +54,7 @@ def issue_refund(order_id: str, amount: float, reason: str) -> str: return f"Refund of ${amount:.2f} issued for order {order_id}: {reason}" -client = UiPathOpenAIChatClient(model="gpt-5-mini-2025-08-07") +client = UiPathOpenAIChatClient() triage = client.as_agent( name="triage", diff --git a/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml index 7b9e2d2..084a923 100644 --- a/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml +++ b/packages/uipath-agent-framework/samples/hitl-workflow/pyproject.toml @@ -18,4 +18,3 @@ dev = [ [tool.uv] prerelease = "allow" - diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py index 6a4ca6c..7861b27 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/chat/openai.py @@ -67,7 +67,7 @@ class UiPathOpenAIChatClient(OpenAIChatClient): from uipath_agent_framework.chat import UiPathOpenAIChatClient - client = UiPathOpenAIChatClient(model="gpt-4o-mini") + client = UiPathOpenAIChatClient(model="gpt-4.1-mini-2025-04-14") agent = client.as_agent( name="assistant", instructions="You are a helpful assistant.", @@ -75,7 +75,7 @@ class UiPathOpenAIChatClient(OpenAIChatClient): ) """ - def __init__(self, model: str = "gpt-4o-mini", **kwargs: Any): + def __init__(self, model: str = "gpt-4.1-mini-2025-04-14", **kwargs: Any): uipath_url, token = get_uipath_config() gateway_url = build_gateway_url("openai", model, uipath_url) diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py index e034119..37b0516 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/messages.py @@ -184,12 +184,20 @@ def map_streaming_content( def close_message(self) -> list[UiPathConversationMessageEvent]: """Close the current message if open. Safety net for end of stream.""" + events: list[UiPathConversationMessageEvent] = [] + # Emit ToolCallEnd for any tool calls that were started but never + # completed (e.g. HITL suspension interrupted before function_result). + if self._pending_tool_calls: + for tool_call_id, message_id in self._pending_tool_calls.items(): + events.append( + self._make_tool_call_end_event(message_id, tool_call_id, {}) + ) + self._pending_tool_calls.clear() if self._message_started and self._current_message_id: - events = [self._make_message_end_event(self._current_message_id)] + events.append(self._make_message_end_event(self._current_message_id)) self._message_started = False self._current_message_id = None - return events - return [] + return events @staticmethod def _extract_text_from_content(content: Content) -> str: diff --git a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py index a2d373d..9595c84 100644 --- a/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py +++ b/packages/uipath-agent-framework/src/uipath_agent_framework/runtime/runtime.py @@ -6,6 +6,7 @@ from agent_framework import ( AgentExecutor, + AgentExecutorResponse, AgentResponse, AgentResponseUpdate, AgentSession, @@ -67,6 +68,12 @@ def __init__( self._last_breakpoint_node: str | None = None self._last_checkpoint_id: str | None = None self._resumed_from_checkpoint_id: str | None = None + # Track tool nodes that emitted STARTED but not yet COMPLETED. + # Persists across _stream_workflow() calls (same runtime instance + # reused by UiPathChatRuntime's while loop), allowing us to emit + # synthetic COMPLETED events on HITL resume when the framework + # doesn't surface function_result in output/executor_completed. + self._pending_tool_nodes: set[str] = set() # ------------------------------------------------------------------ # Checkpoint helpers @@ -92,11 +99,14 @@ async def _save_breakpoint_state( to run before re-arming its breakpoint. The count is incremented every time the same executor hits a breakpoint again (cyclic graphs, GroupChat orchestrators). + + checkpoint_id may be None when the breakpoint fired before any + new checkpoint was created (e.g. the very first executor of a + fresh turn). In that case the resume path will replay from + original_input with skip counts instead of restoring a checkpoint. """ if not self._resumable_storage: return - if checkpoint_id is None: - checkpoint_id = await self._get_latest_checkpoint_id() state = { "skip_nodes": dict(self._breakpoint_skip_nodes), "last_breakpoint_node": self._last_breakpoint_node, @@ -171,6 +181,37 @@ def _apply_session_to_executors(self, session: AgentSession) -> None: if isinstance(executor, AgentExecutor): executor._session = session + def _get_session_from_executors(self) -> AgentSession | None: + """Extract the most complete session from AgentExecutors in the workflow. + + After checkpoint restore each executor receives its own independent + session copy (unlike fresh runs where all executors share one object). + Only the executor that processed the HITL/breakpoint response will + have the updated conversation history. We return the session with the + most messages to ensure the complete history is persisted. + """ + workflow = self.agent.workflow + best_session: AgentSession | None = None + best_msg_count = -1 + for executor in workflow.executors.values(): + if isinstance(executor, AgentExecutor) and executor._session is not None: + msg_count = self._count_session_messages(executor._session) + if msg_count > best_msg_count: + best_msg_count = msg_count + best_session = executor._session + return best_session + + @staticmethod + def _count_session_messages(session: AgentSession) -> int: + """Count total messages across all provider keys in a session's state.""" + count = 0 + for value in session.state.values(): + if isinstance(value, dict) and "messages" in value: + messages = value["messages"] + if isinstance(messages, list): + count += len(messages) + return count + # ------------------------------------------------------------------ # HITL helpers (tool approval flow) # ------------------------------------------------------------------ @@ -270,13 +311,31 @@ async def execute( workflow = self.agent.workflow + # Capture the latest checkpoint BEFORE workflow.run() so we can + # detect whether a NEW checkpoint was created during this execution. + # Without this, breakpoints that fire before any new checkpoint + # (e.g. the first executor of turn 2) would save a stale + # checkpoint from the previous turn, causing the resume to + # restore completed state instead of replaying from input. + baseline_checkpoint_id = await self._get_latest_checkpoint_id() + if is_resuming and input is not None: # HITL resume: checkpoint restores executor state (including session) self._resume_responses = await self._convert_resume_responses(input) - # Inject breakpoints (no skip needed for HITL resume) + # Inject breakpoints with accumulated skip counts so that + # breakpoints don't re-fire on the same executor after HITL + # approval (prevents breakpoint→HITL→breakpoint loop). if options and options.breakpoints: - inject_breakpoint_middleware(self.agent, options.breakpoints) + await self._load_breakpoint_state() + inject_breakpoint_middleware( + self.agent, + options.breakpoints, + self._get_breakpoint_skip(), + ) + # _load_breakpoint_state sets _last_checkpoint_id as a + # side effect. Clear it so it doesn't contaminate later runs. + self._last_checkpoint_id = None if self._resume_responses: checkpoint_id = await self._get_latest_checkpoint_id() @@ -332,6 +391,12 @@ async def execute( checkpoint_storage=self._checkpoint_storage, ) + # After resume paths the checkpoint restores the session into + # executors directly, so the local ``session`` is still None. + # Extract it so it can be persisted after completion. + if session is None: + session = self._get_session_from_executors() + # Check for HITL suspension (framework's request_info mechanism) request_info_events = result.get_request_info_events() hitl_requests = { @@ -375,8 +440,21 @@ async def execute( ) self._last_breakpoint_node = node_id original_input = self._prepare_input(input) if not is_resuming else "" + # Only save checkpoint_id if it was created during THIS run. + # If latest == baseline, no new checkpoint was created (e.g. + # breakpoint on the first executor of a fresh turn) — save + # the checkpoint we resumed from (if any) so we don't lose + # it and replay from scratch on the next resume. + # For fresh turns _resumed_from_checkpoint_id is None, which + # correctly prevents using a stale checkpoint from the + # previous turn. + effective_checkpoint = ( + latest_checkpoint + if latest_checkpoint != baseline_checkpoint_id + else self._resumed_from_checkpoint_id + ) await self._save_breakpoint_state( - original_input, checkpoint_id=latest_checkpoint + original_input, checkpoint_id=effective_checkpoint ) return create_breakpoint_result(e) return self._create_suspended_result(e) @@ -400,9 +478,20 @@ async def stream( self._resume_responses = await self._convert_resume_responses(input) user_input = self._prepare_input(None) - # Inject breakpoints (no skip needed for HITL resume) + # Inject breakpoints with accumulated skip counts so that + # breakpoints don't re-fire on the same executor after HITL + # approval (prevents breakpoint→HITL→breakpoint loop). if options and options.breakpoints: - inject_breakpoint_middleware(self.agent, options.breakpoints) + await self._load_breakpoint_state() + inject_breakpoint_middleware( + self.agent, + options.breakpoints, + self._get_breakpoint_skip(), + ) + # _load_breakpoint_state sets _last_checkpoint_id as a + # side effect. Clear it so _stream_workflow doesn't + # mistake a subsequent fresh run for a breakpoint resume. + self._last_checkpoint_id = None elif is_resuming: # Breakpoint resume: restore original_input and session @@ -421,8 +510,9 @@ async def stream( ) else: - # Fresh run + # Fresh run — clear stale resume state from previous turns self._resume_responses = None + self._last_checkpoint_id = None user_input = self._prepare_input(input) # Load session for multi-turn conversation history @@ -462,6 +552,23 @@ async def _stream_workflow( phase=UiPathRuntimeStatePhase.STARTED, ) + # On HITL resume, emit COMPLETED for tool nodes that were left + # pending when the previous stream suspended. The framework + # doesn't surface function_result in output/executor_completed + # for handoff workflows, so we synthesize these events here. + if is_resuming and self._pending_tool_nodes: + for tool_node in list(self._pending_tool_nodes): + yield UiPathRuntimeStateEvent( + payload={}, + node_name=tool_node, + phase=UiPathRuntimeStatePhase.COMPLETED, + ) + self._pending_tool_nodes.clear() + + # Capture the latest checkpoint BEFORE workflow.run() so we can + # detect whether a NEW checkpoint was created during this execution. + baseline_checkpoint_id = await self._get_latest_checkpoint_id() + # Choose workflow.run() mode based on resume type if self._resume_responses: # HITL resume: pass responses to workflow with checkpoint @@ -495,10 +602,13 @@ async def _stream_workflow( request_info_map: dict[str, Any] = {} is_suspended = False - # Track executors whose tool events were emitted via output events. - # When the workflow filters output events (e.g. GroupChat), tool events - # are instead extracted from executor_completed data as a fallback. - executors_with_tool_outputs: set[str] = set() + # Track which tool event phases were emitted per executor via output + # events. When the workflow filters output events (e.g. GroupChat), + # tool events are extracted from executor_completed data as a fallback. + # Tracking phases (not just executor_ids) lets us handle HITL resume + # where function_call (STARTED) is in output but function_result + # (COMPLETED) is only in executor_completed. + executor_tool_phases: dict[str, set[UiPathRuntimeStatePhase]] = {} # Emit an early STARTED event for the start executor so the graph # visualization shows it immediately rather than after it finishes. @@ -534,17 +644,37 @@ async def _stream_workflow( phase=UiPathRuntimeStatePhase.STARTED, ) elif event.type == "executor_completed": - # When output events were filtered by the workflow (e.g. - # GroupChat where participants are not output executors), - # extract tool state events from the completed data instead. - if ( - event.executor_id - and event.executor_id not in executors_with_tool_outputs - ): + # Extract tool state events from executor_completed data, + # skipping phases already emitted via output events. + # This handles three scenarios: + # 1. GroupChat (no output events): emit all from completed + # 2. Normal (both in output): skip all from completed + # 3. HITL resume (only STARTED in output): emit COMPLETED + if event.executor_id: + emitted_phases = executor_tool_phases.get( + event.executor_id, set() + ) for tool_event in self._extract_tool_state_events( event.data, event.executor_id ): - yield tool_event + if tool_event.phase not in emitted_phases: + # Track pending tool nodes + if tool_event.node_name: + if ( + tool_event.phase + == UiPathRuntimeStatePhase.STARTED + ): + self._pending_tool_nodes.add( + tool_event.node_name + ) + elif ( + tool_event.phase + == UiPathRuntimeStatePhase.COMPLETED + ): + self._pending_tool_nodes.discard( + tool_event.node_name + ) + yield tool_event yield UiPathRuntimeStateEvent( payload=self._serialize_event_data( self._filter_completed_data(event.data) @@ -557,9 +687,16 @@ async def _stream_workflow( tool_events = self._extract_tool_state_events( event.data, executor_id ) - if tool_events: - executors_with_tool_outputs.add(executor_id) for tool_event in tool_events: + executor_tool_phases.setdefault(executor_id, set()).add( + tool_event.phase + ) + # Track pending tool nodes across stream iterations + if tool_event.node_name: + if tool_event.phase == UiPathRuntimeStatePhase.STARTED: + self._pending_tool_nodes.add(tool_event.node_name) + elif tool_event.phase == UiPathRuntimeStatePhase.COMPLETED: + self._pending_tool_nodes.discard(tool_event.node_name) yield tool_event for msg_event in self._extract_workflow_messages(event.data): yield UiPathRuntimeMessageEvent(payload=msg_event) @@ -581,6 +718,10 @@ async def _stream_workflow( for msg_event in self.chat.close_message(): yield UiPathRuntimeMessageEvent(payload=msg_event) + # After resume paths the checkpoint restores the session into + # executors directly, so the local ``session`` may still be None. + if session is None: + session = self._get_session_from_executors() if session is not None: await self._save_session(session) @@ -602,8 +743,16 @@ async def _stream_workflow( self._breakpoint_skip_nodes.get(node_id, 0) + 1 ) self._last_breakpoint_node = node_id + # Only save checkpoint_id if it was created during THIS run. + # Fall back to the checkpoint we resumed from (if any) to + # avoid replaying from scratch on the next resume. + effective_checkpoint = ( + latest_checkpoint + if latest_checkpoint != baseline_checkpoint_id + else self._resumed_from_checkpoint_id + ) await self._save_breakpoint_state( - user_input, checkpoint_id=latest_checkpoint + user_input, checkpoint_id=effective_checkpoint ) yield create_breakpoint_result(e) else: @@ -619,6 +768,10 @@ async def _stream_workflow( for msg_event in self.chat.close_message(): yield UiPathRuntimeMessageEvent(payload=msg_event) + # After resume paths the checkpoint restores the session into + # executors directly, so the local ``session`` may still be None. + if session is None: + session = self._get_session_from_executors() if session is not None: await self._save_session(session) @@ -681,7 +834,11 @@ def _extract_tool_state_events( """ contents: list[Any] = [] - if isinstance(data, AgentResponseUpdate): + if isinstance(data, AgentExecutorResponse): + return UiPathAgentFrameworkRuntime._extract_tool_state_events( + data.agent_response, executor_id + ) + elif isinstance(data, AgentResponseUpdate): contents = list(data.contents or []) elif isinstance(data, AgentResponse): for message in data.messages or []: @@ -724,7 +881,9 @@ def _extract_tool_state_events( def _extract_contents(data: Any) -> list[Any]: """Extract Content objects from any workflow data type.""" contents: list[Any] = [] - if isinstance(data, AgentResponseUpdate): + if isinstance(data, AgentExecutorResponse): + return UiPathAgentFrameworkRuntime._extract_contents(data.agent_response) + elif isinstance(data, AgentResponseUpdate): contents = list(data.contents or []) elif isinstance(data, AgentResponse): for message in data.messages or []: diff --git a/packages/uipath-agent-framework/tests/test_hitl_e2e.py b/packages/uipath-agent-framework/tests/test_hitl_e2e.py index 775f983..9887b8b 100644 --- a/packages/uipath-agent-framework/tests/test_hitl_e2e.py +++ b/packages/uipath-agent-framework/tests/test_hitl_e2e.py @@ -15,10 +15,11 @@ """ import asyncio +import json import os import tempfile -from typing import Any -from unittest.mock import AsyncMock, MagicMock +from typing import Any, cast +from unittest.mock import AsyncMock, MagicMock, Mock import pytest from agent_framework.openai import OpenAIChatClient @@ -32,7 +33,16 @@ from uipath.platform.resume_triggers import UiPathResumeTriggerHandler from uipath.runtime import UiPathResumableRuntime from uipath.runtime.chat.runtime import UiPathChatRuntime -from uipath.runtime.events import UiPathRuntimeEvent +from uipath.runtime.debug import ( + UiPathDebugProtocol, + UiPathDebugRuntime, +) +from uipath.runtime.events import ( + UiPathRuntimeEvent, + UiPathRuntimeStateEvent, + UiPathRuntimeStatePhase, +) +from uipath.runtime.events.state import UiPathRuntimeMessageEvent from uipath.runtime.result import UiPathRuntimeResult, UiPathRuntimeStatus from uipath.runtime.resumable.trigger import ( UiPathResumeTrigger, @@ -572,3 +582,988 @@ async def mock_create(**kwargs: Any): finally: await storage.dispose() os.unlink(tmp_path) + + async def test_multi_turn_after_hitl_resume(self): + """Second fresh turn after HITL resume should not fail with stale session. + + Reproduces the bug where the session saved after HITL resume was stale + (missing tool results), causing OpenAI to reject the next turn with: + "An assistant message with 'tool_calls' must be followed by tool messages" + + Flow: + Turn 1: triage -> billing -> transfer_funds (HITL approve) -> complete + Turn 2: triage -> text response -> complete (must not fail) + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + count = call_count.get("triage", 0) + call_count["triage"] = count + 1 + if count == 0: + # Turn 1: route to billing + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + # Turn 2: respond directly (no tool call) + return make_mock_response("How else can I help?", stream=is_stream) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + chat_runtime, chat_bridge, storage = await _create_hitl_runtime_stack( + agent, "test-hitl-multi-turn", tmp_path, auto_approve=True + ) + + # Turn 1: HITL approval flow + result1 = await chat_runtime.execute({"messages": []}) + assert result1.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 1 failed: {result1.status}" + ) + + # Verify session was saved with complete history (including tool results). + # The bug was that only the HITL-suspension session was persisted + # (missing tool results), causing the next turn to fail. + session_data = await storage.get_value( + "test-hitl-multi-turn", "session", "data" + ) + assert session_data is not None, "Session was not saved after HITL resume" + + session_state = session_data.get("state", {}) + # Find all messages in any provider key + all_messages = [] + for _provider_key, provider_data in session_state.items(): + if not isinstance(provider_data, dict): + continue + msgs = provider_data.get("messages", []) + if not isinstance(msgs, list): + continue + all_messages = msgs + break + + assert len(all_messages) > 0, ( + f"No messages found in session state. " + f"State keys: {list(session_state.keys())}. " + f"Full state: {json.dumps(session_data, indent=2, default=str)[:3000]}" + ) + + # Verify: every assistant message with a function_call must + # eventually be followed by a tool message with function_result + # for the same call_id. This is the exact invariant that OpenAI + # enforces and that breaks with stale sessions. + def get_function_call_ids(msg_data: dict[str, Any]) -> list[str]: + """Get call_ids of function_call contents in a message.""" + ids = [] + for c in msg_data.get("contents", []): + if isinstance(c, dict) and c.get("type") == "function_call": + cid = c.get("call_id") + if cid: + ids.append(cid) + return ids + + def get_function_result_ids(msg_data: dict[str, Any]) -> list[str]: + """Get call_ids of function_result contents in a message.""" + ids = [] + for c in msg_data.get("contents", []): + if isinstance(c, dict) and c.get("type") == "function_result": + cid = c.get("call_id") + if cid: + ids.append(cid) + return ids + + pending_call_ids: set[str] = set() + for msg in all_messages: + if not isinstance(msg, dict): + continue + pending_call_ids.update(get_function_call_ids(msg)) + for rid in get_function_result_ids(msg): + pending_call_ids.discard(rid) + + assert len(pending_call_ids) == 0, ( + f"Session has orphaned function_calls (no function_result): " + f"{pending_call_ids}. This will cause OpenAI to reject the next " + f"turn. Messages:\n" + + json.dumps(all_messages, indent=2, default=str)[:5000] + ) + + # Turn 2: fresh turn after HITL — must not fail with stale session + chat_bridge.interrupts.clear() + result2 = await chat_runtime.execute({"messages": []}) + assert result2.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 2 failed: {result2.status}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_multi_turn_after_hitl_resume_streaming(self): + """Streaming: second fresh turn after HITL resume should not fail. + + Same as test_multi_turn_after_hitl_resume but using the streaming path. + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "X", "to_account": "Y", "amount": 200.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer done.", stream=is_stream) + elif "triage" in system_msg.lower(): + count = call_count.get("triage", 0) + call_count["triage"] = count + 1 + if count == 0: + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response("Anything else?", stream=is_stream) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + chat_runtime, chat_bridge, storage = await _create_hitl_runtime_stack( + agent, "test-hitl-stream-multi", tmp_path, auto_approve=True + ) + + # Turn 1: streaming HITL flow + events1: list[UiPathRuntimeEvent] = [] + async for event in chat_runtime.stream({"messages": []}): + events1.append(event) + results1 = [e for e in events1 if isinstance(e, UiPathRuntimeResult)] + assert len(results1) >= 1 + assert results1[-1].status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 1 failed: {results1[-1].status}" + ) + + # Turn 2: fresh streaming turn after HITL + chat_bridge.interrupts.clear() + events2: list[UiPathRuntimeEvent] = [] + async for event in chat_runtime.stream({"messages": []}): + events2.append(event) + results2 = [e for e in events2 if isinstance(e, UiPathRuntimeResult)] + assert len(results2) >= 1 + assert results2[-1].status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 2 failed: {results2[-1].status}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_tool_node_completed_after_hitl_resume(self): + """Tool node should emit both STARTED and COMPLETED state events across HITL. + + Reproduces the bug where billing_agent_tools emitted STARTED (before HITL + suspension) but never COMPLETED (after HITL resume), because the framework + doesn't surface function_result in output/executor_completed events for + handoff workflows. The fix synthesizes COMPLETED at the start of resume. + + Expected event sequence: + customer_support STARTED + triage STARTED + triage COMPLETED + billing_agent STARTED + billing_agent_tools STARTED <-- before HITL suspension + [HITL interrupt + resume] + billing_agent_tools COMPLETED <-- synthesized on resume + billing_agent STARTED <-- resume executor + billing_agent COMPLETED + customer_support COMPLETED + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + chat_runtime, chat_bridge, storage = await _create_hitl_runtime_stack( + agent, "test-hitl-tool-events", tmp_path, auto_approve=True + ) + + # Collect all events from streaming + events: list[UiPathRuntimeEvent] = [] + async for event in chat_runtime.stream({"messages": []}): + events.append(event) + + # Extract state events for analysis + state_events = [e for e in events if isinstance(e, UiPathRuntimeStateEvent)] + + # Build a summary: node_name -> list of phases + node_phases: dict[str, list[str]] = {} + for se in state_events: + if se.node_name: + node_phases.setdefault(se.node_name, []).append(se.phase.value) + + # billing_agent_tools MUST have both STARTED and COMPLETED + tools_node = "billing_agent_tools" + assert tools_node in node_phases, ( + f"{tools_node} not found in state events. " + f"Nodes seen: {list(node_phases.keys())}" + ) + tools_phases = node_phases[tools_node] + assert UiPathRuntimeStatePhase.STARTED.value in tools_phases, ( + f"{tools_node} missing STARTED. Phases: {tools_phases}" + ) + assert UiPathRuntimeStatePhase.COMPLETED.value in tools_phases, ( + f"{tools_node} missing COMPLETED after HITL resume. " + f"Phases: {tools_phases}. " + f"All state events: " + + ", ".join( + f"{e.node_name}:{e.phase.value}" + for e in state_events + if e.node_name + ) + ) + + # Verify COMPLETED comes after STARTED + started_idx = tools_phases.index(UiPathRuntimeStatePhase.STARTED.value) + completed_idx = tools_phases.index(UiPathRuntimeStatePhase.COMPLETED.value) + assert completed_idx > started_idx, ( + f"{tools_node} COMPLETED ({completed_idx}) should come after " + f"STARTED ({started_idx})" + ) + + # Final result should be successful + results = [e for e in events if isinstance(e, UiPathRuntimeResult)] + assert len(results) >= 1 + assert results[-1].status == UiPathRuntimeStatus.SUCCESSFUL + finally: + await storage.dispose() + os.unlink(tmp_path) + + +# --------------------------------------------------------------------------- +# Debug bridge mock +# --------------------------------------------------------------------------- + + +def _make_debug_bridge(**overrides: Any) -> UiPathDebugProtocol: + """Create a mock debug bridge with sensible defaults.""" + bridge: Mock = Mock(spec=UiPathDebugProtocol) + bridge.connect = AsyncMock() + bridge.disconnect = AsyncMock() + bridge.emit_execution_started = AsyncMock() + bridge.emit_execution_completed = AsyncMock() + bridge.emit_execution_error = AsyncMock() + bridge.emit_execution_suspended = AsyncMock() + bridge.emit_breakpoint_hit = AsyncMock() + bridge.emit_state_update = AsyncMock() + bridge.emit_execution_resumed = AsyncMock() + bridge.wait_for_resume = AsyncMock(return_value=None) + bridge.wait_for_terminate = AsyncMock() + bridge.get_breakpoints = Mock(return_value=[]) + for k, v in overrides.items(): + setattr(bridge, k, v) + return cast(UiPathDebugProtocol, bridge) + + +# --------------------------------------------------------------------------- +# Breakpoint + HITL combined tests +# --------------------------------------------------------------------------- + +# Safety limit: if the debug loop exceeds this many resume calls, +# the test fails — this means breakpoints are stuck in a loop. +MAX_RESUME_CALLS = 20 + + +@pytest.mark.asyncio(loop_scope="class") +class TestBreakpointAndHitlCombined: + """Tests for the interaction between breakpoints and HITL tool approval. + + Uses the full runtime stack: + UiPathDebugRuntime -> UiPathChatRuntime -> UiPathResumableRuntime + -> UiPathAgentFrameworkRuntime + + Reproduces the bug where breakpoint + @requires_approval on the same + node causes an infinite loop: breakpoint → continue → HITL → approve → + breakpoint (again!) → continue → HITL → ... + """ + + @pytest.fixture(autouse=True) + async def _settle_framework(self): + """Allow framework background tasks to complete between tests.""" + yield + await asyncio.sleep(0.2) + + async def test_breakpoint_then_hitl_does_not_loop(self): + """Breakpoint + HITL on same node must complete without looping. + + Flow: + 1. triage → billing_agent (breakpoint fires before billing_agent) + 2. User continues breakpoint + 3. billing_agent runs, calls transfer_funds (HITL fires) + 4. User approves tool via chat bridge + 5. Execution completes — NO more breakpoints + + Without the fix, step 5 would trigger another breakpoint on + billing_agent, creating an infinite breakpoint→HITL→breakpoint loop. + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-bp-hitl-loop" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + base_runtime.chat = MagicMock() + base_runtime.chat.map_messages_to_input.return_value = ( + "Transfer $100 from A to B" + ) + base_runtime.chat.map_streaming_content.return_value = [] + base_runtime.chat.close_message.return_value = [] + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + # Debug bridge: breakpoint on billing_agent + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes. " + f"Breakpoint hits: " + f"{cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count}" + ) + return None + + debug_bridge = _make_debug_bridge() + cast(Mock, debug_bridge.get_breakpoints).return_value = ["billing_agent"] + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + # Execute the full flow + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count + + # Must complete successfully (not loop forever) + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, Breakpoint hits: {bp_count}" + ) + + # Breakpoint should have been hit at least once. + # The exact count depends on how many times the HandoffBuilder + # calls execute() during checkpoint restore, but it MUST NOT + # loop forever (which is caught by MAX_RESUME_CALLS). + assert bp_count >= 1, f"Expected at least 1 breakpoint hit, got {bp_count}" + + # HITL should have been handled by chat bridge + assert len(chat_bridge.interrupts) >= 1, ( + f"Expected at least 1 HITL interrupt, got {len(chat_bridge.interrupts)}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_breakpoint_on_all_nodes_with_hitl(self): + """Breakpoints on ALL nodes + HITL on billing_agent must complete. + + Same as above but with breakpoints="*", which means triage also + gets a breakpoint. Verifies the full sequence: + 1. Breakpoint on triage → continue + 2. triage hands off to billing_agent + 3. Breakpoint on billing_agent → continue + 4. billing_agent calls transfer_funds → HITL → approve + 5. Execution completes + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "X", "to_account": "Y", "amount": 50.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer done.", stream=is_stream) + elif "triage" in system_msg.lower(): + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-bp-all-hitl" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + base_runtime.chat = MagicMock() + base_runtime.chat.map_messages_to_input.return_value = ( + "Transfer $50 from X to Y" + ) + base_runtime.chat.map_streaming_content.return_value = [] + base_runtime.chat.close_message.return_value = [] + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes. " + f"Breakpoint hits: " + f"{cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count}" + ) + return None + + debug_bridge = _make_debug_bridge() + cast(Mock, debug_bridge.get_breakpoints).return_value = "*" + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + result = await debug_runtime.execute({"messages": []}) + + bp_count = cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count + + assert result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {result.status}. " + f"Resumes: {resume_count[0]}, Breakpoint hits: {bp_count}" + ) + + # At least 1 breakpoint should have been hit + assert bp_count >= 1, f"Expected at least 1 breakpoint hit, got {bp_count}" + + # HITL should have been handled (1 interrupt for transfer_funds) + assert len(chat_bridge.interrupts) == 1, ( + f"Expected 1 HITL interrupt, got {len(chat_bridge.interrupts)}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_second_turn_after_breakpoint_and_hitl(self): + """Second turn after breakpoint+HITL must actually run the workflow. + + Reproduces the bug where a breakpoint on the first executor of + turn 2 saved a stale checkpoint_id from turn 1 (because no new + checkpoint existed yet). On resume, workflow.run(checkpoint_id=stale) + restored turn 1's completed state and returned immediately with + no work done — the OTEL trace showed empty workflow.run spans. + + Flow: + Turn 1: breakpoints=* → triage → billing → HITL → approve → complete + Turn 2: breakpoints=* → triage breakpoint → continue → triage + responds "How else can I help?" → complete + + The key assertion: triage's LLM must have been called during turn 2 + (call_count["triage"] increases). Without the fix, the resume + restores turn 1's state and triage is never called. + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + if "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + elif "triage" in system_msg.lower(): + count = call_count.get("triage", 0) + call_count["triage"] = count + 1 + if count == 0: + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + else: + # Turn 2: triage responds directly + return make_mock_response("How else can I help?", stream=is_stream) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-bp-hitl-turn2" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + base_runtime.chat = MagicMock() + base_runtime.chat.map_messages_to_input.return_value = ( + "Transfer $100 from A to B" + ) + base_runtime.chat.map_streaming_content.return_value = [] + base_runtime.chat.close_message.return_value = [] + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes" + ) + return None + + debug_bridge = _make_debug_bridge() + cast(Mock, debug_bridge.get_breakpoints).return_value = "*" + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + # Turn 1: breakpoint + HITL + result1 = await debug_runtime.execute({"messages": []}) + assert result1.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 1 failed: {result1.status}" + ) + + turn1_bp_count = cast( + AsyncMock, debug_bridge.emit_breakpoint_hit + ).await_count + assert turn1_bp_count >= 1, "Turn 1 should have hit at least 1 breakpoint" + + # Record LLM call counts after turn 1 + triage_calls_after_turn1 = call_count.get("triage", 0) + assert triage_calls_after_turn1 >= 1, ( + "Triage should have been called at least once during turn 1" + ) + + # Turn 2: fresh turn — should complete normally with breakpoints + chat_bridge.interrupts.clear() + resume_count[0] = 0 + cast(AsyncMock, debug_bridge.emit_breakpoint_hit).reset_mock() + + result2 = await debug_runtime.execute({"messages": []}) + assert result2.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Turn 2 failed: {result2.status}. " + f"Resumes: {resume_count[0]}, " + f"BPs: {cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count}" + ) + + turn2_bp_count = cast( + AsyncMock, debug_bridge.emit_breakpoint_hit + ).await_count + # Turn 2 should also hit breakpoints (fresh run with breakpoints) + assert turn2_bp_count >= 1, ( + f"Turn 2 should have hit at least 1 breakpoint, got {turn2_bp_count}" + ) + + # KEY ASSERTION: triage's LLM was actually called during turn 2. + # Without the stale-checkpoint fix, the resume restores turn 1's + # completed state and triage is never invoked again. + triage_calls_after_turn2 = call_count.get("triage", 0) + assert triage_calls_after_turn2 > triage_calls_after_turn1, ( + f"Turn 2 must invoke triage LLM. " + f"Calls after turn 1: {triage_calls_after_turn1}, " + f"after turn 2: {triage_calls_after_turn2}. " + f"This means the resume used a stale checkpoint from turn 1." + ) + finally: + await storage.dispose() + os.unlink(tmp_path) + + async def test_no_duplicate_tool_calls_on_breakpoint_resume(self): + """Breakpoint resumes must not emit duplicate tool call message events. + + Full E2E streaming test with BOTH breakpoints=* AND HITL approval. + Reproduces the bug where checkpoint restore replays output events for + already-completed executors. The chat mapper processed each replayed + Content(type='function_call') as new, emitting duplicate ToolCallStart + events — visible as repeated 'handoff_to_billing_agent' in the UI. + + Full runtime stack: + UiPathDebugRuntime -> UiPathChatRuntime -> UiPathResumableRuntime + -> UiPathAgentFrameworkRuntime + + Expected flow (breakpoints="*"): + 1. Breakpoint on triage → debug bridge continue + 2. triage runs, calls handoff_to_billing_agent → handoff to billing + 3. Breakpoint on billing_agent → debug bridge continue + 4. billing_agent runs, calls transfer_funds → HITL suspends + 5. Chat bridge auto-approves → billing completes + 6. More breakpoints may fire on fan-out agents → continues + 7. Workflow completes SUCCESSFUL + + Assertions at every step: + - Breakpoints fired (emit_breakpoint_hit called) + - HITL handled (chat_bridge.interrupts) + - handoff_to_billing_agent ToolCallStart appears exactly ONCE + - transfer_funds ToolCallStart appears exactly ONCE + - Final result is SUCCESSFUL + """ + call_count: dict[str, int] = {} + + async def mock_create(**kwargs: Any): + messages = kwargs.get("messages", []) + is_stream = kwargs.get("stream", False) + system_msg = extract_system_text(messages) + + # Check triage BEFORE billing — triage's system prompt contains + # "billing_agent" in the handoff instructions, so a naive + # "billing" check would match both agents. + if "triage" in system_msg.lower(): + call_count["triage"] = call_count.get("triage", 0) + 1 + return make_tool_call_response( + "handoff_to_billing_agent", stream=is_stream + ) + elif "billing" in system_msg.lower(): + count = call_count.get("billing", 0) + call_count["billing"] = count + 1 + if count == 0: + return make_tool_call_response( + "transfer_funds", + arguments='{"from_account": "A", "to_account": "B", "amount": 100.0}', + stream=is_stream, + ) + else: + return make_mock_response("Transfer complete.", stream=is_stream) + else: + return make_mock_response("OK", stream=is_stream) + + mock_openai = AsyncMock() + mock_openai.chat.completions.create = mock_create + + agent = _build_hitl_agents(mock_openai) + + tmp_fd, tmp_path = tempfile.mkstemp(suffix=".db") + os.close(tmp_fd) + try: + storage = SqliteResumableStorage(tmp_path) + await storage.setup() + assert storage.checkpoint_storage is not None + + runtime_id = "test-no-dup-tool-calls" + scoped_cs = ScopedCheckpointStorage(storage.checkpoint_storage, runtime_id) + + base_runtime = UiPathAgentFrameworkRuntime( + agent=agent, + runtime_id=runtime_id, + checkpoint_storage=scoped_cs, + resumable_storage=storage, + ) + + resumable_runtime = UiPathResumableRuntime( + delegate=base_runtime, + storage=storage, + trigger_manager=UiPathResumeTriggerHandler(), + runtime_id=runtime_id, + ) + + chat_bridge = MockChatBridge(auto_approve=True) + chat_runtime = UiPathChatRuntime( + delegate=resumable_runtime, chat_bridge=chat_bridge + ) + + resume_count = [0] + + async def mock_wait_for_resume(*args: Any, **kwargs: Any) -> None: + resume_count[0] += 1 + if resume_count[0] > MAX_RESUME_CALLS: + raise AssertionError( + f"Infinite loop detected: {resume_count[0]} resumes" + ) + return None + + debug_bridge = _make_debug_bridge() + # Breakpoints on ALL nodes — maximizes checkpoint replays + cast(Mock, debug_bridge.get_breakpoints).return_value = "*" + cast( + AsyncMock, debug_bridge.wait_for_resume + ).side_effect = mock_wait_for_resume + + debug_runtime = UiPathDebugRuntime( + delegate=chat_runtime, debug_bridge=debug_bridge + ) + + # ---- Stream and collect all events ---- + all_events: list[UiPathRuntimeEvent] = [] + async for event in debug_runtime.stream({"messages": []}): + all_events.append(event) + + # ---- Step 1: Verify execution completed successfully ---- + results = [e for e in all_events if isinstance(e, UiPathRuntimeResult)] + assert len(results) >= 1, "No UiPathRuntimeResult events found" + final_result = results[-1] + assert final_result.status == UiPathRuntimeStatus.SUCCESSFUL, ( + f"Expected SUCCESSFUL, got {final_result.status}. " + f"Total results: {len(results)}, " + f"Resumes: {resume_count[0]}" + ) + + # ---- Step 2: Verify breakpoints fired ---- + bp_hit_count = cast(AsyncMock, debug_bridge.emit_breakpoint_hit).await_count + assert bp_hit_count >= 1, ( + f"Expected at least 1 breakpoint hit with breakpoints='*', " + f"got {bp_hit_count}" + ) + + # ---- Step 3: Verify HITL was handled ---- + assert len(chat_bridge.interrupts) >= 1, ( + f"Expected at least 1 HITL interrupt for transfer_funds, " + f"got {len(chat_bridge.interrupts)}" + ) + + # ---- Step 4: Verify LLM calls ---- + assert call_count.get("triage", 0) >= 1, ( + "Triage LLM should have been called at least once" + ) + assert call_count.get("billing", 0) >= 1, ( + "Billing LLM should have been called at least once" + ) + + # ---- Step 5: Count ToolCallStart events per tool name ---- + tool_call_starts: dict[str, int] = {} + tool_call_ends: dict[str, int] = {} + for event in all_events: + if not isinstance(event, UiPathRuntimeMessageEvent): + continue + payload = event.payload + if not hasattr(payload, "tool_call") or not payload.tool_call: + continue + tc = payload.tool_call + if hasattr(tc, "start") and tc.start: + name = tc.start.tool_name + tool_call_starts[name] = tool_call_starts.get(name, 0) + 1 + if hasattr(tc, "end") and tc.end: + tc_id = tc.tool_call_id or "unknown" + tool_call_ends[tc_id] = tool_call_ends.get(tc_id, 0) + 1 + + # ---- Step 6: Assert no duplicate tool calls ---- + assert "handoff_to_billing_agent" in tool_call_starts, ( + f"handoff_to_billing_agent not found in tool calls. " + f"Tool calls seen: {tool_call_starts}" + ) + assert tool_call_starts["handoff_to_billing_agent"] == 1, ( + f"handoff_to_billing_agent should appear exactly once, " + f"but appeared {tool_call_starts['handoff_to_billing_agent']} " + f"times. All tool calls: {tool_call_starts}. " + f"This means checkpoint replay emitted duplicate events." + ) + + assert "transfer_funds" in tool_call_starts, ( + f"transfer_funds not found in tool calls. " + f"Tool calls seen: {tool_call_starts}" + ) + assert tool_call_starts["transfer_funds"] == 1, ( + f"transfer_funds should appear exactly once, " + f"but appeared {tool_call_starts['transfer_funds']} times. " + f"All tool calls: {tool_call_starts}." + ) + + # ---- Step 7: Verify state events are well-formed ---- + state_events = [ + e for e in all_events if isinstance(e, UiPathRuntimeStateEvent) + ] + node_phases: dict[str, list[str]] = {} + for se in state_events: + if se.node_name: + node_phases.setdefault(se.node_name, []).append(se.phase.value) + # The top-level agent node must have started and completed + agent_node = "customer_support" + assert agent_node in node_phases, ( + f"Top-level agent '{agent_node}' not found in state events. " + f"Nodes: {list(node_phases.keys())}" + ) + assert "started" in node_phases[agent_node], ( + "Agent node missing STARTED phase" + ) + assert "completed" in node_phases[agent_node], ( + f"Agent node missing COMPLETED phase. Phases: {node_phases[agent_node]}" + ) + finally: + await storage.dispose() + os.unlink(tmp_path) diff --git a/packages/uipath-agent-framework/tests/test_streaming.py b/packages/uipath-agent-framework/tests/test_streaming.py index ac7e890..6c29267 100644 --- a/packages/uipath-agent-framework/tests/test_streaming.py +++ b/packages/uipath-agent-framework/tests/test_streaming.py @@ -640,6 +640,7 @@ async def test_checkpoint_storage_passed_to_workflow_run_stream(self): agent = WorkflowAgent(workflow=workflow, name="chat_wf") mock_checkpoint_storage = MagicMock() + mock_checkpoint_storage.get_latest = AsyncMock(return_value=None) captured_kwargs: list[dict[str, Any]] = [] def mock_run(**kwargs): @@ -679,6 +680,7 @@ async def test_checkpoint_storage_passed_to_workflow_run_execute(self): agent = WorkflowAgent(workflow=workflow, name="exec_wf") mock_checkpoint_storage = MagicMock() + mock_checkpoint_storage.get_latest = AsyncMock(return_value=None) captured_kwargs: list[dict[str, Any]] = [] async def mock_run(**kwargs): diff --git a/packages/uipath-agent-framework/uv.lock b/packages/uipath-agent-framework/uv.lock index dac2194..4a75135 100644 --- a/packages/uipath-agent-framework/uv.lock +++ b/packages/uipath-agent-framework/uv.lock @@ -2460,7 +2460,7 @@ wheels = [ [[package]] name = "uipath-agent-framework" -version = "0.0.6" +version = "0.0.7" source = { editable = "." } dependencies = [ { name = "agent-framework-core" },