From b76268c7f269bb06896faabff0205b1706502356 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Thu, 23 Oct 2025 21:49:47 +0000 Subject: [PATCH] refactor: add parenthesization for binary operations --- .../core/compile/sqlglot/scalar_compiler.py | 43 ++++++++++++++++++- .../compile/sqlglot/test_scalar_compiler.py | 37 ++++++++++++++++ 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/bigframes/core/compile/sqlglot/scalar_compiler.py b/bigframes/core/compile/sqlglot/scalar_compiler.py index 8167f40fc3..1da58871c7 100644 --- a/bigframes/core/compile/sqlglot/scalar_compiler.py +++ b/bigframes/core/compile/sqlglot/scalar_compiler.py @@ -31,6 +31,37 @@ class ScalarOpCompiler: typing.Callable[[typing.Sequence[TypedExpr], ops.RowOp], sge.Expression], ] = {} + # A set of SQLGlot classes that may need to be parenthesized + SQLGLOT_NEEDS_PARENS = { + # Numeric operations + sge.Add, + sge.Sub, + sge.Mul, + sge.Div, + sge.Mod, + sge.Pow, + # Comparison operations + sge.GTE, + sge.GT, + sge.LTE, + sge.LT, + sge.EQ, + sge.NEQ, + # Logical operations + sge.And, + sge.Or, + sge.Xor, + # Bitwise operations + sge.BitwiseAnd, + sge.BitwiseOr, + sge.BitwiseXor, + sge.BitwiseLeftShift, + sge.BitwiseRightShift, + sge.BitwiseNot, + # Other operations + sge.Is, + } + @functools.singledispatchmethod def compile_expression( self, @@ -110,10 +141,12 @@ def register_binary_op( def decorator(impl: typing.Callable[..., sge.Expression]): def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp): + left = self._add_parentheses(args[0]) + right = self._add_parentheses(args[1]) if pass_op: - return impl(args[0], args[1], op) + return impl(left, right, op) else: - return impl(args[0], args[1]) + return impl(left, right) self._register(key, normalized_impl) return impl @@ -177,6 +210,12 @@ def _register( raise ValueError(f"Operation name {op_name} already registered") self._registry[op_name] = impl + @classmethod + def _add_parentheses(cls, expr: TypedExpr) -> TypedExpr: + if type(expr.expr) in cls.SQLGLOT_NEEDS_PARENS: + return TypedExpr(sge.paren(expr.expr, copy=False), expr.dtype) + return expr + # Singleton compiler scalar_op_compiler = ScalarOpCompiler() diff --git a/tests/unit/core/compile/sqlglot/test_scalar_compiler.py b/tests/unit/core/compile/sqlglot/test_scalar_compiler.py index a2ee2c6331..14d7b47389 100644 --- a/tests/unit/core/compile/sqlglot/test_scalar_compiler.py +++ b/tests/unit/core/compile/sqlglot/test_scalar_compiler.py @@ -170,6 +170,43 @@ def _(*args: TypedExpr, op: ops.NaryOp) -> sge.Expression: mock_impl.assert_called_once_with(arg1, arg2, arg3, arg4, op=mock_op) +def test_binary_op_parentheses(): + compiler = scalar_compiler.ScalarOpCompiler() + + class MockAddOp(ops.BinaryOp): + name = "mock_add_op" + + class MockMulOp(ops.BinaryOp): + name = "mock_mul_op" + + add_op = MockAddOp() + mul_op = MockMulOp() + + @compiler.register_binary_op(add_op) + def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.Add(this=left.expr, expression=right.expr) + + @compiler.register_binary_op(mul_op) + def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.Mul(this=left.expr, expression=right.expr) + + a = TypedExpr(sge.Identifier(this="a"), "int") + b = TypedExpr(sge.Identifier(this="b"), "int") + c = TypedExpr(sge.Identifier(this="c"), "int") + + # (a + b) * c + add_expr = compiler.compile_row_op(add_op, [a, b]) + add_typed_expr = TypedExpr(add_expr, "int") + result1 = compiler.compile_row_op(mul_op, [add_typed_expr, c]) + assert result1.sql() == "(a + b) * c" + + # a * (b + c) + add_expr_2 = compiler.compile_row_op(add_op, [b, c]) + add_typed_expr_2 = TypedExpr(add_expr_2, "int") + result2 = compiler.compile_row_op(mul_op, [a, add_typed_expr_2]) + assert result2.sql() == "a * (b + c)" + + def test_register_duplicate_op_raises(): compiler = scalar_compiler.ScalarOpCompiler()