Skip to content

Commit 24f31b6

Browse files
committed
Fix rebase issues: improve tool call conversion and fix tests
- Fix convert_oci_tool_call_to_langchain to handle both Generic/Meta format (arguments as JSON string) and Cohere format (parameters as dict) - Fix attribute_map check to handle None values from mock objects - Fix line length issues in optimization code and tests - Update test mocks to use correct OCI SDK format
1 parent fd52dc5 commit 24f31b6

File tree

2 files changed

+83
-47
lines changed

2 files changed

+83
-47
lines changed

libs/oci/langchain_oci/chat_models/oci_generative_ai.py

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -100,22 +100,37 @@ def remove_signature_from_tool_description(name: str, description: str) -> str:
100100
@staticmethod
101101
def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall:
102102
"""Convert an OCI tool call to a LangChain ToolCall."""
103-
parsed = json.loads(tool_call.arguments)
104-
105-
# If the parsed result is a string, it means the JSON was escaped, so parse again # noqa: E501
106-
if isinstance(parsed, str):
107-
try:
108-
parsed = json.loads(parsed)
109-
except json.JSONDecodeError:
110-
# If it's not valid JSON, keep it as a string
111-
pass
103+
# Check if this is a Generic/Meta format (has arguments as JSON string)
104+
# or Cohere format (has parameters as dict)
105+
attribute_map = getattr(tool_call, "attribute_map", None) or {}
106+
107+
if "arguments" in attribute_map and tool_call.arguments is not None:
108+
# Generic/Meta format: parse JSON arguments
109+
parsed = json.loads(tool_call.arguments)
110+
111+
# If the parsed result is a string, it means JSON was escaped
112+
if isinstance(parsed, str):
113+
try:
114+
parsed = json.loads(parsed)
115+
except json.JSONDecodeError:
116+
# If it's not valid JSON, keep it as a string
117+
pass
118+
args = parsed
119+
else:
120+
# Cohere format: parameters is already a dict
121+
args = tool_call.parameters
122+
123+
# Get tool call ID (generate one if not present)
124+
tool_id = (
125+
tool_call.id
126+
if "id" in attribute_map
127+
else uuid.uuid4().hex[:]
128+
)
112129

113130
return ToolCall(
114131
name=tool_call.name,
115-
args=parsed
116-
if "arguments" in tool_call.attribute_map
117-
else tool_call.parameters,
118-
id=tool_call.id if "id" in tool_call.attribute_map else uuid.uuid4().hex[:],
132+
args=args,
133+
id=tool_id,
119134
)
120135

121136

@@ -1423,9 +1438,8 @@ def _generate(
14231438
# Add formatted version to generation_info if not already present
14241439
# This avoids redundant formatting in chat_generation_info()
14251440
if "tool_calls" not in generation_info:
1426-
generation_info["tool_calls"] = self._provider.format_response_tool_calls(
1427-
raw_tool_calls
1428-
)
1441+
formatted = self._provider.format_response_tool_calls(raw_tool_calls)
1442+
generation_info["tool_calls"] = formatted
14291443
message = AIMessage(
14301444
content=content or "",
14311445
additional_kwargs=generation_info,

libs/oci/tests/unit_tests/chat_models/test_tool_call_optimization.py

Lines changed: 53 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def test_meta_tool_call_optimization() -> None:
2121
"""Test that tool calls are formatted once and cached for Meta models."""
2222
oci_gen_ai_client = MagicMock()
2323

24-
# Mock response with tool call
24+
# Mock response with tool call - using correct OCI format
2525
def mocked_response(*args): # type: ignore[no-untyped-def]
2626
return MockResponseDict(
2727
{
@@ -50,13 +50,13 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
5050
MockResponseDict(
5151
{
5252
"id": "test_id_123",
53-
"type": "FUNCTION",
54-
"function": MockResponseDict(
55-
{
56-
"name": "get_weather",
57-
"arguments": '{"location": "San Francisco"}',
58-
}
59-
),
53+
"name": "get_weather",
54+
"arguments": '{"location": "San Francisco"}', # noqa: E501
55+
"attribute_map": {
56+
"id": "id",
57+
"name": "name",
58+
"arguments": "arguments", # noqa: E501
59+
},
6060
}
6161
)
6262
],
@@ -94,7 +94,9 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
9494
oci_gen_ai_client.chat.side_effect = mocked_response
9595

9696
# Create LLM with mocked client
97-
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
97+
llm = ChatOCIGenAI(
98+
model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client
99+
)
98100

99101
# Define a simple tool
100102
def get_weather(location: str) -> str:
@@ -105,7 +107,7 @@ def get_weather(location: str) -> str:
105107
llm_with_tools = llm.bind_tools([get_weather])
106108

107109
# Invoke
108-
response = llm_with_tools.invoke([HumanMessage(content="What's the weather in SF?")])
110+
response = llm_with_tools.invoke([HumanMessage(content="What's the weather?")])
109111

110112
# Verify tool_calls field is populated
111113
assert len(response.tool_calls) == 1, "Should have one tool call"
@@ -115,7 +117,9 @@ def get_weather(location: str) -> str:
115117
assert "id" in tool_call
116118

117119
# Verify additional_kwargs contains formatted tool calls
118-
assert "tool_calls" in response.additional_kwargs, "Should have tool_calls in additional_kwargs"
120+
assert (
121+
"tool_calls" in response.additional_kwargs
122+
), "Should have tool_calls in additional_kwargs"
119123
additional_tool_calls = response.additional_kwargs["tool_calls"]
120124
assert len(additional_tool_calls) == 1
121125
assert additional_tool_calls[0]["type"] == "function"
@@ -128,7 +132,7 @@ def test_cohere_tool_call_optimization() -> None:
128132
"""Test that tool calls are formatted once and cached for Cohere models."""
129133
oci_gen_ai_client = MagicMock()
130134

131-
# Mock response with tool call
135+
# Mock response with tool call - using correct Cohere format
132136
def mocked_response(*args): # type: ignore[no-untyped-def]
133137
return MockResponseDict(
134138
{
@@ -139,6 +143,10 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
139143
{
140144
"text": "",
141145
"finish_reason": "TOOL_CALL",
146+
"documents": None,
147+
"citations": None,
148+
"search_queries": None,
149+
"is_search_required": None,
142150
"tool_calls": [
143151
MockResponseDict(
144152
{
@@ -181,7 +189,9 @@ def get_weather(location: str) -> str:
181189
llm_with_tools = llm.bind_tools([get_weather])
182190

183191
# Invoke
184-
response = llm_with_tools.invoke([HumanMessage(content="What's the weather in London?")])
192+
response = llm_with_tools.invoke(
193+
[HumanMessage(content="What's the weather in London?")]
194+
)
185195

186196
# Verify tool_calls field is populated
187197
assert len(response.tool_calls) == 1, "Should have one tool call"
@@ -193,7 +203,9 @@ def get_weather(location: str) -> str:
193203
assert len(tool_call["id"]) > 0, "Tool call ID should not be empty"
194204

195205
# Verify additional_kwargs contains formatted tool calls
196-
assert "tool_calls" in response.additional_kwargs, "Should have tool_calls in additional_kwargs"
206+
assert (
207+
"tool_calls" in response.additional_kwargs
208+
), "Should have tool_calls in additional_kwargs"
197209
additional_tool_calls = response.additional_kwargs["tool_calls"]
198210
assert len(additional_tool_calls) == 1
199211
assert additional_tool_calls[0]["type"] == "function"
@@ -205,7 +217,7 @@ def test_multiple_tool_calls_optimization() -> None:
205217
"""Test optimization with multiple tool calls."""
206218
oci_gen_ai_client = MagicMock()
207219

208-
# Mock response with multiple tool calls
220+
# Mock response with multiple tool calls - using correct OCI format
209221
def mocked_response(*args): # type: ignore[no-untyped-def]
210222
return MockResponseDict(
211223
{
@@ -233,25 +245,25 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
233245
MockResponseDict(
234246
{
235247
"id": "call_1",
236-
"type": "FUNCTION",
237-
"function": MockResponseDict(
238-
{
239-
"name": "get_weather",
240-
"arguments": '{"location": "Tokyo"}',
241-
}
242-
),
248+
"name": "get_weather",
249+
"arguments": '{"location": "Tokyo"}', # noqa: E501
250+
"attribute_map": {
251+
"id": "id",
252+
"name": "name",
253+
"arguments": "arguments", # noqa: E501
254+
},
243255
}
244256
),
245257
MockResponseDict(
246258
{
247259
"id": "call_2",
248-
"type": "FUNCTION",
249-
"function": MockResponseDict(
250-
{
251-
"name": "get_population",
252-
"arguments": '{"city": "Tokyo"}',
253-
}
254-
),
260+
"name": "get_population", # noqa: E501
261+
"arguments": '{"city": "Tokyo"}', # noqa: E501
262+
"attribute_map": {
263+
"id": "id",
264+
"name": "name",
265+
"arguments": "arguments", # noqa: E501
266+
},
255267
}
256268
),
257269
],
@@ -262,6 +274,7 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
262274
}
263275
)
264276
],
277+
"time_created": "2024-01-01T00:00:00Z",
265278
"usage": MockResponseDict(
266279
{
267280
"total_tokens": 200,
@@ -274,13 +287,20 @@ def mocked_response(*args): # type: ignore[no-untyped-def]
274287
}
275288
),
276289
"request_id": "test_request_789",
290+
"headers": MockResponseDict(
291+
{
292+
"content-length": "400",
293+
}
294+
),
277295
}
278296
)
279297

280298
oci_gen_ai_client.chat.side_effect = mocked_response
281299

282300
# Create LLM with mocked client
283-
llm = ChatOCIGenAI(model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client)
301+
llm = ChatOCIGenAI(
302+
model_id="meta.llama-3.3-70b-instruct", client=oci_gen_ai_client
303+
)
284304

285305
# Define tools
286306
def get_weather(location: str) -> str:
@@ -295,7 +315,9 @@ def get_population(city: str) -> int:
295315
llm_with_tools = llm.bind_tools([get_weather, get_population])
296316

297317
# Invoke
298-
response = llm_with_tools.invoke([HumanMessage(content="Weather and population of Tokyo?")])
318+
response = llm_with_tools.invoke(
319+
[HumanMessage(content="Weather and population of Tokyo?")]
320+
)
299321

300322
# Verify tool_calls field has both calls
301323
assert len(response.tool_calls) == 2, "Should have two tool calls"

0 commit comments

Comments
 (0)