Skip to content

Commit fc9b79d

Browse files
committed
Enable add tests migrated Hive tables
1 parent 8042d82 commit fc9b79d

File tree

3 files changed

+21
-40
lines changed

3 files changed

+21
-40
lines changed

pyiceberg/expressions/visitors.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -915,17 +915,13 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
915915

916916
# In the order described by the "Column Projection" section of the Iceberg spec:
917917
# https://iceberg.apache.org/spec/#column-projection
918-
# Evaluate column projection first if it exists
919-
if field_id in self.projected_field_values:
920-
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(
921-
Record(self.projected_field_values[field_id])
922-
):
923-
return AlwaysTrue()
924-
925-
# Evaluate initial_default value
918+
# Evaluate column projection first if it exists, otherwise default to the initial-default-value
919+
field_value = (
920+
self.projected_field_values[field_id] if field.field_id in self.projected_field_values else field.initial_default
921+
)
926922
return (
927923
AlwaysTrue()
928-
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field.initial_default))
924+
if expression_evaluator(Schema(field), pred, case_sensitive=self.case_sensitive)(Record(field_value))
929925
else AlwaysFalse()
930926
)
931927

@@ -940,7 +936,7 @@ def visit_bound_predicate(self, predicate: BoundPredicate[L]) -> BooleanExpressi
940936

941937

942938
def translate_column_names(
943-
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool, projected_field_values: Dict[int, Any] = EMPTY_DICT
939+
expr: BooleanExpression, file_schema: Schema, case_sensitive: bool = True, projected_field_values: Dict[int, Any] = EMPTY_DICT
944940
) -> BooleanExpression:
945941
return visit(expr, _ColumnNameTranslator(file_schema, case_sensitive, projected_field_values))
946942

tests/expressions/test_visitors.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1750,7 +1750,7 @@ def test_translate_column_names_missing_column_match_explicit_null() -> None:
17501750
)
17511751

17521752
# Translate column names
1753-
translated_expr = translate_column_names(bound_expr, file_schema, case_sensitive=True, projected_field_values={2: None})
1753+
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: None})
17541754

17551755
# Should evaluate to AlwaysTrue because the missing column is treated as null
17561756
# missing_col's default initial_default (None) satisfies the IsNull predicate
@@ -1828,12 +1828,7 @@ def test_translate_column_names_missing_column_with_projected_field_matches() ->
18281828
)
18291829

18301830
# Projected column that is missing in the file schema
1831-
projected_field_values = {2: 42}
1832-
1833-
# Translate column names
1834-
translated_expr = translate_column_names(
1835-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1836-
)
1831+
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 42})
18371832

18381833
# Should evaluate to AlwaysTrue since projected field value matches the expression literal
18391834
# even though the field is missing in the file schema
@@ -1860,12 +1855,7 @@ def test_translate_column_names_missing_column_with_projected_field_mismatch() -
18601855
)
18611856

18621857
# Projected column that is missing in the file schema
1863-
projected_field_values = {2: 1}
1864-
1865-
# Translate column names
1866-
translated_expr = translate_column_names(
1867-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1868-
)
1858+
translated_expr = translate_column_names(bound_expr, file_schema, projected_field_values={2: 1})
18691859

18701860
# Should evaluate to AlwaysFalse since projected field value does not match the expression literal
18711861
assert translated_expr == AlwaysFalse()
@@ -1891,15 +1881,14 @@ def test_translate_column_names_missing_column_projected_field_fallbacks_to_init
18911881
)
18921882

18931883
# Projected field value that differs from both the expression literal and initial_default
1894-
projected_field_values = {2: 10} # This doesn't match expression literal (42)
1895-
1896-
# Translate column names
18971884
translated_expr = translate_column_names(
1898-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1885+
bound_expr,
1886+
file_schema,
1887+
projected_field_values={2: 10}, # This doesn't match expression literal (42)
18991888
)
19001889

1901-
# Should evaluate to AlwaysTrue since projected field value doesn't match but initial_default does
1902-
assert translated_expr == AlwaysTrue()
1890+
# Should evaluate to AlwaysFalse since projected field value doesn't
1891+
assert translated_expr == AlwaysFalse()
19031892

19041893

19051894
def test_translate_column_names_missing_column_projected_field_matches_initial_default_mismatch() -> None:
@@ -1922,11 +1911,10 @@ def test_translate_column_names_missing_column_projected_field_matches_initial_d
19221911
)
19231912

19241913
# Projected field value that matches the expression literal
1925-
projected_field_values = {2: 10} # This doesn't match expression literal (42)
1926-
1927-
# Translate column names
19281914
translated_expr = translate_column_names(
1929-
bound_expr, file_schema, case_sensitive=True, projected_field_values=projected_field_values
1915+
bound_expr,
1916+
file_schema,
1917+
projected_field_values={2: 10}, # This doesn't match expression literal (42)
19301918
)
19311919

19321920
# Should evaluate to AlwaysFalse since both projected field value and initial_default does not match

tests/integration/test_hive_migration.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import time
18+
from datetime import date
1819

1920
import pytest
2021
from pyspark.sql import SparkSession
@@ -75,12 +76,8 @@ def test_migrate_table(
7576
tbl = session_catalog_hive.load_table(dst_table_identifier)
7677
assert tbl.schema().column_names == ["number", "dt"]
7778

78-
# TODO: Returns the primitive type (int), rather than the logical type
79-
# assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {'2022-01-01', '2023-01-01'}
80-
79+
assert set(tbl.scan().to_arrow().column(1).combine_chunks().tolist()) == {date(2023, 1, 1), date(2022, 1, 1)}
8180
assert tbl.scan(row_filter="number > 3").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6]
82-
8381
assert tbl.scan(row_filter="dt == '2023-01-01'").to_arrow().column(0).combine_chunks().tolist() == [4, 5, 6]
84-
85-
# TODO: Issue with filtering the projected column
86-
# assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]
82+
assert tbl.scan(row_filter="dt == '2022-01-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]
83+
assert tbl.scan(row_filter="dt < '2022-02-01'").to_arrow().column(0).combine_chunks().tolist() == [1, 2, 3]

0 commit comments

Comments
 (0)