Skip to content

Commit a21361b

Browse files
committed
refactor: add ToArrayOp to the sqlglot compiler
1 parent 3702f56 commit a21361b

File tree

4 files changed

+76
-7
lines changed

4 files changed

+76
-7
lines changed

bigframes/core/compile/sqlglot/expressions/array_ops.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,10 @@
2222
from bigframes import operations as ops
2323
from bigframes.core.compile.sqlglot.expressions.typed_expr import TypedExpr
2424
import bigframes.core.compile.sqlglot.scalar_compiler as scalar_compiler
25+
import bigframes.dtypes as dtypes
2526

2627
register_unary_op = scalar_compiler.scalar_op_compiler.register_unary_op
27-
28-
29-
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
30-
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
31-
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
28+
register_nary_op = scalar_compiler.scalar_op_compiler.register_nary_op
3229

3330

3431
@register_unary_op(ops.ArrayIndexOp, pass_op=True)
@@ -66,3 +63,27 @@ def _(expr: TypedExpr, op: ops.ArraySliceOp) -> sge.Expression:
6663
)
6764

6865
return sge.array(selected_elements)
66+
67+
68+
@register_unary_op(ops.ArrayToStringOp, pass_op=True)
69+
def _(expr: TypedExpr, op: ops.ArrayToStringOp) -> sge.Expression:
70+
return sge.ArrayToString(this=expr.expr, expression=f"'{op.delimiter}'")
71+
72+
73+
@register_nary_op(ops.ToArrayOp)
74+
def _(*exprs: TypedExpr) -> sge.Expression:
75+
do_upcast_bool = any(
76+
dtypes.is_numeric(expr.dtype, include_bool=False) for expr in exprs
77+
)
78+
if do_upcast_bool:
79+
sg_exprs = [_coerce_bool_to_int(expr) for expr in exprs]
80+
else:
81+
sg_exprs = [expr.expr for expr in exprs]
82+
return sge.Array(expressions=sg_exprs)
83+
84+
85+
def _coerce_bool_to_int(typed_expr: TypedExpr) -> sge.Expression:
86+
"""Coerce boolean expression to integer."""
87+
if typed_expr.dtype == dtypes.BOOL_DTYPE:
88+
return sge.Cast(this=typed_expr.expr, to="INT64")
89+
return typed_expr.expr

tests/system/small/engines/test_array_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
REFERENCE_ENGINE = polars_executor.PolarsExecutor()
2727

2828

29-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
29+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
3030
def test_engines_to_array_op(scalars_array_value: array_value.ArrayValue, engine):
3131
# Bigquery won't allow you to materialize arrays with null, so use non-nullable
3232
int64_non_null = ops.coalesce_op.as_expr("int64_col", expression.const(0))
@@ -46,7 +46,7 @@ def test_engines_to_array_op(scalars_array_value: array_value.ArrayValue, engine
4646
assert_equivalence_execution(arr.node, REFERENCE_ENGINE, engine)
4747

4848

49-
@pytest.mark.parametrize("engine", ["polars", "bq"], indirect=True)
49+
@pytest.mark.parametrize("engine", ["polars", "bq", "bq-sqlglot"], indirect=True)
5050
def test_engines_array_reduce_op(arrays_array_value: array_value.ArrayValue, engine):
5151
arr, _ = arrays_array_value.compute_values(
5252
[
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col`,
4+
`float64_col`,
5+
`int64_col`,
6+
`string_col`
7+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
8+
), `bfcte_1` AS (
9+
SELECT
10+
*,
11+
[COALESCE(`bool_col`, FALSE)] AS `bfcol_8`,
12+
[COALESCE(`int64_col`, 0)] AS `bfcol_9`,
13+
[COALESCE(`string_col`, ''), COALESCE(`string_col`, '')] AS `bfcol_10`,
14+
[
15+
COALESCE(`int64_col`, 0),
16+
CAST(COALESCE(`bool_col`, FALSE) AS INT64),
17+
COALESCE(`float64_col`, 0.0)
18+
] AS `bfcol_11`
19+
FROM `bfcte_0`
20+
)
21+
SELECT
22+
`bfcol_8` AS `bool_col`,
23+
`bfcol_9` AS `int64_col`,
24+
`bfcol_10` AS `strs_col`,
25+
`bfcol_11` AS `numeric_col`
26+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_array_ops.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import pytest
1616

1717
from bigframes import operations as ops
18+
from bigframes.core import expression
1819
from bigframes.operations._op_converters import convert_index, convert_slice
1920
import bigframes.pandas as bpd
2021
from bigframes.testing import utils
@@ -60,3 +61,24 @@ def test_array_slice_with_start_and_stop(repeated_types_df: bpd.DataFrame, snaps
6061
)
6162

6263
snapshot.assert_match(sql, "out.sql")
64+
65+
66+
def test_to_array_op(scalar_types_df: bpd.DataFrame, snapshot):
67+
bf_df = scalar_types_df[["int64_col", "bool_col", "float64_col", "string_col"]]
68+
# Bigquery won't allow you to materialize arrays with null, so use non-nullable
69+
int64_non_null = ops.coalesce_op.as_expr("int64_col", expression.const(0))
70+
bool_col_non_null = ops.coalesce_op.as_expr("bool_col", expression.const(False))
71+
float_col_non_null = ops.coalesce_op.as_expr("float64_col", expression.const(0.0))
72+
string_col_non_null = ops.coalesce_op.as_expr("string_col", expression.const(""))
73+
74+
ops_map = {
75+
"bool_col": ops.ToArrayOp().as_expr(bool_col_non_null),
76+
"int64_col": ops.ToArrayOp().as_expr(int64_non_null),
77+
"strs_col": ops.ToArrayOp().as_expr(string_col_non_null, string_col_non_null),
78+
"numeric_col": ops.ToArrayOp().as_expr(
79+
int64_non_null, bool_col_non_null, float_col_non_null
80+
),
81+
}
82+
83+
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
84+
snapshot.assert_match(sql, "out.sql")

0 commit comments

Comments
 (0)