diff --git a/bigframes/core/compile/polars/lowering.py b/bigframes/core/compile/polars/lowering.py index 88e2d6e599..48d63e9ed9 100644 --- a/bigframes/core/compile/polars/lowering.py +++ b/bigframes/core/compile/polars/lowering.py @@ -12,15 +12,31 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses + from bigframes import dtypes from bigframes.core import bigframe_node, expression from bigframes.core.rewrite import op_lowering -from bigframes.operations import numeric_ops +from bigframes.operations import comparison_ops, numeric_ops import bigframes.operations as ops # TODO: Would be more precise to actually have separate op set for polars ops (where they diverge from the original ops) +@dataclasses.dataclass +class CoerceArgsRule(op_lowering.OpLoweringRule): + op_type: type[ops.BinaryOp] + + @property + def op(self) -> type[ops.ScalarOp]: + return self.op_type + + def lower(self, expr: expression.OpExpression) -> expression.Expression: + assert isinstance(expr.op, self.op_type) + larg, rarg = _coerce_comparables(expr.children[0], expr.children[1]) + return expr.op.as_expr(larg, rarg) + + class LowerFloorDivRule(op_lowering.OpLoweringRule): @property def op(self) -> type[ops.ScalarOp]: @@ -40,7 +56,42 @@ def lower(self, expr: expression.OpExpression) -> expression.Expression: return ops.where_op.as_expr(zero_result, divisor_is_zero, expr) -POLARS_LOWERING_RULES = (LowerFloorDivRule(),) +def _coerce_comparables(expr1: expression.Expression, expr2: expression.Expression): + + target_type = dtypes.coerce_to_common(expr1.output_type, expr2.output_type) + if expr1.output_type != target_type: + expr1 = _lower_cast(ops.AsTypeOp(target_type), expr1) + if expr2.output_type != target_type: + expr2 = _lower_cast(ops.AsTypeOp(target_type), expr2) + return expr1, expr2 + + +# TODO: Need to handle bool->string cast to get capitalization correct +def _lower_cast(cast_op: ops.AsTypeOp, arg: expression.Expression): + if arg.output_type == dtypes.BOOL_DTYPE and dtypes.is_numeric(cast_op.to_type): + # bool -> decimal needs two-step cast + new_arg = ops.AsTypeOp(to_type=dtypes.INT_DTYPE).as_expr(arg) + return cast_op.as_expr(new_arg) + return cast_op.as_expr(arg) + + +LOWER_COMPARISONS = tuple( + CoerceArgsRule(op) + for op in ( + comparison_ops.EqOp, + comparison_ops.EqNullsMatchOp, + comparison_ops.NeOp, + comparison_ops.LtOp, + comparison_ops.GtOp, + comparison_ops.LeOp, + comparison_ops.GeOp, + ) +) + +POLARS_LOWERING_RULES = ( + *LOWER_COMPARISONS, + LowerFloorDivRule(), +) def lower_ops_to_polars(root: bigframe_node.BigFrameNode) -> bigframe_node.BigFrameNode: diff --git a/bigframes/core/compile/scalar_op_compiler.py b/bigframes/core/compile/scalar_op_compiler.py index 075089bb7a..30da6b2cb2 100644 --- a/bigframes/core/compile/scalar_op_compiler.py +++ b/bigframes/core/compile/scalar_op_compiler.py @@ -1498,6 +1498,7 @@ def eq_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x == y @@ -1507,6 +1508,7 @@ def eq_nulls_match_op( y: ibis_types.Value, ): """Variant of eq_op where nulls match each other. Only use where dtypes are known to be same.""" + x, y = _coerce_comparables(x, y) literal = ibis_types.literal("$NULL_SENTINEL$") if hasattr(x, "fill_null"): left = x.cast(ibis_dtypes.str).fill_null(literal) @@ -1523,6 +1525,7 @@ def ne_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x != y @@ -1534,6 +1537,17 @@ def _null_or_value(value: ibis_types.Value, where_value: ibis_types.BooleanValue ) +def _coerce_comparables( + x: ibis_types.Value, + y: ibis_types.Value, +): + if x.type().is_boolean() and not y.type().is_boolean(): + x = x.cast(ibis_dtypes.int64) + elif y.type().is_boolean() and not x.type().is_boolean(): + y = y.cast(ibis_dtypes.int64) + return x, y + + @scalar_op_compiler.register_binary_op(ops.and_op) def and_op( x: ibis_types.Value, @@ -1735,6 +1749,7 @@ def lt_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x < y @@ -1744,6 +1759,7 @@ def le_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x <= y @@ -1753,6 +1769,7 @@ def gt_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x > y @@ -1762,6 +1779,7 @@ def ge_op( x: ibis_types.Value, y: ibis_types.Value, ): + x, y = _coerce_comparables(x, y) return x >= y diff --git a/bigframes/session/polars_executor.py b/bigframes/session/polars_executor.py index 24acda35dc..ec00e38606 100644 --- a/bigframes/session/polars_executor.py +++ b/bigframes/session/polars_executor.py @@ -32,11 +32,20 @@ nodes.OrderByNode, nodes.ReversedNode, nodes.SelectionNode, + nodes.ProjectionNode, nodes.SliceNode, nodes.AggregateNode, ) -_COMPATIBLE_SCALAR_OPS = () +_COMPATIBLE_SCALAR_OPS = ( + bigframes.operations.eq_op, + bigframes.operations.eq_null_match_op, + bigframes.operations.ne_op, + bigframes.operations.gt_op, + bigframes.operations.lt_op, + bigframes.operations.ge_op, + bigframes.operations.le_op, +) _COMPATIBLE_AGG_OPS = (agg_ops.SizeOp, agg_ops.SizeUnaryOp) diff --git a/tests/system/small/engines/test_comparison_ops.py b/tests/system/small/engines/test_comparison_ops.py new file mode 100644 index 0000000000..fefff93f58 --- /dev/null +++ b/tests/system/small/engines/test_comparison_ops.py @@ -0,0 +1,70 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools + +import pytest + +from bigframes.core import array_value +import bigframes.operations as ops +from bigframes.session import polars_executor +from bigframes.testing.engine_utils import assert_equivalence_execution + +pytest.importorskip("polars") + +# Polars used as reference as its fast and local. Generally though, prefer gbq engine where they disagree. +REFERENCE_ENGINE = polars_executor.PolarsExecutor() + +# numeric domain + + +def apply_op_pairwise( + array: array_value.ArrayValue, op: ops.BinaryOp, excluded_cols=[] +) -> array_value.ArrayValue: + exprs = [] + for l_arg, r_arg in itertools.permutations(array.column_ids, 2): + if (l_arg in excluded_cols) or (r_arg in excluded_cols): + continue + try: + _ = op.output_type( + array.get_column_type(l_arg), array.get_column_type(r_arg) + ) + exprs.append(op.as_expr(l_arg, r_arg)) + except TypeError: + continue + assert len(exprs) > 0 + new_arr, _ = array.compute_values(exprs) + return new_arr + + +@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize( + "op", + [ + ops.eq_op, + ops.eq_null_match_op, + ops.ne_op, + ops.gt_op, + ops.lt_op, + ops.le_op, + ops.ge_op, + ], +) +def test_engines_project_comparison_op( + scalars_array_value: array_value.ArrayValue, engine, op +): + # exclude string cols as does not contain dates + # bool col actually doesn't work properly for bq engine + arr = apply_op_pairwise(scalars_array_value, op, excluded_cols=["string_col"]) + assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)