|
16 | 16 | # under the License. |
17 | 17 | # pylint:disable=redefined-outer-name |
18 | 18 |
|
| 19 | +import multiprocessing |
19 | 20 | import os |
20 | 21 | import re |
| 22 | +import threading |
21 | 23 | from datetime import date |
22 | 24 | from typing import Iterator |
| 25 | +from unittest import mock |
23 | 26 |
|
24 | 27 | import pyarrow as pa |
25 | 28 | import pyarrow.parquet as pq |
|
31 | 34 | from pyiceberg.exceptions import NoSuchTableError |
32 | 35 | from pyiceberg.io import FileIO |
33 | 36 | from pyiceberg.io.pyarrow import UnsupportedPyArrowTypeException, _pyarrow_schema_ensure_large_types |
| 37 | +from pyiceberg.manifest import DataFile |
34 | 38 | from pyiceberg.partitioning import UNPARTITIONED_PARTITION_SPEC, PartitionField, PartitionSpec |
35 | 39 | from pyiceberg.schema import Schema |
36 | 40 | from pyiceberg.table import Table |
| 41 | +from pyiceberg.table.metadata import TableMetadata |
37 | 42 | from pyiceberg.transforms import BucketTransform, IdentityTransform, MonthTransform |
38 | 43 | from pyiceberg.types import ( |
39 | 44 | BooleanType, |
@@ -229,6 +234,54 @@ def test_add_files_to_unpartitioned_table_raises_has_field_ids( |
229 | 234 | tbl.add_files(file_paths=file_paths) |
230 | 235 |
|
231 | 236 |
|
| 237 | +@pytest.mark.integration |
| 238 | +def test_add_files_parallelized(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None: |
| 239 | + from pyiceberg.io.pyarrow import parquet_file_to_data_file |
| 240 | + |
| 241 | + real_parquet_file_to_data_file = parquet_file_to_data_file |
| 242 | + |
| 243 | + lock = threading.Lock() |
| 244 | + unique_threads_seen = set() |
| 245 | + cpu_count = multiprocessing.cpu_count() |
| 246 | + |
| 247 | + # patch the function _parquet_file_to_data_file to we can track how many unique thread IDs |
| 248 | + # it was executed from |
| 249 | + with mock.patch("pyiceberg.io.pyarrow.parquet_file_to_data_file") as patch_func: |
| 250 | + |
| 251 | + def mock_parquet_file_to_data_file(io: FileIO, table_metadata: TableMetadata, file_path: str) -> DataFile: |
| 252 | + lock.acquire() |
| 253 | + thread_id = threading.get_ident() # the current thread ID |
| 254 | + unique_threads_seen.add(thread_id) |
| 255 | + lock.release() |
| 256 | + return real_parquet_file_to_data_file(io=io, table_metadata=table_metadata, file_path=file_path) |
| 257 | + |
| 258 | + patch_func.side_effect = mock_parquet_file_to_data_file |
| 259 | + |
| 260 | + identifier = f"default.unpartitioned_table_schema_updates_v{format_version}" |
| 261 | + tbl = _create_table(session_catalog, identifier, format_version) |
| 262 | + |
| 263 | + file_paths = [ |
| 264 | + f"s3://warehouse/default/add_files_parallel/v{format_version}/test-{i}.parquet" for i in range(cpu_count * 2) |
| 265 | + ] |
| 266 | + # write parquet files |
| 267 | + for file_path in file_paths: |
| 268 | + fo = tbl.io.new_output(file_path) |
| 269 | + with fo.create(overwrite=True) as fos: |
| 270 | + with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer: |
| 271 | + writer.write_table(ARROW_TABLE) |
| 272 | + |
| 273 | + tbl.add_files(file_paths=file_paths) |
| 274 | + |
| 275 | + # duration creation of threadpool processor, when max_workers is not |
| 276 | + # specified, python will add cpu_count + 4 as the number of threads in the |
| 277 | + # pool in this case |
| 278 | + # https://github.com/python/cpython/blob/e06bebb87e1b33f7251196e1ddb566f528c3fc98/Lib/concurrent/futures/thread.py#L173-L181 |
| 279 | + # we check that we have at least seen the number of threads. we don't |
| 280 | + # specify the workers in the thread pool and we can't check without |
| 281 | + # accessing private attributes of ThreadPoolExecutor |
| 282 | + assert len(unique_threads_seen) >= cpu_count |
| 283 | + |
| 284 | + |
232 | 285 | @pytest.mark.integration |
233 | 286 | def test_add_files_to_unpartitioned_table_with_schema_updates( |
234 | 287 | spark: SparkSession, session_catalog: Catalog, format_version: int |
|
0 commit comments