Skip to content

Commit 54e8938

Browse files
committed
add pytableprovider
1 parent 917d8a8 commit 54e8938

File tree

3 files changed

+28
-4
lines changed

3 files changed

+28
-4
lines changed

python/datafusion/context.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def from_pylist(
636636

637637
def from_pydict(
638638
self, data: dict[str, list[Any]], name: str | None = None
639-
) -> DataFramee
639+
) -> DataFrame:
640640
"""Create a :py:class:`~datafusion.dataframe.DataFrame` from a dictionary.
641641
642642
Args:

python/tests/test_view.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ def test_register_filtered_dataframe():
2222

2323
# Filter the DataFrame (for example, keep rows where a > 2)
2424
df_filtered = df.filter(col("a") > literal(2))
25-
df_filtered = df_filtered.into_view()
25+
view = df_filtered.into_view()
2626

2727
# Register the filtered DataFrame as a table called "view1"
28-
ctx.register_table("view1", df_filtered)
28+
ctx.register_table("view1", view)
2929

3030
# Now run a SQL query against the registered table "view1"
3131
df_view = ctx.sql("SELECT * FROM view1")

src/dataframe.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ use datafusion::arrow::util::pretty;
3030
use datafusion::common::UnnestOptions;
3131
use datafusion::config::{CsvOptions, TableParquetOptions};
3232
use datafusion::dataframe::{DataFrame, DataFrameWriteOptions};
33+
use datafusion::datasource::TableProvider;
3334
use datafusion::execution::SendableRecordBatchStream;
3435
use datafusion::parquet::basic::{BrotliLevel, Compression, GzipLevel, ZstdLevel};
3536
use datafusion::prelude::*;
36-
use datafusion::sql::sqlparser::ast::Table;
3737
use pyo3::exceptions::PyValueError;
3838
use pyo3::prelude::*;
3939
use pyo3::pybacked::PyBackedStr;
@@ -51,6 +51,21 @@ use crate::{
5151
expr::{sort_expr::PySortExpr, PyExpr},
5252
};
5353

54+
#[pyclass(name = "TableProvider", module = "datafusion")]
55+
pub struct PyTableProvider {
56+
provider: Arc<dyn TableProvider>,
57+
}
58+
59+
impl PyTableProvider {
60+
pub fn new(provider: Arc<dyn TableProvider>) -> Self {
61+
Self { provider }
62+
}
63+
64+
pub fn get_provider(&self) -> Arc<dyn TableProvider> {
65+
self.provider.clone()
66+
}
67+
}
68+
5469
/// A PyDataFrame is a representation of a logical plan and an API to compose statements.
5570
/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
5671
/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
@@ -90,6 +105,15 @@ impl PyDataFrame {
90105
}
91106
}
92107

108+
/// Convert this DataFrame into a view (i.e. a TableProvider) that can be registered.
109+
fn into_view(&self) -> PyDataFusionResult<PyTableProvider> {
110+
// Call the underlying Rust DataFrame::into_view method.
111+
// Note that the Rust method consumes self; here we clone the inner Arc<DataFrame>
112+
// so that we don’t invalidate this PyDataFrame.
113+
let table_provider = self.df.as_ref().clone().into_view();
114+
Ok(PyTableProvider::new(table_provider))
115+
}
116+
93117
fn __repr__(&self, py: Python) -> PyDataFusionResult<String> {
94118
let df = self.df.as_ref().clone().limit(0, Some(10))?;
95119
let batches = wait_for_future(py, df.collect())?;

0 commit comments

Comments
 (0)