Skip to content

Commit e98ecaa

Browse files
committed
refactor: add parenthesization for binary operations
1 parent e0b2257 commit e98ecaa

File tree

2 files changed

+80
-2
lines changed

2 files changed

+80
-2
lines changed

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 43 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,37 @@ class ScalarOpCompiler:
3131
typing.Callable[[typing.Sequence[TypedExpr], ops.RowOp], sge.Expression],
3232
] = {}
3333

34+
# A set of SQLGlot classes that may need to be parenthesized
35+
SQLGLOT_NEEDS_PARENS = {
36+
# Numeric operations
37+
sge.Add,
38+
sge.Sub,
39+
sge.Mul,
40+
sge.Div,
41+
sge.Mod,
42+
sge.Pow,
43+
# Comparison operations
44+
sge.GTE,
45+
sge.GT,
46+
sge.LTE,
47+
sge.LT,
48+
sge.EQ,
49+
sge.NEQ,
50+
# Logical operations
51+
sge.And,
52+
sge.Or,
53+
sge.Xor,
54+
# Bitwise operations
55+
sge.BitwiseAnd,
56+
sge.BitwiseOr,
57+
sge.BitwiseXor,
58+
sge.BitwiseLeftShift,
59+
sge.BitwiseRightShift,
60+
sge.BitwiseNot,
61+
# Other operations
62+
sge.Is,
63+
}
64+
3465
@functools.singledispatchmethod
3566
def compile_expression(
3667
self,
@@ -110,10 +141,14 @@ def register_binary_op(
110141

111142
def decorator(impl: typing.Callable[..., sge.Expression]):
112143
def normalized_impl(args: typing.Sequence[TypedExpr], op: ops.RowOp):
144+
# TODO: If the op is associative, we can skip parentheses of the left
145+
# when left and right are of the same op type.
146+
left = self._add_parentheses(args[0])
147+
right = self._add_parentheses(args[1])
113148
if pass_op:
114-
return impl(args[0], args[1], op)
149+
return impl(left, right, op)
115150
else:
116-
return impl(args[0], args[1])
151+
return impl(left, right)
117152

118153
self._register(key, normalized_impl)
119154
return impl
@@ -177,6 +212,12 @@ def _register(
177212
raise ValueError(f"Operation name {op_name} already registered")
178213
self._registry[op_name] = impl
179214

215+
@classmethod
216+
def _add_parentheses(cls, expr: TypedExpr) -> TypedExpr:
217+
if type(expr.expr) in cls.SQLGLOT_NEEDS_PARENS:
218+
return TypedExpr(sge.paren(expr.expr, copy=False), expr.dtype)
219+
return expr
220+
180221

181222
# Singleton compiler
182223
scalar_op_compiler = ScalarOpCompiler()

tests/unit/core/compile/sqlglot/test_scalar_compiler.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,43 @@ def _(*args: TypedExpr, op: ops.NaryOp) -> sge.Expression:
170170
mock_impl.assert_called_once_with(arg1, arg2, arg3, arg4, op=mock_op)
171171

172172

173+
def test_binary_op_parentheses():
174+
compiler = scalar_compiler.ScalarOpCompiler()
175+
176+
class MockAddOp(ops.BinaryOp):
177+
name = "mock_add_op"
178+
179+
class MockMulOp(ops.BinaryOp):
180+
name = "mock_mul_op"
181+
182+
add_op = MockAddOp()
183+
mul_op = MockMulOp()
184+
185+
@compiler.register_binary_op(add_op)
186+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
187+
return sge.Add(this=left.expr, expression=right.expr)
188+
189+
@compiler.register_binary_op(mul_op)
190+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
191+
return sge.Mul(this=left.expr, expression=right.expr)
192+
193+
a = TypedExpr(sge.Identifier(this="a"), "int")
194+
b = TypedExpr(sge.Identifier(this="b"), "int")
195+
c = TypedExpr(sge.Identifier(this="c"), "int")
196+
197+
# (a + b) * c
198+
add_expr = compiler.compile_row_op(add_op, [a, b])
199+
add_typed_expr = TypedExpr(add_expr, "int")
200+
result1 = compiler.compile_row_op(mul_op, [add_typed_expr, c])
201+
assert result1.sql() == "(a + b) * c"
202+
203+
# a * (b + c)
204+
add_expr_2 = compiler.compile_row_op(add_op, [b, c])
205+
add_typed_expr_2 = TypedExpr(add_expr_2, "int")
206+
result2 = compiler.compile_row_op(mul_op, [a, add_typed_expr_2])
207+
assert result2.sql() == "a * (b + c)"
208+
209+
173210
def test_register_duplicate_op_raises():
174211
compiler = scalar_compiler.ScalarOpCompiler()
175212

0 commit comments

Comments
 (0)