diff --git a/pyiceberg/expressions/visitors.py b/pyiceberg/expressions/visitors.py index 26241d2351..99cbc0fb66 100644 --- a/pyiceberg/expressions/visitors.py +++ b/pyiceberg/expressions/visitors.py @@ -861,6 +861,7 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]): Args: file_schema (Schema): The schema of the file. case_sensitive (bool): Whether to consider case when binding a reference to a field in a schema, defaults to True. + projected_field_values (Dict[str, Any]): Values for projected fields not present in the data file. Raises: TypeError: In the case of an UnboundPredicate. @@ -869,10 +870,12 @@ class _ColumnNameTranslator(BooleanExpressionVisitor[BooleanExpression]): file_schema: Schema case_sensitive: bool + projected_field_values: Dict[str, Any] - def __init__(self, file_schema: Schema, case_sensitive: bool) -> None: + def __init__(self, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT) -> None: self.file_schema = file_schema self.case_sensitive = case_sensitive + self.projected_field_values = projected_field_values or {} def visit_true(self) -> BooleanExpression: return AlwaysTrue() @@ -897,9 +900,8 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi file_column_name = self.file_schema.find_column_name(field.field_id) if file_column_name is None: - # In the case of schema evolution, the column might not be present - # we can use the default value as a constant and evaluate it against - # the predicate + # In the case of schema evolution or column projection, the field might not be present in the file schema. + # we can use the projected value or the field's default value as a constant and evaluate it against the predicate pred: BooleanExpression if isinstance(predicate, BoundUnaryPredicate): pred = predicate.as_unbound(field.name) @@ -910,6 +912,14 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi else: raise ValueError(f"Unsupported predicate: {predicate}") + # In the order described by the "Column Projection" section of the Iceberg spec: + # https://iceberg.apache.org/spec/#column-projection + # Evaluate column projection first if it exists + if projected_field_value := self.projected_field_values.get(field.name): + if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(projected_field_value)): + return AlwaysTrue() + + # Evaluate initial_default value return ( AlwaysTrue() if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field.initial_default)) @@ -926,8 +936,10 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi raise ValueError(f"Unsupported predicate: {predicate}") -def translate_column_names(expr: BooleanExpression, file_schema: Schema, case_sensitive: bool) -> BooleanExpression: - return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive)) +def translate_column_names( + expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[str, Any] = EMPTY_DICT +) -> BooleanExpression: + return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values)) class _ExpressionFieldIDs(BooleanExpressionVisitor[Set[int]]): diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 2797371028..e6992843ca 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -1460,18 +1460,20 @@ def _task_to_record_batches( # the table format version. file_schema = pyarrow_to_schema(physical_schema, name_mapping, downcast_ns_timestamp_to_us=True) - pyarrow_filter = None - if bound_row_filter is not AlwaysTrue(): - translated_row_filter = translate_column_names(bound_row_filter, file_schema, case_sensitive=case_sensitive) - bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) - pyarrow_filter = expression_to_pyarrow(bound_file_filter) - # Apply column projection rules # https://iceberg.apache.org/spec/#column-projection should_project_columns, projected_missing_fields = _get_column_projection_values( task.file, projected_schema, partition_spec, file_schema.field_ids ) + pyarrow_filter = None + if bound_row_filter is not AlwaysTrue(): + translated_row_filter = translate_column_names( + bound_row_filter, file_schema, case_sensitive=case_sensitive, projected_field_values=projected_missing_fields + ) + bound_file_filter = bind(file_schema, translated_row_filter, case_sensitive=case_sensitive) + pyarrow_filter = expression_to_pyarrow(bound_file_filter) + file_project_schema = prune_columns(file_schema, projected_field_ids, select_full_types=False) fragment_scanner = ds.Scanner.from_fragment( diff --git a/tests/expressions/test_visitors.py b/tests/expressions/test_visitors.py index 273bd24c9b..f02aadfe44 100644 --- a/tests/expressions/test_visitors.py +++ b/tests/expressions/test_visitors.py @@ -72,6 +72,7 @@ expression_to_plain_format, rewrite_not, rewrite_to_dnf, + translate_column_names, visit, visit_bound_predicate, ) @@ -79,6 +80,7 @@ from pyiceberg.schema import Accessor, Schema from pyiceberg.typedef import Record from pyiceberg.types import ( + BooleanType, DoubleType, FloatType, IcebergType, @@ -1623,3 +1625,282 @@ def test_expression_evaluator_null() -> None: assert expression_evaluator(schema, LessThan("a", 1), case_sensitive=True)(struct) is False assert expression_evaluator(schema, StartsWith("a", 1), case_sensitive=True)(struct) is False assert expression_evaluator(schema, NotStartsWith("a", 1), case_sensitive=True)(struct) is True + + +def test_translate_column_names_simple_case(table_schema_simple: Schema) -> None: + """Test translate_column_names with matching column names.""" + # Create a bound expression using the original schema + unbound_expr = EqualTo("foo", "test_value") + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=table_schema_simple, case_sensitive=True)) + + # File schema has the same column names + file_schema = Schema( + NestedField(field_id=1, name="foo", field_type=StringType(), required=False), + NestedField(field_id=2, name="bar", field_type=IntegerType(), required=True), + NestedField(field_id=3, name="baz", field_type=BooleanType(), required=False), + schema_id=1, + ) + + # Translate column names + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) + + # Should return an unbound expression with the same column name since they match + assert isinstance(translated_expr, EqualTo) + assert translated_expr.term == Reference("foo") + assert translated_expr.literal == literal("test_value") + + +def test_translate_column_names_different_column_names() -> None: + """Test translate_column_names with different column names in file schema.""" + # Original schema + original_schema = Schema( + NestedField(field_id=1, name="original_name", field_type=StringType(), required=False), + schema_id=1, + ) + + # Create bound expression + unbound_expr = EqualTo("original_name", "test_value") + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema has different column name but same field ID + file_schema = Schema( + NestedField(field_id=1, name="file_column_name", field_type=StringType(), required=False), + schema_id=1, + ) + + # Translate column names + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) + + # Should use the file schema's column name + assert isinstance(translated_expr, EqualTo) + assert translated_expr.term == Reference("file_column_name") + assert translated_expr.literal == literal("test_value") + + +def test_translate_column_names_missing_column() -> None: + """Test translate_column_names when column is missing from file schema (such as in schema evolution).""" + # Original schema + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False), + schema_id=1, + ) + + # Create bound expression for the missing column + unbound_expr = EqualTo("missing_col", 42) + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Translate column names + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) + + # missing_col's default initial_default (None) does not match the expression literal (42) + assert translated_expr == AlwaysFalse() + + +def test_translate_column_names_missing_column_match_null() -> None: + """Test translate_column_names when missing column matches null.""" + # Original schema + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False), + schema_id=1, + ) + + # Create bound expression for the missing column + unbound_expr = IsNull("missing_col") + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Translate column names + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) + + # Should evaluate to AlwaysTrue because the missing column is treated as null + # missing_col's default initial_default (None) satisfies the IsNull predicate + assert translated_expr == AlwaysTrue() + + +def test_translate_column_names_missing_column_with_initial_default() -> None: + """Test translate_column_names when missing column's initial_default matches expression.""" + # Original schema + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=42), + schema_id=1, + ) + + # Create bound expression for the missing column + unbound_expr = EqualTo("missing_col", 42) + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Translate column names + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) + + # Should evaluate to AlwaysTrue because the initial_default value (42) matches the literal (42) + assert translated_expr == AlwaysTrue() + + +def test_translate_column_names_missing_column_with_initial_default_mismatch() -> None: + """Test translate_column_names when missing column's initial_default doesn't match expression.""" + # Original schema + original_schema = Schema( + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10), + schema_id=1, + ) + + # Create bound expression that won't match the default value + unbound_expr = EqualTo("missing_col", 42) + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema doesn't have this column + file_schema = Schema( + NestedField(field_id=1, name="other_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Translate column names + translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True) + + # Should evaluate to AlwaysFalse because initial_default value (10) doesn't match literal (42) + assert translated_expr == AlwaysFalse() + + +def test_translate_column_names_missing_column_with_projected_field_matches() -> None: + """Test translate_column_names with projected field value that matches expression.""" + # Original schema with a field that has no initial_default (defaults to None) + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False), + schema_id=1, + ) + + # Create bound expression for the missing column + unbound_expr = EqualTo("missing_col", 42) + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Projected column that is missing in the file schema + projected_field_values = {"missing_col": 42} + + # Translate column names + translated_expr = translate_column_names( + bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values + ) + + # Should evaluate to AlwaysTrue since projected field value matches the expression literal + # even though the field is missing in the file schema + assert translated_expr == AlwaysTrue() + + +def test_translate_column_names_missing_column_with_projected_field_mismatch() -> None: + """Test translate_column_names with projected field value that doesn't match expression.""" + # Original schema with a field that has no initial_default (defaults to None) + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False), + schema_id=1, + ) + + # Create bound expression for the missing column + unbound_expr = EqualTo("missing_col", 42) + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Projected column that is missing in the file schema + projected_field_values = {"missing_col": 1} + + # Translate column names + translated_expr = translate_column_names( + bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values + ) + + # Should evaluate to AlwaysFalse since projected field value does not match the expression literal + assert translated_expr == AlwaysFalse() + + +def test_translate_column_names_missing_column_projected_field_fallbacks_to_initial_default() -> None: + """Test translate_column_names when projected field value doesn't match but initial_default does.""" + # Original schema with a field that has an initial_default + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=42), + schema_id=1, + ) + + # Create bound expression for the missing column that would match initial_default + unbound_expr = EqualTo("missing_col", 42) + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Projected field value that differs from both the expression literal and initial_default + projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42) + + # Translate column names + translated_expr = translate_column_names( + bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values + ) + + # Should evaluate to AlwaysTrue since projected field value doesn't match but initial_default does + assert translated_expr == AlwaysTrue() + + +def test_translate_column_names_missing_column_projected_field_matches_initial_default_mismatch() -> None: + """Test translate_column_names when both projected field value and initial_default doesn't match.""" + # Original schema with a field that has an initial_default that doesn't match the expression + original_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + NestedField(field_id=2, name="missing_col", field_type=IntegerType(), required=False, initial_default=10), + schema_id=1, + ) + + # Create bound expression for the missing column + unbound_expr = EqualTo("missing_col", 42) + bound_expr = visit(unbound_expr, visitor=BindVisitor(schema=original_schema, case_sensitive=True)) + + # File schema only has the existing column (field_id=1), missing field_id=2 + file_schema = Schema( + NestedField(field_id=1, name="existing_col", field_type=StringType(), required=False), + schema_id=1, + ) + + # Projected field value that matches the expression literal + projected_field_values = {"missing_col": 10} # This doesn't match expression literal (42) + + # Translate column names + translated_expr = translate_column_names( + bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values + ) + + # Should evaluate to AlwaysFalse since both projected field value and initial_default does not match + assert translated_expr == AlwaysFalse() diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index 4f121ba3bc..ac16ef18f6 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -1197,6 +1197,16 @@ def test_identity_transform_column_projection(tmp_path: str, catalog: InMemoryCa }, schema=schema, ) + # Test that row filter works with partition value projection + assert table.scan(row_filter="partition_id = 1").to_arrow() == pa.table( + { + "other_field": ["foo", "bar", "baz"], + "partition_id": [1, 1, 1], + }, + schema=schema, + ) + # Test that row filter does not return any rows for a non-existing partition value + assert len(table.scan(row_filter="partition_id = -1").to_arrow()) == 0 def test_identity_transform_columns_projection(tmp_path: str, catalog: InMemoryCatalog) -> None: