Skip to content

Commit ce09b4a

Browse files
committed
refactor: add sqlscalarop to the sqlglot compiler
1 parent 9fad51c commit ce09b4a

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

bigframes/core/compile/sqlglot/expressions/generic_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import sqlglot as sg
1718
import sqlglot.expressions as sge
1819

1920
from bigframes import dtypes
@@ -80,6 +81,17 @@ def _(expr: TypedExpr) -> sge.Expression:
8081
return sge.BitwiseNot(this=sge.paren(expr.expr))
8182

8283

84+
@register_nary_op(ops.SqlScalarOp, pass_op=True)
85+
def _(*operands: TypedExpr, op: ops.SqlScalarOp) -> sge.Expression:
86+
# TODO: can we include a string in the sqlglot expression without parsing?
87+
return sg.parse_one(
88+
op.sql_template.format(
89+
*[operand.expr.sql(dialect="bigquery") for operand in operands]
90+
),
91+
dialect="bigquery",
92+
)
93+
94+
8395
@register_unary_op(ops.isnull_op)
8496
def _(expr: TypedExpr) -> sge.Expression:
8597
return sge.Is(this=expr.expr, expression=sge.Null())
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`bool_col` AS `bfcol_0`,
4+
`bytes_col` AS `bfcol_1`
5+
FROM `bigframes-dev`.`sqlglot_test`.`scalar_types`
6+
), `bfcte_1` AS (
7+
SELECT
8+
*,
9+
CAST(`bfcol_0` AS INT64) + BYTE_LENGTH(`bfcol_1`) AS `bfcol_2`
10+
FROM `bfcte_0`
11+
)
12+
SELECT
13+
`bfcol_2` AS `bool_col`
14+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_generic_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,17 @@ def test_notnull(scalar_types_df: bpd.DataFrame, snapshot):
261261
snapshot.assert_match(sql, "out.sql")
262262

263263

264+
def test_sql_scalar_op(scalar_types_df: bpd.DataFrame, snapshot):
265+
bf_df = scalar_types_df[["bool_col", "bytes_col"]]
266+
sql = utils._apply_nary_op(
267+
bf_df,
268+
ops.SqlScalarOp(dtypes.INT_DTYPE, "CAST({0} AS INT64) + BYTE_LENGTH({1})"),
269+
"bool_col",
270+
"bytes_col",
271+
)
272+
snapshot.assert_match(sql, "out.sql")
273+
274+
264275
def test_map(scalar_types_df: bpd.DataFrame, snapshot):
265276
col_name = "string_col"
266277
bf_df = scalar_types_df[[col_name]]

0 commit comments

Comments
 (0)