diff --git a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py index 3d527f2a2f..98f1603be7 100644 --- a/bigframes/core/compile/sqlglot/expressions/unary_compiler.py +++ b/bigframes/core/compile/sqlglot/expressions/unary_compiler.py @@ -14,6 +14,7 @@ from __future__ import annotations +import functools import typing import pandas as pd @@ -292,6 +293,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="DAYOFYEAR"), expression=expr.expr) +@UNARY_OP_REGISTRATION.register(ops.EndsWithOp) +def _(op: ops.EndsWithOp, expr: TypedExpr) -> sge.Expression: + if not op.pat: + return sge.false() + + def to_endswith(pat: str) -> sge.Expression: + return sge.func("ENDS_WITH", expr.expr, sge.convert(pat)) + + conditions = [to_endswith(pat) for pat in op.pat] + return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions) + + @UNARY_OP_REGISTRATION.register(ops.exp_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Case( @@ -633,6 +646,18 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) +@UNARY_OP_REGISTRATION.register(ops.StartsWithOp) +def _(op: ops.StartsWithOp, expr: TypedExpr) -> sge.Expression: + if not op.pat: + return sge.false() + + def to_startswith(pat: str) -> sge.Expression: + return sge.func("STARTS_WITH", expr.expr, sge.convert(pat)) + + conditions = [to_startswith(pat) for pat in op.pat] + return functools.reduce(lambda x, y: sge.Or(this=x, expression=y), conditions) + + @UNARY_OP_REGISTRATION.register(ops.StrStripOp) def _(op: ops.StrStripOp, expr: TypedExpr) -> sge.Expression: return sge.Trim(this=sge.convert(op.to_strip), expression=expr.expr) @@ -656,6 +681,11 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: ) +@UNARY_OP_REGISTRATION.register(ops.StringSplitOp) +def _(op: ops.StringSplitOp, expr: TypedExpr) -> sge.Expression: + return sge.Split(this=expr.expr, expression=sge.convert(op.pat)) + + @UNARY_OP_REGISTRATION.register(ops.StrGetOp) def _(op: ops.StrGetOp, expr: TypedExpr) -> sge.Expression: return sge.Substring( @@ -808,3 +838,31 @@ def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: @UNARY_OP_REGISTRATION.register(ops.year_op) def _(op: ops.base_ops.UnaryOp, expr: TypedExpr) -> sge.Expression: return sge.Extract(this=sge.Identifier(this="YEAR"), expression=expr.expr) + + +@UNARY_OP_REGISTRATION.register(ops.ZfillOp) +def _(op: ops.ZfillOp, expr: TypedExpr) -> sge.Expression: + return sge.Case( + ifs=[ + sge.If( + this=sge.EQ( + this=sge.Substring( + this=expr.expr, start=sge.convert(1), length=sge.convert(1) + ), + expression=sge.convert("-"), + ), + true=sge.Concat( + expressions=[ + sge.convert("-"), + sge.func( + "LPAD", + sge.Substring(this=expr.expr, start=sge.convert(1)), + sge.convert(op.width - 1), + sge.convert("0"), + ), + ] + ), + ) + ], + default=sge.func("LPAD", expr.expr, sge.convert(op.width), sge.convert("0")), + ) diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql new file mode 100644 index 0000000000..f224471e79 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/multiple_patterns.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ENDS_WITH(`bfcol_0`, 'ab') OR ENDS_WITH(`bfcol_0`, 'cd') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql new file mode 100644 index 0000000000..e9f61ddd7c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/no_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + FALSE AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql new file mode 100644 index 0000000000..a4e259f0b2 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_endswith/single_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + ENDS_WITH(`bfcol_0`, 'ab') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql new file mode 100644 index 0000000000..061b57e208 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/multiple_patterns.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + STARTS_WITH(`bfcol_0`, 'ab') OR STARTS_WITH(`bfcol_0`, 'cd') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql new file mode 100644 index 0000000000..e9f61ddd7c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/no_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + FALSE AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql new file mode 100644 index 0000000000..726ce05b8c --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_startswith/single_pattern.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + STARTS_WITH(`bfcol_0`, 'ab') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_string_split/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_string_split/out.sql new file mode 100644 index 0000000000..fea0d6eaf1 --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_string_split/out.sql @@ -0,0 +1,13 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + SPLIT(`bfcol_0`, ',') AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_zfill/out.sql b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_zfill/out.sql new file mode 100644 index 0000000000..e5d70ab44b --- /dev/null +++ b/tests/unit/core/compile/sqlglot/expressions/snapshots/test_unary_compiler/test_zfill/out.sql @@ -0,0 +1,17 @@ +WITH `bfcte_0` AS ( + SELECT + `string_col` AS `bfcol_0` + FROM `bigframes-dev`.`sqlglot_test`.`scalar_types` +), `bfcte_1` AS ( + SELECT + *, + CASE + WHEN SUBSTRING(`bfcol_0`, 1, 1) = '-' + THEN CONCAT('-', LPAD(SUBSTRING(`bfcol_0`, 1), 9, '0')) + ELSE LPAD(`bfcol_0`, 10, '0') + END AS `bfcol_1` + FROM `bfcte_0` +) +SELECT + `bfcol_1` AS `string_col` +FROM `bfcte_1` \ No newline at end of file diff --git a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py index 2a3297a46c..f011721ee5 100644 --- a/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py +++ b/tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py @@ -125,6 +125,18 @@ def test_dayofyear(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_endswith(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab",)), "string_col") + snapshot.assert_match(sql, "single_pattern.sql") + + sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=("ab", "cd")), "string_col") + snapshot.assert_match(sql, "multiple_patterns.sql") + + sql = _apply_unary_op(bf_df, ops.EndsWithOp(pat=()), "string_col") + snapshot.assert_match(sql, "no_pattern.sql") + + def test_exp(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["float64_col"]] sql = _apply_unary_op(bf_df, ops.exp_op, "float64_col") @@ -501,6 +513,18 @@ def test_sqrt(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_startswith(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab",)), "string_col") + snapshot.assert_match(sql, "single_pattern.sql") + + sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=("ab", "cd")), "string_col") + snapshot.assert_match(sql, "multiple_patterns.sql") + + sql = _apply_unary_op(bf_df, ops.StartsWithOp(pat=()), "string_col") + snapshot.assert_match(sql, "no_pattern.sql") + + def test_str_get(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["string_col"]] sql = _apply_unary_op(bf_df, ops.StrGetOp(1), "string_col") @@ -650,6 +674,12 @@ def test_sinh(scalar_types_df: bpd.DataFrame, snapshot): snapshot.assert_match(sql, "out.sql") +def test_string_split(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.StringSplitOp(pat=","), "string_col") + snapshot.assert_match(sql, "out.sql") + + def test_tan(scalar_types_df: bpd.DataFrame, snapshot): bf_df = scalar_types_df[["float64_col"]] sql = _apply_unary_op(bf_df, ops.tan_op, "float64_col") @@ -790,3 +820,9 @@ def test_year(scalar_types_df: bpd.DataFrame, snapshot): sql = _apply_unary_op(bf_df, ops.year_op, "timestamp_col") snapshot.assert_match(sql, "out.sql") + + +def test_zfill(scalar_types_df: bpd.DataFrame, snapshot): + bf_df = scalar_types_df[["string_col"]] + sql = _apply_unary_op(bf_df, ops.ZfillOp(width=10), "string_col") + snapshot.assert_match(sql, "out.sql")