Skip to content

Commit c28018c

Browse files
committed
fix struct field op
1 parent 42ec975 commit c28018c

File tree

3 files changed

+21
-8
lines changed

3 files changed

+21
-8
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,10 @@ def _(op: ops.StrftimeOp, expr: TypedExpr) -> sge.Expression:
561561

562562
@UNARY_OP_REGISTRATION.register(ops.StructFieldOp)
563563
def _(op: ops.StructFieldOp, expr: TypedExpr) -> sge.Expression:
564-
return sge.StructExtract(this=expr.expr, expression=sge.convert(op.name_or_index))
564+
return sge.Column(
565+
this=sge.to_identifier(op.name_or_index, quoted=True),
566+
catalog=expr.expr,
567+
)
565568

566569

567570
@UNARY_OP_REGISTRATION.register(ops.tan_op)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`people` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`nested_structs_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0`.`name` AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `people`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,14 +468,11 @@ def test_strftime(scalar_types_df: bpd.DataFrame, snapshot):
468468
snapshot.assert_match(sql, "out.sql")
469469

470470

471-
def test_struct_field(scalar_types_df: bpd.DataFrame, snapshot):
472-
bf_df = scalar_types_df[["int64_col"]]
473-
bf_df["struct_col"] = bpd.DataFrame(
474-
{"field1": bf_df["int64_col"], "field2": bf_df["int64_col"] * 2}
475-
).to_struct()
476-
sql = _apply_unary_op(bf_df, ops.StructFieldOp("field1"), "struct_col")
471+
def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot):
472+
bf_df = nested_structs_types_df[["people"]]
473+
sql = _apply_unary_op(bf_df, ops.StructFieldOp("name"), "people")
477474

478-
snapshot.assert_match(sql, "out_sql")
475+
snapshot.assert_match(sql, "out.sql")
479476

480477

481478
def test_str_contains(scalar_types_df: bpd.DataFrame, snapshot):

0 commit comments

Comments
 (0)