Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 111 additions & 0 deletions singlestoredb/ibis_extras/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
"""SingleStoreDB extensions for Ibis.

This package adds SingleStoreDB-specific features to the Ibis backend.
Features are automatically registered on import.

Usage
-----
>>> import ibis
>>> import singlestoredb.ibis_extras # Auto-registers extensions
>>>
>>> con = ibis.singlestoredb.connect(host="...", database="...")
>>>
>>> # Variable accessors (from old ibis_singlestoredb)
>>> con.show.databases()
>>> con.globals["max_connections"]
>>> con.vars["autocommit"]
>>>
>>> # Backend methods
>>> con.get_storage_info()
>>> con.get_workload_metrics()
>>> con.optimize_table("users")
>>>
>>> # Table methods (work on any table from SingleStoreDB)
>>> t = con.table("users")
>>> t.optimize()
>>> t.get_stats()
"""
from __future__ import annotations

import warnings

from .mixins import BackendExtensionsMixin
from .mixins import TableExtensionsMixin

__all__ = [
'BackendExtensionsMixin',
'TableExtensionsMixin',
'is_registered',
'register',
]

_registered = False


def _check_collisions(cls: type, mixin: type) -> None:
"""Check for method collisions between mixin and target class."""
mixin_attrs = {
name
for name in dir(mixin)
if not name.startswith('_') and callable(getattr(mixin, name, None))
}
mixin_props = {
name
for name in dir(mixin)
if not name.startswith('_')
and isinstance(getattr(mixin, name, None), property)
}
mixin_members = mixin_attrs | mixin_props

existing_attrs = {name for name in dir(cls) if not name.startswith('_')}

collisions = mixin_members & existing_attrs
if collisions:
warnings.warn(
f'Mixin {mixin.__name__} has methods that collide with '
f'{cls.__name__}: {collisions}',
stacklevel=3,
)


def register() -> None:
"""Register mixins on Backend and ir.Table.

This is called automatically on import, but can be called
explicitly if needed.
"""
global _registered # noqa: PLW0603
if _registered:
return

try:
import ibis.expr.types as ir
from ibis.backends.singlestoredb import Backend
except ImportError as e:
raise ImportError(
'ibis_extras requires ibis with singlestoredb backend. '
'Install with: pip install "singlestoredb[ibis]"',
) from e

# Check for collisions before adding mixins
_check_collisions(Backend, BackendExtensionsMixin)
_check_collisions(ir.Table, TableExtensionsMixin)

# Add mixin to Backend
if BackendExtensionsMixin not in Backend.__bases__:
Backend.__bases__ = (BackendExtensionsMixin,) + Backend.__bases__

# Add mixin to ir.Table
if TableExtensionsMixin not in ir.Table.__bases__:
ir.Table.__bases__ = (TableExtensionsMixin,) + ir.Table.__bases__

_registered = True


def is_registered() -> bool:
"""Check if extensions have been registered."""
return _registered


# Auto-register on import
register()
211 changes: 211 additions & 0 deletions singlestoredb/ibis_extras/mixins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
"""Mixin classes for SingleStoreDB extensions."""
from __future__ import annotations

from typing import Any
from typing import Protocol
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from contextlib import AbstractContextManager

import ibis.expr.types as ir

class _BackendProtocol(Protocol):
"""Protocol defining backend interface used by BackendExtensionsMixin."""

_client: Any

@property
def current_database(self) -> str: ...
def sql(self, query: str) -> ir.Table: ...
def raw_sql(self, query: str) -> AbstractContextManager[Any]: ...

class _TableProtocol(Protocol):
"""Protocol defining table interface used by TableExtensionsMixin."""

def get_name(self) -> str: ...
def op(self) -> Any: ...

_BackendBase: type = _BackendProtocol
_TableBase: type = _TableProtocol
else:
_BackendBase = object
_TableBase = object


def _quote_identifier(name: str) -> str:
"""Quote an identifier (table, database, column name) for safe SQL usage."""
# Escape backticks by doubling them (MySQL/SingleStore convention)
escaped = name.replace('`', '``')
return f'`{escaped}`'


def _escape_string_literal(value: str) -> str:
"""Escape a string value for use in SQL string literals."""
# Escape single quotes by doubling them, and escape backslashes
return value.replace('\\', '\\\\').replace("'", "''")


def _get_table_backend_and_db(
table: ir.Table,
) -> tuple[BackendExtensionsMixin, str | None]:
"""Get SingleStoreDB backend and database from table."""
op = table.op()
if hasattr(op, 'source') and op.source.name == 'singlestoredb':
db = getattr(getattr(op, 'namespace', None), 'database', None)
return op.source, db # type: ignore[return-value]
raise TypeError(
f'This method only works with SingleStoreDB tables, '
f"got {getattr(op.source, 'name', 'unknown')} backend",
)


class BackendExtensionsMixin(_BackendBase):
"""Mixin for SingleStoreDB Backend extensions."""

__slots__ = ()

# --- Variable/Show accessors from old ibis_singlestoredb package ---

@property
def show(self) -> Any:
"""Access to SHOW commands on the server."""
return self._client.show

@property
def globals(self) -> Any:
"""Accessor for global variables in the server."""
return self._client.globals

@property
def locals(self) -> Any:
"""Accessor for local variables in the server."""
return self._client.locals

@property
def cluster_globals(self) -> Any:
"""Accessor for cluster global variables in the server."""
return self._client.cluster_globals

@property
def cluster_locals(self) -> Any:
"""Accessor for cluster local variables in the server."""
return self._client.cluster_locals

@property
def vars(self) -> Any:
"""Accessor for variables in the server."""
return self._client.vars

@property
def cluster_vars(self) -> Any:
"""Accessor for cluster variables in the server."""
return self._client.cluster_vars

# --- New extension methods ---

def get_storage_info(self, database: str | None = None) -> ir.Table:
"""Get storage statistics for tables in a database.

Parameters
----------
database
Database name. Defaults to current database.

Returns
-------
ir.Table
Table with storage statistics.
"""
db = _escape_string_literal(database or self.current_database)
# S608: db is escaped via _escape_string_literal
query = f"""
SELECT * FROM information_schema.table_statistics
WHERE database_name = '{db}'
""" # noqa: S608
return self.sql(query)

def get_workload_metrics(self) -> ir.Table:
"""Get workload management metrics."""
return self.sql(
'SELECT * FROM information_schema.mv_workload_management_events',
)

def optimize_table(self, table_name: str, *, database: str | None = None) -> None:
"""Optimize a specific table.

Parameters
----------
table_name
Name of table to optimize.
database
Database name. Defaults to current database.
"""
db = _quote_identifier(database or self.current_database)
table = _quote_identifier(table_name)
with self.raw_sql(f'OPTIMIZE TABLE {db}.{table} FULL'):
pass

def get_table_stats(
self,
table_name: str,
*,
database: str | None = None,
) -> dict[str, Any]:
"""Get statistics for a specific table.

Parameters
----------
table_name
Name of table.
database
Database name. Defaults to current database.

Returns
-------
dict
Table statistics.
"""
db = _escape_string_literal(database or self.current_database)
table = _escape_string_literal(table_name)
# S608: db and table are escaped via _escape_string_literal
result = self.sql(
f"""
SELECT * FROM information_schema.table_statistics
WHERE database_name = '{db}' AND table_name = '{table}'
""", # noqa: S608
).execute()
return result.to_dict(orient='records')[0] if len(result) else {}


class TableExtensionsMixin(_TableBase):
"""Mixin for ir.Table extensions (SingleStoreDB only)."""

__slots__ = ()

def optimize(self) -> None:
"""Optimize this table (SingleStoreDB only)."""
backend, db = _get_table_backend_and_db(self)
backend.optimize_table(self.get_name(), database=db)

def get_stats(self) -> dict[str, Any]:
"""Get statistics for this table (SingleStoreDB only)."""
backend, db = _get_table_backend_and_db(self)
return backend.get_table_stats(self.get_name(), database=db)

def get_column_statistics(self, column: str | None = None) -> ir.Table:
"""Get column statistics (SingleStoreDB only).

Parameters
----------
column
Specific column name, or None for all columns.
"""
backend, db = _get_table_backend_and_db(self)
db_name = db or backend.current_database
db_quoted = _quote_identifier(db_name)
table = _quote_identifier(self.get_name())
query = f'SHOW COLUMNAR_SEGMENT_INDEX ON {db_quoted}.{table}'
if column:
query += f' COLUMNS ({_quote_identifier(column)})'
return backend.sql(query)