Skip to content

Commit b4e31ef

Browse files
committed
fix struct options
1 parent b355e47 commit b4e31ef

File tree

12 files changed

+141
-76
lines changed

12 files changed

+141
-76
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 62 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,9 @@ def evaluate(
151151
model: Union[bigframes.ml.base.BaseEstimator, str],
152152
input_: Optional[Union[dataframe.DataFrame, str]] = None,
153153
*,
154-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
154+
perform_aggregation: Optional[bool] = None,
155+
horizon: Optional[int] = None,
156+
confidence_level: Optional[float] = None,
155157
) -> dataframe.DataFrame:
156158
"""
157159
Evaluates a BigQuery ML model.
@@ -166,8 +168,22 @@ def evaluate(
166168
input_ (Union[bigframes.pandas.DataFrame, str], optional):
167169
The DataFrame or query to use for evaluation. If not provided, the
168170
evaluation data from training is used.
169-
options (Mapping[str, Union[str, int, float, bool, list]], optional):
170-
The OPTIONS clause, which specifies the model options.
171+
perform_aggregation (bool, optional):
172+
A BOOL value that indicates the level of evaluation for forecasting
173+
accuracy. If you specify TRUE, then the forecasting accuracy is on
174+
the time series level. If you specify FALSE, the forecasting
175+
accuracy is on the timestamp level. The default value is TRUE.
176+
horizon (int, optional):
177+
An INT64 value that specifies the number of forecasted time points
178+
against which the evaluation metrics are computed. The default value
179+
is the horizon value specified in the CREATE MODEL statement for the
180+
time series model, or 1000 if unspecified. When evaluating multiple
181+
time series at the same time, this parameter applies to each time
182+
series.
183+
confidence_level (float, optional):
184+
A FLOAT64 value that specifies the percentage of the future values
185+
that fall in the prediction interval. The default value is 0.95. The
186+
valid input range is ``[0, 1)``.
171187
172188
Returns:
173189
bigframes.pandas.DataFrame:
@@ -179,7 +195,9 @@ def evaluate(
179195
sql = bigframes.core.sql.ml.evaluate(
180196
model_name=model_name,
181197
table=table_sql,
182-
options=options,
198+
perform_aggregation=perform_aggregation,
199+
horizon=horizon,
200+
confidence_level=confidence_level,
183201
)
184202

185203
return session.read_gbq(sql)
@@ -190,7 +208,9 @@ def predict(
190208
model: Union[bigframes.ml.base.BaseEstimator, str],
191209
input_: Union[dataframe.DataFrame, str],
192210
*,
193-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
211+
threshold: Optional[float] = None,
212+
keep_original_columns: Optional[bool] = None,
213+
trial_id: Optional[int] = None,
194214
) -> dataframe.DataFrame:
195215
"""
196216
Runs prediction on a BigQuery ML model.
@@ -204,8 +224,15 @@ def predict(
204224
The model to use for prediction.
205225
input_ (Union[bigframes.pandas.DataFrame, str]):
206226
The DataFrame or query to use for prediction.
207-
options (Mapping[str, Union[str, int, float, bool, list]], optional):
208-
The OPTIONS clause, which specifies the model options.
227+
threshold (float, optional):
228+
The threshold to use for classification models.
229+
keep_original_columns (bool, optional):
230+
Whether to keep the original columns in the output.
231+
trial_id (int, optional):
232+
An INT64 value that identifies the hyperparameter tuning trial that
233+
you want the function to evaluate. The function uses the optimal
234+
trial by default. Only specify this argument if you ran
235+
hyperparameter tuning when creating the model.
209236
210237
Returns:
211238
bigframes.pandas.DataFrame:
@@ -217,7 +244,9 @@ def predict(
217244
sql = bigframes.core.sql.ml.predict(
218245
model_name=model_name,
219246
table=table_sql,
220-
options=options,
247+
threshold=threshold,
248+
keep_original_columns=keep_original_columns,
249+
trial_id=trial_id,
221250
)
222251

223252
return session.read_gbq(sql)
@@ -228,7 +257,10 @@ def explain_predict(
228257
model: Union[bigframes.ml.base.BaseEstimator, str],
229258
input_: Union[dataframe.DataFrame, str],
230259
*,
231-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
260+
top_k_features: Optional[int] = None,
261+
threshold: Optional[float] = None,
262+
integrated_gradients_num_steps: Optional[int] = None,
263+
approx_feature_contrib: Optional[bool] = None,
232264
) -> dataframe.DataFrame:
233265
"""
234266
Runs explainable prediction on a BigQuery ML model.
@@ -242,8 +274,19 @@ def explain_predict(
242274
The model to use for prediction.
243275
input_ (Union[bigframes.pandas.DataFrame, str]):
244276
The DataFrame or query to use for prediction.
245-
options (Mapping[str, Union[str, int, float, bool, list]], optional):
246-
The OPTIONS clause, which specifies the model options.
277+
top_k_features (int, optional):
278+
The number of top features to return.
279+
threshold (float, optional):
280+
The threshold for binary classification models.
281+
integrated_gradients_num_steps (int, optional):
282+
an INT64 value that specifies the number of steps to sample between
283+
the example being explained and its baseline. This value is used to
284+
approximate the integral in integrated gradients attribution
285+
methods. Increasing the value improves the precision of feature
286+
attributions, but can be slower and more computationally expensive.
287+
approx_feature_contrib (bool, optional):
288+
A BOOL value that indicates whether to use an approximate feature
289+
contribution method in the XGBoost model explanation.
247290
248291
Returns:
249292
bigframes.pandas.DataFrame:
@@ -255,7 +298,10 @@ def explain_predict(
255298
sql = bigframes.core.sql.ml.explain_predict(
256299
model_name=model_name,
257300
table=table_sql,
258-
options=options,
301+
top_k_features=top_k_features,
302+
threshold=threshold,
303+
integrated_gradients_num_steps=integrated_gradients_num_steps,
304+
approx_feature_contrib=approx_feature_contrib,
259305
)
260306

261307
return session.read_gbq(sql)
@@ -265,7 +311,7 @@ def explain_predict(
265311
def global_explain(
266312
model: Union[bigframes.ml.base.BaseEstimator, str],
267313
*,
268-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
314+
class_level_explain: Optional[bool] = None,
269315
) -> dataframe.DataFrame:
270316
"""
271317
Gets global explanations for a BigQuery ML model.
@@ -277,8 +323,8 @@ def global_explain(
277323
Args:
278324
model (bigframes.ml.base.BaseEstimator or str):
279325
The model to get explanations from.
280-
options (Mapping[str, Union[str, int, float, bool, list]], optional):
281-
The OPTIONS clause, which specifies the model options.
326+
class_level_explain (bool, optional):
327+
Whether to return class-level explanations.
282328
283329
Returns:
284330
bigframes.pandas.DataFrame:
@@ -287,7 +333,7 @@ def global_explain(
287333
model_name, session = _get_model_name_and_session(model)
288334
sql = bigframes.core.sql.ml.global_explain(
289335
model_name=model_name,
290-
options=options,
336+
class_level_explain=class_level_explain,
291337
)
292338

293339
return session.read_gbq(sql)

bigframes/core/sql/ml.py

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from __future__ import annotations
1616

17-
from typing import Mapping, Optional, Union
17+
from typing import Dict, Mapping, Optional, Union
1818

1919
import bigframes.core.compile.googlesql as googlesql
2020
import bigframes.core.sql
@@ -94,33 +94,48 @@ def create_model_ddl(
9494
ddl += f"AS (\n {', '.join(parts)}\n)"
9595
else:
9696
# Just training_data is treated as the query_statement
97-
ddl += f"AS {training_data}"
97+
ddl += f"AS {training_data}\n"
9898

9999
return ddl
100100

101101

102+
def _build_struct_sql(
103+
struct_options: Mapping[str, Union[str, int, float, bool]]
104+
) -> str:
105+
if not struct_options:
106+
return ""
107+
108+
rendered_options = []
109+
for option_name, option_value in struct_options.items():
110+
rendered_val = bigframes.core.sql.simple_literal(option_value)
111+
rendered_options.append(f"{rendered_val} AS {option_name}")
112+
return f", STRUCT({', '.join(rendered_options)})"
113+
114+
102115
def evaluate(
103116
model_name: str,
104117
*,
105118
table: Optional[str] = None,
106-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
119+
perform_aggregation: Optional[bool] = None,
120+
horizon: Optional[int] = None,
121+
confidence_level: Optional[float] = None,
107122
) -> str:
108-
"""Encode the ML.EVALUATE statement.
109-
123+
"""Encode the ML.EVAluate statement.
110124
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-evaluate for reference.
111125
"""
126+
struct_options: Dict[str, Union[str, int, float, bool]] = {}
127+
if perform_aggregation is not None:
128+
struct_options["perform_aggregation"] = perform_aggregation
129+
if horizon is not None:
130+
struct_options["horizon"] = horizon
131+
if confidence_level is not None:
132+
struct_options["confidence_level"] = confidence_level
133+
112134
sql = f"SELECT * FROM ML.EVALUATE(MODEL {googlesql.identifier(model_name)}"
113135
if table:
114136
sql += f", ({table})"
115-
if options:
116-
rendered_options = []
117-
for option_name, option_value in options.items():
118-
if isinstance(option_value, (list, tuple)):
119-
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
120-
else:
121-
rendered_val = bigframes.core.sql.simple_literal(option_value)
122-
rendered_options.append(f"{option_name} = {rendered_val}")
123-
sql += f", OPTIONS({', '.join(rendered_options)})"
137+
138+
sql += _build_struct_sql(struct_options)
124139
sql += ")\n"
125140
return sql
126141

@@ -129,24 +144,25 @@ def predict(
129144
model_name: str,
130145
table: str,
131146
*,
132-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
147+
threshold: Optional[float] = None,
148+
keep_original_columns: Optional[bool] = None,
149+
trial_id: Optional[int] = None,
133150
) -> str:
134151
"""Encode the ML.PREDICT statement.
135-
136152
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-predict for reference.
137153
"""
154+
struct_options = {}
155+
if threshold is not None:
156+
struct_options["threshold"] = threshold
157+
if keep_original_columns is not None:
158+
struct_options["keep_original_columns"] = keep_original_columns
159+
if trial_id is not None:
160+
struct_options["trial_id"] = trial_id
161+
138162
sql = (
139163
f"SELECT * FROM ML.PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})"
140164
)
141-
if options:
142-
rendered_options = []
143-
for option_name, option_value in options.items():
144-
if isinstance(option_value, (list, tuple)):
145-
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
146-
else:
147-
rendered_val = bigframes.core.sql.simple_literal(option_value)
148-
rendered_options.append(f"{option_name} = {rendered_val}")
149-
sql += f", OPTIONS({', '.join(rendered_options)})"
165+
sql += _build_struct_sql(struct_options)
150166
sql += ")\n"
151167
return sql
152168

@@ -155,44 +171,45 @@ def explain_predict(
155171
model_name: str,
156172
table: str,
157173
*,
158-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
174+
top_k_features: Optional[int] = None,
175+
threshold: Optional[float] = None,
176+
integrated_gradients_num_steps: Optional[int] = None,
177+
approx_feature_contrib: Optional[bool] = None,
159178
) -> str:
160179
"""Encode the ML.EXPLAIN_PREDICT statement.
161-
162180
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-explain-predict for reference.
163181
"""
182+
struct_options: Dict[str, Union[str, int, float, bool]] = {}
183+
if top_k_features is not None:
184+
struct_options["top_k_features"] = top_k_features
185+
if threshold is not None:
186+
struct_options["threshold"] = threshold
187+
if integrated_gradients_num_steps is not None:
188+
struct_options[
189+
"integrated_gradients_num_steps"
190+
] = integrated_gradients_num_steps
191+
if approx_feature_contrib is not None:
192+
struct_options["approx_feature_contrib"] = approx_feature_contrib
193+
164194
sql = f"SELECT * FROM ML.EXPLAIN_PREDICT(MODEL {googlesql.identifier(model_name)}, ({table})"
165-
if options:
166-
rendered_options = []
167-
for option_name, option_value in options.items():
168-
if isinstance(option_value, (list, tuple)):
169-
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
170-
else:
171-
rendered_val = bigframes.core.sql.simple_literal(option_value)
172-
rendered_options.append(f"{option_name} = {rendered_val}")
173-
sql += f", OPTIONS({', '.join(rendered_options)})"
195+
sql += _build_struct_sql(struct_options)
174196
sql += ")\n"
175197
return sql
176198

177199

178200
def global_explain(
179201
model_name: str,
180202
*,
181-
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
203+
class_level_explain: Optional[bool] = None,
182204
) -> str:
183205
"""Encode the ML.GLOBAL_EXPLAIN statement.
184-
185206
See https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-global-explain for reference.
186207
"""
208+
struct_options = {}
209+
if class_level_explain is not None:
210+
struct_options["class_level_explain"] = class_level_explain
211+
187212
sql = f"SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL {googlesql.identifier(model_name)}"
188-
if options:
189-
rendered_options = []
190-
for option_name, option_value in options.items():
191-
if isinstance(option_value, (list, tuple)):
192-
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
193-
else:
194-
rendered_val = bigframes.core.sql.simple_literal(option_value)
195-
rendered_options.append(f"{option_name} = {rendered_val}")
196-
sql += f", OPTIONS({', '.join(rendered_options)})"
213+
sql += _build_struct_sql(struct_options)
197214
sql += ")\n"
198215
return sql
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
CREATE MODEL `my_project.my_dataset.my_model`
22
OPTIONS(model_type = 'LINEAR_REG', input_label_cols = ['label'])
3-
AS SELECT * FROM my_table
3+
AS SELECT * FROM my_table
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
CREATE MODEL IF NOT EXISTS `my_model`
22
OPTIONS(model_type = 'KMEANS')
3-
AS SELECT * FROM t
3+
AS SELECT * FROM t
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
CREATE MODEL `my_model`
22
OPTIONS(hidden_units = [32, 16], dropout = 0.2)
3-
AS SELECT * FROM t
3+
AS SELECT * FROM t
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
CREATE OR REPLACE MODEL `my_model`
22
OPTIONS(model_type = 'LOGISTIC_REG')
3-
AS SELECT * FROM t
3+
AS SELECT * FROM t
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
CREATE MODEL `my_model`
22
TRANSFORM (ML.STANDARD_SCALER(c1) OVER() AS c1_scaled, c2)
33
OPTIONS(model_type = 'LINEAR_REG')
4-
AS SELECT c1, c2, label FROM t
4+
AS SELECT c1, c2, label FROM t
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EVALUATE(MODEL `my_model`, OPTIONS(threshold = 0.5))
1+
SELECT * FROM ML.EVALUATE(MODEL `my_model`, STRUCT(False AS perform_aggregation, 10 AS horizon, 0.95 AS confidence_level))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), OPTIONS(top_k_features = 5))
1+
SELECT * FROM ML.EXPLAIN_PREDICT(MODEL `my_model`, (SELECT * FROM new_data), STRUCT(5 AS top_k_features))
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, OPTIONS(num_features = 10))
1+
SELECT * FROM ML.GLOBAL_EXPLAIN(MODEL `my_model`, STRUCT(True AS class_level_explain))

0 commit comments

Comments
 (0)