diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 1aaab32dbe..ff3abc85b4 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1385,9 +1385,7 @@ def _get_column_projection_values( def _task_to_record_batches( fs: FileSystem, task: FileScanTask, - bound_row_filter: BooleanExpression, projected_schema: Schema, - projected_field_ids: Set[int], positional_deletes: Optional[List[ChunkedArray]], case_sensitive: bool, name_mapping: Optional[NameMapping] = None, @@ -1405,8 +1403,8 @@ def _task_to_record_batches( file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True) pyarrow_filter = None - if bound_row_filter is not AlwaysTrue(): - translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive) + if task.residual is not AlwaysTrue(): + translated_row_filter = translate_column_names(task.residual, file_schema, case_sensitive=case_sensitive) bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) pyarrow_filter = expression_to_pyarrow(bound_file_filter) @@ -1416,7 +1414,13 @@ def _task_to_record_batches( task.file, projected_schema, partition_spec, file_schema.field_ids ) - file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) + file_project_schema = prune_columns( + file_schema, + { + id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType)) + }.union(extract_field_ids(task.residual)), + select_full_types=False, + ) fragment_scanner = ds.Scanner.from_fragment( fragment=fragment, @@ -1514,7 +1518,7 @@ class ArrowScan: _table_metadata: TableMetadata _io: FileIO _projected_schema: Schema - _bound_row_filter: BooleanExpression + _bound_row_filter: Optional[BooleanExpression] _case_sensitive: bool _limit: Optional[int] """Scan the Iceberg Table and create an Arrow construct. @@ -1533,26 +1537,25 @@ def __init__( table_metadata: TableMetadata, io: FileIO, projected_schema: Schema, - row_filter: BooleanExpression, + row_filter: Optional[BooleanExpression] = None, case_sensitive: bool = True, limit: Optional[int] = None, ) -> None: self._table_metadata = table_metadata self._io = io self._projected_schema = projected_schema - self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) + if row_filter is not None: + deprecation_message( + deprecated_in="0.9.0", + removed_in="0.10.0", + help_message="row_filter is marked as deprecated, and will be removed in 0.10.0. Please make sure to set the residual on the ScanTasks.", + ) + self._bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive) + else: + self._bound_row_filter = None self._case_sensitive = case_sensitive self._limit = limit - @property - def _projected_field_ids(self) -> Set[int]: - """Set of field IDs that should be projected from the data files.""" - return { - id - for id in self._projected_schema.field_ids - if not isinstance(self._projected_schema.find_type(id), (MapType, ListType)) - }.union(extract_field_ids(self._bound_row_filter)) - def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: """Scan the Iceberg table and return a pa.Table. @@ -1573,7 +1576,10 @@ def to_table(self, tasks: Iterable[FileScanTask]) -> pa.Table: deletes_per_file = _read_all_delete_files(self._io, tasks) executor = ExecutorFactory.get_or_create() - def _table_from_scan_task(task: FileScanTask) -> pa.Table: + if self._bound_row_filter is not None: + tasks = [task._set_residual(expr=self._bound_row_filter) for task in tasks] + + def _table_from_scan_task(task: FileScanTask) -> Optional[pa.Table]: batches = list(self._record_batches_from_scan_tasks_and_deletes([task], deletes_per_file)) if len(batches) > 0: return pa.Table.from_batches(batches) @@ -1643,6 +1649,9 @@ def to_record_batches(self, tasks: Iterable[FileScanTask]) -> Iterator[pa.Record ResolveError: When a required field cannot be found in the file ValueError: When a field type in the file cannot be projected to the schema type """ + if self._bound_row_filter is not None: + tasks = [task._set_residual(expr=self._bound_row_filter) for task in tasks] + deletes_per_file = _read_all_delete_files(self._io, tasks) return self._record_batches_from_scan_tasks_and_deletes(tasks, deletes_per_file) @@ -1656,9 +1665,7 @@ def _record_batches_from_scan_tasks_and_deletes( batches = _task_to_record_batches( _fs_from_file_path(self._io, task.file.file_path), task, - self._bound_row_filter, self._projected_schema, - self._projected_field_ids, deletes_per_file.get(task.file.file_path), self._case_sensitive, self._table_metadata.name_mapping(), diff --git a/pyiceberg/table/__init__.py b/pyiceberg/table/__init__.py index 78676a774a..0b3f0bd188 100644 --- a/pyiceberg/table/__init__.py +++ b/pyiceberg/table/__init__.py @@ -1632,6 +1632,10 @@ def __init__( self.length = length or data_file.file_size_in_bytes self.residual = residual + def _set_residual(self, expr: BooleanExpression) -> "FileScanTask": + self.residual = expr + return self + def _open_manifest( io: FileIO, @@ -1827,8 +1831,12 @@ def plan_files(self) -> Iterable[FileScanTask]: data_entry, positional_delete_entries, ), - residual=residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( - data_entry.data_file.partition + residual=bind( + self.table_metadata.schema(), + residual_evaluators[data_entry.data_file.spec_id](data_entry.data_file).residual_for( + data_entry.data_file.partition + ), + case_sensitive=self.case_sensitive, ), ) for data_entry in data_entries @@ -1845,7 +1853,7 @@ def to_arrow(self) -> pa.Table: from pyiceberg.io.pyarrow import ArrowScan return ArrowScan( - self.table_metadata, self.io, self.projection(), self.row_filter, self.case_sensitive, self.limit + self.table_metadata, self.io, self.projection(), case_sensitive=self.case_sensitive, limit=self.limit ).to_table(self.plan_files()) def to_arrow_batch_reader(self) -> pa.RecordBatchReader: @@ -1938,7 +1946,6 @@ def count(self) -> int: table_metadata=self.table_metadata, io=self.io, projected_schema=self.projection(), - row_filter=self.row_filter, case_sensitive=self.case_sensitive, ) tbl = arrow_scan.to_table([task]) diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index e90f3a46fc..7b115c6a9f 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -972,7 +972,6 @@ def project( ), io=PyArrowFileIO(), projected_schema=schema, - row_filter=expr or AlwaysTrue(), case_sensitive=True, ).to_table( tasks=[ @@ -984,7 +983,8 @@ def project( partition={}, record_count=3, file_size_in_bytes=3, - ) + ), + residual=expr or AlwaysTrue(), ) for file in files ]