Skip to content

Commit 617f742

Browse files
committed
Add additional tests for csv read options
1 parent af8a205 commit 617f742

File tree

3 files changed

+134
-25
lines changed

3 files changed

+134
-25
lines changed

docs/source/user-guide/io/csv.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,6 @@ If you require additional control over how to read the CSV file, you can use
5555
.with_file_extension(".gz") # File extension other than .csv
5656
)
5757
df = ctx.read_csv("data.csv.gz", options=options)
58+
59+
Details for all CSV reading options can be found on the
60+
`DataFusion documentation site <https://datafusion.apache.org/library-user-guide/custom-table-providers.html>`_.

python/datafusion/options.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
import pyarrow as pa
2626

27+
from datafusion.expr import sort_list_to_raw_sort_list
28+
2729
if TYPE_CHECKING:
2830
from datafusion.expr import SortExpr
2931

@@ -208,6 +210,15 @@ def to_inner(self) -> options.CsvReadOptions:
208210
209211
This is intended for internal use only.
210212
"""
213+
file_sort_order = (
214+
[]
215+
if self.file_sort_order is None
216+
else [
217+
sort_list_to_raw_sort_list(sort_list)
218+
for sort_list in self.file_sort_order
219+
]
220+
)
221+
211222
return options.CsvReadOptions(
212223
has_header=self.has_header,
213224
delimiter=ord(self.delimiter[0]) if self.delimiter else ord(","),
@@ -223,7 +234,7 @@ def to_inner(self) -> options.CsvReadOptions:
223234
self.table_partition_cols
224235
),
225236
file_compression_type=self.file_compression_type or "",
226-
file_sort_order=self.file_sort_order or [],
237+
file_sort_order=file_sort_order,
227238
null_regex=self.null_regex,
228239
truncated_rows=self.truncated_rows,
229240
)

python/tests/test_context.py

Lines changed: 119 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pyarrow.dataset as ds
2323
import pytest
2424
from datafusion import (
25+
CsvReadOptions,
2526
DataFrame,
2627
RuntimeEnvBuilder,
2728
SessionConfig,
@@ -626,6 +627,8 @@ def test_read_csv_list(ctx):
626627
def test_read_csv_compressed(ctx, tmp_path):
627628
test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv")
628629

630+
expected = ctx.read_csv(test_data_path).collect()
631+
629632
# File compression type
630633
gzip_path = tmp_path / "aggregate_test_100.csv.gz"
631634

@@ -636,7 +639,13 @@ def test_read_csv_compressed(ctx, tmp_path):
636639
gzipped_file.writelines(csv_file)
637640

638641
csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz")
639-
csv_df.select(column("c1")).show()
642+
assert csv_df.collect() == expected
643+
644+
csv_df = ctx.read_csv(
645+
gzip_path,
646+
options=CsvReadOptions(file_extension=".gz", file_compression_type="gz"),
647+
)
648+
assert csv_df.collect() == expected
640649

641650

642651
def test_read_parquet(ctx):
@@ -735,43 +744,129 @@ def test_csv_read_options_builder_pattern():
735744
assert options.file_extension == ".tsv"
736745

737746

738-
@pytest.mark.parametrize(
739-
("as_read", "global_ctx"),
740-
[
741-
(True, True),
742-
(True, False),
743-
(False, False),
744-
],
745-
)
746-
def test_read_csv_with_options(tmp_path, as_read, global_ctx):
747-
"""Test reading CSV with CsvReadOptions."""
748-
from datafusion import CsvReadOptions, SessionContext
747+
def read_csv_with_options_inner(
748+
tmp_path: pathlib.Path,
749+
csv_content: str,
750+
options: CsvReadOptions,
751+
expected: pa.RecordBatch,
752+
as_read: bool,
753+
global_ctx: bool,
754+
) -> None:
755+
from datafusion import SessionContext
749756

750757
# Create a test CSV file
751-
csv_path = tmp_path / "test.csv"
752-
csv_content = "name;age;city\nAlice;30;New York\nBob;25\n#Charlie;35;Paris"
758+
group_dir = tmp_path / "group=a"
759+
group_dir.mkdir(exist_ok=True)
760+
761+
csv_path = group_dir / "test.csv"
753762
csv_path.write_text(csv_content)
754763

755764
ctx = SessionContext()
756765

757-
# Test with CsvReadOptions
758-
options = CsvReadOptions(
759-
has_header=True, delimiter=";", comment="#", truncated_rows=True
760-
)
761-
762766
if as_read:
763767
if global_ctx:
764768
from datafusion.io import read_csv
765769

766-
df = read_csv(str(csv_path), options=options)
770+
df = read_csv(str(tmp_path), options=options)
767771
else:
768-
df = ctx.read_csv(str(csv_path), options=options)
772+
df = ctx.read_csv(str(tmp_path), options=options)
769773
else:
770-
ctx.register_csv("test_table", str(csv_path), options=options)
774+
ctx.register_csv("test_table", str(tmp_path), options=options)
771775
df = ctx.sql("SELECT * FROM test_table")
776+
df.show()
772777

773778
# Verify the data
774779
result = df.collect()
775780
assert len(result) == 1
776-
assert result[0].num_columns == 3
777-
assert result[0].column(0).to_pylist() == ["Alice", "Bob", None]
781+
assert result[0] == expected
782+
783+
784+
@pytest.mark.parametrize(
785+
("as_read", "global_ctx"),
786+
[
787+
(True, True),
788+
(True, False),
789+
(False, False),
790+
],
791+
)
792+
def test_read_csv_with_options(tmp_path, as_read, global_ctx):
793+
"""Test reading CSV with CsvReadOptions."""
794+
795+
csv_content = "Alice;30;|New York; NY|\nBob;25\n#Charlie;35;Paris\nPhil;75;Detroit' MI\nKarin;50;|Stockholm\nSweden|" # noqa: E501
796+
797+
# Some of the read options are difficult to test in combination
798+
# such as schema and schema_infer_max_records so run multiple tests
799+
# file_sort_order doesn't impact reading, but included here to ensure
800+
# all options parse correctly
801+
options = CsvReadOptions(
802+
has_header=False,
803+
delimiter=";",
804+
quote="|",
805+
terminator="\n",
806+
escape="\\",
807+
comment="#",
808+
newlines_in_values=True,
809+
schema_infer_max_records=1,
810+
null_regex="[pP]+aris",
811+
truncated_rows=True,
812+
file_sort_order=[[column("column_1").sort(), column("column_2")], ["column_3"]],
813+
)
814+
815+
expected = pa.RecordBatch.from_arrays(
816+
[
817+
pa.array(["Alice", "Bob", "Phil", "Karin"]),
818+
pa.array([30, 25, 75, 50]),
819+
pa.array(["New York; NY", None, "Detroit' MI", "Stockholm\nSweden"]),
820+
],
821+
names=["column_1", "column_2", "column_3"],
822+
)
823+
824+
read_csv_with_options_inner(
825+
tmp_path, csv_content, options, expected, as_read, global_ctx
826+
)
827+
828+
schema = pa.schema(
829+
[
830+
pa.field("name", pa.string(), nullable=False),
831+
pa.field("age", pa.float32(), nullable=False),
832+
pa.field("location", pa.string(), nullable=True),
833+
]
834+
)
835+
options.with_schema(schema)
836+
837+
expected = pa.RecordBatch.from_arrays(
838+
[
839+
pa.array(["Alice", "Bob", "Phil", "Karin"]),
840+
pa.array([30.0, 25.0, 75.0, 50.0]),
841+
pa.array(["New York; NY", None, "Detroit' MI", "Stockholm\nSweden"]),
842+
],
843+
schema=schema,
844+
)
845+
846+
read_csv_with_options_inner(
847+
tmp_path, csv_content, options, expected, as_read, global_ctx
848+
)
849+
850+
csv_content = "name,age\nAlice,30\nBob,25\nCharlie,35\nDiego,40\nEmily,15"
851+
852+
expected = pa.RecordBatch.from_arrays(
853+
[
854+
pa.array(["Alice", "Bob", "Charlie", "Diego", "Emily"]),
855+
pa.array([30, 25, 35, 40, 15]),
856+
pa.array(["a", "a", "a", "a", "a"]),
857+
],
858+
schema=pa.schema(
859+
[
860+
pa.field("name", pa.string(), nullable=True),
861+
pa.field("age", pa.int64(), nullable=True),
862+
pa.field("group", pa.string(), nullable=False),
863+
]
864+
),
865+
)
866+
options = CsvReadOptions(
867+
table_partition_cols=[("group", pa.string())],
868+
)
869+
870+
read_csv_with_options_inner(
871+
tmp_path, csv_content, options, expected, as_read, global_ctx
872+
)

0 commit comments

Comments
 (0)