diff --git a/benchmarks/runner.py b/benchmarks/runner.py index 5ea10c0..c2d57af 100644 --- a/benchmarks/runner.py +++ b/benchmarks/runner.py @@ -29,50 +29,51 @@ def run_benchmark_suite( ) client = SQLiteVecClient(table="benchmark", db_path=db_path) - # Create table - dim = config["dimension"] - distance = config["distance"] - client.create_table(dim=dim, distance=distance) - - # Generate data - texts = generate_texts(dataset_size) - embeddings = generate_embeddings(dataset_size, dim) - metadata = generate_metadata(dataset_size) - - # Benchmark: Add - print(f" Benchmarking add ({dataset_size} records)...") - results.append(benchmark_add(client, texts, embeddings, metadata)) - - # Get rowids for subsequent operations - rowids = list(range(1, dataset_size + 1)) - - # Benchmark: Get Many - print(f" Benchmarking get_many ({dataset_size} records)...") - results.append(benchmark_get_many(client, rowids)) - - # Benchmark: Similarity Search - print(" Benchmarking similarity_search...") - query_emb = [0.5] * dim - iterations = config["similarity_search"]["iterations"] - for top_k in config["similarity_search"]["top_k_values"]: - results.append( - benchmark_similarity_search(client, query_emb, top_k, iterations) - ) - - # Benchmark: Update Many - print(f" Benchmarking update_many ({dataset_size} records)...") - new_texts = [f"updated_{i}" for i in range(dataset_size)] - results.append(benchmark_update_many(client, rowids, new_texts)) - - # Benchmark: Get All - print(f" Benchmarking get_all ({dataset_size} records)...") - batch_size = config["batch_size"] - results.append(benchmark_get_all(client, dataset_size, batch_size)) - - # Benchmark: Delete Many - print(f" Benchmarking delete_many ({dataset_size} records)...") - results.append(benchmark_delete_many(client, rowids)) - - client.close() + try: + # Create table + dim = config["dimension"] + distance = config["distance"] + client.create_table(dim=dim, distance=distance) + + # Generate data + texts = generate_texts(dataset_size) + embeddings = generate_embeddings(dataset_size, dim) + metadata = generate_metadata(dataset_size) + + # Benchmark: Add + print(f" Benchmarking add ({dataset_size} records)...") + results.append(benchmark_add(client, texts, embeddings, metadata)) + + # Get rowids for subsequent operations + rowids = list(range(1, dataset_size + 1)) + + # Benchmark: Get Many + print(f" Benchmarking get_many ({dataset_size} records)...") + results.append(benchmark_get_many(client, rowids)) + + # Benchmark: Similarity Search + print(" Benchmarking similarity_search...") + query_emb = [0.5] * dim + iterations = config["similarity_search"]["iterations"] + for top_k in config["similarity_search"]["top_k_values"]: + results.append( + benchmark_similarity_search(client, query_emb, top_k, iterations) + ) + + # Benchmark: Update Many + print(f" Benchmarking update_many ({dataset_size} records)...") + new_texts = [f"updated_{i}" for i in range(dataset_size)] + results.append(benchmark_update_many(client, rowids, new_texts)) + + # Benchmark: Get All + print(f" Benchmarking get_all ({dataset_size} records)...") + batch_size = config["batch_size"] + results.append(benchmark_get_all(client, dataset_size, batch_size)) + + # Benchmark: Delete Many + print(f" Benchmarking delete_many ({dataset_size} records)...") + results.append(benchmark_delete_many(client, rowids)) + finally: + client.close() return results diff --git a/sqlite_vec_client/base.py b/sqlite_vec_client/base.py index 5c0ed89..6dd607a 100644 --- a/sqlite_vec_client/base.py +++ b/sqlite_vec_client/base.py @@ -63,6 +63,13 @@ def create_connection(db_path: str) -> sqlite3.Connection: connection.enable_load_extension(True) sqlite_vec.load(connection) connection.enable_load_extension(False) + + # Performance optimizations + connection.execute("PRAGMA journal_mode=WAL") + connection.execute("PRAGMA synchronous=NORMAL") + connection.execute("PRAGMA cache_size=-64000") # 64MB cache + connection.execute("PRAGMA temp_store=MEMORY") + logger.info(f"Successfully connected to database: {db_path}") return connection except sqlite3.Error as e: @@ -262,13 +269,6 @@ def add( validate_embeddings_match(texts, embeddings, metadata) logger.debug(f"Adding {len(texts)} records to table '{self.table}'") try: - max_id = self.connection.execute( - f"SELECT max(rowid) as rowid FROM {self.table}" - ).fetchone()["rowid"] - - if max_id is None: - max_id = 0 - if metadata is None: metadata = [dict() for _ in texts] @@ -276,17 +276,26 @@ def add( (text, json.dumps(md), serialize_f32(embedding)) for text, md, embedding in zip(texts, metadata, embeddings) ] - self.connection.executemany( + + cur = self.connection.cursor() + + # Get max rowid before insert + max_before = cur.execute( + f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}" + ).fetchone()[0] + + cur.executemany( f"""INSERT INTO {self.table}(text, metadata, text_embedding) VALUES (?,?,?)""", data_input, ) + + # Calculate rowids from max_before + rowids = list(range(max_before + 1, max_before + len(texts) + 1)) + if not self._in_transaction: self.connection.commit() - results = self.connection.execute( - f"SELECT rowid FROM {self.table} WHERE rowid > {max_id}" - ) - rowids = [row["rowid"] for row in results] + logger.info(f"Added {len(rowids)} records to table '{self.table}'") return rowids except sqlite3.OperationalError as e: @@ -447,15 +456,25 @@ def delete_many(self, rowids: list[int]) -> int: if not rowids: return 0 logger.debug(f"Deleting {len(rowids)} records") - placeholders = ",".join(["?"] * len(rowids)) + + # SQLite has a limit on SQL variables (typically 999 or 32766) + # Split into chunks to avoid "too many SQL variables" error + chunk_size = 500 cur = self.connection.cursor() - cur.execute( - f"DELETE FROM {self.table} WHERE rowid IN ({placeholders})", - rowids, - ) + deleted_count = 0 + + for i in range(0, len(rowids), chunk_size): + chunk = rowids[i : i + chunk_size] + placeholders = ",".join(["?"] * len(chunk)) + cur.execute( + f"DELETE FROM {self.table} WHERE rowid IN ({placeholders})", + chunk, + ) + deleted_count += cur.rowcount + if not self._in_transaction: self.connection.commit() - deleted_count = cur.rowcount + logger.info(f"Deleted {deleted_count} records from table '{self.table}'") return deleted_count @@ -475,10 +494,93 @@ def update_many( if not updates: return 0 logger.debug(f"Updating {len(updates)} records") - updated_count = 0 + + # Group updates by which fields are being updated + text_updates = [] + metadata_updates = [] + embedding_updates = [] + full_updates = [] + + mixed_updates = [] + for rowid, text, metadata, embedding in updates: - if self.update(rowid, text=text, metadata=metadata, embedding=embedding): - updated_count += 1 + has_text = text is not None + has_metadata = metadata is not None + has_embedding = embedding is not None + + if has_text and has_metadata and has_embedding: + if text is not None and metadata is not None and embedding is not None: + full_updates.append( + (text, json.dumps(metadata), serialize_f32(embedding), rowid) + ) + elif has_text and not has_metadata and not has_embedding: + text_updates.append((text, rowid)) + elif has_metadata and not has_text and not has_embedding: + metadata_updates.append((json.dumps(metadata), rowid)) + elif has_embedding and not has_text and not has_metadata: + if embedding is not None: + embedding_updates.append((serialize_f32(embedding), rowid)) + else: + # Mixed updates - store for individual execution + mixed_updates.append((rowid, text, metadata, embedding)) + + cur = self.connection.cursor() + updated_count = 0 + + # Batch execute grouped updates + if full_updates: + cur.executemany( + f""" + UPDATE {self.table} + SET text = ?, metadata = ?, text_embedding = ? WHERE rowid = ? + """, + full_updates, + ) + updated_count += cur.rowcount + + if text_updates: + cur.executemany( + f"UPDATE {self.table} SET text = ? WHERE rowid = ?", text_updates + ) + updated_count += cur.rowcount + + if metadata_updates: + cur.executemany( + f"UPDATE {self.table} SET metadata = ? WHERE rowid = ?", + metadata_updates, + ) + updated_count += cur.rowcount + + if embedding_updates: + cur.executemany( + f"UPDATE {self.table} SET text_embedding = ? WHERE rowid = ?", + embedding_updates, + ) + updated_count += cur.rowcount + + # Handle mixed updates individually + for rowid, text, metadata, embedding in mixed_updates: + sets = [] + params: list[Any] = [] + if text is not None: + sets.append("text = ?") + params.append(text) + if metadata is not None: + sets.append("metadata = ?") + params.append(json.dumps(metadata)) + if embedding is not None: + sets.append("text_embedding = ?") + params.append(serialize_f32(embedding)) + params.append(rowid) + + if sets: + sql = f"UPDATE {self.table} SET " + ", ".join(sets) + " WHERE rowid = ?" + cur.execute(sql, params) + updated_count += cur.rowcount + + if not self._in_transaction: + self.connection.commit() + logger.info(f"Updated {updated_count} records in table '{self.table}'") return updated_count @@ -493,13 +595,26 @@ def get_all(self, batch_size: int = 100) -> Generator[Result, None, None]: """ validate_limit(batch_size) logger.debug(f"Fetching all records with batch_size={batch_size}") - offset = 0 + last_rowid = 0 + cursor = self.connection.cursor() + while True: - batch = self.list_results(limit=batch_size, offset=offset) - if not batch: + cursor.execute( + f""" + SELECT rowid, text, metadata, text_embedding FROM {self.table} + WHERE rowid > ? + ORDER BY rowid ASC + LIMIT ? + """, + [last_rowid, batch_size], + ) + rows = cursor.fetchall() + if not rows: break - yield from batch - offset += batch_size + + results = self.rows_to_results(rows) + yield from results + last_rowid = results[-1][0] # Get last rowid from batch @contextmanager def transaction(self) -> Generator[None, None, None]: