|
3 | 3 | from typing import Any, Dict, List, Optional |
4 | 4 |
|
5 | 5 | from config import CFG # config (keeps chunk_size etc if needed) |
| 6 | +import atexit |
| 7 | +import logging |
| 8 | +import threading |
| 9 | +import queue |
| 10 | + |
| 11 | +logging.basicConfig(level=logging.INFO) |
| 12 | +_LOG = logging.getLogger(__name__) |
| 13 | + |
| 14 | +# registry of DBWriter instances keyed by database path |
| 15 | +_WRITERS = {} |
| 16 | +_WRITERS_LOCK = threading.Lock() |
| 17 | + |
| 18 | +class _DBTask: |
| 19 | + def __init__(self, sql, params): |
| 20 | + self.sql = sql |
| 21 | + self.params = params |
| 22 | + self.event = threading.Event() |
| 23 | + self.rowid = None |
| 24 | + self.exception = None |
| 25 | + |
| 26 | +class DBWriter: |
| 27 | + def __init__(self, database_path, timeout_seconds=30): |
| 28 | + self.database_path = database_path |
| 29 | + self._q = queue.Queue() |
| 30 | + self._stop = threading.Event() |
| 31 | + self._thread = threading.Thread(target=self._worker, daemon=True, name=f"DBWriter-{database_path}") |
| 32 | + self._timeout_seconds = timeout_seconds |
| 33 | + self._thread.start() |
| 34 | + |
| 35 | + def _open_conn(self): |
| 36 | + conn = sqlite3.connect(self.database_path, timeout=self._timeout_seconds, check_same_thread=False) |
| 37 | + # Reduce contention and allow concurrent readers during writes |
| 38 | + conn.execute("PRAGMA journal_mode=WAL;") |
| 39 | + # Make busy timeout explicit (milliseconds) |
| 40 | + conn.execute("PRAGMA busy_timeout = 30000;") |
| 41 | + # Optional: balance durability and performance |
| 42 | + conn.execute("PRAGMA synchronous = NORMAL;") |
| 43 | + return conn |
| 44 | + |
| 45 | + def _worker(self): |
| 46 | + conn = None |
| 47 | + try: |
| 48 | + conn = self._open_conn() |
| 49 | + cur = conn.cursor() |
| 50 | + while not self._stop.is_set(): |
| 51 | + try: |
| 52 | + task = self._q.get(timeout=0.5) |
| 53 | + except queue.Empty: |
| 54 | + continue |
| 55 | + if task is None: |
| 56 | + # sentinel to stop |
| 57 | + break |
| 58 | + try: |
| 59 | + cur.execute(task.sql, task.params) |
| 60 | + conn.commit() |
| 61 | + task.rowid = cur.lastrowid |
| 62 | + except Exception as e: |
| 63 | + # store exception for the waiting thread to raise or handle |
| 64 | + task.exception = e |
| 65 | + try: |
| 66 | + conn.rollback() |
| 67 | + except Exception: |
| 68 | + pass |
| 69 | + _LOG.exception("Error executing DB task") |
| 70 | + finally: |
| 71 | + task.event.set() |
| 72 | + self._q.task_done() |
| 73 | + except Exception: |
| 74 | + _LOG.exception("DBWriter thread initialization failed") |
| 75 | + finally: |
| 76 | + if conn: |
| 77 | + try: |
| 78 | + conn.close() |
| 79 | + except Exception: |
| 80 | + pass |
| 81 | + |
| 82 | + def enqueue_and_wait(self, sql, params, wait_timeout=60.0): |
| 83 | + """ |
| 84 | + Enqueue an SQL write and wait for the background thread to perform it. |
| 85 | + Returns the lastrowid or raises the exception raised during execution. |
| 86 | + """ |
| 87 | + task = _DBTask(sql, params) |
| 88 | + self._q.put(task) |
| 89 | + completed = task.event.wait(wait_timeout) |
| 90 | + if not completed: |
| 91 | + raise TimeoutError(f"Timed out waiting for DB write to {self.database_path}") |
| 92 | + if task.exception: |
| 93 | + # re-raise sqlite3.OperationalError or other exceptions |
| 94 | + raise task.exception |
| 95 | + return task.rowid |
| 96 | + |
| 97 | + def enqueue_no_wait(self, sql, params): |
| 98 | + """ |
| 99 | + Fire-and-forget enqueue (no result returned). |
| 100 | + """ |
| 101 | + task = _DBTask(sql, params) |
| 102 | + self._q.put(task) |
| 103 | + return task |
| 104 | + |
| 105 | + def stop(self, wait=True): |
| 106 | + """Stop the worker thread. If wait=True, block until thread joins.""" |
| 107 | + self._stop.set() |
| 108 | + # enqueue sentinel for immediate exit |
| 109 | + self._q.put(None) |
| 110 | + if wait: |
| 111 | + self._thread.join(timeout=5.0) |
| 112 | + |
| 113 | +def _get_writer(database_path): |
| 114 | + with _WRITERS_LOCK: |
| 115 | + w = _WRITERS.get(database_path) |
| 116 | + if w is None: |
| 117 | + w = DBWriter(database_path) |
| 118 | + _WRITERS[database_path] = w |
| 119 | + return w |
| 120 | + |
| 121 | +def stop_all_writers(): |
| 122 | + """Stop all DBWriter threads (called automatically at process exit).""" |
| 123 | + with _WRITERS_LOCK: |
| 124 | + writers = list(_WRITERS.values()) |
| 125 | + _WRITERS.clear() |
| 126 | + for w in writers: |
| 127 | + try: |
| 128 | + w.stop(wait=True) |
| 129 | + except Exception: |
| 130 | + _LOG.exception("Error stopping DBWriter") |
| 131 | + |
| 132 | +# ensure cleanup at exit |
| 133 | +atexit.register(stop_all_writers) |
6 | 134 |
|
7 | 135 | # Simple connection helper: we open new connections per operation so the code is robust |
8 | 136 | # across threads. We set WAL journal mode for safer concurrency. |
@@ -108,21 +236,20 @@ def update_analysis_status(database_path: str, analysis_id: int, status: str) -> |
108 | 236 | conn.close() |
109 | 237 |
|
110 | 238 |
|
111 | | -def store_file(database_path: str, analysis_id: int, path: str, content: str, language: str) -> int: |
| 239 | +def store_file(database_path, analysis_id, path, content, language): |
112 | 240 | """ |
113 | | - Insert a file row. Returns the new file id. |
| 241 | + Insert a file record into the DB using a queued single-writer to avoid |
| 242 | + sqlite 'database is locked' errors in multithreaded scenarios. |
| 243 | + Returns lastrowid (same as the previous store_file implementation). |
114 | 244 | """ |
115 | | - conn = _get_connection(database_path) |
116 | | - try: |
117 | | - cur = conn.cursor() |
118 | | - cur.execute( |
119 | | - "INSERT INTO files (analysis_id, path, content, language, snippet) VALUES (?, ?, ?, ?, ?)", |
120 | | - (analysis_id, path, content, language, (content[:512] if content else "")), |
121 | | - ) |
122 | | - conn.commit() |
123 | | - return int(cur.lastrowid) |
124 | | - finally: |
125 | | - conn.close() |
| 245 | + snippet = (content[:512] if content else "") |
| 246 | + sql = "INSERT INTO files (analysis_id, path, content, language, snippet) VALUES (?, ?, ?, ?, ?)" |
| 247 | + params = (analysis_id, path, content, language, snippet) |
| 248 | + |
| 249 | + writer = _get_writer(database_path) |
| 250 | + # We wait for the background writer to complete the insert and then return the row id. |
| 251 | + # This preserves the synchronous semantics callers expect. |
| 252 | + return writer.enqueue_and_wait(sql, params, wait_timeout=60.0) |
126 | 253 |
|
127 | 254 |
|
128 | 255 | def insert_chunk_row_with_null_embedding(database_path: str, file_id: int, path: str, chunk_index: int) -> int: |
|
0 commit comments