Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 28 additions & 16 deletions langfuse/callback/langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,24 +149,36 @@ def on_llm_new_token(

self.updated_completion_start_time_memo.add(run_id)

def get_langchain_run_name(self, serialized: Dict[str, Any], **kwargs: Any) -> str:
"""Retrieves the 'run_name' for an entity based on Langchain convention, prioritizing the 'name'
key in 'kwargs' or falling back to the 'name' or 'id' in 'serialized'. Defaults to "<unknown>"
if none are available.
def get_langchain_run_name(self, serialized: Optional[Dict[str, Any]], **kwargs: Any) -> str:
"""Retrieve the name of a serialized LangChain runnable.

The prioritization for the determination of the run name is as follows:
- The value assigned to the "name" key in `kwargs`.
- The value assigned to the "name" key in `serialized`.
- The last entry of the value assigned to the "id" key in `serialized`.
- "<unknown>".

Args:
serialized (Dict[str, Any]): A dictionary containing the entity's serialized data.
serialized (Optional[Dict[str, Any]]): A dictionary containing the runnable's serialized data.
**kwargs (Any): Additional keyword arguments, potentially including the 'name' override.

Returns:
str: The determined Langchain run name for the entity.
str: The determined name of the Langchain runnable.
"""
# Check if 'name' is in kwargs and not None, otherwise use default fallback logic
if "name" in kwargs and kwargs["name"] is not None:
return kwargs["name"]

# Fallback to serialized 'name', 'id', or "<unknown>"
return serialized.get("name", serialized.get("id", ["<unknown>"])[-1])
try:
return serialized["name"]
except (KeyError, TypeError):
pass

try:
return serialized["id"][-1]
except (KeyError, TypeError):
pass

return "<unknown>"

def on_retriever_error(
self,
Expand Down Expand Up @@ -196,7 +208,7 @@ def on_retriever_error(

def on_chain_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
inputs: Dict[str, Any],
*,
run_id: UUID,
Expand Down Expand Up @@ -289,7 +301,7 @@ def _deregister_langfuse_prompt(self, run_id: Optional[UUID]):

def __generate_trace_and_parent(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
inputs: Union[Dict[str, Any], List[str], str, None],
*,
run_id: UUID,
Expand Down Expand Up @@ -479,7 +491,7 @@ def on_chain_error(

def on_chat_model_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
messages: List[List[BaseMessage]],
*,
run_id: UUID,
Expand Down Expand Up @@ -508,7 +520,7 @@ def on_chat_model_start(

def on_llm_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
prompts: List[str],
*,
run_id: UUID,
Expand All @@ -535,7 +547,7 @@ def on_llm_start(

def on_tool_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
input_str: str,
*,
run_id: UUID,
Expand Down Expand Up @@ -573,7 +585,7 @@ def on_tool_start(

def on_retriever_start(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
query: str,
*,
run_id: UUID,
Expand Down Expand Up @@ -698,7 +710,7 @@ def on_tool_error(

def __on_llm_action(
self,
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
run_id: UUID,
prompts: List[str],
parent_run_id: Optional[UUID] = None,
Expand Down
23 changes: 17 additions & 6 deletions langfuse/extract_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


def _extract_model_name(
serialized: Dict[str, Any],
serialized: Optional[Dict[str, Any]],
**kwargs: Any,
):
"""Extracts the model name from the serialized or kwargs object. This is used to get the model names for Langfuse."""
Expand Down Expand Up @@ -106,13 +106,18 @@ def _extract_model_name(


def _extract_model_from_repr_by_pattern(
id: str, serialized: dict, pattern: str, default: Optional[str] = None
id: str, serialized: Optional[Dict[str, Any]], pattern: str, default: Optional[str] = None
):
if serialized is None:
return None

if serialized.get("id")[-1] == id:
if serialized.get("repr"):
extracted = _extract_model_with_regex(pattern, serialized.get("repr"))
return extracted if extracted else default if default else None

return None


def _extract_model_with_regex(pattern: str, text: str):
match = re.search(rf"{pattern}='(.*?)'", text)
Expand All @@ -123,21 +128,27 @@ def _extract_model_with_regex(pattern: str, text: str):

def _extract_model_by_path_for_id(
id: str,
serialized: dict,
serialized: Optional[Dict[str, Any]],
kwargs: dict,
keys: List[str],
select_from: str = Literal["serialized", "kwargs"],
select_from: Literal["serialized", "kwargs"],
):
if serialized is None and select_from == "serialized":
return None

if serialized.get("id")[-1] == id:
return _extract_model_by_path(serialized, kwargs, keys, select_from)


def _extract_model_by_path(
serialized: dict,
serialized: Optional[Dict[str, Any]],
kwargs: dict,
keys: List[str],
select_from: str = Literal["serialized", "kwargs"],
select_from: Literal["serialized", "kwargs"],
):
if serialized is None and select_from == "serialized":
return None

current_obj = kwargs if select_from == "kwargs" else serialized

for key in keys:
Expand Down
Loading