@@ -2318,3 +2318,49 @@ def call_model(state: MessagesState):
23182318 assert observation .level == "DEFAULT"
23192319
23202320 assert hidden_count > 0
2321+
2322+
2323+ def test_cached_token_usage ():
2324+ prompt = ChatPromptTemplate .from_messages (
2325+ [
2326+ (
2327+ "system" ,
2328+ (
2329+ "This is a test prompt to reproduce the issue. "
2330+ "The prompt needs 1024 tokens to enable cache." * 100
2331+ ),
2332+ ),
2333+ ("user" , "Reply to this message {test_param}." ),
2334+ ]
2335+ )
2336+ chat = ChatOpenAI (model = "gpt-4o-mini" )
2337+ chain = prompt | chat
2338+ handler = CallbackHandler ()
2339+ config = {"callbacks" : [handler ]}
2340+
2341+ chain .invoke ({"test_param" : "in a funny way" }, config )
2342+
2343+ # invoke again to force cached token usage
2344+ chain .invoke ({"test_param" : "in a funny way" }, config )
2345+
2346+ handler .flush ()
2347+
2348+ trace = get_api ().trace .get (handler .get_trace_id ())
2349+
2350+ generation = next ((o for o in trace .observations if o .type == "GENERATION" ))
2351+
2352+ assert generation .usage_details ["input_cache_read" ] > 0
2353+ assert (
2354+ generation .usage_details ["input" ]
2355+ + generation .usage_details ["input_cache_read" ]
2356+ + generation .usage_details ["output" ]
2357+ == generation .usage_details ["total" ]
2358+ )
2359+
2360+ assert generation .cost_details ["input_cache_read" ] > 0
2361+ assert (
2362+ generation .cost_details ["input" ]
2363+ + generation .cost_details ["input_cache_read" ]
2364+ + generation .cost_details ["output" ]
2365+ == generation .cost_details ["total" ]
2366+ )
0 commit comments