diff --git a/.gitignore b/.gitignore index 47d38baef..2ec4f07d4 100644 --- a/.gitignore +++ b/.gitignore @@ -6,10 +6,13 @@ temp/* chroma-collections.parquet chroma-embeddings.parquet .DS_Store +.venv/ +.env .env* notebook SDK/* log/* logs/ +results/ parts/* json_results/* diff --git a/README.md b/README.md index 879a67efc..60e059386 100644 --- a/README.md +++ b/README.md @@ -149,10 +149,13 @@ pip3 install --upgrade -r requirements.txt ### 2. Set your OpenAI API key -Create a `.env` file in the root directory and add your API key: +Create a `.env` file in the root directory and add your API key (supports `CHATGPT_API_KEY` or `OPENAI_API_KEY`). +If you use an OpenAI-compatible endpoint/proxy, also set a base URL (include `/v1`): ```bash CHATGPT_API_KEY=your_openai_key_here +CHATGPT_BASE_URL=https://your-base-url/v1 +CHATGPT_MODEL=gpt-4o-2024-11-20 ``` ### 3. Run PageIndex on your PDF @@ -160,6 +163,10 @@ CHATGPT_API_KEY=your_openai_key_here ```bash python3 run_pageindex.py --pdf_path /path/to/your/document.pdf ``` +If your file path contains spaces, wrap it in quotes: +```bash +python3 run_pageindex.py --pdf_path "/path/with spaces/document.pdf" +```
Optional parameters @@ -173,7 +180,8 @@ You can customize the processing with additional optional arguments: --max-tokens-per-node Max tokens per node (default: 20000) --if-add-node-id Add node ID (yes/no, default: yes) --if-add-node-summary Add node summary (yes/no, default: yes) ---if-add-doc-description Add doc description (yes/no, default: yes) +--if-add-doc-description Add doc description (yes/no, default: no) +--if-add-node-text Add node text (yes/no, default: no) ```
diff --git a/pageindex/page_index.py b/pageindex/page_index.py index 882fb5dea..1ffbdb448 100644 --- a/pageindex/page_index.py +++ b/pageindex/page_index.py @@ -166,34 +166,35 @@ def extract_toc_content(content, model=None): Directly return the full table of contents content. Do not output anything else.""" response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt) - + if finish_reason == "error" and not response: + raise Exception("Failed to extract table of contents (API error)") + if_complete = check_if_toc_transformation_is_complete(content, response, model) if if_complete == "yes" and finish_reason == "finished": return response - - chat_history = [ - {"role": "user", "content": prompt}, - {"role": "assistant", "content": response}, - ] - prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) - response = response + new_response - if_complete = check_if_toc_transformation_is_complete(content, response, model) - + + max_attempts = 5 + attempt = 0 + continue_prompt = "please continue the generation of table of contents, directly output the remaining part of the structure" + while not (if_complete == "yes" and finish_reason == "finished"): + attempt += 1 + if attempt > max_attempts: + raise Exception("Failed to complete table of contents after maximum retries") + chat_history = [ - {"role": "user", "content": prompt}, - {"role": "assistant", "content": response}, + {"role": "user", "content": prompt}, + {"role": "assistant", "content": response}, ] - prompt = f"""please continue the generation of table of contents , directly output the remaining part of the structure""" - new_response, finish_reason = ChatGPT_API_with_finish_reason(model=model, prompt=prompt, chat_history=chat_history) + new_response, finish_reason = ChatGPT_API_with_finish_reason( + model=model, prompt=continue_prompt, chat_history=chat_history + ) + if finish_reason == "error" and not new_response: + raise Exception("Failed to continue table of contents extraction (API error)") + response = response + new_response if_complete = check_if_toc_transformation_is_complete(content, response, model) - - # Optional: Add a maximum retry limit to prevent infinite loops - if len(chat_history) > 5: # Arbitrary limit of 10 attempts - raise Exception('Failed to complete table of contents after maximum retries') - + return response def detect_page_index(toc_content, model=None): @@ -414,6 +415,29 @@ def add_page_offset_to_toc_json(data, offset): return data +def sanitize_toc_page_numbers(toc_items): + """ + Best-effort cleanup for parsed TOC page numbers. + + Some PDFs have TOCs with inconsistent numbering (resets/spikes). Since downstream + logic assumes page numbers are monotonic, treat non-monotonic integers as invalid + so they can be re-inferred later (e.g. via `process_none_page_numbers` / verification). + """ + last_page = None + for item in toc_items: + page = item.get("page") + if not isinstance(page, int): + continue + if page <= 0: + item["page"] = None + continue + if last_page is not None and page < last_page: + item["page"] = None + continue + last_page = page + return toc_items + + def page_list_to_group_text(page_contents, token_lengths, max_tokens=20000, overlap_page=1): num_tokens = sum(token_lengths) @@ -565,14 +589,14 @@ def generate_toc_init(part, model=None): else: raise Exception(f'finish reason: {finish_reason}') -def process_no_toc(page_list, start_index=1, model=None, logger=None): +def process_no_toc(page_list, start_index=1, model=None, logger=None, max_tokens=20000): page_contents=[] token_lengths=[] for page_index in range(start_index, start_index+len(page_list)): page_text = f"\n{page_list[page_index-start_index][0]}\n\n\n" page_contents.append(page_text) token_lengths.append(count_tokens(page_text, model)) - group_texts = page_list_to_group_text(page_contents, token_lengths) + group_texts = page_list_to_group_text(page_contents, token_lengths, max_tokens=max_tokens) logger.info(f'len(group_texts): {len(group_texts)}') toc_with_page_number= generate_toc_init(group_texts[0], model) @@ -586,7 +610,7 @@ def process_no_toc(page_list, start_index=1, model=None, logger=None): return toc_with_page_number -def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_index=1, model=None, logger=None): +def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_index=1, model=None, logger=None, max_tokens=20000): page_contents=[] token_lengths=[] toc_content = toc_transformer(toc_content, model) @@ -596,7 +620,7 @@ def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_in page_contents.append(page_text) token_lengths.append(count_tokens(page_text, model)) - group_texts = page_list_to_group_text(page_contents, token_lengths) + group_texts = page_list_to_group_text(page_contents, token_lengths, max_tokens=max_tokens) logger.info(f'len(group_texts): {len(group_texts)}') toc_with_page_number=copy.deepcopy(toc_content) @@ -613,6 +637,7 @@ def process_toc_no_page_numbers(toc_content, toc_page_list, page_list, start_in def process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_check_page_num=None, model=None, logger=None): toc_with_page_number = toc_transformer(toc_content, model) + toc_with_page_number = sanitize_toc_page_numbers(toc_with_page_number) logger.info(f'toc_with_page_number: {toc_with_page_number}') toc_no_page_number = remove_page_number(copy.deepcopy(toc_with_page_number)) @@ -674,11 +699,11 @@ def process_none_page_numbers(toc_items, page_list, start_index=1, model=None): continue item_copy = copy.deepcopy(item) - del item_copy['page'] + item_copy.pop('page', None) result = add_page_number_to_toc(page_contents, item_copy, model) if isinstance(result[0]['physical_index'], str) and result[0]['physical_index'].startswith('').strip()) - del item['page'] + item.pop('page', None) return toc_items @@ -758,13 +783,13 @@ async def fix_incorrect_toc(toc_with_page_number, page_list, incorrect_results, incorrect_results_and_range_logs = [] # Helper function to process and check a single incorrect item async def process_and_check_item(incorrect_item): - list_index = incorrect_item['list_index'] + toc_list_index = incorrect_item['list_index'] # Check if list_index is valid - if list_index < 0 or list_index >= len(toc_with_page_number): + if toc_list_index < 0 or toc_list_index >= len(toc_with_page_number): # Return an invalid result for out-of-bounds indices return { - 'list_index': list_index, + 'list_index': toc_list_index, 'title': incorrect_item['title'], 'physical_index': incorrect_item.get('physical_index'), 'is_valid': False @@ -772,7 +797,7 @@ async def process_and_check_item(incorrect_item): # Find the previous correct item prev_correct = None - for i in range(list_index-1, -1, -1): + for i in range(toc_list_index - 1, -1, -1): if i not in incorrect_indices and i >= 0 and i < len(toc_with_page_number): physical_index = toc_with_page_number[i].get('physical_index') if physical_index is not None: @@ -784,7 +809,7 @@ async def process_and_check_item(incorrect_item): # Find the next correct item next_correct = None - for i in range(list_index+1, len(toc_with_page_number)): + for i in range(toc_list_index + 1, len(toc_with_page_number)): if i not in incorrect_indices and i >= 0 and i < len(toc_with_page_number): physical_index = toc_with_page_number[i].get('physical_index') if physical_index is not None: @@ -795,7 +820,7 @@ async def process_and_check_item(incorrect_item): next_correct = end_index incorrect_results_and_range_logs.append({ - 'list_index': list_index, + 'list_index': toc_list_index, 'title': incorrect_item['title'], 'prev_correct': prev_correct, 'next_correct': next_correct @@ -804,9 +829,9 @@ async def process_and_check_item(incorrect_item): page_contents=[] for page_index in range(prev_correct, next_correct+1): # Add bounds checking to prevent IndexError - list_index = page_index - start_index - if list_index >= 0 and list_index < len(page_list): - page_text = f"\n{page_list[list_index][0]}\n\n\n" + page_list_index = page_index - start_index + if 0 <= page_list_index < len(page_list): + page_text = f"\n{page_list[page_list_index][0]}\n\n\n" page_contents.append(page_text) else: continue @@ -820,7 +845,7 @@ async def process_and_check_item(incorrect_item): check_result = await check_title_appearance(check_item, page_list, start_index, model) return { - 'list_index': list_index, + 'list_index': toc_list_index, 'title': incorrect_item['title'], 'physical_index': physical_index_int, 'is_valid': check_result['answer'] == 'yes' @@ -951,13 +976,36 @@ async def verify_toc(page_list, list_result, start_index=1, N=None, model=None): async def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=None, start_index=1, opt=None, logger=None): print(mode) print(f'start_index: {start_index}') - - if mode == 'process_toc_with_page_numbers': - toc_with_page_number = process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_check_page_num=opt.toc_check_page_num, model=opt.model, logger=logger) - elif mode == 'process_toc_no_page_numbers': - toc_with_page_number = process_toc_no_page_numbers(toc_content, toc_page_list, page_list, model=opt.model, logger=logger) - else: - toc_with_page_number = process_no_toc(page_list, start_index=start_index, model=opt.model, logger=logger) + max_tokens = getattr(opt, "max_token_num_each_node", 20000) + try: + if mode == 'process_toc_with_page_numbers': + toc_with_page_number = process_toc_with_page_numbers(toc_content, toc_page_list, page_list, toc_check_page_num=opt.toc_check_page_num, model=opt.model, logger=logger) + elif mode == 'process_toc_no_page_numbers': + toc_with_page_number = process_toc_no_page_numbers( + toc_content, + toc_page_list, + page_list, + start_index=start_index, + model=opt.model, + logger=logger, + max_tokens=max_tokens, + ) + else: + toc_with_page_number = process_no_toc( + page_list, + start_index=start_index, + model=opt.model, + logger=logger, + max_tokens=max_tokens, + ) + except Exception as e: + if logger: + logger.error({"mode": mode, "start_index": start_index, "error": str(e)}) + if mode == 'process_toc_with_page_numbers': + return await meta_processor(page_list, mode='process_toc_no_page_numbers', toc_content=toc_content, toc_page_list=toc_page_list, start_index=start_index, opt=opt, logger=logger) + if mode == 'process_toc_no_page_numbers': + return await meta_processor(page_list, mode='process_no_toc', start_index=start_index, opt=opt, logger=logger) + return [{"structure": "1", "title": "Document", "physical_index": start_index}] toc_with_page_number = [item for item in toc_with_page_number if item.get('physical_index') is not None] @@ -971,7 +1019,7 @@ async def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=N accuracy, incorrect_results = await verify_toc(page_list, toc_with_page_number, start_index=start_index, model=opt.model) logger.info({ - 'mode': 'process_toc_with_page_numbers', + 'mode': mode, 'accuracy': accuracy, 'incorrect_results': incorrect_results }) @@ -986,7 +1034,9 @@ async def meta_processor(page_list, mode=None, toc_content=None, toc_page_list=N elif mode == 'process_toc_no_page_numbers': return await meta_processor(page_list, mode='process_no_toc', start_index=start_index, opt=opt, logger=logger) else: - raise Exception('Processing failed') + if logger: + logger.error({"mode": mode, "start_index": start_index, "accuracy": accuracy, "incorrect_results_count": len(incorrect_results)}) + return [{"structure": "1", "title": "Document", "physical_index": start_index}] async def process_large_node_recursively(node, page_list, opt=None, logger=None): @@ -1022,15 +1072,25 @@ async def tree_parser(page_list, opt, doc=None, logger=None): check_toc_result = check_toc(page_list, opt) logger.info(check_toc_result) - if check_toc_result.get("toc_content") and check_toc_result["toc_content"].strip() and check_toc_result["page_index_given_in_toc"] == "yes": - toc_with_page_number = await meta_processor( - page_list, - mode='process_toc_with_page_numbers', - start_index=1, - toc_content=check_toc_result['toc_content'], - toc_page_list=check_toc_result['toc_page_list'], - opt=opt, - logger=logger) + if check_toc_result.get("toc_content") and check_toc_result["toc_content"].strip(): + if check_toc_result["page_index_given_in_toc"] == "yes": + toc_with_page_number = await meta_processor( + page_list, + mode='process_toc_with_page_numbers', + start_index=1, + toc_content=check_toc_result['toc_content'], + toc_page_list=check_toc_result['toc_page_list'], + opt=opt, + logger=logger) + else: + toc_with_page_number = await meta_processor( + page_list, + mode='process_toc_no_page_numbers', + start_index=1, + toc_content=check_toc_result['toc_content'], + toc_page_list=check_toc_result['toc_page_list'], + opt=opt, + logger=logger) else: toc_with_page_number = await meta_processor( page_list, @@ -1066,7 +1126,7 @@ def page_index_main(doc, opt=None): raise ValueError("Unsupported input type. Expected a PDF file path or BytesIO object.") print('Parsing PDF...') - page_list = get_page_tokens(doc) + page_list = get_page_tokens(doc, model=getattr(opt, "model", None)) logger.info({'total_page_number': len(page_list)}) logger.info({'total_token': sum([page[1] for page in page_list])}) @@ -1141,4 +1201,4 @@ def validate_and_truncate_physical_indices(toc_with_page_number, page_list_lengt if truncated_items: print(f"Truncated {len(truncated_items)} TOC items that exceeded document length") - return toc_with_page_number \ No newline at end of file + return toc_with_page_number diff --git a/pageindex/utils.py b/pageindex/utils.py index dc7acd888..68a79a4a2 100644 --- a/pageindex/utils.py +++ b/pageindex/utils.py @@ -5,7 +5,10 @@ from datetime import datetime import time import json -import PyPDF2 +try: + import pypdf as PyPDF2 +except ImportError: + import PyPDF2 import copy import asyncio import pymupdf @@ -16,24 +19,73 @@ import yaml from pathlib import Path from types import SimpleNamespace as config +from typing import Optional -CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") +CHATGPT_API_KEY = os.getenv("CHATGPT_API_KEY") or os.getenv("OPENAI_API_KEY") +CHATGPT_BASE_URL = os.getenv("CHATGPT_BASE_URL") or os.getenv("OPENAI_BASE_URL") +CHATGPT_MODEL = os.getenv("CHATGPT_MODEL") or os.getenv("OPENAI_MODEL") + +DEFAULT_MODEL = "gpt-4o-2024-11-20" + + +def _default_model_from_env() -> Optional[str]: + return os.getenv("CHATGPT_MODEL") or os.getenv("OPENAI_MODEL") + + +def _resolve_model_name(model: Optional[str]) -> str: + if model is None or model == DEFAULT_MODEL: + return _default_model_from_env() or DEFAULT_MODEL + return model + + +def _tiktoken_encoder(model: Optional[str]): + model_name = _resolve_model_name(model) + try: + return tiktoken.encoding_for_model(model_name) + except Exception: + return tiktoken.get_encoding("cl100k_base") + + +_SYNC_OPENAI_CLIENTS = {} +_ASYNC_OPENAI_CLIENTS = {} + + +def _client_cache_key(api_key: Optional[str], base_url: Optional[str]): + return (api_key or "", base_url or "") + + +def _openai_client(api_key: Optional[str] = CHATGPT_API_KEY, base_url: Optional[str] = CHATGPT_BASE_URL): + key = _client_cache_key(api_key, base_url) + client = _SYNC_OPENAI_CLIENTS.get(key) + if client is None: + client = openai.OpenAI(api_key=api_key, base_url=base_url) if base_url else openai.OpenAI(api_key=api_key) + _SYNC_OPENAI_CLIENTS[key] = client + return client + + +def _openai_async_client(api_key: Optional[str] = CHATGPT_API_KEY, base_url: Optional[str] = CHATGPT_BASE_URL): + key = _client_cache_key(api_key, base_url) + client = _ASYNC_OPENAI_CLIENTS.get(key) + if client is None: + client = openai.AsyncOpenAI(api_key=api_key, base_url=base_url) if base_url else openai.AsyncOpenAI(api_key=api_key) + _ASYNC_OPENAI_CLIENTS[key] = client + return client def count_tokens(text, model=None): if not text: return 0 - enc = tiktoken.encoding_for_model(model) + enc = _tiktoken_encoder(model) tokens = enc.encode(text) return len(tokens) def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + model = _resolve_model_name(model) + client = _openai_client(api_key=api_key) for i in range(max_retries): try: if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) + messages = list(chat_history) + [{"role": "user", "content": prompt}] else: messages = [{"role": "user", "content": prompt}] @@ -54,18 +106,18 @@ def ChatGPT_API_with_finish_reason(model, prompt, api_key=CHATGPT_API_KEY, chat_ time.sleep(1) # Wait for 1秒 before retrying else: logging.error('Max retries reached for prompt: ' + prompt) - return "Error" + return "", "error" def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): max_retries = 10 - client = openai.OpenAI(api_key=api_key) + model = _resolve_model_name(model) + client = _openai_client(api_key=api_key) for i in range(max_retries): try: if chat_history: - messages = chat_history - messages.append({"role": "user", "content": prompt}) + messages = list(chat_history) + [{"role": "user", "content": prompt}] else: messages = [{"role": "user", "content": prompt}] @@ -88,16 +140,17 @@ def ChatGPT_API(model, prompt, api_key=CHATGPT_API_KEY, chat_history=None): async def ChatGPT_API_async(model, prompt, api_key=CHATGPT_API_KEY): max_retries = 10 + model = _resolve_model_name(model) messages = [{"role": "user", "content": prompt}] for i in range(max_retries): try: - async with openai.AsyncOpenAI(api_key=api_key) as client: - response = await client.chat.completions.create( - model=model, - messages=messages, - temperature=0, - ) - return response.choices[0].message.content + client = _openai_async_client(api_key=api_key) + response = await client.chat.completions.create( + model=model, + messages=messages, + temperature=0, + ) + return response.choices[0].message.content except Exception as e: print('************* Retrying *************') logging.error(f"Error: {e}") @@ -195,6 +248,38 @@ def structure_to_list(structure): nodes.extend(structure_to_list(item)) return nodes + +def create_node_mapping(tree, include_page_ranges: bool = False, max_page: Optional[int] = None): + """ + Create a `node_id -> node` mapping for convenient lookups in retrieval pipelines. + + - `include_page_ranges=False` (default): returns `node_id -> node_dict` and adds `page_index` + (alias of `start_index`) for backward compatibility with older notebooks. + - `include_page_ranges=True`: returns `node_id -> {"node": node, "start_index": x, "end_index": y}`. + """ + node_map = {} + for node in structure_to_list(tree): + if not isinstance(node, dict): + continue + node_id = node.get("node_id") + if not node_id: + continue + + start_index = node.get("start_index") + end_index = node.get("end_index") + + if include_page_ranges: + if max_page is not None and isinstance(end_index, int): + end_index = min(end_index, max_page) + node_map[node_id] = {"node": node, "start_index": start_index, "end_index": end_index} + else: + node_copy = dict(node) + if start_index is not None and "page_index" not in node_copy: + node_copy["page_index"] = start_index + node_map[node_id] = node_copy + + return node_map + def get_leaf_nodes(structure): if isinstance(structure, dict): @@ -410,8 +495,9 @@ def add_preface_if_needed(data): -def get_page_tokens(pdf_path, model="gpt-4o-2024-11-20", pdf_parser="PyPDF2"): - enc = tiktoken.encoding_for_model(model) +def get_page_tokens(pdf_path, model=DEFAULT_MODEL, pdf_parser="PyPDF2"): + model_name = _resolve_model_name(model) + enc = _tiktoken_encoder(model_name) if pdf_parser == "PyPDF2": pdf_reader = PyPDF2.PdfReader(pdf_path) page_list = [] @@ -709,4 +795,7 @@ def load(self, user_opt=None) -> config: self._validate_keys(user_dict) merged = {**self._default_dict, **user_dict} - return config(**merged) \ No newline at end of file + env_default_model = _default_model_from_env() + if env_default_model and "model" not in user_dict: + merged["model"] = env_default_model + return config(**merged) diff --git a/requirements.txt b/requirements.txt index 463db58f1..a695e9880 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ openai==1.101.0 pymupdf==1.26.4 -PyPDF2==3.0.1 +pypdf>=4.0.0 python-dotenv==1.1.0 tiktoken==0.11.0 pyyaml==6.0.2 diff --git a/run_pageindex.py b/run_pageindex.py index 107024505..a5beca266 100644 --- a/run_pageindex.py +++ b/run_pageindex.py @@ -10,7 +10,8 @@ parser.add_argument('--pdf_path', type=str, help='Path to the PDF file') parser.add_argument('--md_path', type=str, help='Path to the Markdown file') - parser.add_argument('--model', type=str, default='gpt-4o-2024-11-20', help='Model to use') + default_model = ConfigLoader().load().model + parser.add_argument('--model', type=str, default=default_model, help='Model to use') parser.add_argument('--toc-check-pages', type=int, default=20, help='Number of pages to check for table of contents (PDF only)') @@ -130,4 +131,4 @@ with open(output_file, 'w', encoding='utf-8') as f: json.dump(toc_with_page_number, f, indent=2, ensure_ascii=False) - print(f'Tree structure saved to: {output_file}') \ No newline at end of file + print(f'Tree structure saved to: {output_file}')