Skip to content

Commit 6f87b9a

Browse files
committed
refactor: add parenthesization for binary operations
1 parent e0b2257 commit 6f87b9a

File tree

2 files changed

+81
-2
lines changed

2 files changed

+81
-2
lines changed

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,45 @@
2424
import bigframes.operations as ops
2525

2626

27+
# TODO: add parenthesize for operators
2728
class ScalarOpCompiler:
2829
# Mapping of operation name to implemenations
2930
_registry: dict[
3031
str,
3132
typing.Callable[[typing.Sequence[TypedExpr], ops.RowOp], sge.Expression],
3233
] = {}
3334

35+
# A set of SQLGlot classes that may need to be parenthesized
36+
SQLGLOT_NEEDS_PARENS = {
37+
# Numeric operations
38+
sge.Add,
39+
sge.Sub,
40+
sge.Mul,
41+
sge.Div,
42+
sge.Mod,
43+
sge.Pow,
44+
# Comparison operations
45+
sge.GTE,
46+
sge.GT,
47+
sge.LTE,
48+
sge.LT,
49+
sge.EQ,
50+
sge.NEQ,
51+
# Logical operations
52+
sge.And,
53+
sge.Or,
54+
sge.Xor,
55+
# Bitwise operations
56+
sge.BitwiseAnd,
57+
sge.BitwiseOr,
58+
sge.BitwiseXor,
59+
sge.BitwiseLeftShift,
60+
sge.BitwiseRightShift,
61+
sge.BitwiseNot,
62+
# Other operations
63+
sge.Is,
64+
}
65+
3466
@functools.singledispatchmethod
3567
def compile_expression(
3668
self,
@@ -110,10 +142,14 @@ def register_binary_op(
110142

111143
def decorator(impl: typing.Callable[..., sge.Expression]):
112144
def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp):
145+
# TODO: If the op is associative, we can skip parentheses of the left
146+
# when left and right are of the same op type.
147+
left = self._add_parentheses(args[0])
148+
right = self._add_parentheses(args[1])
113149
if pass_op:
114-
return impl(args[0], args[1], op)
150+
return impl(left, right, op)
115151
else:
116-
return impl(args[0], args[1])
152+
return impl(left, right)
117153

118154
self._register(key, normalized_impl)
119155
return impl
@@ -177,6 +213,12 @@ def _register(
177213
raise ValueError(f"Operation name {op_name} already registered")
178214
self._registry[op_name] = impl
179215

216+
@classmethod
217+
def _add_parentheses(cls, expr: TypedExpr) -> TypedExpr:
218+
if type(expr.expr) in cls.SQLGLOT_NEEDS_PARENS:
219+
return TypedExpr(sge.paren(expr.expr, copy=False), expr.dtype)
220+
return expr
221+
180222

181223
# Singleton compiler
182224
scalar_op_compiler = ScalarOpCompiler()

tests/unit/core/compile/sqlglot/test_scalar_compiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,43 @@ def _(*args: TypedExpr, op: ops.NaryOp) -> sge.Expression:
170170
mock_impl.assert_called_once_with(arg1, arg2, arg3, arg4, op=mock_op)
171171

172172

173+
def test_binary_op_parentheses():
174+
compiler = scalar_compiler.ScalarOpCompiler()
175+
176+
class MockAddOp(ops.BinaryOp):
177+
name = "mock_add_op"
178+
179+
class MockMulOp(ops.BinaryOp):
180+
name = "mock_mul_op"
181+
182+
add_op = MockAddOp()
183+
mul_op = MockMulOp()
184+
185+
@compiler.register_binary_op(add_op)
186+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
187+
return sge.Add(this=left.expr, expression=right.expr)
188+
189+
@compiler.register_binary_op(mul_op)
190+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
191+
return sge.Mul(this=left.expr, expression=right.expr)
192+
193+
a = TypedExpr(sge.Identifier(this="a"), "int")
194+
b = TypedExpr(sge.Identifier(this="b"), "int")
195+
c = TypedExpr(sge.Identifier(this="c"), "int")
196+
197+
# (a + b) * c
198+
add_expr = compiler.compile_row_op(add_op, [a, b])
199+
add_typed_expr = TypedExpr(add_expr, "int")
200+
result1 = compiler.compile_row_op(mul_op, [add_typed_expr, c])
201+
assert result1.sql() == "(a + b) * c"
202+
203+
# a * (b + c)
204+
add_expr_2 = compiler.compile_row_op(add_op, [b, c])
205+
add_typed_expr_2 = TypedExpr(add_expr_2, "int")
206+
result2 = compiler.compile_row_op(mul_op, [a, add_typed_expr_2])
207+
assert result2.sql() == "a * (b + c)"
208+
209+
173210
def test_register_duplicate_op_raises():
174211
compiler = scalar_compiler.ScalarOpCompiler()
175212

0 commit comments

Comments
 (0)