From 122b4a468ea2173b0a280816cfae072628ced5e7 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Tue, 30 Sep 2025 18:10:28 +0000 Subject: [PATCH] refactor: support ops.mod_op for the sqlglot compiler --- .../sqlglot/expressions/numeric_ops.py | 43 +++ .../system/small/engines/test_numeric_ops.py | 2 +- .../test_numeric_ops/test_mod_numeric/out.sql | 252 ++++++++++++++++++ .../sqlglot/expressions/test_numeric_ops.py | 16 ++ 4 files changed, 312 insertions(+), 1 deletion(-) create mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index 1a6447ceb7..d86df93921 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -323,6 +323,49 @@ def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: return result +@register_binary_op(ops.mod_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + # In BigQuery returned value has the same sign as X. In pandas, the sign of y is used, so we need to flip the result if sign(x) != sign(y) + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + + # BigQuery MOD function doesn't support float types, so cast to BIGNUMERIC + if left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE: + left_expr = sge.Cast(this=left_expr, to="BIGNUMERIC") + right_expr = sge.Cast(this=right_expr, to="BIGNUMERIC") + + # MOD(N, 0) will error in bigquery, but needs to return null + bq_mod = sge.Mod(this=left_expr, expression=right_expr) + zero_result = ( + constants._NAN + if (left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE) + else constants._ZERO + ) + return sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=right_expr, expression=constants._ZERO), + true=zero_result * left_expr, + ), + sge.If( + this=sge.and_( + right_expr < constants._ZERO, + bq_mod > constants._ZERO, + ), + true=right_expr + bq_mod, + ), + sge.If( + this=sge.and_( + right_expr > constants._ZERO, + bq_mod < constants._ZERO, + ), + true=right_expr + bq_mod, + ), + ], + default=bq_mod, + ) + + @register_binary_op(ops.mul_op) def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: left_expr = _coerce_bool_to_int(left) diff --git a/tests/system/small/engines/test_numeric_ops.py b/tests/system/small/engines/test_numeric_ops.py index 7928922e41..ef0f8d9d0d 100644 --- a/tests/system/small/engines/test_numeric_ops.py +++ b/tests/system/small/engines/test_numeric_ops.py @@ -161,7 +161,7 @@ def test_engines_project_floordiv_durations( assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_project_mod( scalars_array_value: array_value.ArrayValue, engine, diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql new file mode 100644 index 0000000000..7913b43aa6 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mod_numeric/out.sql @@ -0,0 +1,252 @@ +WITH `bfcte_0` AS ( + SELECT + `int64_col` AS `bfcol_0`, + `float64_col` AS `bfcol_1`, + `rowindex` AS `bfcol_2` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + `bfcol_2` AS `bfcol_6`, + `bfcol_0` AS `bfcol_7`, + `bfcol_1` AS `bfcol_8`, + CASE + WHEN `bfcol_0` = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * `bfcol_0` + WHEN `bfcol_0` < CAST(0 AS INT64) + AND ( + MOD(`bfcol_0`, `bfcol_0`) + ) > CAST(0 AS INT64) + THEN `bfcol_0` + ( + MOD(`bfcol_0`, `bfcol_0`) + ) + WHEN `bfcol_0` > CAST(0 AS INT64) + AND ( + MOD(`bfcol_0`, `bfcol_0`) + ) < CAST(0 AS INT64) + THEN `bfcol_0` + ( + MOD(`bfcol_0`, `bfcol_0`) + ) + ELSE MOD(`bfcol_0`, `bfcol_0`) + END AS `bfcol_9` + FROM `bfcte_0` +), `bfcte_2` AS ( + SELECT + *, + `bfcol_6` AS `bfcol_14`, + `bfcol_7` AS `bfcol_15`, + `bfcol_8` AS `bfcol_16`, + `bfcol_9` AS `bfcol_17`, + CASE + WHEN -`bfcol_7` = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * `bfcol_7` + WHEN -`bfcol_7` < CAST(0 AS INT64) + AND ( + MOD(`bfcol_7`, -`bfcol_7`) + ) > CAST(0 AS INT64) + THEN -`bfcol_7` + ( + MOD(`bfcol_7`, -`bfcol_7`) + ) + WHEN -`bfcol_7` > CAST(0 AS INT64) + AND ( + MOD(`bfcol_7`, -`bfcol_7`) + ) < CAST(0 AS INT64) + THEN -`bfcol_7` + ( + MOD(`bfcol_7`, -`bfcol_7`) + ) + ELSE MOD(`bfcol_7`, -`bfcol_7`) + END AS `bfcol_18` + FROM `bfcte_1` +), `bfcte_3` AS ( + SELECT + *, + `bfcol_14` AS `bfcol_24`, + `bfcol_15` AS `bfcol_25`, + `bfcol_16` AS `bfcol_26`, + `bfcol_17` AS `bfcol_27`, + `bfcol_18` AS `bfcol_28`, + CASE + WHEN 1 = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * `bfcol_15` + WHEN 1 < CAST(0 AS INT64) AND ( + MOD(`bfcol_15`, 1) + ) > CAST(0 AS INT64) + THEN 1 + ( + MOD(`bfcol_15`, 1) + ) + WHEN 1 > CAST(0 AS INT64) AND ( + MOD(`bfcol_15`, 1) + ) < CAST(0 AS INT64) + THEN 1 + ( + MOD(`bfcol_15`, 1) + ) + ELSE MOD(`bfcol_15`, 1) + END AS `bfcol_29` + FROM `bfcte_2` +), `bfcte_4` AS ( + SELECT + *, + `bfcol_24` AS `bfcol_36`, + `bfcol_25` AS `bfcol_37`, + `bfcol_26` AS `bfcol_38`, + `bfcol_27` AS `bfcol_39`, + `bfcol_28` AS `bfcol_40`, + `bfcol_29` AS `bfcol_41`, + CASE + WHEN 0 = CAST(0 AS INT64) + THEN CAST(0 AS INT64) * `bfcol_25` + WHEN 0 < CAST(0 AS INT64) AND ( + MOD(`bfcol_25`, 0) + ) > CAST(0 AS INT64) + THEN 0 + ( + MOD(`bfcol_25`, 0) + ) + WHEN 0 > CAST(0 AS INT64) AND ( + MOD(`bfcol_25`, 0) + ) < CAST(0 AS INT64) + THEN 0 + ( + MOD(`bfcol_25`, 0) + ) + ELSE MOD(`bfcol_25`, 0) + END AS `bfcol_42` + FROM `bfcte_3` +), `bfcte_5` AS ( + SELECT + *, + `bfcol_36` AS `bfcol_50`, + `bfcol_37` AS `bfcol_51`, + `bfcol_38` AS `bfcol_52`, + `bfcol_39` AS `bfcol_53`, + `bfcol_40` AS `bfcol_54`, + `bfcol_41` AS `bfcol_55`, + `bfcol_42` AS `bfcol_56`, + CASE + WHEN CAST(`bfcol_38` AS BIGNUMERIC) = CAST(0 AS INT64) + THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_38` AS BIGNUMERIC) + WHEN CAST(`bfcol_38` AS BIGNUMERIC) < CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC)) + ) > CAST(0 AS INT64) + THEN CAST(`bfcol_38` AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC)) + ) + WHEN CAST(`bfcol_38` AS BIGNUMERIC) > CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC)) + ) < CAST(0 AS INT64) + THEN CAST(`bfcol_38` AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC)) + ) + ELSE MOD(CAST(`bfcol_38` AS BIGNUMERIC), CAST(`bfcol_38` AS BIGNUMERIC)) + END AS `bfcol_57` + FROM `bfcte_4` +), `bfcte_6` AS ( + SELECT + *, + `bfcol_50` AS `bfcol_66`, + `bfcol_51` AS `bfcol_67`, + `bfcol_52` AS `bfcol_68`, + `bfcol_53` AS `bfcol_69`, + `bfcol_54` AS `bfcol_70`, + `bfcol_55` AS `bfcol_71`, + `bfcol_56` AS `bfcol_72`, + `bfcol_57` AS `bfcol_73`, + CASE + WHEN CAST(-`bfcol_52` AS BIGNUMERIC) = CAST(0 AS INT64) + THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_52` AS BIGNUMERIC) + WHEN CAST(-`bfcol_52` AS BIGNUMERIC) < CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC)) + ) > CAST(0 AS INT64) + THEN CAST(-`bfcol_52` AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC)) + ) + WHEN CAST(-`bfcol_52` AS BIGNUMERIC) > CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC)) + ) < CAST(0 AS INT64) + THEN CAST(-`bfcol_52` AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC)) + ) + ELSE MOD(CAST(`bfcol_52` AS BIGNUMERIC), CAST(-`bfcol_52` AS BIGNUMERIC)) + END AS `bfcol_74` + FROM `bfcte_5` +), `bfcte_7` AS ( + SELECT + *, + `bfcol_66` AS `bfcol_84`, + `bfcol_67` AS `bfcol_85`, + `bfcol_68` AS `bfcol_86`, + `bfcol_69` AS `bfcol_87`, + `bfcol_70` AS `bfcol_88`, + `bfcol_71` AS `bfcol_89`, + `bfcol_72` AS `bfcol_90`, + `bfcol_73` AS `bfcol_91`, + `bfcol_74` AS `bfcol_92`, + CASE + WHEN CAST(1 AS BIGNUMERIC) = CAST(0 AS INT64) + THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_68` AS BIGNUMERIC) + WHEN CAST(1 AS BIGNUMERIC) < CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC)) + ) > CAST(0 AS INT64) + THEN CAST(1 AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC)) + ) + WHEN CAST(1 AS BIGNUMERIC) > CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC)) + ) < CAST(0 AS INT64) + THEN CAST(1 AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC)) + ) + ELSE MOD(CAST(`bfcol_68` AS BIGNUMERIC), CAST(1 AS BIGNUMERIC)) + END AS `bfcol_93` + FROM `bfcte_6` +), `bfcte_8` AS ( + SELECT + *, + `bfcol_84` AS `bfcol_104`, + `bfcol_85` AS `bfcol_105`, + `bfcol_86` AS `bfcol_106`, + `bfcol_87` AS `bfcol_107`, + `bfcol_88` AS `bfcol_108`, + `bfcol_89` AS `bfcol_109`, + `bfcol_90` AS `bfcol_110`, + `bfcol_91` AS `bfcol_111`, + `bfcol_92` AS `bfcol_112`, + `bfcol_93` AS `bfcol_113`, + CASE + WHEN CAST(0 AS BIGNUMERIC) = CAST(0 AS INT64) + THEN CAST('NaN' AS FLOAT64) * CAST(`bfcol_86` AS BIGNUMERIC) + WHEN CAST(0 AS BIGNUMERIC) < CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC)) + ) > CAST(0 AS INT64) + THEN CAST(0 AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC)) + ) + WHEN CAST(0 AS BIGNUMERIC) > CAST(0 AS INT64) + AND ( + MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC)) + ) < CAST(0 AS INT64) + THEN CAST(0 AS BIGNUMERIC) + ( + MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC)) + ) + ELSE MOD(CAST(`bfcol_86` AS BIGNUMERIC), CAST(0 AS BIGNUMERIC)) + END AS `bfcol_114` + FROM `bfcte_7` +) +SELECT + `bfcol_104` AS `rowindex`, + `bfcol_105` AS `int64_col`, + `bfcol_106` AS `float64_col`, + `bfcol_107` AS `int_mod_int`, + `bfcol_108` AS `int_mod_int_neg`, + `bfcol_109` AS `int_mod_1`, + `bfcol_110` AS `int_mod_0`, + `bfcol_111` AS `float_mod_float`, + `bfcol_112` AS `float_mod_float_neg`, + `bfcol_113` AS `float_mod_1`, + `bfcol_114` AS `float_mod_0` +FROM `bfcte_8` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index e0c41857e9..231d9d5bf0 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -287,6 +287,22 @@ def test_mul_numeric(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(bf_df.sql, "out.sql") +def test_mod_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "float64_col"]] + + bf_df["int_mod_int"] = bf_df["int64_col"] % bf_df["int64_col"] + bf_df["int_mod_int_neg"] = bf_df["int64_col"] % -bf_df["int64_col"] + bf_df["int_mod_1"] = bf_df["int64_col"] % 1 + bf_df["int_mod_0"] = bf_df["int64_col"] % 0 + + bf_df["float_mod_float"] = bf_df["float64_col"] % bf_df["float64_col"] + bf_df["float_mod_float_neg"] = bf_df["float64_col"] % -bf_df["float64_col"] + bf_df["float_mod_1"] = bf_df["float64_col"] % 1 + bf_df["float_mod_0"] = bf_df["float64_col"] % 0 + + snapshot.assert_match(bf_df.sql, "out.sql") + + def test_sub_numeric(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["int64_col", "bool_col"]]