diff --git a/sqlspec/adapters/asyncpg/__init__.py b/sqlspec/adapters/asyncpg/__init__.py index bcfad5bb..a1bbaaa4 100644 --- a/sqlspec/adapters/asyncpg/__init__.py +++ b/sqlspec/adapters/asyncpg/__init__.py @@ -1,5 +1,7 @@ """AsyncPG adapter for SQLSpec.""" +import sqlspec.adapters.asyncpg.dialect as dialect # noqa: F401 + from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPool, AsyncpgPreparedStatement from sqlspec.adapters.asyncpg.config import AsyncpgConfig, AsyncpgConnectionConfig, AsyncpgPoolConfig from sqlspec.adapters.asyncpg.core import default_statement_config diff --git a/sqlspec/adapters/asyncpg/config.py b/sqlspec/adapters/asyncpg/config.py index 80416835..2c199c89 100644 --- a/sqlspec/adapters/asyncpg/config.py +++ b/sqlspec/adapters/asyncpg/config.py @@ -102,6 +102,11 @@ class AsyncpgDriverFeatures(TypedDict): Defaults to True when pgvector-python is installed. Provides automatic conversion between Python objects and PostgreSQL vector types. Enables vector similarity operations and index support. + enable_paradedb: Enable ParadeDB (pg_search) extension detection. + When enabled and the pg_search extension is detected, the SQL dialect + switches to "paradedb" which supports search operators (@@@, &&&, etc.) + and inherits all pgvector distance operators. + Defaults to True. Independent of enable_pgvector. enable_cloud_sql: Enable Google Cloud SQL connector integration. Requires cloud-sql-python-connector package. Defaults to False (explicit opt-in required). @@ -146,6 +151,7 @@ class AsyncpgDriverFeatures(TypedDict): json_deserializer: NotRequired["Callable[[str], Any]"] enable_json_codecs: NotRequired[bool] enable_pgvector: NotRequired[bool] + enable_paradedb: NotRequired[bool] enable_cloud_sql: NotRequired[bool] cloud_sql_instance: NotRequired[str] cloud_sql_enable_iam_auth: NotRequired[bool] @@ -328,6 +334,7 @@ def __init__( self._cloud_sql_connector: Any | None = None self._alloydb_connector: Any | None = None self._pgvector_available: bool | None = None + self._paradedb_available: bool | None = None self._validate_connector_config() @@ -435,7 +442,43 @@ async def _create_pool(self) -> "Pool[Record]": config.setdefault("init", self._init_connection) - return await asyncpg_create_pool(**config) + pool = await asyncpg_create_pool(**config) + await self._detect_extensions(pool) + return pool + + async def _detect_extensions(self, pool: "Pool[Record]") -> None: + """Detect database extensions and update dialect accordingly. + + Args: + pool: Connection pool to acquire a connection from. + """ + extensions = [ + name + for name, enabled in [ + ("vector", self.driver_features.get("enable_pgvector", False)), + ("pg_search", self.driver_features.get("enable_paradedb", False)), + ] + if enabled + ] + if not extensions: + return + + connection = await pool.acquire() + try: + results = await connection.fetch( + "SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])", + extensions, + ) + detected = {r["extname"] for r in results} + self._pgvector_available = "vector" in detected + self._paradedb_available = "pg_search" in detected + except Exception: + self._pgvector_available = False + self._paradedb_available = False + finally: + await pool.release(connection) + + self._update_dialect_for_extensions() async def _init_connection(self, connection: "AsyncpgConnection") -> None: """Initialize connection with JSON codecs, pgvector support, and user callback. @@ -450,22 +493,27 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None: decoder=self.driver_features.get("json_deserializer", from_json), ) - if self.driver_features.get("enable_pgvector", False): - if self._pgvector_available is None: - try: - result = await connection.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'") - self._pgvector_available = bool(result) - except Exception: - # If we can't query extensions, assume false to be safe and avoid errors - self._pgvector_available = False - - if self._pgvector_available: - await register_pgvector_support(connection) + if self._pgvector_available: + await register_pgvector_support(connection) # Call user-provided callback after internal setup if self._user_connection_hook is not None: await self._user_connection_hook(connection) + def _update_dialect_for_extensions(self) -> None: + """Update statement_config dialect based on detected extensions. + + Priority: paradedb > pgvector > postgres (default). + """ + current_dialect = getattr(self.statement_config, "dialect", "postgres") + if current_dialect != "postgres": + return + + if self._paradedb_available: + self.statement_config = self.statement_config.replace(dialect="paradedb") + elif self._pgvector_available: + self.statement_config = self.statement_config.replace(dialect="pgvector") + async def _close_pool(self) -> None: """Close the actual async connection pool and cleanup connectors.""" if self.connection_instance: diff --git a/sqlspec/adapters/asyncpg/core.py b/sqlspec/adapters/asyncpg/core.py index a5520c88..88927656 100644 --- a/sqlspec/adapters/asyncpg/core.py +++ b/sqlspec/adapters/asyncpg/core.py @@ -242,6 +242,7 @@ def apply_driver_features( deserializer = processed_features.setdefault("json_deserializer", from_json) processed_features.setdefault("enable_json_codecs", True) processed_features.setdefault("enable_pgvector", PGVECTOR_INSTALLED) + processed_features.setdefault("enable_paradedb", True) processed_features.setdefault("enable_cloud_sql", False) processed_features.setdefault("enable_alloydb", False) diff --git a/sqlspec/adapters/asyncpg/dialect/__init__.py b/sqlspec/adapters/asyncpg/dialect/__init__.py new file mode 100644 index 00000000..2cda30ff --- /dev/null +++ b/sqlspec/adapters/asyncpg/dialect/__init__.py @@ -0,0 +1,12 @@ +"""Asyncpg dialect submodule.""" + +from sqlglot.dialects.dialect import Dialect + +from sqlspec.adapters.asyncpg.dialect._paradedb import ParadeDB +from sqlspec.adapters.asyncpg.dialect._pgvector import PGVector + +# Register dialects with sqlglot +Dialect.classes["pgvector"] = PGVector +Dialect.classes["paradedb"] = ParadeDB + +__all__ = ("PGVector", "ParadeDB") diff --git a/sqlspec/adapters/asyncpg/dialect/_paradedb.py b/sqlspec/adapters/asyncpg/dialect/_paradedb.py new file mode 100644 index 00000000..893b4126 --- /dev/null +++ b/sqlspec/adapters/asyncpg/dialect/_paradedb.py @@ -0,0 +1,78 @@ +"""ParadeDB dialect extending PGVector with pg_search BM25/search operators. + +Adds support for ParadeDB search operators: +- @@@ : BM25 full-text search +- &&& : Boolean AND search +- ||| : Boolean OR search +- === : Exact term match +- ### : Score/rank retrieval +- ## : Snippet/highlight retrieval +- ##> : Snippet/highlight with options + +Also inherits pgvector distance operators from PGVector: +- <-> : L2 (Euclidean) distance +- <#> : Negative inner product +- <=> : Cosine distance +- <+> : L1 (Taxicab/Manhattan) distance +- <~> : Hamming distance (binary vectors) +- <%> : Jaccard distance (binary vectors) +""" + +from __future__ import annotations + +from sqlglot import exp +from sqlglot.tokens import TokenType + +from sqlspec.adapters.asyncpg.dialect._pgvector import PGVector, PGVectorGenerator, PGVectorParser, PGVectorTokenizer + +__all__ = ("ParadeDB",) + +_PARADEDB_SEARCH_TOKEN = TokenType.DAT + + +class SearchOperator(exp.Binary): + """ParadeDB search operation that preserves the original operator.""" + + arg_types = {"this": True, "expression": True, "operator": True} + + +class ParadeDBTokenizer(PGVectorTokenizer): + """Tokenizer with ParadeDB search operators and pgvector distance operators.""" + + KEYWORDS = { + **PGVectorTokenizer.KEYWORDS, + "@@@": _PARADEDB_SEARCH_TOKEN, + "&&&": _PARADEDB_SEARCH_TOKEN, + "|||": _PARADEDB_SEARCH_TOKEN, + "===": _PARADEDB_SEARCH_TOKEN, + "###": _PARADEDB_SEARCH_TOKEN, + "##": _PARADEDB_SEARCH_TOKEN, + "##>": _PARADEDB_SEARCH_TOKEN, + } + + +class ParadeDBParser(PGVectorParser): + """Parser with ParadeDB search operators and pgvector distance operators.""" + + FACTOR = { + **PGVectorParser.FACTOR, + _PARADEDB_SEARCH_TOKEN: SearchOperator, + } + + +class ParadeDBGenerator(PGVectorGenerator): + """Generator that renders ParadeDB search operators and pgvector distance operators.""" + + def searchoperator_sql(self, expression: SearchOperator) -> str: + op = expression.args.get("operator", "@@@") + left = self.sql(expression, "this") + right = self.sql(expression, "expression") + return f"{left} {op} {right}" + + +class ParadeDB(PGVector): + """ParadeDB dialect with pg_search and pgvector extension support.""" + + Tokenizer = ParadeDBTokenizer + Parser = ParadeDBParser + Generator = ParadeDBGenerator diff --git a/sqlspec/adapters/asyncpg/dialect/_pgvector.py b/sqlspec/adapters/asyncpg/dialect/_pgvector.py new file mode 100644 index 00000000..272ee705 --- /dev/null +++ b/sqlspec/adapters/asyncpg/dialect/_pgvector.py @@ -0,0 +1,99 @@ +"""PGVector dialect extending Postgres with vector distance operators. + +Adds support for pgvector distance operators: +- <-> : L2 (Euclidean) distance (already in base Postgres) +- <#> : Negative inner product +- <=> : Cosine distance +- <+> : L1 (Taxicab/Manhattan) distance +- <~> : Hamming distance (binary vectors) +- <%> : Jaccard distance (binary vectors) +""" + +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.dialects.postgres import Postgres +from sqlglot.tokens import TokenType + +__all__ = ("PGVector",) + +# Use a single unused token type for all pgvector distance operators. +# The actual operator string is captured during parsing and stored in the expression. +# SQLGlot is not going to add extension operators, even as unused tokens, so this allows us +# to work around the limitation: https://github.com/tobymao/sqlglot/issues/6949 +_PGVECTOR_DISTANCE_TOKEN = TokenType.CARET_AT + + +class VectorDistance(exp.Binary): + """Vector distance operation that preserves the original operator.""" + + arg_types = {"this": True, "expression": True, "operator": True} + + +class PGVectorTokenizer(Postgres.Tokenizer): + """Tokenizer with pgvector distance operators.""" + + KEYWORDS = { + **Postgres.Tokenizer.KEYWORDS, + "<#>": _PGVECTOR_DISTANCE_TOKEN, + "<=>": _PGVECTOR_DISTANCE_TOKEN, + "<+>": _PGVECTOR_DISTANCE_TOKEN, + "<~>": _PGVECTOR_DISTANCE_TOKEN, + "<%>": _PGVECTOR_DISTANCE_TOKEN, + } + + +class PGVectorParser(Postgres.Parser): + """Parser that captures the original operator string for pgvector operations.""" + + FACTOR = { + **Postgres.Parser.FACTOR, + _PGVECTOR_DISTANCE_TOKEN: VectorDistance, + } + + def _parse_factor(self) -> t.Optional[exp.Expression]: + parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary + this = self._parse_at_time_zone(parse_method()) + + while self._match_set(self.FACTOR): + klass = self.FACTOR[self._prev.token_type] + comments = self._prev_comments + operator_text = self._prev.text + expression = parse_method() + + if not expression and klass is exp.IntDiv and self._prev.text.isalpha(): + self._retreat(self._index - 1) + return this + + if "operator" in klass.arg_types: + this = self.expression( + klass, this=this, comments=comments, expression=expression, operator=operator_text + ) + else: + this = self.expression(klass, this=this, comments=comments, expression=expression) + + if isinstance(this, exp.Div): + this.set("typed", self.dialect.TYPED_DIVISION) + this.set("safe", self.dialect.SAFE_DIVISION) + + return this + + +class PGVectorGenerator(Postgres.Generator): + """Generator that renders pgvector distance operators.""" + + def vectordistance_sql(self, expression: VectorDistance) -> str: + op = expression.args.get("operator", "<->") + left = self.sql(expression, "this") + right = self.sql(expression, "expression") + return f"{left} {op} {right}" + + +class PGVector(Postgres): + """PostgreSQL dialect with pgvector extension support.""" + + Tokenizer = PGVectorTokenizer + Parser = PGVectorParser + Generator = PGVectorGenerator diff --git a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py index 89e0ac2e..4d14a8d2 100644 --- a/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py +++ b/tests/unit/adapters/test_asyncpg/test_cloud_connectors.py @@ -16,7 +16,8 @@ def disable_connectors_by_default(): """Disable both connectors by default for clean test state.""" with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", False): with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", False): - yield + with patch.object(AsyncpgConfig, "_detect_extensions", new_callable=AsyncMock): + yield @pytest.fixture diff --git a/tests/unit/dialects/test_paradedb.py b/tests/unit/dialects/test_paradedb.py new file mode 100644 index 00000000..2e6be63b --- /dev/null +++ b/tests/unit/dialects/test_paradedb.py @@ -0,0 +1,98 @@ +"""Dialect unit tests for the ParadeDB (PostgreSQL + pgvector + pg_search) dialect.""" + +from sqlglot import parse_one + +import sqlspec.adapters.asyncpg.dialect # noqa: F401 + + +def _render(sql: str) -> str: + return parse_one(sql, dialect="paradedb").sql(dialect="paradedb") + + +def test_bm25_search_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ 'shoes'" + rendered = _render(sql) + assert "@@@" in rendered + + +def test_match_conjunction_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description &&& 'running shoes'" + rendered = _render(sql) + assert "&&&" in rendered + + +def test_match_disjunction_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description ||| 'shoes'" + rendered = _render(sql) + assert "|||" in rendered + + +def test_term_query_operator() -> None: + sql = "SELECT * FROM mock_items WHERE category === 'footwear'" + rendered = _render(sql) + assert "===" in rendered + + +def test_phrase_query_operator() -> None: + sql = "SELECT * FROM mock_items WHERE description ### 'running shoes'" + rendered = _render(sql) + assert "###" in rendered + + +def test_proximity_operator() -> None: + sql = "SELECT description, rating, category FROM mock_items WHERE description @@@ ('sleek' ## 1 ## 'shoes')" + rendered = _render(sql) + assert "##" in rendered + + +def test_directional_proximity_operator() -> None: + sql = "SELECT description, rating, category FROM mock_items WHERE description @@@ ('sleek' ##> 1 ##> 'shoes')" + rendered = _render(sql) + assert "##>" in rendered + + +def test_snippet_with_named_args() -> None: + sql = ( + "SELECT id, pdb.snippet(description, start_tag => '', end_tag => '') " + "FROM mock_items WHERE description ||| 'shoes' LIMIT 5" + ) + rendered = _render(sql) + assert "|||" in rendered + assert "pdb.snippet" in rendered + + +def test_pgvector_operators_still_work() -> None: + sql = "SELECT embedding <=> '[1,2,3]' AS cosine_dist, embedding <#> '[1,2,3]' AS inner_prod FROM items" + rendered = _render(sql) + assert "<=>" in rendered + assert "<#>" in rendered + + +def test_search_in_where_clause() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ 'shoes' AND active = TRUE" + rendered = _render(sql) + assert "@@@" in rendered + assert "WHERE" in rendered + + +def test_multiple_search_operators() -> None: + sql = ( + "SELECT description ## 'query' AS snippet, description ### 'query' AS score " + "FROM mock_items WHERE description @@@ 'query'" + ) + rendered = _render(sql) + assert "##" in rendered + assert "###" in rendered + assert "@@@" in rendered + + +def test_fuzzy_cast() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ 'runing shose'::pdb.fuzzy(2)" + rendered = _render(sql) + assert "@@@" in rendered + + +def test_prox_regex() -> None: + sql = "SELECT * FROM mock_items WHERE description @@@ pdb.prox_regex('sho.*', 2, 'run.*')" + rendered = _render(sql) + assert "@@@" in rendered diff --git a/tests/unit/dialects/test_pgvector.py b/tests/unit/dialects/test_pgvector.py new file mode 100644 index 00000000..5b5ad989 --- /dev/null +++ b/tests/unit/dialects/test_pgvector.py @@ -0,0 +1,60 @@ +"""Dialect unit tests for the PGVector (PostgreSQL + pgvector) dialect.""" + +from sqlglot import parse_one + +import sqlspec.adapters.asyncpg.dialect # noqa: F401 + + +def _render(sql: str) -> str: + return parse_one(sql, dialect="pgvector").sql(dialect="pgvector") + + +def test_cosine_distance_operator() -> None: + sql = "SELECT embedding <=> '[1,2,3]' FROM items" + rendered = _render(sql) + assert "<=>" in rendered + + +def test_negative_inner_product_operator() -> None: + sql = "SELECT embedding <#> '[1,2,3]' FROM items" + rendered = _render(sql) + assert "<#>" in rendered + + +def test_l1_distance_operator() -> None: + sql = "SELECT embedding <+> '[1,2,3]' FROM items" + rendered = _render(sql) + assert "<+>" in rendered + + +def test_hamming_distance_operator() -> None: + sql = "SELECT embedding <~> '[1,0,1]' FROM items" + rendered = _render(sql) + assert "<~>" in rendered + + +def test_jaccard_distance_operator() -> None: + sql = "SELECT embedding <%> '[1,0,1]' FROM items" + rendered = _render(sql) + assert "<%>" in rendered + + +def test_order_by_cosine_distance() -> None: + sql = "SELECT * FROM items ORDER BY embedding <=> '[1,2,3]'" + rendered = _render(sql) + assert "ORDER BY" in rendered + assert "<=>" in rendered + + +def test_distance_in_where_clause() -> None: + sql = "SELECT * FROM items WHERE embedding <=> '[1,2,3]' < 0.5" + rendered = _render(sql) + assert "<=>" in rendered + assert "WHERE" in rendered + + +def test_multiple_distance_operators() -> None: + sql = "SELECT embedding <=> '[1,2,3]' AS cosine_dist, embedding <#> '[1,2,3]' AS inner_prod FROM items" + rendered = _render(sql) + assert "<=>" in rendered + assert "<#>" in rendered