diff --git a/singlestoredb/ibis_extras/__init__.py b/singlestoredb/ibis_extras/__init__.py new file mode 100644 index 00000000..2026d326 --- /dev/null +++ b/singlestoredb/ibis_extras/__init__.py @@ -0,0 +1,111 @@ +"""SingleStoreDB extensions for Ibis. + +This package adds SingleStoreDB-specific features to the Ibis backend. +Features are automatically registered on import. + +Usage +----- +>>> import ibis +>>> import singlestoredb.ibis_extras # Auto-registers extensions +>>> +>>> con = ibis.singlestoredb.connect(host="...", database="...") +>>> +>>> # Variable accessors (from old ibis_singlestoredb) +>>> con.show.databases() +>>> con.globals["max_connections"] +>>> con.vars["autocommit"] +>>> +>>> # Backend methods +>>> con.get_storage_info() +>>> con.get_workload_metrics() +>>> con.optimize_table("users") +>>> +>>> # Table methods (work on any table from SingleStoreDB) +>>> t = con.table("users") +>>> t.optimize() +>>> t.get_stats() +""" +from __future__ import annotations + +import warnings + +from .mixins import BackendExtensionsMixin +from .mixins import TableExtensionsMixin + +__all__ = [ + 'BackendExtensionsMixin', + 'TableExtensionsMixin', + 'is_registered', + 'register', +] + +_registered = False + + +def _check_collisions(cls: type, mixin: type) -> None: + """Check for method collisions between mixin and target class.""" + mixin_attrs = { + name + for name in dir(mixin) + if not name.startswith('_') and callable(getattr(mixin, name, None)) + } + mixin_props = { + name + for name in dir(mixin) + if not name.startswith('_') + and isinstance(getattr(mixin, name, None), property) + } + mixin_members = mixin_attrs | mixin_props + + existing_attrs = {name for name in dir(cls) if not name.startswith('_')} + + collisions = mixin_members & existing_attrs + if collisions: + warnings.warn( + f'Mixin {mixin.__name__} has methods that collide with ' + f'{cls.__name__}: {collisions}', + stacklevel=3, + ) + + +def register() -> None: + """Register mixins on Backend and ir.Table. + + This is called automatically on import, but can be called + explicitly if needed. + """ + global _registered # noqa: PLW0603 + if _registered: + return + + try: + import ibis.expr.types as ir + from ibis.backends.singlestoredb import Backend + except ImportError as e: + raise ImportError( + 'ibis_extras requires ibis with singlestoredb backend. ' + 'Install with: pip install "singlestoredb[ibis]"', + ) from e + + # Check for collisions before adding mixins + _check_collisions(Backend, BackendExtensionsMixin) + _check_collisions(ir.Table, TableExtensionsMixin) + + # Add mixin to Backend + if BackendExtensionsMixin not in Backend.__bases__: + Backend.__bases__ = (BackendExtensionsMixin,) + Backend.__bases__ + + # Add mixin to ir.Table + if TableExtensionsMixin not in ir.Table.__bases__: + ir.Table.__bases__ = (TableExtensionsMixin,) + ir.Table.__bases__ + + _registered = True + + +def is_registered() -> bool: + """Check if extensions have been registered.""" + return _registered + + +# Auto-register on import +register() diff --git a/singlestoredb/ibis_extras/mixins.py b/singlestoredb/ibis_extras/mixins.py new file mode 100644 index 00000000..b031cc1a --- /dev/null +++ b/singlestoredb/ibis_extras/mixins.py @@ -0,0 +1,211 @@ +"""Mixin classes for SingleStoreDB extensions.""" +from __future__ import annotations + +from typing import Any +from typing import Protocol +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from contextlib import AbstractContextManager + + import ibis.expr.types as ir + + class _BackendProtocol(Protocol): + """Protocol defining backend interface used by BackendExtensionsMixin.""" + + _client: Any + + @property + def current_database(self) -> str: ... + def sql(self, query: str) -> ir.Table: ... + def raw_sql(self, query: str) -> AbstractContextManager[Any]: ... + + class _TableProtocol(Protocol): + """Protocol defining table interface used by TableExtensionsMixin.""" + + def get_name(self) -> str: ... + def op(self) -> Any: ... + + _BackendBase: type = _BackendProtocol + _TableBase: type = _TableProtocol +else: + _BackendBase = object + _TableBase = object + + +def _quote_identifier(name: str) -> str: + """Quote an identifier (table, database, column name) for safe SQL usage.""" + # Escape backticks by doubling them (MySQL/SingleStore convention) + escaped = name.replace('`', '``') + return f'`{escaped}`' + + +def _escape_string_literal(value: str) -> str: + """Escape a string value for use in SQL string literals.""" + # Escape single quotes by doubling them, and escape backslashes + return value.replace('\\', '\\\\').replace("'", "''") + + +def _get_table_backend_and_db( + table: ir.Table, +) -> tuple[BackendExtensionsMixin, str | None]: + """Get SingleStoreDB backend and database from table.""" + op = table.op() + if hasattr(op, 'source') and op.source.name == 'singlestoredb': + db = getattr(getattr(op, 'namespace', None), 'database', None) + return op.source, db # type: ignore[return-value] + raise TypeError( + f'This method only works with SingleStoreDB tables, ' + f"got {getattr(op.source, 'name', 'unknown')} backend", + ) + + +class BackendExtensionsMixin(_BackendBase): + """Mixin for SingleStoreDB Backend extensions.""" + + __slots__ = () + + # --- Variable/Show accessors from old ibis_singlestoredb package --- + + @property + def show(self) -> Any: + """Access to SHOW commands on the server.""" + return self._client.show + + @property + def globals(self) -> Any: + """Accessor for global variables in the server.""" + return self._client.globals + + @property + def locals(self) -> Any: + """Accessor for local variables in the server.""" + return self._client.locals + + @property + def cluster_globals(self) -> Any: + """Accessor for cluster global variables in the server.""" + return self._client.cluster_globals + + @property + def cluster_locals(self) -> Any: + """Accessor for cluster local variables in the server.""" + return self._client.cluster_locals + + @property + def vars(self) -> Any: + """Accessor for variables in the server.""" + return self._client.vars + + @property + def cluster_vars(self) -> Any: + """Accessor for cluster variables in the server.""" + return self._client.cluster_vars + + # --- New extension methods --- + + def get_storage_info(self, database: str | None = None) -> ir.Table: + """Get storage statistics for tables in a database. + + Parameters + ---------- + database + Database name. Defaults to current database. + + Returns + ------- + ir.Table + Table with storage statistics. + """ + db = _escape_string_literal(database or self.current_database) + # S608: db is escaped via _escape_string_literal + query = f""" + SELECT * FROM information_schema.table_statistics + WHERE database_name = '{db}' + """ # noqa: S608 + return self.sql(query) + + def get_workload_metrics(self) -> ir.Table: + """Get workload management metrics.""" + return self.sql( + 'SELECT * FROM information_schema.mv_workload_management_events', + ) + + def optimize_table(self, table_name: str, *, database: str | None = None) -> None: + """Optimize a specific table. + + Parameters + ---------- + table_name + Name of table to optimize. + database + Database name. Defaults to current database. + """ + db = _quote_identifier(database or self.current_database) + table = _quote_identifier(table_name) + with self.raw_sql(f'OPTIMIZE TABLE {db}.{table} FULL'): + pass + + def get_table_stats( + self, + table_name: str, + *, + database: str | None = None, + ) -> dict[str, Any]: + """Get statistics for a specific table. + + Parameters + ---------- + table_name + Name of table. + database + Database name. Defaults to current database. + + Returns + ------- + dict + Table statistics. + """ + db = _escape_string_literal(database or self.current_database) + table = _escape_string_literal(table_name) + # S608: db and table are escaped via _escape_string_literal + result = self.sql( + f""" + SELECT * FROM information_schema.table_statistics + WHERE database_name = '{db}' AND table_name = '{table}' + """, # noqa: S608 + ).execute() + return result.to_dict(orient='records')[0] if len(result) else {} + + +class TableExtensionsMixin(_TableBase): + """Mixin for ir.Table extensions (SingleStoreDB only).""" + + __slots__ = () + + def optimize(self) -> None: + """Optimize this table (SingleStoreDB only).""" + backend, db = _get_table_backend_and_db(self) + backend.optimize_table(self.get_name(), database=db) + + def get_stats(self) -> dict[str, Any]: + """Get statistics for this table (SingleStoreDB only).""" + backend, db = _get_table_backend_and_db(self) + return backend.get_table_stats(self.get_name(), database=db) + + def get_column_statistics(self, column: str | None = None) -> ir.Table: + """Get column statistics (SingleStoreDB only). + + Parameters + ---------- + column + Specific column name, or None for all columns. + """ + backend, db = _get_table_backend_and_db(self) + db_name = db or backend.current_database + db_quoted = _quote_identifier(db_name) + table = _quote_identifier(self.get_name()) + query = f'SHOW COLUMNAR_SEGMENT_INDEX ON {db_quoted}.{table}' + if column: + query += f' COLUMNS ({_quote_identifier(column)})' + return backend.sql(query)