Skip to content

Commit 78cadef

Browse files
authored
fix(langchain): parse tool calls in input (#965)
1 parent 8bd61dc commit 78cadef

File tree

2 files changed

+67
-0
lines changed

2 files changed

+67
-0
lines changed

langfuse/callback/langchain.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,10 @@ def __on_llm_action(
704704
**kwargs: Any,
705705
):
706706
try:
707+
tools = kwargs.get("invocation_params", {}).get("tools", None)
708+
if tools and isinstance(tools, list):
709+
prompts.extend([{"role": "tool", "content": tool} for tool in tools])
710+
707711
self.__generate_trace_and_parent(
708712
serialized,
709713
inputs=prompts[0] if len(prompts) == 1 else prompts,

tests/test_langchain.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
SimpleSequentialChain,
1414
ConversationChain,
1515
)
16+
from langchain_core.tools import StructuredTool
1617
from langchain.chains.openai_functions import create_openai_fn_chain
1718
from langchain.chains.summarize import load_summarize_chain
1819
from langchain_community.document_loaders import TextLoader
@@ -2095,3 +2096,65 @@ def test_get_langchain_chat_prompt_with_precompiled_prompt():
20952096

20962097
assert system_message.content == "This is a dog."
20972098
assert user_message.content == "This is a langchain chain."
2099+
2100+
2101+
def test_callback_openai_functions_with_tools():
2102+
handler = CallbackHandler()
2103+
2104+
llm = ChatOpenAI(model="gpt-4", temperature=0, callbacks=[handler])
2105+
2106+
class StandardizedAddress(BaseModel):
2107+
street: str = Field(description="The street name and number")
2108+
city: str = Field(description="The city name")
2109+
state: str = Field(description="The state or province")
2110+
zip_code: str = Field(description="The postal code")
2111+
2112+
class GetWeather(BaseModel):
2113+
city: str = Field(description="The city name")
2114+
state: str = Field(description="The state or province")
2115+
zip_code: str = Field(description="The postal code")
2116+
2117+
address_tool = StructuredTool.from_function(
2118+
func=lambda **kwargs: StandardizedAddress(**kwargs),
2119+
name="standardize_address",
2120+
description="Standardize the given address",
2121+
args_schema=StandardizedAddress,
2122+
)
2123+
2124+
weather_tool = StructuredTool.from_function(
2125+
func=lambda **kwargs: GetWeather(**kwargs),
2126+
name="get_weather",
2127+
description="Get the weather for the given city",
2128+
args_schema=GetWeather,
2129+
)
2130+
2131+
messages = [
2132+
{
2133+
"role": "user",
2134+
"content": "Please standardize this address: 123 Main St, Springfield, IL 62701",
2135+
}
2136+
]
2137+
2138+
llm.bind_tools([address_tool, weather_tool]).invoke(messages)
2139+
2140+
handler.flush()
2141+
2142+
api = get_api()
2143+
trace = api.trace.get(handler.get_trace_id())
2144+
2145+
generations = list(filter(lambda x: x.type == "GENERATION", trace.observations))
2146+
assert len(generations) > 0
2147+
2148+
for generation in generations:
2149+
assert generation.input is not None
2150+
tool_messages = [msg for msg in generation.input if msg["role"] == "tool"]
2151+
assert len(tool_messages) == 2
2152+
assert any(
2153+
"standardize_address" == msg["content"]["function"]["name"]
2154+
for msg in tool_messages
2155+
)
2156+
assert any(
2157+
"get_weather" == msg["content"]["function"]["name"] for msg in tool_messages
2158+
)
2159+
2160+
assert generation.output is not None

0 commit comments

Comments
 (0)