Skip to content

Commit 6b76e8f

Browse files
committed
feat: add tests for registering DataFrame as a view and handling TableProvider
1 parent 0ff5e90 commit 6b76e8f

File tree

2 files changed

+69
-0
lines changed

2 files changed

+69
-0
lines changed

python/tests/test_context.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SessionConfig,
2828
SessionContext,
2929
SQLOptions,
30+
TableProvider,
3031
column,
3132
literal,
3233
)
@@ -330,6 +331,35 @@ def test_deregister_table(ctx, database):
330331
assert public.names() == {"csv1", "csv2"}
331332

332333

334+
def test_register_table_from_dataframe_into_view(ctx):
335+
df = ctx.from_pydict({"a": [1, 2]})
336+
provider = df.into_view()
337+
ctx.register_table("view_tbl", provider)
338+
result = ctx.sql("SELECT * FROM view_tbl").collect()
339+
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
340+
341+
342+
def test_table_provider_from_capsule(ctx):
343+
df = ctx.from_pydict({"a": [1, 2]})
344+
provider = df.into_view()
345+
capsule = provider.__datafusion_table_provider__()
346+
provider2 = TableProvider.from_capsule(capsule)
347+
ctx.register_table("capsule_tbl", provider2)
348+
result = ctx.sql("SELECT * FROM capsule_tbl").collect()
349+
assert [b.to_pydict() for b in result] == [{"a": [1, 2]}]
350+
351+
352+
def test_table_provider_from_capsule_invalid():
353+
with pytest.raises(Exception): # noqa: B017
354+
TableProvider.from_capsule(object())
355+
356+
357+
def test_register_table_with_dataframe_errors(ctx):
358+
df = ctx.from_pydict({"a": [1]})
359+
with pytest.raises(Exception): # noqa: B017
360+
ctx.register_table("bad", df)
361+
362+
333363
def test_register_dataset(ctx):
334364
# create a RecordBatch and register it as a pyarrow.dataset.Dataset
335365
batch = pa.RecordBatch.from_arrays(

tests/dataframe_into_view.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
use std::sync::Arc;
2+
3+
use datafusion::arrow::array::Int32Array;
4+
use datafusion::arrow::datatypes::{DataType, Field, Schema};
5+
use datafusion::arrow::record_batch::RecordBatch;
6+
use datafusion::datasource::MemTable;
7+
use datafusion::prelude::SessionContext;
8+
use datafusion_python::dataframe::PyDataFrame;
9+
10+
#[test]
11+
fn dataframe_into_view_returns_table_provider() {
12+
// Create an in-memory table with one Int32 column.
13+
let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, false)]));
14+
let batch = RecordBatch::try_new(
15+
Arc::clone(&schema),
16+
vec![Arc::new(Int32Array::from(vec![1, 2, 3]))],
17+
)
18+
.unwrap();
19+
let table = MemTable::try_new(schema, vec![vec![batch]]).unwrap();
20+
21+
// Build a DataFrame from the table and convert it into a view.
22+
let ctx = SessionContext::new();
23+
let df = ctx.read_table(Arc::new(table)).unwrap();
24+
let py_df = PyDataFrame::new(df);
25+
let provider = py_df.into_view().unwrap();
26+
27+
// Register the view in a new context and ensure it can be queried.
28+
let ctx = SessionContext::new();
29+
ctx.register_table("view", provider.as_table().table())
30+
.unwrap();
31+
32+
let rt = tokio::runtime::Runtime::new().unwrap();
33+
let batches = rt
34+
.block_on(ctx.sql("SELECT * FROM view").unwrap().collect())
35+
.unwrap();
36+
37+
assert_eq!(batches.len(), 1);
38+
assert_eq!(batches[0].num_rows(), 3);
39+
}

0 commit comments

Comments
 (0)