Skip to content

Commit d590f20

Browse files
committed
Add integration tests for tool calling with Meta and Cohere models
Tests verify that the fix prevents infinite tool calling loops for: - Meta Llama 4 Scout - Meta Llama 3.3 70B - Cohere Command A - Cohere Command R Plus Each test confirms that after receiving tool results, the model generates a final response without making additional tool calls. Signed-off-by: Federico Kamelhar <federico.kamelhar@oracle.com>
1 parent f69514c commit d590f20

File tree

1 file changed

+190
-0
lines changed

1 file changed

+190
-0
lines changed
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
# Copyright (c) 2025 Oracle and/or its affiliates.
2+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
3+
4+
"""Integration tests for tool calling with OCI Generative AI chat models.
5+
6+
These tests verify that tool calling works correctly without infinite loops
7+
for both Meta and Cohere models after receiving tool results.
8+
"""
9+
10+
import os
11+
import pytest
12+
from langchain_core.messages import HumanMessage, SystemMessage
13+
from langchain_oci.chat_models import ChatOCIGenAI
14+
from langgraph.graph import StateGraph, START, END, MessagesState
15+
from langgraph.prebuilt import ToolNode
16+
from langchain.tools import StructuredTool
17+
18+
19+
def get_weather(city: str) -> str:
20+
"""Get the current weather for a given city name."""
21+
weather_data = {
22+
"chicago": "Sunny, 65°F",
23+
"new york": "Cloudy, 60°F",
24+
"san francisco": "Foggy, 58°F",
25+
}
26+
return weather_data.get(city.lower(), f"Weather data not available for {city}")
27+
28+
29+
@pytest.fixture
30+
def weather_tool():
31+
"""Create a weather tool for testing."""
32+
return StructuredTool.from_function(
33+
func=get_weather,
34+
name="get_weather",
35+
description="Get the current weather for a given city name.",
36+
)
37+
38+
39+
def create_agent(model_id: str, weather_tool: StructuredTool):
40+
"""Create a LangGraph agent with tool calling."""
41+
chat_model = ChatOCIGenAI(
42+
model_id=model_id,
43+
service_endpoint=f"https://inference.generativeai.{os.getenv('OCI_REGION', 'us-chicago-1')}.oci.oraclecloud.com",
44+
compartment_id=os.getenv("OCI_COMP"),
45+
model_kwargs={"temperature": 0.3, "max_tokens": 512, "top_p": 0.9},
46+
auth_type="SECURITY_TOKEN",
47+
auth_profile="DEFAULT",
48+
auth_file_location=os.path.expanduser("~/.oci/config"),
49+
disable_streaming="tool_calling",
50+
)
51+
52+
tool_node = ToolNode(tools=[weather_tool])
53+
model_with_tools = chat_model.bind_tools([weather_tool])
54+
55+
def call_model(state: MessagesState):
56+
"""Call the model with tools bound."""
57+
messages = state["messages"]
58+
response = model_with_tools.invoke(messages)
59+
return {"messages": [response]}
60+
61+
def should_continue(state: MessagesState):
62+
"""Check if the model wants to call a tool."""
63+
messages = state["messages"]
64+
last_message = messages[-1]
65+
66+
if hasattr(last_message, "tool_calls") and last_message.tool_calls:
67+
return "tools"
68+
return END
69+
70+
builder = StateGraph(MessagesState)
71+
builder.add_node("call_model", call_model)
72+
builder.add_node("tools", tool_node)
73+
builder.add_edge(START, "call_model")
74+
builder.add_conditional_edges("call_model", should_continue, ["tools", END])
75+
builder.add_edge("tools", "call_model")
76+
77+
return builder.compile()
78+
79+
80+
@pytest.mark.requires("oci")
81+
@pytest.mark.parametrize(
82+
"model_id",
83+
[
84+
"meta.llama-4-scout-17b-16e-instruct",
85+
"meta.llama-3.3-70b-instruct",
86+
"cohere.command-a-03-2025",
87+
"cohere.command-r-plus-08-2024",
88+
],
89+
)
90+
def test_tool_calling_no_infinite_loop(model_id: str, weather_tool: StructuredTool):
91+
"""Test that tool calling works without infinite loops.
92+
93+
This test verifies that after a tool is called and results are returned,
94+
the model generates a final response without making additional tool calls,
95+
preventing infinite loops.
96+
97+
The fix sets tool_choice='none' when ToolMessages are present in the
98+
conversation history, which tells the model to stop calling tools.
99+
"""
100+
agent = create_agent(model_id, weather_tool)
101+
102+
# Invoke the agent
103+
result = agent.invoke(
104+
{
105+
"messages": [
106+
SystemMessage(
107+
content="You are a helpful assistant. Use the available tools when needed to answer questions accurately."
108+
),
109+
HumanMessage(content="What's the weather in Chicago?"),
110+
]
111+
}
112+
)
113+
114+
messages = result["messages"]
115+
116+
# Verify the conversation structure
117+
assert len(messages) >= 4, "Should have at least: System, Human, AI (tool call), Tool, AI (final)"
118+
119+
# Find tool messages
120+
tool_messages = [msg for msg in messages if type(msg).__name__ == "ToolMessage"]
121+
assert len(tool_messages) >= 1, "Should have at least one tool result"
122+
123+
# Find AI messages with tool calls
124+
ai_tool_calls = [
125+
msg for msg in messages
126+
if type(msg).__name__ == "AIMessage" and hasattr(msg, "tool_calls") and msg.tool_calls
127+
]
128+
# The model should call the tool, but after receiving results, should not call again
129+
# Allow flexibility - some models might make 1 call, others might need 2, but should stop
130+
assert len(ai_tool_calls) <= 2, f"Model made too many tool calls ({len(ai_tool_calls)}), possible infinite loop"
131+
132+
# Verify final message is an AI response without tool calls
133+
final_message = messages[-1]
134+
assert type(final_message).__name__ == "AIMessage", "Final message should be AIMessage"
135+
assert final_message.content, "Final message should have content"
136+
assert not (hasattr(final_message, "tool_calls") and final_message.tool_calls), \
137+
"Final message should not have tool_calls (infinite loop prevention)"
138+
139+
# Verify the answer mentions the weather
140+
assert "65" in final_message.content or "sunny" in final_message.content.lower(), \
141+
"Final response should mention the weather data"
142+
143+
144+
@pytest.mark.requires("oci")
145+
def test_meta_llama_tool_calling(weather_tool: StructuredTool):
146+
"""Specific test for Meta Llama models to ensure fix works."""
147+
model_id = "meta.llama-4-scout-17b-16e-instruct"
148+
agent = create_agent(model_id, weather_tool)
149+
150+
result = agent.invoke(
151+
{
152+
"messages": [
153+
SystemMessage(content="You are a helpful assistant."),
154+
HumanMessage(content="Check the weather in San Francisco."),
155+
]
156+
}
157+
)
158+
159+
messages = result["messages"]
160+
final_message = messages[-1]
161+
162+
# Meta Llama was specifically affected by infinite loops
163+
# Verify it stops after receiving tool results
164+
assert type(final_message).__name__ == "AIMessage"
165+
assert not (hasattr(final_message, "tool_calls") and final_message.tool_calls)
166+
assert "foggy" in final_message.content.lower() or "58" in final_message.content
167+
168+
169+
@pytest.mark.requires("oci")
170+
def test_cohere_tool_calling(weather_tool: StructuredTool):
171+
"""Specific test for Cohere models to ensure they work correctly."""
172+
model_id = "cohere.command-a-03-2025"
173+
agent = create_agent(model_id, weather_tool)
174+
175+
result = agent.invoke(
176+
{
177+
"messages": [
178+
SystemMessage(content="You are a helpful assistant."),
179+
HumanMessage(content="What's the weather like in New York?"),
180+
]
181+
}
182+
)
183+
184+
messages = result["messages"]
185+
final_message = messages[-1]
186+
187+
# Cohere models should handle tool calling naturally
188+
assert type(final_message).__name__ == "AIMessage"
189+
assert not (hasattr(final_message, "tool_calls") and final_message.tool_calls)
190+
assert "60" in final_message.content or "cloudy" in final_message.content.lower()

0 commit comments

Comments
 (0)