diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index 98f1603be7..f519aef70d 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -27,6 +27,7 @@ import bigframes.core.compile.sqlglot.expressions.constants as constants from bigframes.core.compile.sqlglot.expressions.op_registration import OpRegistration from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr +import bigframes.dtypes as dtypes UNARY_OP_REGISTRATION = OpRegistration() @@ -420,7 +421,28 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: @UNARY_OP_REGISTRATION.register(ops.IsInOp) def _(op: ops.IsInOp, expr: TypedExpr) -> sge.Expression: - return sge.In(this=expr.expr, expressions=[sge.convert(v) for v in op.values]) + values = [] + is_numeric_expr = dtypes.is_numeric(expr.dtype) + for value in op.values: + if value is None: + continue + dtype = dtypes.bigframes_type(type(value)) + if expr.dtype == dtype or is_numeric_expr and dtypes.is_numeric(dtype): + values.append(sge.convert(value)) + + if op.match_nulls: + contains_nulls = any(_is_null(value) for value in op.values) + if contains_nulls: + return sge.Is(this=expr.expr, expression=sge.Null()) | sge.In( + this=expr.expr, expressions=values + ) + + if len(values) == 0: + return sge.convert(False) + + return sge.func( + "COALESCE", sge.In(this=expr.expr, expressions=values), sge.convert(False) + ) @UNARY_OP_REGISTRATION.register(ops.isalnum_op) @@ -767,7 +789,7 @@ def _(op: ops.ToTimedeltaOp, expr: TypedExpr) -> sge.Expression: factor = UNIT_TO_US_CONVERSION_FACTORS[op.unit] if factor != 1: value = sge.Mul(this=value, expression=sge.convert(factor)) - return sge.Interval(this=value, unit=sge.Identifier(this="MICROSECOND")) + return value @UNARY_OP_REGISTRATION.register(ops.UnixMicros) @@ -866,3 +888,9 @@ def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression: ], default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")), ) + + +# Helpers +def _is_null(value) -> bool: + # float NaN/inf should be treated as distinct from 'true' null values + return typing.cast(bool, pd.isna(value)) and not isinstance(value, float) diff --git a/tests/system/small/engines/test_generic_ops.py b/tests/system/small/engines/test_generic_ops.py index 8deef3638e..14c6e9a454 100644 --- a/tests/system/small/engines/test_generic_ops.py +++ b/tests/system/small/engines/test_generic_ops.py @@ -392,7 +392,7 @@ def test_engines_invert_op(scalars_array_value: array_value.ArrayValue, engine): assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine) -@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True) +@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True) def test_engines_isin_op(scalars_array_value: array_value.ArrayValue, engine): arr, col_ids = scalars_array_value.compute_values( [ diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql index c8a8cf6cbf..f8752d0a60 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_mul_timedelta/out.sql @@ -11,7 +11,7 @@ WITH `bfcte_0` AS ( `bfcol_1` AS `bfcol_8`, `bfcol_2` AS `bfcol_9`, `bfcol_0` AS `bfcol_10`, - INTERVAL `bfcol_3` MICROSECOND AS `bfcol_11` + `bfcol_3` AS `bfcol_11` FROM `bfcte_0` ), `bfcte_2` AS ( SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql index 460f941d1b..2d615fcca6 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_binary_compiler/test_sub_timedelta/out.sql @@ -11,7 +11,7 @@ WITH `bfcte_0` AS ( `bfcol_1` AS `bfcol_8`, `bfcol_2` AS `bfcol_9`, `bfcol_0` AS `bfcol_10`, - INTERVAL `bfcol_3` MICROSECOND AS `bfcol_11` + `bfcol_3` AS `bfcol_11` FROM `bfcte_0` ), `bfcte_2` AS ( SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_is_in/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_is_in/out.sql index 36941df71b..7a1a2a743d 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_is_in/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_is_in/out.sql @@ -1,13 +1,32 @@ WITH `bfcte_0` AS ( SELECT - `int64_col` AS `bfcol_0` + `int64_col` AS `bfcol_0`, + `float64_col` AS `bfcol_1` FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` ), `bfcte_1` AS ( SELECT *, - `bfcol_0` IN (1, 2, 3) AS `bfcol_1` + COALESCE(`bfcol_0` IN (1, 2, 3), FALSE) AS `bfcol_2`, + ( + `bfcol_0` IS NULL + ) OR `bfcol_0` IN (123456) AS `bfcol_3`, + COALESCE(`bfcol_0` IN (1.0, 2.0, 3.0), FALSE) AS `bfcol_4`, + FALSE AS `bfcol_5`, + COALESCE(`bfcol_0` IN (2.5, 3), FALSE) AS `bfcol_6`, + FALSE AS `bfcol_7`, + COALESCE(`bfcol_0` IN (123456), FALSE) AS `bfcol_8`, + ( + `bfcol_1` IS NULL + ) OR `bfcol_1` IN (1, 2, 3) AS `bfcol_9` FROM `bfcte_0` ) SELECT - `bfcol_1` AS `int64_col` + `bfcol_2` AS `ints`, + `bfcol_3` AS `ints_w_null`, + `bfcol_4` AS `floats`, + `bfcol_5` AS `strings`, + `bfcol_6` AS `mixed`, + `bfcol_7` AS `empty`, + `bfcol_8` AS `ints_wo_match_nulls`, + `bfcol_9` AS `float_in_ints` FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_to_timedelta/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_to_timedelta/out.sql index 01ebebc455..057e6c778e 100644 --- a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_to_timedelta/out.sql +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_to_timedelta/out.sql @@ -8,7 +8,7 @@ WITH `bfcte_0` AS ( *, `bfcol_1` AS `bfcol_4`, `bfcol_0` AS `bfcol_5`, - INTERVAL `bfcol_0` MICROSECOND AS `bfcol_6` + `bfcol_0` AS `bfcol_6` FROM `bfcte_0` ), `bfcte_2` AS ( SELECT @@ -16,7 +16,7 @@ WITH `bfcte_0` AS ( `bfcol_4` AS `bfcol_10`, `bfcol_5` AS `bfcol_11`, `bfcol_6` AS `bfcol_12`, - INTERVAL (`bfcol_5` * 1000000) MICROSECOND AS `bfcol_13` + `bfcol_5` * 1000000 AS `bfcol_13` FROM `bfcte_1` ), `bfcte_3` AS ( SELECT @@ -25,7 +25,7 @@ WITH `bfcte_0` AS ( `bfcol_11` AS `bfcol_19`, `bfcol_12` AS `bfcol_20`, `bfcol_13` AS `bfcol_21`, - INTERVAL (`bfcol_11` * 604800000000) MICROSECOND AS `bfcol_22` + `bfcol_11` * 604800000000 AS `bfcol_22` FROM `bfcte_2` ) SELECT diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 815bb84a9a..fced18f5be 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -370,12 +370,25 @@ def test_invert(scalar_types_df: bpd.DataFrame, snapshot): def test_is_in(scalar_types_df: bpd.DataFrame, snapshot): - col_name = "int64_col" - bf_df = scalar_types_df[[col_name]] - sql = _apply_unary_ops( - bf_df, [ops.IsInOp(values=(1, 2, 3)).as_expr(col_name)], [col_name] - ) + int_col = "int64_col" + float_col = "float64_col" + bf_df = scalar_types_df[[int_col, float_col]] + ops_map = { + "ints": ops.IsInOp(values=(1, 2, 3)).as_expr(int_col), + "ints_w_null": ops.IsInOp(values=(None, 123456)).as_expr(int_col), + "floats": ops.IsInOp(values=(1.0, 2.0, 3.0), match_nulls=False).as_expr( + int_col + ), + "strings": ops.IsInOp(values=("1.0", "2.0")).as_expr(int_col), + "mixed": ops.IsInOp(values=("1.0", 2.5, 3)).as_expr(int_col), + "empty": ops.IsInOp(values=()).as_expr(int_col), + "ints_wo_match_nulls": ops.IsInOp( + values=(None, 123456), match_nulls=False + ).as_expr(int_col), + "float_in_ints": ops.IsInOp(values=(1, 2, 3, None)).as_expr(float_col), + } + sql = _apply_unary_ops(bf_df, list(ops_map.values()), list(ops_map.keys())) snapshot.assert_match(sql, "out.sql")