diff --git a/src/agentlab/llm/tracking.py b/src/agentlab/llm/tracking.py index 8e3d812a..6a08839b 100644 --- a/src/agentlab/llm/tracking.py +++ b/src/agentlab/llm/tracking.py @@ -1,7 +1,7 @@ -from functools import cache import os import threading from contextlib import contextmanager +from functools import cache import requests from langchain_community.callbacks.openai_info import MODEL_COST_PER_1K_TOKENS @@ -10,10 +10,13 @@ class LLMTracker: - def __init__(self): + def __init__(self, suffix=""): self.input_tokens = 0 self.output_tokens = 0 self.cost = 0.0 + self.input_tokens_key = "input_tokens_" + suffix if suffix else "input_tokens" + self.output_tokens_key = "output_tokens_" + suffix if suffix else "output_tokens" + self.cost_key = "cost_" + suffix if suffix else "cost" def __call__(self, input_tokens: int, output_tokens: int, cost: float): self.input_tokens += input_tokens @@ -23,9 +26,9 @@ def __call__(self, input_tokens: int, output_tokens: int, cost: float): @property def stats(self): return { - "input_tokens": self.input_tokens, - "output_tokens": self.output_tokens, - "cost": self.cost, + self.input_tokens_key: self.input_tokens, + self.output_tokens_key: self.output_tokens, + self.cost_key: self.cost, } def add_tracker(self, tracker: "LLMTracker"): @@ -36,12 +39,12 @@ def __repr__(self): @contextmanager -def set_tracker(): +def set_tracker(suffix=""): global TRACKER if not hasattr(TRACKER, "instance"): TRACKER.instance = None previous_tracker = TRACKER.instance # type: LLMTracker - TRACKER.instance = LLMTracker() + TRACKER.instance = LLMTracker(suffix) try: yield TRACKER.instance finally: @@ -52,9 +55,9 @@ def set_tracker(): TRACKER.instance = previous_tracker -def cost_tracker_decorator(get_action): +def cost_tracker_decorator(get_action, suffix=""): def wrapper(self, obs): - with set_tracker() as tracker: + with set_tracker(suffix) as tracker: action, agent_info = get_action(self, obs) agent_info.get("stats").update(tracker.stats) return action, agent_info