diff --git a/langfuse/callback/langchain.py b/langfuse/callback/langchain.py index 8749ef914..c697c718d 100644 --- a/langfuse/callback/langchain.py +++ b/langfuse/callback/langchain.py @@ -1096,12 +1096,18 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]): for key, value in input_token_details.items(): usage_model[f"input_{key}"] = value + if "input" in usage_model: + usage_model["input"] = max(0, usage_model["input"] - value) + if "output_token_details" in usage_model: output_token_details = usage_model.pop("output_token_details", {}) for key, value in output_token_details.items(): usage_model[f"output_{key}"] = value + if "output" in usage_model: + usage_model["output"] = max(0, usage_model["output"] - value) + return usage_model if usage_model else None diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 86e49b970..90cdec1f2 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -2318,3 +2318,49 @@ def call_model(state: MessagesState): assert observation.level == "DEFAULT" assert hidden_count > 0 + + +def test_cached_token_usage(): + prompt = ChatPromptTemplate.from_messages( + [ + ( + "system", + ( + "This is a test prompt to reproduce the issue. " + "The prompt needs 1024 tokens to enable cache." * 100 + ), + ), + ("user", "Reply to this message {test_param}."), + ] + ) + chat = ChatOpenAI(model="gpt-4o-mini") + chain = prompt | chat + handler = CallbackHandler() + config = {"callbacks": [handler]} + + chain.invoke({"test_param": "in a funny way"}, config) + + # invoke again to force cached token usage + chain.invoke({"test_param": "in a funny way"}, config) + + handler.flush() + + trace = get_api().trace.get(handler.get_trace_id()) + + generation = next((o for o in trace.observations if o.type == "GENERATION")) + + assert generation.usage_details["input_cache_read"] > 0 + assert ( + generation.usage_details["input"] + + generation.usage_details["input_cache_read"] + + generation.usage_details["output"] + == generation.usage_details["total"] + ) + + assert generation.cost_details["input_cache_read"] > 0 + assert ( + generation.cost_details["input"] + + generation.cost_details["input_cache_read"] + + generation.cost_details["output"] + == generation.cost_details["total"] + )