Skip to content

Commit 122b4a4

Browse files
committed
refactor: support ops.mod_op for the sqlglot compiler
1 parent 9130a61 commit 122b4a4

File tree

4 files changed

+312
-1
lines changed

4 files changed

+312
-1
lines changed

bigframes/core/compile/sqlglot/expressions/numeric_ops.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,49 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
323323
return result
324324

325325

326+
@register_binary_op(ops.mod_op)
327+
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
328+
# In BigQuery returned value has the same sign as X. In pandas, the sign of y is used, so we need to flip the result if sign(x) != sign(y)
329+
left_expr = _coerce_bool_to_int(left)
330+
right_expr = _coerce_bool_to_int(right)
331+
332+
# BigQuery MOD function doesn't support float types, so cast to BIGNUMERIC
333+
if left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE:
334+
left_expr = sge.Cast(this=left_expr, to="BIGNUMERIC")
335+
right_expr = sge.Cast(this=right_expr, to="BIGNUMERIC")
336+
337+
# MOD(N, 0) will error in bigquery, but needs to return null
338+
bq_mod = sge.Mod(this=left_expr, expression=right_expr)
339+
zero_result = (
340+
constants._NAN
341+
if (left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE)
342+
else constants._ZERO
343+
)
344+
return sge.Case(
345+
ifs=[
346+
sge.If(
347+
this=sge.EQ(this=right_expr, expression=constants._ZERO),
348+
true=zero_result * left_expr,
349+
),
350+
sge.If(
351+
this=sge.and_(
352+
right_expr < constants._ZERO,
353+
bq_mod > constants._ZERO,
354+
),
355+
true=right_expr + bq_mod,
356+
),
357+
sge.If(
358+
this=sge.and_(
359+
right_expr > constants._ZERO,
360+
bq_mod < constants._ZERO,
361+
),
362+
true=right_expr + bq_mod,
363+
),
364+
],
365+
default=bq_mod,
366+
)
367+
368+
326369
@register_binary_op(ops.mul_op)
327370
def _(left: TypedExpr, right: TypedExpr) -> sge.Expression:
328371
left_expr = _coerce_bool_to_int(left)

tests/system/small/engines/test_numeric_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def test_engines_project_floordiv_durations(
161161
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
162162

163163

164-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
164+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
165165
def test_engines_project_mod(
166166
scalars_array_value: array_value.ArrayValue,
167167
engine,
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`int64_col` AS `bfcol_0`,
4+
`float64_col` AS `bfcol_1`,
5+
`rowindex` AS `bfcol_2`
6+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
7+
), `bfcte_1` AS (
8+
SELECT
9+
*,
10+
`bfcol_2` AS `bfcol_6`,
11+
`bfcol_0` AS `bfcol_7`,
12+
`bfcol_1` AS `bfcol_8`,
13+
CASE
14+
WHEN `bfcol_0` = CAST(0 AS INT64)
15+
THEN CAST(0 AS INT64) * `bfcol_0`
16+
WHEN `bfcol_0` < CAST(0 AS INT64)
17+
AND (
18+
MOD(`bfcol_0`, `bfcol_0`)
19+
) > CAST(0 AS INT64)
20+
THEN `bfcol_0` + (
21+
MOD(`bfcol_0`, `bfcol_0`)
22+
)
23+
WHEN `bfcol_0` > CAST(0 AS INT64)
24+
AND (
25+
MOD(`bfcol_0`, `bfcol_0`)
26+
) < CAST(0 AS INT64)
27+
THEN `bfcol_0` + (
28+
MOD(`bfcol_0`, `bfcol_0`)
29+
)
30+
ELSE MOD(`bfcol_0`, `bfcol_0`)
31+
END AS `bfcol_9`
32+
FROM `bfcte_0`
33+
), `bfcte_2` AS (
34+
SELECT
35+
*,
36+
`bfcol_6` AS `bfcol_14`,
37+
`bfcol_7` AS `bfcol_15`,
38+
`bfcol_8` AS `bfcol_16`,
39+
`bfcol_9` AS `bfcol_17`,
40+
CASE
41+
WHEN -`bfcol_7` = CAST(0 AS INT64)
42+
THEN CAST(0 AS INT64) * `bfcol_7`
43+
WHEN -`bfcol_7` < CAST(0 AS INT64)
44+
AND (
45+
MOD(`bfcol_7`, -`bfcol_7`)
46+
) > CAST(0 AS INT64)
47+
THEN -`bfcol_7` + (
48+
MOD(`bfcol_7`, -`bfcol_7`)
49+
)
50+
WHEN -`bfcol_7` > CAST(0 AS INT64)
51+
AND (
52+
MOD(`bfcol_7`, -`bfcol_7`)
53+
) < CAST(0 AS INT64)
54+
THEN -`bfcol_7` + (
55+
MOD(`bfcol_7`, -`bfcol_7`)
56+
)
57+
ELSE MOD(`bfcol_7`, -`bfcol_7`)
58+
END AS `bfcol_18`
59+
FROM `bfcte_1`
60+
), `bfcte_3` AS (
61+
SELECT
62+
*,
63+
`bfcol_14` AS `bfcol_24`,
64+
`bfcol_15` AS `bfcol_25`,
65+
`bfcol_16` AS `bfcol_26`,
66+
`bfcol_17` AS `bfcol_27`,
67+
`bfcol_18` AS `bfcol_28`,
68+
CASE
69+
WHEN 1 = CAST(0 AS INT64)
70+
THEN CAST(0 AS INT64) * `bfcol_15`
71+
WHEN 1 < CAST(0 AS INT64) AND (
72+
MOD(`bfcol_15`, 1)
73+
) > CAST(0 AS INT64)
74+
THEN 1 + (
75+
MOD(`bfcol_15`, 1)
76+
)
77+
WHEN 1 > CAST(0 AS INT64) AND (
78+
MOD(`bfcol_15`, 1)
79+
) < CAST(0 AS INT64)
80+
THEN 1 + (
81+
MOD(`bfcol_15`, 1)
82+
)
83+
ELSE MOD(`bfcol_15`, 1)
84+
END AS `bfcol_29`
85+
FROM `bfcte_2`
86+
), `bfcte_4` AS (
87+
SELECT
88+
*,
89+
`bfcol_24` AS `bfcol_36`,
90+
`bfcol_25` AS `bfcol_37`,
91+
`bfcol_26` AS `bfcol_38`,
92+
`bfcol_27` AS `bfcol_39`,
93+
`bfcol_28` AS `bfcol_40`,
94+
`bfcol_29` AS `bfcol_41`,
95+
CASE
96+
WHEN 0 = CAST(0 AS INT64)
97+
THEN CAST(0 AS INT64) * `bfcol_25`
98+
WHEN 0 < CAST(0 AS INT64) AND (
99+
MOD(`bfcol_25`, 0)
100+
) > CAST(0 AS INT64)
101+
THEN 0 + (
102+
MOD(`bfcol_25`, 0)
103+
)
104+
WHEN 0 > CAST(0 AS INT64) AND (
105+
MOD(`bfcol_25`, 0)
106+
) < CAST(0 AS INT64)
107+
THEN 0 + (
108+
MOD(`bfcol_25`, 0)
109+
)
110+
ELSE MOD(`bfcol_25`, 0)
111+
END AS `bfcol_42`
112+
FROM `bfcte_3`
113+
), `bfcte_5` AS (
114+
SELECT
115+
*,
116+
`bfcol_36` AS `bfcol_50`,
117+
`bfcol_37` AS `bfcol_51`,
118+
`bfcol_38` AS `bfcol_52`,
119+
`bfcol_39` AS `bfcol_53`,
120+
`bfcol_40` AS `bfcol_54`,
121+
`bfcol_41` AS `bfcol_55`,
122+
`bfcol_42` AS `bfcol_56`,
123+
CASE
124+
WHEN CAST(`bfcol_38` AS BIGNUMERIC) = CAST(0 AS INT64)
125+
THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_38` AS BIGNUMERIC)
126+
WHEN CAST(`bfcol_38` AS BIGNUMERIC) < CAST(0 AS INT64)
127+
AND (
128+
MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC))
129+
) > CAST(0 AS INT64)
130+
THEN CAST(`bfcol_38` AS BIGNUMERIC) + (
131+
MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC))
132+
)
133+
WHEN CAST(`bfcol_38` AS BIGNUMERIC) > CAST(0 AS INT64)
134+
AND (
135+
MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC))
136+
) < CAST(0 AS INT64)
137+
THEN CAST(`bfcol_38` AS BIGNUMERIC) + (
138+
MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC))
139+
)
140+
ELSE MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC))
141+
END AS `bfcol_57`
142+
FROM `bfcte_4`
143+
), `bfcte_6` AS (
144+
SELECT
145+
*,
146+
`bfcol_50` AS `bfcol_66`,
147+
`bfcol_51` AS `bfcol_67`,
148+
`bfcol_52` AS `bfcol_68`,
149+
`bfcol_53` AS `bfcol_69`,
150+
`bfcol_54` AS `bfcol_70`,
151+
`bfcol_55` AS `bfcol_71`,
152+
`bfcol_56` AS `bfcol_72`,
153+
`bfcol_57` AS `bfcol_73`,
154+
CASE
155+
WHEN CAST(-`bfcol_52` AS BIGNUMERIC) = CAST(0 AS INT64)
156+
THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_52` AS BIGNUMERIC)
157+
WHEN CAST(-`bfcol_52` AS BIGNUMERIC) < CAST(0 AS INT64)
158+
AND (
159+
MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC))
160+
) > CAST(0 AS INT64)
161+
THEN CAST(-`bfcol_52` AS BIGNUMERIC) + (
162+
MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC))
163+
)
164+
WHEN CAST(-`bfcol_52` AS BIGNUMERIC) > CAST(0 AS INT64)
165+
AND (
166+
MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC))
167+
) < CAST(0 AS INT64)
168+
THEN CAST(-`bfcol_52` AS BIGNUMERIC) + (
169+
MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC))
170+
)
171+
ELSE MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC))
172+
END AS `bfcol_74`
173+
FROM `bfcte_5`
174+
), `bfcte_7` AS (
175+
SELECT
176+
*,
177+
`bfcol_66` AS `bfcol_84`,
178+
`bfcol_67` AS `bfcol_85`,
179+
`bfcol_68` AS `bfcol_86`,
180+
`bfcol_69` AS `bfcol_87`,
181+
`bfcol_70` AS `bfcol_88`,
182+
`bfcol_71` AS `bfcol_89`,
183+
`bfcol_72` AS `bfcol_90`,
184+
`bfcol_73` AS `bfcol_91`,
185+
`bfcol_74` AS `bfcol_92`,
186+
CASE
187+
WHEN CAST(1 AS BIGNUMERIC) = CAST(0 AS INT64)
188+
THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_68` AS BIGNUMERIC)
189+
WHEN CAST(1 AS BIGNUMERIC) < CAST(0 AS INT64)
190+
AND (
191+
MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC))
192+
) > CAST(0 AS INT64)
193+
THEN CAST(1 AS BIGNUMERIC) + (
194+
MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC))
195+
)
196+
WHEN CAST(1 AS BIGNUMERIC) > CAST(0 AS INT64)
197+
AND (
198+
MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC))
199+
) < CAST(0 AS INT64)
200+
THEN CAST(1 AS BIGNUMERIC) + (
201+
MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC))
202+
)
203+
ELSE MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC))
204+
END AS `bfcol_93`
205+
FROM `bfcte_6`
206+
), `bfcte_8` AS (
207+
SELECT
208+
*,
209+
`bfcol_84` AS `bfcol_104`,
210+
`bfcol_85` AS `bfcol_105`,
211+
`bfcol_86` AS `bfcol_106`,
212+
`bfcol_87` AS `bfcol_107`,
213+
`bfcol_88` AS `bfcol_108`,
214+
`bfcol_89` AS `bfcol_109`,
215+
`bfcol_90` AS `bfcol_110`,
216+
`bfcol_91` AS `bfcol_111`,
217+
`bfcol_92` AS `bfcol_112`,
218+
`bfcol_93` AS `bfcol_113`,
219+
CASE
220+
WHEN CAST(0 AS BIGNUMERIC) = CAST(0 AS INT64)
221+
THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_86` AS BIGNUMERIC)
222+
WHEN CAST(0 AS BIGNUMERIC) < CAST(0 AS INT64)
223+
AND (
224+
MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC))
225+
) > CAST(0 AS INT64)
226+
THEN CAST(0 AS BIGNUMERIC) + (
227+
MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC))
228+
)
229+
WHEN CAST(0 AS BIGNUMERIC) > CAST(0 AS INT64)
230+
AND (
231+
MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC))
232+
) < CAST(0 AS INT64)
233+
THEN CAST(0 AS BIGNUMERIC) + (
234+
MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC))
235+
)
236+
ELSE MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC))
237+
END AS `bfcol_114`
238+
FROM `bfcte_7`
239+
)
240+
SELECT
241+
`bfcol_104` AS `rowindex`,
242+
`bfcol_105` AS `int64_col`,
243+
`bfcol_106` AS `float64_col`,
244+
`bfcol_107` AS `int_mod_int`,
245+
`bfcol_108` AS `int_mod_int_neg`,
246+
`bfcol_109` AS `int_mod_1`,
247+
`bfcol_110` AS `int_mod_0`,
248+
`bfcol_111` AS `float_mod_float`,
249+
`bfcol_112` AS `float_mod_float_neg`,
250+
`bfcol_113` AS `float_mod_1`,
251+
`bfcol_114` AS `float_mod_0`
252+
FROM `bfcte_8`

tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,22 @@ def test_mul_numeric(scalar_types_df: bpd.DataFrame, snapshot):
287287
snapshot.assert_match(bf_df.sql, "out.sql")
288288

289289

290+
def test_mod_numeric(scalar_types_df: bpd.DataFrame, snapshot):
291+
bf_df = scalar_types_df[["int64_col", "float64_col"]]
292+
293+
bf_df["int_mod_int"] = bf_df["int64_col"] % bf_df["int64_col"]
294+
bf_df["int_mod_int_neg"] = bf_df["int64_col"] % -bf_df["int64_col"]
295+
bf_df["int_mod_1"] = bf_df["int64_col"] % 1
296+
bf_df["int_mod_0"] = bf_df["int64_col"] % 0
297+
298+
bf_df["float_mod_float"] = bf_df["float64_col"] % bf_df["float64_col"]
299+
bf_df["float_mod_float_neg"] = bf_df["float64_col"] % -bf_df["float64_col"]
300+
bf_df["float_mod_1"] = bf_df["float64_col"] % 1
301+
bf_df["float_mod_0"] = bf_df["float64_col"] % 0
302+
303+
snapshot.assert_match(bf_df.sql, "out.sql")
304+
305+
290306
def test_sub_numeric(scalar_types_df: bpd.DataFrame, snapshot):
291307
bf_df = scalar_types_df[["int64_col", "bool_col"]]
292308

0 commit comments

Comments
 (0)