From a28b70bf8e52636364d007f2db6a406189215b3d Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Fri, 20 Jun 2025 23:04:38 +0000 Subject: [PATCH 1/3] test: Add cross-validation for equals op between engines --- bigframes/session/polars_executor.py | 3 +- .../small/engines/test_comparison_ops.py | 53 +++++++++++++++++++ 2 files changed, 55 insertions(+), 1 deletion(-) create mode 100644 tests/system/small/engines/test_comparison_ops.py diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 6f1f35764c..6aa705c624 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -31,9 +31,10 @@ nodes.OrderByNode, nodes.ReversedNode, nodes.SelectionNode, + nodes.ProjectionNode, ) -_COMPATIBLE_SCALAR_OPS = () +_COMPATIBLE_SCALAR_OPS = (bigframes.operations.eq_op,) def _get_expr_ops(expr: expression.Expression) -> set[bigframes.operations.ScalarOp]: diff --git a/tests/system/small/engines/test_comparison_ops.py b/tests/system/small/engines/test_comparison_ops.py new file mode 100644 index 0000000000..2c7a5efbba --- /dev/null +++ b/tests/system/small/engines/test_comparison_ops.py @@ -0,0 +1,53 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import pytest + +from bigframes.core import array_value +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + + +def apply_op_pairwise( + array: array_value.ArrayValue, op: ops.BinaryOp +) -> array_value.ArrayValue: + exprs = [] + for l_arg, r_arg in itertools.permutations(array.column_ids, 2): + try: + _ = op.output_type( + array.get_column_type(l_arg), array.get_column_type(r_arg) + ) + exprs.append(op.as_expr(l_arg, r_arg)) + except TypeError: + continue + assert len(exprs) > 0 + new_arr, _ = array.compute_values(exprs) + return new_arr + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +def test_engines_project_eq_op( + scalars_array_value: array_value.ArrayValue, + engine, +): + arr = apply_op_pairwise(scalars_array_value, ops.eq_op) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) From 6ea9b94a185683fca712ef3074fc4a7e945adfe2 Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 26 Jun 2025 00:40:42 +0000 Subject: [PATCH 2/3] add lowering rule for eq_op --- bigframes/core/compile/polars/lowering.py | 38 ++++++++++++++++++- .../small/engines/test_comparison_ops.py | 9 ++++- 2 files changed, 43 insertions(+), 4 deletions(-) diff --git a/bigframes/core/compile/polars/lowering.py b/bigframes/core/compile/polars/lowering.py index 88e2d6e599..fd46a6b6b4 100644 --- a/bigframes/core/compile/polars/lowering.py +++ b/bigframes/core/compile/polars/lowering.py @@ -15,12 +15,22 @@ from bigframes import dtypes from bigframes.core import bigframe_node, expression from bigframes.core.rewrite import op_lowering -from bigframes.operations import numeric_ops +from bigframes.operations import comparison_ops, numeric_ops import bigframes.operations as ops # TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops) +class LowerEqRule(op_lowering.OpLoweringRule): + @property + def op(self) -> type[ops.ScalarOp]: + return comparison_ops.EqOp + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + larg, rarg = _coerce_comparables(expr.children[0], expr.children[1]) + return ops.eq_op.as_expr(larg, rarg) + + class LowerFloorDivRule(op_lowering.OpLoweringRule): @property def op(self) -> type[ops.ScalarOp]: @@ -40,7 +50,31 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression: return ops.where_op.as_expr(zero_result, divisor_is_zero, expr) -POLARS_LOWERING_RULES = (LowerFloorDivRule(),) +def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression): + + target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type) + if expr1.output_type != target_type: + expr1 = _lower_cast(ops.AsTypeOp(target_type), expr1) + if expr2.output_type != target_type: + expr2 = _lower_cast(ops.AsTypeOp(target_type), expr2) + return expr1, expr2 + + +def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): + if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type): + # bool -> decimal needs two-step cast + new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) + return cast_op.as_expr(new_arg) + if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE: + new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) + return cast_op.as_expr(new_arg) + return cast_op.as_expr(arg) + + +POLARS_LOWERING_RULES = ( + LowerEqRule(), + LowerFloorDivRule(), +) def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode: diff --git a/tests/system/small/engines/test_comparison_ops.py b/tests/system/small/engines/test_comparison_ops.py index 2c7a5efbba..2b69b998da 100644 --- a/tests/system/small/engines/test_comparison_ops.py +++ b/tests/system/small/engines/test_comparison_ops.py @@ -28,10 +28,12 @@ def apply_op_pairwise( - array: array_value.ArrayValue, op: ops.BinaryOp + array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[] ) -> array_value.ArrayValue: exprs = [] for l_arg, r_arg in itertools.permutations(array.column_ids, 2): + if (l_arg in excluded_cols) or (r_arg in excluded_cols): + continue try: _ = op.output_type( array.get_column_type(l_arg), array.get_column_type(r_arg) @@ -49,5 +51,8 @@ def test_engines_project_eq_op( scalars_array_value: array_value.ArrayValue, engine, ): - arr = apply_op_pairwise(scalars_array_value, ops.eq_op) + # exclude string cols as does not contain dates + arr = apply_op_pairwise( + scalars_array_value, ops.eq_op, excluded_cols=["string_col"] + ) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) From d9f38d1dcd4e60e026dd43b552654b579962155b Mon Sep 17 00:00:00 2001 From: Trevor Bergeron Date: Thu, 26 Jun 2025 20:02:24 +0000 Subject: [PATCH 3/3] generalize to comparison ops --- bigframes/core/compile/polars/lowering.py | 31 ++++++++++++++----- bigframes/core/compile/scalar_op_compiler.py | 18 +++++++++++ bigframes/session/polars_executor.py | 10 +++++- .../small/engines/test_comparison_ops.py | 24 ++++++++++---- 4 files changed, 69 insertions(+), 14 deletions(-) diff --git a/bigframes/core/compile/polars/lowering.py b/bigframes/core/compile/polars/lowering.py index fd46a6b6b4..48d63e9ed9 100644 --- a/bigframes/core/compile/polars/lowering.py +++ b/bigframes/core/compile/polars/lowering.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses + from bigframes import dtypes from bigframes.core import bigframe_node, expression from bigframes.core.rewrite import op_lowering @@ -21,14 +23,18 @@ # TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops) -class LowerEqRule(op_lowering.OpLoweringRule): +@dataclasses.dataclass +class CoerceArgsRule(op_lowering.OpLoweringRule): + op_type: type[ops.BinaryOp] + @property def op(self) -> type[ops.ScalarOp]: - return comparison_ops.EqOp + return self.op_type def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, self.op_type) larg, rarg = _coerce_comparables(expr.children[0], expr.children[1]) - return ops.eq_op.as_expr(larg, rarg) + return expr.op.as_expr(larg, rarg) class LowerFloorDivRule(op_lowering.OpLoweringRule): @@ -60,19 +66,30 @@ def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expressi return expr1, expr2 +# TODO: Need to handle bool->string cast to get capitalization correct def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type): # bool -> decimal needs two-step cast new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) return cast_op.as_expr(new_arg) - if arg.output_type == dtypes.BOOL_DTYPE and cast_op.to_type == dtypes.STRING_DTYPE: - new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) - return cast_op.as_expr(new_arg) return cast_op.as_expr(arg) +LOWER_COMPARISONS = tuple( + CoerceArgsRule(op) + for op in ( + comparison_ops.EqOp, + comparison_ops.EqNullsMatchOp, + comparison_ops.NeOp, + comparison_ops.LtOp, + comparison_ops.GtOp, + comparison_ops.LeOp, + comparison_ops.GeOp, + ) +) + POLARS_LOWERING_RULES = ( - LowerEqRule(), + *LOWER_COMPARISONS, LowerFloorDivRule(), ) diff --git a/bigframes/core/compile/scalar_op_compiler.py b/bigframes/core/compile/scalar_op_compiler.py index 075089bb7a..30da6b2cb2 100644 --- a/bigframes/core/compile/scalar_op_compiler.py +++ b/bigframes/core/compile/scalar_op_compiler.py @@ -1498,6 +1498,7 @@ def eq_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x == y @@ -1507,6 +1508,7 @@ def eq_nulls_match_op( y: ibis_types.Value, ): """Variant of eq_op where nulls match each other. Only use where dtypes are known to be same.""" + x, y = _coerce_comparables(x, y) literal = ibis_types.literal("$NULL_SENTINEL$") if hasattr(x, "fill_null"): left = x.cast(ibis_dtypes.str).fill_null(literal) @@ -1523,6 +1525,7 @@ def ne_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x != y @@ -1534,6 +1537,17 @@ def _null_or_value(value: ibis_types.Value, where_value: ibis_types.BooleanValue ) +def _coerce_comparables( + x: ibis_types.Value, + y: ibis_types.Value, +): + if x.type().is_boolean() and not y.type().is_boolean(): + x = x.cast(ibis_dtypes.int64) + elif y.type().is_boolean() and not x.type().is_boolean(): + y = y.cast(ibis_dtypes.int64) + return x, y + + @scalar_op_compiler.register_binary_op(ops.and_op) def and_op( x: ibis_types.Value, @@ -1735,6 +1749,7 @@ def lt_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x < y @@ -1744,6 +1759,7 @@ def le_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x <= y @@ -1753,6 +1769,7 @@ def gt_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x > y @@ -1762,6 +1779,7 @@ def ge_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x >= y diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index a95b7751ba..0a74cfc96d 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -35,7 +35,15 @@ nodes.SliceNode, ) -_COMPATIBLE_SCALAR_OPS = (bigframes.operations.eq_op,) +_COMPATIBLE_SCALAR_OPS = ( + bigframes.operations.eq_op, + bigframes.operations.eq_null_match_op, + bigframes.operations.ne_op, + bigframes.operations.gt_op, + bigframes.operations.lt_op, + bigframes.operations.ge_op, + bigframes.operations.le_op, +) def _get_expr_ops(expr: expression.Expression) -> set[bigframes.operations.ScalarOp]: diff --git a/tests/system/small/engines/test_comparison_ops.py b/tests/system/small/engines/test_comparison_ops.py index 2b69b998da..fefff93f58 100644 --- a/tests/system/small/engines/test_comparison_ops.py +++ b/tests/system/small/engines/test_comparison_ops.py @@ -26,6 +26,8 @@ # Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. REFERENCE_ENGINE = polars_executor.PolarsExecutor() +# numeric domain + def apply_op_pairwise( array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[] @@ -47,12 +49,22 @@ def apply_op_pairwise( @pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) -def test_engines_project_eq_op( - scalars_array_value: array_value.ArrayValue, - engine, +@pytest.mark.parametrize( + "op", + [ + ops.eq_op, + ops.eq_null_match_op, + ops.ne_op, + ops.gt_op, + ops.lt_op, + ops.le_op, + ops.ge_op, + ], +) +def test_engines_project_comparison_op( + scalars_array_value: array_value.ArrayValue, engine, op ): # exclude string cols as does not contain dates - arr = apply_op_pairwise( - scalars_array_value, ops.eq_op, excluded_cols=["string_col"] - ) + # bool col actually doesn't work properly for bq engine + arr = apply_op_pairwise(scalars_array_value, op, excluded_cols=["string_col"]) assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)