Skip to content

Commit c1adfd9

Browse files
feat: Add BigQuery ML CREATE MODEL support
- Refactor `bigframes.core.sql` to a package. - Add `bigframes.core.sql.ml` for DDL generation. - Add `bigframes.bigquery.ml` module with `create_model` function. - Add unit tests for SQL generation. - Use `_start_query_ml_ddl` for execution. - Return the created model object using `read_gbq_model`. - Remove `query` argument, simplify SQL generation logic. - Fix linting and mypy errors.
1 parent fdee53f commit c1adfd9

File tree

2 files changed

+12
-7
lines changed

2 files changed

+12
-7
lines changed

bigframes/bigquery/_operations/ml.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
import typing
1818
from typing import Mapping, Optional, Union
1919

20-
import bigframes.core.sql.ml
2120
import bigframes.core.log_adapter as log_adapter
21+
import bigframes.core.sql.ml
2222
import bigframes.dataframe as dataframe
2323
import bigframes.ml.base
2424
import bigframes.session
@@ -59,7 +59,11 @@ def create_model(
5959
# Determine session from DataFrames if not provided
6060
if session is None:
6161
# Try to get session from inputs
62-
dfs = [obj for obj in [training_data, custom_holiday] if hasattr(obj, "_session")]
62+
dfs = [
63+
obj
64+
for obj in [training_data, custom_holiday]
65+
if isinstance(obj, dataframe.DataFrame)
66+
]
6367
if dfs:
6468
session = dfs[0]._session
6569

bigframes/core/sql/ml.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import bigframes.core.compile.googlesql as googlesql
2121
import bigframes.core.sql
2222

23+
2324
def create_model_ddl(
2425
model_name: str,
2526
*,
@@ -63,9 +64,9 @@ def create_model_ddl(
6364
# [REMOTE WITH CONNECTION {connection_name | DEFAULT}]
6465
if connection_name:
6566
if connection_name.upper() == "DEFAULT":
66-
ddl += "REMOTE WITH CONNECTION DEFAULT\n"
67+
ddl += "REMOTE WITH CONNECTION DEFAULT\n"
6768
else:
68-
ddl += f"REMOTE WITH CONNECTION {googlesql.identifier(connection_name)}\n"
69+
ddl += f"REMOTE WITH CONNECTION {googlesql.identifier(connection_name)}\n"
6970

7071
# [OPTIONS(model_option_list)]
7172
if options:
@@ -77,7 +78,7 @@ def create_model_ddl(
7778
# if value is list, it is [val1, val2]
7879
rendered_val = bigframes.core.sql.simple_literal(list(option_value))
7980
else:
80-
rendered_val = bigframes.core.sql.simple_literal(option_value)
81+
rendered_val = bigframes.core.sql.simple_literal(option_value)
8182

8283
rendered_options.append(f"{option_name} = {rendered_val}")
8384

@@ -93,7 +94,7 @@ def create_model_ddl(
9394
parts.append(f"custom_holiday AS ({custom_holiday})")
9495
ddl += f"AS (\n {', '.join(parts)}\n)"
9596
else:
96-
# Just training_data is treated as the query_statement
97-
ddl += f"AS {training_data}"
97+
# Just training_data is treated as the query_statement
98+
ddl += f"AS {training_data}"
9899

99100
return ddl

0 commit comments

Comments
 (0)