From 3709bbdd25e98d12acedc891d9deb2ca8e6c4427 Mon Sep 17 00:00:00 2001 From: rundef Date: Fri, 2 May 2025 14:14:45 -0400 Subject: [PATCH 1/2] Performance improvements and benchmarks --- benchmark.py | 125 ++++++++++++++++++++++ docs/benchmarks.rst | 29 +++++ docs/index.rst | 3 +- sqlalchemy_memory/base/indexes.py | 48 ++++++--- sqlalchemy_memory/base/pending_changes.py | 3 +- sqlalchemy_memory/base/session.py | 8 +- sqlalchemy_memory/base/store.py | 62 +++++++---- sqlalchemy_memory/helpers/ordered_set.py | 29 +++++ 8 files changed, 272 insertions(+), 35 deletions(-) create mode 100644 benchmark.py create mode 100644 docs/benchmarks.rst create mode 100644 sqlalchemy_memory/helpers/ordered_set.py diff --git a/benchmark.py b/benchmark.py new file mode 100644 index 0000000..20d7d56 --- /dev/null +++ b/benchmark.py @@ -0,0 +1,125 @@ +from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Index, update, delete +from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy_memory import MemorySession +import argparse +import time +import random +from faker import Faker + +try: + from sqlalchemy_memory import create_memory_engine +except ImportError: + create_memory_engine = None + +Base = declarative_base() +fake = Faker() +CATEGORIES = list("ABCDEFGHIJK") + +class Item(Base): + __tablename__ = "items" + + id = Column(Integer, primary_key=True) + name = Column(String) + active = Column(Boolean, index=True) + category = Column(String, index=True) + +def generate_items(n): + for _ in range(n): + yield Item( + name=fake.name(), + active=random.choice([True, False]), + category=random.choice(CATEGORIES) + ) + +def generate_random_select_query(): + clauses = [] + if random.random() < 0.5: + clauses.append(Item.active == random.choice([True, False])) + if random.random() < 0.5 or not clauses: + subset = random.sample(CATEGORIES, random.randint(1, 4)) + clauses.append(Item.category.in_(subset)) + return select(Item).where(*clauses) + +def inserts(Session, count): + insert_start = time.time() + with Session() as session: + session.add_all(generate_items(count)) + session.commit() + insert_duration = time.time() - insert_start + print(f"Inserted {count} items in {insert_duration:.2f} seconds.") + return insert_duration + +def selects(Session, count): + queries = [generate_random_select_query() for _ in range(count)] + + query_start = time.time() + with Session() as session: + for stmt in queries: + list(session.execute(stmt).scalars()) + query_duration = time.time() - query_start + print(f"Executed {count} select queries in {query_duration:.2f} seconds.") + return query_duration + +def updates(Session, random_ids): + update_start = time.time() + with Session() as session: + for rid in random_ids: + stmt = update(Item).where(Item.id == rid).values( + name=fake.name(), + category=random.choice(CATEGORIES), + active=random.choice([True, False]) + ) + session.execute(stmt) + session.commit() + update_duration = time.time() - update_start + print(f"Executed {len(random_ids)} updates in {update_duration:.2f} seconds.") + return update_duration + +def deletes(Session, random_ids): + delete_start = time.time() + with Session() as session: + for rid in random_ids: + stmt = delete(Item).where(Item.id == rid) + session.execute(stmt) + session.commit() + delete_duration = time.time() - delete_start + print(f"Deleted {len(random_ids)} items in {delete_duration:.2f} seconds.") + return delete_duration + +def run_benchmark(db_type="sqlite", count=100_000): + print(f"Running benchmark: type={db_type}, count={count}") + + if db_type == "sqlite": + engine = create_engine("sqlite:///:memory:", echo=False) + Session = sessionmaker(engine) + elif db_type == "memory": + engine = create_engine("memory://") + Session = sessionmaker( + engine, + class_=MemorySession, + expire_on_commit=False, + ) + else: + raise ValueError("Invalid --type. Use 'sqlite' or 'memory'.") + + Base.metadata.create_all(engine) + + elapsed = inserts(Session, count) + elapsed += selects(Session, 500) + + random_ids = random.sample(range(1, count + 1), 500) + elapsed += updates(Session, random_ids) + + random_ids = random.sample(range(1, count + 1), 500) + elapsed += deletes(Session, random_ids) + + print(f"Total runtime for {db_type}: {elapsed:.2f} seconds.") + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--type", choices=["sqlite", "memory"], required=True) + parser.add_argument("--count", type=int, default=10_000) + args = parser.parse_args() + run_benchmark(args.type, args.count) diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst new file mode 100644 index 0000000..415d41b --- /dev/null +++ b/docs/benchmarks.rst @@ -0,0 +1,29 @@ +Benchmark Comparison (20,000 items) +=================================== + +This benchmark compares `sqlalchemy-memory` to `in-memory SQLite` using 20,000 inserted items and a series of 500 queries, updates, and deletions. + +As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, delivering significantly faster query performance. While SQLite performs slightly better on update and delete operations, the overall runtime of `sqlalchemy-memory` remains substantially lower, making it a strong choice for prototyping and simulation. + +.. list-table:: + :header-rows: 1 + :widths: 25 25 25 + + * - Operation + - SQLite (in-memory) + - sqlalchemy-memory + * - Insert + - 3.17 sec + - 2.70 sec + * - 500 Select Queries + - 26.37 sec + - 2.94 sec + * - 500 Updates + - 0.26 sec + - 1.12 sec + * - 500 Deletes + - 0.09 sec + - 0.90 sec + * - **Total Runtime** + - **29.89 sec** + - **7.66 sec** diff --git a/docs/index.rst b/docs/index.rst index cfb4bca..c1be23f 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -112,4 +112,5 @@ Quickstart: async example query update delete - commit_rollback \ No newline at end of file + commit_rollback + benchmarks \ No newline at end of file diff --git a/sqlalchemy_memory/base/indexes.py b/sqlalchemy_memory/base/indexes.py index 9ea1400..9809d40 100644 --- a/sqlalchemy_memory/base/indexes.py +++ b/sqlalchemy_memory/base/indexes.py @@ -3,8 +3,12 @@ from typing import Any, List from sqlalchemy.sql import operators +from ..helpers.ordered_set import OrderedSet + class IndexManager: + __slots__ = ('hash_index', 'range_index', 'table_indexes', 'columns_mapping', ) + def __init__(self): self.hash_index = HashIndex() self.range_index = RangeIndex() @@ -12,6 +16,7 @@ def __init__(self): self.table_indexes = {} self.columns_mapping = {} + def get_indexes(self, obj): """ Retrieve index from object's table as dict: indexname => list of column name @@ -21,18 +26,27 @@ def get_indexes(self, obj): if tablename not in self.table_indexes: self.table_indexes[tablename] = {} + pk_col_name = obj.__table__.primary_key.columns[0].name + for index in obj.__table__.indexes: if len(index.expressions) > 1: # Ignoring compound indexes for now ... continue + if index.name == pk_col_name: + pk_col_name = None + self.table_indexes[tablename][index.name] = [ col.name for col in index.expressions ] + if pk_col_name: + self.table_indexes[tablename][pk_col_name] = [pk_col_name] + return self.table_indexes[tablename] + def _column_to_index(self, tablename, colname): """ Get index name from tablename & column name @@ -51,6 +65,7 @@ def _column_to_index(self, tablename, colname): return self.columns_mapping[tablename][colname] + def _get_index_key(self, obj, columns): if len(columns) == 1: return getattr(obj, columns[0]) @@ -65,7 +80,7 @@ def on_insert(self, obj): self.hash_index.add(tablename, indexname, value, obj) self.range_index.add(tablename, indexname, value, obj) - + def on_delete(self, obj): tablename = obj.__tablename__ indexes = self.get_indexes(obj) @@ -145,6 +160,7 @@ def query(self, collection, tablename, colname, operator, value): in_range = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1]) return list(set(collection) - set(in_range)) + def get_selectivity(self, tablename, colname, operator, value, total_count): """ Estimate selectivity: higher means worst filtering. @@ -187,23 +203,24 @@ class HashIndex: Maintains insertion order of objects. """ + __slots__ = ('index',) + def __init__(self): - self.index = defaultdict(lambda: defaultdict(lambda: defaultdict(list))) + self.index = defaultdict(lambda: defaultdict(lambda: defaultdict(OrderedSet))) + def add(self, tablename: str, indexname: str, value: Any, obj: Any): - self.index[tablename][indexname][value].append(obj) + self.index[tablename][indexname][value].add(obj) + def remove(self, tablename: str, indexname: str, value: Any, obj: Any): - lst = self.index[tablename][indexname][value] - try: - lst.remove(obj) - if not lst: - del self.index[tablename][indexname][value] - except ValueError: - pass + s = self.index[tablename][indexname][value] + s.discard(obj) + if not s: + del self.index[tablename][indexname][value] def query(self, tablename: str, indexname: str, value: Any) -> List[Any]: - return self.index[tablename][indexname].get(value, []) + return list(self.index[tablename][indexname].get(value, [])) class RangeIndex: @@ -215,12 +232,19 @@ class RangeIndex: index[tablename][indexname] = SortedDict { value: [obj1, obj2, ...] } """ + __slots__ = ('index',) + def __init__(self): self.index = defaultdict(lambda: defaultdict(SortedDict)) def add(self, tablename: str, indexname: str, value: Any, obj: Any): - self.index[tablename][indexname].setdefault(value, []).append(obj) + index = self.index[tablename][indexname] + if value in index: + index[value].append(obj) + else: + index[value] = [obj] + def remove(self, tablename: str, indexname: str, value: Any, obj: Any): col = self.index[tablename][indexname] if value in col: diff --git a/sqlalchemy_memory/base/pending_changes.py b/sqlalchemy_memory/base/pending_changes.py index d3851ea..fa93ccd 100644 --- a/sqlalchemy_memory/base/pending_changes.py +++ b/sqlalchemy_memory/base/pending_changes.py @@ -20,8 +20,7 @@ def rollback(self): def add(self, obj, **kwargs): tablename = obj.__tablename__ - if not any(id(x) == id(obj) for x in self._to_add[tablename]): - self._to_add[tablename].append(obj) + self._to_add[tablename].append(obj) def delete(self, obj): tablename = obj.__tablename__ diff --git a/sqlalchemy_memory/base/session.py b/sqlalchemy_memory/base/session.py index 215c1f0..c26739f 100644 --- a/sqlalchemy_memory/base/session.py +++ b/sqlalchemy_memory/base/session.py @@ -22,6 +22,10 @@ def __init__(self, *args, **kwargs): def add(self, obj, **kwargs): self.pending_changes.add(obj, **kwargs) + def add_all(self, instances, **kwargs): + for instance in instances: + self.add(instance, **kwargs) + def delete(self, obj): self.pending_changes.delete(obj) @@ -159,7 +163,7 @@ def _handle_update(self, statement: Update, **kwargs): pk_col_name = None for obj in collection: if pk_col_name is None: - pk_col_name = self.store._get_primary_key_name(obj) + pk_col_name = self.store._get_primary_key_name(obj.__table__) pk_value = getattr(obj, pk_col_name) self.update(tablename, pk_value, data) @@ -188,7 +192,7 @@ def merge(self, instance, **kwargs): Merge a possibly detached instance into the current session """ - pk_name = self.store._get_primary_key_name(instance) + pk_name = self.store._get_primary_key_name(instance.__table__) pk_value = getattr(instance, pk_name) existing = self.store.get_by_primary_key(instance, pk_value) diff --git a/sqlalchemy_memory/base/store.py b/sqlalchemy_memory/base/store.py index bb68af2..85ebcf2 100644 --- a/sqlalchemy_memory/base/store.py +++ b/sqlalchemy_memory/base/store.py @@ -24,6 +24,10 @@ def _reset(self): # Auto increment counter per table self._pk_counter = defaultdict(int) + # Caches + self.table_columns = {} + self.table_pk_name = {} + @property def dirty(self): return self.pending_changes.dirty @@ -37,7 +41,7 @@ def commit(self): continue data = self.data.get(tablename, []) - pk_col_name = self._get_primary_key_name(objs[0]) + pk_col_name = self._get_primary_key_name(objs[0].__table__) pk_values = set(getattr(obj, pk_col_name) for obj in objs) logger.debug(f"Deleting rows from table '{tablename}' with PK values={pk_values}") @@ -57,11 +61,16 @@ def commit(self): self.index_manager.on_delete(obj) # apply adds + added = set() for tablename, objs in self.pending_changes._to_add.items(): if tablename not in self.data: self.data[tablename] = [] for obj in objs: + if id(obj) in added: + continue + added.add(id(obj)) + pk_value = self._assign_primary_key_if_needed(obj) if pk_value in self.data_by_pk[tablename].keys(): raise Exception(f"Cannot have duplicate PK value {pk_value} for table '{tablename}'") @@ -109,17 +118,31 @@ def get_by_primary_key(self, entity, pk_value): return self.data_by_pk[tablename].get(pk_value) - def _get_primary_key_name(self, obj): + def _get_primary_key_name(self, table): """ Return the PK column name """ - pk_cols = obj.__table__.primary_key.columns + tablename = table.name + if tablename not in self.table_pk_name: + pk_cols = table.primary_key.columns + + if len(pk_cols) != 1: + raise NotImplementedError("Only single-column primary keys are supported.") + + col = list(pk_cols)[0] + self.table_pk_name[tablename] = col.name + + return self.table_pk_name[tablename] - if len(pk_cols) != 1: - raise NotImplementedError("Only single-column primary keys are supported.") + def _get_table_columns(self, table): + """ + Returns the table columns + """ + tablename = table.name + if tablename not in self.table_columns: + self.table_columns[tablename] = table.columns - col = list(pk_cols)[0] - return col.name + return self.table_columns[tablename] def _assign_primary_key_if_needed(self, obj): """ @@ -127,18 +150,18 @@ def _assign_primary_key_if_needed(self, obj): If user specifies an ID, use it and update the counter if necessary. If no ID is specified, assign the next available one. """ - pk_col_name = self._get_primary_key_name(obj) - table = obj.__tablename__ - current_id = getattr(obj, pk_col_name, None) + pk_col_name = self._get_primary_key_name(obj.__table__) + current_id = obj.__dict__.get(pk_col_name, None) + tablename = obj.__tablename__ if current_id is None: # Auto-assign next ID - self._pk_counter[table] += 1 - current_id = self._pk_counter[table] - setattr(obj, pk_col_name, current_id) + current_id = self._pk_counter[tablename] = self._pk_counter[tablename] + 1 + obj.__dict__[pk_col_name] = current_id + else: # Ensure auto-increment counter stays ahead - self._pk_counter[table] = max(self._pk_counter[table], current_id) + self._pk_counter[tablename] = max(self._pk_counter[tablename], current_id) return current_id @@ -147,7 +170,10 @@ def _apply_column_defaults(self, obj): Apply default and server_default values to an ORM object. """ - for column in obj.__table__.columns: + for column in self._get_table_columns(obj.__table__): + if column.default is None and column.server_default is None: + continue + attr_name = column.name current_value = getattr(obj, attr_name, None) @@ -163,15 +189,15 @@ def _apply_column_defaults(self, obj): else: value = column.default.arg - setattr(obj, attr_name, value) + obj.__dict__[attr_name] = value elif column.server_default is not None: if isinstance(column.server_default.arg, TextClause): text_value = column.server_default.arg.text - setattr(obj, attr_name, text_value) + obj.__dict__[attr_name] = text_value elif isinstance(column.server_default.arg, func.now().__class__): - setattr(obj, attr_name, datetime.utcnow()) + obj.__dict__[attr_name] = datetime.utcnow() else: raise Exception(f"Unhandled server_default type: {type(column.server_default)}") diff --git a/sqlalchemy_memory/helpers/ordered_set.py b/sqlalchemy_memory/helpers/ordered_set.py new file mode 100644 index 0000000..2289826 --- /dev/null +++ b/sqlalchemy_memory/helpers/ordered_set.py @@ -0,0 +1,29 @@ +from collections import OrderedDict + +class OrderedSet: + def __init__(self): + self._data = OrderedDict() + + def add(self, item): + self._data[item] = None + + def discard(self, item): + self._data.pop(item, None) + + def __contains__(self, item): + return item in self._data + + def __iter__(self): + return iter(self._data) + + def __len__(self): + return len(self._data) + + def __bool__(self): + return bool(self._data) + + def remove(self, item): + del self._data[item] + + def clear(self): + self._data.clear() From 7ec5ee7d1eb83a12af2861be82c9ca6097de2616 Mon Sep 17 00:00:00 2001 From: rundef Date: Fri, 2 May 2025 14:14:49 -0400 Subject: [PATCH 2/2] =?UTF-8?q?Bump=20version:=200.3.0=20=E2=86=92=200.3.1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pyproject.toml | 2 +- sqlalchemy_memory/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 484a2e9..72ca5f4 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.0 +current_version = 0.3.1 commit = True tag = True diff --git a/pyproject.toml b/pyproject.toml index cca6509..f111648 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sqlalchemy-memory" -version = "0.3.0" +version = "0.3.1" dependencies = [ "sqlalchemy>=2.0,<3.0", "sortedcontainers>=2.4.0" diff --git a/sqlalchemy_memory/__init__.py b/sqlalchemy_memory/__init__.py index d1b3d78..e9f5cac 100644 --- a/sqlalchemy_memory/__init__.py +++ b/sqlalchemy_memory/__init__.py @@ -6,4 +6,4 @@ "AsyncMemorySession", ] -__version__ = '0.3.0' \ No newline at end of file +__version__ = '0.3.1' \ No newline at end of file