Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,7 +173,9 @@ You can customize the processing with additional optional arguments:
--max-tokens-per-node Max tokens per node (default: 20000)
--if-add-node-id Add node ID (yes/no, default: yes)
--if-add-node-summary Add node summary (yes/no, default: yes)
--if-add-doc-description Add doc description (yes/no, default: yes)
--if-add-doc-description Add doc description (yes/no, default: no)
--if-add-node-text Add node text (yes/no, default: no)
--if-report-usage Report token usage (yes/no, default: no)
```
</details>

Expand Down
100 changes: 91 additions & 9 deletions pageindex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import PyPDF2
import copy
import asyncio
import threading
import pymupdf
from io import BytesIO
from dotenv import load_dotenv
Expand All @@ -19,21 +20,96 @@

CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY")


class UsageTracker:
def __init__(self):
self._lock = threading.Lock()
self.calls = 0
self.prompt_tokens = 0
self.completion_tokens = 0
self.total_tokens = 0
self.by_model = {}

def add(self, usage, model=None):
if not usage:
return

if isinstance(usage, dict):
prompt_tokens = usage.get("prompt_tokens")
completion_tokens = usage.get("completion_tokens")
total_tokens = usage.get("total_tokens")
else:
prompt_tokens = getattr(usage, "prompt_tokens", None)
completion_tokens = getattr(usage, "completion_tokens", None)
total_tokens = getattr(usage, "total_tokens", None)

if prompt_tokens is None and completion_tokens is None and total_tokens is None:
return

with self._lock:
self.calls += 1
if prompt_tokens is not None:
self.prompt_tokens += int(prompt_tokens)
if completion_tokens is not None:
self.completion_tokens += int(completion_tokens)
if total_tokens is not None:
self.total_tokens += int(total_tokens)

model_key = model or "unknown"
if model_key not in self.by_model:
self.by_model[model_key] = {
"calls": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
model_row = self.by_model[model_key]
model_row["calls"] += 1
if prompt_tokens is not None:
model_row["prompt_tokens"] += int(prompt_tokens)
if completion_tokens is not None:
model_row["completion_tokens"] += int(completion_tokens)
if total_tokens is not None:
model_row["total_tokens"] += int(total_tokens)

def to_dict(self):
with self._lock:
return {
"calls": self.calls,
"prompt_tokens": self.prompt_tokens,
"completion_tokens": self.completion_tokens,
"total_tokens": self.total_tokens,
"by_model": copy.deepcopy(self.by_model),
}


_USAGE_TRACKER = None


def set_usage_tracker(tracker):
global _USAGE_TRACKER
_USAGE_TRACKER = tracker


def get_usage_tracker():
return _USAGE_TRACKER


def count_tokens(text, model=None):
if not text:
return 0
enc = tiktoken.encoding_for_model(model)
tokens = enc.encode(text)
return len(tokens)

def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None, usage_tracker=None):
max_retries = 10
client = openai.OpenAI(api_key=api_key)
tracker = usage_tracker or get_usage_tracker()
for i in range(max_retries):
try:
if chat_history:
messages = chat_history
messages.append({"role": "user", "content": prompt})
messages = list(chat_history) + [{"role": "user", "content": prompt}]
else:
messages = [{"role": "user", "content": prompt}]

Expand All @@ -42,6 +118,8 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_
messages=messages,
temperature=0,
)
if tracker and getattr(response, "usage", None) is not None:
tracker.add(response.usage, model=model)
if response.choices[0].finish_reason == "length":
return response.choices[0].message.content, "max_output_reached"
else:
Expand All @@ -58,14 +136,14 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_



def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None, usage_tracker=None):
max_retries = 10
client = openai.OpenAI(api_key=api_key)
tracker = usage_tracker or get_usage_tracker()
for i in range(max_retries):
try:
if chat_history:
messages = chat_history
messages.append({"role": "user", "content": prompt})
messages = list(chat_history) + [{"role": "user", "content": prompt}]
else:
messages = [{"role": "user", "content": prompt}]

Expand All @@ -74,7 +152,8 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
messages=messages,
temperature=0,
)

if tracker and getattr(response, "usage", None) is not None:
tracker.add(response.usage, model=model)
return response.choices[0].message.content
except Exception as e:
print('************* Retrying *************')
Expand All @@ -86,9 +165,10 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None):
return "Error"


async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY, usage_tracker=None):
max_retries = 10
messages = [{"role": "user", "content": prompt}]
tracker = usage_tracker or get_usage_tracker()
for i in range(max_retries):
try:
async with openai.AsyncOpenAI(api_key=api_key) as client:
Expand All @@ -97,6 +177,8 @@ async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY):
messages=messages,
temperature=0,
)
if tracker and getattr(response, "usage", None) is not None:
tracker.add(response.usage, model=model)
return response.choices[0].message.content
except Exception as e:
print('************* Retrying *************')
Expand Down Expand Up @@ -709,4 +791,4 @@ def load(self, user_opt=None) -> config:

self._validate_keys(user_dict)
merged = {**self._default_dict, **user_dict}
return config(**merged)
return config(**merged)
14 changes: 13 additions & 1 deletion run_pageindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
help='Whether to add doc description to the doc')
parser.add_argument('--if-add-node-text', type=str, default='no',
help='Whether to add text to the node')

parser.add_argument('--if-report-usage', type=str, default='no',
help='Whether to report token usage for this run (yes/no)')

# Markdown specific arguments
parser.add_argument('--if-thinning', type=str, default='no',
Expand Down Expand Up @@ -63,6 +66,12 @@
if_add_node_text=args.if_add_node_text
)

usage_tracker = None
if args.if_report_usage.lower() == 'yes':
from pageindex.utils import UsageTracker, set_usage_tracker
usage_tracker = UsageTracker()
set_usage_tracker(usage_tracker)

# Process the PDF
toc_with_page_number = page_index_main(args.pdf_path, opt)
print('Parsing done, saving to file...')
Expand All @@ -77,6 +86,9 @@
json.dump(toc_with_page_number, f, indent=2)

print(f'Tree structure saved to: {output_file}')
if usage_tracker:
print("Token usage:")
print(json.dumps(usage_tracker.to_dict(), indent=2))

elif args.md_path:
# Validate Markdown file
Expand Down Expand Up @@ -130,4 +142,4 @@
with open(output_file, 'w', encoding='utf-8') as f:
json.dump(toc_with_page_number, f, indent=2, ensure_ascii=False)

print(f'Tree structure saved to: {output_file}')
print(f'Tree structure saved to: {output_file}')
21 changes: 21 additions & 0 deletions tests/test_usage_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import unittest

from pageindex.utils import UsageTracker


class TestUsageTracker(unittest.TestCase):
def test_add_dict_usage(self) -> None:
tracker = UsageTracker()
tracker.add({"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}, model="test-model")

self.assertEqual(tracker.calls, 1)
self.assertEqual(tracker.prompt_tokens, 10)
self.assertEqual(tracker.completion_tokens, 5)
self.assertEqual(tracker.total_tokens, 15)
self.assertEqual(tracker.by_model["test-model"]["calls"], 1)

def test_ignores_missing_usage(self) -> None:
tracker = UsageTracker()
tracker.add(None, model="test-model")
tracker.add({}, model="test-model")
self.assertEqual(tracker.calls, 0)