Skip to content

Commit 95ac49f

Browse files
committed
refactor: fix str.extract doctest
1 parent 69fa7f4 commit 95ac49f

File tree

3 files changed

+19
-9
lines changed

3 files changed

+19
-9
lines changed

bigframes/core/compile/sqlglot/expressions/string_ops.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,12 +48,14 @@ def _(expr: TypedExpr, op: ops.StrExtractOp) -> sge.Expression:
4848
# Cannot use BigQuery's REGEXP_EXTRACT function, which only allows one
4949
# capturing group.
5050
pat_expr = sge.convert(op.pat)
51-
if op.n != 0:
52-
pat_expr = sge.func("CONCAT", sge.convert(".*?"), pat_expr, sge.convert(".*"))
53-
else:
51+
if op.n == 0:
5452
pat_expr = sge.func("CONCAT", sge.convert(".*?("), pat_expr, sge.convert(").*"))
53+
n = 1
54+
else:
55+
pat_expr = sge.func("CONCAT", sge.convert(".*?"), pat_expr, sge.convert(".*"))
56+
n = op.n
5557

56-
rex_replace = sge.func("REGEXP_REPLACE", expr.expr, pat_expr, sge.convert(r"\1"))
58+
rex_replace = sge.func("REGEXP_REPLACE", expr.expr, pat_expr, sge.convert(f"\\{n}"))
5759
rex_contains = sge.func("REGEXP_CONTAINS", expr.expr, sge.convert(op.pat))
5860
return sge.If(this=rex_contains, true=rex_replace, false=sge.null())
5961

tests/unit/core/compile/sqlglot/expressions/snapshots/test_string_ops/test_str_extract/out.sql

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,19 @@ WITH `bfcte_0` AS (
55
), `bfcte_1` AS (
66
SELECT
77
*,
8+
IF(
9+
REGEXP_CONTAINS(`string_col`, '([a-z]*)'),
10+
REGEXP_REPLACE(`string_col`, CONCAT('.*?(', '([a-z]*)', ').*'), '\\1'),
11+
NULL
12+
) AS `bfcol_1`,
813
IF(
914
REGEXP_CONTAINS(`string_col`, '([a-z]*)'),
1015
REGEXP_REPLACE(`string_col`, CONCAT('.*?', '([a-z]*)', '.*'), '\\1'),
1116
NULL
12-
) AS `bfcol_1`
17+
) AS `bfcol_2`
1318
FROM `bfcte_0`
1419
)
1520
SELECT
16-
`bfcol_1` AS `string_col`
21+
`bfcol_1` AS `zero`,
22+
`bfcol_2` AS `one`
1723
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_string_ops.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -260,9 +260,11 @@ def test_str_contains_regex(scalar_types_df: bpd.DataFrame, snapshot):
260260
def test_str_extract(scalar_types_df: bpd.DataFrame, snapshot):
261261
col_name = "string_col"
262262
bf_df = scalar_types_df[[col_name]]
263-
sql = utils._apply_ops_to_sql(
264-
bf_df, [ops.StrExtractOp(r"([a-z]*)", 1).as_expr(col_name)], [col_name]
265-
)
263+
ops_map = {
264+
"zero": ops.StrExtractOp(r"([a-z]*)", 0).as_expr(col_name),
265+
"one": ops.StrExtractOp(r"([a-z]*)", 1).as_expr(col_name),
266+
}
267+
sql = utils._apply_ops_to_sql(bf_df, list(ops_map.values()), list(ops_map.keys()))
266268

267269
snapshot.assert_match(sql, "out.sql")
268270

0 commit comments

Comments
 (0)