Skip to content

Commit 2a0edf4

Browse files
committed
feat(langgraph): extract original tool_call_id from runtime for consistency
Add _extract_tool_call_id function to handle MCP tool runtime objects and prioritize original tool_call_id over run_id for consistent event tracking. Update converter to use extracted ID or fallback to run_id when not available. Add comprehensive tests for tool start/end events with runtime tool_call_id. fixes inconsistent tool call ID tracking when using MCP tools that inject runtime objects with original LLM-generated tool_call_id 为 runtime 对象添加提取原始 tool_call_id 的功能以确保一致性 添加 _extract_tool_call_id 函数来处理 MCP 工具 runtime 对象,并优先使用原始 tool_call_id 而不是 run_id 来确保事件跟踪的一致性。更新转换器以使用提取的 ID, 当不可用时回退到 run_id。为带有 runtime tool_call_id 的工具开始/结束事件添加 全面的测试。 修复使用注入包含原始 LLM 生成 tool_call_id 的 runtime 对象的 MCP 工具时 工具调用 ID 跟踪不一致的问题 Change-Id: I2838c7b88ea8c01c87b39038d2b92a06bea89167 Signed-off-by: OhYee <oyohyee@oyohyee.com>
1 parent 175ec93 commit 2a0edf4

File tree

2 files changed

+133
-6
lines changed

2 files changed

+133
-6
lines changed

agentrun/integration/langgraph/agent_converter.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,31 @@ def _filter_tool_input(tool_input: Any) -> Any:
114114
return filtered
115115

116116

117+
def _extract_tool_call_id(tool_input: Any) -> Optional[str]:
118+
"""从工具输入中提取原始的 tool_call_id。
119+
120+
MCP 工具会在 input 中注入 runtime 对象,其中包含 LLM 返回的原始 tool_call_id。
121+
使用这个 ID 可以保证工具调用事件的 ID 一致性。
122+
123+
Args:
124+
tool_input: 工具输入(可能是 dict 或其他类型)
125+
126+
Returns:
127+
tool_call_id 或 None
128+
"""
129+
if not isinstance(tool_input, dict):
130+
return None
131+
132+
# 尝试从 runtime 对象中提取 tool_call_id
133+
runtime = tool_input.get("runtime")
134+
if runtime is not None and hasattr(runtime, "tool_call_id"):
135+
tc_id = runtime.tool_call_id
136+
if isinstance(tc_id, str) and tc_id:
137+
return tc_id
138+
139+
return None
140+
141+
117142
def _extract_content(chunk: Any) -> Optional[str]:
118143
"""从 chunk 中提取文本内容"""
119144
if chunk is None:
@@ -538,13 +563,18 @@ def _convert_astream_events_event(
538563
run_id = event_dict.get("run_id", "")
539564
tool_name = event_dict.get("name", "")
540565
tool_input_raw = data.get("input", {})
566+
# 优先使用 runtime 中的原始 tool_call_id,保证 ID 一致性
567+
tool_call_id = _extract_tool_call_id(tool_input_raw) or run_id
541568
# 过滤掉内部字段(如 MCP 注入的 runtime)
542569
tool_input = _filter_tool_input(tool_input_raw)
543570

544-
if run_id:
571+
if tool_call_id:
545572
yield AgentResult(
546573
event=EventType.TOOL_CALL_START,
547-
data={"tool_call_id": run_id, "tool_call_name": tool_name},
574+
data={
575+
"tool_call_id": tool_call_id,
576+
"tool_call_name": tool_name,
577+
},
548578
)
549579
if tool_input:
550580
args_str = (
@@ -554,25 +584,28 @@ def _convert_astream_events_event(
554584
)
555585
yield AgentResult(
556586
event=EventType.TOOL_CALL_ARGS,
557-
data={"tool_call_id": run_id, "delta": args_str},
587+
data={"tool_call_id": tool_call_id, "delta": args_str},
558588
)
559589

560590
# 4. 工具结束
561591
elif event_type == "on_tool_end":
562592
run_id = event_dict.get("run_id", "")
563593
output = data.get("output", "")
594+
tool_input_raw = data.get("input", {})
595+
# 优先使用 runtime 中的原始 tool_call_id,保证 ID 一致性
596+
tool_call_id = _extract_tool_call_id(tool_input_raw) or run_id
564597

565-
if run_id:
598+
if tool_call_id:
566599
yield AgentResult(
567600
event=EventType.TOOL_CALL_RESULT,
568601
data={
569-
"tool_call_id": run_id,
602+
"tool_call_id": tool_call_id,
570603
"result": _format_tool_output(output),
571604
},
572605
)
573606
yield AgentResult(
574607
event=EventType.TOOL_CALL_END,
575-
data={"tool_call_id": run_id},
608+
data={"tool_call_id": tool_call_id},
576609
)
577610

578611
# 5. LLM 结束

tests/unittests/integration/test_langchain_convert.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,100 @@ def __str__(self):
307307
assert "internal" not in delta
308308
assert "__pregel" not in delta
309309

310+
def test_on_tool_start_uses_runtime_tool_call_id(self):
311+
"""测试 on_tool_start 使用 runtime 中的原始 tool_call_id 而非 run_id
312+
313+
MCP 工具会在 input.runtime 中注入 tool_call_id,这是 LLM 返回的原始 ID。
314+
应该优先使用这个 ID,以保证工具调用事件的 ID 一致性。
315+
"""
316+
317+
class FakeToolRuntime:
318+
"""模拟 MCP 的 ToolRuntime 对象"""
319+
320+
def __init__(self, tool_call_id: str):
321+
self.tool_call_id = tool_call_id
322+
323+
original_tool_call_id = "call_original_from_llm_12345"
324+
325+
event = {
326+
"event": "on_tool_start",
327+
"name": "get_weather",
328+
"run_id": (
329+
"run_id_different_from_tool_call_id"
330+
), # run_id 与 tool_call_id 不同
331+
"data": {
332+
"input": {
333+
"city": "北京",
334+
"runtime": FakeToolRuntime(original_tool_call_id),
335+
}
336+
},
337+
}
338+
339+
results = list(convert(event))
340+
341+
# TOOL_CALL_START + TOOL_CALL_ARGS
342+
assert len(results) == 2
343+
344+
# 应该使用 runtime 中的原始 tool_call_id,而不是 run_id
345+
assert results[0].event == EventType.TOOL_CALL_START
346+
assert results[0].data["tool_call_id"] == original_tool_call_id
347+
assert results[0].data["tool_call_name"] == "get_weather"
348+
349+
assert results[1].event == EventType.TOOL_CALL_ARGS
350+
assert results[1].data["tool_call_id"] == original_tool_call_id
351+
352+
def test_on_tool_end_uses_runtime_tool_call_id(self):
353+
"""测试 on_tool_end 使用 runtime 中的原始 tool_call_id 而非 run_id"""
354+
355+
class FakeToolRuntime:
356+
"""模拟 MCP 的 ToolRuntime 对象"""
357+
358+
def __init__(self, tool_call_id: str):
359+
self.tool_call_id = tool_call_id
360+
361+
original_tool_call_id = "call_original_from_llm_67890"
362+
363+
event = {
364+
"event": "on_tool_end",
365+
"run_id": "run_id_different_from_tool_call_id",
366+
"data": {
367+
"output": {"weather": "晴天", "temp": 25},
368+
"input": {
369+
"city": "北京",
370+
"runtime": FakeToolRuntime(original_tool_call_id),
371+
},
372+
},
373+
}
374+
375+
results = list(convert(event))
376+
377+
# TOOL_CALL_RESULT + TOOL_CALL_END
378+
assert len(results) == 2
379+
380+
# 应该使用 runtime 中的原始 tool_call_id
381+
assert results[0].event == EventType.TOOL_CALL_RESULT
382+
assert results[0].data["tool_call_id"] == original_tool_call_id
383+
384+
assert results[1].event == EventType.TOOL_CALL_END
385+
assert results[1].data["tool_call_id"] == original_tool_call_id
386+
387+
def test_on_tool_start_fallback_to_run_id(self):
388+
"""测试当 runtime 中没有 tool_call_id 时,回退使用 run_id"""
389+
event = {
390+
"event": "on_tool_start",
391+
"name": "get_time",
392+
"run_id": "run_789",
393+
"data": {"input": {"timezone": "Asia/Shanghai"}}, # 没有 runtime
394+
}
395+
396+
results = list(convert(event))
397+
398+
assert len(results) == 2
399+
assert results[0].event == EventType.TOOL_CALL_START
400+
# 应该回退使用 run_id
401+
assert results[0].data["tool_call_id"] == "run_789"
402+
assert results[1].data["tool_call_id"] == "run_789"
403+
310404
def test_on_chain_stream_model_node(self):
311405
"""测试 on_chain_stream 事件(model 节点)"""
312406
msg = create_mock_ai_message("你好!有什么可以帮你的吗?")

0 commit comments

Comments
 (0)