Skip to content

Commit f2c80bb

Browse files
committed
add sub_op implementations
1 parent db867d3 commit f2c80bb

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

bigframes/core/compile/sqlglot/expressions/binary_compiler.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,51 @@ def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
8181
)
8282

8383

84+
@BINARY_OP_REGISTRATION.register(ops.sub_op)
85+
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
86+
if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype):
87+
left_expr = left.expr
88+
if left.dtype == dtypes.BOOL_DTYPE:
89+
left_expr = sge.Cast(this=left_expr, to="INT64")
90+
right_expr = right.expr
91+
if right.dtype == dtypes.BOOL_DTYPE:
92+
right_expr = sge.Cast(this=right_expr, to="INT64")
93+
return sge.Sub(this=left_expr, expression=right_expr)
94+
95+
if (
96+
dtypes.is_datetime_like(left.dtype) or dtypes.is_date_like(left.dtype)
97+
) and right.dtype == dtypes.TIMEDELTA_DTYPE:
98+
left_expr = left.expr
99+
if left.dtype == dtypes.DATE_DTYPE:
100+
left_expr = sge.Cast(this=left_expr, to="DATETIME")
101+
return sge.TimestampSub(
102+
this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND")
103+
)
104+
if (dtypes.is_datetime_like(left.dtype) or dtypes.is_date_like(left.dtype)) and (
105+
dtypes.is_datetime_like(right.dtype) or dtypes.is_date_like(right.dtype)
106+
):
107+
left_expr = left.expr
108+
if left.dtype == dtypes.DATE_DTYPE:
109+
left_expr = sge.Cast(this=left_expr, to="DATETIME")
110+
right_expr = right.expr
111+
if right.dtype == dtypes.DATE_DTYPE:
112+
right_expr = sge.Cast(this=right_expr, to="DATETIME")
113+
return sge.TimestampDiff(
114+
this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND")
115+
)
116+
117+
if left.dtype == dtypes.DATE_DTYPE and right.dtype == dtypes.INT_DTYPE:
118+
return sge.DateSub(
119+
this=left.expr, expression=right.expr, unit=sge.Var(this="DAY")
120+
)
121+
if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE:
122+
return sge.Sub(this=left.expr, expression=right.expr)
123+
124+
raise TypeError(
125+
f"Cannot subtract type {left.dtype} and {right.dtype}. {constants.FEEDBACK_LINK}"
126+
)
127+
128+
84129
@BINARY_OP_REGISTRATION.register(ops.ge_op)
85130
def _(op, left: TypedExpr, right: TypedExpr) -> sge.Expression:
86131
return sge.GTE(this=left.expr, expression=right.expr)

tests/system/small/engines/test_numeric_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def test_engines_project_add(
6262
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
6363

6464

65-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
65+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
6666
def test_engines_project_sub(
6767
scalars_array_value: array_value.ArrayValue,
6868
engine,

0 commit comments

Comments
 (0)