Skip to content

Commit 3a5d413

Browse files
committed
fix: add tests for parallel branches with triggers
1 parent 071863c commit 3a5d413

File tree

1 file changed

+317
-13
lines changed

1 file changed

+317
-13
lines changed

tests/runtime/test_resumable.py

Lines changed: 317 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,11 @@
2020
from uipath_langchain.runtime.storage import SqliteResumableStorage
2121

2222

23-
class MockTriggerHandler:
24-
"""Mock implementation of UiPathResumeTriggerHandler."""
23+
class SequentialTriggerHandler:
24+
"""Mock implementation that fires triggers sequentially.
25+
26+
Resolves triggers one at a time across multiple resume calls.
27+
"""
2528

2629
def __init__(self):
2730
self.call_count = 0
@@ -59,9 +62,31 @@ async def read_trigger(self, trigger: UiPathResumeTrigger) -> Any:
5962
)
6063

6164

65+
class ParallelTriggerHandler:
66+
"""Mock implementation that fires all triggers immediately.
67+
68+
Resolves all triggers on the first resume call.
69+
"""
70+
71+
async def create_trigger(self, suspend_value: Any) -> UiPathResumeTrigger:
72+
"""Create a trigger from suspend value."""
73+
trigger = UiPathResumeTrigger(
74+
trigger_type=UiPathResumeTriggerType.API,
75+
trigger_name=UiPathResumeTriggerName.API,
76+
payload=suspend_value,
77+
)
78+
return trigger
79+
80+
async def read_trigger(self, trigger: UiPathResumeTrigger) -> Any:
81+
"""Read trigger and return immediate response."""
82+
assert trigger.payload is not None
83+
branch_name = trigger.payload.get("message", "unknown")
84+
return f"Response for {branch_name}"
85+
86+
6287
@pytest.mark.asyncio
63-
async def test_parallel_branches_with_multiple_interrupts_execution():
64-
"""Test graph execution with parallel branches and multiple interrupts."""
88+
async def test_parallel_branches_with_sequential_trigger_resolution():
89+
"""Test graph execution with parallel branches where triggers resolve sequentially."""
6590

6691
# Define state
6792
class State(TypedDict, total=False):
@@ -110,19 +135,19 @@ def branch_c(state: State) -> State:
110135
# Create base runtime
111136
base_runtime = UiPathLangGraphRuntime(
112137
graph=compiled_graph,
113-
runtime_id="parallel-test",
138+
runtime_id="parallel-sequential-test",
114139
entrypoint="test",
115140
)
116141

117142
# Create storage and trigger manager
118143
storage = SqliteResumableStorage(memory)
119144

120-
# Wrap with UiPathResumableRuntime
145+
# Wrap with UiPathResumableRuntime using sequential trigger handler
121146
runtime = UiPathResumableRuntime(
122147
delegate=base_runtime,
123148
storage=storage,
124-
trigger_manager=MockTriggerHandler(),
125-
runtime_id="parallel-test",
149+
trigger_manager=SequentialTriggerHandler(),
150+
runtime_id="parallel-sequential-test",
126151
)
127152

128153
# First execution - should hit all 3 interrupts
@@ -141,11 +166,11 @@ def branch_c(state: State) -> State:
141166
assert len(result.triggers) == 3
142167

143168
# Verify triggers were saved to storage
144-
saved_triggers = await storage.get_triggers("parallel-test")
169+
saved_triggers = await storage.get_triggers("parallel-sequential-test")
145170
assert saved_triggers is not None
146171
assert len(saved_triggers) == 3
147172

148-
# Resume 1: Resolve only first interrupt (no input, will restore from storage)
173+
# Resume 1: Resolve only first interrupt
149174
result_1 = await runtime.execute(
150175
input=None,
151176
options=UiPathExecuteOptions(resume=True),
@@ -157,7 +182,7 @@ def branch_c(state: State) -> State:
157182
assert len(result_1.triggers) == 2
158183

159184
# Verify only 2 triggers remain in storage
160-
saved_triggers = await storage.get_triggers("parallel-test")
185+
saved_triggers = await storage.get_triggers("parallel-sequential-test")
161186
assert saved_triggers is not None
162187
assert len(saved_triggers) == 2
163188

@@ -173,7 +198,7 @@ def branch_c(state: State) -> State:
173198
assert len(result_2.triggers) == 1
174199

175200
# Verify only 1 trigger remains in storage
176-
saved_triggers = await storage.get_triggers("parallel-test")
201+
saved_triggers = await storage.get_triggers("parallel-sequential-test")
177202
assert saved_triggers is not None
178203
assert len(saved_triggers) == 1
179204

@@ -188,7 +213,7 @@ def branch_c(state: State) -> State:
188213
assert result_3.output is not None
189214

190215
# Verify no triggers remain
191-
saved_triggers = await storage.get_triggers("parallel-test")
216+
saved_triggers = await storage.get_triggers("parallel-sequential-test")
192217
assert saved_triggers is None or len(saved_triggers) == 0
193218

194219
# Verify all branches completed
@@ -200,3 +225,282 @@ def branch_c(state: State) -> State:
200225
finally:
201226
if os.path.exists(temp_db.name):
202227
os.remove(temp_db.name)
228+
229+
230+
@pytest.mark.asyncio
231+
async def test_parallel_branches_with_parallel_trigger_resolution():
232+
"""Test graph execution with parallel branches where all triggers fire immediately."""
233+
234+
# Define state
235+
class State(TypedDict, total=False):
236+
branch_a_result: str | None
237+
branch_b_result: str | None
238+
branch_c_result: str | None
239+
240+
# Define nodes that interrupt
241+
def branch_a(state: State) -> State:
242+
result = interrupt({"message": "Branch A needs input"})
243+
return {"branch_a_result": f"A completed with: {result}"}
244+
245+
def branch_b(state: State) -> State:
246+
result = interrupt({"message": "Branch B needs input"})
247+
return {"branch_b_result": f"B completed with: {result}"}
248+
249+
def branch_c(state: State) -> State:
250+
result = interrupt({"message": "Branch C needs input"})
251+
return {"branch_c_result": f"C completed with: {result}"}
252+
253+
# Build graph with parallel branches
254+
graph = StateGraph(State)
255+
graph.add_node("branch_a", branch_a)
256+
graph.add_node("branch_b", branch_b)
257+
graph.add_node("branch_c", branch_c)
258+
259+
# All branches start in parallel
260+
graph.add_edge(START, "branch_a")
261+
graph.add_edge(START, "branch_b")
262+
graph.add_edge(START, "branch_c")
263+
264+
# All branches go to end
265+
graph.add_edge("branch_a", END)
266+
graph.add_edge("branch_b", END)
267+
graph.add_edge("branch_c", END)
268+
269+
# Create temporary database
270+
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
271+
temp_db.close()
272+
273+
try:
274+
# Compile graph with checkpointer
275+
async with AsyncSqliteSaver.from_conn_string(temp_db.name) as memory:
276+
compiled_graph = graph.compile(checkpointer=memory)
277+
278+
# Create base runtime
279+
base_runtime = UiPathLangGraphRuntime(
280+
graph=compiled_graph,
281+
runtime_id="parallel-parallel-test",
282+
entrypoint="test",
283+
)
284+
285+
# Create storage and trigger manager
286+
storage = SqliteResumableStorage(memory)
287+
288+
# Wrap with UiPathResumableRuntime using parallel trigger handler
289+
runtime = UiPathResumableRuntime(
290+
delegate=base_runtime,
291+
storage=storage,
292+
trigger_manager=ParallelTriggerHandler(),
293+
runtime_id="parallel-parallel-test",
294+
)
295+
296+
# First execution - should hit all 3 interrupts
297+
result = await runtime.execute(
298+
input={
299+
"branch_a_result": None,
300+
"branch_b_result": None,
301+
"branch_c_result": None,
302+
},
303+
options=UiPathExecuteOptions(resume=False),
304+
)
305+
306+
# Should be suspended with 3 triggers
307+
assert result.status == UiPathRuntimeStatus.SUSPENDED
308+
assert result.triggers is not None
309+
assert len(result.triggers) == 3
310+
311+
# Verify triggers were saved to storage
312+
saved_triggers = await storage.get_triggers("parallel-parallel-test")
313+
assert saved_triggers is not None
314+
assert len(saved_triggers) == 3
315+
316+
# Resume: All triggers should resolve immediately
317+
result_resume = await runtime.execute(
318+
input=None,
319+
options=UiPathExecuteOptions(resume=True),
320+
)
321+
322+
# Should now be successful (all triggers resolved in one go)
323+
assert result_resume.status == UiPathRuntimeStatus.SUCCESSFUL
324+
assert result_resume.output is not None
325+
326+
# Verify no triggers remain
327+
saved_triggers = await storage.get_triggers("parallel-parallel-test")
328+
assert saved_triggers is None or len(saved_triggers) == 0
329+
330+
# Verify all branches completed
331+
output = result_resume.output
332+
assert "branch_a_result" in output
333+
assert "branch_b_result" in output
334+
assert "branch_c_result" in output
335+
336+
# Verify all branches got their responses
337+
assert "Response for Branch A needs input" in output["branch_a_result"]
338+
assert "Response for Branch B needs input" in output["branch_b_result"]
339+
assert "Response for Branch C needs input" in output["branch_c_result"]
340+
341+
finally:
342+
if os.path.exists(temp_db.name):
343+
os.remove(temp_db.name)
344+
345+
346+
@pytest.mark.asyncio
347+
async def test_two_branches_with_two_sequential_interrupts_each():
348+
"""Test graph execution with 2 parallel branches, each having 2 sequential interrupts."""
349+
350+
# Define state
351+
class State(TypedDict, total=False):
352+
branch_a_first_result: str | None
353+
branch_a_second_result: str | None
354+
branch_b_first_result: str | None
355+
branch_b_second_result: str | None
356+
357+
# Define nodes that interrupt twice sequentially
358+
def branch_a_first(state: State) -> State:
359+
result = interrupt({"message": "Branch A - First interrupt"})
360+
return {"branch_a_first_result": f"A-1 completed with: {result}"}
361+
362+
def branch_a_second(state: State) -> State:
363+
result = interrupt({"message": "Branch A - Second interrupt"})
364+
return {"branch_a_second_result": f"A-2 completed with: {result}"}
365+
366+
def branch_b_first(state: State) -> State:
367+
result = interrupt({"message": "Branch B - First interrupt"})
368+
return {"branch_b_first_result": f"B-1 completed with: {result}"}
369+
370+
def branch_b_second(state: State) -> State:
371+
result = interrupt({"message": "Branch B - Second interrupt"})
372+
return {"branch_b_second_result": f"B-2 completed with: {result}"}
373+
374+
# Build graph with parallel branches, each with sequential nodes
375+
graph = StateGraph(State)
376+
graph.add_node("branch_a_first", branch_a_first)
377+
graph.add_node("branch_a_second", branch_a_second)
378+
graph.add_node("branch_b_first", branch_b_first)
379+
graph.add_node("branch_b_second", branch_b_second)
380+
381+
# Branch A: START -> a_first -> a_second -> END
382+
graph.add_edge(START, "branch_a_first")
383+
graph.add_edge("branch_a_first", "branch_a_second")
384+
graph.add_edge("branch_a_second", END)
385+
386+
# Branch B: START -> b_first -> b_second -> END
387+
graph.add_edge(START, "branch_b_first")
388+
graph.add_edge("branch_b_first", "branch_b_second")
389+
graph.add_edge("branch_b_second", END)
390+
391+
# Create temporary database
392+
temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db")
393+
temp_db.close()
394+
395+
try:
396+
# Compile graph with checkpointer
397+
async with AsyncSqliteSaver.from_conn_string(temp_db.name) as memory:
398+
compiled_graph = graph.compile(checkpointer=memory)
399+
400+
# Create base runtime
401+
base_runtime = UiPathLangGraphRuntime(
402+
graph=compiled_graph,
403+
runtime_id="two-branches-sequential-test",
404+
entrypoint="test",
405+
)
406+
407+
# Create storage and trigger manager
408+
storage = SqliteResumableStorage(memory)
409+
410+
# Wrap with UiPathResumableRuntime using parallel trigger handler
411+
runtime = UiPathResumableRuntime(
412+
delegate=base_runtime,
413+
storage=storage,
414+
trigger_manager=ParallelTriggerHandler(),
415+
runtime_id="two-branches-sequential-test",
416+
)
417+
418+
# First execution - should hit first interrupt in both branches (2 total)
419+
result = await runtime.execute(
420+
input={
421+
"branch_a_first_result": None,
422+
"branch_a_second_result": None,
423+
"branch_b_first_result": None,
424+
"branch_b_second_result": None,
425+
},
426+
options=UiPathExecuteOptions(resume=False),
427+
)
428+
429+
# Should be suspended with 2 triggers (first interrupt from each branch)
430+
assert result.status == UiPathRuntimeStatus.SUSPENDED
431+
assert result.triggers is not None
432+
assert len(result.triggers) == 2
433+
434+
# Verify triggers were saved to storage
435+
saved_triggers = await storage.get_triggers("two-branches-sequential-test")
436+
assert saved_triggers is not None
437+
assert len(saved_triggers) == 2
438+
439+
# Verify we got the first interrupts from both branches
440+
trigger_messages = [t.payload.get("message") for t in result.triggers]
441+
assert "Branch A - First interrupt" in trigger_messages
442+
assert "Branch B - First interrupt" in trigger_messages
443+
444+
# Resume 1: Resolve first interrupts, will hit second interrupts
445+
result_1 = await runtime.execute(
446+
input=None,
447+
options=UiPathExecuteOptions(resume=True),
448+
)
449+
450+
# Should still be suspended with 2 triggers (second interrupt from each branch)
451+
assert result_1.status == UiPathRuntimeStatus.SUSPENDED
452+
assert result_1.triggers is not None
453+
assert len(result_1.triggers) == 2
454+
455+
# Verify we got the second interrupts from both branches
456+
trigger_messages = [t.payload.get("message") for t in result_1.triggers]
457+
assert "Branch A - Second interrupt" in trigger_messages
458+
assert "Branch B - Second interrupt" in trigger_messages
459+
460+
# Verify 2 triggers remain in storage
461+
saved_triggers = await storage.get_triggers("two-branches-sequential-test")
462+
assert saved_triggers is not None
463+
assert len(saved_triggers) == 2
464+
465+
# Resume 2: Resolve second interrupts, should complete
466+
result_2 = await runtime.execute(
467+
input=None,
468+
options=UiPathExecuteOptions(resume=True),
469+
)
470+
471+
# Should now be successful
472+
assert result_2.status == UiPathRuntimeStatus.SUCCESSFUL
473+
assert result_2.output is not None
474+
475+
# Verify no triggers remain
476+
saved_triggers = await storage.get_triggers("two-branches-sequential-test")
477+
assert saved_triggers is None or len(saved_triggers) == 0
478+
479+
# Verify all branch steps completed
480+
output = result_2.output
481+
assert "branch_a_first_result" in output
482+
assert "branch_a_second_result" in output
483+
assert "branch_b_first_result" in output
484+
assert "branch_b_second_result" in output
485+
486+
# Verify all steps got their responses
487+
assert (
488+
"Response for Branch A - First interrupt"
489+
in output["branch_a_first_result"]
490+
)
491+
assert (
492+
"Response for Branch A - Second interrupt"
493+
in output["branch_a_second_result"]
494+
)
495+
assert (
496+
"Response for Branch B - First interrupt"
497+
in output["branch_b_first_result"]
498+
)
499+
assert (
500+
"Response for Branch B - Second interrupt"
501+
in output["branch_b_second_result"]
502+
)
503+
504+
finally:
505+
if os.path.exists(temp_db.name):
506+
os.remove(temp_db.name)

0 commit comments

Comments
 (0)