|
13 | 13 | SimpleSequentialChain, |
14 | 14 | ConversationChain, |
15 | 15 | ) |
| 16 | +from langchain_core.tools import StructuredTool |
16 | 17 | from langchain.chains.openai_functions import create_openai_fn_chain |
17 | 18 | from langchain.chains.summarize import load_summarize_chain |
18 | 19 | from langchain_community.document_loaders import TextLoader |
@@ -2095,3 +2096,65 @@ def test_get_langchain_chat_prompt_with_precompiled_prompt(): |
2095 | 2096 |
|
2096 | 2097 | assert system_message.content == "This is a dog." |
2097 | 2098 | assert user_message.content == "This is a langchain chain." |
| 2099 | + |
| 2100 | + |
| 2101 | +def test_callback_openai_functions_with_tools(): |
| 2102 | + handler = CallbackHandler() |
| 2103 | + |
| 2104 | + llm = ChatOpenAI(model="gpt-4", temperature=0, callbacks=[handler]) |
| 2105 | + |
| 2106 | + class StandardizedAddress(BaseModel): |
| 2107 | + street: str = Field(description="The street name and number") |
| 2108 | + city: str = Field(description="The city name") |
| 2109 | + state: str = Field(description="The state or province") |
| 2110 | + zip_code: str = Field(description="The postal code") |
| 2111 | + |
| 2112 | + class GetWeather(BaseModel): |
| 2113 | + city: str = Field(description="The city name") |
| 2114 | + state: str = Field(description="The state or province") |
| 2115 | + zip_code: str = Field(description="The postal code") |
| 2116 | + |
| 2117 | + address_tool = StructuredTool.from_function( |
| 2118 | + func=lambda **kwargs: StandardizedAddress(**kwargs), |
| 2119 | + name="standardize_address", |
| 2120 | + description="Standardize the given address", |
| 2121 | + args_schema=StandardizedAddress, |
| 2122 | + ) |
| 2123 | + |
| 2124 | + weather_tool = StructuredTool.from_function( |
| 2125 | + func=lambda **kwargs: GetWeather(**kwargs), |
| 2126 | + name="get_weather", |
| 2127 | + description="Get the weather for the given city", |
| 2128 | + args_schema=GetWeather, |
| 2129 | + ) |
| 2130 | + |
| 2131 | + messages = [ |
| 2132 | + { |
| 2133 | + "role": "user", |
| 2134 | + "content": "Please standardize this address: 123 Main St, Springfield, IL 62701", |
| 2135 | + } |
| 2136 | + ] |
| 2137 | + |
| 2138 | + llm.bind_tools([address_tool, weather_tool]).invoke(messages) |
| 2139 | + |
| 2140 | + handler.flush() |
| 2141 | + |
| 2142 | + api = get_api() |
| 2143 | + trace = api.trace.get(handler.get_trace_id()) |
| 2144 | + |
| 2145 | + generations = list(filter(lambda x: x.type == "GENERATION", trace.observations)) |
| 2146 | + assert len(generations) > 0 |
| 2147 | + |
| 2148 | + for generation in generations: |
| 2149 | + assert generation.input is not None |
| 2150 | + tool_messages = [msg for msg in generation.input if msg["role"] == "tool"] |
| 2151 | + assert len(tool_messages) == 2 |
| 2152 | + assert any( |
| 2153 | + "standardize_address" == msg["content"]["function"]["name"] |
| 2154 | + for msg in tool_messages |
| 2155 | + ) |
| 2156 | + assert any( |
| 2157 | + "get_weather" == msg["content"]["function"]["name"] for msg in tool_messages |
| 2158 | + ) |
| 2159 | + |
| 2160 | + assert generation.output is not None |
0 commit comments