Skip to content

Commit 0ef2378

Browse files
authored
fix(langchain): cached token usage (#1121)
1 parent 80c6c55 commit 0ef2378

File tree

2 files changed

+52
-0
lines changed

2 files changed

+52
-0
lines changed

langfuse/callback/langchain.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1096,12 +1096,18 @@ def _parse_usage_model(usage: typing.Union[pydantic.BaseModel, dict]):
10961096
for key, value in input_token_details.items():
10971097
usage_model[f"input_{key}"] = value
10981098

1099+
if "input" in usage_model:
1100+
usage_model["input"] = max(0, usage_model["input"] - value)
1101+
10991102
if "output_token_details" in usage_model:
11001103
output_token_details = usage_model.pop("output_token_details", {})
11011104

11021105
for key, value in output_token_details.items():
11031106
usage_model[f"output_{key}"] = value
11041107

1108+
if "output" in usage_model:
1109+
usage_model["output"] = max(0, usage_model["output"] - value)
1110+
11051111
return usage_model if usage_model else None
11061112

11071113

tests/test_langchain.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2318,3 +2318,49 @@ def call_model(state: MessagesState):
23182318
assert observation.level == "DEFAULT"
23192319

23202320
assert hidden_count > 0
2321+
2322+
2323+
def test_cached_token_usage():
2324+
prompt = ChatPromptTemplate.from_messages(
2325+
[
2326+
(
2327+
"system",
2328+
(
2329+
"This is a test prompt to reproduce the issue. "
2330+
"The prompt needs 1024 tokens to enable cache." * 100
2331+
),
2332+
),
2333+
("user", "Reply to this message {test_param}."),
2334+
]
2335+
)
2336+
chat = ChatOpenAI(model="gpt-4o-mini")
2337+
chain = prompt | chat
2338+
handler = CallbackHandler()
2339+
config = {"callbacks": [handler]}
2340+
2341+
chain.invoke({"test_param": "in a funny way"}, config)
2342+
2343+
# invoke again to force cached token usage
2344+
chain.invoke({"test_param": "in a funny way"}, config)
2345+
2346+
handler.flush()
2347+
2348+
trace = get_api().trace.get(handler.get_trace_id())
2349+
2350+
generation = next((o for o in trace.observations if o.type == "GENERATION"))
2351+
2352+
assert generation.usage_details["input_cache_read"] > 0
2353+
assert (
2354+
generation.usage_details["input"]
2355+
+ generation.usage_details["input_cache_read"]
2356+
+ generation.usage_details["output"]
2357+
== generation.usage_details["total"]
2358+
)
2359+
2360+
assert generation.cost_details["input_cache_read"] > 0
2361+
assert (
2362+
generation.cost_details["input"]
2363+
+ generation.cost_details["input_cache_read"]
2364+
+ generation.cost_details["output"]
2365+
== generation.cost_details["total"]
2366+
)

0 commit comments

Comments
 (0)