Skip to content

Commit f9ca741

Browse files
djouallahclaude
andcommitted
feat: Add truncated_rows parameter to register_csv and read_csv
Exposes the truncated_rows parameter from Rust DataFusion to Python bindings. This enables reading CSV files with inconsistent column counts by creating a union schema and filling missing columns with nulls. The parameter was added to DataFusion Rust in PR apache/datafusion#17553 and is now available in datafusion 51.0.0. Changes: - Add truncated_rows parameter to SessionContext.register_csv() - Add truncated_rows parameter to SessionContext.read_csv() - Add comprehensive tests for both methods - Update docstrings with parameter documentation Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
1 parent eaa3f79 commit f9ca741

File tree

4 files changed

+143
-4
lines changed

4 files changed

+143
-4
lines changed

python/datafusion/context.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -927,6 +927,7 @@ def register_csv(
927927
schema_infer_max_records: int = 1000,
928928
file_extension: str = ".csv",
929929
file_compression_type: str | None = None,
930+
truncated_rows: bool = False,
930931
) -> None:
931932
"""Register a CSV file as a table.
932933
@@ -946,6 +947,10 @@ def register_csv(
946947
file_extension: File extension; only files with this extension are
947948
selected for data input.
948949
file_compression_type: File compression type.
950+
truncated_rows: Allow reading CSV files with inconsistent column
951+
counts by creating a union schema. Missing columns are filled
952+
with nulls. Default is False. Useful for evolving datasets
953+
where newer files have additional columns.
949954
"""
950955
path = [str(p) for p in path] if isinstance(path, list) else str(path)
951956

@@ -958,6 +963,7 @@ def register_csv(
958963
schema_infer_max_records,
959964
file_extension,
960965
file_compression_type,
966+
truncated_rows,
961967
)
962968

963969
def register_json(
@@ -1123,6 +1129,7 @@ def read_csv(
11231129
file_extension: str = ".csv",
11241130
table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None,
11251131
file_compression_type: str | None = None,
1132+
truncated_rows: bool = False,
11261133
) -> DataFrame:
11271134
"""Read a CSV data source.
11281135
@@ -1140,6 +1147,10 @@ def read_csv(
11401147
selected for data input.
11411148
table_partition_cols: Partition columns.
11421149
file_compression_type: File compression type.
1150+
truncated_rows: Allow reading CSV files with inconsistent column
1151+
counts by creating a union schema. Missing columns are filled
1152+
with nulls. Default is False. Useful for evolving datasets
1153+
where newer files have additional columns.
11431154
11441155
Returns:
11451156
DataFrame representation of the read CSV files
@@ -1160,6 +1171,7 @@ def read_csv(
11601171
file_extension,
11611172
table_partition_cols,
11621173
file_compression_type,
1174+
truncated_rows,
11631175
)
11641176
)
11651177

python/tests/test_context.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pyarrow as pa
2222
import pyarrow.dataset as ds
2323
import pytest
24+
from pyarrow.csv import write_csv
2425
from datafusion import (
2526
DataFrame,
2627
RuntimeEnvBuilder,
@@ -639,6 +640,65 @@ def test_read_csv_compressed(ctx, tmp_path):
639640
csv_df.select(column("c1")).show()
640641

641642

643+
def test_read_csv_truncated_rows(ctx, tmp_path):
644+
# Create CSV file with 3 columns
645+
path1 = tmp_path / "file1.csv"
646+
table1 = pa.Table.from_arrays(
647+
[
648+
[1, 2],
649+
["a", "b"],
650+
[1.1, 2.2],
651+
],
652+
names=["int", "str", "float"],
653+
)
654+
write_csv(table1, path1)
655+
656+
# Create CSV file with 5 columns
657+
path2 = tmp_path / "file2.csv"
658+
table2 = pa.Table.from_arrays(
659+
[
660+
[3, 4],
661+
["c", "d"],
662+
[3.3, 4.4],
663+
["x", "y"],
664+
[10, 20],
665+
],
666+
names=["int", "str", "float", "extra1", "extra2"],
667+
)
668+
write_csv(table2, path2)
669+
670+
# Read with truncated_rows=True to handle mismatched columns
671+
df = ctx.read_csv([path1, path2], truncated_rows=True)
672+
result = df.collect()
673+
result_table = pa.Table.from_batches(result)
674+
675+
# Should have 5 columns (union schema)
676+
assert len(result_table.schema) == 5
677+
assert result_table.schema.names == ["int", "str", "float", "extra1", "extra2"]
678+
679+
# Should have 4 rows total (2 from each file)
680+
assert result_table.num_rows == 4
681+
682+
# Convert to dict for easier validation
683+
result_dict = result_table.to_pydict()
684+
685+
# Check that rows from file1 have nulls for extra1 and extra2
686+
assert result_dict["int"] == [1, 2, 3, 4]
687+
assert result_dict["str"] == ["a", "b", "c", "d"]
688+
assert result_dict["float"] == [1.1, 2.2, 3.3, 4.4]
689+
690+
# First two rows should have None for extra1 and extra2
691+
assert result_dict["extra1"][0] is None
692+
assert result_dict["extra1"][1] is None
693+
assert result_dict["extra1"][2] == "x"
694+
assert result_dict["extra1"][3] == "y"
695+
696+
assert result_dict["extra2"][0] is None
697+
assert result_dict["extra2"][1] is None
698+
assert result_dict["extra2"][2] == 10
699+
assert result_dict["extra2"][3] == 20
700+
701+
642702
def test_read_parquet(ctx):
643703
parquet_df = ctx.read_parquet(path="parquet/data/alltypes_plain.parquet")
644704
parquet_df.show()

python/tests/test_sql.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,67 @@ def test_register_csv_list(ctx, tmp_path):
137137
assert int_sum == 2 * sum(int_values)
138138

139139

140+
def test_register_csv_truncated_rows(ctx, tmp_path):
141+
# Create CSV file with 3 columns
142+
path1 = tmp_path / "file1.csv"
143+
table1 = pa.Table.from_arrays(
144+
[
145+
[1, 2],
146+
["a", "b"],
147+
[1.1, 2.2],
148+
],
149+
names=["int", "str", "float"],
150+
)
151+
write_csv(table1, path1)
152+
153+
# Create CSV file with 5 columns
154+
path2 = tmp_path / "file2.csv"
155+
table2 = pa.Table.from_arrays(
156+
[
157+
[3, 4],
158+
["c", "d"],
159+
[3.3, 4.4],
160+
["x", "y"],
161+
[10, 20],
162+
],
163+
names=["int", "str", "float", "extra1", "extra2"],
164+
)
165+
write_csv(table2, path2)
166+
167+
# Register with truncated_rows=True to handle mismatched columns
168+
ctx.register_csv("mixed", [path1, path2], truncated_rows=True)
169+
170+
# Verify the table exists and has correct schema
171+
result = ctx.sql("SELECT * FROM mixed").collect()
172+
result_table = pa.Table.from_batches(result)
173+
174+
# Should have 5 columns (union schema)
175+
assert len(result_table.schema) == 5
176+
assert result_table.schema.names == ["int", "str", "float", "extra1", "extra2"]
177+
178+
# Should have 4 rows total (2 from each file)
179+
assert result_table.num_rows == 4
180+
181+
# Convert to dict for easier validation
182+
result_dict = result_table.to_pydict()
183+
184+
# Check that rows from file1 have nulls for extra1 and extra2
185+
assert result_dict["int"] == [1, 2, 3, 4]
186+
assert result_dict["str"] == ["a", "b", "c", "d"]
187+
assert result_dict["float"] == [1.1, 2.2, 3.3, 4.4]
188+
189+
# First two rows should have None for extra1 and extra2
190+
assert result_dict["extra1"][0] is None
191+
assert result_dict["extra1"][1] is None
192+
assert result_dict["extra1"][2] == "x"
193+
assert result_dict["extra1"][3] == "y"
194+
195+
assert result_dict["extra2"][0] is None
196+
assert result_dict["extra2"][1] is None
197+
assert result_dict["extra2"][2] == 10
198+
assert result_dict["extra2"][3] == 20
199+
200+
140201
def test_register_http_csv(ctx):
141202
url = "https://raw.githubusercontent.com/ibis-project/testing-data/refs/heads/master/csv/diamonds.csv"
142203
ctx.register_object_store("", Http(url))

src/context.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,8 @@ impl PySessionContext {
715715
delimiter=",",
716716
schema_infer_max_records=1000,
717717
file_extension=".csv",
718-
file_compression_type=None))]
718+
file_compression_type=None,
719+
truncated_rows=false))]
719720
pub fn register_csv(
720721
&self,
721722
name: &str,
@@ -726,6 +727,7 @@ impl PySessionContext {
726727
schema_infer_max_records: usize,
727728
file_extension: &str,
728729
file_compression_type: Option<String>,
730+
truncated_rows: bool,
729731
py: Python,
730732
) -> PyDataFusionResult<()> {
731733
let delimiter = delimiter.as_bytes();
@@ -740,7 +742,8 @@ impl PySessionContext {
740742
.delimiter(delimiter[0])
741743
.schema_infer_max_records(schema_infer_max_records)
742744
.file_extension(file_extension)
743-
.file_compression_type(parse_file_compression_type(file_compression_type)?);
745+
.file_compression_type(parse_file_compression_type(file_compression_type)?)
746+
.truncated_rows(truncated_rows);
744747
options.schema = schema.as_ref().map(|x| &x.0);
745748

746749
if path.is_instance_of::<PyList>() {
@@ -969,7 +972,8 @@ impl PySessionContext {
969972
schema_infer_max_records=1000,
970973
file_extension=".csv",
971974
table_partition_cols=vec![],
972-
file_compression_type=None))]
975+
file_compression_type=None,
976+
truncated_rows=false))]
973977
pub fn read_csv(
974978
&self,
975979
path: &Bound<'_, PyAny>,
@@ -980,6 +984,7 @@ impl PySessionContext {
980984
file_extension: &str,
981985
table_partition_cols: Vec<(String, PyArrowType<DataType>)>,
982986
file_compression_type: Option<String>,
987+
truncated_rows: bool,
983988
py: Python,
984989
) -> PyDataFusionResult<PyDataFrame> {
985990
let delimiter = delimiter.as_bytes();
@@ -1000,7 +1005,8 @@ impl PySessionContext {
10001005
.map(|(name, ty)| (name, ty.0))
10011006
.collect::<Vec<(String, DataType)>>(),
10021007
)
1003-
.file_compression_type(parse_file_compression_type(file_compression_type)?);
1008+
.file_compression_type(parse_file_compression_type(file_compression_type)?)
1009+
.truncated_rows(truncated_rows);
10041010
options.schema = schema.as_ref().map(|x| &x.0);
10051011

10061012
if path.is_instance_of::<PyList>() {

0 commit comments

Comments
 (0)