Skip to content

Commit 92c69aa

Browse files
feat: Add BigFrames.bigquery.st_regionstats method
This commit adds the `BigFrames.bigquery.st_regionstats` method, which allows users to compute statistics for a raster band within a given geography. The implementation includes: - A new `StRegionStatsOp` in `bigframes/operations/geo_ops.py`. - Compiler implementations for both the SQLGlot and Ibis backends. - A unit test with a SQL snapshot. - A system test that demonstrates the use of the new function. This commit also addresses feedback from the code review, including: - Adding `area` to the output struct of `st_regionstats`. - Making the `options` parameter a positional argument. - Adding comments to explain the use of `pass_op`. - Converting the unit test to a pytest-style function. - Moving the sample code to a system test.
1 parent ab46078 commit 92c69aa

File tree

8 files changed

+102
-91
lines changed

8 files changed

+102
-91
lines changed

bigframes/bigquery/_operations/geo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -683,7 +683,6 @@ def st_regionstats(
683683
geography: bigframes.geopandas.GeoSeries,
684684
raster: bigframes.series.Series,
685685
band: str,
686-
*,
687686
options: Mapping[str, Union[str, int, float]] = {},
688687
) -> bigframes.dataframe.DataFrame:
689688
"""Computes statistics for a raster band within a given geography.

bigframes/core/compile/ibis_compiler/scalar_op_compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,9 @@ def register_ternary_op(
169169
Args:
170170
op_ref (TernaryOp or TernaryOp type):
171171
Class or instance of operator that is implemented by the decorated function.
172+
pass_op (bool):
173+
Set to true if implementation takes the operator object as the last argument.
174+
This is needed for parameterized ops where parameters are part of op object.
172175
"""
173176
key = typing.cast(str, op_ref.name)
174177

@@ -296,5 +299,7 @@ def st_regionstats(
296299
if op.options:
297300
args.append(bigframes_vendored.ibis.literal(op.options, type="json"))
298301
return bigframes_vendored.ibis.remote_function(
299-
"st_regionstats", args, output_type="struct<min: float, max: float, sum: float, count: int, mean: float>" # type: ignore
302+
"st_regionstats",
303+
args,
304+
output_type="struct<min: float, max: float, sum: float, count: int, mean: float, area: float>", # type: ignore
300305
)

bigframes/core/compile/sqlglot/scalar_compiler.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -201,11 +201,14 @@ def compile_st_regionstats(
201201
args = [geography.expr, raster.expr, band.expr]
202202
if op.options:
203203
args.append(
204-
sge.EQ(
205-
this=sge.Identifier(this="OPTIONS"),
206-
expression=sge.Anonymous(
207-
this="JSON", expressions=[sge.convert(op.options)]
208-
),
204+
sge.Anonymous(
205+
this="_",
206+
expressions=[
207+
sge.Identifier(this="OPTIONS"),
208+
sge.Anonymous(
209+
this="JSON", expressions=[sge.convert(op.options)]
210+
),
211+
],
209212
)
210213
)
211214
return sge.func("ST_REGIONSTATS", *args)

bigframes/operations/geo_ops.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
142142
("sum", dtypes.FLOAT_DTYPE),
143143
("count", dtypes.INT_DTYPE),
144144
("mean", dtypes.FLOAT_DTYPE),
145+
("area", dtypes.FLOAT_DTYPE),
145146
]
146147
)
147148

samples/snippets/wildfire_risk.py

Lines changed: 0 additions & 75 deletions
This file was deleted.
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Inspired by the SQL at https://cloud.google.com/blog/products/data-analytics/a-closer-look-at-earth-engine-in-bigquery
16+
17+
import bigframes.bigquery as bbq
18+
import bigframes.pandas as bpd
19+
20+
21+
def test_wildfire_risk(session):
22+
# Step 1: Select inputs from datasets that we've subscribed to
23+
wildfire_raster = bpd.read_gbq("wildfire_risk_to_community_v0_mosaic.fire")[
24+
"assets.image.href"
25+
]
26+
places = bpd.read_gbq("bigquery-public-data.geo_us_census_places.places_colorado")[
27+
["place_id", "place_name", "place_geom"]
28+
]
29+
places = places.rename(columns={"place_geom": "geo"})
30+
31+
# Step 2: Compute the weather forecast using WeatherNext Graph forecast data
32+
weather_forecast = bpd.read_gbq("weathernext_graph_forecasts.59572747_4_0")
33+
weather_forecast = weather_forecast[
34+
weather_forecast["init_time"] == "2025-04-28 00:00:00+00:00"
35+
]
36+
weather_forecast = weather_forecast.explode("forecast")
37+
wind_speed = (
38+
weather_forecast["forecast"]["10m_u_component_of_wind"] ** 2
39+
+ weather_forecast["forecast"]["10m_v_component_of_wind"] ** 2
40+
) ** 0.5
41+
weather_forecast = weather_forecast.assign(wind_speed=wind_speed)
42+
weather_forecast = weather_forecast[weather_forecast["forecast"]["hours"] < 24]
43+
weather_forecast = weather_forecast.merge(
44+
places, how="inner", left_on="geography_polygon", right_on="geo"
45+
)
46+
weather_forecast = weather_forecast.groupby("place_id").agg(
47+
place_name=("place_name", "first"),
48+
geo=("geo", "first"),
49+
average_wind_speed=("wind_speed", "mean"),
50+
maximum_wind_speed=("wind_speed", "max"),
51+
)
52+
53+
# Step 3: Combine with wildfire risk for each community
54+
wildfire_risk = weather_forecast.assign(
55+
wildfire_likelihood=bbq.st_regionstats(
56+
weather_forecast["geo"],
57+
wildfire_raster,
58+
"BP",
59+
options={"scale": 1000},
60+
)["mean"],
61+
wildfire_consequence=bbq.st_regionstats(
62+
weather_forecast["geo"],
63+
wildfire_raster,
64+
"CRPS",
65+
options={"scale": 1000},
66+
)["mean"],
67+
)
68+
69+
# Step 4: Compute a simple composite index of relative wildfire risk.
70+
relative_risk = (
71+
wildfire_risk["wildfire_likelihood"].rank(pct=True)
72+
+ wildfire_risk["wildfire_consequence"].rank(pct=True)
73+
+ wildfire_risk["average_wind_speed"].rank(pct=True)
74+
) / 3 * 100
75+
wildfire_risk = wildfire_risk.assign(relative_risk=relative_risk)
76+
assert wildfire_risk is not None

tests/unit/core/compile/sqlglot/snapshots/test_geo_compiler/test_st_regionstats/out.sql

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ WITH `bfcte_1` AS (
2121
), `bfcte_4` AS (
2222
SELECT
2323
*,
24-
ST_REGIONSTATS(`bfcol_1`, `bfcol_5`, 'band1', OPTIONS = JSON('{"scale": 100}')) AS `bfcol_8`
24+
ST_REGIONSTATS(`bfcol_1`, `bfcol_5`, 'band1', _(OPTIONS, JSON('{"scale": 100}'))) AS `bfcol_8`
2525
FROM `bfcte_3`
2626
), `bfcte_5` AS (
2727
SELECT
@@ -30,13 +30,15 @@ WITH `bfcte_1` AS (
3030
`bfcol_8`.`max` AS `bfcol_11`,
3131
`bfcol_8`.`sum` AS `bfcol_12`,
3232
`bfcol_8`.`count` AS `bfcol_13`,
33-
`bfcol_8`.`mean` AS `bfcol_14`
33+
`bfcol_8`.`mean` AS `bfcol_14`,
34+
`bfcol_8`.`area` AS `bfcol_15`
3435
FROM `bfcte_4`
3536
)
3637
SELECT
3738
`bfcol_10` AS `min`,
3839
`bfcol_11` AS `max`,
3940
`bfcol_12` AS `sum`,
4041
`bfcol_13` AS `count`,
41-
`bfcol_14` AS `mean`
42+
`bfcol_14` AS `mean`,
43+
`bfcol_15` AS `area`
4244
FROM `bfcte_5`

tests/unit/core/compile/sqlglot/test_geo_compiler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
pytest.importorskip("pytest_snapshot")
2222

2323

24-
class TestGeoCompiler:
25-
def test_st_regionstats(self, compiler_session, snapshot):
26-
geos = gpd.GeoSeries(["POINT(1 1)"], session=compiler_session)
27-
rasters = bpd.Series(["raster_uri"], dtype="string", session=compiler_session)
28-
df = bbq.st_regionstats(geos, rasters, "band1", options={"scale": 100})
29-
snapshot.assert_match(df.sql, "out.sql")
24+
def test_st_regionstats(compiler_session, snapshot):
25+
geos = gpd.GeoSeries(["POINT(1 1)"], session=compiler_session)
26+
rasters = bpd.Series(["raster_uri"], dtype="string", session=compiler_session)
27+
df = bbq.st_regionstats(geos, rasters, "band1", {"scale": 100})
28+
assert "area" in df.columns
29+
snapshot.assert_match(df.sql, "out.sql")

0 commit comments

Comments
 (0)