Skip to content

Commit fdee53f

Browse files
feat: Add BigQuery ML CREATE MODEL support
- Refactor `bigframes.core.sql` to a package. - Add `bigframes.core.sql.ml` for DDL generation. - Add `bigframes.bigquery.ml` module with `create_model` function. - Add unit tests for SQL generation. - Use `_start_query_ml_ddl` for execution. - Return the created model object using `read_gbq_model`. - Remove `query` argument, simplify SQL generation logic.
1 parent 54050d5 commit fdee53f

File tree

3 files changed

+28
-30
lines changed

3 files changed

+28
-30
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,16 @@
2323
import bigframes.ml.base
2424
import bigframes.session
2525

26+
27+
# Helper to convert DataFrame to SQL string
28+
def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
29+
if isinstance(df_or_sql, str):
30+
return df_or_sql
31+
# It's a DataFrame
32+
sql, _, _ = df_or_sql._to_sql_query(include_index=False)
33+
return sql
34+
35+
2636
@log_adapter.method_logger(custom_base_name="bigquery_ml")
2737
def create_model(
2838
model_name: str,
@@ -34,7 +44,6 @@ def create_model(
3444
output_schema: Optional[Mapping[str, str]] = None,
3545
connection_name: Optional[str] = None,
3646
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
37-
query: Optional[Union[dataframe.DataFrame, str]] = None,
3847
training_data: Optional[Union[dataframe.DataFrame, str]] = None,
3948
custom_holiday: Optional[Union[dataframe.DataFrame, str]] = None,
4049
session: Optional[bigframes.session.Session] = None,
@@ -44,22 +53,13 @@ def create_model(
4453
"""
4554
import bigframes.pandas as bpd
4655

47-
# Helper to convert DataFrame to SQL string
48-
def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
49-
if isinstance(df_or_sql, str):
50-
return df_or_sql
51-
# It's a DataFrame
52-
sql, _, _ = df_or_sql._to_sql_query(include_index=True)
53-
return sql
54-
55-
query_statement = _to_sql(query) if query is not None else None
5656
training_data_sql = _to_sql(training_data) if training_data is not None else None
5757
custom_holiday_sql = _to_sql(custom_holiday) if custom_holiday is not None else None
5858

5959
# Determine session from DataFrames if not provided
6060
if session is None:
6161
# Try to get session from inputs
62-
dfs = [obj for obj in [query, training_data, custom_holiday] if hasattr(obj, "_session")]
62+
dfs = [obj for obj in [training_data, custom_holiday] if hasattr(obj, "_session")]
6363
if dfs:
6464
session = dfs[0]._session
6565

@@ -72,7 +72,6 @@ def _to_sql(df_or_sql: Union[dataframe.DataFrame, str]) -> str:
7272
output_schema=output_schema,
7373
connection_name=connection_name,
7474
options=options,
75-
query_statement=query_statement,
7675
training_data=training_data_sql,
7776
custom_holiday=custom_holiday_sql,
7877
)

bigframes/core/sql/ml.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,13 @@ def create_model_ddl(
3030
output_schema: Optional[Mapping[str, str]] = None,
3131
connection_name: Optional[str] = None,
3232
options: Optional[Mapping[str, Union[str, int, float, bool, list]]] = None,
33-
query_statement: Optional[str] = None,
3433
training_data: Optional[str] = None,
3534
custom_holiday: Optional[str] = None,
3635
) -> str:
37-
"""Encode the CREATE MODEL statement."""
36+
"""Encode the CREATE MODEL statement.
37+
38+
See https://docs.cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create for reference.
39+
"""
3840

3941
if replace:
4042
create = "CREATE OR REPLACE MODEL "
@@ -83,18 +85,15 @@ def create_model_ddl(
8385

8486
# [AS {query_statement | ( training_data AS (query_statement), custom_holiday AS (holiday_statement) )}]
8587

86-
if query_statement and (training_data or custom_holiday):
87-
raise ValueError("Cannot specify both `query_statement` and (`training_data` or `custom_holiday`).")
88-
89-
if query_statement:
90-
ddl += f"AS {query_statement}"
91-
elif training_data:
92-
# specialized AS clause
93-
parts = []
94-
parts.append(f"training_data AS ({training_data})")
88+
if training_data:
9589
if custom_holiday:
90+
# When custom_holiday is present, we need named clauses
91+
parts = []
92+
parts.append(f"training_data AS ({training_data})")
9693
parts.append(f"custom_holiday AS ({custom_holiday})")
97-
98-
ddl += f"AS (\n {', '.join(parts)}\n)"
94+
ddl += f"AS (\n {', '.join(parts)}\n)"
95+
else:
96+
# Just training_data is treated as the query_statement
97+
ddl += f"AS {training_data}"
9998

10099
return ddl

tests/unit/core/sql/test_ml.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ def test_create_model_basic(snapshot):
1919
sql = bigframes.core.sql.ml.create_model_ddl(
2020
model_name="my_project.my_dataset.my_model",
2121
options={"model_type": "LINEAR_REG", "input_label_cols": ["label"]},
22-
query_statement="SELECT * FROM my_table",
22+
training_data="SELECT * FROM my_table",
2323
)
2424
snapshot.assert_match(sql, "create_model_basic.sql")
2525

@@ -28,7 +28,7 @@ def test_create_model_replace(snapshot):
2828
model_name="my_model",
2929
replace=True,
3030
options={"model_type": "LOGISTIC_REG"},
31-
query_statement="SELECT * FROM t",
31+
training_data="SELECT * FROM t",
3232
)
3333
snapshot.assert_match(sql, "create_model_replace.sql")
3434

@@ -37,7 +37,7 @@ def test_create_model_if_not_exists(snapshot):
3737
model_name="my_model",
3838
if_not_exists=True,
3939
options={"model_type": "KMEANS"},
40-
query_statement="SELECT * FROM t",
40+
training_data="SELECT * FROM t",
4141
)
4242
snapshot.assert_match(sql, "create_model_if_not_exists.sql")
4343

@@ -46,7 +46,7 @@ def test_create_model_transform(snapshot):
4646
model_name="my_model",
4747
transform=["ML.STANDARD_SCALER(c1) OVER() AS c1_scaled", "c2"],
4848
options={"model_type": "LINEAR_REG"},
49-
query_statement="SELECT c1, c2, label FROM t",
49+
training_data="SELECT c1, c2, label FROM t",
5050
)
5151
snapshot.assert_match(sql, "create_model_transform.sql")
5252

@@ -81,6 +81,6 @@ def test_create_model_list_option(snapshot):
8181
sql = bigframes.core.sql.ml.create_model_ddl(
8282
model_name="my_model",
8383
options={"hidden_units": [32, 16], "dropout": 0.2},
84-
query_statement="SELECT * FROM t",
84+
training_data="SELECT * FROM t",
8585
)
8686
snapshot.assert_match(sql, "create_model_list_option.sql")

0 commit comments

Comments
 (0)