1616# under the License.
1717# pylint:disable=redefined-outer-name
1818
19+ import multiprocessing
1920import os
2021import re
2122import threading
@@ -241,24 +242,27 @@ def test_add_files_parallelized(spark: SparkSession, session_catalog: Catalog, f
241242
242243 lock = threading .Lock ()
243244 unique_threads_seen = set ()
245+ cpu_count = multiprocessing .cpu_count ()
244246
245247 # patch the function _parquet_file_to_data_file to we can track how many unique thread IDs
246248 # it was executed from
247249 with mock .patch ("pyiceberg.io.pyarrow.parquet_file_to_data_file" ) as patch_func :
248250
249- def mock_parquet_file_to_data_file (io : FileIO , table_metadata : TableMetadata , file_path : str , schema : Schema ) -> DataFile :
251+ def mock_parquet_file_to_data_file (io : FileIO , table_metadata : TableMetadata , file_path : str ) -> DataFile :
250252 lock .acquire ()
251253 thread_id = threading .get_ident () # the current thread ID
252254 unique_threads_seen .add (thread_id )
253255 lock .release ()
254- return real_parquet_file_to_data_file (io = io , table_metadata = table_metadata , file_path = file_path , schema = schema )
256+ return real_parquet_file_to_data_file (io = io , table_metadata = table_metadata , file_path = file_path )
255257
256258 patch_func .side_effect = mock_parquet_file_to_data_file
257259
258260 identifier = f"default.unpartitioned_table_schema_updates_v{ format_version } "
259261 tbl = _create_table (session_catalog , identifier , format_version )
260262
261- file_paths = [f"s3://warehouse/default/add_files_parallel/v{ format_version } /test-{ i } .parquet" for i in range (10 )]
263+ file_paths = [
264+ f"s3://warehouse/default/add_files_parallel/v{ format_version } /test-{ i } .parquet" for i in range (cpu_count * 2 )
265+ ]
262266 # write parquet files
263267 for file_path in file_paths :
264268 fo = tbl .io .new_output (file_path )
@@ -268,7 +272,14 @@ def mock_parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, fi
268272
269273 tbl .add_files (file_paths = file_paths )
270274
271- assert len (unique_threads_seen ) == 10
275+ # duration creation of threadpool processor, when max_workers is not
276+ # specified, python will add cpu_count + 4 as the number of threads in the
277+ # pool in this case
278+ # https://github.com/python/cpython/blob/e06bebb87e1b33f7251196e1ddb566f528c3fc98/Lib/concurrent/futures/thread.py#L173-L181
279+ # we check that we have at least seen the number of threads. we don't
280+ # specify the workers in the thread pool and we can't check without
281+ # accessing private attributes of ThreadPoolExecutor
282+ assert len (unique_threads_seen ) >= cpu_count
272283
273284
274285@pytest .mark .integration
0 commit comments