Skip to content

Commit 90171d9

Browse files
committed
feat: add automatic variable registration for Arrow-compatible Python objects in SQL queries
1 parent 3c4609d commit 90171d9

File tree

4 files changed

+329
-11
lines changed

4 files changed

+329
-11
lines changed

docs/source/user-guide/sql.rst

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,4 +36,29 @@ DataFusion also offers a SQL API, read the full reference `here <https://arrow.a
3636
df = ctx.sql('SELECT "Attack"+"Defense", "Attack"-"Defense" FROM pokemon')
3737
3838
# collect and convert to pandas DataFrame
39-
df.to_pandas()
39+
df.to_pandas()
40+
41+
Automatic variable registration
42+
-------------------------------
43+
44+
You can opt-in to DataFusion automatically registering Arrow-compatible Python
45+
objects that appear in SQL queries. This removes the need to call
46+
``register_*`` helpers explicitly when working with in-memory data structures.
47+
48+
.. code-block:: python
49+
50+
import pyarrow as pa
51+
from datafusion import SessionContext
52+
53+
ctx = SessionContext(auto_register_python_variables=True)
54+
55+
orders = pa.Table.from_pydict({"item": ["apple", "pear"], "qty": [5, 2]})
56+
57+
result = ctx.sql("SELECT item, qty FROM orders WHERE qty > 2")
58+
print(result.to_pandas())
59+
60+
The feature inspects the call stack for variables whose names match missing
61+
tables and registers them if they expose Arrow data (including pandas and
62+
Polars DataFrames). Existing contexts can enable or disable the behavior at
63+
runtime through the :py:attr:`SessionContext.auto_register_python_variables`
64+
property.

python/datafusion/context.py

Lines changed: 159 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,10 @@
1919

2020
from __future__ import annotations
2121

22+
import inspect
23+
import re
2224
import warnings
23-
from typing import TYPE_CHECKING, Any, Protocol
25+
from typing import TYPE_CHECKING, Any, Iterator, Protocol
2426

2527
try:
2628
from warnings import deprecated # Python 3.13+
@@ -41,6 +43,8 @@
4143
from ._internal import SQLOptions as SQLOptionsInternal
4244
from ._internal import expr as expr_internal
4345

46+
_MISSING_TABLE_PATTERN = re.compile(r"(?i)(?:table|view) '([^']+)' not found")
47+
4448
if TYPE_CHECKING:
4549
import pathlib
4650
from collections.abc import Sequence
@@ -483,6 +487,8 @@ def __init__(
483487
self,
484488
config: SessionConfig | None = None,
485489
runtime: RuntimeEnvBuilder | None = None,
490+
*,
491+
auto_register_python_variables: bool = False,
486492
) -> None:
487493
"""Main interface for executing queries with DataFusion.
488494
@@ -493,6 +499,9 @@ def __init__(
493499
Args:
494500
config: Session configuration options.
495501
runtime: Runtime configuration options.
502+
auto_register_python_variables: Automatically register Arrow-like
503+
Python objects referenced in SQL queries when they are available
504+
in the caller's scope.
496505
497506
Example usage:
498507
@@ -508,6 +517,7 @@ def __init__(
508517
runtime = runtime.config_internal if runtime is not None else None
509518

510519
self.ctx = SessionContextInternal(config, runtime)
520+
self._auto_register_python_variables = auto_register_python_variables
511521

512522
def __repr__(self) -> str:
513523
"""Print a string representation of the Session Context."""
@@ -534,8 +544,18 @@ def enable_url_table(self) -> SessionContext:
534544
klass = self.__class__
535545
obj = klass.__new__(klass)
536546
obj.ctx = self.ctx.enable_url_table()
547+
obj._auto_register_python_variables = self._auto_register_python_variables
537548
return obj
538549

550+
@property
551+
def auto_register_python_variables(self) -> bool:
552+
"""Toggle automatic registration of Python variables in SQL queries."""
553+
return self._auto_register_python_variables
554+
555+
@auto_register_python_variables.setter
556+
def auto_register_python_variables(self, enabled: bool) -> None:
557+
self._auto_register_python_variables = bool(enabled)
558+
539559
def register_object_store(
540560
self, schema: str, store: Any, host: str | None = None
541561
) -> None:
@@ -600,9 +620,12 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
600620
Returns:
601621
DataFrame representation of the SQL query.
602622
"""
603-
if options is None:
604-
return DataFrame(self.ctx.sql(query))
605-
return DataFrame(self.ctx.sql_with_options(query, options.options_internal))
623+
options_internal = None if options is None else options.options_internal
624+
return self._sql_with_retry(
625+
query,
626+
options_internal,
627+
self._auto_register_python_variables,
628+
)
606629

607630
def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
608631
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
@@ -619,6 +642,138 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
619642
"""
620643
return self.sql(query, options)
621644

645+
def _sql_with_retry(
646+
self,
647+
query: str,
648+
options_internal: SQLOptionsInternal | None,
649+
allow_retry: bool,
650+
) -> DataFrame:
651+
try:
652+
if options_internal is None:
653+
return DataFrame(self.ctx.sql(query))
654+
return DataFrame(self.ctx.sql_with_options(query, options_internal))
655+
except Exception as exc:
656+
if not allow_retry or not self._handle_missing_table_error(exc):
657+
raise
658+
return self._sql_with_retry(query, options_internal, allow_retry)
659+
660+
def _handle_missing_table_error(self, error: Exception) -> bool:
661+
missing_tables = self._extract_missing_table_names(error)
662+
if not missing_tables:
663+
return False
664+
665+
registered_any = False
666+
attempted: set[str] = set()
667+
for raw_name in missing_tables:
668+
for candidate in self._candidate_table_names(raw_name):
669+
if candidate in attempted:
670+
continue
671+
attempted.add(candidate)
672+
673+
value = self._lookup_python_variable(candidate)
674+
if value is None:
675+
continue
676+
if self._register_python_value(candidate, value):
677+
registered_any = True
678+
break
679+
return registered_any
680+
681+
def _candidate_table_names(self, identifier: str) -> Iterator[str]:
682+
cleaned = identifier.strip().strip('"')
683+
if not cleaned:
684+
return
685+
686+
seen: set[str] = set()
687+
candidates = [cleaned]
688+
if "." in cleaned:
689+
candidates.append(cleaned.rsplit(".", 1)[-1])
690+
691+
for candidate in candidates:
692+
normalized = candidate.strip()
693+
if not normalized or normalized in seen:
694+
continue
695+
seen.add(normalized)
696+
yield normalized
697+
698+
def _extract_missing_table_names(self, error: Exception) -> set[str]:
699+
names: set[str] = set()
700+
attribute = getattr(error, "missing_table_names", None)
701+
if attribute is not None:
702+
if isinstance(attribute, (list, tuple, set, frozenset)):
703+
for item in attribute:
704+
if item is None:
705+
continue
706+
for candidate in self._candidate_table_names(str(item)):
707+
names.add(candidate)
708+
elif attribute is not None:
709+
for candidate in self._candidate_table_names(str(attribute)):
710+
names.add(candidate)
711+
if names:
712+
return names
713+
714+
message = str(error)
715+
return {match.group(1) for match in _MISSING_TABLE_PATTERN.finditer(message)}
716+
717+
def _lookup_python_variable(self, name: str) -> Any | None:
718+
frame = inspect.currentframe()
719+
outer = frame.f_back if frame is not None else None
720+
lower_name = name.lower()
721+
722+
try:
723+
while outer is not None:
724+
for mapping in (outer.f_locals, outer.f_globals):
725+
if not mapping:
726+
continue
727+
if name in mapping:
728+
value = mapping[name]
729+
if value is not None:
730+
return value
731+
# allow outer scopes to provide a non-``None`` value
732+
continue
733+
for key, value in mapping.items():
734+
if value is None:
735+
continue
736+
if key == name or key.lower() == lower_name:
737+
return value
738+
outer = outer.f_back
739+
finally:
740+
del outer
741+
del frame
742+
return None
743+
744+
def _register_python_value(self, table_name: str, value: Any) -> bool:
745+
if value is None:
746+
return False
747+
748+
registered = False
749+
if isinstance(value, DataFrame):
750+
self.register_view(table_name, value)
751+
registered = True
752+
elif isinstance(value, Table):
753+
self.register_table(table_name, value)
754+
registered = True
755+
else:
756+
provider = getattr(value, "__datafusion_table_provider__", None)
757+
if callable(provider):
758+
self.register_table_provider(table_name, value)
759+
registered = True
760+
elif hasattr(value, "__arrow_c_stream__") or hasattr(
761+
value, "__arrow_c_array__"
762+
):
763+
self.from_arrow(value, name=table_name)
764+
registered = True
765+
else:
766+
module_name = getattr(type(value), "__module__", "") or ""
767+
class_name = getattr(type(value), "__name__", "") or ""
768+
if module_name.startswith("pandas.") and class_name == "DataFrame":
769+
self.from_pandas(value, name=table_name)
770+
registered = True
771+
elif module_name.startswith("polars") and class_name == "DataFrame":
772+
self.from_polars(value, name=table_name)
773+
registered = True
774+
775+
return registered
776+
622777
def create_dataframe(
623778
self,
624779
partitions: list[list[pa.RecordBatch]],

python/tests/test_context.py

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,69 @@ def test_from_pylist(ctx):
255255
assert df.collect()[0].num_rows == 3
256256

257257

258+
def test_sql_missing_table_without_auto_register(ctx):
259+
arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841
260+
261+
with pytest.raises(Exception, match="not found") as excinfo:
262+
ctx.sql("SELECT * FROM arrow_table").collect()
263+
264+
missing = getattr(excinfo.value, "missing_table_names", None)
265+
assert missing is not None
266+
assert "arrow_table" in set(ctx._extract_missing_table_names(excinfo.value))
267+
268+
269+
def test_sql_auto_register_arrow_table():
270+
ctx = SessionContext(auto_register_python_variables=True)
271+
arrow_table = pa.Table.from_pydict({"value": [1, 2, 3]}) # noqa: F841
272+
273+
result = ctx.sql(
274+
"SELECT SUM(value) AS total FROM arrow_table",
275+
).collect()
276+
277+
assert ctx.table_exist("arrow_table")
278+
assert result[0].column(0).to_pylist()[0] == 6
279+
280+
281+
def test_sql_auto_register_arrow_outer_scope():
282+
ctx = SessionContext()
283+
ctx.auto_register_python_variables = True
284+
arrow_table = pa.Table.from_pydict({"value": [1, 2, 3, 4]}) # noqa: F841
285+
286+
def run_query():
287+
return ctx.sql(
288+
"SELECT COUNT(*) AS total_rows FROM arrow_table",
289+
).collect()
290+
291+
result = run_query()
292+
assert result[0].column(0).to_pylist()[0] == 4
293+
294+
295+
def test_sql_auto_register_pandas_dataframe():
296+
pd = pytest.importorskip("pandas")
297+
298+
ctx = SessionContext(auto_register_python_variables=True)
299+
pandas_df = pd.DataFrame({"value": [1, 2, 3, 4]}) # noqa: F841
300+
301+
result = ctx.sql(
302+
"SELECT AVG(value) AS avg_value FROM pandas_df",
303+
).collect()
304+
305+
assert pytest.approx(result[0].column(0).to_pylist()[0]) == 2.5
306+
307+
308+
def test_sql_auto_register_polars_dataframe():
309+
pl = pytest.importorskip("polars")
310+
311+
ctx = SessionContext(auto_register_python_variables=True)
312+
polars_df = pl.DataFrame({"value": [2, 4, 6]}) # noqa: F841
313+
314+
result = ctx.sql(
315+
"SELECT MIN(value) AS min_value FROM polars_df",
316+
).collect()
317+
318+
assert result[0].column(0).to_pylist()[0] == 2
319+
320+
258321
def test_from_pydict(ctx):
259322
# create a dataframe from Python dictionary
260323
data = {"a": [1, 2, 3], "b": [4, 5, 6]}
@@ -484,7 +547,7 @@ def test_table_exist(ctx):
484547

485548

486549
def test_table_not_found(ctx):
487-
from uuid import uuid4
550+
from uuid import uuid4 # noqa: PLC0415
488551

489552
with pytest.raises(KeyError):
490553
ctx.table(f"not-found-{uuid4()}")

0 commit comments

Comments
 (0)