Skip to content

Commit db867d3

Browse files
committed
chore: implement add_op and del_op compilers
1 parent 0468a4d commit db867d3

File tree

2 files changed

+45
-3
lines changed

2 files changed

+45
-3
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import bigframes_vendored.constants as constants
1718
import sqlglot.expressions as sge
1819

1920
from bigframes import dtypes
@@ -35,8 +36,49 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
3536
# String addition
3637
return sge.Concat(expressions=[left.expr, right.expr])
3738

38-
# Numerical addition
39-
return sge.Add(this=left.expr, expression=right.expr)
39+
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
40+
left_expr = left.expr
41+
if left.dtype == dtypes.BOOL_DTYPE:
42+
left_expr = sge.Cast(this=left_expr, to="INT64")
43+
right_expr = right.expr
44+
if right.dtype == dtypes.BOOL_DTYPE:
45+
right_expr = sge.Cast(this=right_expr, to="INT64")
46+
return sge.Add(this=left_expr, expression=right_expr)
47+
48+
if (
49+
dtypes.is_datetime_like(left.dtype) or dtypes.is_date_like(left.dtype)
50+
) and right.dtype == dtypes.TIMEDELTA_DTYPE:
51+
left_expr = left.expr
52+
if left.dtype == dtypes.DATE_DTYPE:
53+
left_expr = sge.Cast(this=left_expr, to="DATETIME")
54+
return sge.TimestampAdd(
55+
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
56+
)
57+
if (
58+
dtypes.is_datetime_like(right.dtype) or dtypes.is_date_like(right.dtype)
59+
) and left.dtype == dtypes.TIMEDELTA_DTYPE:
60+
right_expr = right.expr
61+
if right.dtype == dtypes.DATE_DTYPE:
62+
right_expr = sge.Cast(this=right_expr, to="DATETIME")
63+
return sge.TimestampAdd(
64+
this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND")
65+
)
66+
67+
if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.INT_DTYPE:
68+
return sge.DateAdd(
69+
this=left.expr, expression=right.expr, unit=sge.Var(this="DAY")
70+
)
71+
72+
if right.dtype == dtypes.DATE_DTYPE and left.dtype == dtypes.INT_DTYPE:
73+
return sge.DateAdd(
74+
this=right.expr, expression=left.expr, unit=sge.Var(this="DAY")
75+
)
76+
if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
77+
return sge.Add(this=left.expr, expression=right.expr)
78+
79+
raise TypeError(
80+
f"Cannot add type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
81+
)
4082

4183

4284
@BINARY_OP_REGISTRATION.register(ops.ge_op)

tests/system/small/engines/test_numeric_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def apply_op_pairwise(
5353
return new_arr
5454

5555

56-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
56+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5757
def test_engines_project_add(
5858
scalars_array_value: array_value.ArrayValue,
5959
engine,

0 commit comments

Comments
 (0)