1919
2020from __future__ import annotations
2121
22+ import inspect
23+ import re
2224import warnings
23- from typing import TYPE_CHECKING , Any , Protocol
25+ from typing import TYPE_CHECKING , Any , Iterator , Protocol
2426
2527try :
2628 from warnings import deprecated # Python 3.13+
4143from ._internal import SQLOptions as SQLOptionsInternal
4244from ._internal import expr as expr_internal
4345
46+ _MISSING_TABLE_PATTERN = re .compile (r"(?i)(?:table|view) '([^']+)' not found" )
47+
4448if TYPE_CHECKING :
4549 import pathlib
4650 from collections .abc import Sequence
@@ -483,6 +487,8 @@ def __init__(
483487 self ,
484488 config : SessionConfig | None = None ,
485489 runtime : RuntimeEnvBuilder | None = None ,
490+ * ,
491+ auto_register_python_variables : bool = False ,
486492 ) -> None :
487493 """Main interface for executing queries with DataFusion.
488494
@@ -493,6 +499,9 @@ def __init__(
493499 Args:
494500 config: Session configuration options.
495501 runtime: Runtime configuration options.
502+ auto_register_python_variables: Automatically register Arrow-like
503+ Python objects referenced in SQL queries when they are available
504+ in the caller's scope.
496505
497506 Example usage:
498507
@@ -508,6 +517,7 @@ def __init__(
508517 runtime = runtime .config_internal if runtime is not None else None
509518
510519 self .ctx = SessionContextInternal (config , runtime )
520+ self ._auto_register_python_variables = auto_register_python_variables
511521
512522 def __repr__ (self ) -> str :
513523 """Print a string representation of the Session Context."""
@@ -534,8 +544,18 @@ def enable_url_table(self) -> SessionContext:
534544 klass = self .__class__
535545 obj = klass .__new__ (klass )
536546 obj .ctx = self .ctx .enable_url_table ()
547+ obj ._auto_register_python_variables = self ._auto_register_python_variables
537548 return obj
538549
550+ @property
551+ def auto_register_python_variables (self ) -> bool :
552+ """Toggle automatic registration of Python variables in SQL queries."""
553+ return self ._auto_register_python_variables
554+
555+ @auto_register_python_variables .setter
556+ def auto_register_python_variables (self , enabled : bool ) -> None :
557+ self ._auto_register_python_variables = bool (enabled )
558+
539559 def register_object_store (
540560 self , schema : str , store : Any , host : str | None = None
541561 ) -> None :
@@ -600,9 +620,12 @@ def sql(self, query: str, options: SQLOptions | None = None) -> DataFrame:
600620 Returns:
601621 DataFrame representation of the SQL query.
602622 """
603- if options is None :
604- return DataFrame (self .ctx .sql (query ))
605- return DataFrame (self .ctx .sql_with_options (query , options .options_internal ))
623+ options_internal = None if options is None else options .options_internal
624+ return self ._sql_with_retry (
625+ query ,
626+ options_internal ,
627+ self ._auto_register_python_variables ,
628+ )
606629
607630 def sql_with_options (self , query : str , options : SQLOptions ) -> DataFrame :
608631 """Create a :py:class:`~datafusion.dataframe.DataFrame` from SQL query text.
@@ -619,6 +642,138 @@ def sql_with_options(self, query: str, options: SQLOptions) -> DataFrame:
619642 """
620643 return self .sql (query , options )
621644
645+ def _sql_with_retry (
646+ self ,
647+ query : str ,
648+ options_internal : SQLOptionsInternal | None ,
649+ allow_retry : bool ,
650+ ) -> DataFrame :
651+ try :
652+ if options_internal is None :
653+ return DataFrame (self .ctx .sql (query ))
654+ return DataFrame (self .ctx .sql_with_options (query , options_internal ))
655+ except Exception as exc :
656+ if not allow_retry or not self ._handle_missing_table_error (exc ):
657+ raise
658+ return self ._sql_with_retry (query , options_internal , allow_retry )
659+
660+ def _handle_missing_table_error (self , error : Exception ) -> bool :
661+ missing_tables = self ._extract_missing_table_names (error )
662+ if not missing_tables :
663+ return False
664+
665+ registered_any = False
666+ attempted : set [str ] = set ()
667+ for raw_name in missing_tables :
668+ for candidate in self ._candidate_table_names (raw_name ):
669+ if candidate in attempted :
670+ continue
671+ attempted .add (candidate )
672+
673+ value = self ._lookup_python_variable (candidate )
674+ if value is None :
675+ continue
676+ if self ._register_python_value (candidate , value ):
677+ registered_any = True
678+ break
679+ return registered_any
680+
681+ def _candidate_table_names (self , identifier : str ) -> Iterator [str ]:
682+ cleaned = identifier .strip ().strip ('"' )
683+ if not cleaned :
684+ return
685+
686+ seen : set [str ] = set ()
687+ candidates = [cleaned ]
688+ if "." in cleaned :
689+ candidates .append (cleaned .rsplit ("." , 1 )[- 1 ])
690+
691+ for candidate in candidates :
692+ normalized = candidate .strip ()
693+ if not normalized or normalized in seen :
694+ continue
695+ seen .add (normalized )
696+ yield normalized
697+
698+ def _extract_missing_table_names (self , error : Exception ) -> set [str ]:
699+ names : set [str ] = set ()
700+ attribute = getattr (error , "missing_table_names" , None )
701+ if attribute is not None :
702+ if isinstance (attribute , (list , tuple , set , frozenset )):
703+ for item in attribute :
704+ if item is None :
705+ continue
706+ for candidate in self ._candidate_table_names (str (item )):
707+ names .add (candidate )
708+ elif attribute is not None :
709+ for candidate in self ._candidate_table_names (str (attribute )):
710+ names .add (candidate )
711+ if names :
712+ return names
713+
714+ message = str (error )
715+ return {match .group (1 ) for match in _MISSING_TABLE_PATTERN .finditer (message )}
716+
717+ def _lookup_python_variable (self , name : str ) -> Any | None :
718+ frame = inspect .currentframe ()
719+ outer = frame .f_back if frame is not None else None
720+ lower_name = name .lower ()
721+
722+ try :
723+ while outer is not None :
724+ for mapping in (outer .f_locals , outer .f_globals ):
725+ if not mapping :
726+ continue
727+ if name in mapping :
728+ value = mapping [name ]
729+ if value is not None :
730+ return value
731+ # allow outer scopes to provide a non-``None`` value
732+ continue
733+ for key , value in mapping .items ():
734+ if value is None :
735+ continue
736+ if key == name or key .lower () == lower_name :
737+ return value
738+ outer = outer .f_back
739+ finally :
740+ del outer
741+ del frame
742+ return None
743+
744+ def _register_python_value (self , table_name : str , value : Any ) -> bool :
745+ if value is None :
746+ return False
747+
748+ registered = False
749+ if isinstance (value , DataFrame ):
750+ self .register_view (table_name , value )
751+ registered = True
752+ elif isinstance (value , Table ):
753+ self .register_table (table_name , value )
754+ registered = True
755+ else :
756+ provider = getattr (value , "__datafusion_table_provider__" , None )
757+ if callable (provider ):
758+ self .register_table_provider (table_name , value )
759+ registered = True
760+ elif hasattr (value , "__arrow_c_stream__" ) or hasattr (
761+ value , "__arrow_c_array__"
762+ ):
763+ self .from_arrow (value , name = table_name )
764+ registered = True
765+ else :
766+ module_name = getattr (type (value ), "__module__" , "" ) or ""
767+ class_name = getattr (type (value ), "__name__" , "" ) or ""
768+ if module_name .startswith ("pandas." ) and class_name == "DataFrame" :
769+ self .from_pandas (value , name = table_name )
770+ registered = True
771+ elif module_name .startswith ("polars" ) and class_name == "DataFrame" :
772+ self .from_polars (value , name = table_name )
773+ registered = True
774+
775+ return registered
776+
622777 def create_dataframe (
623778 self ,
624779 partitions : list [list [pa .RecordBatch ]],
0 commit comments