Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 46 additions & 45 deletions benchmarks/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,50 +29,51 @@ def run_benchmark_suite(
)
client = SQLiteVecClient(table="benchmark", db_path=db_path)

# Create table
dim = config["dimension"]
distance = config["distance"]
client.create_table(dim=dim, distance=distance)

# Generate data
texts = generate_texts(dataset_size)
embeddings = generate_embeddings(dataset_size, dim)
metadata = generate_metadata(dataset_size)

# Benchmark: Add
print(f" Benchmarking add ({dataset_size} records)...")
results.append(benchmark_add(client, texts, embeddings, metadata))

# Get rowids for subsequent operations
rowids = list(range(1, dataset_size + 1))

# Benchmark: Get Many
print(f" Benchmarking get_many ({dataset_size} records)...")
results.append(benchmark_get_many(client, rowids))

# Benchmark: Similarity Search
print(" Benchmarking similarity_search...")
query_emb = [0.5] * dim
iterations = config["similarity_search"]["iterations"]
for top_k in config["similarity_search"]["top_k_values"]:
results.append(
benchmark_similarity_search(client, query_emb, top_k, iterations)
)

# Benchmark: Update Many
print(f" Benchmarking update_many ({dataset_size} records)...")
new_texts = [f"updated_{i}" for i in range(dataset_size)]
results.append(benchmark_update_many(client, rowids, new_texts))

# Benchmark: Get All
print(f" Benchmarking get_all ({dataset_size} records)...")
batch_size = config["batch_size"]
results.append(benchmark_get_all(client, dataset_size, batch_size))

# Benchmark: Delete Many
print(f" Benchmarking delete_many ({dataset_size} records)...")
results.append(benchmark_delete_many(client, rowids))

client.close()
try:
# Create table
dim = config["dimension"]
distance = config["distance"]
client.create_table(dim=dim, distance=distance)

# Generate data
texts = generate_texts(dataset_size)
embeddings = generate_embeddings(dataset_size, dim)
metadata = generate_metadata(dataset_size)

# Benchmark: Add
print(f" Benchmarking add ({dataset_size} records)...")
results.append(benchmark_add(client, texts, embeddings, metadata))

# Get rowids for subsequent operations
rowids = list(range(1, dataset_size + 1))

# Benchmark: Get Many
print(f" Benchmarking get_many ({dataset_size} records)...")
results.append(benchmark_get_many(client, rowids))

# Benchmark: Similarity Search
print(" Benchmarking similarity_search...")
query_emb = [0.5] * dim
iterations = config["similarity_search"]["iterations"]
for top_k in config["similarity_search"]["top_k_values"]:
results.append(
benchmark_similarity_search(client, query_emb, top_k, iterations)
)

# Benchmark: Update Many
print(f" Benchmarking update_many ({dataset_size} records)...")
new_texts = [f"updated_{i}" for i in range(dataset_size)]
results.append(benchmark_update_many(client, rowids, new_texts))

# Benchmark: Get All
print(f" Benchmarking get_all ({dataset_size} records)...")
batch_size = config["batch_size"]
results.append(benchmark_get_all(client, dataset_size, batch_size))

# Benchmark: Delete Many
print(f" Benchmarking delete_many ({dataset_size} records)...")
results.append(benchmark_delete_many(client, rowids))
finally:
client.close()

return results
167 changes: 141 additions & 26 deletions sqlite_vec_client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ def create_connection(db_path: str) -> sqlite3.Connection:
connection.enable_load_extension(True)
sqlite_vec.load(connection)
connection.enable_load_extension(False)

# Performance optimizations
connection.execute("PRAGMA journal_mode=WAL")
connection.execute("PRAGMA synchronous=NORMAL")
connection.execute("PRAGMA cache_size=-64000") # 64MB cache
connection.execute("PRAGMA temp_store=MEMORY")

logger.info(f"Successfully connected to database: {db_path}")
return connection
except sqlite3.Error as e:
Expand Down Expand Up @@ -262,31 +269,33 @@ def add(
validate_embeddings_match(texts, embeddings, metadata)
logger.debug(f"Adding {len(texts)} records to table '{self.table}'")
try:
max_id = self.connection.execute(
f"SELECT max(rowid) as rowid FROM {self.table}"
).fetchone()["rowid"]

if max_id is None:
max_id = 0

if metadata is None:
metadata = [dict() for _ in texts]

data_input = [
(text, json.dumps(md), serialize_f32(embedding))
for text, md, embedding in zip(texts, metadata, embeddings)
]
self.connection.executemany(

cur = self.connection.cursor()

# Get max rowid before insert
max_before = cur.execute(
f"SELECT COALESCE(MAX(rowid), 0) FROM {self.table}"
).fetchone()[0]

cur.executemany(
f"""INSERT INTO {self.table}(text, metadata, text_embedding)
VALUES (?,?,?)""",
data_input,
)

# Calculate rowids from max_before
rowids = list(range(max_before + 1, max_before + len(texts) + 1))

if not self._in_transaction:
self.connection.commit()
results = self.connection.execute(
f"SELECT rowid FROM {self.table} WHERE rowid > {max_id}"
)
rowids = [row["rowid"] for row in results]

logger.info(f"Added {len(rowids)} records to table '{self.table}'")
return rowids
except sqlite3.OperationalError as e:
Expand Down Expand Up @@ -447,15 +456,25 @@ def delete_many(self, rowids: list[int]) -> int:
if not rowids:
return 0
logger.debug(f"Deleting {len(rowids)} records")
placeholders = ",".join(["?"] * len(rowids))

# SQLite has a limit on SQL variables (typically 999 or 32766)
# Split into chunks to avoid "too many SQL variables" error
chunk_size = 500
cur = self.connection.cursor()
cur.execute(
f"DELETE FROM {self.table} WHERE rowid IN ({placeholders})",
rowids,
)
deleted_count = 0

for i in range(0, len(rowids), chunk_size):
chunk = rowids[i : i + chunk_size]
placeholders = ",".join(["?"] * len(chunk))
cur.execute(
f"DELETE FROM {self.table} WHERE rowid IN ({placeholders})",
chunk,
)
deleted_count += cur.rowcount

if not self._in_transaction:
self.connection.commit()
deleted_count = cur.rowcount

logger.info(f"Deleted {deleted_count} records from table '{self.table}'")
return deleted_count

Expand All @@ -475,10 +494,93 @@ def update_many(
if not updates:
return 0
logger.debug(f"Updating {len(updates)} records")
updated_count = 0

# Group updates by which fields are being updated
text_updates = []
metadata_updates = []
embedding_updates = []
full_updates = []

mixed_updates = []

for rowid, text, metadata, embedding in updates:
if self.update(rowid, text=text, metadata=metadata, embedding=embedding):
updated_count += 1
has_text = text is not None
has_metadata = metadata is not None
has_embedding = embedding is not None

if has_text and has_metadata and has_embedding:
if text is not None and metadata is not None and embedding is not None:
full_updates.append(
(text, json.dumps(metadata), serialize_f32(embedding), rowid)
)
elif has_text and not has_metadata and not has_embedding:
text_updates.append((text, rowid))
elif has_metadata and not has_text and not has_embedding:
metadata_updates.append((json.dumps(metadata), rowid))
elif has_embedding and not has_text and not has_metadata:
if embedding is not None:
embedding_updates.append((serialize_f32(embedding), rowid))
else:
# Mixed updates - store for individual execution
mixed_updates.append((rowid, text, metadata, embedding))

cur = self.connection.cursor()
updated_count = 0

# Batch execute grouped updates
if full_updates:
cur.executemany(
f"""
UPDATE {self.table}
SET text = ?, metadata = ?, text_embedding = ? WHERE rowid = ?
""",
full_updates,
)
updated_count += cur.rowcount

if text_updates:
cur.executemany(
f"UPDATE {self.table} SET text = ? WHERE rowid = ?", text_updates
)
updated_count += cur.rowcount

if metadata_updates:
cur.executemany(
f"UPDATE {self.table} SET metadata = ? WHERE rowid = ?",
metadata_updates,
)
updated_count += cur.rowcount

if embedding_updates:
cur.executemany(
f"UPDATE {self.table} SET text_embedding = ? WHERE rowid = ?",
embedding_updates,
)
updated_count += cur.rowcount

# Handle mixed updates individually
for rowid, text, metadata, embedding in mixed_updates:
sets = []
params: list[Any] = []
if text is not None:
sets.append("text = ?")
params.append(text)
if metadata is not None:
sets.append("metadata = ?")
params.append(json.dumps(metadata))
if embedding is not None:
sets.append("text_embedding = ?")
params.append(serialize_f32(embedding))
params.append(rowid)

if sets:
sql = f"UPDATE {self.table} SET " + ", ".join(sets) + " WHERE rowid = ?"
cur.execute(sql, params)
updated_count += cur.rowcount

if not self._in_transaction:
self.connection.commit()

logger.info(f"Updated {updated_count} records in table '{self.table}'")
return updated_count

Expand All @@ -493,13 +595,26 @@ def get_all(self, batch_size: int = 100) -> Generator[Result, None, None]:
"""
validate_limit(batch_size)
logger.debug(f"Fetching all records with batch_size={batch_size}")
offset = 0
last_rowid = 0
cursor = self.connection.cursor()

while True:
batch = self.list_results(limit=batch_size, offset=offset)
if not batch:
cursor.execute(
f"""
SELECT rowid, text, metadata, text_embedding FROM {self.table}
WHERE rowid > ?
ORDER BY rowid ASC
LIMIT ?
""",
[last_rowid, batch_size],
)
rows = cursor.fetchall()
if not rows:
break
yield from batch
offset += batch_size

results = self.rows_to_results(rows)
yield from results
last_rowid = results[-1][0] # Get last rowid from batch

@contextmanager
def transaction(self) -> Generator[None, None, None]:
Expand Down