Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sqlspec/adapters/asyncpg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""AsyncPG adapter for SQLSpec."""

import sqlspec.adapters.asyncpg.dialect as dialect # noqa: F401

from sqlspec.adapters.asyncpg._typing import AsyncpgConnection, AsyncpgPool, AsyncpgPreparedStatement
from sqlspec.adapters.asyncpg.config import AsyncpgConfig, AsyncpgConnectionConfig, AsyncpgPoolConfig
from sqlspec.adapters.asyncpg.core import default_statement_config
Expand Down
72 changes: 60 additions & 12 deletions sqlspec/adapters/asyncpg/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ class AsyncpgDriverFeatures(TypedDict):
Defaults to True when pgvector-python is installed.
Provides automatic conversion between Python objects and PostgreSQL vector types.
Enables vector similarity operations and index support.
enable_paradedb: Enable ParadeDB (pg_search) extension detection.
When enabled and the pg_search extension is detected, the SQL dialect
switches to "paradedb" which supports search operators (@@@, &&&, etc.)
and inherits all pgvector distance operators.
Defaults to True. Independent of enable_pgvector.
enable_cloud_sql: Enable Google Cloud SQL connector integration.
Requires cloud-sql-python-connector package.
Defaults to False (explicit opt-in required).
Expand Down Expand Up @@ -146,6 +151,7 @@ class AsyncpgDriverFeatures(TypedDict):
json_deserializer: NotRequired["Callable[[str], Any]"]
enable_json_codecs: NotRequired[bool]
enable_pgvector: NotRequired[bool]
enable_paradedb: NotRequired[bool]
enable_cloud_sql: NotRequired[bool]
cloud_sql_instance: NotRequired[str]
cloud_sql_enable_iam_auth: NotRequired[bool]
Expand Down Expand Up @@ -328,6 +334,7 @@ def __init__(
self._cloud_sql_connector: Any | None = None
self._alloydb_connector: Any | None = None
self._pgvector_available: bool | None = None
self._paradedb_available: bool | None = None

self._validate_connector_config()

Expand Down Expand Up @@ -435,7 +442,43 @@ async def _create_pool(self) -> "Pool[Record]":

config.setdefault("init", self._init_connection)

return await asyncpg_create_pool(**config)
pool = await asyncpg_create_pool(**config)
await self._detect_extensions(pool)
return pool

async def _detect_extensions(self, pool: "Pool[Record]") -> None:
"""Detect database extensions and update dialect accordingly.

Args:
pool: Connection pool to acquire a connection from.
"""
extensions = [
name
for name, enabled in [
("vector", self.driver_features.get("enable_pgvector", False)),
("pg_search", self.driver_features.get("enable_paradedb", False)),
]
if enabled
]
if not extensions:
return

connection = await pool.acquire()
try:
results = await connection.fetch(
"SELECT extname FROM pg_extension WHERE extname = ANY($1::text[])",
extensions,
)
detected = {r["extname"] for r in results}
self._pgvector_available = "vector" in detected
self._paradedb_available = "pg_search" in detected
except Exception:
self._pgvector_available = False
self._paradedb_available = False
finally:
await pool.release(connection)

self._update_dialect_for_extensions()

async def _init_connection(self, connection: "AsyncpgConnection") -> None:
"""Initialize connection with JSON codecs, pgvector support, and user callback.
Expand All @@ -450,22 +493,27 @@ async def _init_connection(self, connection: "AsyncpgConnection") -> None:
decoder=self.driver_features.get("json_deserializer", from_json),
)

if self.driver_features.get("enable_pgvector", False):
if self._pgvector_available is None:
try:
result = await connection.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
self._pgvector_available = bool(result)
except Exception:
# If we can't query extensions, assume false to be safe and avoid errors
self._pgvector_available = False

if self._pgvector_available:
await register_pgvector_support(connection)
if self._pgvector_available:
await register_pgvector_support(connection)

# Call user-provided callback after internal setup
if self._user_connection_hook is not None:
await self._user_connection_hook(connection)

def _update_dialect_for_extensions(self) -> None:
"""Update statement_config dialect based on detected extensions.

Priority: paradedb > pgvector > postgres (default).
"""
current_dialect = getattr(self.statement_config, "dialect", "postgres")
if current_dialect != "postgres":
return

if self._paradedb_available:
self.statement_config = self.statement_config.replace(dialect="paradedb")
elif self._pgvector_available:
self.statement_config = self.statement_config.replace(dialect="pgvector")

async def _close_pool(self) -> None:
"""Close the actual async connection pool and cleanup connectors."""
if self.connection_instance:
Expand Down
1 change: 1 addition & 0 deletions sqlspec/adapters/asyncpg/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ def apply_driver_features(
deserializer = processed_features.setdefault("json_deserializer", from_json)
processed_features.setdefault("enable_json_codecs", True)
processed_features.setdefault("enable_pgvector", PGVECTOR_INSTALLED)
processed_features.setdefault("enable_paradedb", True)
processed_features.setdefault("enable_cloud_sql", False)
processed_features.setdefault("enable_alloydb", False)

Expand Down
12 changes: 12 additions & 0 deletions sqlspec/adapters/asyncpg/dialect/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Asyncpg dialect submodule."""

from sqlglot.dialects.dialect import Dialect

from sqlspec.adapters.asyncpg.dialect._paradedb import ParadeDB
from sqlspec.adapters.asyncpg.dialect._pgvector import PGVector

# Register dialects with sqlglot
Dialect.classes["pgvector"] = PGVector
Dialect.classes["paradedb"] = ParadeDB

__all__ = ("PGVector", "ParadeDB")
78 changes: 78 additions & 0 deletions sqlspec/adapters/asyncpg/dialect/_paradedb.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""ParadeDB dialect extending PGVector with pg_search BM25/search operators.

Adds support for ParadeDB search operators:
- @@@ : BM25 full-text search
- &&& : Boolean AND search
- ||| : Boolean OR search
- === : Exact term match
- ### : Score/rank retrieval
- ## : Snippet/highlight retrieval
- ##> : Snippet/highlight with options

Also inherits pgvector distance operators from PGVector:
- <-> : L2 (Euclidean) distance
- <#> : Negative inner product
- <=> : Cosine distance
- <+> : L1 (Taxicab/Manhattan) distance
- <~> : Hamming distance (binary vectors)
- <%> : Jaccard distance (binary vectors)
"""

from __future__ import annotations

from sqlglot import exp
from sqlglot.tokens import TokenType

from sqlspec.adapters.asyncpg.dialect._pgvector import PGVector, PGVectorGenerator, PGVectorParser, PGVectorTokenizer

__all__ = ("ParadeDB",)

_PARADEDB_SEARCH_TOKEN = TokenType.DAT


class SearchOperator(exp.Binary):
"""ParadeDB search operation that preserves the original operator."""

arg_types = {"this": True, "expression": True, "operator": True}


class ParadeDBTokenizer(PGVectorTokenizer):
"""Tokenizer with ParadeDB search operators and pgvector distance operators."""

KEYWORDS = {
**PGVectorTokenizer.KEYWORDS,
"@@@": _PARADEDB_SEARCH_TOKEN,
"&&&": _PARADEDB_SEARCH_TOKEN,
"|||": _PARADEDB_SEARCH_TOKEN,
"===": _PARADEDB_SEARCH_TOKEN,
"###": _PARADEDB_SEARCH_TOKEN,
"##": _PARADEDB_SEARCH_TOKEN,
"##>": _PARADEDB_SEARCH_TOKEN,
}


class ParadeDBParser(PGVectorParser):
"""Parser with ParadeDB search operators and pgvector distance operators."""

FACTOR = {
**PGVectorParser.FACTOR,
_PARADEDB_SEARCH_TOKEN: SearchOperator,
}


class ParadeDBGenerator(PGVectorGenerator):
"""Generator that renders ParadeDB search operators and pgvector distance operators."""

def searchoperator_sql(self, expression: SearchOperator) -> str:
op = expression.args.get("operator", "@@@")
left = self.sql(expression, "this")
right = self.sql(expression, "expression")
return f"{left} {op} {right}"


class ParadeDB(PGVector):
"""ParadeDB dialect with pg_search and pgvector extension support."""

Tokenizer = ParadeDBTokenizer
Parser = ParadeDBParser
Generator = ParadeDBGenerator
99 changes: 99 additions & 0 deletions sqlspec/adapters/asyncpg/dialect/_pgvector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
"""PGVector dialect extending Postgres with vector distance operators.

Adds support for pgvector distance operators:
- <-> : L2 (Euclidean) distance (already in base Postgres)
- <#> : Negative inner product
- <=> : Cosine distance
- <+> : L1 (Taxicab/Manhattan) distance
- <~> : Hamming distance (binary vectors)
- <%> : Jaccard distance (binary vectors)
"""

from __future__ import annotations

import typing as t

from sqlglot import exp
from sqlglot.dialects.postgres import Postgres
from sqlglot.tokens import TokenType

__all__ = ("PGVector",)

# Use a single unused token type for all pgvector distance operators.
# The actual operator string is captured during parsing and stored in the expression.
# SQLGlot is not going to add extension operators, even as unused tokens, so this allows us
# to work around the limitation: https://github.com/tobymao/sqlglot/issues/6949
_PGVECTOR_DISTANCE_TOKEN = TokenType.CARET_AT


class VectorDistance(exp.Binary):
"""Vector distance operation that preserves the original operator."""

arg_types = {"this": True, "expression": True, "operator": True}


class PGVectorTokenizer(Postgres.Tokenizer):
"""Tokenizer with pgvector distance operators."""

KEYWORDS = {
**Postgres.Tokenizer.KEYWORDS,
"<#>": _PGVECTOR_DISTANCE_TOKEN,
"<=>": _PGVECTOR_DISTANCE_TOKEN,
"<+>": _PGVECTOR_DISTANCE_TOKEN,
"<~>": _PGVECTOR_DISTANCE_TOKEN,
"<%>": _PGVECTOR_DISTANCE_TOKEN,
}


class PGVectorParser(Postgres.Parser):
"""Parser that captures the original operator string for pgvector operations."""

FACTOR = {
**Postgres.Parser.FACTOR,
_PGVECTOR_DISTANCE_TOKEN: VectorDistance,
}

def _parse_factor(self) -> t.Optional[exp.Expression]:
parse_method = self._parse_exponent if self.EXPONENT else self._parse_unary
this = self._parse_at_time_zone(parse_method())

while self._match_set(self.FACTOR):
klass = self.FACTOR[self._prev.token_type]
comments = self._prev_comments
operator_text = self._prev.text
expression = parse_method()

if not expression and klass is exp.IntDiv and self._prev.text.isalpha():
self._retreat(self._index - 1)
return this

if "operator" in klass.arg_types:
this = self.expression(
klass, this=this, comments=comments, expression=expression, operator=operator_text
)
else:
this = self.expression(klass, this=this, comments=comments, expression=expression)

if isinstance(this, exp.Div):
this.set("typed", self.dialect.TYPED_DIVISION)
this.set("safe", self.dialect.SAFE_DIVISION)

return this


class PGVectorGenerator(Postgres.Generator):
"""Generator that renders pgvector distance operators."""

def vectordistance_sql(self, expression: VectorDistance) -> str:
op = expression.args.get("operator", "<->")
left = self.sql(expression, "this")
right = self.sql(expression, "expression")
return f"{left} {op} {right}"


class PGVector(Postgres):
"""PostgreSQL dialect with pgvector extension support."""

Tokenizer = PGVectorTokenizer
Parser = PGVectorParser
Generator = PGVectorGenerator
3 changes: 2 additions & 1 deletion tests/unit/adapters/test_asyncpg/test_cloud_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def disable_connectors_by_default():
"""Disable both connectors by default for clean test state."""
with patch("sqlspec.adapters.asyncpg.config.CLOUD_SQL_CONNECTOR_INSTALLED", False):
with patch("sqlspec.adapters.asyncpg.config.ALLOYDB_CONNECTOR_INSTALLED", False):
yield
with patch.object(AsyncpgConfig, "_detect_extensions", new_callable=AsyncMock):
yield


@pytest.fixture
Expand Down
Loading
Loading