diff --git a/docs/content/pypaimon/data-evolution.md b/docs/content/pypaimon/data-evolution.md index f941136e9d63..e7bf021f9630 100644 --- a/docs/content/pypaimon/data-evolution.md +++ b/docs/content/pypaimon/data-evolution.md @@ -196,3 +196,8 @@ commit.close() - **Row order matters**: the batches you write must have the **same number of rows** as the batches you read, in the same order for that shard. - **Parallelism**: run multiple shards by calling `new_shard_updator(shard_idx, num_shards)` for each shard. + +## Read After Partial Shard Update + +- **Full table read**: rows from updated shards have the new column; rows from other shards have null for that column. +- **Per-shard read** (`with_shard(shard_idx, num_shards)`): read only the shard(s) you need. (new column where written, null elsewhere). diff --git a/paimon-python/pypaimon/globalindex/data_evolution_batch_scan.py b/paimon-python/pypaimon/globalindex/data_evolution_batch_scan.py new file mode 100644 index 000000000000..2d6733af580e --- /dev/null +++ b/paimon-python/pypaimon/globalindex/data_evolution_batch_scan.py @@ -0,0 +1,69 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + + +from typing import Optional + +from pypaimon.common.predicate import Predicate +from pypaimon.table.special_fields import SpecialFields + + +class DataEvolutionBatchScan: + @staticmethod + def remove_row_id_filter(predicate: Optional[Predicate]) -> Optional[Predicate]: + if predicate is None: + return None + return DataEvolutionBatchScan._remove(predicate) + + @staticmethod + def _remove(predicate: Predicate) -> Optional[Predicate]: + if predicate.method == 'and': + new_children = [] + for p in predicate.literals: + sub = DataEvolutionBatchScan._remove(p) + if sub is not None: + new_children.append(sub) + if not new_children: + return None + if len(new_children) == 1: + return new_children[0] + return Predicate( + method='and', + index=predicate.index, + field=predicate.field, + literals=new_children + ) + if predicate.method == 'or': + new_children = [] + for p in predicate.literals: + sub = DataEvolutionBatchScan._remove(p) + if sub is None: + return None + new_children.append(sub) + if len(new_children) == 1: + return new_children[0] + return Predicate( + method='or', + index=predicate.index, + field=predicate.field, + literals=new_children + ) + # Leaf: remove if _ROW_ID + if predicate.field == SpecialFields.ROW_ID.name: + return None + return predicate diff --git a/paimon-python/pypaimon/globalindex/range.py b/paimon-python/pypaimon/globalindex/range.py index 19b9b40e9439..e73b2462b37f 100644 --- a/paimon-python/pypaimon/globalindex/range.py +++ b/paimon-python/pypaimon/globalindex/range.py @@ -153,6 +153,14 @@ def merge_sorted_as_possible(ranges: List['Range']) -> List['Range']: return result + @staticmethod + def to_ranges(ids: List[int]) -> List['Range']: + if not ids: + return [] + sorted_ids = sorted(set(ids)) + ranges = [Range(i, i) for i in sorted_ids] + return Range.sort_and_merge_overlap(ranges, merge=True, adjacent=True) + @staticmethod def sort_and_merge_overlap(ranges: List['Range'], merge: bool = True, adjacent: bool = True) -> List['Range']: """ diff --git a/paimon-python/pypaimon/read/read_builder.py b/paimon-python/pypaimon/read/read_builder.py index be489c9bac38..2986960221bf 100644 --- a/paimon-python/pypaimon/read/read_builder.py +++ b/paimon-python/pypaimon/read/read_builder.py @@ -67,7 +67,8 @@ def new_read(self) -> TableRead: return TableRead( table=self.table, predicate=self._predicate, - read_type=self.read_type() + read_type=self.read_type(), + projection=self._projection, ) def new_predicate_builder(self) -> PredicateBuilder: diff --git a/paimon-python/pypaimon/read/reader/concat_batch_reader.py b/paimon-python/pypaimon/read/reader/concat_batch_reader.py index 4318f883eb2e..9486ba276ac2 100644 --- a/paimon-python/pypaimon/read/reader/concat_batch_reader.py +++ b/paimon-python/pypaimon/read/reader/concat_batch_reader.py @@ -141,6 +141,8 @@ class DataEvolutionMergeReader(RecordBatchReader): - The fourth field comes from batch1, and it is at offset 1 in batch1. - The fifth field comes from batch2, and it is at offset 1 in batch2. - The sixth field comes from batch1, and it is at offset 0 in batch1. + + When row_offsets[i] == -1 (no file provides that field), output a column of nulls using schema. """ def __init__( @@ -207,14 +209,36 @@ def read_arrow_batch(self) -> Optional[RecordBatch]: for i in range(len(self.row_offsets)): batch_index = self.row_offsets[i] field_index = self.field_offsets[i] + field_name = self.schema.field(i).name if self.schema else None + column = None + if batch_index >= 0 and batches[batch_index] is not None: - columns.append(batches[batch_index].column(field_index).slice(0, min_rows)) - else: + src_batch = batches[batch_index] + if field_name is not None and field_name in src_batch.schema.names: + column = src_batch.column( + src_batch.schema.get_field_index(field_name) + ).slice(0, min_rows) + elif field_index < src_batch.num_columns: + column = src_batch.column(field_index).slice(0, min_rows) + + if column is None and field_name is not None: + for b in batches: + if b is not None and field_name in b.schema.names: + column = b.column(b.schema.get_field_index(field_name)).slice( + 0, min_rows + ) + break + + if column is not None: + columns.append(column) + elif self.schema is not None and i < len(self.schema): columns.append(pa.nulls(min_rows, type=self.schema.field(i).type)) for i in range(len(self.readers)): if batches[i] is not None and batches[i].num_rows > min_rows: - self._buffers[i] = batches[i].slice(min_rows, batches[i].num_rows - min_rows) + self._buffers[i] = batches[i].slice( + min_rows, batches[i].num_rows - min_rows + ) return pa.RecordBatch.from_arrays(columns, schema=self.schema) diff --git a/paimon-python/pypaimon/read/reader/data_file_batch_reader.py b/paimon-python/pypaimon/read/reader/data_file_batch_reader.py index c9e51785a056..dd1053d6c125 100644 --- a/paimon-python/pypaimon/read/reader/data_file_batch_reader.py +++ b/paimon-python/pypaimon/read/reader/data_file_batch_reader.py @@ -16,7 +16,7 @@ # limitations under the License. ################################################################################ -from typing import List, Optional +from typing import List, Optional, Tuple import pyarrow as pa from pyarrow import RecordBatch @@ -48,6 +48,33 @@ def __init__(self, format_reader: RecordBatchReader, index_mapping: List[int], p self.first_row_id = first_row_id self.max_sequence_number = max_sequence_number self.system_fields = system_fields + self.requested_field_names = [field.name for field in fields] if fields else None + self.fields = fields + + def _align_to_requested_names( + self, + inter_arrays: List, + inter_names: List, + requested_field_names: List[str], + num_rows: int, + ) -> Tuple[List, List]: + name_to_idx = {n: i for i, n in enumerate(inter_names)} + ordered_arrays = [] + ordered_names = [] + for name in requested_field_names: + idx = name_to_idx.get(name) + if idx is None and name.startswith("_KEY_") and name[5:] in name_to_idx: + idx = name_to_idx[name[5:]] + if idx is not None: + ordered_arrays.append(inter_arrays[idx]) + ordered_names.append(name) + else: + field = self.schema_map.get(name) + ordered_arrays.append( + pa.nulls(num_rows, type=field.type) if field is not None else pa.nulls(num_rows) + ) + ordered_names.append(name) + return ordered_arrays, ordered_names def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch]: if isinstance(self.format_reader, FormatBlobReader): @@ -57,11 +84,27 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch if record_batch is None: return None + num_rows = record_batch.num_rows if self.partition_info is None and self.index_mapping is None: if self.row_tracking_enabled and self.system_fields: record_batch = self._assign_row_tracking(record_batch) + if self.requested_field_names is not None: + inter_arrays = list(record_batch.columns) + inter_names = list(record_batch.schema.names) + ordered_arrays, ordered_names = self._align_to_requested_names( + inter_arrays, inter_names, self.requested_field_names, num_rows + ) + record_batch = pa.RecordBatch.from_arrays(ordered_arrays, ordered_names) return record_batch + if (self.partition_info is None and self.index_mapping is not None + and not self.requested_field_names): + ncol = record_batch.num_columns + if len(self.index_mapping) == ncol and self.index_mapping == list(range(ncol)): + if self.row_tracking_enabled and self.system_fields: + record_batch = self._assign_row_tracking(record_batch) + return record_batch + inter_arrays = [] inter_names = [] num_rows = record_batch.num_rows @@ -75,41 +118,133 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch inter_names.append(partition_field.name) else: real_index = self.partition_info.get_real_index(i) - if real_index < record_batch.num_columns: + name = ( + self.requested_field_names[i] + if self.requested_field_names and i < len(self.requested_field_names) + else f"_col_{i}" + ) + batch_names = record_batch.schema.names + col_idx = None + if name in batch_names: + col_idx = record_batch.schema.get_field_index(name) + elif name.startswith("_KEY_") and name[5:] in batch_names: + col_idx = record_batch.schema.get_field_index(name[5:]) + if col_idx is not None: + inter_arrays.append(record_batch.column(col_idx)) + inter_names.append(name) + elif real_index < record_batch.num_columns: inter_arrays.append(record_batch.column(real_index)) - inter_names.append(record_batch.schema.field(real_index).name) + inter_names.append(name) + else: + field = self.schema_map.get(name) + inter_arrays.append( + pa.nulls(num_rows, type=field.type) if field is not None else pa.nulls(num_rows) + ) + inter_names.append(name) else: - inter_arrays = record_batch.columns - inter_names = record_batch.schema.names + inter_arrays = list(record_batch.columns) + inter_names = list(record_batch.schema.names) + + if self.requested_field_names is not None: + inter_arrays, inter_names = self._align_to_requested_names( + inter_arrays, inter_names, self.requested_field_names, num_rows + ) - if self.index_mapping is not None: + if self.index_mapping is not None and not ( + self.requested_field_names is not None and inter_names == self.requested_field_names): mapped_arrays = [] mapped_names = [] + partition_names = ( + set(pf.name for pf in self.partition_info.partition_fields) + if self.partition_info else set() + ) + non_partition_indices = [idx for idx, name in enumerate(inter_names) if name not in partition_names] for i, real_index in enumerate(self.index_mapping): - if 0 <= real_index < len(inter_arrays): - mapped_arrays.append(inter_arrays[real_index]) - mapped_names.append(inter_names[real_index]) + if 0 <= real_index < len(non_partition_indices): + actual_index = non_partition_indices[real_index] + mapped_arrays.append(inter_arrays[actual_index]) + mapped_names.append(inter_names[actual_index]) else: - null_array = pa.nulls(num_rows) + name = ( + self.requested_field_names[i] + if self.requested_field_names and i < len(self.requested_field_names) + else f"null_col_{i}" + ) + field = self.schema_map.get(name) + null_array = pa.nulls(num_rows, type=field.type) if field is not None else pa.nulls(num_rows) mapped_arrays.append(null_array) - mapped_names.append(f"null_col_{i}") + mapped_names.append(name) + + if self.partition_info: + partition_arrays_map = { + inter_names[i]: inter_arrays[i] + for i in range(len(inter_names)) + if inter_names[i] in partition_names + } + + if self.requested_field_names: + final_arrays = [] + final_names = [] + mapped_name_to_array = {name: arr for name, arr in zip(mapped_names, mapped_arrays)} + + for name in self.requested_field_names: + if name in mapped_name_to_array: + final_arrays.append(mapped_name_to_array[name]) + final_names.append(name) + elif name in partition_arrays_map: + final_arrays.append(partition_arrays_map[name]) + final_names.append(name) + else: + # Field not in file (e.g. index_mapping -1): output null column + field = self.schema_map.get(name) + null_arr = pa.nulls(num_rows, type=field.type) if field is not None else pa.nulls(num_rows) + final_arrays.append(null_arr) + final_names.append(name) + + inter_arrays = final_arrays + inter_names = final_names + else: + mapped_name_set = set(mapped_names) + for name, arr in partition_arrays_map.items(): + if name not in mapped_name_set: + mapped_arrays.append(arr) + mapped_names.append(name) + inter_arrays = mapped_arrays + inter_names = mapped_names + else: + inter_arrays = mapped_arrays + inter_names = mapped_names if self.system_primary_key: for i in range(len(self.system_primary_key)): - if not mapped_names[i].startswith("_KEY_"): - mapped_names[i] = f"_KEY_{mapped_names[i]}" + if i < len(inter_names) and not inter_names[i].startswith("_KEY_"): + inter_names[i] = f"_KEY_{inter_names[i]}" - inter_arrays = mapped_arrays - inter_names = mapped_names + if self.requested_field_names is not None and len(inter_arrays) < len(self.requested_field_names): + for name in self.requested_field_names[len(inter_arrays):]: + field = self.schema_map.get(name) + inter_arrays.append( + pa.nulls(num_rows, type=field.type) if field is not None else pa.nulls(num_rows) + ) + inter_names.append(name) + + for i, name in enumerate(inter_names): + target_field = self.schema_map.get(name) + if target_field is not None and inter_arrays[i].type != target_field.type: + try: + inter_arrays[i] = inter_arrays[i].cast(target_field.type) + except (pa.ArrowInvalid, pa.ArrowNotImplementedError): + inter_arrays[i] = pa.nulls(num_rows, type=target_field.type) # to contains 'not null' property final_fields = [] for i, name in enumerate(inter_names): array = inter_arrays[i] target_field = self.schema_map.get(name) - if not target_field: - target_field = pa.field(name, array.type) - final_fields.append(target_field) + if target_field is None: + final_fields.append(pa.field(name, array.type)) + else: + final_fields.append(target_field) final_schema = pa.schema(final_fields) record_batch = pa.RecordBatch.from_arrays(inter_arrays, schema=final_schema) @@ -122,31 +257,34 @@ def read_arrow_batch(self, start_idx=None, end_idx=None) -> Optional[RecordBatch def _assign_row_tracking(self, record_batch: RecordBatch) -> RecordBatch: """Assign row tracking meta fields (_ROW_ID and _SEQUENCE_NUMBER).""" arrays = list(record_batch.columns) + num_cols = len(arrays) # Handle _ROW_ID field if SpecialFields.ROW_ID.name in self.system_fields.keys(): idx = self.system_fields[SpecialFields.ROW_ID.name] - # Create a new array that fills with computed row IDs - arrays[idx] = pa.array(range(self.first_row_id, self.first_row_id + record_batch.num_rows), type=pa.int64()) + if idx < num_cols: + if self.first_row_id is None: + raise ValueError( + "Row tracking requires first_row_id on the file; " + "got None. Ensure file metadata has first_row_id when reading _ROW_ID." + ) + arrays[idx] = pa.array( + range(self.first_row_id, self.first_row_id + record_batch.num_rows), + type=pa.int64()) # Handle _SEQUENCE_NUMBER field if SpecialFields.SEQUENCE_NUMBER.name in self.system_fields.keys(): idx = self.system_fields[SpecialFields.SEQUENCE_NUMBER.name] # Create a new array that fills with max_sequence_number - arrays[idx] = pa.repeat(self.max_sequence_number, record_batch.num_rows) + if idx < num_cols: + arrays[idx] = pa.repeat(self.max_sequence_number, record_batch.num_rows) names = record_batch.schema.names - table = None - for i, name in enumerate(names): - field = pa.field( - name, arrays[i].type, - nullable=record_batch.schema.field(name).nullable - ) - if table is None: - table = pa.table({name: arrays[i]}, schema=pa.schema([field])) - else: - table = table.append_column(field, arrays[i]) - return table.to_batches()[0] + fields = [ + pa.field(name, arrays[i].type, nullable=record_batch.schema.field(name).nullable) + for i, name in enumerate(names) + ] + return pa.RecordBatch.from_arrays(arrays, schema=pa.schema(fields)) def close(self) -> None: self.format_reader.close() diff --git a/paimon-python/pypaimon/read/reader/predicate_filter_record_batch_reader.py b/paimon-python/pypaimon/read/reader/predicate_filter_record_batch_reader.py new file mode 100644 index 000000000000..ac388231ee9b --- /dev/null +++ b/paimon-python/pypaimon/read/reader/predicate_filter_record_batch_reader.py @@ -0,0 +1,66 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################ + +from typing import Optional + +import pyarrow as pa + +from pypaimon.common.predicate import Predicate +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader +from pypaimon.table.row.offset_row import OffsetRow + + +class PredicateFilterRecordBatchReader(RecordBatchReader): + def __init__(self, reader: RecordBatchReader, predicate: Predicate): + self.reader = reader + self.predicate = predicate + + def read_arrow_batch(self) -> Optional[pa.RecordBatch]: + while True: + batch = self.reader.read_arrow_batch() + if batch is None: + return None + if batch.num_rows == 0: + return batch + ncols = batch.num_columns + nrows = batch.num_rows + mask = [] + row_tuple = [None] * ncols + offset_row = OffsetRow(row_tuple, 0, ncols) + for i in range(nrows): + for j in range(ncols): + row_tuple[j] = batch.column(j)[i].as_py() + offset_row.replace(tuple(row_tuple)) + try: + mask.append(self.predicate.test(offset_row)) + except IndexError: + raise + except (TypeError, ValueError): + mask.append(False) + if any(mask): + filtered = batch.filter(pa.array(mask)) + if filtered.num_rows > 0: + return filtered + # no rows passed predicate in this batch, continue to next batch + continue + + def return_batch_pos(self) -> int: + return getattr(self.reader, 'return_batch_pos', lambda: 0)() + + def close(self) -> None: + self.reader.close() diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index 0f53ef0cd33d..2ca5f398c63c 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -37,6 +37,54 @@ from pypaimon.manifest.simple_stats_evolutions import SimpleStatsEvolutions +def _row_ranges_from_predicate(predicate: Optional[Predicate]) -> Optional[List]: + """ + Extract row ID ranges from predicate for data evolution push-down. + Supports _ROW_ID with 'equal' and 'in', and 'and'/'or' of those. + Returns None if predicate cannot be converted to row ID ranges. + """ + from pypaimon.globalindex.range import Range + from pypaimon.table.special_fields import SpecialFields + + if predicate is None: + return None + + def visit(p: Predicate): + if p.method == 'and': + result = None + for child in p.literals: + sub = visit(child) + if sub is None: + continue + result = Range.and_(result, sub) if result is not None else sub + if not result: + return result + return result + if p.method == 'or': + parts = [] + for child in p.literals: + sub = visit(child) + if sub is None: + return None + parts.extend(sub) + if not parts: + return [] + return Range.sort_and_merge_overlap(parts, merge=True, adjacent=True) + if p.field != SpecialFields.ROW_ID.name: + return None + if p.method == 'equal': + if not p.literals: + return [] + return Range.to_ranges([int(p.literals[0])]) + if p.method == 'in': + if not p.literals: + return [] + return Range.to_ranges([int(x) for x in p.literals]) + return None + + return visit(predicate) + + def _filter_manifest_files_by_row_ranges( manifest_files: List[ManifestFileMeta], row_ranges: List) -> List[ManifestFileMeta]: @@ -186,6 +234,8 @@ def _create_data_evolution_split_generator(self): row_ranges = global_index_result.results().to_range_list() if isinstance(global_index_result, ScoredGlobalIndexResult): score_getter = global_index_result.score_getter() + if row_ranges is None and self.predicate is not None: + row_ranges = _row_ranges_from_predicate(self.predicate) manifest_files = self.manifest_scanner() @@ -343,6 +393,10 @@ def _filter_manifest_entry(self, entry: ManifestEntry) -> bool: else: if not self.predicate: return True + from pypaimon.globalindex.data_evolution_batch_scan import DataEvolutionBatchScan + predicate_for_stats = DataEvolutionBatchScan.remove_row_id_filter(self.predicate) + if predicate_for_stats is None: + return True if entry.file.value_stats_cols is None and entry.file.write_cols is not None: stats_fields = entry.file.write_cols else: @@ -352,7 +406,7 @@ def _filter_manifest_entry(self, entry: ManifestEntry) -> bool: entry.file.row_count, stats_fields ) - return self.predicate.test_by_simple_stats( + return predicate_for_stats.test_by_simple_stats( evolved_stats, entry.file.row_count ) diff --git a/paimon-python/pypaimon/read/split_read.py b/paimon-python/pypaimon/read/split_read.py index 2088310aa4c0..2c654c1b7757 100644 --- a/paimon-python/pypaimon/read/split_read.py +++ b/paimon-python/pypaimon/read/split_read.py @@ -19,7 +19,7 @@ import os from abc import ABC, abstractmethod from functools import partial -from typing import Callable, List, Optional, Tuple +from typing import Callable, List, Optional, Set, Tuple from pypaimon.common.options.core_options import CoreOptions from pypaimon.common.predicate import Predicate @@ -40,6 +40,7 @@ from pypaimon.read.reader.field_bunch import BlobBunch, DataBunch, FieldBunch from pypaimon.read.reader.filter_record_reader import FilterRecordReader from pypaimon.read.reader.format_avro_reader import FormatAvroReader +from pypaimon.read.reader.predicate_filter_record_batch_reader import PredicateFilterRecordBatchReader from pypaimon.read.reader.row_range_filter_record_reader import RowIdFilterRecordBatchReader from pypaimon.read.reader.format_blob_reader import FormatBlobReader from pypaimon.read.reader.format_lance_reader import FormatLanceReader @@ -54,7 +55,7 @@ from pypaimon.read.reader.sort_merge_reader import SortMergeReaderWithMinHeap from pypaimon.read.split import Split from pypaimon.read.sliced_split import SlicedSplit -from pypaimon.schema.data_types import DataField, PyarrowFieldParser +from pypaimon.schema.data_types import AtomicType, DataField, PyarrowFieldParser from pypaimon.table.special_fields import SpecialFields from pypaimon.globalindex.indexed_split import IndexedSplit @@ -104,14 +105,34 @@ def _push_down_predicate(self) -> Optional[Predicate]: def create_reader(self) -> RecordReader: """Create a record reader for the given split.""" + def _get_blob_column_names(self) -> Set[str]: + out = set() + for field in self.table.table_schema.fields: + t = field.type + if isinstance(t, AtomicType) and t.type.upper() == "BLOB": + out.add(field.name) + return out + def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, - read_fields: List[str], row_tracking_enabled: bool) -> RecordBatchReader: + read_fields: List[str], row_tracking_enabled: bool, + use_requested_field_names: bool = True) -> RecordBatchReader: (read_file_fields, read_arrow_predicate) = self._get_fields_and_predicate(file.schema_id, read_fields) - # Use external_path if available, otherwise use file_path file_path = file.external_path if file.external_path else file.file_path _, extension = os.path.splitext(file_path) file_format = extension[1:] + blob_column_names = self._get_blob_column_names() + is_blob_file = file_format == CoreOptions.FILE_FORMAT_BLOB + if not is_blob_file and blob_column_names: + read_file_fields = [f for f in read_file_fields if f not in blob_column_names] + + if getattr(file, "write_cols", None): + read_file_fields = list(read_file_fields) + for col in file.write_cols: + if col in blob_column_names and not is_blob_file: + continue + if col in read_fields and col not in read_file_fields: + read_file_fields.append(col) batch_size = self.table.options.read_batch_size() @@ -121,8 +142,12 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, self.read_fields, read_arrow_predicate, batch_size=batch_size) elif file_format == CoreOptions.FILE_FORMAT_BLOB: blob_as_descriptor = CoreOptions.blob_as_descriptor(self.table.options) + blob_full_fields = ( + SpecialFields.row_type_with_row_tracking(self.table.table_schema.fields) + if row_tracking_enabled else self.table.table_schema.fields + ) format_reader = FormatBlobReader(self.table.file_io, file_path, read_file_fields, - self.read_fields, read_arrow_predicate, blob_as_descriptor, + blob_full_fields, read_arrow_predicate, blob_as_descriptor, batch_size=batch_size) elif file_format == CoreOptions.FILE_FORMAT_LANCE: format_reader = FormatLanceReader(self.table.file_io, file_path, read_file_fields, @@ -133,20 +158,74 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, else: raise ValueError(f"Unexpected file format: {file_format}") - index_mapping = self.create_index_mapping() - partition_info = self._create_partition_info() - system_fields = SpecialFields.find_system_fields(self.read_fields) + write_cols = getattr(file, "write_cols", None) + if write_cols: + num_cols = len(read_file_fields) if is_blob_file else len(read_fields) + if num_cols > 0: + index_mapping = list(range(num_cols)) if num_cols > 0 else None + else: + index_mapping = None + else: + index_mapping = self.create_index_mapping() + table_schema_fields = ( SpecialFields.row_type_with_row_tracking(self.table.table_schema.fields) if row_tracking_enabled else self.table.table_schema.fields ) + if for_merge_read: + fields = self.read_fields + elif is_blob_file: + field_map = {field.name: field for field in table_schema_fields} + requested_fields = [] + for field_name in read_file_fields: + if field_name in field_map: + requested_fields.append(field_map[field_name]) + fields = requested_fields if requested_fields else table_schema_fields + elif use_requested_field_names and write_cols: + field_map = {field.name: field for field in table_schema_fields} + requested_fields = [] + for field_name in read_fields: + if field_name in field_map: + requested_fields.append(field_map[field_name]) + fields = requested_fields if requested_fields else table_schema_fields + else: + field_map = {field.name: field for field in table_schema_fields} + requested_fields = [field_map[f.name] for f in self.read_fields if f.name in field_map] + fields = requested_fields if requested_fields else table_schema_fields + + system_fields = SpecialFields.find_system_fields(fields) + + field_map = {field.name: field for field in table_schema_fields} + actual_read_fields_for_partition = [] + for field_name in read_file_fields: + if field_name in field_map: + actual_read_fields_for_partition.append(field_map[field_name]) + + if ( + not for_merge_read + and self.table.partition_keys + and actual_read_fields_for_partition + and fields is table_schema_fields + ): + partition_row = self.split.partition + full_partition_and_file = list(partition_row.fields) + actual_read_fields_for_partition + available_names = {f.name for f in full_partition_and_file} + fields = [f for f in self.read_fields if f.name in available_names] + if not fields: + fields = full_partition_and_file + + partition_info = self._create_partition_info( + actual_read_fields=actual_read_fields_for_partition if actual_read_fields_for_partition else None, + output_fields=fields + ) + if for_merge_read: return DataFileBatchReader( format_reader, index_mapping, partition_info, self.trimmed_primary_key, - table_schema_fields, + fields, file.max_sequence_number, file.first_row_id, row_tracking_enabled, @@ -157,7 +236,7 @@ def file_reader_supplier(self, file: DataFileMeta, for_merge_read: bool, index_mapping, partition_info, None, - table_schema_fields, + fields, file.max_sequence_number, file.first_row_id, row_tracking_enabled, @@ -299,30 +378,45 @@ def _get_trimmed_fields(self, read_data_fields: List[DataField], return trimmed_mapping, trimmed_fields - def _create_partition_info(self): + def _create_partition_info( + self, + actual_read_fields: Optional[List[DataField]] = None, + output_fields: Optional[List[DataField]] = None): if not self.table.partition_keys: return None - partition_mapping = self._construct_partition_mapping() + partition_mapping = self._construct_partition_mapping(actual_read_fields, output_fields) if not partition_mapping: return None return PartitionInfo(partition_mapping, self.split.partition) - def _construct_partition_mapping(self) -> List[int]: - _, trimmed_fields = self._get_trimmed_fields( - self._get_read_data_fields(), self._get_all_data_fields() - ) + def _construct_partition_mapping( + self, + actual_read_fields: Optional[List[DataField]] = None, + output_fields: Optional[List[DataField]] = None) -> List[int]: + if actual_read_fields is not None: + read_data_fields = actual_read_fields + else: + read_data_fields = self._get_read_data_fields() + + if output_fields is not None: + fields_to_map = output_fields + else: + fields_to_map = read_data_fields + + actual_read_field_names = {field.name: idx for idx, field in enumerate(read_data_fields)} partition_names = self.table.partition_keys + num_record_batch_cols = len(read_data_fields) - mapping = [0] * (len(trimmed_fields) + 1) - p_count = 0 + mapping = [0] * (len(fields_to_map) + 1) - for i, field in enumerate(trimmed_fields): + for i, field in enumerate(fields_to_map): if field.name in partition_names: partition_index = partition_names.index(field.name) mapping[i] = -(partition_index + 1) - p_count += 1 + elif field.name in actual_read_field_names: + mapping[i] = actual_read_field_names[field.name] + 1 else: - mapping[i] = (i - p_count) + 1 + mapping[i] = num_record_batch_cols + 1 return mapping @@ -394,8 +488,12 @@ def _get_all_data_fields(self): class MergeFileSplitRead(SplitRead): + def create_index_mapping(self): + return None + def kv_reader_supplier(self, file: DataFileMeta, dv_factory: Optional[Callable] = None) -> RecordReader: - file_batch_reader = self.file_reader_supplier(file, True, self._get_final_read_data_fields(), False) + merge_read_fields = [f.name for f in self._get_read_data_fields()] + file_batch_reader = self.file_reader_supplier(file, True, merge_read_fields, False) dv = dv_factory() if dv_factory else None if dv: return ApplyDeletionVectorReader( @@ -449,6 +547,9 @@ def __init__( actual_split = split.data_split() super().__init__(table, predicate, read_type, actual_split, row_tracking_enabled) + def _push_down_predicate(self) -> Optional[Predicate]: + return None + def create_reader(self) -> RecordReader: files = self.split.files suppliers = [] @@ -460,14 +561,19 @@ def create_reader(self) -> RecordReader: if len(need_merge_files) == 1 or not self.read_fields: # No need to merge fields, just create a single file reader suppliers.append( - lambda f=need_merge_files[0]: self._create_file_reader(f, self._get_final_read_data_fields()) + lambda f=need_merge_files[0]: self._create_file_reader( + f, self._get_final_read_data_fields(), use_requested_field_names=False + ) ) else: suppliers.append( lambda files=need_merge_files: self._create_union_reader(files) ) - return ConcatBatchReader(suppliers) + reader = ConcatBatchReader(suppliers) + if self.predicate is not None: + reader = PredicateFilterRecordBatchReader(reader, self.predicate) + return reader def _split_by_row_id(self, files: List[DataFileMeta]) -> List[List[DataFileMeta]]: """Split files by firstRowId for data evolution.""" @@ -516,6 +622,14 @@ def _create_union_reader(self, need_merge_files: List[DataFileMeta]) -> RecordRe # Split field bunches fields_files = self._split_field_bunches(need_merge_files) + def _bunch_sort_key(bunch: FieldBunch) -> tuple: + first_file = bunch.files()[0] + max_seq = max(f.max_sequence_number for f in bunch.files()) + is_partial = 1 if (first_file.write_cols and len(first_file.write_cols) > 0) else 0 + return (max_seq, is_partial) + + fields_files = sorted(fields_files, key=_bunch_sort_key, reverse=True) + # Validate row counts and first row IDs row_count = fields_files[0].row_count() first_row_id = fields_files[0].files()[0].first_row_id @@ -533,51 +647,115 @@ def _create_union_reader(self, need_merge_files: List[DataFileMeta]) -> RecordRe file_record_readers = [None] * len(fields_files) read_field_index = [field.id for field in all_read_fields] - # Initialize offsets + # Initialize offsets and per-bunch read_fields (built in two passes) row_offsets = [-1] * len(all_read_fields) field_offsets = [-1] * len(all_read_fields) + read_fields_per_bunch = [[] for _ in range(len(fields_files))] + # Pass 1: Assign from partial bunches (write_cols) by name first. This ensures columns + for i, bunch in enumerate(fields_files): + first_file = bunch.files()[0] + if not (first_file.write_cols and len(first_file.write_cols) > 0): + continue + for j, field in enumerate(all_read_fields): + if row_offsets[j] == -1 and field.name in first_file.write_cols: + # Do not assign non-blob fields to a blob bunch (blob file only has blob column) + if self._is_blob_file(first_file.file_name) and field.name != first_file.write_cols[0]: + continue + row_offsets[j] = i + field_offsets[j] = len(read_fields_per_bunch[i]) + read_fields_per_bunch[i].append(field) + + # Pass 2: Assign remaining fields by field id (full-schema base and system fields) for i, bunch in enumerate(fields_files): first_file = bunch.files()[0] - - # Get field IDs for this bunch if self._is_blob_file(first_file.file_name): - # For blob files, we need to get the field ID from the write columns field_ids = [self._get_field_id_from_write_cols(first_file)] elif first_file.write_cols: field_ids = self._get_field_ids_from_write_cols(first_file.write_cols) else: - # For regular files, get all field IDs from the schema - field_ids = [field.id for field in self.table.fields] - - read_fields = [] + schema = self.table.schema_manager.get_schema(first_file.schema_id) + schema_fields = ( + SpecialFields.row_type_with_row_tracking(schema.fields) + if self.row_tracking_enabled else schema.fields + ) + field_ids = [field.id for field in schema_fields] + read_fields = list(read_fields_per_bunch[i]) for j, read_field_id in enumerate(read_field_index): + if row_offsets[j] != -1: + continue for field_id in field_ids: if read_field_id == field_id: - if row_offsets[j] == -1: - row_offsets[j] = i - field_offsets[j] = len(read_fields) - read_fields.append(all_read_fields[j]) + row_offsets[j] = i + field_offsets[j] = len(read_fields) + read_fields.append(all_read_fields[j]) break + read_fields_per_bunch[i] = read_fields + # Post-pass: any data field still unassigned, take from a partial bunch by name + table_field_names = {f.name for f in self.table.fields} + for i, field in enumerate(all_read_fields): + if row_offsets[i] != -1: + continue + if field.name not in table_field_names: + continue + for bi, bunch in enumerate(fields_files): + first_file = bunch.files()[0] + if not first_file.write_cols or field.name not in first_file.write_cols: + continue + # Do not assign non-blob fields to a blob bunch (blob file only has blob column) + if self._is_blob_file(first_file.file_name) and field.name != first_file.write_cols[0]: + continue + row_offsets[i] = bi + field_offsets[i] = len(read_fields_per_bunch[bi]) + read_fields_per_bunch[bi].append(field) + break + + write_cols_tuples = [ + tuple(f.files()[0].write_cols or ()) + for f in fields_files + if not self._is_blob_file(f.files()[0].file_name) + ] + all_same_write_cols = len(set(write_cols_tuples)) <= 1 if write_cols_tuples else True + use_requested_field_names = not all_same_write_cols + + table_field_names_set = {f.name for f in self.table.fields} + for i, bunch in enumerate(fields_files): + read_fields = list(read_fields_per_bunch[i]) if not read_fields: file_record_readers[i] = None else: + if not self._is_blob_file(bunch.files()[0].file_name): + schema = self.table.schema_manager.get_schema(bunch.files()[0].schema_id) + schema_fields = ( + SpecialFields.row_type_with_row_tracking(schema.fields) + if self.row_tracking_enabled else schema.fields + ) + blob_column_names = self._get_blob_column_names() + read_field_names_set = {f.name for f in read_fields} + for f in schema_fields: + if f.name in blob_column_names: + continue + if f.name in table_field_names_set and f.name not in read_field_names_set: + read_fields.append(f) + read_field_names_set.add(f.name) read_field_names = self._remove_partition_fields(read_fields) table_fields = self.read_fields - self.read_fields = read_fields # create reader based on read_fields + self.read_fields = read_fields batch_size = self.table.options.read_batch_size() - # Create reader for this bunch if len(bunch.files()) == 1: - suppliers = [lambda r=self._create_file_reader( - bunch.files()[0], read_field_names - ): r] + suppliers = [ + partial(self._create_file_reader, file=bunch.files()[0], + read_fields=read_field_names, + use_requested_field_names=use_requested_field_names) + ] file_record_readers[i] = MergeAllBatchReader(suppliers, batch_size=batch_size) else: - # Create concatenated reader for multiple files suppliers = [ partial(self._create_file_reader, file=file, - read_fields=read_field_names) for file in bunch.files() + read_fields=read_field_names, + use_requested_field_names=use_requested_field_names) + for file in bunch.files() ] file_record_readers[i] = MergeAllBatchReader(suppliers, batch_size=batch_size) self.read_fields = table_fields @@ -591,21 +769,37 @@ def _create_union_reader(self, need_merge_files: List[DataFileMeta]) -> RecordRe output_schema = PyarrowFieldParser.from_paimon_schema(all_read_fields) return DataEvolutionMergeReader(row_offsets, field_offsets, file_record_readers, schema=output_schema) - def _create_file_reader(self, file: DataFileMeta, read_fields: [str]) -> Optional[RecordReader]: + def _create_file_reader(self, file: DataFileMeta, read_fields: [str], + use_requested_field_names: bool = True) -> Optional[RecordReader]: """Create a file reader for a single file.""" + shard_file_idx_map = ( + self.split.shard_file_idx_map() if isinstance(self.split, SlicedSplit) else {} + ) + begin_pos, end_pos = 0, 0 + if file.file_name in shard_file_idx_map: + (begin_pos, end_pos) = shard_file_idx_map[file.file_name] + if (begin_pos, end_pos) == (-1, -1): + return None + def create_record_reader(): return self.file_reader_supplier( file=file, for_merge_read=False, read_fields=read_fields, - row_tracking_enabled=True) + row_tracking_enabled=True, + use_requested_field_names=use_requested_field_names) + + base = create_record_reader() + if file.file_name in shard_file_idx_map: + base = ShardBatchReader(base, begin_pos, end_pos) if self.row_ranges is None: - return create_record_reader() + return base file_range = Range(file.first_row_id, file.first_row_id + file.row_count - 1) row_ranges = Range.and_(self.row_ranges, [file_range]) if len(row_ranges) == 0: return EmptyRecordBatchReader() - return RowIdFilterRecordBatchReader(create_record_reader(), file.first_row_id, row_ranges) + first_row_id = file.first_row_id + (begin_pos if file.file_name in shard_file_idx_map else 0) + return RowIdFilterRecordBatchReader(base, first_row_id, row_ranges) def _split_field_bunches(self, need_merge_files: List[DataFileMeta]) -> List[FieldBunch]: """Split files into field bunches.""" diff --git a/paimon-python/pypaimon/read/table_read.py b/paimon-python/pypaimon/read/table_read.py index e0a442f30582..1b7143948f52 100644 --- a/paimon-python/pypaimon/read/table_read.py +++ b/paimon-python/pypaimon/read/table_read.py @@ -28,17 +28,20 @@ SplitRead) from pypaimon.schema.data_types import DataField, PyarrowFieldParser from pypaimon.table.row.offset_row import OffsetRow +from pypaimon.table.special_fields import SpecialFields class TableRead: """Implementation of TableRead for native Python reading.""" - def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataField]): + def __init__(self, table, predicate: Optional[Predicate], read_type: List[DataField], + projection: Optional[List[str]] = None): from pypaimon.table.file_store_table import FileStoreTable self.table: FileStoreTable = table self.predicate = predicate self.read_type = read_type + self.projection = projection def to_iterator(self, splits: List[Split]) -> Iterator: def _record_generator(): @@ -59,15 +62,21 @@ def to_arrow_batch_reader(self, splits: List[Split]) -> pyarrow.ipc.RecordBatchR @staticmethod def _try_to_pad_batch_by_schema(batch: pyarrow.RecordBatch, target_schema): - if batch.schema.names == target_schema.names: + if batch.schema.names == target_schema.names and len(batch.schema.names) == len(target_schema.names): return batch - columns = [] num_rows = batch.num_rows + columns = [] + batch_column_names = batch.schema.names # py36: use schema.names (no RecordBatch.column_names) for field in target_schema: - if field.name in batch.column_names: + if field.name in batch_column_names: col = batch.column(field.name) + if col.type != field.type: + try: + col = col.cast(field.type) + except (pyarrow.ArrowInvalid, pyarrow.ArrowNotImplementedError): + col = pyarrow.nulls(num_rows, type=field.type) else: col = pyarrow.nulls(num_rows, type=field.type) columns.append(col) @@ -78,6 +87,17 @@ def to_arrow(self, splits: List[Split]) -> Optional[pyarrow.Table]: batch_reader = self.to_arrow_batch_reader(splits) schema = PyarrowFieldParser.from_paimon_schema(self.read_type) + + if self.projection is None: + table_field_names = set(f.name for f in self.table.fields) + output_schema_fields = [ + field for field in schema + if not SpecialFields.is_system_field(field.name) or field.name in table_field_names + ] + output_schema = pyarrow.schema(output_schema_fields) + else: + output_schema = schema + table_list = [] for batch in iter(batch_reader.read_next_batch, None): if batch.num_rows == 0: @@ -85,9 +105,15 @@ def to_arrow(self, splits: List[Split]) -> Optional[pyarrow.Table]: table_list.append(self._try_to_pad_batch_by_schema(batch, schema)) if not table_list: - return pyarrow.Table.from_arrays([pyarrow.array([], type=field.type) for field in schema], schema=schema) - else: - return pyarrow.Table.from_batches(table_list) + empty_arrays = [pyarrow.array([], type=field.type) for field in output_schema] + return pyarrow.Table.from_arrays(empty_arrays, schema=output_schema) + + concat_arrays = [ + pyarrow.concat_arrays([b.column(field.name) for b in table_list]) + for field in output_schema + ] + single_batch = pyarrow.RecordBatch.from_arrays(concat_arrays, schema=output_schema) + return pyarrow.Table.from_batches([single_batch], schema=output_schema) def _arrow_batch_generator(self, splits: List[Split], schema: pyarrow.Schema) -> Iterator[pyarrow.RecordBatch]: chunk_size = 65536 @@ -196,9 +222,11 @@ def _create_split_read(self, split: Split) -> SplitRead: row_tracking_enabled=False ) elif self.table.options.data_evolution_enabled(): + from pypaimon.globalindex.data_evolution_batch_scan import DataEvolutionBatchScan + predicate_for_read = DataEvolutionBatchScan.remove_row_id_filter(self.predicate) return DataEvolutionSplitRead( table=self.table, - predicate=None, # Never push predicate to split read + predicate=predicate_for_read, read_type=self.read_type, split=split, row_tracking_enabled=True diff --git a/paimon-python/pypaimon/tests/data_evolution_test.py b/paimon-python/pypaimon/tests/data_evolution_test.py index a9f0de508b06..787637387579 100644 --- a/paimon-python/pypaimon/tests/data_evolution_test.py +++ b/paimon-python/pypaimon/tests/data_evolution_test.py @@ -22,7 +22,13 @@ import pyarrow as pa from pypaimon import CatalogFactory, Schema +from pypaimon.common.predicate_builder import PredicateBuilder from pypaimon.manifest.manifest_list_manager import ManifestListManager +from pypaimon.read.reader.iface.record_batch_reader import RecordBatchReader +from pypaimon.read.reader.predicate_filter_record_batch_reader import ( + PredicateFilterRecordBatchReader, +) +from pypaimon.schema.data_types import AtomicType, DataField from pypaimon.snapshot.snapshot_manager import SnapshotManager @@ -86,6 +92,54 @@ def test_basic(self): ('f1', pa.int16()), ])) self.assertEqual(actual_data, expect_data) + self.assertEqual( + len(actual_data.schema), len(expect_data.schema), + 'Read output column count must match schema') + self.assertEqual( + actual_data.schema.names, expect_data.schema.names, + 'Read output column names must match schema') + + def test_partitioned_read_requested_column_missing_in_file(self): + pa_schema = pa.schema([('f0', pa.int32()), ('f1', pa.string()), ('dt', pa.string())]) + schema = Schema.from_pyarrow_schema( + pa_schema, + partition_keys=['dt'], + options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 'true'} + ) + self.catalog.create_table('default.test_partition_missing_col', schema, False) + table = self.catalog.get_table('default.test_partition_missing_col') + wb = table.new_batch_write_builder() + + tw1 = wb.new_write() + tc1 = wb.new_commit() + tw1.write_arrow(pa.Table.from_pydict( + {'f0': [1, 2], 'f1': ['a', 'b'], 'dt': ['p1', 'p1']}, + schema=pa_schema + )) + tc1.commit(tw1.prepare_commit()) + tw1.close() + tc1.close() + + tw2 = wb.new_write().with_write_type(['f0', 'dt']) + tc2 = wb.new_commit() + # Row key extractor uses table column indices; pass table-ordered data with null for f1 + tw2.write_arrow(pa.Table.from_pydict( + {'f0': [3, 4], 'f1': [None, None], 'dt': ['p1', 'p1']}, + schema=pa_schema + )) + tc2.commit(tw2.prepare_commit()) + tw2.close() + tc2.close() + + actual = table.new_read_builder().new_read().to_arrow(table.new_read_builder().new_scan().plan().splits()) + self.assertEqual(len(actual.schema), 3, 'Must have f0, f1, dt (no silent drop when f1 missing in file)') + self.assertEqual(actual.schema.names, ['f0', 'f1', 'dt']) + self.assertEqual(actual.num_rows, 4) + f1_col = actual.column('f1') + self.assertEqual(f1_col[0].as_py(), 'a') + self.assertEqual(f1_col[1].as_py(), 'b') + self.assertIsNone(f1_col[2].as_py()) + self.assertIsNone(f1_col[3].as_py()) # assert manifest file meta contains min and max row id manifest_list_manager = ManifestListManager(table) @@ -225,6 +279,14 @@ def test_with_slice(self): [2, 1001, 2001], "with_slice(1, 4) should return id in (2, 1001, 2001). Got ids=%s" % ids, ) + scan_oob = rb.new_scan().with_slice(10, 12) + splits_oob = scan_oob.plan().splits() + result_oob = rb.new_read().to_pandas(splits_oob) + self.assertEqual( + len(result_oob), + 0, + "with_slice(10, 12) on 6 rows should return 0 rows (out of bounds), got %d" % len(result_oob), + ) # Out-of-bounds slice: 6 rows total, slice(10, 12) should return 0 rows scan_oob = rb.new_scan().with_slice(10, 12) @@ -320,6 +382,8 @@ def test_multiple_appends(self): 'f2': ['b'] * 100 + ['y'] + ['d'], }, schema=simple_pa_schema) self.assertEqual(actual, expect) + self.assertEqual(len(actual.schema), len(expect.schema), 'Merge read output column count must match schema') + self.assertEqual(actual.schema.names, expect.schema.names, 'Merge read output column names must match schema') def test_disorder_cols_append(self): simple_pa_schema = pa.schema([ @@ -689,6 +753,7 @@ def test_read_row_tracking_metadata(self): pa.field('_SEQUENCE_NUMBER', pa.int64(), nullable=False), ])) self.assertEqual(actual_data, expect_data) + self.assertEqual(len(actual_data.schema), len(expect_data.schema), 'Read output column count must match schema') # write 2 table_write = write_builder.new_write().with_write_type(['f0']) @@ -724,6 +789,7 @@ def test_read_row_tracking_metadata(self): pa.field('_SEQUENCE_NUMBER', pa.int64(), nullable=False), ])) self.assertEqual(actual_data, expect_data) + self.assertEqual(len(actual_data.schema), len(expect_data.schema), 'Read output column count must match schema') def test_from_arrays_without_schema(self): schema = pa.schema([ @@ -742,3 +808,35 @@ def test_from_arrays_without_schema(self): rebuilt = pa.RecordBatch.from_arrays(arrays, names=batch.schema.names) self.assertTrue(rebuilt.schema.field('_ROW_ID').nullable) self.assertTrue(rebuilt.schema.field('_SEQUENCE_NUMBER').nullable) + + def test_with_filter(self): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array(['a', 'b', 'c'])], + names=['f0', 'f1'], + ) + fields = [ + DataField(0, 'f0', AtomicType('BIGINT')), + DataField(1, 'f1', AtomicType('STRING')), + DataField(2, 'f2', AtomicType('INT')), + ] + pb = PredicateBuilder(fields) + predicate = pb.greater_than('f2', 5) + + class _SingleBatchReader(RecordBatchReader): + def __init__(self, b): + self._batch, self._returned = b, False + + def read_arrow_batch(self): + if self._returned: + return None + self._returned = True + return self._batch + + def close(self): + pass + + filtered_reader = PredicateFilterRecordBatchReader( + _SingleBatchReader(batch), predicate + ) + with self.assertRaises(IndexError): + filtered_reader.read_arrow_batch() diff --git a/paimon-python/pypaimon/tests/shard_table_updator_test.py b/paimon-python/pypaimon/tests/shard_table_updator_test.py index 641f545b4711..1ce1e2f8bf8c 100644 --- a/paimon-python/pypaimon/tests/shard_table_updator_test.py +++ b/paimon-python/pypaimon/tests/shard_table_updator_test.py @@ -80,7 +80,7 @@ def test_compute_column_d_equals_c_plus_b_minus_a(self): # Step 3: Use ShardTableUpdator to compute d = c + b - a table_update = write_builder.new_update() - table_update.with_read_projection(['a', 'b', 'c']) + table_update.with_read_projection(['a', 'b', 'c', '_ROW_ID']) table_update.with_update_type(['d']) shard_updator = table_update.new_shard_updator(0, 1) @@ -93,7 +93,13 @@ def test_compute_column_d_equals_c_plus_b_minus_a(self): a_values = batch.column('a').to_pylist() b_values = batch.column('b').to_pylist() c_values = batch.column('c').to_pylist() - + row_id_values = batch.column('_ROW_ID').to_pylist() + self.assertEqual( + row_id_values, + list(range(len(a_values))), + '_ROW_ID should be [0, 1, 2, ...] for sequential rows', + ) + d_values = [c + b - a for a, b, c in zip(a_values, b_values, c_values)] # Create batch with d column @@ -316,5 +322,269 @@ def test_compute_column_with_existing_column(self): self.assertEqual(actual, expected) print("\n✅ Test passed! Column d = c + b - a computed correctly!") + def test_partial_shard_update_full_read_schema_unified(self): + table_schema = pa.schema([ + ('a', pa.int32()), + ('b', pa.int32()), + ('c', pa.int32()), + ('d', pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + table_schema, + options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 'true'}, + ) + name = self._create_unique_table_name() + self.catalog.create_table(name, schema, False) + table = self.catalog.get_table(name) + + # Two commits => two files (two first_row_id ranges) + for start, end in [(1, 10), (10, 20)]: + wb = table.new_batch_write_builder() + tw = wb.new_write().with_write_type(['a', 'b', 'c']) + tc = wb.new_commit() + data = pa.Table.from_pydict({ + 'a': list(range(start, end + 1)), + 'b': [i * 10 for i in range(start, end + 1)], + 'c': [i * 100 for i in range(start, end + 1)], + }, schema=pa.schema([ + ('a', pa.int32()), ('b', pa.int32()), ('c', pa.int32()), + ])) + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + # Only shard 0 runs => only first file gets d + wb = table.new_batch_write_builder() + upd = wb.new_update() + upd.with_read_projection(['a', 'b', 'c']) + upd.with_update_type(['d']) + shard0 = upd.new_shard_updator(0, 2) + reader = shard0.arrow_reader() + for batch in iter(reader.read_next_batch, None): + a_ = batch.column('a').to_pylist() + b_ = batch.column('b').to_pylist() + c_ = batch.column('c').to_pylist() + d_ = [c + b - a for a, b, c in zip(a_, b_, c_)] + shard0.update_by_arrow_batch(pa.RecordBatch.from_pydict( + {'d': d_}, schema=pa.schema([('d', pa.int32())]), + )) + tc = wb.new_commit() + tc.commit(shard0.prepare_commit()) + tc.close() + + rb = table.new_read_builder() + tr = rb.new_read() + actual = tr.to_arrow(rb.new_scan().plan().splits()) + self.assertEqual(actual.num_rows, 21) + d_col = actual.column('d') + # First 10 rows (shard 0): d = c+b-a + for i in range(10): + self.assertEqual(d_col[i].as_py(), (i + 1) * 100 + (i + 1) * 10 - (i + 1)) + # Rows 10-20 (shard 1 not run): d is null + for i in range(10, 21): + self.assertIsNone(d_col[i].as_py()) + + def test_with_shard_read_after_partial_shard_update(self): + table_schema = pa.schema([ + ('a', pa.int32()), + ('b', pa.int32()), + ('c', pa.int32()), + ('d', pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + table_schema, + options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 'true'}, + ) + name = self._create_unique_table_name() + self.catalog.create_table(name, schema, False) + table = self.catalog.get_table(name) + + for start, end in [(1, 10), (10, 20)]: + wb = table.new_batch_write_builder() + tw = wb.new_write().with_write_type(['a', 'b', 'c']) + tc = wb.new_commit() + data = pa.Table.from_pydict({ + 'a': list(range(start, end + 1)), + 'b': [i * 10 for i in range(start, end + 1)], + 'c': [i * 100 for i in range(start, end + 1)], + }, schema=pa.schema([ + ('a', pa.int32()), ('b', pa.int32()), ('c', pa.int32()), + ])) + tw.write_arrow(data) + tc.commit(tw.prepare_commit()) + tw.close() + tc.close() + + wb = table.new_batch_write_builder() + upd = wb.new_update() + upd.with_read_projection(['a', 'b', 'c']) + upd.with_update_type(['d']) + shard0 = upd.new_shard_updator(0, 2) + reader = shard0.arrow_reader() + for batch in iter(reader.read_next_batch, None): + a_ = batch.column('a').to_pylist() + b_ = batch.column('b').to_pylist() + c_ = batch.column('c').to_pylist() + d_ = [c + b - a for a, b, c in zip(a_, b_, c_)] + shard0.update_by_arrow_batch(pa.RecordBatch.from_pydict( + {'d': d_}, schema=pa.schema([('d', pa.int32())]), + )) + tc = wb.new_commit() + tc.commit(shard0.prepare_commit()) + tc.close() + + rb = table.new_read_builder() + tr = rb.new_read() + + splits_0 = rb.new_scan().with_shard(0, 2).plan().splits() + result_0 = tr.to_arrow(splits_0) + self.assertEqual(result_0.num_rows, 10) + d_col_0 = result_0.column('d') + for i in range(10): + self.assertEqual( + d_col_0[i].as_py(), + (i + 1) * 100 + (i + 1) * 10 - (i + 1), + "Shard 0 row %d: d should be c+b-a" % i, + ) + + splits_1 = rb.new_scan().with_shard(1, 2).plan().splits() + result_1 = tr.to_arrow(splits_1) + self.assertEqual(result_1.num_rows, 11) + d_col_1 = result_1.column('d') + for i in range(11): + self.assertIsNone(d_col_1[i].as_py(), "Shard 1 row %d: d should be null" % i) + + full_splits = rb.new_scan().plan().splits() + full_result = tr.to_arrow(full_splits) + self.assertEqual( + result_0.num_rows + result_1.num_rows, + full_result.num_rows, + "Shard 0 + Shard 1 row count should equal full scan (21)", + ) + + rb_filter = table.new_read_builder() + rb_filter.with_projection(['a', 'b', 'c', 'd', '_ROW_ID']) + pb = rb_filter.new_predicate_builder() + pred_row_id = pb.is_in('_ROW_ID', [0, 1, 2, 3, 4]) + rb_filter.with_filter(pred_row_id) + tr_filter = rb_filter.new_read() + splits_row_id = rb_filter.new_scan().plan().splits() + result_row_id = tr_filter.to_arrow(splits_row_id) + self.assertEqual(result_row_id.num_rows, 5, "Filter _ROW_ID in [0..4] should return 5 rows") + a_col = result_row_id.column('a') + d_col_r = result_row_id.column('d') + for i in range(5): + self.assertEqual(a_col[i].as_py(), i + 1) + self.assertEqual( + d_col_r[i].as_py(), + (i + 1) * 100 + (i + 1) * 10 - (i + 1), + "Filter-by-_row_id row %d: d should be c+b-a" % i, + ) + + rb_slice = table.new_read_builder() + tr_slice = rb_slice.new_read() + slice_0 = rb_slice.new_scan().with_slice(0, 10).plan().splits() + result_slice_0 = tr_slice.to_arrow(slice_0) + self.assertEqual(result_slice_0.num_rows, 10, "with_slice(0, 10) should return 10 rows") + d_s0 = result_slice_0.column('d') + for i in range(10): + self.assertEqual( + d_s0[i].as_py(), + (i + 1) * 100 + (i + 1) * 10 - (i + 1), + "Slice [0,10) row %d: d should be c+b-a" % i, + ) + slice_1 = rb_slice.new_scan().with_slice(10, 21).plan().splits() + result_slice_1 = tr_slice.to_arrow(slice_1) + self.assertEqual(result_slice_1.num_rows, 11, "with_slice(10, 21) should return 11 rows") + d_s1 = result_slice_1.column('d') + for i in range(11): + self.assertIsNone(d_s1[i].as_py(), "Slice [10,21) row %d: d should be null" % i) + + cross_slice = rb_slice.new_scan().with_slice(5, 16).plan().splits() + result_cross = tr_slice.to_arrow(cross_slice) + self.assertEqual( + result_cross.num_rows, 11, + "Cross-shard with_slice(5, 16) should return 11 rows (5 from file1 + 6 from file2)", + ) + a_cross = result_cross.column('a') + d_cross = result_cross.column('d') + for i in range(5): + self.assertEqual(a_cross[i].as_py(), 6 + i) + self.assertEqual( + d_cross[i].as_py(), + (6 + i) * 100 + (6 + i) * 10 - (6 + i), + "Cross-shard slice row %d (from file1): d should be c+b-a" % i, + ) + for i in range(5, 11): + self.assertEqual(a_cross[i].as_py(), 10 + (i - 5)) + self.assertIsNone(d_cross[i].as_py(), "Cross-shard slice row %d (from file2): d null" % i) + + rb_col = table.new_read_builder() + rb_col.with_projection(['a', 'b', 'c', 'd']) + pb_col = rb_col.new_predicate_builder() + pred_d = pb_col.is_in('d', [109, 218]) # d = c+b-a for a=1,2 + rb_col.with_filter(pred_d) + tr_col = rb_col.new_read() + splits_d = rb_col.new_scan().plan().splits() + result_d = tr_col.to_arrow(splits_d) + self.assertEqual(result_d.num_rows, 2, "Filter d in [109, 218] should return 2 rows") + a_d = result_d.column('a') + d_d = result_d.column('d') + self.assertEqual(a_d[0].as_py(), 1) + self.assertEqual(d_d[0].as_py(), 109) + self.assertEqual(a_d[1].as_py(), 2) + self.assertEqual(d_d[1].as_py(), 218) + + def test_read_projection(self): + table_schema = pa.schema([ + ('a', pa.int32()), + ('b', pa.int32()), + ('c', pa.int32()), + ]) + schema = Schema.from_pyarrow_schema( + table_schema, + options={'row-tracking.enabled': 'true', 'data-evolution.enabled': 'true'} + ) + name = self._create_unique_table_name('read_proj') + self.catalog.create_table(name, schema, False) + table = self.catalog.get_table(name) + + write_builder = table.new_batch_write_builder() + table_write = write_builder.new_write().with_write_type(['a', 'b', 'c']) + table_commit = write_builder.new_commit() + init_data = pa.Table.from_pydict( + {'a': [1, 2, 3], 'b': [10, 20, 30], 'c': [100, 200, 300]}, + schema=pa.schema([('a', pa.int32()), ('b', pa.int32()), ('c', pa.int32())]) + ) + table_write.write_arrow(init_data) + cmts = table_write.prepare_commit() + for cmt in cmts: + for nf in cmt.new_files: + nf.first_row_id = 0 + table_commit.commit(cmts) + table_write.close() + table_commit.close() + + table_update = write_builder.new_update() + table_update.with_read_projection(['a', 'b', 'c']) + table_update.with_update_type(['a']) + shard_updator = table_update.new_shard_updator(0, 1) + reader = shard_updator.arrow_reader() + + batch = reader.read_next_batch() + self.assertIsNotNone(batch, "Should have at least one batch") + actual_columns = set(batch.schema.names) + + expected_columns = {'a', 'b', 'c'} + self.assertEqual( + actual_columns, + expected_columns, + "with_read_projection(['a','b','c']) should return only a,b,c; " + "got %s. _ROW_ID and _SEQUENCE_NUMBER should NOT be returned when not in projection." + % actual_columns + ) + + if __name__ == '__main__': unittest.main()