Skip to content

Commit 4f0b7bd

Browse files
committed
fix struct field op
1 parent 42ec975 commit 4f0b7bd

File tree

3 files changed

+35
-8
lines changed

3 files changed

+35
-8
lines changed

bigframes/core/compile/sqlglot/expressions/unary_compiler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
import typing
1818

19+
import pandas as pd
20+
import pyarrow as pa
1921
import sqlglot
2022
import sqlglot.expressions as sge
2123

@@ -561,7 +563,17 @@ def _(op: ops.StrftimeOp, expr: TypedExpr) -> sge.Expression:
561563

562564
@UNARY_OP_REGISTRATION.register(ops.StructFieldOp)
563565
def _(op: ops.StructFieldOp, expr: TypedExpr) -> sge.Expression:
564-
return sge.StructExtract(this=expr.expr, expression=sge.convert(op.name_or_index))
566+
if isinstance(op.name_or_index, str):
567+
name = op.name_or_index
568+
else:
569+
pa_type = typing.cast(pd.ArrowDtype, expr.dtype)
570+
pa_struct_type = typing.cast(pa.StructType, pa_type.pyarrow_dtype)
571+
name = pa_struct_type.field(op.name_or_index).name
572+
573+
return sge.Column(
574+
this=sge.to_identifier(name, quoted=True),
575+
catalog=expr.expr,
576+
)
565577

566578

567579
@UNARY_OP_REGISTRATION.register(ops.tan_op)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
WITH `bfcte_0` AS (
2+
SELECT
3+
`people` AS `bfcol_0`
4+
FROM `bigframes-dev`.`sqlglot_test`.`nested_structs_types`
5+
), `bfcte_1` AS (
6+
SELECT
7+
*,
8+
`bfcol_0`.`name` AS `bfcol_1`
9+
FROM `bfcte_0`
10+
)
11+
SELECT
12+
`bfcol_1` AS `people`
13+
FROM `bfcte_1`

tests/unit/core/compile/sqlglot/expressions/test_unary_compiler.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,14 +468,16 @@ def test_strftime(scalar_types_df: bpd.DataFrame, snapshot):
468468
snapshot.assert_match(sql, "out.sql")
469469

470470

471-
def test_struct_field(scalar_types_df: bpd.DataFrame, snapshot):
472-
bf_df = scalar_types_df[["int64_col"]]
473-
bf_df["struct_col"] = bpd.DataFrame(
474-
{"field1": bf_df["int64_col"], "field2": bf_df["int64_col"] * 2}
475-
).to_struct()
476-
sql = _apply_unary_op(bf_df, ops.StructFieldOp("field1"), "struct_col")
471+
def test_struct_field(nested_structs_types_df: bpd.DataFrame, snapshot):
472+
bf_df = nested_structs_types_df[["people"]]
473+
474+
# When a name string is provided.
475+
sql = _apply_unary_op(bf_df, ops.StructFieldOp("name"), "people")
476+
snapshot.assert_match(sql, "out.sql")
477477

478-
snapshot.assert_match(sql, "out_sql")
478+
# When an index integer is provided.
479+
sql = _apply_unary_op(bf_df, ops.StructFieldOp(0), "people")
480+
snapshot.assert_match(sql, "out.sql")
479481

480482

481483
def test_str_contains(scalar_types_df: bpd.DataFrame, snapshot):

0 commit comments

Comments
 (0)