Skip to content

Commit f822495

Browse files
authored
feat: make register_csv accept a list of paths (#883)
1 parent 2df33d3 commit f822495

File tree

3 files changed

+99
-11
lines changed

3 files changed

+99
-11
lines changed

python/datafusion/context.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -714,7 +714,7 @@ def register_parquet(
714714
def register_csv(
715715
self,
716716
name: str,
717-
path: str | pathlib.Path,
717+
path: str | pathlib.Path | list[str | pathlib.Path],
718718
schema: pyarrow.Schema | None = None,
719719
has_header: bool = True,
720720
delimiter: str = ",",
@@ -728,7 +728,7 @@ def register_csv(
728728
729729
Args:
730730
name: Name of the table to register.
731-
path: Path to the CSV file.
731+
path: Path to the CSV file. It also accepts a list of Paths.
732732
schema: An optional schema representing the CSV file. If None, the
733733
CSV reader will try to infer it based on data in file.
734734
has_header: Whether the CSV file have a header. If schema inference
@@ -741,9 +741,14 @@ def register_csv(
741741
selected for data input.
742742
file_compression_type: File compression type.
743743
"""
744+
if isinstance(path, list):
745+
path = [str(p) for p in path]
746+
else:
747+
path = str(path)
748+
744749
self.ctx.register_csv(
745750
name,
746-
str(path),
751+
path,
747752
schema,
748753
has_header,
749754
delimiter,

python/datafusion/tests/test_sql.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,41 @@ def test_register_csv(ctx, tmp_path):
104104
ctx.register_csv("csv4", path, file_compression_type="rar")
105105

106106

107+
def test_register_csv_list(ctx, tmp_path):
108+
path = tmp_path / "test.csv"
109+
110+
int_values = [1, 2, 3, 4]
111+
table = pa.Table.from_arrays(
112+
[
113+
int_values,
114+
["a", "b", "c", "d"],
115+
[1.1, 2.2, 3.3, 4.4],
116+
],
117+
names=["int", "str", "float"],
118+
)
119+
write_csv(table, path)
120+
ctx.register_csv("csv", path)
121+
122+
csv_df = ctx.table("csv")
123+
expected_count = csv_df.count() * 2
124+
ctx.register_csv(
125+
"double_csv",
126+
path=[
127+
path,
128+
path,
129+
],
130+
)
131+
132+
double_csv_df = ctx.table("double_csv")
133+
actual_count = double_csv_df.count()
134+
assert actual_count == expected_count
135+
136+
int_sum = ctx.sql("select sum(int) from double_csv").to_pydict()[
137+
"sum(double_csv.int)"
138+
][0]
139+
assert int_sum == 2 * sum(int_values)
140+
141+
107142
def test_register_parquet(ctx, tmp_path):
108143
path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
109144
ctx.register_parquet("t", path)

src/context.rs

Lines changed: 56 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,17 +46,21 @@ use crate::utils::{get_tokio_runtime, wait_for_future};
4646
use datafusion::arrow::datatypes::{DataType, Schema, SchemaRef};
4747
use datafusion::arrow::pyarrow::PyArrowType;
4848
use datafusion::arrow::record_batch::RecordBatch;
49-
use datafusion::common::ScalarValue;
49+
use datafusion::catalog_common::TableReference;
50+
use datafusion::common::{exec_err, ScalarValue};
5051
use datafusion::datasource::file_format::file_compression_type::FileCompressionType;
5152
use datafusion::datasource::file_format::parquet::ParquetFormat;
5253
use datafusion::datasource::listing::{
5354
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
5455
};
5556
use datafusion::datasource::MemTable;
5657
use datafusion::datasource::TableProvider;
57-
use datafusion::execution::context::{SQLOptions, SessionConfig, SessionContext, TaskContext};
58+
use datafusion::execution::context::{
59+
DataFilePaths, SQLOptions, SessionConfig, SessionContext, TaskContext,
60+
};
5861
use datafusion::execution::disk_manager::DiskManagerConfig;
5962
use datafusion::execution::memory_pool::{FairSpillPool, GreedyMemoryPool, UnboundedMemoryPool};
63+
use datafusion::execution::options::ReadOptions;
6064
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
6165
use datafusion::physical_plan::SendableRecordBatchStream;
6266
use datafusion::prelude::{
@@ -621,7 +625,7 @@ impl PySessionContext {
621625
pub fn register_csv(
622626
&mut self,
623627
name: &str,
624-
path: PathBuf,
628+
path: &Bound<'_, PyAny>,
625629
schema: Option<PyArrowType<Schema>>,
626630
has_header: bool,
627631
delimiter: &str,
@@ -630,9 +634,6 @@ impl PySessionContext {
630634
file_compression_type: Option<String>,
631635
py: Python,
632636
) -> PyResult<()> {
633-
let path = path
634-
.to_str()
635-
.ok_or_else(|| PyValueError::new_err("Unable to convert path to a string"))?;
636637
let delimiter = delimiter.as_bytes();
637638
if delimiter.len() != 1 {
638639
return Err(PyValueError::new_err(
@@ -648,8 +649,15 @@ impl PySessionContext {
648649
.file_compression_type(parse_file_compression_type(file_compression_type)?);
649650
options.schema = schema.as_ref().map(|x| &x.0);
650651

651-
let result = self.ctx.register_csv(name, path, options);
652-
wait_for_future(py, result).map_err(DataFusionError::from)?;
652+
if path.is_instance_of::<PyList>() {
653+
let paths = path.extract::<Vec<String>>()?;
654+
let result = self.register_csv_from_multiple_paths(name, paths, options);
655+
wait_for_future(py, result).map_err(DataFusionError::from)?;
656+
} else {
657+
let path = path.extract::<String>()?;
658+
let result = self.ctx.register_csv(name, &path, options);
659+
wait_for_future(py, result).map_err(DataFusionError::from)?;
660+
}
653661

654662
Ok(())
655663
}
@@ -981,6 +989,46 @@ impl PySessionContext {
981989
async fn _table(&self, name: &str) -> datafusion::common::Result<DataFrame> {
982990
self.ctx.table(name).await
983991
}
992+
993+
async fn register_csv_from_multiple_paths(
994+
&self,
995+
name: &str,
996+
table_paths: Vec<String>,
997+
options: CsvReadOptions<'_>,
998+
) -> datafusion::common::Result<()> {
999+
let table_paths = table_paths.to_urls()?;
1000+
let session_config = self.ctx.copied_config();
1001+
let listing_options =
1002+
options.to_listing_options(&session_config, self.ctx.copied_table_options());
1003+
1004+
let option_extension = listing_options.file_extension.clone();
1005+
1006+
if table_paths.is_empty() {
1007+
return exec_err!("No table paths were provided");
1008+
}
1009+
1010+
// check if the file extension matches the expected extension
1011+
for path in &table_paths {
1012+
let file_path = path.as_str();
1013+
if !file_path.ends_with(option_extension.clone().as_str()) && !path.is_collection() {
1014+
return exec_err!(
1015+
"File path '{file_path}' does not match the expected extension '{option_extension}'"
1016+
);
1017+
}
1018+
}
1019+
1020+
let resolved_schema = options
1021+
.get_resolved_schema(&session_config, self.ctx.state(), table_paths[0].clone())
1022+
.await?;
1023+
1024+
let config = ListingTableConfig::new_with_multi_paths(table_paths)
1025+
.with_listing_options(listing_options)
1026+
.with_schema(resolved_schema);
1027+
let table = ListingTable::try_new(config)?;
1028+
self.ctx
1029+
.register_table(TableReference::Bare { table: name.into() }, Arc::new(table))?;
1030+
Ok(())
1031+
}
9841032
}
9851033

9861034
pub fn convert_table_partition_cols(

0 commit comments

Comments
 (0)