From 6201c9c14e7b390e841bf6c61b54911104ebac0a Mon Sep 17 00:00:00 2001 From: DankerMu Date: Tue, 27 Jan 2026 08:26:39 +0800 Subject: [PATCH] Add token usage reporting --- README.md | 4 +- pageindex/utils.py | 100 ++++++++++++++++++++++++++++++++---- run_pageindex.py | 14 ++++- tests/test_usage_tracker.py | 21 ++++++++ 4 files changed, 128 insertions(+), 11 deletions(-) create mode 100644 tests/test_usage_tracker.py diff --git a/README.md b/README.md index 879a67efc..d9f675fb8 100644 --- a/README.md +++ b/README.md @@ -173,7 +173,9 @@ You can customize the processing with additional optional arguments: --max-tokens-per-node Max tokens per node (default: 20000) --if-add-node-id Add node ID (yes/no, default: yes) --if-add-node-summary Add node summary (yes/no, default: yes) ---if-add-doc-description Add doc description (yes/no, default: yes) +--if-add-doc-description Add doc description (yes/no, default: no) +--if-add-node-text Add node text (yes/no, default: no) +--if-report-usage Report token usage (yes/no, default: no) ``` diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..7fac705fc 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -8,6 +8,7 @@ import PyPDF2 import copy import asyncio +import threading import pymupdf from io import BytesIO from dotenv import load_dotenv @@ -19,6 +20,81 @@ CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") + +class UsageTracker: + def __init__(self): + self._lock = threading.Lock() + self.calls = 0 + self.prompt_tokens = 0 + self.completion_tokens = 0 + self.total_tokens = 0 + self.by_model = {} + + def add(self, usage, model=None): + if not usage: + return + + if isinstance(usage, dict): + prompt_tokens = usage.get("prompt_tokens") + completion_tokens = usage.get("completion_tokens") + total_tokens = usage.get("total_tokens") + else: + prompt_tokens = getattr(usage, "prompt_tokens", None) + completion_tokens = getattr(usage, "completion_tokens", None) + total_tokens = getattr(usage, "total_tokens", None) + + if prompt_tokens is None and completion_tokens is None and total_tokens is None: + return + + with self._lock: + self.calls += 1 + if prompt_tokens is not None: + self.prompt_tokens += int(prompt_tokens) + if completion_tokens is not None: + self.completion_tokens += int(completion_tokens) + if total_tokens is not None: + self.total_tokens += int(total_tokens) + + model_key = model or "unknown" + if model_key not in self.by_model: + self.by_model[model_key] = { + "calls": 0, + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + model_row = self.by_model[model_key] + model_row["calls"] += 1 + if prompt_tokens is not None: + model_row["prompt_tokens"] += int(prompt_tokens) + if completion_tokens is not None: + model_row["completion_tokens"] += int(completion_tokens) + if total_tokens is not None: + model_row["total_tokens"] += int(total_tokens) + + def to_dict(self): + with self._lock: + return { + "calls": self.calls, + "prompt_tokens": self.prompt_tokens, + "completion_tokens": self.completion_tokens, + "total_tokens": self.total_tokens, + "by_model": copy.deepcopy(self.by_model), + } + + +_USAGE_TRACKER = None + + +def set_usage_tracker(tracker): + global _USAGE_TRACKER + _USAGE_TRACKER = tracker + + +def get_usage_tracker(): + return _USAGE_TRACKER + + def count_tokens(text, model=None): if not text: return 0 @@ -26,14 +102,14 @@ def count_tokens(text, model=None): tokens = enc.encode(text) return len(tokens) -def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None, usage_tracker=None): max_retries = 10 client = openai.OpenAI(api_key=api_key) + tracker = usage_tracker or get_usage_tracker() for i in range(max_retries): try: if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) + messages = list(chat_history) + [{"role": "user", "content": prompt}] else: messages = [{"role": "user", "content": prompt}] @@ -42,6 +118,8 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ messages=messages, temperature=0, ) + if tracker and getattr(response, "usage", None) is not None: + tracker.add(response.usage, model=model) if response.choices[0].finish_reason == "length": return response.choices[0].message.content, "max_output_reached" else: @@ -58,14 +136,14 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ -def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): +def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None, usage_tracker=None): max_retries = 10 client = openai.OpenAI(api_key=api_key) + tracker = usage_tracker or get_usage_tracker() for i in range(max_retries): try: if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) + messages = list(chat_history) + [{"role": "user", "content": prompt}] else: messages = [{"role": "user", "content": prompt}] @@ -74,7 +152,8 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): messages=messages, temperature=0, ) - + if tracker and getattr(response, "usage", None) is not None: + tracker.add(response.usage, model=model) return response.choices[0].message.content except Exception as e: print('************* Retrying *************') @@ -86,9 +165,10 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): return "Error" -async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): +async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY, usage_tracker=None): max_retries = 10 messages = [{"role": "user", "content": prompt}] + tracker = usage_tracker or get_usage_tracker() for i in range(max_retries): try: async with openai.AsyncOpenAI(api_key=api_key) as client: @@ -97,6 +177,8 @@ async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): messages=messages, temperature=0, ) + if tracker and getattr(response, "usage", None) is not None: + tracker.add(response.usage, model=model) return response.choices[0].message.content except Exception as e: print('************* Retrying *************') @@ -709,4 +791,4 @@ def load(self, user_opt=None) -> config: self._validate_keys(user_dict) merged = {**self._default_dict, **user_dict} - return config(**merged) \ No newline at end of file + return config(**merged) diff --git a/run_pageindex.py b/run_pageindex.py index 107024505..4630e604e 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -27,6 +27,9 @@ help='Whether to add doc description to the doc') parser.add_argument('--if-add-node-text', type=str, default='no', help='Whether to add text to the node') + + parser.add_argument('--if-report-usage', type=str, default='no', + help='Whether to report token usage for this run (yes/no)') # Markdown specific arguments parser.add_argument('--if-thinning', type=str, default='no', @@ -63,6 +66,12 @@ if_add_node_text=args.if_add_node_text ) + usage_tracker = None + if args.if_report_usage.lower() == 'yes': + from pageindex.utils import UsageTracker, set_usage_tracker + usage_tracker = UsageTracker() + set_usage_tracker(usage_tracker) + # Process the PDF toc_with_page_number = page_index_main(args.pdf_path, opt) print('Parsing done, saving to file...') @@ -77,6 +86,9 @@ json.dump(toc_with_page_number, f, indent=2) print(f'Tree structure saved to: {output_file}') + if usage_tracker: + print("Token usage:") + print(json.dumps(usage_tracker.to_dict(), indent=2)) elif args.md_path: # Validate Markdown file @@ -130,4 +142,4 @@ with open(output_file, 'w', encoding='utf-8') as f: json.dump(toc_with_page_number, f, indent=2, ensure_ascii=False) - print(f'Tree structure saved to: {output_file}') \ No newline at end of file + print(f'Tree structure saved to: {output_file}') diff --git a/tests/test_usage_tracker.py b/tests/test_usage_tracker.py new file mode 100644 index 000000000..2916aa9ad --- /dev/null +++ b/tests/test_usage_tracker.py @@ -0,0 +1,21 @@ +import unittest + +from pageindex.utils import UsageTracker + + +class TestUsageTracker(unittest.TestCase): + def test_add_dict_usage(self) -> None: + tracker = UsageTracker() + tracker.add({"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, model="test-model") + + self.assertEqual(tracker.calls, 1) + self.assertEqual(tracker.prompt_tokens, 10) + self.assertEqual(tracker.completion_tokens, 5) + self.assertEqual(tracker.total_tokens, 15) + self.assertEqual(tracker.by_model["test-model"]["calls"], 1) + + def test_ignores_missing_usage(self) -> None: + tracker = UsageTracker() + tracker.add(None, model="test-model") + tracker.add({}, model="test-model") + self.assertEqual(tracker.calls, 0)