diff --git a/.bumpversion.cfg b/.bumpversion.cfg index 72ca5f4..348e1d5 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 0.3.1 +current_version = 0.4.0 commit = True tag = True diff --git a/README.md b/README.md index 92192ab..3fdb708 100644 --- a/README.md +++ b/README.md @@ -36,7 +36,15 @@ Data is kept purely in RAM and is **volatile**: it is **not persisted across app - **Zero I/O overhead**: pure in‑RAM storage (`dict`/`list` under the hood) - **Commit/rollback support** - **Index support**: indexes are recognized and used for faster lookups -- **Merge and `get()` support**: like real SQLAlchemy behavior +- **Lazy query evaluation**: supports generator pipelines and short-circuiting + - `first()`-style queries avoid scanning the full dataset + - Optimized for read-heavy workloads and streaming filters + +## Benchmark + +Curious how `sqlalchemy-memory` stacks up? + +[View Benchmark Results](https://sqlalchemy-memory.readthedocs.io/en/latest/benchmarks.html) comparing `sqlalchemy-memory` to `in-memory SQLite` ## Installation @@ -48,25 +56,6 @@ pip install sqlalchemy-memory [See the official documentation for usage examples](https://sqlalchemy-memory.readthedocs.io/en/latest/) - -## Status - -Currently supports basic functionality equivalent to: - -- SQLite in-memory behavior for ORM + Core queries - -- `declarative_base()` model support - -Coming soon: - -- `func.count()` / aggregations - -- Joins and relationships (limited) - -- Compound indexes - -- Better expression support in `update(...).values()` (e.g., +=) - ## Testing Simply run `make tests` diff --git a/benchmark.py b/benchmark.py index 20d7d56..d26214c 100644 --- a/benchmark.py +++ b/benchmark.py @@ -1,16 +1,15 @@ -from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Index, update, delete +from sqlalchemy import create_engine, Column, Integer, String, Boolean, select, Float, update, delete, bindparam, literal from sqlalchemy.orm import declarative_base, sessionmaker +from sqlalchemy.sql import operators +from sqlalchemy.sql.elements import BinaryExpression from sqlalchemy_memory import MemorySession import argparse import time import random from faker import Faker -try: - from sqlalchemy_memory import create_memory_engine -except ImportError: - create_memory_engine = None +random.seed(42) Base = declarative_base() fake = Faker() CATEGORIES = list("ABCDEFGHIJK") @@ -22,22 +21,46 @@ class Item(Base): name = Column(String) active = Column(Boolean, index=True) category = Column(String, index=True) + price = Column(Float, index=True) + cost = Column(Float) def generate_items(n): for _ in range(n): yield Item( name=fake.name(), active=random.choice([True, False]), - category=random.choice(CATEGORIES) + category=random.choice(CATEGORIES), + price=round(random.uniform(5, 500), 2), + cost=round(random.uniform(1, 300), 2), ) def generate_random_select_query(): clauses = [] + if random.random() < 0.5: - clauses.append(Item.active == random.choice([True, False])) - if random.random() < 0.5 or not clauses: + val = random.choice([True, False]) + op = random.choice([operators.eq, operators.ne]) + clauses.append(BinaryExpression(Item.active, literal(val), op)) + + if random.random() < 0.7: subset = random.sample(CATEGORIES, random.randint(1, 4)) - clauses.append(Item.category.in_(subset)) + op = random.choice([operators.in_op, operators.notin_op]) + param = bindparam("category_list", subset, expanding=True) + clauses.append(BinaryExpression(Item.category, param, op)) + + if random.random() < 0.6: + price_val = round(random.uniform(10, 400), 2) + op = random.choice([operators.gt, operators.lt, operators.le, operators.gt]) + clauses.append(BinaryExpression(Item.price, literal(price_val), op)) + + if random.random() < 0.3: + cost_val = round(random.uniform(10, 200), 2) + op = random.choice([operators.gt, operators.lt, operators.le, operators.gt]) + clauses.append(BinaryExpression(Item.cost, literal(cost_val), op)) + + if not clauses: + clauses.append(Item.active == True) + return select(Item).where(*clauses) def inserts(Session, count): @@ -49,15 +72,24 @@ def inserts(Session, count): print(f"Inserted {count} items in {insert_duration:.2f} seconds.") return insert_duration -def selects(Session, count): +def selects(Session, count, fetch_type): queries = [generate_random_select_query() for _ in range(count)] query_start = time.time() with Session() as session: for stmt in queries: - list(session.execute(stmt).scalars()) + if fetch_type == "limit": + stmt = stmt.limit(5) + + result = session.execute(stmt) + + if fetch_type == "first": + result.first() + else: + list(result.scalars()) + query_duration = time.time() - query_start - print(f"Executed {count} select queries in {query_duration:.2f} seconds.") + print(f"Executed {count} select queries ({fetch_type}) in {query_duration:.2f} seconds.") return query_duration def updates(Session, random_ids): @@ -105,7 +137,8 @@ def run_benchmark(db_type="sqlite", count=100_000): Base.metadata.create_all(engine) elapsed = inserts(Session, count) - elapsed += selects(Session, 500) + elapsed += selects(Session, 500, fetch_type="all") + elapsed += selects(Session, 500, fetch_type="limit") random_ids = random.sample(range(1, count + 1), 500) elapsed += updates(Session, random_ids) diff --git a/docs/benchmarks.rst b/docs/benchmarks.rst index 62655c5..2f7e5eb 100644 --- a/docs/benchmarks.rst +++ b/docs/benchmarks.rst @@ -5,6 +5,8 @@ This benchmark compares `sqlalchemy-memory` to `in-memory SQLite` using 20,000 i As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, delivering significantly faster query performance. While SQLite performs slightly better on update and delete operations, the overall runtime of `sqlalchemy-memory` remains substantially lower, making it a strong choice for prototyping and simulation. +`Check the benchmark script on GitHub `_ + .. list-table:: :header-rows: 1 :widths: 25 25 25 @@ -13,17 +15,20 @@ As the results show, `sqlalchemy-memory` **excels in read-heavy workloads**, del - SQLite (in-memory) - sqlalchemy-memory * - Insert - - 3.17 sec - - 2.70 sec - * - 500 Select Queries - - 26.37 sec - - 2.94 sec + - 3.30 sec + - **3.10 sec** + * - 500 Select Queries (all()) + - 30.07 sec + - **4.14 sec** + * - 500 Select Queries (limit(5)) + - **0.24** sec + - 0.30 sec * - 500 Updates - - 0.26 sec - - 1.12 sec + - 0.25 sec + - **0.19** sec * - 500 Deletes - - 0.09 sec - - 0.90 sec - * - **Total Runtime** - - **29.89 sec** - - **7.66 sec** + - **0.09** sec + - **0.09** sec + * - *Total Runtime* + - 33.95 sec + - **7.81 sec** diff --git a/docs/index.rst b/docs/index.rst index c1be23f..11dde90 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,6 +3,8 @@ Welcome to sqlalchemy-memory's documentation! `sqlalchemy-memory` is a pure in-memory backend for SQLAlchemy 2.0 that supports both sync and async modes, with full compatibility for SQLAlchemy Core and ORM. +📦 GitHub: https://github.com/rundef/sqlalchemy-memory + Quickstart: sync example ------------------------ diff --git a/docs/query.rst b/docs/query.rst index a972800..120a54c 100644 --- a/docs/query.rst +++ b/docs/query.rst @@ -15,6 +15,7 @@ Supported Functions - `DATE(column)` - `func.json_extract(col, '$.expr')` +- Aggregation functions: - Aggregations: `func.count()` / `func.sum()` / `func.min()` / `func.max()` / `func.avg()` Indexes ------- diff --git a/pyproject.toml b/pyproject.toml index f111648..d129ff9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sqlalchemy-memory" -version = "0.3.1" +version = "0.4.0" dependencies = [ "sqlalchemy>=2.0,<3.0", "sortedcontainers>=2.4.0" diff --git a/sqlalchemy_memory/__init__.py b/sqlalchemy_memory/__init__.py index e9f5cac..ec69fcb 100644 --- a/sqlalchemy_memory/__init__.py +++ b/sqlalchemy_memory/__init__.py @@ -6,4 +6,4 @@ "AsyncMemorySession", ] -__version__ = '0.3.1' \ No newline at end of file +__version__ = '0.4.0' \ No newline at end of file diff --git a/sqlalchemy_memory/base/indexes.py b/sqlalchemy_memory/base/indexes.py index 9809d40..f1b1625 100644 --- a/sqlalchemy_memory/base/indexes.py +++ b/sqlalchemy_memory/base/indexes.py @@ -1,6 +1,7 @@ from collections import defaultdict from sortedcontainers import SortedDict -from typing import Any, List +from typing import Any, List, Generator +from itertools import chain from sqlalchemy.sql import operators from ..helpers.ordered_set import OrderedSet @@ -108,62 +109,84 @@ def on_update(self, obj, updates): self.hash_index.add(tablename, indexname, new_value, obj) self.range_index.add(tablename, indexname, new_value, obj) - def query(self, collection, tablename, colname, operator, value): + def query(self, collection, tablename, colname, operator, value, collection_is_full_table=False): indexname = self._column_to_index(tablename, colname) if not indexname: return None - # Use hash index for = / != / IN / NOT IN operators if operator == operators.eq: result = self.hash_index.query(tablename, indexname, value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + return (item for item in collection if item in result) elif operator == operators.ne: - # All values except the given one excluded = self.hash_index.query(tablename, indexname, value) - return list(set(collection) - set(excluded)) + return (item for item in collection if item not in excluded) elif operator == operators.in_op: - result = [] - for v in value: - result.extend(self.hash_index.query(tablename, indexname, v)) - return list(set(result) & set(collection)) + result = chain.from_iterable( + self.hash_index.query(tablename, indexname, v) for v in value + ) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.notin_op: - excluded = [] - for v in value: - excluded.extend(self.hash_index.query(tablename, indexname, v)) - return list(set(collection) - set(excluded)) + excluded = set(chain.from_iterable( + self.hash_index.query(tablename, indexname, v) for v in value + )) + return (item for item in collection if item not in excluded) - # Use range index - if operator == operators.gt: + elif operator == operators.gt: result = self.range_index.query(tablename, indexname, gt=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.ge: result = self.range_index.query(tablename, indexname, gte=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.lt: result = self.range_index.query(tablename, indexname, lt=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.le: result = self.range_index.query(tablename, indexname, lte=value) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.between_op and isinstance(value, (tuple, list)) and len(value) == 2: result = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1]) - return list(set(result) & set(collection)) + if collection_is_full_table: + return result + result = set(result) + return (item for item in collection if item in result) elif operator == operators.not_between_op and isinstance(value, (tuple, list)) and len(value) == 2: - in_range = self.range_index.query(tablename, indexname, gte=value[0], lte=value[1]) - return list(set(collection) - set(in_range)) + in_range = set(self.range_index.query(tablename, indexname, gte=value[0], lte=value[1])) + return (item for item in collection if item not in in_range) def get_selectivity(self, tablename, colname, operator, value, total_count): """ - Estimate selectivity: higher means worst filtering. + Estimate the selectivity of a single WHERE condition. + + This method is used to rank or sort WHERE conditions by their estimated + filtering power. A lower selectivity value indicates that the condition + is expected to filter out more rows (i.e., fewer rows remain after applying it), + making it more selective. """ indexname = self._column_to_index(tablename, colname) @@ -220,7 +243,7 @@ def remove(self, tablename: str, indexname: str, value: Any, obj: Any): del self.index[tablename][indexname][value] def query(self, tablename: str, indexname: str, value: Any) -> List[Any]: - return list(self.index[tablename][indexname].get(value, [])) + return self.index[tablename][indexname].get(value, []) class RangeIndex: @@ -255,7 +278,7 @@ def remove(self, tablename: str, indexname: str, value: Any, obj: Any): except ValueError: pass - def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte=None) -> List[Any]: + def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte=None) -> Generator: sd = self.index[tablename][indexname] # Define range bounds @@ -264,14 +287,10 @@ def query(self, tablename: str, indexname: str, gt=None, gte=None, lt=None, lte= inclusive_min = gte is not None inclusive_max = lte is not None - irange = sd.irange( + keys = sd.irange( minimum=min_key, maximum=max_key, inclusive=(inclusive_min, inclusive_max) ) - result = [] - for key in irange: - result.extend(sd[key]) - - return result + return chain.from_iterable(sd[key] for key in keys) diff --git a/sqlalchemy_memory/base/query.py b/sqlalchemy_memory/base/query.py index 801a523..86628b5 100644 --- a/sqlalchemy_memory/base/query.py +++ b/sqlalchemy_memory/base/query.py @@ -1,15 +1,23 @@ from sqlalchemy.sql.elements import ( UnaryExpression, BinaryExpression, BindParameter, ExpressionClauseList, BooleanClauseList, - Grouping, True_, False_, Null + Grouping, True_, False_, Null, + Label, Case, ) from sqlalchemy.sql.functions import FunctionElement from sqlalchemy.sql import operators -from sqlalchemy.sql.annotation import AnnotatedTable +from sqlalchemy.sql.annotation import AnnotatedTable, AnnotatedColumn +from sqlalchemy.sql.schema import Table +from sqlalchemy.sql.functions import Function +from sqlalchemy.sql.selectable import Select, Join +from sqlalchemy.sql.dml import Delete, Update from sqlalchemy.orm.query import Query +from sqlalchemy.orm.decl_api import DeclarativeMeta from functools import cached_property +from itertools import tee, islice import fnmatch from ..logger import logger +from ..helpers.utils import _dedup_chain from .resolvers import DateResolver, JsonExtractResolver OPERATOR_ADAPTERS = { @@ -29,68 +37,117 @@ } class MemoryQuery(Query): - def __init__(self, entities, session): - super().__init__(entities, session) - self._model = entities[0] - - self._where_criteria = [] - self._order_by = [] - self._limit = None - self._offset = None + def __init__(self, statement, session): + self.session = session + self._statement = statement @property def store(self): return self.session.store + @property + def table(self): + if isinstance(self._statement, (Update, Delete)): + return self._statement.table + + if len(self._statement._from_obj) == 1: + return self._statement._from_obj[0] + + # Attempt to extract table from raw columns (quicker) + table = self._extract_table_from_raw_columns() + if table is not None: + return table + + from_clauses = self._statement.get_final_froms() + if len(from_clauses) != 1: + raise Exception(f"Only select statement with a single FROM clause are supported") + + from_clause = from_clauses[0] + + if isinstance(from_clause, Table): + return from_clause + + if isinstance(from_clause, Join): + return from_clause.left + + raise Exception(f"Unhandled SELECT FROM clause type: {type(from_clause)}") + @cached_property def tablename(self): - if isinstance(self._model, AnnotatedTable): - return self._model.name - return self._model.__tablename__ + return self.table.name + + @cached_property + def is_select(self): + return isinstance(self._statement, Select) + + @cached_property + def _limit(self): + if self.is_select and self._statement._limit_clause is not None: + return self._statement._limit_clause.value + + @cached_property + def _offset(self): + if self.is_select and self._statement._offset_clause is not None: + return self._statement._offset_clause.value + + @cached_property + def _order_by(self): + if self.is_select: + return self._statement._order_by_clauses + return [] + + @cached_property + def _where_criteria(self): + return self._statement._where_criteria + + def iter_items(self): + gen = self._execute_query() + gen = self._project(gen) + return gen def first(self): - items = self._execute_query() - return items[0] if items else None + gen = self.iter_items() + try: + return next(gen) + except StopIteration: + return None def all(self): - items = self._execute_query() - return items + gen = self.iter_items() + return list(gen) def filter(self, condition): - self._where_criteria.append(condition) + self._statement._where_criteria.append(condition) return self - def limit(self, value): - self._limit = value - return self + def _apply_boolean_condition(self, cond: BooleanClauseList, stream): + op = cond.operator # and_ or or_ - def offset(self, value): - self._offset = value - return self + if op is operators.and_: + # Apply filters sequentially to the current stream + for subcond in cond.clauses: + stream = self._apply_condition(subcond, stream) + return stream - def order_by(self, clause): - self._order_by.append(clause) - return self + op = cond.operator - def _apply_boolean_condition(self, cond: BooleanClauseList, collection): - op = cond.operator # and_ or or_ + if op is operators.and_: + for subcond in cond.clauses: + stream = self._apply_condition(subcond, stream) - # Recursively evaluate each sub-condition - subresults = [ - set(self._apply_condition(subcond, collection)) - for subcond in cond.clauses - ] + return stream - if op is operators.and_: - # Intersection: item must satisfy all sub-conditions - result = set.intersection(*subresults) elif op is operators.or_: - # Union: item can satisfy any sub-condition - result = set.union(*subresults) - else: - raise NotImplementedError(f"Unsupported BooleanClauseList operator: {op}") + # Materialize the stream once and tee for each OR branch + + streams = tee(stream, len(cond.clauses)) + substreams = [ + self._apply_condition(subcond, s) + for subcond, s in zip(cond.clauses, streams) + ] + return _dedup_chain(*substreams) - return list(result) + raise NotImplementedError(f"Unsupported BooleanClauseList op: {op}") def _resolve_rhs(self, rhs): if isinstance(rhs, BindParameter): @@ -109,7 +166,7 @@ def _resolve_rhs(self, rhs): else: raise NotImplementedError(f"Unsupported RHS: {type(rhs)}") - def _apply_binary_condition(self, cond: BinaryExpression, collection): + def _apply_binary_condition(self, cond: BinaryExpression, stream, is_first=False): # Extract the Python value it's being compared to value = self._resolve_rhs(cond.right) @@ -138,75 +195,66 @@ def _apply_binary_condition(self, cond: BinaryExpression, collection): op = cond.operator # Use index if available - index_result = self.store.query_index(collection, table_name, attr_name, op, value) + index_result = self.store.query_index(stream, table_name, attr_name, op, value, collection_is_full_table=is_first) if index_result is not None: return index_result if op in OPERATOR_ADAPTERS: op = OPERATOR_ADAPTERS[op](value) - return [ - item for item in collection - if op(accessor(item, attr_name), value) - ] + return (item for item in stream if op(accessor(item, attr_name), value)) - def _apply_condition(self, cond, collection): + def _apply_condition(self, cond, stream, is_first=False): if isinstance(cond, Grouping): # Unwrap - return self._apply_condition(cond.element, collection) + return self._apply_condition(cond.element, stream) if isinstance(cond, BinaryExpression): # Represent an expression that is ``LEFT RIGHT`` - return self._apply_binary_condition(cond, collection) + return self._apply_binary_condition(cond, stream, is_first=is_first) if isinstance(cond, BooleanClauseList): # and_ / or_ expressions - return self._apply_boolean_condition(cond, collection) + return self._apply_boolean_condition(cond, stream) raise NotImplementedError(f"Unsupported condition type: {type(cond)}") def _execute_query(self): - collection = self.store.data.get(self.tablename, []) - if not collection: + stream = iter(self.store.data.get(self.tablename, [])) + if not stream: logger.debug(f"Table '{self.tablename}' is empty") - return collection + return [] # Apply conditions conditions = sorted(self._where_criteria, key=self._get_condition_selectivity) - - for condition in conditions: - collection = self._apply_condition(condition, collection) - - if len(collection) == 0: - # No need to go further - return collection + for idx, condition in enumerate(conditions): + stream = self._apply_condition(condition, stream, is_first=(idx == 0)) # Apply order by - for clause in reversed(self._order_by or []): - reverse = False - - if isinstance(clause, UnaryExpression): - if clause.modifier is operators.desc_op: - reverse = True - elif clause.modifier is operators.asc_op: - reverse = False - col = clause.element - else: - col = clause + if self._order_by: + stream = list(stream) + for clause in reversed(self._order_by): + col = clause.element if isinstance(clause, UnaryExpression) else clause + reverse = isinstance(clause, UnaryExpression) and clause.modifier is operators.desc_op + stream = sorted(stream, key=lambda x: getattr(x, col.name), reverse=reverse) - collection = sorted(collection, key=lambda x: getattr(x, col.name), reverse=reverse) + # Offset / limit + if self._limit or self._offset: + start = self._offset or 0 + stop = start + self._limit if self._limit else None + stream = islice(stream, start, stop) - # Apply offset - if self._offset is not None: - collection = collection[self._offset:] - - # Apply limit - if self._limit is not None: - collection = collection[:self._limit] - - return collection + return stream def _get_condition_selectivity(self, cond): + """ + Estimate the selectivity of a single WHERE condition. + + This method is used to rank or sort WHERE conditions by their estimated + filtering power. A lower selectivity value indicates that the condition + is expected to filter out more rows (i.e., fewer rows remain after applying it), + making it more selective. + """ total_count = self.store.count(self.tablename) if not isinstance(cond, BinaryExpression): @@ -228,3 +276,179 @@ def _get_condition_selectivity(self, cond): value=value, total_count=total_count ) + + def _extract_table_from_column(self, c): + if isinstance(c, AnnotatedTable): + return c + + if isinstance(c, AnnotatedColumn): + return c.table + + if isinstance(c, DeclarativeMeta): + # Old session.query(...) api + return c.__table__ + + if isinstance(c, Label): + return self._extract_table_from_column(c.element) + + if isinstance(c, Function): + clause = next(iter(c.clauses)) + return self._extract_table_from_column(clause) + + def _extract_table_from_raw_columns(self): + _tables = [ + self._extract_table_from_column(c) + for c in self._statement._raw_columns + ] + + if len(set(_tables)) == 1: + return _tables[0] + + _tables = list(set(_tables)) + # Try to find a "root" table by checking if it has relationship to other tables + for candidate in _tables: + others = set(_tables) - {candidate} + candidate_columns = candidate.columns if hasattr(candidate, "columns") else [] + + foreign = [ + fk.column.table + for col in candidate_columns + for fk in col.foreign_keys + ] + if all(other in foreign for other in others): + return candidate + + def _project(self, stream): + """ + Apply SELECT column projection to the final collection. + + Supports raw columns, labels, and simple aggregates. + """ + + if not self.is_select: + return stream + + cols = self._statement._raw_columns + group_by = self._statement._group_by_clauses + + # Bypass projection if this is a simple SELECT [table] + if not group_by and all(isinstance(c, (AnnotatedTable, DeclarativeMeta, Join)) for c in cols): + return stream + + if group_by or self._contains_aggregation_function(cols): + grouped = {} + if group_by: + for item in stream: + key = tuple(getattr(item, col.name) for col in group_by) + grouped.setdefault(key, []).append(item) + else: + grouped = { + "_all_": [item for item in stream] + } + + result = [] + for key, group_items in grouped.items(): + row = [] + for col in cols: + value = self._evaluate_column(col, group_items) + row.append(value) + result.append(tuple(row)) + return result + + else: + return ( + tuple(self._evaluate_column(col, [item]) for col in cols) + for item in stream + ) + + def _contains_aggregation_function(self, cols): + for c in cols: + if isinstance(c, Label): + c = c.element + + if isinstance(c, FunctionElement): + if c.name.lower() in ["count", "sum", "min", "max", "avg"]: + return True + + return False + + def _evaluate_column(self, col, items): + """ + Evaluate a column or expression over one or many items. + """ + + if isinstance(col, AnnotatedTable): + return items[0] + + if isinstance(col, Label): + return self._evaluate_column(col.element, items) + + if isinstance(col, AnnotatedColumn): + + if self.tablename == col.table.name: + # Column belongs to the primary ORM model + return getattr(items[0], col.name) + + else: + # Column belongs to a related model + item = items[0] + rel_name = col.table.name # e.g., 'vendors' + # Find matching attribute on the main object + for attr_name in vars(item): + attr = getattr(item, attr_name, None) + if hasattr(attr, "__table__") and attr.__table__.name == rel_name: + return getattr(attr, col.name) + + raise ValueError(f"Could not find related model '{rel_name}' on '{type(item).__name__}'") + + if isinstance(col, FunctionElement): + fn_name = col.name.lower() + col_expr = next(iter(col.clauses)) + values = [getattr(item, col_expr.name) for item in items] + + if fn_name == "count": + return len(values) + elif fn_name == "sum": + return sum(values) + elif fn_name == "min": + return min(values) + elif fn_name == "max": + return max(values) + elif fn_name == "avg": + return sum(values) / len(values) if values else None + else: + raise NotImplementedError(f"Function not supported: {fn_name}") + + if isinstance(col, Case): + for condition_expr, result_expr in col.whens: + condition_value = self._evaluate_expression(condition_expr, items) + if condition_value: + return self._evaluate_expression(result_expr, items) + + # No condition matched; return else_ + return self._evaluate_expression(col.else_, items) + + raise NotImplementedError(f"Column type not handled: {type(col)}") + + def _evaluate_expression(self, expr, items): + """ + Evaluate an expression (which might be a Grouping, BinaryExpression, BindParameter, etc.). + """ + + if isinstance(expr, BindParameter): + return expr.value + + if isinstance(expr, Grouping): + return self._evaluate_expression(expr.element, items) + + if isinstance(expr, BinaryExpression): + left = self._evaluate_expression(expr.left, items) + right = self._evaluate_expression(expr.right, items) + op = expr.operator + return op(left, right) + + if isinstance(expr, AnnotatedColumn): + return self._evaluate_column(expr, items) + + raise NotImplementedError(f"Unsupported expression type: {type(expr)}") + diff --git a/sqlalchemy_memory/base/session.py b/sqlalchemy_memory/base/session.py index c26739f..3675b47 100644 --- a/sqlalchemy_memory/base/session.py +++ b/sqlalchemy_memory/base/session.py @@ -1,18 +1,21 @@ from sqlalchemy.orm import Session -from sqlalchemy.sql.selectable import Select +from sqlalchemy.sql.selectable import Select, SelectLabelStyle from sqlalchemy.sql.dml import Insert, Delete, Update -from sqlalchemy.engine import IteratorResult +from sqlalchemy.engine import IteratorResult, ChunkedIteratorResult from sqlalchemy.engine.cursor import SimpleResultMetaData -from functools import lru_cache +from sqlalchemy.sql.annotation import AnnotatedTable +from functools import partial + +from unittest.mock import MagicMock from .query import MemoryQuery from .pending_changes import PendingChanges from ..logger import logger +from ..helpers.utils import chunk_generator class MemorySession(Session): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self._query_cls = MemoryQuery self._has_pending_merge = False self.store = self.get_bind().dialect._store @@ -45,57 +48,37 @@ def scalar(self, statement, **kwargs): return self.execute(statement, **kwargs).scalar() @staticmethod - @lru_cache(maxsize=256) - def _get_metadata_for_annotated_table(annotated_table): - """ - Build minimal cursor metadata - """ - col_names = [col.name for col in annotated_table._columns] + def _get_metadata_from_columns(columns): return SimpleResultMetaData([ - (col_name, None, None, None, None, None, None) - for col_name in col_names + getattr(col, "name", str(col)) + for col in columns ]) - def _handle_select(self, statement: Select, **kwargs): - entities = statement._raw_columns - if len(entities) != 1: - raise Exception("Only single‑entity SELECTs are supported") - # Execute the query - q = MemoryQuery(entities, self) - - # Apply WHERE - for cond in statement._where_criteria: - q = q.filter(cond) + q = MemoryQuery(statement, self) + results = q.iter_items() - # Apply ORDER BY - for clause in statement._order_by_clauses: - q = q.order_by(clause) + metadata = self._get_metadata_from_columns(statement._raw_columns) - # Apply LIMIT / OFFSET - if statement._limit_clause is not None: - q = q.limit(statement._limit_clause.value) - if statement._offset_clause is not None: - q = q.offset(statement._offset_clause.value) + if statement._label_style is SelectLabelStyle.LABEL_STYLE_LEGACY_ORM and all( + isinstance(c, AnnotatedTable) for c in statement._raw_columns + ): + """ + Support for legacy session.query(...) style + """ + it = IteratorResult(metadata, results) + it._real_result = MagicMock(_source_supports_scalars=True) + it._generate_rows = False + return it - objs = q.all() + it = ChunkedIteratorResult(metadata, partial(chunk_generator, results)) - # Wrap each object in a single‑element tuple, so .scalars() yields it - wrapped = ((obj,) for obj in objs) - - metadata = MemorySession._get_metadata_for_annotated_table(entities[0]) - - return IteratorResult(metadata, wrapped) + return it def _handle_delete(self, statement: Delete, **kwargs): - q = MemoryQuery([statement.table], self) - - for cond in statement._where_criteria: - q = q.filter(cond) - - collection = q.all() + collection = MemoryQuery(statement, self).all() for obj in collection: self.delete(obj) @@ -132,10 +115,7 @@ def _handle_insert(self, statement: Insert, params=None, **kwargs): # Handle RETURNING(...) if statement._returning: cols = list(statement._returning) - metadata = SimpleResultMetaData([ - (col.name, None, None, None, None, None, None) - for col in cols - ]) + metadata = self._get_metadata_from_columns(cols) rows = [ tuple(getattr(obj, col.name) for col in cols) for obj in instances @@ -147,12 +127,7 @@ def _handle_insert(self, statement: Insert, params=None, **kwargs): return result def _handle_update(self, statement: Update, **kwargs): - q = MemoryQuery([statement.table], self) - - for cond in statement._where_criteria: - q = q.filter(cond) - - collection = q.all() + collection = MemoryQuery(statement, self).all() data = { col.name: bindparam.value diff --git a/sqlalchemy_memory/base/store.py b/sqlalchemy_memory/base/store.py index 85ebcf2..dc265f6 100644 --- a/sqlalchemy_memory/base/store.py +++ b/sqlalchemy_memory/base/store.py @@ -202,11 +202,8 @@ def _apply_column_defaults(self, obj): else: raise Exception(f"Unhandled server_default type: {type(column.server_default)}") - def query_index(self, collection, table_name, attr_name, op, value): - result = self.index_manager.query(collection, table_name, attr_name, op, value) - if result is not None: - logger.debug(f"Reduced '{table_name}' dataset from {len(collection)} items to {len(result)} by using index on '{attr_name}") - return result + def query_index(self, stream, table_name, attr_name, op, value, **kwargs): + return self.index_manager.query(stream, table_name, attr_name, op, value, **kwargs) def count(self, tablename): return len(self.data[tablename]) diff --git a/sqlalchemy_memory/helpers/utils.py b/sqlalchemy_memory/helpers/utils.py new file mode 100644 index 0000000..4e8a97e --- /dev/null +++ b/sqlalchemy_memory/helpers/utils.py @@ -0,0 +1,20 @@ +from itertools import chain + +def _dedup_chain(*streams): + """ + Lazily merge multiple input iterators, yielding unique items only. + + This function performs a streamed union of all provided iterators, + ensuring that each item is yielded at most once while preserving order + across the combined streams. + """ + seen = set() + for item in chain(*streams): + if item not in seen: + seen.add(item) + yield item + + +def chunk_generator(results, *a): + chunk = [(r,) if not isinstance(r, (list, tuple)) else r for r in results] + yield chunk diff --git a/tests/conftest.py b/tests/conftest.py index 7a279de..3d8ba56 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,3 +41,10 @@ async def AsyncSessionFactory(): sync_session_class=MemorySession, expire_on_commit=False, ) + +@pytest.fixture +def sqlite_SessionFactory(): + engine = create_engine("sqlite:///:memory:", echo=False) + Base.metadata.create_all(engine) + + yield sessionmaker(engine) \ No newline at end of file diff --git a/tests/models.py b/tests/models.py index 94c7917..c3e9378 100644 --- a/tests/models.py +++ b/tests/models.py @@ -1,6 +1,7 @@ -from sqlalchemy.orm import declarative_base, mapped_column, Mapped -from sqlalchemy import JSON, func, text +from sqlalchemy.orm import declarative_base, mapped_column, Mapped, relationship +from sqlalchemy import JSON, func, text, ForeignKey from datetime import datetime +from typing import List Base = declarative_base() @@ -29,6 +30,21 @@ class Product(Base): def __repr__(self): return f"Product(id={self.id} name={self.name})" + +class Vendor(Base): + __tablename__ = "vendors" + + id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False) + + # Relationship back to products + products: Mapped[List["ProductWithIndex"]] = relationship( + back_populates="vendor", cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"Vendor(id={self.id}, name='{self.name}')" + class ProductWithIndex(Base): __tablename__ = "products_with_index" id: Mapped[int] = mapped_column(primary_key=True) @@ -37,5 +53,8 @@ class ProductWithIndex(Base): category: Mapped[str] = mapped_column(index=True, nullable=False) price: Mapped[float] = mapped_column(default=True, index=True) + vendor_id: Mapped[int] = mapped_column(ForeignKey("vendors.id"), index=True) + vendor: Mapped["Vendor"] = relationship(back_populates="products") + def __repr__(self): return f"ProductWithIndex(id={self.id} name={self.name})" diff --git a/tests/test_advanced.py b/tests/test_advanced.py index a4309cb..80d402b 100644 --- a/tests/test_advanced.py +++ b/tests/test_advanced.py @@ -1,8 +1,10 @@ -from sqlalchemy import select, func, and_, or_ +from sqlalchemy import select, func, and_, or_, not_, case +from sqlalchemy.orm import joinedload, selectinload from datetime import datetime, date +from sqlalchemy.sql.annotation import AnnotatedTable import pytest -from models import Item, Product +from models import Item, Product, ProductWithIndex, Vendor class TestAdvanced: @pytest.mark.parametrize( @@ -195,7 +197,40 @@ def test_date_filter(self, SessionFactory, operator, value, expected_ids): results = session.execute(stmt).scalars().all() assert {item.id for item in results} == expected_ids - def test_and_or(self, SessionFactory): + @pytest.mark.parametrize("condition,expected_ids", [ + ( + (Product.id > 1) & ((Product.id < 4) | (Product.category == "A")), + {2, 3, 5}, + ), + ( + and_( + Product.id > 1, + or_( + Product.id < 4, + Product.category == "A" + ) + ), + {2, 3, 5}, + ), + ( + not_(Product.category == "A"), + {2, 3, 4} + ), + ( + or_( + and_( + not_(Product.category == "A"), # 2,3,4 + Product.id > 2 # 3,4 + ), + and_( + Product.category == "A", # 1,5 + not_(Product.id == 1) # 2,3,4,5 + ), + ), + {3, 4, 5}, + ), + ]) + def test_and_or_not(self, SessionFactory, condition, expected_ids): with SessionFactory() as session: session.add_all([ Product(id=1, name="foo", category="A"), @@ -208,29 +243,10 @@ def test_and_or(self, SessionFactory): stmt = ( select(Product) - .where( - (Product.id > 1) & ( - (Product.id < 4) | (Product.category == "A") - ) - ) + .where(condition) ) results = session.execute(stmt).scalars().all() - assert {item.id for item in results} == {2, 3, 5} - - stmt = ( - select(Product) - .where( - and_( - Product.id > 1, - or_( - Product.id < 4, - Product.category == "A" - ) - ) - ) - ) - results = session.execute(stmt).scalars().all() - assert {item.id for item in results} == {2, 3, 5} + assert {item.id for item in results} == expected_ids def test_session_inception(self, SessionFactory): with SessionFactory() as session1: @@ -240,3 +256,113 @@ def test_session_inception(self, SessionFactory): with SessionFactory() as session2: results = session2.execute(select(Item)).scalars().all() assert len(results) == 1 + + @pytest.mark.parametrize("query", [ + lambda: select(ProductWithIndex), + lambda: select(ProductWithIndex.id, ProductWithIndex.name, ProductWithIndex.category), + + # Join shouldn't affect anything + lambda: select(ProductWithIndex).options(joinedload(ProductWithIndex.vendor)), + lambda: select(ProductWithIndex).options(selectinload(ProductWithIndex.vendor)), + lambda: select( + ProductWithIndex.id, + ProductWithIndex.name, + ProductWithIndex.category, + ).join(ProductWithIndex.vendor) + ]) + def test_select_subset_of_columns(self, SessionFactory, query): + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + vendor2 = Vendor(id=20, name="Second vendor") + + session.add_all([ + vendor1, + vendor2, + ]) + + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20, vendor=vendor2), + ]) + session.commit() + + if callable(query): + query = query() + + results = session.execute(query) + + if isinstance(query._raw_columns[0], AnnotatedTable): + results = results.scalars() + + # We get the objects straight back, no column selection + assert { + r.id: (r.name, r.category) + for r in results + } == { + 1: ("foo", "A"), + 2: ("bar", "B"), + 3: ("foobar", "B"), + } + + @pytest.mark.parametrize("query, expected", [ + ( + # Labels + select( + ProductWithIndex.id.label("product_id"), + ProductWithIndex.name.label("product_name"), + Vendor.name.label("vendor_name"), + ), + [ + {"product_id": 1, "product_name": "foo", "vendor_name": "First vendor"}, + {"product_id": 2, "product_name": "bar", "vendor_name": "First vendor"}, + {"product_id": 3, "product_name": "foobar", "vendor_name": "Second vendor"}, + ] + ), + + ( + # Simple case() + select( + ProductWithIndex.id, + case( + ( + ProductWithIndex.id >= 3, "High" + ), + ( + ProductWithIndex.id < 2, "Low" + ), + else_="Medium" + ).label("test"), + ), + [ + {"id": 1, "test": "Low"}, + {"id": 2, "test": "Medium"}, + {"id": 3, "test": "High"}, + ] + ), + ]) + def test_select_expressions(self, SessionFactory, query, expected): + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + vendor2 = Vendor(id=20, name="Second vendor") + + session.add_all([ + vendor1, + vendor2, + ]) + + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20, vendor=vendor2), + ]) + session.commit() + + results = session.execute(query) + results = list(results) + + assert len(results) == len(expected) + for idx, (result, expected_result) in enumerate(zip(results, expected)): + for k, v in expected_result.items(): + assert hasattr(result, k), f"Expected {k} to be in result, but keys are {result.__dict__.keys()}" + assert getattr(result, k) == v, f"Expected {k} to be == {v} for item #{idx}" diff --git a/tests/test_aggregation.py b/tests/test_aggregation.py new file mode 100644 index 0000000..4d6e0f8 --- /dev/null +++ b/tests/test_aggregation.py @@ -0,0 +1,82 @@ +import pytest + +from sqlalchemy import func, select, case + +from models import ProductWithIndex, Vendor + +class TestAggregation: + @pytest.mark.parametrize("query_fn,expected", [ + ( + lambda: select(func.count(ProductWithIndex.price).label("count")), + {"count": 3} + ), + ( + lambda: select( + func.count(ProductWithIndex.id).label("count"), + func.min(ProductWithIndex.id).label("minimum"), + func.max(ProductWithIndex.id).label("maximum"), + func.avg(ProductWithIndex.id).label("avg"), + func.sum(ProductWithIndex.id), + ), + { + "count": 3, + "minimum": 1, + "maximum": 3, + "avg": 2, + "sum": 6, + } + ), + ]) + def test_select_aggr(self, SessionFactory, query_fn, expected): + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + vendor2 = Vendor(id=20, name="Second vendor") + + session.add_all([ + vendor1, + vendor2, + ]) + + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20, vendor=vendor2), + ]) + session.commit() + + result = session.execute(query_fn()).mappings().one() + + assert result == expected + + def test_group_by(self, SessionFactory): + with SessionFactory() as session: + session.add_all([ + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10), + ProductWithIndex(id=2, name="bar", category="B", vendor_id=10), + ProductWithIndex(id=3, name="foobar", category="B", vendor_id=20), + ]) + session.commit() + + results = session.execute(select(ProductWithIndex).group_by(ProductWithIndex.vendor_id)) + results = results.scalars().all() + + assert len(results) == 2 + assert results[0].id == 1 + assert results[1].id == 3 + + results = session.execute( + select( + func.count(ProductWithIndex.id), + func.min(ProductWithIndex.id).label("minimum"), + ) + .group_by(ProductWithIndex.vendor_id) + ) + results = list(results) + + assert len(results) == 2 + + assert results[0] == (2, 1) + assert results[0].minimum == 1 + + assert results[1] == (1, 3) + assert results[1].minimum == 3 diff --git a/tests/test_basic.py b/tests/test_basic.py index 7aa7e9a..30165d7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,3 +1,4 @@ +import pytest from sqlalchemy import select from models import Item diff --git a/tests/test_comparison.py b/tests/test_comparison.py new file mode 100644 index 0000000..57604d2 --- /dev/null +++ b/tests/test_comparison.py @@ -0,0 +1,82 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.orm import selectinload, joinedload +from collections.abc import Iterable + +from models import ProductWithIndex, Vendor + +def is_iterable(obj): + return isinstance(obj, Iterable) and not isinstance(obj, (str, bytes)) + +class TestComparison: + @pytest.mark.parametrize("query_lambda", [ + lambda s: s.execute(select(ProductWithIndex)), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name)), + lambda s: s.query(ProductWithIndex), + lambda s: s.query(ProductWithIndex.id, ProductWithIndex.name), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name)).scalars(), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name)).scalar(), + lambda s: s.execute(select(ProductWithIndex).options(selectinload(ProductWithIndex.vendor))), + lambda s: s.execute(select(ProductWithIndex).options(joinedload(ProductWithIndex.vendor))), + + lambda s: s.execute(select( + ProductWithIndex, + Vendor, + ).join(ProductWithIndex.vendor)), + + lambda s: s.execute( + select( + ProductWithIndex.id.label("product_id"), + ProductWithIndex.name.label("product_name"), + Vendor.name.label("vendor_name"), + ).join(ProductWithIndex.vendor) + ), + + lambda s: s.execute(select(ProductWithIndex).group_by(ProductWithIndex.category)), + lambda s: s.execute(select(ProductWithIndex.id, ProductWithIndex.name).group_by(ProductWithIndex.category)), + ]) + async def test_select_same_as_sqlite(self, SessionFactory, sqlite_SessionFactory, query_lambda): + with sqlite_SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + session.add_all([ + vendor1, + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foo", category="B", vendor_id=10, vendor=vendor1), + ]) + session.commit() + + result_sqlite = query_lambda(session) + + _original_type = type(result_sqlite) + if is_iterable(result_sqlite): + result_sqlite = list(result_sqlite) + + with SessionFactory() as session: + vendor1 = Vendor(id=10, name="First vendor") + session.add_all([ + vendor1, + ProductWithIndex(id=1, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=2, name="foo", category="A", vendor_id=10, vendor=vendor1), + ProductWithIndex(id=3, name="foo", category="B", vendor_id=10, vendor=vendor1), + ]) + session.commit() + + result = query_lambda(session) + _type = type(result) + + if is_iterable(result): + result = list(result) + + assert _type == _original_type + + if is_iterable(result): + assert len(result) == len(result_sqlite) + assert len(result) > 0 + + for r1, r2 in zip(result, result_sqlite): + assert type(r1) == type(r2) + + else: + assert result == result_sqlite + diff --git a/tests/test_crud.py b/tests/test_crud.py index 9a37e35..599cfa8 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -31,6 +31,26 @@ def test_insert(self, SessionFactory): assert items[2].id == 3 assert items[2].name == "fba" + def test_insert_returning(self, sqlite_SessionFactory, SessionFactory): + with sqlite_SessionFactory() as session: + stmt = insert(Item).values(name="foo").returning(Item.id, Item.name) + result = session.execute(stmt) + returned = result.first() + + assert returned is not None + assert returned.id > 0 + assert returned.name == "foo" + + # The row exists in the transaction, but not in the DB + item = session.get(Item, returned.id) + assert item is not None + + session.rollback() + + # Now the row is gone + item = session.get(Item, returned.id) + assert item is None + def test_update(self, SessionFactory): with SessionFactory() as session: diff --git a/tests/test_indexes.py b/tests/test_indexes.py index 17072d1..f6c5c67 100644 --- a/tests/test_indexes.py +++ b/tests/test_indexes.py @@ -74,6 +74,41 @@ def test_range_index(self, query_kwargs, expected_ids): results = index.query("products", "price_index", **query_kwargs) assert {r.id for r in results} == expected_ids + def test_index_manager(self): + mgr = IndexManager() + mgr.table_indexes = { + "products": { + "id": ["id"], + "price_index": ["price"] + } + } + + objs = [ + MagicMock(id=1, price=10, __tablename__="products"), + MagicMock(id=2, price=30, __tablename__="products"), + MagicMock(id=3, price=20, __tablename__="products"), + ] + + for obj in objs: + mgr.on_insert(obj) + + for full in [True, False]: + result = list(mgr.query(objs, "products", "id", operators.in_op, [2, 3], collection_is_full_table=full)) + assert len(result) == 2 + assert set(r.id for r in result) == {2, 3} + + for full in [True, False]: + result = list(mgr.query(objs, "products", "id", operators.notin_op, [1, 2], collection_is_full_table=full)) + assert len(result) == 1 + assert set(r.id for r in result) == {3} + + for full in [True, False]: + result = list(mgr.query(objs, "products", "id", operators.gt, 1, collection_is_full_table=full)) + assert len(result) == 2 + assert set(r.id for r in result) == {2, 3} + + + @pytest.mark.parametrize("query_kwargs,expected_ids", [ # All ES assets ({"gte": ("ES", -float("inf")), "lte": ("ES", float("inf"))}, {1, 2}), @@ -169,39 +204,39 @@ def test_synchronized_indexes(self, SessionFactory): store = session.store collection = store.data[tablename] - assert len(store.query_index(collection, tablename, "active", operators.eq, True)) == 2 - assert len(store.query_index(collection, tablename, "active", operators.ne, True)) == 0 - assert len(store.query_index(collection, tablename, "active", operators.eq, False)) == 0 - assert len(store.query_index(collection, tablename, "active", operators.ne, False)) == 2 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, True))) == 2 + assert len(list(store.query_index(collection, tablename, "active", operators.ne, True))) == 0 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, False))) == 0 + assert len(list(store.query_index(collection, tablename, "active", operators.ne, False))) == 2 - assert len(store.query_index(collection, tablename, "category", operators.eq, "A")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "B")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "Z")) == 0 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "A"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "B"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "Z"))) == 0 # Assert nothing was changed on rollback item = session.get(ProductWithIndex, 2) item.category = "Z" session.rollback() - assert len(store.query_index(collection, tablename, "category", operators.eq, "A")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "B")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "Z")) == 0 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "A"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "B"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "Z"))) == 0 # Assert index was synchronized after update item = session.get(ProductWithIndex, 2) item.category = "Z" session.commit() - assert len(store.query_index(collection, tablename, "category", operators.eq, "A")) == 1 - assert len(store.query_index(collection, tablename, "category", operators.eq, "B")) == 0 - assert len(store.query_index(collection, tablename, "category", operators.eq, "Z")) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "A"))) == 1 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "B"))) == 0 + assert len(list(store.query_index(collection, tablename, "category", operators.eq, "Z"))) == 1 # Assert nothing was changed on rollback session.delete(collection[0]) session.rollback() - assert len(store.query_index(collection, tablename, "active", operators.eq, True)) == 2 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, True))) == 2 # Assert index was synchronized after deletion session.delete(collection[0]) session.commit() - assert len(store.query_index(collection, tablename, "active", operators.eq, True)) == 1 - assert len(store.query_index(collection, tablename, "active", operators.eq, False)) == 0 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, True))) == 1 + assert len(list(store.query_index(collection, tablename, "active", operators.eq, False))) == 0 diff --git a/tests/test_select.py b/tests/test_select.py deleted file mode 100644 index 275092b..0000000 --- a/tests/test_select.py +++ /dev/null @@ -1,133 +0,0 @@ -""" - -## 1. Core Queries (SQLAlchemy 2.0 Core) - -```python -from sqlalchemy import ( - create_engine, - MetaData, - Table, Column, Integer, String, - select, insert, update, delete, -) -# 1) Setup -engine = create_engine("sqlite:///:memory:", future=True) -metadata = MetaData() - -users = Table( - "users", metadata, - Column("id", Integer, primary_key=True), - Column("name", String), - Column("age", Integer), -) - -metadata.create_all(engine) - -# 2) INSERT -with engine.begin() as conn: - conn.execute( - insert(users), - [ - {"name": "alice", "age": 30}, - {"name": "bob", "age": 25}, - ] - ) - -# 3) SELECT -with engine.connect() as conn: - stmt = select(users).where(users.c.age > 20) - result = conn.execute(stmt) - rows = result.all() # list of Row objects - for row in rows: - print(row.id, row.name, row.age) - -# 4) UPDATE -with engine.begin() as conn: - stmt = ( - update(users) - .where(users.c.name == "alice") - .values(age=31) - .returning(users.c.id, users.c.age) - ) - updated = conn.execute(stmt).all() - -# 5) DELETE -with engine.begin() as conn: - stmt = delete(users).where(users.c.age < 28) - result = conn.execute(stmt) - print("deleted rows:", result.rowcount) - - - - - - - - - -from sqlalchemy import select -from sqlalchemy.orm import ( - declarative_base, Mapped, mapped_column, - sessionmaker, -) - -Base = declarative_base() - -class Item(Base): - __tablename__ = "items" - id: Mapped[int] = mapped_column(primary_key=True) - x: Mapped[int] - y: Mapped[int | None] - -engine = create_engine("sqlite:///:memory:", future=True) -Session = sessionmaker(engine, future=True) - -Base.metadata.create_all(engine) -session = Session() - -# add some sample data -session.add_all([ - Item(x=1, y=None), - Item(x=1, y=10), - Item(x=2, y=None), -]) -session.commit() - -# Filter: x == 1 AND y IS NULL -stmt = select(Item).where( - Item.x == 1, - Item.y.is_(None) -) -items = session.scalars(stmt).all() -# -> returns all Item instances matching those conditions - - - -from sqlalchemy import func, select - -# Core style: -stmt_core = select( - func.max(users.c.age).label("max_age"), - func.min(users.c.age).label("min_age"), -) -with engine.connect() as conn: - max_age, min_age = conn.execute(stmt_core).one() - -# ORM style: -stmt_orm = select( - func.max(Item.x).label("max_x"), - func.min(Item.y).label("min_y"), -) -max_x, min_y = session.execute(stmt_orm).one() - - - - -from sqlalchemy import func -stmt = select(Item).where( - func.json_extract(Item.data, "$.key") == "value" -) - - - -stmt = select(Item).order_by(Item.id).limit(10).offset(20) -""" \ No newline at end of file