Skip to content

Commit b44d520

Browse files
committed
fix the sqlgot compiler
1 parent 4252e90 commit b44d520

File tree

4 files changed

+47
-56
lines changed

4 files changed

+47
-56
lines changed

bigframes/bigquery/_operations/geo.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,7 @@ def st_regionstats(
685685
band: Optional[str] = None,
686686
include: Optional[str] = None,
687687
options: Optional[Mapping[str, Union[str, int, float]]] = None,
688-
) -> bigframes.dataframe.DataFrame:
688+
) -> bigframes.series.Series:
689689
"""Returns statistics summarizing the pixel values of the raster image
690690
referenced by raster_id that intersect with geography.
691691
@@ -726,14 +726,13 @@ def st_regionstats(
726726
documentation for a list of available options.
727727
728728
Returns:
729-
bigframes.dataframe.DataFrame:
730-
A dataframe containing the computed statistics.
729+
bigframes.pandas.Series:
730+
A STRUCT Series containing the computed statistics.
731731
"""
732732
op = ops.StRegionStatsOp(
733733
raster_id=raster_id,
734734
band=band,
735735
include=include,
736736
options=json.dumps(options) if options else None,
737737
)
738-
df = geography._apply_unary_op(op)
739-
return df[df.columns[0]].struct.explode()
738+
return geography._apply_unary_op(op)

bigframes/core/compile/sqlglot/expressions/geo_ops.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,11 @@ def _(
8989
geography: TypedExpr,
9090
op: ops.StRegionStatsOp,
9191
):
92-
args = [geography.expr] # TODO: get raster, band, include from op.
92+
args = [geography.expr, sge.convert(op.raster_id)]
93+
if op.band:
94+
args.append(sge.Kwarg(this="band", expression=sge.convert(op.band)))
95+
if op.include:
96+
args.append(sge.Kwarg(this="include", expression=sge.convert(op.include)))
9397
if op.options:
94-
args.append(
95-
sge.Anonymous(
96-
this="_",
97-
expressions=[
98-
sge.Identifier(this="OPTIONS"),
99-
sge.Anonymous(this="JSON", expressions=[sge.convert(op.options)]),
100-
],
101-
)
102-
)
98+
args.append(sge.Kwarg(this="options", expression=sge.convert(op.options)))
10399
return sge.func("ST_REGIONSTATS", *args)
Lines changed: 28 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,36 @@
1-
WITH `bfcte_1` AS (
1+
WITH `bfcte_0` AS (
22
SELECT
33
*
4-
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` INT64, `bfcol_1` STRING>>[STRUCT(0, 'POINT(1 1)')])
5-
), `bfcte_0` AS (
6-
SELECT
7-
*
8-
FROM UNNEST(ARRAY<STRUCT<`bfcol_2` INT64, `bfcol_3` STRING>>[STRUCT(0, 'raster_uri')])
9-
), `bfcte_2` AS (
10-
SELECT
11-
`bfcol_2` AS `bfcol_4`,
12-
`bfcol_3` AS `bfcol_5`
13-
FROM `bfcte_0`
14-
), `bfcte_3` AS (
15-
SELECT
16-
*
17-
FROM `bfcte_1`
18-
LEFT JOIN `bfcte_2`
19-
ON COALESCE(`bfcol_0`, 0) = COALESCE(`bfcol_4`, 0)
20-
AND COALESCE(`bfcol_0`, 1) = COALESCE(`bfcol_4`, 1)
21-
), `bfcte_4` AS (
4+
FROM UNNEST(ARRAY<STRUCT<`bfcol_0` STRING, `bfcol_1` INT64>>[STRUCT('POINT(1 1)', 0)])
5+
), `bfcte_1` AS (
226
SELECT
237
*,
24-
ST_REGIONSTATS(`bfcol_1`, `bfcol_5`, 'band1', _(OPTIONS, JSON('{"scale": 100}'))) AS `bfcol_8`
25-
FROM `bfcte_3`
26-
), `bfcte_5` AS (
8+
ST_REGIONSTATS(
9+
`bfcol_0`,
10+
'ee://some/raster/uri',
11+
band => 'band1',
12+
include => 'some equation',
13+
options => '{"scale": 100}'
14+
) AS `bfcol_2`
15+
FROM `bfcte_0`
16+
), `bfcte_2` AS (
2717
SELECT
2818
*,
29-
`bfcol_8`.`min` AS `bfcol_10`,
30-
`bfcol_8`.`max` AS `bfcol_11`,
31-
`bfcol_8`.`sum` AS `bfcol_12`,
32-
`bfcol_8`.`count` AS `bfcol_13`,
33-
`bfcol_8`.`mean` AS `bfcol_14`,
34-
`bfcol_8`.`area` AS `bfcol_15`
35-
FROM `bfcte_4`
19+
`bfcol_2`.`min` AS `bfcol_5`,
20+
`bfcol_2`.`max` AS `bfcol_6`,
21+
`bfcol_2`.`sum` AS `bfcol_7`,
22+
`bfcol_2`.`count` AS `bfcol_8`,
23+
`bfcol_2`.`mean` AS `bfcol_9`,
24+
`bfcol_2`.`area` AS `bfcol_10`
25+
FROM `bfcte_1`
3626
)
3727
SELECT
38-
`bfcol_10` AS `min`,
39-
`bfcol_11` AS `max`,
40-
`bfcol_12` AS `sum`,
41-
`bfcol_13` AS `count`,
42-
`bfcol_14` AS `mean`,
43-
`bfcol_15` AS `area`
44-
FROM `bfcte_5`
28+
`bfcol_5` AS `min`,
29+
`bfcol_6` AS `max`,
30+
`bfcol_7` AS `sum`,
31+
`bfcol_8` AS `count`,
32+
`bfcol_9` AS `mean`,
33+
`bfcol_10` AS `area`
34+
FROM `bfcte_2`
35+
ORDER BY
36+
`bfcol_1` ASC NULLS LAST

tests/unit/core/compile/sqlglot/test_geo_compiler.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,14 +16,18 @@
1616

1717
import bigframes.bigquery as bbq
1818
import bigframes.geopandas as gpd
19-
import bigframes.pandas as bpd
2019

2120
pytest.importorskip("pytest_snapshot")
2221

2322

2423
def test_st_regionstats(compiler_session, snapshot):
2524
geos = gpd.GeoSeries(["POINT(1 1)"], session=compiler_session)
26-
rasters = bpd.Series(["raster_uri"], dtype="string", session=compiler_session)
27-
df = bbq.st_regionstats(geos, rasters, "band1", {"scale": 100})
28-
assert "area" in df.columns
29-
snapshot.assert_match(df.sql, "out.sql")
25+
result = bbq.st_regionstats(
26+
geos,
27+
"ee://some/raster/uri",
28+
band="band1",
29+
include="some equation",
30+
options={"scale": 100},
31+
)
32+
assert "area" in result.struct.dtypes.index
33+
snapshot.assert_match(result.struct.explode().sql, "out.sql")

0 commit comments

Comments
 (0)