|
3 | 3 | import string |
4 | 4 | import time |
5 | 5 | from time import sleep |
6 | | -from typing import Any, Dict, List, Mapping, Optional |
| 6 | +from typing import Any, Dict, List, Literal, Mapping, Optional |
7 | 7 |
|
8 | 8 | import pytest |
9 | 9 | from langchain.agents import AgentType, initialize_agent |
|
30 | 30 | from langchain_core.language_models.llms import LLM |
31 | 31 | from langchain_core.output_parsers import StrOutputParser |
32 | 32 | from langchain_core.runnables.base import RunnableLambda |
33 | | -from langchain_core.tools import StructuredTool |
| 33 | +from langchain_core.tools import StructuredTool, tool |
34 | 34 | from langchain_openai import AzureChatOpenAI, ChatOpenAI, OpenAI |
| 35 | +from langgraph.checkpoint.memory import MemorySaver |
| 36 | +from langgraph.graph import END, START, MessagesState, StateGraph |
| 37 | +from langgraph.prebuilt import ToolNode |
35 | 38 | from pydantic.v1 import BaseModel, Field |
36 | 39 |
|
37 | 40 | from langfuse.callback import CallbackHandler |
| 41 | +from langfuse.callback.langchain import LANGSMITH_TAG_HIDDEN |
38 | 42 | from langfuse.client import Langfuse |
39 | 43 | from tests.api_wrapper import LangfuseAPI |
40 | 44 | from tests.utils import create_uuid, encode_file_to_base64, get_api |
@@ -2223,3 +2227,92 @@ def test_multimodal(): |
2223 | 2227 | "@@@langfuseMedia:type=image/jpeg|id=" |
2224 | 2228 | in trace.observations[0].input[0]["content"][1]["image_url"]["url"] |
2225 | 2229 | ) |
| 2230 | + |
| 2231 | + |
| 2232 | +def test_langgraph(): |
| 2233 | + # Define the tools for the agent to use |
| 2234 | + @tool |
| 2235 | + def search(query: str): |
| 2236 | + """Call to surf the web.""" |
| 2237 | + # This is a placeholder, but don't tell the LLM that... |
| 2238 | + if "sf" in query.lower() or "san francisco" in query.lower(): |
| 2239 | + return "It's 60 degrees and foggy." |
| 2240 | + return "It's 90 degrees and sunny." |
| 2241 | + |
| 2242 | + tools = [search] |
| 2243 | + tool_node = ToolNode(tools) |
| 2244 | + model = ChatOpenAI(model="gpt-4o-mini").bind_tools(tools) |
| 2245 | + |
| 2246 | + # Define the function that determines whether to continue or not |
| 2247 | + def should_continue(state: MessagesState) -> Literal["tools", END]: |
| 2248 | + messages = state["messages"] |
| 2249 | + last_message = messages[-1] |
| 2250 | + # If the LLM makes a tool call, then we route to the "tools" node |
| 2251 | + if last_message.tool_calls: |
| 2252 | + return "tools" |
| 2253 | + # Otherwise, we stop (reply to the user) |
| 2254 | + return END |
| 2255 | + |
| 2256 | + # Define the function that calls the model |
| 2257 | + def call_model(state: MessagesState): |
| 2258 | + messages = state["messages"] |
| 2259 | + response = model.invoke(messages) |
| 2260 | + # We return a list, because this will get added to the existing list |
| 2261 | + return {"messages": [response]} |
| 2262 | + |
| 2263 | + # Define a new graph |
| 2264 | + workflow = StateGraph(MessagesState) |
| 2265 | + |
| 2266 | + # Define the two nodes we will cycle between |
| 2267 | + workflow.add_node("agent", call_model) |
| 2268 | + workflow.add_node("tools", tool_node) |
| 2269 | + |
| 2270 | + # Set the entrypoint as `agent` |
| 2271 | + # This means that this node is the first one called |
| 2272 | + workflow.add_edge(START, "agent") |
| 2273 | + |
| 2274 | + # We now add a conditional edge |
| 2275 | + workflow.add_conditional_edges( |
| 2276 | + # First, we define the start node. We use `agent`. |
| 2277 | + # This means these are the edges taken after the `agent` node is called. |
| 2278 | + "agent", |
| 2279 | + # Next, we pass in the function that will determine which node is called next. |
| 2280 | + should_continue, |
| 2281 | + ) |
| 2282 | + |
| 2283 | + # We now add a normal edge from `tools` to `agent`. |
| 2284 | + # This means that after `tools` is called, `agent` node is called next. |
| 2285 | + workflow.add_edge("tools", "agent") |
| 2286 | + |
| 2287 | + # Initialize memory to persist state between graph runs |
| 2288 | + checkpointer = MemorySaver() |
| 2289 | + |
| 2290 | + # Finally, we compile it! |
| 2291 | + # This compiles it into a LangChain Runnable, |
| 2292 | + # meaning you can use it as you would any other runnable. |
| 2293 | + # Note that we're (optionally) passing the memory when compiling the graph |
| 2294 | + app = workflow.compile(checkpointer=checkpointer) |
| 2295 | + |
| 2296 | + handler = CallbackHandler() |
| 2297 | + |
| 2298 | + # Use the Runnable |
| 2299 | + final_state = app.invoke( |
| 2300 | + {"messages": [HumanMessage(content="what is the weather in sf")]}, |
| 2301 | + config={"configurable": {"thread_id": 42}, "callbacks": [handler]}, |
| 2302 | + ) |
| 2303 | + print(final_state["messages"][-1].content) |
| 2304 | + handler.flush() |
| 2305 | + |
| 2306 | + trace = get_api().trace.get(handler.get_trace_id()) |
| 2307 | + |
| 2308 | + hidden_count = 0 |
| 2309 | + |
| 2310 | + for observation in trace.observations: |
| 2311 | + if LANGSMITH_TAG_HIDDEN in observation.metadata.get("tags", []): |
| 2312 | + hidden_count += 1 |
| 2313 | + assert observation.level == "DEBUG" |
| 2314 | + |
| 2315 | + else: |
| 2316 | + assert observation.level == "DEFAULT" |
| 2317 | + |
| 2318 | + assert hidden_count > 0 |
0 commit comments