2222import pyarrow .dataset as ds
2323import pytest
2424from datafusion import (
25+ CsvReadOptions ,
2526 DataFrame ,
2627 RuntimeEnvBuilder ,
2728 SessionConfig ,
@@ -626,6 +627,8 @@ def test_read_csv_list(ctx):
626627def test_read_csv_compressed (ctx , tmp_path ):
627628 test_data_path = pathlib .Path ("testing/data/csv/aggregate_test_100.csv" )
628629
630+ expected = ctx .read_csv (test_data_path ).collect ()
631+
629632 # File compression type
630633 gzip_path = tmp_path / "aggregate_test_100.csv.gz"
631634
@@ -636,7 +639,13 @@ def test_read_csv_compressed(ctx, tmp_path):
636639 gzipped_file .writelines (csv_file )
637640
638641 csv_df = ctx .read_csv (gzip_path , file_extension = ".gz" , file_compression_type = "gz" )
639- csv_df .select (column ("c1" )).show ()
642+ assert csv_df .collect () == expected
643+
644+ csv_df = ctx .read_csv (
645+ gzip_path ,
646+ options = CsvReadOptions (file_extension = ".gz" , file_compression_type = "gz" ),
647+ )
648+ assert csv_df .collect () == expected
640649
641650
642651def test_read_parquet (ctx ):
@@ -735,43 +744,129 @@ def test_csv_read_options_builder_pattern():
735744 assert options .file_extension == ".tsv"
736745
737746
738- @pytest .mark .parametrize (
739- ("as_read" , "global_ctx" ),
740- [
741- (True , True ),
742- (True , False ),
743- (False , False ),
744- ],
745- )
746- def test_read_csv_with_options (tmp_path , as_read , global_ctx ):
747- """Test reading CSV with CsvReadOptions."""
748- from datafusion import CsvReadOptions , SessionContext
747+ def read_csv_with_options_inner (
748+ tmp_path : pathlib .Path ,
749+ csv_content : str ,
750+ options : CsvReadOptions ,
751+ expected : pa .RecordBatch ,
752+ as_read : bool ,
753+ global_ctx : bool ,
754+ ) -> None :
755+ from datafusion import SessionContext
749756
750757 # Create a test CSV file
751- csv_path = tmp_path / "test.csv"
752- csv_content = "name;age;city\n Alice;30;New York\n Bob;25\n #Charlie;35;Paris"
758+ group_dir = tmp_path / "group=a"
759+ group_dir .mkdir (exist_ok = True )
760+
761+ csv_path = group_dir / "test.csv"
753762 csv_path .write_text (csv_content )
754763
755764 ctx = SessionContext ()
756765
757- # Test with CsvReadOptions
758- options = CsvReadOptions (
759- has_header = True , delimiter = ";" , comment = "#" , truncated_rows = True
760- )
761-
762766 if as_read :
763767 if global_ctx :
764768 from datafusion .io import read_csv
765769
766- df = read_csv (str (csv_path ), options = options )
770+ df = read_csv (str (tmp_path ), options = options )
767771 else :
768- df = ctx .read_csv (str (csv_path ), options = options )
772+ df = ctx .read_csv (str (tmp_path ), options = options )
769773 else :
770- ctx .register_csv ("test_table" , str (csv_path ), options = options )
774+ ctx .register_csv ("test_table" , str (tmp_path ), options = options )
771775 df = ctx .sql ("SELECT * FROM test_table" )
776+ df .show ()
772777
773778 # Verify the data
774779 result = df .collect ()
775780 assert len (result ) == 1
776- assert result [0 ].num_columns == 3
777- assert result [0 ].column (0 ).to_pylist () == ["Alice" , "Bob" , None ]
781+ assert result [0 ] == expected
782+
783+
784+ @pytest .mark .parametrize (
785+ ("as_read" , "global_ctx" ),
786+ [
787+ (True , True ),
788+ (True , False ),
789+ (False , False ),
790+ ],
791+ )
792+ def test_read_csv_with_options (tmp_path , as_read , global_ctx ):
793+ """Test reading CSV with CsvReadOptions."""
794+
795+ csv_content = "Alice;30;|New York; NY|\n Bob;25\n #Charlie;35;Paris\n Phil;75;Detroit' MI\n Karin;50;|Stockholm\n Sweden|" # noqa: E501
796+
797+ # Some of the read options are difficult to test in combination
798+ # such as schema and schema_infer_max_records so run multiple tests
799+ # file_sort_order doesn't impact reading, but included here to ensure
800+ # all options parse correctly
801+ options = CsvReadOptions (
802+ has_header = False ,
803+ delimiter = ";" ,
804+ quote = "|" ,
805+ terminator = "\n " ,
806+ escape = "\\ " ,
807+ comment = "#" ,
808+ newlines_in_values = True ,
809+ schema_infer_max_records = 1 ,
810+ null_regex = "[pP]+aris" ,
811+ truncated_rows = True ,
812+ file_sort_order = [[column ("column_1" ).sort (), column ("column_2" )], ["column_3" ]],
813+ )
814+
815+ expected = pa .RecordBatch .from_arrays (
816+ [
817+ pa .array (["Alice" , "Bob" , "Phil" , "Karin" ]),
818+ pa .array ([30 , 25 , 75 , 50 ]),
819+ pa .array (["New York; NY" , None , "Detroit' MI" , "Stockholm\n Sweden" ]),
820+ ],
821+ names = ["column_1" , "column_2" , "column_3" ],
822+ )
823+
824+ read_csv_with_options_inner (
825+ tmp_path , csv_content , options , expected , as_read , global_ctx
826+ )
827+
828+ schema = pa .schema (
829+ [
830+ pa .field ("name" , pa .string (), nullable = False ),
831+ pa .field ("age" , pa .float32 (), nullable = False ),
832+ pa .field ("location" , pa .string (), nullable = True ),
833+ ]
834+ )
835+ options .with_schema (schema )
836+
837+ expected = pa .RecordBatch .from_arrays (
838+ [
839+ pa .array (["Alice" , "Bob" , "Phil" , "Karin" ]),
840+ pa .array ([30.0 , 25.0 , 75.0 , 50.0 ]),
841+ pa .array (["New York; NY" , None , "Detroit' MI" , "Stockholm\n Sweden" ]),
842+ ],
843+ schema = schema ,
844+ )
845+
846+ read_csv_with_options_inner (
847+ tmp_path , csv_content , options , expected , as_read , global_ctx
848+ )
849+
850+ csv_content = "name,age\n Alice,30\n Bob,25\n Charlie,35\n Diego,40\n Emily,15"
851+
852+ expected = pa .RecordBatch .from_arrays (
853+ [
854+ pa .array (["Alice" , "Bob" , "Charlie" , "Diego" , "Emily" ]),
855+ pa .array ([30 , 25 , 35 , 40 , 15 ]),
856+ pa .array (["a" , "a" , "a" , "a" , "a" ]),
857+ ],
858+ schema = pa .schema (
859+ [
860+ pa .field ("name" , pa .string (), nullable = True ),
861+ pa .field ("age" , pa .int64 (), nullable = True ),
862+ pa .field ("group" , pa .string (), nullable = False ),
863+ ]
864+ ),
865+ )
866+ options = CsvReadOptions (
867+ table_partition_cols = [("group" , pa .string ())],
868+ )
869+
870+ read_csv_with_options_inner (
871+ tmp_path , csv_content , options , expected , as_read , global_ctx
872+ )
0 commit comments