@@ -46,17 +46,21 @@ use crate::utils::{get_tokio_runtime, wait_for_future};
4646use datafusion:: arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
4747use datafusion:: arrow:: pyarrow:: PyArrowType ;
4848use datafusion:: arrow:: record_batch:: RecordBatch ;
49- use datafusion:: common:: ScalarValue ;
49+ use datafusion:: catalog_common:: TableReference ;
50+ use datafusion:: common:: { exec_err, ScalarValue } ;
5051use datafusion:: datasource:: file_format:: file_compression_type:: FileCompressionType ;
5152use datafusion:: datasource:: file_format:: parquet:: ParquetFormat ;
5253use datafusion:: datasource:: listing:: {
5354 ListingOptions , ListingTable , ListingTableConfig , ListingTableUrl ,
5455} ;
5556use datafusion:: datasource:: MemTable ;
5657use datafusion:: datasource:: TableProvider ;
57- use datafusion:: execution:: context:: { SQLOptions , SessionConfig , SessionContext , TaskContext } ;
58+ use datafusion:: execution:: context:: {
59+ DataFilePaths , SQLOptions , SessionConfig , SessionContext , TaskContext ,
60+ } ;
5861use datafusion:: execution:: disk_manager:: DiskManagerConfig ;
5962use datafusion:: execution:: memory_pool:: { FairSpillPool , GreedyMemoryPool , UnboundedMemoryPool } ;
63+ use datafusion:: execution:: options:: ReadOptions ;
6064use datafusion:: execution:: runtime_env:: { RuntimeConfig , RuntimeEnv } ;
6165use datafusion:: physical_plan:: SendableRecordBatchStream ;
6266use datafusion:: prelude:: {
@@ -621,7 +625,7 @@ impl PySessionContext {
621625 pub fn register_csv (
622626 & mut self ,
623627 name : & str ,
624- path : PathBuf ,
628+ path : & Bound < ' _ , PyAny > ,
625629 schema : Option < PyArrowType < Schema > > ,
626630 has_header : bool ,
627631 delimiter : & str ,
@@ -630,9 +634,6 @@ impl PySessionContext {
630634 file_compression_type : Option < String > ,
631635 py : Python ,
632636 ) -> PyResult < ( ) > {
633- let path = path
634- . to_str ( )
635- . ok_or_else ( || PyValueError :: new_err ( "Unable to convert path to a string" ) ) ?;
636637 let delimiter = delimiter. as_bytes ( ) ;
637638 if delimiter. len ( ) != 1 {
638639 return Err ( PyValueError :: new_err (
@@ -648,8 +649,15 @@ impl PySessionContext {
648649 . file_compression_type ( parse_file_compression_type ( file_compression_type) ?) ;
649650 options. schema = schema. as_ref ( ) . map ( |x| & x. 0 ) ;
650651
651- let result = self . ctx . register_csv ( name, path, options) ;
652- wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?;
652+ if path. is_instance_of :: < PyList > ( ) {
653+ let paths = path. extract :: < Vec < String > > ( ) ?;
654+ let result = self . register_csv_from_multiple_paths ( name, paths, options) ;
655+ wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?;
656+ } else {
657+ let path = path. extract :: < String > ( ) ?;
658+ let result = self . ctx . register_csv ( name, & path, options) ;
659+ wait_for_future ( py, result) . map_err ( DataFusionError :: from) ?;
660+ }
653661
654662 Ok ( ( ) )
655663 }
@@ -981,6 +989,46 @@ impl PySessionContext {
981989 async fn _table ( & self , name : & str ) -> datafusion:: common:: Result < DataFrame > {
982990 self . ctx . table ( name) . await
983991 }
992+
993+ async fn register_csv_from_multiple_paths (
994+ & self ,
995+ name : & str ,
996+ table_paths : Vec < String > ,
997+ options : CsvReadOptions < ' _ > ,
998+ ) -> datafusion:: common:: Result < ( ) > {
999+ let table_paths = table_paths. to_urls ( ) ?;
1000+ let session_config = self . ctx . copied_config ( ) ;
1001+ let listing_options =
1002+ options. to_listing_options ( & session_config, self . ctx . copied_table_options ( ) ) ;
1003+
1004+ let option_extension = listing_options. file_extension . clone ( ) ;
1005+
1006+ if table_paths. is_empty ( ) {
1007+ return exec_err ! ( "No table paths were provided" ) ;
1008+ }
1009+
1010+ // check if the file extension matches the expected extension
1011+ for path in & table_paths {
1012+ let file_path = path. as_str ( ) ;
1013+ if !file_path. ends_with ( option_extension. clone ( ) . as_str ( ) ) && !path. is_collection ( ) {
1014+ return exec_err ! (
1015+ "File path '{file_path}' does not match the expected extension '{option_extension}'"
1016+ ) ;
1017+ }
1018+ }
1019+
1020+ let resolved_schema = options
1021+ . get_resolved_schema ( & session_config, self . ctx . state ( ) , table_paths[ 0 ] . clone ( ) )
1022+ . await ?;
1023+
1024+ let config = ListingTableConfig :: new_with_multi_paths ( table_paths)
1025+ . with_listing_options ( listing_options)
1026+ . with_schema ( resolved_schema) ;
1027+ let table = ListingTable :: try_new ( config) ?;
1028+ self . ctx
1029+ . register_table ( TableReference :: Bare { table : name. into ( ) } , Arc :: new ( table) ) ?;
1030+ Ok ( ( ) )
1031+ }
9841032}
9851033
9861034pub fn convert_table_partition_cols (
0 commit comments