2929
3030import pyarrow as pa
3131
32- from datafusion .catalog import Catalog , CatalogProvider , Table
32+ from datafusion .catalog import Catalog
3333from datafusion .dataframe import DataFrame
34- from datafusion .expr import SortKey , sort_list_to_raw_sort_list
34+ from datafusion .expr import sort_list_to_raw_sort_list
3535from datafusion .record_batch import RecordBatchStream
36- from datafusion .user_defined import AggregateUDF , ScalarUDF , TableFunction , WindowUDF
36+ from datafusion .utils import _normalize_table_provider
3737
3838from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal
3939from ._internal import SessionConfig as SessionConfigInternal
4848 import pandas as pd
4949 import polars as pl # type: ignore[import]
5050
51+ from datafusion import TableProvider
52+ from datafusion .catalog import CatalogProvider , Table
53+ from datafusion .expr import SortKey
5154 from datafusion .plan import ExecutionPlan , LogicalPlan
55+ from datafusion .user_defined import (
56+ AggregateUDF ,
57+ ScalarUDF ,
58+ TableFunction ,
59+ WindowUDF ,
60+ )
5261
5362
5463class ArrowStreamExportable (Protocol ):
@@ -733,7 +742,7 @@ def from_polars(self, data: pl.DataFrame, name: str | None = None) -> DataFrame:
733742 # https://github.com/apache/datafusion-python/pull/1016#discussion_r1983239116
734743 # is the discussion on how we arrived at adding register_view
735744 def register_view (self , name : str , df : DataFrame ) -> None :
736- """Register a :py:class: `~datafusion.detaframe .DataFrame` as a view.
745+ """Register a :py:class:`~datafusion.dataframe .DataFrame` as a view.
737746
738747 Args:
739748 name (str): The name to register the view under.
@@ -742,16 +751,29 @@ def register_view(self, name: str, df: DataFrame) -> None:
742751 view = df .into_view ()
743752 self .ctx .register_table (name , view )
744753
745- def register_table (self , name : str , table : Table ) -> None :
746- """Register a :py:class: `~datafusion.catalog.Table` as a table.
754+ def register_table (
755+ self , name : str , table : Table | TableProvider | TableProviderExportable
756+ ) -> None :
757+ """Register a Table or TableProvider.
747758
748- The registered table can be referenced from SQL statement executed against.
759+ The registered table can be referenced from SQL statements executed against
760+ this context.
761+
762+ Plain :py:class:`~datafusion.dataframe.DataFrame` objects are not supported;
763+ convert them first with :meth:`datafusion.dataframe.DataFrame.into_view` or
764+ :meth:`datafusion.TableProvider.from_dataframe`.
765+
766+ Objects implementing ``__datafusion_table_provider__`` are also supported
767+ and treated as :py:class:`~datafusion.TableProvider` instances.
749768
750769 Args:
751770 name: Name of the resultant table.
752- table: DataFusion table to add to the session context.
771+ table: DataFusion :class:`Table`, :class:`TableProvider`, or any object
772+ implementing ``__datafusion_table_provider__`` to add to the session
773+ context.
753774 """
754- self .ctx .register_table (name , table .table )
775+ provider = _normalize_table_provider (table )
776+ self .ctx .register_table (name , provider )
755777
756778 def deregister_table (self , name : str ) -> None :
757779 """Remove a table from the session."""
@@ -771,14 +793,21 @@ def register_catalog_provider(
771793 self .ctx .register_catalog_provider (name , provider )
772794
773795 def register_table_provider (
774- self , name : str , provider : TableProviderExportable
796+ self , name : str , provider : Table | TableProvider | TableProviderExportable
775797 ) -> None :
776798 """Register a table provider.
777799
778- This table provider must have a method called ``__datafusion_table_provider__``
779- which returns a PyCapsule that exposes a ``FFI_TableProvider``.
800+ Deprecated: use :meth:`register_table` instead.
801+
802+ Objects implementing ``__datafusion_table_provider__`` are also supported
803+ and treated as :py:class:`~datafusion.TableProvider` instances.
780804 """
781- self .ctx .register_table_provider (name , provider )
805+ warnings .warn (
806+ "register_table_provider is deprecated; use register_table" ,
807+ DeprecationWarning ,
808+ stacklevel = 2 ,
809+ )
810+ self .register_table (name , provider )
782811
783812 def register_udtf (self , func : TableFunction ) -> None :
784813 """Register a user defined table function."""
0 commit comments