From ef6c299f63eb361ab9a0347937193719fba91b20 Mon Sep 17 00:00:00 2001 From: Chelsea Lin Date: Wed, 17 Sep 2025 22:47:01 +0000 Subject: [PATCH] refactor: reorganize the sqlglot scalar compiler layout - part 3 --- bigframes/core/compile/sqlglot/__init__.py | 1 - .../sqlglot/expressions/binary_compiler.py | 241 --------------- .../compile/sqlglot/expressions/blob_ops.py | 6 + .../sqlglot/expressions/comparison_ops.py | 72 ++++- .../compile/sqlglot/expressions/json_ops.py | 6 + .../sqlglot/expressions/numeric_ops.py | 144 +++++++++ bigframes/testing/utils.py | 39 ++- .../test_mul_timedelta/out.sql | 43 --- .../test_obj_make_ref/out.sql | 0 .../test_eq_null_match/out.sql | 0 .../test_eq_numeric/out.sql | 0 .../test_ge_numeric/out.sql | 0 .../test_gt_numeric/out.sql | 0 .../test_le_numeric/out.sql | 0 .../test_lt_numeric/out.sql | 0 .../test_ne_numeric/out.sql | 0 .../test_add_timedelta/out.sql | 0 .../test_sub_timedelta/out.sql | 0 .../test_json_set/out.sql | 0 .../test_add_numeric/out.sql | 0 .../test_div_numeric/out.sql | 0 .../test_div_timedelta/out.sql | 0 .../test_floordiv_timedelta/out.sql | 0 .../test_mul_numeric/out.sql | 0 .../test_sub_numeric/out.sql | 0 .../test_add_string/out.sql | 0 .../expressions/test_binary_compiler.py | 278 ------------------ .../sqlglot/expressions/test_blob_ops.py | 5 + .../expressions/test_comparison_ops.py | 78 +++++ .../sqlglot/expressions/test_datetime_ops.py | 27 ++ .../sqlglot/expressions/test_json_ops.py | 10 + .../sqlglot/expressions/test_numeric_ops.py | 86 ++++++ .../sqlglot/expressions/test_string_ops.py | 8 + 33 files changed, 469 insertions(+), 575 deletions(-) delete mode 100644 bigframes/core/compile/sqlglot/expressions/binary_compiler.py delete mode 100644 tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_blob_ops}/test_obj_make_ref/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_comparison_ops}/test_eq_null_match/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_comparison_ops}/test_eq_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_comparison_ops}/test_ge_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_comparison_ops}/test_gt_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_comparison_ops}/test_le_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_comparison_ops}/test_lt_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_comparison_ops}/test_ne_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_datetime_ops}/test_add_timedelta/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_datetime_ops}/test_sub_timedelta/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_json_ops}/test_json_set/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_numeric_ops}/test_add_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_numeric_ops}/test_div_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_numeric_ops}/test_div_timedelta/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_numeric_ops}/test_floordiv_timedelta/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_numeric_ops}/test_mul_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_numeric_ops}/test_sub_numeric/out.sql (100%) rename tests/unit/core/compile/sqlglot/expressions/snapshots/{test_binary_compiler => test_string_ops}/test_add_string/out.sql (100%) delete mode 100644 tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py diff --git a/bigframes/core/compile/sqlglot/__init__.py b/bigframes/core/compile/sqlglot/__init__.py index 5fe8099043..fdfb6f2161 100644 --- a/bigframes/core/compile/sqlglot/__init__.py +++ b/bigframes/core/compile/sqlglot/__init__.py @@ -15,7 +15,6 @@ from bigframes.core.compile.sqlglot.compiler import SQLGlotCompiler import bigframes.core.compile.sqlglot.expressions.array_ops # noqa: F401 -import bigframes.core.compile.sqlglot.expressions.binary_compiler # noqa: F401 import bigframes.core.compile.sqlglot.expressions.blob_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.comparison_ops # noqa: F401 import bigframes.core.compile.sqlglot.expressions.date_ops # noqa: F401 diff --git a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py b/bigframes/core/compile/sqlglot/expressions/binary_compiler.py deleted file mode 100644 index b18d15cae6..0000000000 --- a/bigframes/core/compile/sqlglot/expressions/binary_compiler.py +++ /dev/null @@ -1,241 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from __future__ import annotations - -import bigframes_vendored.constants as bf_constants -import sqlglot.expressions as sge - -from bigframes import dtypes -from bigframes import operations as ops -import bigframes.core.compile.sqlglot.expressions.constants as constants -from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr -import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler - -register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op - -# TODO: add parenthesize for operators - - -@register_binary_op(ops.add_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE: - # String addition - return sge.Concat(expressions=[left.expr, right.expr]) - - if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.Add(this=left_expr, expression=right_expr) - - if ( - dtypes.is_time_or_date_like(left.dtype) - and right.dtype == dtypes.TIMEDELTA_DTYPE - ): - left_expr = _coerce_date_to_datetime(left) - return sge.TimestampAdd( - this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") - ) - if ( - dtypes.is_time_or_date_like(right.dtype) - and left.dtype == dtypes.TIMEDELTA_DTYPE - ): - right_expr = _coerce_date_to_datetime(right) - return sge.TimestampAdd( - this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND") - ) - if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: - return sge.Add(this=left.expr, expression=right.expr) - - raise TypeError( - f"Cannot add type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}" - ) - - -@register_binary_op(ops.eq_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.EQ(this=left_expr, expression=right_expr) - - -@register_binary_op(ops.eq_null_match_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = left.expr - if right.dtype != dtypes.BOOL_DTYPE: - left_expr = _coerce_bool_to_int(left) - - right_expr = right.expr - if left.dtype != dtypes.BOOL_DTYPE: - right_expr = _coerce_bool_to_int(right) - - sentinel = sge.convert("$NULL_SENTINEL$") - left_coalesce = sge.Coalesce( - this=sge.Cast(this=left_expr, to="STRING"), expressions=[sentinel] - ) - right_coalesce = sge.Coalesce( - this=sge.Cast(this=right_expr, to="STRING"), expressions=[sentinel] - ) - return sge.EQ(this=left_coalesce, expression=right_coalesce) - - -@register_binary_op(ops.div_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - - result = sge.func("IEEE_DIVIDE", left_expr, right_expr) - if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): - return sge.Cast(this=sge.Floor(this=result), to="INT64") - else: - return result - - -@register_binary_op(ops.floordiv_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - - result: sge.Expression = sge.Cast( - this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64" - ) - - # DIV(N, 0) will error in bigquery, but needs to return `0` for int, and - # `inf`` for float in BQ so we short-circuit in this case. - # Multiplying left by zero propogates nulls. - zero_result = ( - constants._INF - if (left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE) - else constants._ZERO - ) - result = sge.Case( - ifs=[ - sge.If( - this=sge.EQ(this=right_expr, expression=constants._ZERO), - true=zero_result * left_expr, - ) - ], - default=result, - ) - - if dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE: - result = sge.Cast(this=sge.Floor(this=result), to="INT64") - - return result - - -@register_binary_op(ops.ge_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.GTE(this=left_expr, expression=right_expr) - - -@register_binary_op(ops.gt_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.GT(this=left_expr, expression=right_expr) - - -@register_binary_op(ops.JSONSet, pass_op=True) -def _(left: TypedExpr, right: TypedExpr, op) -> sge.Expression: - return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr) - - -@register_binary_op(ops.lt_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.LT(this=left_expr, expression=right_expr) - - -@register_binary_op(ops.le_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.LTE(this=left_expr, expression=right_expr) - - -@register_binary_op(ops.mul_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - - result = sge.Mul(this=left_expr, expression=right_expr) - - if (dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE) or ( - left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype) - ): - return sge.Cast(this=sge.Floor(this=result), to="INT64") - else: - return result - - -@register_binary_op(ops.ne_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.NEQ(this=left_expr, expression=right_expr) - - -@register_binary_op(ops.obj_make_ref_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - return sge.func("OBJ.MAKE_REF", left.expr, right.expr) - - -@register_binary_op(ops.sub_op) -def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: - if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): - left_expr = _coerce_bool_to_int(left) - right_expr = _coerce_bool_to_int(right) - return sge.Sub(this=left_expr, expression=right_expr) - - if ( - dtypes.is_time_or_date_like(left.dtype) - and right.dtype == dtypes.TIMEDELTA_DTYPE - ): - left_expr = _coerce_date_to_datetime(left) - return sge.TimestampSub( - this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") - ) - if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like( - right.dtype - ): - left_expr = _coerce_date_to_datetime(left) - right_expr = _coerce_date_to_datetime(right) - return sge.TimestampDiff( - this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND") - ) - - if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: - return sge.Sub(this=left.expr, expression=right.expr) - - raise TypeError( - f"Cannot subtract type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}" - ) - - -def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression: - """Coerce boolean expression to integer.""" - if typed_expr.dtype == dtypes.BOOL_DTYPE: - return sge.Cast(this=typed_expr.expr, to="INT64") - return typed_expr.expr - - -def _coerce_date_to_datetime(typed_expr: TypedExpr) -> sge.Expression: - """Coerce date expression to datetime.""" - if typed_expr.dtype == dtypes.DATE_DTYPE: - return sge.Cast(this=typed_expr.expr, to="DATETIME") - return typed_expr.expr diff --git a/bigframes/core/compile/sqlglot/expressions/blob_ops.py b/bigframes/core/compile/sqlglot/expressions/blob_ops.py index 58f905087d..03708f80c6 100644 --- a/bigframes/core/compile/sqlglot/expressions/blob_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/blob_ops.py @@ -21,6 +21,7 @@ import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op @register_unary_op(ops.obj_fetch_metadata_op) @@ -31,3 +32,8 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.ObjGetAccessUrl) def _(expr: TypedExpr) -> sge.Expression: return sge.func("OBJ.GET_ACCESS_URL", expr.expr) + + +@register_binary_op(ops.obj_make_ref_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + return sge.func("OBJ.MAKE_REF", left.expr, right.expr) diff --git a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py index 3bf94cf8ab..eb08144b8a 100644 --- a/bigframes/core/compile/sqlglot/expressions/comparison_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/comparison_ops.py @@ -19,12 +19,13 @@ import pandas as pd import sqlglot.expressions as sge +from bigframes import dtypes from bigframes import operations as ops from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler -import bigframes.dtypes as dtypes register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op @register_unary_op(ops.IsInOp, pass_op=True) @@ -53,7 +54,76 @@ def _(expr: TypedExpr, op: ops.IsInOp) -> sge.Expression: ) +@register_binary_op(ops.eq_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.EQ(this=left_expr, expression=right_expr) + + +@register_binary_op(ops.eq_null_match_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = left.expr + if right.dtype != dtypes.BOOL_DTYPE: + left_expr = _coerce_bool_to_int(left) + + right_expr = right.expr + if left.dtype != dtypes.BOOL_DTYPE: + right_expr = _coerce_bool_to_int(right) + + sentinel = sge.convert("$NULL_SENTINEL$") + left_coalesce = sge.Coalesce( + this=sge.Cast(this=left_expr, to="STRING"), expressions=[sentinel] + ) + right_coalesce = sge.Coalesce( + this=sge.Cast(this=right_expr, to="STRING"), expressions=[sentinel] + ) + return sge.EQ(this=left_coalesce, expression=right_coalesce) + + +@register_binary_op(ops.ge_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.GTE(this=left_expr, expression=right_expr) + + +@register_binary_op(ops.gt_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.GT(this=left_expr, expression=right_expr) + + +@register_binary_op(ops.lt_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.LT(this=left_expr, expression=right_expr) + + +@register_binary_op(ops.le_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.LTE(this=left_expr, expression=right_expr) + + +@register_binary_op(ops.ne_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.NEQ(this=left_expr, expression=right_expr) + + # Helpers def _is_null(value) -> bool: # float NaN/inf should be treated as distinct from 'true' null values return typing.cast(bool, pd.isna(value)) and not isinstance(value, float) + + +def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression: + """Coerce boolean expression to integer.""" + if typed_expr.dtype == dtypes.BOOL_DTYPE: + return sge.Cast(this=typed_expr.expr, to="INT64") + return typed_expr.expr diff --git a/bigframes/core/compile/sqlglot/expressions/json_ops.py b/bigframes/core/compile/sqlglot/expressions/json_ops.py index 754e8d80eb..442eb9fdf5 100644 --- a/bigframes/core/compile/sqlglot/expressions/json_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/json_ops.py @@ -21,6 +21,7 @@ import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op @register_unary_op(ops.JSONExtract, pass_op=True) @@ -66,3 +67,8 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.ToJSONString) def _(expr: TypedExpr) -> sge.Expression: return sge.func("TO_JSON_STRING", expr.expr) + + +@register_binary_op(ops.JSONSet, pass_op=True) +def _(left: TypedExpr, right: TypedExpr, op) -> sge.Expression: + return sge.func("JSON_SET", left.expr, sge.convert(op.json_path), right.expr) diff --git a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py index 09c08e2095..1a6447ceb7 100644 --- a/bigframes/core/compile/sqlglot/expressions/numeric_ops.py +++ b/bigframes/core/compile/sqlglot/expressions/numeric_ops.py @@ -14,14 +14,17 @@ from __future__ import annotations +import bigframes_vendored.constants as bf_constants import sqlglot.expressions as sge +from bigframes import dtypes from bigframes import operations as ops import bigframes.core.compile.sqlglot.expressions.constants as constants from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op +register_binary_op = scalar_compiler.scalar_op_compiler.register_binary_op @register_unary_op(ops.abs_op) @@ -238,3 +241,144 @@ def _(expr: TypedExpr) -> sge.Expression: @register_unary_op(ops.tanh_op) def _(expr: TypedExpr) -> sge.Expression: return sge.func("TANH", expr.expr) + + +@register_binary_op(ops.add_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if left.dtype == dtypes.STRING_DTYPE and right.dtype == dtypes.STRING_DTYPE: + # String addition + return sge.Concat(expressions=[left.expr, right.expr]) + + if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.Add(this=left_expr, expression=right_expr) + + if ( + dtypes.is_time_or_date_like(left.dtype) + and right.dtype == dtypes.TIMEDELTA_DTYPE + ): + left_expr = _coerce_date_to_datetime(left) + return sge.TimestampAdd( + this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") + ) + if ( + dtypes.is_time_or_date_like(right.dtype) + and left.dtype == dtypes.TIMEDELTA_DTYPE + ): + right_expr = _coerce_date_to_datetime(right) + return sge.TimestampAdd( + this=right_expr, expression=left.expr, unit=sge.Var(this="MICROSECOND") + ) + if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: + return sge.Add(this=left.expr, expression=right.expr) + + raise TypeError( + f"Cannot add type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}" + ) + + +@register_binary_op(ops.div_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + + result = sge.func("IEEE_DIVIDE", left_expr, right_expr) + if left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype): + return sge.Cast(this=sge.Floor(this=result), to="INT64") + else: + return result + + +@register_binary_op(ops.floordiv_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + + result: sge.Expression = sge.Cast( + this=sge.Floor(this=sge.func("IEEE_DIVIDE", left_expr, right_expr)), to="INT64" + ) + + # DIV(N, 0) will error in bigquery, but needs to return `0` for int, and + # `inf`` for float in BQ so we short-circuit in this case. + # Multiplying left by zero propogates nulls. + zero_result = ( + constants._INF + if (left.dtype == dtypes.FLOAT_DTYPE or right.dtype == dtypes.FLOAT_DTYPE) + else constants._ZERO + ) + result = sge.Case( + ifs=[ + sge.If( + this=sge.EQ(this=right_expr, expression=constants._ZERO), + true=zero_result * left_expr, + ) + ], + default=result, + ) + + if dtypes.is_numeric(right.dtype) and left.dtype == dtypes.TIMEDELTA_DTYPE: + result = sge.Cast(this=sge.Floor(this=result), to="INT64") + + return result + + +@register_binary_op(ops.mul_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + + result = sge.Mul(this=left_expr, expression=right_expr) + + if (dtypes.is_numeric(left.dtype) and right.dtype == dtypes.TIMEDELTA_DTYPE) or ( + left.dtype == dtypes.TIMEDELTA_DTYPE and dtypes.is_numeric(right.dtype) + ): + return sge.Cast(this=sge.Floor(this=result), to="INT64") + else: + return result + + +@register_binary_op(ops.sub_op) +def _(left: TypedExpr, right: TypedExpr) -> sge.Expression: + if dtypes.is_numeric(left.dtype) and dtypes.is_numeric(right.dtype): + left_expr = _coerce_bool_to_int(left) + right_expr = _coerce_bool_to_int(right) + return sge.Sub(this=left_expr, expression=right_expr) + + if ( + dtypes.is_time_or_date_like(left.dtype) + and right.dtype == dtypes.TIMEDELTA_DTYPE + ): + left_expr = _coerce_date_to_datetime(left) + return sge.TimestampSub( + this=left_expr, expression=right.expr, unit=sge.Var(this="MICROSECOND") + ) + if dtypes.is_time_or_date_like(left.dtype) and dtypes.is_time_or_date_like( + right.dtype + ): + left_expr = _coerce_date_to_datetime(left) + right_expr = _coerce_date_to_datetime(right) + return sge.TimestampDiff( + this=left_expr, expression=right_expr, unit=sge.Var(this="MICROSECOND") + ) + + if left.dtype == dtypes.TIMEDELTA_DTYPE and right.dtype == dtypes.TIMEDELTA_DTYPE: + return sge.Sub(this=left.expr, expression=right.expr) + + raise TypeError( + f"Cannot subtract type {left.dtype} and {right.dtype}. {bf_constants.FEEDBACK_LINK}" + ) + + +def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression: + """Coerce boolean expression to integer.""" + if typed_expr.dtype == dtypes.BOOL_DTYPE: + return sge.Cast(this=typed_expr.expr, to="INT64") + return typed_expr.expr + + +def _coerce_date_to_datetime(typed_expr: TypedExpr) -> sge.Expression: + """Coerce date expression to datetime.""" + if typed_expr.dtype == dtypes.DATE_DTYPE: + return sge.Cast(this=typed_expr.expr, to="DATETIME") + return typed_expr.expr diff --git a/bigframes/testing/utils.py b/bigframes/testing/utils.py index d38e323d57..b4daab7aad 100644 --- a/bigframes/testing/utils.py +++ b/bigframes/testing/utils.py @@ -25,9 +25,10 @@ import pyarrow as pa # type: ignore import pytest -from bigframes.core import expression as expr +from bigframes import operations as ops +from bigframes.core import expression as ex import bigframes.functions._utils as bff_utils -import bigframes.pandas +import bigframes.pandas as bpd ML_REGRESSION_METRICS = [ "mean_absolute_error", @@ -67,17 +68,13 @@ # Prefer this function for tests that run in both ordered and unordered mode -def assert_dfs_equivalent( - pd_df: pd.DataFrame, bf_df: bigframes.pandas.DataFrame, **kwargs -): +def assert_dfs_equivalent(pd_df: pd.DataFrame, bf_df: bpd.DataFrame, **kwargs): bf_df_local = bf_df.to_pandas() ignore_order = not bf_df._session._strictly_ordered assert_pandas_df_equal(bf_df_local, pd_df, ignore_order=ignore_order, **kwargs) -def assert_series_equivalent( - pd_series: pd.Series, bf_series: bigframes.pandas.Series, **kwargs -): +def assert_series_equivalent(pd_series: pd.Series, bf_series: bpd.Series, **kwargs): bf_df_local = bf_series.to_pandas() ignore_order = not bf_series._session._strictly_ordered assert_series_equal(bf_df_local, pd_series, ignore_order=ignore_order, **kwargs) @@ -452,12 +449,12 @@ def get_function_name(func, package_requirements=None, is_row_processor=False): def _apply_unary_ops( - obj: bigframes.pandas.DataFrame, - ops_list: Sequence[expr.Expression], + obj: bpd.DataFrame, + ops_list: Sequence[ex.Expression], new_names: Sequence[str], ) -> str: """Applies a list of unary ops to the given DataFrame and returns the SQL - representing the resulting DataFrames.""" + representing the resulting DataFrame.""" array_value = obj._block.expr result, old_names = array_value.compute_values(ops_list) @@ -468,3 +465,23 @@ def _apply_unary_ops( sql = result.session._executor.to_sql(result, enable_cache=False) return sql + + +def _apply_binary_op( + obj: bpd.DataFrame, + op: ops.BinaryOp, + l_arg: str, + r_arg: Union[str, ex.Expression], +) -> str: + """Applies a binary op to the given DataFrame and return the SQL representing + the resulting DataFrame.""" + array_value = obj._block.expr + op_expr = op.as_expr(l_arg, r_arg) + result, col_ids = array_value.compute_values([op_expr]) + + # Rename columns for deterministic golden SQL results. + assert len(col_ids) == 1 + result = result.rename_columns({col_ids[0]: l_arg}).select_columns([l_arg]) + + sql = result.session._executor.to_sql(result, enable_cache=False) + return sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql deleted file mode 100644 index f8752d0a60..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql +++ /dev/null @@ -1,43 +0,0 @@ -WITH `bfcte_0` AS ( - SELECT - `int64_col` AS `bfcol_0`, - `rowindex` AS `bfcol_1`, - `timestamp_col` AS `bfcol_2`, - `duration_col` AS `bfcol_3` - FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` -), `bfcte_1` AS ( - SELECT - *, - `bfcol_1` AS `bfcol_8`, - `bfcol_2` AS `bfcol_9`, - `bfcol_0` AS `bfcol_10`, - `bfcol_3` AS `bfcol_11` - FROM `bfcte_0` -), `bfcte_2` AS ( - SELECT - *, - `bfcol_8` AS `bfcol_16`, - `bfcol_9` AS `bfcol_17`, - `bfcol_10` AS `bfcol_18`, - `bfcol_11` AS `bfcol_19`, - CAST(FLOOR(`bfcol_11` * `bfcol_10`) AS INT64) AS `bfcol_20` - FROM `bfcte_1` -), `bfcte_3` AS ( - SELECT - *, - `bfcol_16` AS `bfcol_26`, - `bfcol_17` AS `bfcol_27`, - `bfcol_18` AS `bfcol_28`, - `bfcol_19` AS `bfcol_29`, - `bfcol_20` AS `bfcol_30`, - CAST(FLOOR(`bfcol_18` * `bfcol_19`) AS INT64) AS `bfcol_31` - FROM `bfcte_2` -) -SELECT - `bfcol_26` AS `rowindex`, - `bfcol_27` AS `timestamp_col`, - `bfcol_28` AS `int64_col`, - `bfcol_29` AS `duration_col`, - `bfcol_30` AS `timedelta_mul_numeric`, - `bfcol_31` AS `numeric_mul_timedelta` -FROM `bfcte_3` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_obj_make_ref/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_blob_ops/test_obj_make_ref/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_obj_make_ref/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_blob_ops/test_obj_make_ref/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_null_match/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_null_match/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_null_match/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_eq_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_eq_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ge_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ge_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ge_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ge_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_gt_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_gt_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_gt_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_gt_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_le_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_le_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_le_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_le_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_lt_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_lt_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_lt_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_lt_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ne_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_ne_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_comparison_ops/test_ne_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_add_timedelta/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_timedelta/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_add_timedelta/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_sub_timedelta/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_datetime_ops/test_sub_timedelta/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_json_set/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_json_ops/test_json_set/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_json_set/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_json_ops/test_json_set/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_add_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_add_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_timedelta/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_div_timedelta/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_div_timedelta/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_floordiv_timedelta/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_floordiv_timedelta/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_floordiv_timedelta/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mul_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_mul_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_numeric/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_sub_numeric/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_numeric/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_numeric_ops/test_sub_numeric/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_string/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_add_string/out.sql similarity index 100% rename from tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_add_string/out.sql rename to tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_add_string/out.sql diff --git a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py deleted file mode 100644 index a2218d0afa..0000000000 --- a/tests/unit/core/compile/sqlglot/expressions/test_binary_compiler.py +++ /dev/null @@ -1,278 +0,0 @@ -# Copyright 2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import typing - -import pandas as pd -import pytest - -from bigframes import operations as ops -import bigframes.core.expression as ex -import bigframes.pandas as bpd - -pytest.importorskip("pytest_snapshot") - - -def _apply_binary_op( - obj: bpd.DataFrame, - op: ops.BinaryOp, - l_arg: str, - r_arg: typing.Union[str, ex.Expression], -) -> str: - array_value = obj._block.expr - op_expr = op.as_expr(l_arg, r_arg) - result, col_ids = array_value.compute_values([op_expr]) - - # Rename columns for deterministic golden SQL results. - assert len(col_ids) == 1 - result = result.rename_columns({col_ids[0]: l_arg}).select_columns([l_arg]) - - sql = result.session._executor.to_sql(result, enable_cache=False) - return sql - - -def test_add_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_add_int"] = bf_df["int64_col"] + bf_df["int64_col"] - bf_df["int_add_1"] = bf_df["int64_col"] + 1 - - bf_df["int_add_bool"] = bf_df["int64_col"] + bf_df["bool_col"] - bf_df["bool_add_int"] = bf_df["bool_col"] + bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_add_string(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["string_col"]] - sql = _apply_binary_op(bf_df, ops.add_op, "string_col", ex.const("a")) - - snapshot.assert_match(sql, "out.sql") - - -def test_add_timedelta(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col", "date_col"]] - timedelta = pd.Timedelta(1, unit="d") - - bf_df["date_add_timedelta"] = bf_df["date_col"] + timedelta - bf_df["timestamp_add_timedelta"] = bf_df["timestamp_col"] + timedelta - bf_df["timedelta_add_date"] = timedelta + bf_df["date_col"] - bf_df["timedelta_add_timestamp"] = timedelta + bf_df["timestamp_col"] - bf_df["timedelta_add_timedelta"] = timedelta + timedelta - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_add_unsupported_raises(scalar_types_df: bpd.DataFrame): - with pytest.raises(TypeError): - _apply_binary_op(scalar_types_df, ops.add_op, "timestamp_col", "date_col") - - with pytest.raises(TypeError): - _apply_binary_op(scalar_types_df, ops.add_op, "int64_col", "string_col") - - -def test_div_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] - - bf_df["int_div_int"] = bf_df["int64_col"] / bf_df["int64_col"] - bf_df["int_div_1"] = bf_df["int64_col"] / 1 - bf_df["int_div_0"] = bf_df["int64_col"] / 0.0 - - bf_df["int_div_float"] = bf_df["int64_col"] / bf_df["float64_col"] - bf_df["float_div_int"] = bf_df["float64_col"] / bf_df["int64_col"] - bf_df["float_div_0"] = bf_df["float64_col"] / 0.0 - - bf_df["int_div_bool"] = bf_df["int64_col"] / bf_df["bool_col"] - bf_df["bool_div_int"] = bf_df["bool_col"] / bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col", "int64_col"]] - timedelta = pd.Timedelta(1, unit="d") - bf_df["timedelta_div_numeric"] = timedelta / bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - sql = _apply_binary_op(bf_df, ops.eq_null_match_op, "int64_col", "bool_col") - snapshot.assert_match(sql, "out.sql") - - -def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"] - bf_df["int_ne_1"] = bf_df["int64_col"] == 1 - - bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"] - bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] - - bf_df["int_div_int"] = bf_df["int64_col"] // bf_df["int64_col"] - bf_df["int_div_1"] = bf_df["int64_col"] // 1 - bf_df["int_div_0"] = bf_df["int64_col"] // 0.0 - - bf_df["int_div_float"] = bf_df["int64_col"] // bf_df["float64_col"] - bf_df["float_div_int"] = bf_df["float64_col"] // bf_df["int64_col"] - bf_df["float_div_0"] = bf_df["float64_col"] // 0.0 - - bf_df["int_div_bool"] = bf_df["int64_col"] // bf_df["bool_col"] - bf_df["bool_div_int"] = bf_df["bool_col"] // bf_df["int64_col"] - - -def test_floordiv_timedelta(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col", "date_col"]] - timedelta = pd.Timedelta(1, unit="d") - - bf_df["timedelta_div_numeric"] = timedelta // 2 - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_gt_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_gt_int"] = bf_df["int64_col"] > bf_df["int64_col"] - bf_df["int_gt_1"] = bf_df["int64_col"] > 1 - - bf_df["int_gt_bool"] = bf_df["int64_col"] > bf_df["bool_col"] - bf_df["bool_gt_int"] = bf_df["bool_col"] > bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_ge_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_ge_int"] = bf_df["int64_col"] >= bf_df["int64_col"] - bf_df["int_ge_1"] = bf_df["int64_col"] >= 1 - - bf_df["int_ge_bool"] = bf_df["int64_col"] >= bf_df["bool_col"] - bf_df["bool_ge_int"] = bf_df["bool_col"] >= bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_json_set(json_types_df: bpd.DataFrame, snapshot): - bf_df = json_types_df[["json_col"]] - sql = _apply_binary_op( - bf_df, ops.JSONSet(json_path="$.a"), "json_col", ex.const(100) - ) - - snapshot.assert_match(sql, "out.sql") - - -def test_lt_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_lt_int"] = bf_df["int64_col"] < bf_df["int64_col"] - bf_df["int_lt_1"] = bf_df["int64_col"] < 1 - - bf_df["int_lt_bool"] = bf_df["int64_col"] < bf_df["bool_col"] - bf_df["bool_lt_int"] = bf_df["bool_col"] < bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_le_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_le_int"] = bf_df["int64_col"] <= bf_df["int64_col"] - bf_df["int_le_1"] = bf_df["int64_col"] <= 1 - - bf_df["int_le_bool"] = bf_df["int64_col"] <= bf_df["bool_col"] - bf_df["bool_le_int"] = bf_df["bool_col"] <= bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_sub_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_add_int"] = bf_df["int64_col"] - bf_df["int64_col"] - bf_df["int_add_1"] = bf_df["int64_col"] - 1 - - bf_df["int_add_bool"] = bf_df["int64_col"] - bf_df["bool_col"] - bf_df["bool_add_int"] = bf_df["bool_col"] - bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_sub_timedelta(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col", "duration_col", "date_col"]] - bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us") - - bf_df["date_sub_timedelta"] = bf_df["date_col"] - bf_df["duration_col"] - bf_df["timestamp_sub_timedelta"] = bf_df["timestamp_col"] - bf_df["duration_col"] - bf_df["timestamp_sub_date"] = bf_df["date_col"] - bf_df["date_col"] - bf_df["date_sub_timestamp"] = bf_df["timestamp_col"] - bf_df["timestamp_col"] - bf_df["timedelta_sub_timedelta"] = bf_df["duration_col"] - bf_df["duration_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_sub_unsupported_raises(scalar_types_df: bpd.DataFrame): - with pytest.raises(TypeError): - _apply_binary_op(scalar_types_df, ops.sub_op, "string_col", "string_col") - - with pytest.raises(TypeError): - _apply_binary_op(scalar_types_df, ops.sub_op, "int64_col", "string_col") - - -def test_mul_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_mul_int"] = bf_df["int64_col"] * bf_df["int64_col"] - bf_df["int_mul_1"] = bf_df["int64_col"] * 1 - - bf_df["int_mul_bool"] = bf_df["int64_col"] * bf_df["bool_col"] - bf_df["bool_mul_int"] = bf_df["bool_col"] * bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_mul_timedelta(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["timestamp_col", "int64_col", "duration_col"]] - bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us") - - bf_df["timedelta_mul_numeric"] = bf_df["duration_col"] * bf_df["int64_col"] - bf_df["numeric_mul_timedelta"] = bf_df["int64_col"] * bf_df["duration_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") - - -def test_obj_make_ref(scalar_types_df: bpd.DataFrame, snapshot): - blob_df = scalar_types_df["string_col"].str.to_blob() - snapshot.assert_match(blob_df.to_frame().sql, "out.sql") - - -def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot): - bf_df = scalar_types_df[["int64_col", "bool_col"]] - - bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"] - bf_df["int_ne_1"] = bf_df["int64_col"] != 1 - - bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"] - bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"] - - snapshot.assert_match(bf_df.sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/expressions/test_blob_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_blob_ops.py index 7876a754ee..80aa22aaac 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_blob_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_blob_ops.py @@ -29,3 +29,8 @@ def test_obj_get_access_url(scalar_types_df: bpd.DataFrame, snapshot): blob_s = scalar_types_df["string_col"].str.to_blob() sql = blob_s.blob.read_url().to_frame().sql snapshot.assert_match(sql, "out.sql") + + +def test_obj_make_ref(scalar_types_df: bpd.DataFrame, snapshot): + blob_df = scalar_types_df["string_col"].str.to_blob() + snapshot.assert_match(blob_df.to_frame().sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py index 9a901687fa..6c3eb64414 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_comparison_ops.py @@ -42,3 +42,81 @@ def test_is_in(scalar_types_df: bpd.DataFrame, snapshot): sql = utils._apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) snapshot.assert_match(sql, "out.sql") + + +def test_eq_null_match(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + sql = utils._apply_binary_op(bf_df, ops.eq_null_match_op, "int64_col", "bool_col") + snapshot.assert_match(sql, "out.sql") + + +def test_eq_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_ne_int"] = bf_df["int64_col"] == bf_df["int64_col"] + bf_df["int_ne_1"] = bf_df["int64_col"] == 1 + + bf_df["int_ne_bool"] = bf_df["int64_col"] == bf_df["bool_col"] + bf_df["bool_ne_int"] = bf_df["bool_col"] == bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_gt_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_gt_int"] = bf_df["int64_col"] > bf_df["int64_col"] + bf_df["int_gt_1"] = bf_df["int64_col"] > 1 + + bf_df["int_gt_bool"] = bf_df["int64_col"] > bf_df["bool_col"] + bf_df["bool_gt_int"] = bf_df["bool_col"] > bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_ge_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_ge_int"] = bf_df["int64_col"] >= bf_df["int64_col"] + bf_df["int_ge_1"] = bf_df["int64_col"] >= 1 + + bf_df["int_ge_bool"] = bf_df["int64_col"] >= bf_df["bool_col"] + bf_df["bool_ge_int"] = bf_df["bool_col"] >= bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_lt_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_lt_int"] = bf_df["int64_col"] < bf_df["int64_col"] + bf_df["int_lt_1"] = bf_df["int64_col"] < 1 + + bf_df["int_lt_bool"] = bf_df["int64_col"] < bf_df["bool_col"] + bf_df["bool_lt_int"] = bf_df["bool_col"] < bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_le_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_le_int"] = bf_df["int64_col"] <= bf_df["int64_col"] + bf_df["int_le_1"] = bf_df["int64_col"] <= 1 + + bf_df["int_le_bool"] = bf_df["int64_col"] <= bf_df["bool_col"] + bf_df["bool_le_int"] = bf_df["bool_col"] <= bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_ne_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_ne_int"] = bf_df["int64_col"] != bf_df["int64_col"] + bf_df["int_ne_1"] = bf_df["int64_col"] != 1 + + bf_df["int_ne_bool"] = bf_df["int64_col"] != bf_df["bool_col"] + bf_df["bool_ne_int"] = bf_df["bool_col"] != bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py index 0a8aa320bb..91926e7bdd 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_datetime_ops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pandas as pd import pytest from bigframes import operations as ops @@ -215,3 +216,29 @@ def test_iso_year(scalar_types_df: bpd.DataFrame, snapshot): sql = utils._apply_unary_ops(bf_df, [ops.iso_year_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") + + +def test_add_timedelta(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["timestamp_col", "date_col"]] + timedelta = pd.Timedelta(1, unit="d") + + bf_df["date_add_timedelta"] = bf_df["date_col"] + timedelta + bf_df["timestamp_add_timedelta"] = bf_df["timestamp_col"] + timedelta + bf_df["timedelta_add_date"] = timedelta + bf_df["date_col"] + bf_df["timedelta_add_timestamp"] = timedelta + bf_df["timestamp_col"] + bf_df["timedelta_add_timedelta"] = timedelta + timedelta + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_sub_timedelta(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["timestamp_col", "duration_col", "date_col"]] + bf_df["duration_col"] = bpd.to_timedelta(bf_df["duration_col"], unit="us") + + bf_df["date_sub_timedelta"] = bf_df["date_col"] - bf_df["duration_col"] + bf_df["timestamp_sub_timedelta"] = bf_df["timestamp_col"] - bf_df["duration_col"] + bf_df["timestamp_sub_date"] = bf_df["date_col"] - bf_df["date_col"] + bf_df["date_sub_timestamp"] = bf_df["timestamp_col"] - bf_df["timestamp_col"] + bf_df["timedelta_sub_timedelta"] = bf_df["duration_col"] - bf_df["duration_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/expressions/test_json_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_json_ops.py index ecbac10ef2..75206091e0 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_json_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_json_ops.py @@ -15,6 +15,7 @@ import pytest from bigframes import operations as ops +import bigframes.core.expression as ex import bigframes.pandas as bpd from bigframes.testing import utils @@ -97,3 +98,12 @@ def test_to_json_string(json_types_df: bpd.DataFrame, snapshot): ) snapshot.assert_match(sql, "out.sql") + + +def test_json_set(json_types_df: bpd.DataFrame, snapshot): + bf_df = json_types_df[["json_col"]] + sql = utils._apply_binary_op( + bf_df, ops.JSONSet(json_path="$.a"), "json_col", ex.const(100) + ) + + snapshot.assert_match(sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py index 10fd4b2427..e0c41857e9 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_numeric_ops.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pandas as pd import pytest from bigframes import operations as ops @@ -211,3 +212,88 @@ def test_tanh(scalar_types_df: bpd.DataFrame, snapshot): sql = utils._apply_unary_ops(bf_df, [ops.tanh_op.as_expr(col_name)], [col_name]) snapshot.assert_match(sql, "out.sql") + + +def test_add_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_add_int"] = bf_df["int64_col"] + bf_df["int64_col"] + bf_df["int_add_1"] = bf_df["int64_col"] + 1 + + bf_df["int_add_bool"] = bf_df["int64_col"] + bf_df["bool_col"] + bf_df["bool_add_int"] = bf_df["bool_col"] + bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_div_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] + + bf_df["int_div_int"] = bf_df["int64_col"] / bf_df["int64_col"] + bf_df["int_div_1"] = bf_df["int64_col"] / 1 + bf_df["int_div_0"] = bf_df["int64_col"] / 0.0 + + bf_df["int_div_float"] = bf_df["int64_col"] / bf_df["float64_col"] + bf_df["float_div_int"] = bf_df["float64_col"] / bf_df["int64_col"] + bf_df["float_div_0"] = bf_df["float64_col"] / 0.0 + + bf_df["int_div_bool"] = bf_df["int64_col"] / bf_df["bool_col"] + bf_df["bool_div_int"] = bf_df["bool_col"] / bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_div_timedelta(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["timestamp_col", "int64_col"]] + timedelta = pd.Timedelta(1, unit="d") + bf_df["timedelta_div_numeric"] = timedelta / bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_floordiv_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col"]] + + bf_df["int_div_int"] = bf_df["int64_col"] // bf_df["int64_col"] + bf_df["int_div_1"] = bf_df["int64_col"] // 1 + bf_df["int_div_0"] = bf_df["int64_col"] // 0.0 + + bf_df["int_div_float"] = bf_df["int64_col"] // bf_df["float64_col"] + bf_df["float_div_int"] = bf_df["float64_col"] // bf_df["int64_col"] + bf_df["float_div_0"] = bf_df["float64_col"] // 0.0 + + bf_df["int_div_bool"] = bf_df["int64_col"] // bf_df["bool_col"] + bf_df["bool_div_int"] = bf_df["bool_col"] // bf_df["int64_col"] + + +def test_floordiv_timedelta(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["timestamp_col", "date_col"]] + timedelta = pd.Timedelta(1, unit="d") + + bf_df["timedelta_div_numeric"] = timedelta // 2 + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_mul_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_mul_int"] = bf_df["int64_col"] * bf_df["int64_col"] + bf_df["int_mul_1"] = bf_df["int64_col"] * 1 + + bf_df["int_mul_bool"] = bf_df["int64_col"] * bf_df["bool_col"] + bf_df["bool_mul_int"] = bf_df["bool_col"] * bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") + + +def test_sub_numeric(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["int64_col", "bool_col"]] + + bf_df["int_add_int"] = bf_df["int64_col"] - bf_df["int64_col"] + bf_df["int_add_1"] = bf_df["int64_col"] - 1 + + bf_df["int_add_bool"] = bf_df["int64_col"] - bf_df["bool_col"] + bf_df["bool_add_int"] = bf_df["bool_col"] - bf_df["int64_col"] + + snapshot.assert_match(bf_df.sql, "out.sql") diff --git a/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py b/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py index 79c67a09ca..9121334811 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_string_ops.py @@ -15,6 +15,7 @@ import pytest from bigframes import operations as ops +import bigframes.core.expression as ex import bigframes.pandas as bpd from bigframes.testing import utils @@ -303,3 +304,10 @@ def test_zfill(scalar_types_df: bpd.DataFrame, snapshot): bf_df, [ops.ZfillOp(width=10).as_expr(col_name)], [col_name] ) snapshot.assert_match(sql, "out.sql") + + +def test_add_string(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = utils._apply_binary_op(bf_df, ops.add_op, "string_col", ex.const("a")) + + snapshot.assert_match(sql, "out.sql")