From 53ccb7662127b91b81955173382ba5f57ed526c3 Mon Sep 17 00:00:00 2001 From: rundef Date: Mon, 5 May 2025 17:52:55 -0400 Subject: [PATCH 1/4] Use generators to filter datasets --- README.md | 9 + benchmark.py | 59 ++++-- docs/benchmarks.rst | 29 +-- docs/index.rst | 2 + sqlalchemy_memory/base/indexes.py | 83 +++++--- sqlalchemy_memory/base/query.py | 325 +++++++++++++++++++++-------- sqlalchemy_memory/base/session.py | 66 +++--- sqlalchemy_memory/base/store.py | 7 +- sqlalchemy_memory/helpers/utils.py | 15 ++ tests/conftest.py | 7 + tests/models.py | 23 +- tests/test_advanced.py | 170 ++++++++++++--- tests/test_aggregation.py | 83 ++++++++ tests/test_basic.py | 1 + tests/test_comparison.py | 56 +++++ tests/test_indexes.py | 67 ++++-- tests/test_select.py | 133 ------------ 17 files changed, 775 insertions(+), 360 deletions(-) create mode 100644 sqlalchemy_memory/helpers/utils.py create mode 100644 tests/test_aggregation.py create mode 100644 tests/test_comparison.py delete mode 100644 tests/test_select.py diff --git a/README.md b/README.md index 92192ab..7d8e4cc 100644 --- a/README.md +++ b/README.md @@ -37,6 +37,15 @@ Data is kept purely in RAM and is **volatile**: it is **not persisted across app - **Commit/rollback support** - **Index support**: indexes are recognized and used for faster lookups - **Merge and `get()` support**: like real SQLAlchemy behavior +- **Lazy query evaluation**: supports generator pipelines and short-circuiting + - `first()`-style queries avoid scanning the full dataset + - Optimized for read-heavy workloads and streaming filters + +## Benchmark + +Curious how `sqlalchemy-memory` stacks up? + +[View Benchmark Results](https://sqlalchemy-memory.readthedocs.io/en/latest/benchmarks.html) comparing `sqlalchemy-memory` to `in-memory SQLite` ## Installation diff --git a/benchmark.py b/benchmark.py index 20d7d56..d26214c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,16 +1,15 @@ -from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Index, update, delete +from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Float, update, delete, bindparam, literal from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.sql import operators +from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy_memory import MemorySession import argparse import time import random from faker import Faker -try: - from sqlalchemy_memory import create_memory_engine -except ImportError: - create_memory_engine = None +random.seed(42) Base = declarative_base() fake = Faker() CATEGORIES = list("ABCDEFGHIJK") @@ -22,22 +21,46 @@ class Item(Base): name = Column(String) active = Column(Boolean, index=True) category = Column(String, index=True) + price = Column(Float, index=True) + cost = Column(Float) def generate_items(n): for _ in range(n): yield Item( name=fake.name(), active=random.choice([True, False]), - category=random.choice(CATEGORIES) + category=random.choice(CATEGORIES), + price=round(random.uniform(5, 500), 2), + cost=round(random.uniform(1, 300), 2), ) def generate_random_select_query(): clauses = [] + if random.random() < 0.5: - clauses.append(Item.active == random.choice([True, False])) - if random.random() < 0.5 or not clauses: + val = random.choice([True, False]) + op = random.choice([operators.eq, operators.ne]) + clauses.append(BinaryExpression(Item.active, literal(val), op)) + + if random.random() < 0.7: subset = random.sample(CATEGORIES, random.randint(1, 4)) - clauses.append(Item.category.in_(subset)) + op = random.choice([operators.in_op, operators.notin_op]) + param = bindparam("category_list", subset, expanding=True) + clauses.append(BinaryExpression(Item.category, param, op)) + + if random.random() < 0.6: + price_val = round(random.uniform(10, 400), 2) + op = random.choice([operators.gt, operators.lt, operators.le, operators.gt]) + clauses.append(BinaryExpression(Item.price, literal(price_val), op)) + + if random.random() < 0.3: + cost_val = round(random.uniform(10, 200), 2) + op = random.choice([operators.gt, operators.lt, operators.le, operators.gt]) + clauses.append(BinaryExpression(Item.cost, literal(cost_val), op)) + + if not clauses: + clauses.append(Item.active == True) + return select(Item).where(*clauses) def inserts(Session, count): @@ -49,15 +72,24 @@ def inserts(Session, count): print(f"Inserted {count} items in {insert_duration:.2f} seconds.") return insert_duration -def selects(Session, count): +def selects(Session, count, fetch_type): queries = [generate_random_select_query() for _ in range(count)] query_start = time.time() with Session() as session: for stmt in queries: - list(session.execute(stmt).scalars()) + if fetch_type == "limit": + stmt = stmt.limit(5) + + result = session.execute(stmt) + + if fetch_type == "first": + result.first() + else: + list(result.scalars()) + query_duration = time.time() - query_start - print(f"Executed {count} select queries in {query_duration:.2f} seconds.") + print(f"Executed {count} select queries ({fetch_type}) in {query_duration:.2f} seconds.") return query_duration def updates(Session, random_ids): @@ -105,7 +137,8 @@ def run_benchmark(db_type="sqlite", count=100_000): Base.metadata.create_all(engine) elapsed = inserts(Session, count) - elapsed += selects(Session, 500) + elapsed += selects(Session, 500, fetch_type="all") + elapsed += selects(Session, 500, fetch_type="limit") random_ids = random.sample(range(1, count + 1), 500) elapsed += updates(Session, random_ids) diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index 62655c5..2f7e5eb 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -5,6 +5,8 @@ This benchmark compares `sqlalchemy-memory` to `in-memory SQLite` using 20,000 i As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, delivering significantly faster query performance. While SQLite performs slightly better on update and delete operations, the overall runtime of `sqlalchemy-memory` remains substantially lower, making it a strong choice for prototyping and simulation. +`Check the benchmark script on GitHub `_ + .. list-table:: :header-rows: 1 :widths: 25 25 25 @@ -13,17 +15,20 @@ As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, del - SQLite (in-memory) - sqlalchemy-memory * - Insert - - 3.17 sec - - 2.70 sec - * - 500 Select Queries - - 26.37 sec - - 2.94 sec + - 3.30 sec + - **3.10 sec** + * - 500 Select Queries (all()) + - 30.07 sec + - **4.14 sec** + * - 500 Select Queries (limit(5)) + - **0.24** sec + - 0.30 sec * - 500 Updates - - 0.26 sec - - 1.12 sec + - 0.25 sec + - **0.19** sec * - 500 Deletes - - 0.09 sec - - 0.90 sec - * - **Total Runtime** - - **29.89 sec** - - **7.66 sec** + - **0.09** sec + - **0.09** sec + * - *Total Runtime* + - 33.95 sec + - **7.81 sec** diff --git a/docs/index.rst b/docs/index.rst index c1be23f..11dde90 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,6 +3,8 @@ Welcome to sqlalchemy-memory's documentation! `sqlalchemy-memory` is a pure in-memory backend for SQLAlchemy 2.0 that supports both sync and async modes, with full compatibility for SQLAlchemy Core and ORM. +📦 GitHub: https://github.com/rundef/sqlalchemy-memory + Quickstart: sync example ------------------------ diff --git a/sqlalchemy_memory/base/indexes.py b/sqlalchemy_memory/base/indexes.py index 9809d40..f1b1625 100644 --- a/sqlalchemy_memory/base/indexes.py +++ b/sqlalchemy_memory/base/indexes.py @@ -1,6 +1,7 @@ from collections import defaultdict from sortedcontainers import SortedDict -from typing import Any, List +from typing import Any, List, Generator +from itertools import chain from sqlalchemy.sql import operators from ..helpers.ordered_set import OrderedSet @@ -108,62 +109,84 @@ def on_update(self, obj, updates): self.hash_index.add(tablename, indexname, new_value, obj) self.range_index.add(tablename, indexname, new_value, obj) - def query(self, collection, tablename, colname, operator, value): + def query(self, collection, tablename, colname, operator, value, collection_is_full_table=False): indexname = self._column_to_index(tablename, colname) if not indexname: return None - # Use hash index for = / != / IN / NOT IN operators if operator == operators.eq: result = self.hash_index.query(tablename, indexname, value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + return (item for item in collection if item in result) elif operator == operators.ne: - # All values except the given one excluded = self.hash_index.query(tablename, indexname, value) - return list(set(collection) - set(excluded)) + return (item for item in collection if item not in excluded) elif operator == operators.in_op: - result = [] - for v in value: - result.extend(self.hash_index.query(tablename, indexname, v)) - return list(set(result) & set(collection)) + result = chain.from_iterable( + self.hash_index.query(tablename, indexname, v) for v in value + ) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.notin_op: - excluded = [] - for v in value: - excluded.extend(self.hash_index.query(tablename, indexname, v)) - return list(set(collection) - set(excluded)) + excluded = set(chain.from_iterable( + self.hash_index.query(tablename, indexname, v) for v in value + )) + return (item for item in collection if item not in excluded) - # Use range index - if operator == operators.gt: + elif operator == operators.gt: result = self.range_index.query(tablename, indexname, gt=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.ge: result = self.range_index.query(tablename, indexname, gte=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.lt: result = self.range_index.query(tablename, indexname, lt=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.le: result = self.range_index.query(tablename, indexname, lte=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.between_op and isinstance(value, (tuple, list)) and len(value) == 2: result = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1]) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.not_between_op and isinstance(value, (tuple, list)) and len(value) == 2: - in_range = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1]) - return list(set(collection) - set(in_range)) + in_range = set(self.range_index.query(tablename, indexname, gte=value[0], lte=value[1])) + return (item for item in collection if item not in in_range) def get_selectivity(self, tablename, colname, operator, value, total_count): """ - Estimate selectivity: higher means worst filtering. + Estimate the selectivity of a single WHERE condition. + + This method is used to rank or sort WHERE conditions by their estimated + filtering power. A lower selectivity value indicates that the condition + is expected to filter out more rows (i.e., fewer rows remain after applying it), + making it more selective. """ indexname = self._column_to_index(tablename, colname) @@ -220,7 +243,7 @@ def remove(self, tablename: str, indexname: str, value: Any, obj: Any): del self.index[tablename][indexname][value] def query(self, tablename: str, indexname: str, value: Any) -> List[Any]: - return list(self.index[tablename][indexname].get(value, [])) + return self.index[tablename][indexname].get(value, []) class RangeIndex: @@ -255,7 +278,7 @@ def remove(self, tablename: str, indexname: str, value: Any, obj: Any): except ValueError: pass - def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte=None) -> List[Any]: + def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte=None) -> Generator: sd = self.index[tablename][indexname] # Define range bounds @@ -264,14 +287,10 @@ def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte= inclusive_min = gte is not None inclusive_max = lte is not None - irange = sd.irange( + keys = sd.irange( minimum=min_key, maximum=max_key, inclusive=(inclusive_min, inclusive_max) ) - result = [] - for key in irange: - result.extend(sd[key]) - - return result + return chain.from_iterable(sd[key] for key in keys) diff --git a/sqlalchemy_memory/base/query.py b/sqlalchemy_memory/base/query.py index 801a523..4668171 100644 --- a/sqlalchemy_memory/base/query.py +++ b/sqlalchemy_memory/base/query.py @@ -1,15 +1,23 @@ from sqlalchemy.sql.elements import ( UnaryExpression, BinaryExpression, BindParameter, ExpressionClauseList, BooleanClauseList, - Grouping, True_, False_, Null + Grouping, True_, False_, Null, + Label, ) from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql import operators -from sqlalchemy.sql.annotation import AnnotatedTable +from sqlalchemy.sql.annotation import AnnotatedTable, AnnotatedColumn +from sqlalchemy.sql.schema import Table +from sqlalchemy.sql.functions import Function +from sqlalchemy.sql.selectable import Select, Join +from sqlalchemy.sql.dml import Delete, Update from sqlalchemy.orm.query import Query +from sqlalchemy.orm.decl_api import DeclarativeMeta from functools import cached_property +from itertools import tee, islice import fnmatch from ..logger import logger +from ..helpers.utils import _dedup_chain from .resolvers import DateResolver, JsonExtractResolver OPERATOR_ADAPTERS = { @@ -29,68 +37,117 @@ } class MemoryQuery(Query): - def __init__(self, entities, session): - super().__init__(entities, session) - self._model = entities[0] - - self._where_criteria = [] - self._order_by = [] - self._limit = None - self._offset = None + def __init__(self, statement, session): + self.session = session + self._statement = statement @property def store(self): return self.session.store + @property + def table(self): + if isinstance(self._statement, (Update, Delete)): + return self._statement.table + + if len(self._statement._from_obj) == 1: + return self._statement._from_obj[0] + + # Attempt to extract table from raw columns (quicker) + table = self._extract_table_from_raw_columns() + if table is not None: + return table + + from_clauses = self._statement.get_final_froms() + if len(from_clauses) != 1: + raise Exception(f"Only select statement with a single FROM clause are supported") + + from_clause = from_clauses[0] + + if isinstance(from_clause, Table): + return from_clause + + if isinstance(from_clause, Join): + return from_clause.left + + raise Exception(f"Unhandled SELECT FROM clause type: {type(from_clause)}") + @cached_property def tablename(self): - if isinstance(self._model, AnnotatedTable): - return self._model.name - return self._model.__tablename__ + return self.table.name + + @cached_property + def is_select(self): + return isinstance(self._statement, Select) + + @cached_property + def _limit(self): + if self.is_select and self._statement._limit_clause is not None: + return self._statement._limit_clause.value + + @cached_property + def _offset(self): + if self.is_select and self._statement._offset_clause is not None: + return self._statement._offset_clause.value + + @cached_property + def _order_by(self): + if self.is_select: + return self._statement._order_by_clauses + return [] + + @cached_property + def _where_criteria(self): + return self._statement._where_criteria + + def iter_items(self): + return self._execute_query() def first(self): - items = self._execute_query() - return items[0] if items else None + gen = self.iter_items() + try: + return next(gen) + except StopIteration: + return None def all(self): - items = self._execute_query() - return items + gen = self.iter_items() + gen = self._project(gen) + #print(items) + return list(gen) def filter(self, condition): - self._where_criteria.append(condition) + self._statement._where_criteria.append(condition) return self - def limit(self, value): - self._limit = value - return self + def _apply_boolean_condition(self, cond: BooleanClauseList, stream): + op = cond.operator # and_ or or_ - def offset(self, value): - self._offset = value - return self + if op is operators.and_: + # Apply filters sequentially to the current stream + for subcond in cond.clauses: + stream = self._apply_condition(subcond, stream) + return stream - def order_by(self, clause): - self._order_by.append(clause) - return self + op = cond.operator - def _apply_boolean_condition(self, cond: BooleanClauseList, collection): - op = cond.operator # and_ or or_ + if op is operators.and_: + for subcond in cond.clauses: + stream = self._apply_condition(subcond, stream) - # Recursively evaluate each sub-condition - subresults = [ - set(self._apply_condition(subcond, collection)) - for subcond in cond.clauses - ] + return stream - if op is operators.and_: - # Intersection: item must satisfy all sub-conditions - result = set.intersection(*subresults) elif op is operators.or_: - # Union: item can satisfy any sub-condition - result = set.union(*subresults) - else: - raise NotImplementedError(f"Unsupported BooleanClauseList operator: {op}") + # Materialize the stream once and tee for each OR branch + + streams = tee(stream, len(cond.clauses)) + substreams = [ + self._apply_condition(subcond, s) + for subcond, s in zip(cond.clauses, streams) + ] + return _dedup_chain(*substreams) - return list(result) + raise NotImplementedError(f"Unsupported BooleanClauseList op: {op}") def _resolve_rhs(self, rhs): if isinstance(rhs, BindParameter): @@ -109,7 +166,7 @@ def _resolve_rhs(self, rhs): else: raise NotImplementedError(f"Unsupported RHS: {type(rhs)}") - def _apply_binary_condition(self, cond: BinaryExpression, collection): + def _apply_binary_condition(self, cond: BinaryExpression, stream, is_first=False): # Extract the Python value it's being compared to value = self._resolve_rhs(cond.right) @@ -138,75 +195,66 @@ def _apply_binary_condition(self, cond: BinaryExpression, collection): op = cond.operator # Use index if available - index_result = self.store.query_index(collection, table_name, attr_name, op, value) + index_result = self.store.query_index(stream, table_name, attr_name, op, value, collection_is_full_table=is_first) if index_result is not None: return index_result if op in OPERATOR_ADAPTERS: op = OPERATOR_ADAPTERS[op](value) - return [ - item for item in collection - if op(accessor(item, attr_name), value) - ] + return (item for item in stream if op(accessor(item, attr_name), value)) - def _apply_condition(self, cond, collection): + def _apply_condition(self, cond, stream, is_first=False): if isinstance(cond, Grouping): # Unwrap - return self._apply_condition(cond.element, collection) + return self._apply_condition(cond.element, stream) if isinstance(cond, BinaryExpression): # Represent an expression that is ``LEFT RIGHT`` - return self._apply_binary_condition(cond, collection) + return self._apply_binary_condition(cond, stream, is_first=is_first) if isinstance(cond, BooleanClauseList): # and_ / or_ expressions - return self._apply_boolean_condition(cond, collection) + return self._apply_boolean_condition(cond, stream) raise NotImplementedError(f"Unsupported condition type: {type(cond)}") def _execute_query(self): - collection = self.store.data.get(self.tablename, []) - if not collection: + stream = iter(self.store.data.get(self.tablename, [])) + if not stream: logger.debug(f"Table '{self.tablename}' is empty") - return collection + return [] # Apply conditions conditions = sorted(self._where_criteria, key=self._get_condition_selectivity) - - for condition in conditions: - collection = self._apply_condition(condition, collection) - - if len(collection) == 0: - # No need to go further - return collection + for idx, condition in enumerate(conditions): + stream = self._apply_condition(condition, stream, is_first=(idx == 0)) # Apply order by - for clause in reversed(self._order_by or []): - reverse = False - - if isinstance(clause, UnaryExpression): - if clause.modifier is operators.desc_op: - reverse = True - elif clause.modifier is operators.asc_op: - reverse = False - col = clause.element - else: - col = clause - - collection = sorted(collection, key=lambda x: getattr(x, col.name), reverse=reverse) + if self._order_by: + stream = list(stream) + for clause in reversed(self._order_by): + col = clause.element if isinstance(clause, UnaryExpression) else clause + reverse = isinstance(clause, UnaryExpression) and clause.modifier is operators.desc_op + stream = sorted(stream, key=lambda x: getattr(x, col.name), reverse=reverse) - # Apply offset - if self._offset is not None: - collection = collection[self._offset:] + # Offset / limit + if self._limit or self._offset: + start = self._offset or 0 + stop = start + self._limit if self._limit else None + stream = islice(stream, start, stop) - # Apply limit - if self._limit is not None: - collection = collection[:self._limit] - - return collection + return stream def _get_condition_selectivity(self, cond): + """ + Estimate the selectivity of a single WHERE condition. + + This method is used to rank or sort WHERE conditions by their estimated + filtering power. A lower selectivity value indicates that the condition + is expected to filter out more rows (i.e., fewer rows remain after applying it), + making it more selective. + """ total_count = self.store.count(self.tablename) if not isinstance(cond, BinaryExpression): @@ -228,3 +276,114 @@ def _get_condition_selectivity(self, cond): value=value, total_count=total_count ) + + def _extract_table_from_column(self, c): + if isinstance(c, AnnotatedTable): + return c + + if isinstance(c, AnnotatedColumn): + return c.table + + if isinstance(c, DeclarativeMeta): + # Old session.query(...) api + return c.__table__ + + if isinstance(c, Label): + return self._extract_table_from_column(c.element) + + if isinstance(c, Function): + clause = next(iter(c.clauses)) + return self._extract_table_from_column(clause) + + def _extract_table_from_raw_columns(self): + _tables = [ + self._extract_table_from_column(c) + for c in self._statement._raw_columns + ] + + if len(set(_tables)) == 1: + return _tables[0] + + _tables = list(set(_tables)) + # Try to find a "root" table by checking if it has relationship to other tables + for candidate in _tables: + others = set(_tables) - {candidate} + candidate_columns = candidate.columns if hasattr(candidate, "columns") else [] + + foreign = [ + fk.column.table + for col in candidate_columns + for fk in col.foreign_keys + ] + if all(other in foreign for other in others): + return candidate + + def _project(self, stream): + """ + Apply SELECT column projection to the final collection. + + Supports raw columns, labels, and simple aggregates. + """ + + if not self.is_select: + return stream + + cols = self._statement._raw_columns + + # Bypass projection if this is a simple SELECT [table] + if all(isinstance(c, (AnnotatedTable, DeclarativeMeta, Join)) for c in cols): + return stream + + group_by = self._statement._group_by_clauses + + if group_by: + grouped = {} + for item in stream: + key = tuple(getattr(item, col.name) for col in group_by) + grouped.setdefault(key, []).append(item) + + result = [] + for key, group_items in grouped.items(): + row = [] + for col in cols: + value = self._evaluate_column(col, group_items) + row.append(value) + result.append(tuple(row)) + return result + + else: + return ( + tuple(self._evaluate_column(col, [item]) for col in cols) + for item in stream + ) + + def _evaluate_column(self, col, items): + """ + Evaluate a column or expression over one or many items. + """ + if isinstance(col, Label): + return self._evaluate_column(col.element, items) + + if isinstance(col, AnnotatedColumn): + return getattr(items[0], col.name) + + if isinstance(col, FunctionElement): + fn_name = col.name.lower() + col_expr = next(iter(col.clauses)) + values = [getattr(item, col_expr.name) for item in items] + + if fn_name == "count": + return len(values) + elif fn_name == "sum": + return sum(values) + elif fn_name == "min": + return min(values) + elif fn_name == "max": + return max(values) + elif fn_name == "avg": + return sum(values) / len(values) if values else None + else: + raise NotImplementedError(f"Function not supported: {fn_name}") + + raise NotImplementedError(f"Column type not handled: {type(col)}") + diff --git a/sqlalchemy_memory/base/session.py b/sqlalchemy_memory/base/session.py index c26739f..e4b7248 100644 --- a/sqlalchemy_memory/base/session.py +++ b/sqlalchemy_memory/base/session.py @@ -1,10 +1,12 @@ from sqlalchemy.orm import Session -from sqlalchemy.sql.selectable import Select +from sqlalchemy.sql.selectable import Select, SelectLabelStyle from sqlalchemy.sql.dml import Insert, Delete, Update from sqlalchemy.engine import IteratorResult from sqlalchemy.engine.cursor import SimpleResultMetaData from functools import lru_cache +from unittest.mock import MagicMock + from .query import MemoryQuery from .pending_changes import PendingChanges from ..logger import logger @@ -12,7 +14,6 @@ class MemorySession(Session): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._query_cls = MemoryQuery self._has_pending_merge = False self.store = self.get_bind().dialect._store @@ -46,56 +47,46 @@ def scalar(self, statement, **kwargs): @staticmethod @lru_cache(maxsize=256) - def _get_metadata_for_annotated_table(annotated_table): + def _get_metadata_for_table(table): """ Build minimal cursor metadata """ - col_names = [col.name for col in annotated_table._columns] + col_names = [col.name for col in table._columns] return SimpleResultMetaData([ (col_name, None, None, None, None, None, None) for col_name in col_names ]) + @staticmethod + def _get_metadata_from_columns(columns): + return SimpleResultMetaData([ + (getattr(col, "name", str(col)), None, None, None, None, None, None) + for col in columns + ]) def _handle_select(self, statement: Select, **kwargs): - entities = statement._raw_columns - if len(entities) != 1: - raise Exception("Only single‑entity SELECTs are supported") - # Execute the query - q = MemoryQuery(entities, self) - - # Apply WHERE - for cond in statement._where_criteria: - q = q.filter(cond) - - # Apply ORDER BY - for clause in statement._order_by_clauses: - q = q.order_by(clause) + q = MemoryQuery(statement, self) + results = q.iter_items() - # Apply LIMIT / OFFSET - if statement._limit_clause is not None: - q = q.limit(statement._limit_clause.value) - if statement._offset_clause is not None: - q = q.offset(statement._offset_clause.value) + metadata = self._get_metadata_from_columns(statement._raw_columns) - objs = q.all() + if statement._label_style is SelectLabelStyle.LABEL_STYLE_LEGACY_ORM: + """ + Support for legacy session.query(...) style + """ + it = IteratorResult(metadata, results) + it._real_result = MagicMock(_source_supports_scalars=True) + it._generate_rows = False + return it # Wrap each object in a single‑element tuple, so .scalars() yields it - wrapped = ((obj,) for obj in objs) - - metadata = MemorySession._get_metadata_for_annotated_table(entities[0]) - - return IteratorResult(metadata, wrapped) + results = ((r,) for r in results) + return IteratorResult(metadata, results) def _handle_delete(self, statement: Delete, **kwargs): - q = MemoryQuery([statement.table], self) - - for cond in statement._where_criteria: - q = q.filter(cond) - - collection = q.all() + collection = MemoryQuery(statement, self).all() for obj in collection: self.delete(obj) @@ -147,12 +138,7 @@ def _handle_insert(self, statement: Insert, params=None, **kwargs): return result def _handle_update(self, statement: Update, **kwargs): - q = MemoryQuery([statement.table], self) - - for cond in statement._where_criteria: - q = q.filter(cond) - - collection = q.all() + collection = MemoryQuery(statement, self).all() data = { col.name: bindparam.value diff --git a/sqlalchemy_memory/base/store.py b/sqlalchemy_memory/base/store.py index 85ebcf2..dc265f6 100644 --- a/sqlalchemy_memory/base/store.py +++ b/sqlalchemy_memory/base/store.py @@ -202,11 +202,8 @@ def _apply_column_defaults(self, obj): else: raise Exception(f"Unhandled server_default type: {type(column.server_default)}") - def query_index(self, collection, table_name, attr_name, op, value): - result = self.index_manager.query(collection, table_name, attr_name, op, value) - if result is not None: - logger.debug(f"Reduced '{table_name}' dataset from {len(collection)} items to {len(result)} by using index on '{attr_name}") - return result + def query_index(self, stream, table_name, attr_name, op, value, **kwargs): + return self.index_manager.query(stream, table_name, attr_name, op, value, **kwargs) def count(self, tablename): return len(self.data[tablename]) diff --git a/sqlalchemy_memory/helpers/utils.py b/sqlalchemy_memory/helpers/utils.py new file mode 100644 index 0000000..3261c80 --- /dev/null +++ b/sqlalchemy_memory/helpers/utils.py @@ -0,0 +1,15 @@ +from itertools import chain + +def _dedup_chain(*streams): + """ + Lazily merge multiple input iterators, yielding unique items only. + + This function performs a streamed union of all provided iterators, + ensuring that each item is yielded at most once while preserving order + across the combined streams. + """ + seen = set() + for item in chain(*streams): + if item not in seen: + seen.add(item) + yield item diff --git a/tests/conftest.py b/tests/conftest.py index 7a279de..3d8ba56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,3 +41,10 @@ async def AsyncSessionFactory(): sync_session_class=MemorySession, expire_on_commit=False, ) + +@pytest.fixture +def sqlite_SessionFactory(): + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + + yield sessionmaker(engine) \ No newline at end of file diff --git a/tests/models.py b/tests/models.py index 94c7917..c3e9378 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,7 @@ -from sqlalchemy.orm import declarative_base, mapped_column, Mapped -from sqlalchemy import JSON, func, text +from sqlalchemy.orm import declarative_base, mapped_column, Mapped, relationship +from sqlalchemy import JSON, func, text, ForeignKey from datetime import datetime +from typing import List Base = declarative_base() @@ -29,6 +30,21 @@ class Product(Base): def __repr__(self): return f"Product(id={self.id} name={self.name})" + +class Vendor(Base): + __tablename__ = "vendors" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False) + + # Relationship back to products + products: Mapped[List["ProductWithIndex"]] = relationship( + back_populates="vendor", cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"Vendor(id={self.id}, name='{self.name}')" + class ProductWithIndex(Base): __tablename__ = "products_with_index" id: Mapped[int] = mapped_column(primary_key=True) @@ -37,5 +53,8 @@ class ProductWithIndex(Base): category: Mapped[str] = mapped_column(index=True, nullable=False) price: Mapped[float] = mapped_column(default=True, index=True) + vendor_id: Mapped[int] = mapped_column(ForeignKey("vendors.id"), index=True) + vendor: Mapped["Vendor"] = relationship(back_populates="products") + def __repr__(self): return f"ProductWithIndex(id={self.id} name={self.name})" diff --git a/tests/test_advanced.py b/tests/test_advanced.py index a4309cb..cf6eb7a 100644 --- a/tests/test_advanced.py +++ b/tests/test_advanced.py @@ -1,8 +1,9 @@ -from sqlalchemy import select, func, and_, or_ +from sqlalchemy import select, func, and_, or_, not_, case +from sqlalchemy.orm import joinedload, selectinload from datetime import datetime, date import pytest -from models import Item, Product +from models import Item, Product, ProductWithIndex, Vendor class TestAdvanced: @pytest.mark.parametrize( @@ -195,7 +196,40 @@ def test_date_filter(self, SessionFactory, operator, value, expected_ids): results = session.execute(stmt).scalars().all() assert {item.id for item in results} == expected_ids - def test_and_or(self, SessionFactory): + @pytest.mark.parametrize("condition,expected_ids", [ + ( + (Product.id > 1) & ((Product.id < 4) | (Product.category == "A")), + {2, 3, 5}, + ), + ( + and_( + Product.id > 1, + or_( + Product.id < 4, + Product.category == "A" + ) + ), + {2, 3, 5}, + ), + ( + not_(Product.category == "A"), + {2, 3, 4} + ), + ( + or_( + and_( + not_(Product.category == "A"), # 2,3,4 + Product.id > 2 # 3,4 + ), + and_( + Product.category == "A", # 1,5 + not_(Product.id == 1) # 2,3,4,5 + ), + ), + {3, 4, 5}, + ), + ]) + def test_and_or_not(self, SessionFactory, condition, expected_ids): with SessionFactory() as session: session.add_all([ Product(id=1, name="foo", category="A"), @@ -208,29 +242,10 @@ def test_and_or(self, SessionFactory): stmt = ( select(Product) - .where( - (Product.id > 1) & ( - (Product.id < 4) | (Product.category == "A") - ) - ) + .where(condition) ) results = session.execute(stmt).scalars().all() - assert {item.id for item in results} == {2, 3, 5} - - stmt = ( - select(Product) - .where( - and_( - Product.id > 1, - or_( - Product.id < 4, - Product.category == "A" - ) - ) - ) - ) - results = session.execute(stmt).scalars().all() - assert {item.id for item in results} == {2, 3, 5} + assert {item.id for item in results} == expected_ids def test_session_inception(self, SessionFactory): with SessionFactory() as session1: @@ -240,3 +255,110 @@ def test_session_inception(self, SessionFactory): with SessionFactory() as session2: results = session2.execute(select(Item)).scalars().all() assert len(results) == 1 + + @pytest.mark.parametrize("query", [ + select(ProductWithIndex), + select(ProductWithIndex.id, ProductWithIndex.name), + + # Join shouldn't affect anything + lambda: select(ProductWithIndex).options(joinedload(ProductWithIndex.vendor)), + lambda: select(ProductWithIndex).options(selectinload(ProductWithIndex.vendor)), + lambda: select( + ProductWithIndex.id, + ProductWithIndex.name, + ).join(ProductWithIndex.vendor) + ]) + def test_select_subset_of_columns(self, SessionFactory, query): + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + vendor2 = Vendor(id=20, name="Second vendor") + + session.add_all([ + vendor1, + vendor2, + ]) + + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20, vendor=vendor2), + ]) + session.commit() + + if callable(query): + query = query() + + results = session.execute(query).scalars().all() + + assert len(results) == 3 + + # We get the objects straight back, no column selection + assert { + r.id: (r.name, r.category) + for r in results + } == { + 1: ("foo", "A"), + 2: ("bar", "B"), + 3: ("foobar", "B"), + } + + @pytest.mark.parametrize("query, expected", [ + ( + # Labels + select( + ProductWithIndex.id.label("product_id"), + ProductWithIndex.name.label("product_name"), + Vendor.name.label("vendor_name"), + ), + [ + {"product_id": 1, "product_name": "foo", "vendor_name": "First vendor"}, + {"product_id": 2, "product_name": "bar", "vendor_name": "First vendor"}, + {"product_id": 3, "product_name": "foobar", "vendor_name": "Second vendor"}, + ] + ), + + ( + # Simple case() + select( + ProductWithIndex.id, + case( + ( + ProductWithIndex.id >= 3, "High" + ), + ( + ProductWithIndex.id < 2, "Low" + ), + else_="Medium" + ).label("test"), + ), + [ + {"id": 1, "test": "Low"}, + {"id": 2, "test": "Medium"}, + {"id": 3, "test": "High"}, + ] + ), + ]) + def test_select_expressions(self, SessionFactory, query, expected): + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + vendor2 = Vendor(id=20, name="Second vendor") + + session.add_all([ + vendor1, + vendor2, + ]) + + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20, vendor=vendor2), + ]) + session.commit() + + results = session.execute(query).scalars().all() + + assert len(results) == len(expected) + for idx, (result, expected_result) in enumerate(zip(results, expected)): + for k, v in expected_result.items(): + assert hasattr(result, k), f"Expected {k} to be in result, but keys are {result.__dict__.keys()}" + assert getattr(result, k) == v, f"Expected {k} to be == {v} for item #{idx}" diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py new file mode 100644 index 0000000..fa1917b --- /dev/null +++ b/tests/test_aggregation.py @@ -0,0 +1,83 @@ +import pytest + +from sqlalchemy import func, select, case + +from models import ProductWithIndex, Vendor + +class TestAggregation: + @pytest.mark.parametrize("query,expected", [ + ( + select(func.count(ProductWithIndex.price).label("count")), + [{"count": 5}] + ), + ( + select( + func.count(ProductWithIndex.id).label("count"), + func.min(ProductWithIndex.id).label("minimum"), + func.max(ProductWithIndex.id).label("maximum"), + func.avg(ProductWithIndex.id).label("avg"), + func.sum(ProductWithIndex.id), + ), + [{ + "count": 3, + "minimum": 1, + "maximum": 3, + "avg": 2, + "wtf": 6, + }] + ), + ( + # More complex case() + select( + func.sum( + case([ + (ProductWithIndex.category == 'A', ProductWithIndex.vendor_id - ProductWithIndex.id), + (ProductWithIndex.category == 'B', ProductWithIndex.id - ProductWithIndex.vendor_id) + ]) * ProductWithIndex.id + ).label('final_value') + ), + [ + # ((10-1)*1) + ((2-10)*2) + ((3-20)*3) + {"final_value": -58} + ], + ) + ]) + def test_select_aggr(self, SessionFactory, query, expected): + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + vendor2 = Vendor(id=20, name="Second vendor") + + session.add_all([ + vendor1, + vendor2, + ]) + + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20, vendor=vendor2), + ]) + session.commit() + + result = session.execute(query).one() + + print(result) + count_x, min_x, max_x, sum_x = result + print(f"Count: {count_x}, Min: {min_x}, Max: {max_x}, Sum: {sum_x}") + + def test_group_by(self): + return + stmt = select( + YourModel.category, + func.count(YourModel.id), + func.sum(YourModel.sales) + ).group_by(YourModel.category) + + def test_having(self): + return + subq = ( + select(YourModel.category, func.sum(YourModel.sales).label("total_sales")) + .group_by(YourModel.category) + ).subquery() + + stmt = select(subq).where(subq.c.total_sales > 1000) \ No newline at end of file diff --git a/tests/test_basic.py b/tests/test_basic.py index 7aa7e9a..30165d7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,4 @@ +import pytest from sqlalchemy import select from models import Item diff --git a/tests/test_comparison.py b/tests/test_comparison.py new file mode 100644 index 0000000..891f57b --- /dev/null +++ b/tests/test_comparison.py @@ -0,0 +1,56 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.orm import selectinload, joinedload + +from models import ProductWithIndex, Vendor + +class TestComparison: + @pytest.mark.parametrize("query_lambda", [ + lambda s: s.execute(select(ProductWithIndex)), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name)), + lambda s: s.query(ProductWithIndex), + lambda s: s.query(ProductWithIndex.id, ProductWithIndex.name), # fails (OK.) + lambda s: s.execute(select(ProductWithIndex.id)).scalars(), # fails (OK. projection not yet done) + lambda s: s.execute(select(ProductWithIndex).options(selectinload(ProductWithIndex.vendor))), + lambda s: s.execute(select(ProductWithIndex).options(joinedload(ProductWithIndex.vendor))), + + lambda s: s.execute(select( + ProductWithIndex, + Vendor, + ).join(ProductWithIndex.vendor)), + + lambda s: s.execute( + select( + ProductWithIndex.id.label("product_id"), + ProductWithIndex.name.label("product_name"), + Vendor.name.label("vendor_name"), + ).join(ProductWithIndex.vendor) + ), + ]) + async def test_select_same_as_sqlite(self, SessionFactory, sqlite_SessionFactory, query_lambda): + with sqlite_SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + session.add_all([ + vendor1, + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ]) + session.commit() + + result_sqlite = list(query_lambda(session)) + + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + session.add_all([ + vendor1, + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ]) + session.commit() + + result = list(query_lambda(session)) + + assert type(result) == type(result_sqlite) + assert len(result) == len(result_sqlite) + assert len(result) > 0 + + for r1, r2 in zip(result, result_sqlite): + assert type(r1) == type(r2) diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 17072d1..f6c5c67 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -74,6 +74,41 @@ def test_range_index(self, query_kwargs, expected_ids): results = index.query("products", "price_index", **query_kwargs) assert {r.id for r in results} == expected_ids + def test_index_manager(self): + mgr = IndexManager() + mgr.table_indexes = { + "products": { + "id": ["id"], + "price_index": ["price"] + } + } + + objs = [ + MagicMock(id=1, price=10, __tablename__="products"), + MagicMock(id=2, price=30, __tablename__="products"), + MagicMock(id=3, price=20, __tablename__="products"), + ] + + for obj in objs: + mgr.on_insert(obj) + + for full in [True, False]: + result = list(mgr.query(objs, "products", "id", operators.in_op, [2, 3], collection_is_full_table=full)) + assert len(result) == 2 + assert set(r.id for r in result) == {2, 3} + + for full in [True, False]: + result = list(mgr.query(objs, "products", "id", operators.notin_op, [1, 2], collection_is_full_table=full)) + assert len(result) == 1 + assert set(r.id for r in result) == {3} + + for full in [True, False]: + result = list(mgr.query(objs, "products", "id", operators.gt, 1, collection_is_full_table=full)) + assert len(result) == 2 + assert set(r.id for r in result) == {2, 3} + + + @pytest.mark.parametrize("query_kwargs,expected_ids", [ # All ES assets ({"gte": ("ES", -float("inf")), "lte": ("ES", float("inf"))}, {1, 2}), @@ -169,39 +204,39 @@ def test_synchronized_indexes(self, SessionFactory): store = session.store collection = store.data[tablename] - assert len(store.query_index(collection, tablename, "active", operators.eq, True)) == 2 - assert len(store.query_index(collection, tablename, "active", operators.ne, True)) == 0 - assert len(store.query_index(collection, tablename, "active", operators.eq, False)) == 0 - assert len(store.query_index(collection, tablename, "active", operators.ne, False)) == 2 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, True))) == 2 + assert len(list(store.query_index(collection, tablename, "active", operators.ne, True))) == 0 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, False))) == 0 + assert len(list(store.query_index(collection, tablename, "active", operators.ne, False))) == 2 - assert len(store.query_index(collection, tablename, "category", operators.eq, "A")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "B")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "Z")) == 0 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "A"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "B"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "Z"))) == 0 # Assert nothing was changed on rollback item = session.get(ProductWithIndex, 2) item.category = "Z" session.rollback() - assert len(store.query_index(collection, tablename, "category", operators.eq, "A")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "B")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "Z")) == 0 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "A"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "B"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "Z"))) == 0 # Assert index was synchronized after update item = session.get(ProductWithIndex, 2) item.category = "Z" session.commit() - assert len(store.query_index(collection, tablename, "category", operators.eq, "A")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "B")) == 0 - assert len(store.query_index(collection, tablename, "category", operators.eq, "Z")) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "A"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "B"))) == 0 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "Z"))) == 1 # Assert nothing was changed on rollback session.delete(collection[0]) session.rollback() - assert len(store.query_index(collection, tablename, "active", operators.eq, True)) == 2 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, True))) == 2 # Assert index was synchronized after deletion session.delete(collection[0]) session.commit() - assert len(store.query_index(collection, tablename, "active", operators.eq, True)) == 1 - assert len(store.query_index(collection, tablename, "active", operators.eq, False)) == 0 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, True))) == 1 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, False))) == 0 diff --git a/tests/test_select.py b/tests/test_select.py deleted file mode 100644 index 275092b..0000000 --- a/tests/test_select.py +++ /dev/null @@ -1,133 +0,0 @@ -""" - -## 1. Core Queries (SQLAlchemy 2.0 Core) - -```python -from sqlalchemy import ( - create_engine, - MetaData, - Table, Column, Integer, String, - select, insert, update, delete, -) -# 1) Setup -engine = create_engine("sqlite:///:memory:", future=True) -metadata = MetaData() - -users = Table( - "users", metadata, - Column("id", Integer, primary_key=True), - Column("name", String), - Column("age", Integer), -) - -metadata.create_all(engine) - -# 2) INSERT -with engine.begin() as conn: - conn.execute( - insert(users), - [ - {"name": "alice", "age": 30}, - {"name": "bob", "age": 25}, - ] - ) - -# 3) SELECT -with engine.connect() as conn: - stmt = select(users).where(users.c.age > 20) - result = conn.execute(stmt) - rows = result.all() # list of Row objects - for row in rows: - print(row.id, row.name, row.age) - -# 4) UPDATE -with engine.begin() as conn: - stmt = ( - update(users) - .where(users.c.name == "alice") - .values(age=31) - .returning(users.c.id, users.c.age) - ) - updated = conn.execute(stmt).all() - -# 5) DELETE -with engine.begin() as conn: - stmt = delete(users).where(users.c.age < 28) - result = conn.execute(stmt) - print("deleted rows:", result.rowcount) - - - - - - - - - -from sqlalchemy import select -from sqlalchemy.orm import ( - declarative_base, Mapped, mapped_column, - sessionmaker, -) - -Base = declarative_base() - -class Item(Base): - __tablename__ = "items" - id: Mapped[int] = mapped_column(primary_key=True) - x: Mapped[int] - y: Mapped[int | None] - -engine = create_engine("sqlite:///:memory:", future=True) -Session = sessionmaker(engine, future=True) - -Base.metadata.create_all(engine) -session = Session() - -# add some sample data -session.add_all([ - Item(x=1, y=None), - Item(x=1, y=10), - Item(x=2, y=None), -]) -session.commit() - -# Filter: x == 1 AND y IS NULL -stmt = select(Item).where( - Item.x == 1, - Item.y.is_(None) -) -items = session.scalars(stmt).all() -# -> returns all Item instances matching those conditions - - - -from sqlalchemy import func, select - -# Core style: -stmt_core = select( - func.max(users.c.age).label("max_age"), - func.min(users.c.age).label("min_age"), -) -with engine.connect() as conn: - max_age, min_age = conn.execute(stmt_core).one() - -# ORM style: -stmt_orm = select( - func.max(Item.x).label("max_x"), - func.min(Item.y).label("min_y"), -) -max_x, min_y = session.execute(stmt_orm).one() - - - - -from sqlalchemy import func -stmt = select(Item).where( - func.json_extract(Item.data, "$.key") == "value" -) - - - -stmt = select(Item).order_by(Item.id).limit(10).offset(20) -""" \ No newline at end of file From 84ad9db979c21d7cde302c2dfaa69465475609fe Mon Sep 17 00:00:00 2001 From: rundef Date: Mon, 12 May 2025 17:57:45 -0400 Subject: [PATCH 2/4] Finish aggregation support --- README.md | 20 ------- docs/query.rst | 1 + sqlalchemy_memory/base/query.py | 89 ++++++++++++++++++++++++++---- sqlalchemy_memory/base/session.py | 35 ++++-------- sqlalchemy_memory/helpers/utils.py | 5 ++ tests/test_advanced.py | 14 +++-- tests/test_aggregation.py | 80 +++++++++++++-------------- tests/test_comparison.py | 44 ++++++++++++--- tests/test_crud.py | 20 +++++++ 9 files changed, 199 insertions(+), 109 deletions(-) diff --git a/README.md b/README.md index 7d8e4cc..3fdb708 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,6 @@ Data is kept purely in RAM and is **volatile**: it is **not persisted across app - **Zero I/O overhead**: pure in‑RAM storage (`dict`/`list` under the hood) - **Commit/rollback support** - **Index support**: indexes are recognized and used for faster lookups -- **Merge and `get()` support**: like real SQLAlchemy behavior - **Lazy query evaluation**: supports generator pipelines and short-circuiting - `first()`-style queries avoid scanning the full dataset - Optimized for read-heavy workloads and streaming filters @@ -57,25 +56,6 @@ pip install sqlalchemy-memory [See the official documentation for usage examples](https://sqlalchemy-memory.readthedocs.io/en/latest/) - -## Status - -Currently supports basic functionality equivalent to: - -- SQLite in-memory behavior for ORM + Core queries - -- `declarative_base()` model support - -Coming soon: - -- `func.count()` / aggregations - -- Joins and relationships (limited) - -- Compound indexes - -- Better expression support in `update(...).values()` (e.g., +=) - ## Testing Simply run `make tests` diff --git a/docs/query.rst b/docs/query.rst index a972800..120a54c 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -15,6 +15,7 @@ Supported Functions - `DATE(column)` - `func.json_extract(col, '$.expr')` +- Aggregation functions: - Aggregations: `func.count()` / `func.sum()` / `func.min()` / `func.max()` / `func.avg()` Indexes ------- diff --git a/sqlalchemy_memory/base/query.py b/sqlalchemy_memory/base/query.py index 4668171..86628b5 100644 --- a/sqlalchemy_memory/base/query.py +++ b/sqlalchemy_memory/base/query.py @@ -1,7 +1,7 @@ from sqlalchemy.sql.elements import ( UnaryExpression, BinaryExpression, BindParameter, ExpressionClauseList, BooleanClauseList, Grouping, True_, False_, Null, - Label, + Label, Case, ) from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql import operators @@ -101,7 +101,9 @@ def _where_criteria(self): return self._statement._where_criteria def iter_items(self): - return self._execute_query() + gen = self._execute_query() + gen = self._project(gen) + return gen def first(self): gen = self.iter_items() @@ -112,8 +114,6 @@ def first(self): def all(self): gen = self.iter_items() - gen = self._project(gen) - #print(items) return list(gen) def filter(self, condition): @@ -329,18 +329,22 @@ def _project(self, stream): return stream cols = self._statement._raw_columns + group_by = self._statement._group_by_clauses # Bypass projection if this is a simple SELECT [table] - if all(isinstance(c, (AnnotatedTable, DeclarativeMeta, Join)) for c in cols): + if not group_by and all(isinstance(c, (AnnotatedTable, DeclarativeMeta, Join)) for c in cols): return stream - group_by = self._statement._group_by_clauses - - if group_by: + if group_by or self._contains_aggregation_function(cols): grouped = {} - for item in stream: - key = tuple(getattr(item, col.name) for col in group_by) - grouped.setdefault(key, []).append(item) + if group_by: + for item in stream: + key = tuple(getattr(item, col.name) for col in group_by) + grouped.setdefault(key, []).append(item) + else: + grouped = { + "_all_": [item for item in stream] + } result = [] for key, group_items in grouped.items(): @@ -357,15 +361,45 @@ def _project(self, stream): for item in stream ) + def _contains_aggregation_function(self, cols): + for c in cols: + if isinstance(c, Label): + c = c.element + + if isinstance(c, FunctionElement): + if c.name.lower() in ["count", "sum", "min", "max", "avg"]: + return True + + return False + def _evaluate_column(self, col, items): """ Evaluate a column or expression over one or many items. """ + + if isinstance(col, AnnotatedTable): + return items[0] + if isinstance(col, Label): return self._evaluate_column(col.element, items) if isinstance(col, AnnotatedColumn): - return getattr(items[0], col.name) + + if self.tablename == col.table.name: + # Column belongs to the primary ORM model + return getattr(items[0], col.name) + + else: + # Column belongs to a related model + item = items[0] + rel_name = col.table.name # e.g., 'vendors' + # Find matching attribute on the main object + for attr_name in vars(item): + attr = getattr(item, attr_name, None) + if hasattr(attr, "__table__") and attr.__table__.name == rel_name: + return getattr(attr, col.name) + + raise ValueError(f"Could not find related model '{rel_name}' on '{type(item).__name__}'") if isinstance(col, FunctionElement): fn_name = col.name.lower() @@ -385,5 +419,36 @@ def _evaluate_column(self, col, items): else: raise NotImplementedError(f"Function not supported: {fn_name}") + if isinstance(col, Case): + for condition_expr, result_expr in col.whens: + condition_value = self._evaluate_expression(condition_expr, items) + if condition_value: + return self._evaluate_expression(result_expr, items) + + # No condition matched; return else_ + return self._evaluate_expression(col.else_, items) + raise NotImplementedError(f"Column type not handled: {type(col)}") + def _evaluate_expression(self, expr, items): + """ + Evaluate an expression (which might be a Grouping, BinaryExpression, BindParameter, etc.). + """ + + if isinstance(expr, BindParameter): + return expr.value + + if isinstance(expr, Grouping): + return self._evaluate_expression(expr.element, items) + + if isinstance(expr, BinaryExpression): + left = self._evaluate_expression(expr.left, items) + right = self._evaluate_expression(expr.right, items) + op = expr.operator + return op(left, right) + + if isinstance(expr, AnnotatedColumn): + return self._evaluate_column(expr, items) + + raise NotImplementedError(f"Unsupported expression type: {type(expr)}") + diff --git a/sqlalchemy_memory/base/session.py b/sqlalchemy_memory/base/session.py index e4b7248..3675b47 100644 --- a/sqlalchemy_memory/base/session.py +++ b/sqlalchemy_memory/base/session.py @@ -1,15 +1,17 @@ from sqlalchemy.orm import Session from sqlalchemy.sql.selectable import Select, SelectLabelStyle from sqlalchemy.sql.dml import Insert, Delete, Update -from sqlalchemy.engine import IteratorResult +from sqlalchemy.engine import IteratorResult, ChunkedIteratorResult from sqlalchemy.engine.cursor import SimpleResultMetaData -from functools import lru_cache +from sqlalchemy.sql.annotation import AnnotatedTable +from functools import partial from unittest.mock import MagicMock from .query import MemoryQuery from .pending_changes import PendingChanges from ..logger import logger +from ..helpers.utils import chunk_generator class MemorySession(Session): def __init__(self, *args, **kwargs): @@ -45,22 +47,10 @@ def scalars(self, statement, **kwargs): def scalar(self, statement, **kwargs): return self.execute(statement, **kwargs).scalar() - @staticmethod - @lru_cache(maxsize=256) - def _get_metadata_for_table(table): - """ - Build minimal cursor metadata - """ - col_names = [col.name for col in table._columns] - return SimpleResultMetaData([ - (col_name, None, None, None, None, None, None) - for col_name in col_names - ]) - @staticmethod def _get_metadata_from_columns(columns): return SimpleResultMetaData([ - (getattr(col, "name", str(col)), None, None, None, None, None, None) + getattr(col, "name", str(col)) for col in columns ]) @@ -71,7 +61,9 @@ def _handle_select(self, statement: Select, **kwargs): metadata = self._get_metadata_from_columns(statement._raw_columns) - if statement._label_style is SelectLabelStyle.LABEL_STYLE_LEGACY_ORM: + if statement._label_style is SelectLabelStyle.LABEL_STYLE_LEGACY_ORM and all( + isinstance(c, AnnotatedTable) for c in statement._raw_columns + ): """ Support for legacy session.query(...) style """ @@ -80,9 +72,9 @@ def _handle_select(self, statement: Select, **kwargs): it._generate_rows = False return it - # Wrap each object in a single‑element tuple, so .scalars() yields it - results = ((r,) for r in results) - return IteratorResult(metadata, results) + it = ChunkedIteratorResult(metadata, partial(chunk_generator, results)) + + return it def _handle_delete(self, statement: Delete, **kwargs): @@ -123,10 +115,7 @@ def _handle_insert(self, statement: Insert, params=None, **kwargs): # Handle RETURNING(...) if statement._returning: cols = list(statement._returning) - metadata = SimpleResultMetaData([ - (col.name, None, None, None, None, None, None) - for col in cols - ]) + metadata = self._get_metadata_from_columns(cols) rows = [ tuple(getattr(obj, col.name) for col in cols) for obj in instances diff --git a/sqlalchemy_memory/helpers/utils.py b/sqlalchemy_memory/helpers/utils.py index 3261c80..4e8a97e 100644 --- a/sqlalchemy_memory/helpers/utils.py +++ b/sqlalchemy_memory/helpers/utils.py @@ -13,3 +13,8 @@ def _dedup_chain(*streams): if item not in seen: seen.add(item) yield item + + +def chunk_generator(results, *a): + chunk = [(r,) if not isinstance(r, (list, tuple)) else r for r in results] + yield chunk diff --git a/tests/test_advanced.py b/tests/test_advanced.py index cf6eb7a..80d402b 100644 --- a/tests/test_advanced.py +++ b/tests/test_advanced.py @@ -1,6 +1,7 @@ from sqlalchemy import select, func, and_, or_, not_, case from sqlalchemy.orm import joinedload, selectinload from datetime import datetime, date +from sqlalchemy.sql.annotation import AnnotatedTable import pytest from models import Item, Product, ProductWithIndex, Vendor @@ -257,8 +258,8 @@ def test_session_inception(self, SessionFactory): assert len(results) == 1 @pytest.mark.parametrize("query", [ - select(ProductWithIndex), - select(ProductWithIndex.id, ProductWithIndex.name), + lambda: select(ProductWithIndex), + lambda: select(ProductWithIndex.id, ProductWithIndex.name, ProductWithIndex.category), # Join shouldn't affect anything lambda: select(ProductWithIndex).options(joinedload(ProductWithIndex.vendor)), @@ -266,6 +267,7 @@ def test_session_inception(self, SessionFactory): lambda: select( ProductWithIndex.id, ProductWithIndex.name, + ProductWithIndex.category, ).join(ProductWithIndex.vendor) ]) def test_select_subset_of_columns(self, SessionFactory, query): @@ -288,9 +290,10 @@ def test_select_subset_of_columns(self, SessionFactory, query): if callable(query): query = query() - results = session.execute(query).scalars().all() + results = session.execute(query) - assert len(results) == 3 + if isinstance(query._raw_columns[0], AnnotatedTable): + results = results.scalars() # We get the objects straight back, no column selection assert { @@ -355,7 +358,8 @@ def test_select_expressions(self, SessionFactory, query, expected): ]) session.commit() - results = session.execute(query).scalars().all() + results = session.execute(query) + results = list(results) assert len(results) == len(expected) for idx, (result, expected_result) in enumerate(zip(results, expected)): diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index fa1917b..70fc1cb 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -5,44 +5,29 @@ from models import ProductWithIndex, Vendor class TestAggregation: - @pytest.mark.parametrize("query,expected", [ + @pytest.mark.parametrize("query_fn,expected", [ ( - select(func.count(ProductWithIndex.price).label("count")), - [{"count": 5}] + lambda: select(func.count(ProductWithIndex.price).label("count")), + {"count": 3} ), ( - select( + lambda: select( func.count(ProductWithIndex.id).label("count"), func.min(ProductWithIndex.id).label("minimum"), func.max(ProductWithIndex.id).label("maximum"), func.avg(ProductWithIndex.id).label("avg"), func.sum(ProductWithIndex.id), ), - [{ + { "count": 3, "minimum": 1, "maximum": 3, "avg": 2, - "wtf": 6, - }] + "sum": 6, + } ), - ( - # More complex case() - select( - func.sum( - case([ - (ProductWithIndex.category == 'A', ProductWithIndex.vendor_id - ProductWithIndex.id), - (ProductWithIndex.category == 'B', ProductWithIndex.id - ProductWithIndex.vendor_id) - ]) * ProductWithIndex.id - ).label('final_value') - ), - [ - # ((10-1)*1) + ((2-10)*2) + ((3-20)*3) - {"final_value": -58} - ], - ) ]) - def test_select_aggr(self, SessionFactory, query, expected): + def test_select_aggr(self, SessionFactory, query_fn, expected): with SessionFactory() as session: vendor1 = Vendor(id=10, name="First vendor") vendor2 = Vendor(id=20, name="Second vendor") @@ -59,25 +44,40 @@ def test_select_aggr(self, SessionFactory, query, expected): ]) session.commit() - result = session.execute(query).one() + result = session.execute(query_fn()).mappings().one() print(result) - count_x, min_x, max_x, sum_x = result - print(f"Count: {count_x}, Min: {min_x}, Max: {max_x}, Sum: {sum_x}") + assert result == expected + + def test_group_by(self, SessionFactory): + with (SessionFactory() as session): + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20), + ]) + session.commit() + + results = session.execute(select(ProductWithIndex).group_by(ProductWithIndex.vendor_id)) + results = results.scalars().all() + + assert len(results) == 2 + assert results[0].id == 1 + assert results[1].id == 3 + + results = session.execute( + select( + func.count(ProductWithIndex.id), + func.min(ProductWithIndex.id).label("minimum"), + ) + .group_by(ProductWithIndex.vendor_id) + ) + results = list(results) - def test_group_by(self): - return - stmt = select( - YourModel.category, - func.count(YourModel.id), - func.sum(YourModel.sales) - ).group_by(YourModel.category) + assert len(results) == 2 - def test_having(self): - return - subq = ( - select(YourModel.category, func.sum(YourModel.sales).label("total_sales")) - .group_by(YourModel.category) - ).subquery() + assert results[0] == (2, 1) + assert results[0].minimum == 1 - stmt = select(subq).where(subq.c.total_sales > 1000) \ No newline at end of file + assert results[1] == (1, 3) + assert results[1].minimum == 3 diff --git a/tests/test_comparison.py b/tests/test_comparison.py index 891f57b..57604d2 100644 --- a/tests/test_comparison.py +++ b/tests/test_comparison.py @@ -1,16 +1,21 @@ import pytest from sqlalchemy import select from sqlalchemy.orm import selectinload, joinedload +from collections.abc import Iterable from models import ProductWithIndex, Vendor +def is_iterable(obj): + return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) + class TestComparison: @pytest.mark.parametrize("query_lambda", [ lambda s: s.execute(select(ProductWithIndex)), lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name)), lambda s: s.query(ProductWithIndex), - lambda s: s.query(ProductWithIndex.id, ProductWithIndex.name), # fails (OK.) - lambda s: s.execute(select(ProductWithIndex.id)).scalars(), # fails (OK. projection not yet done) + lambda s: s.query(ProductWithIndex.id, ProductWithIndex.name), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name)).scalars(), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name)).scalar(), lambda s: s.execute(select(ProductWithIndex).options(selectinload(ProductWithIndex.vendor))), lambda s: s.execute(select(ProductWithIndex).options(joinedload(ProductWithIndex.vendor))), @@ -26,6 +31,9 @@ class TestComparison: Vendor.name.label("vendor_name"), ).join(ProductWithIndex.vendor) ), + + lambda s: s.execute(select(ProductWithIndex).group_by(ProductWithIndex.category)), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name).group_by(ProductWithIndex.category)), ]) async def test_select_same_as_sqlite(self, SessionFactory, sqlite_SessionFactory, query_lambda): with sqlite_SessionFactory() as session: @@ -33,24 +41,42 @@ async def test_select_same_as_sqlite(self, SessionFactory, sqlite_SessionFactory session.add_all([ vendor1, ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foo", category="B", vendor_id=10, vendor=vendor1), ]) session.commit() - result_sqlite = list(query_lambda(session)) + result_sqlite = query_lambda(session) + + _original_type = type(result_sqlite) + if is_iterable(result_sqlite): + result_sqlite = list(result_sqlite) with SessionFactory() as session: vendor1 = Vendor(id=10, name="First vendor") session.add_all([ vendor1, ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foo", category="B", vendor_id=10, vendor=vendor1), ]) session.commit() - result = list(query_lambda(session)) + result = query_lambda(session) + _type = type(result) + + if is_iterable(result): + result = list(result) + + assert _type == _original_type + + if is_iterable(result): + assert len(result) == len(result_sqlite) + assert len(result) > 0 + + for r1, r2 in zip(result, result_sqlite): + assert type(r1) == type(r2) - assert type(result) == type(result_sqlite) - assert len(result) == len(result_sqlite) - assert len(result) > 0 + else: + assert result == result_sqlite - for r1, r2 in zip(result, result_sqlite): - assert type(r1) == type(r2) diff --git a/tests/test_crud.py b/tests/test_crud.py index 9a37e35..599cfa8 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -31,6 +31,26 @@ def test_insert(self, SessionFactory): assert items[2].id == 3 assert items[2].name == "fba" + def test_insert_returning(self, sqlite_SessionFactory, SessionFactory): + with sqlite_SessionFactory() as session: + stmt = insert(Item).values(name="foo").returning(Item.id, Item.name) + result = session.execute(stmt) + returned = result.first() + + assert returned is not None + assert returned.id > 0 + assert returned.name == "foo" + + # The row exists in the transaction, but not in the DB + item = session.get(Item, returned.id) + assert item is not None + + session.rollback() + + # Now the row is gone + item = session.get(Item, returned.id) + assert item is None + def test_update(self, SessionFactory): with SessionFactory() as session: From ae69e1ad0539202f527e6ed92993e64e3931cd54 Mon Sep 17 00:00:00 2001 From: rundef Date: Mon, 12 May 2025 17:59:06 -0400 Subject: [PATCH 3/4] . --- tests/test_aggregation.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py index 70fc1cb..4d6e0f8 100644 --- a/tests/test_aggregation.py +++ b/tests/test_aggregation.py @@ -46,11 +46,10 @@ def test_select_aggr(self, SessionFactory, query_fn, expected): result = session.execute(query_fn()).mappings().one() - print(result) assert result == expected def test_group_by(self, SessionFactory): - with (SessionFactory() as session): + with SessionFactory() as session: session.add_all([ ProductWithIndex(id=1, name="foo", category="A", vendor_id=10), ProductWithIndex(id=2, name="bar", category="B", vendor_id=10), From 97f587d7c358b29f0c6a1463f091587236046160 Mon Sep 17 00:00:00 2001 From: rundef Date: Mon, 12 May 2025 18:00:46 -0400 Subject: [PATCH 4/4] =?UTF-8?q?Bump=20version:=200.3.1=20=E2=86=92=200.4.0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .bumpversion.cfg | 2 +- pyproject.toml | 2 +- sqlalchemy_memory/__init__.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 72ca5f4..348e1d5 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.1 +current_version = 0.4.0 commit = True tag = True diff --git a/pyproject.toml b/pyproject.toml index f111648..d129ff9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sqlalchemy-memory" -version = "0.3.1" +version = "0.4.0" dependencies = [ "sqlalchemy>=2.0,<3.0", "sortedcontainers>=2.4.0" diff --git a/sqlalchemy_memory/__init__.py b/sqlalchemy_memory/__init__.py index e9f5cac..ec69fcb 100644 --- a/sqlalchemy_memory/__init__.py +++ b/sqlalchemy_memory/__init__.py @@ -6,4 +6,4 @@ "AsyncMemorySession", ] -__version__ = '0.3.1' \ No newline at end of file +__version__ = '0.4.0' \ No newline at end of file