@@ -32,15 +32,17 @@ class TensorFlowModel(base.Predictor):
3232 """Imported TensorFlow model.
3333
3434 Args:
35+ model_path (str):
36+ GCS path that holds the model files.
3537 session (BigQuery Session):
3638 BQ session to create the model
37- model_path (str):
38- GCS path that holds the model files."""
39+ """
3940
4041 def __init__ (
4142 self ,
43+ model_path : str ,
44+ * ,
4245 session : Optional [bigframes .Session ] = None ,
43- model_path : Optional [str ] = None ,
4446 ):
4547 self .session = session or bpd .get_global_session ()
4648 self .model_path = model_path
@@ -59,7 +61,7 @@ def _from_bq(
5961 ) -> TensorFlowModel :
6062 assert model .model_type == "TENSORFLOW"
6163
62- tf_model = cls (session = session , model_path = None )
64+ tf_model = cls (session = session , model_path = "" )
6365 tf_model ._bqml_model = core .BqmlModel (session , model )
6466 return tf_model
6567
@@ -109,15 +111,17 @@ class ONNXModel(base.Predictor):
109111 """Imported Open Neural Network Exchange (ONNX) model.
110112
111113 Args:
114+ model_path (str):
115+ Cloud Storage path that holds the model files.
112116 session (BigQuery Session):
113117 BQ session to create the model
114- model_path (str):
115- Cloud Storage path that holds the model files."""
118+ """
116119
117120 def __init__ (
118121 self ,
122+ model_path : str ,
123+ * ,
119124 session : Optional [bigframes .Session ] = None ,
120- model_path : Optional [str ] = None ,
121125 ):
122126 self .session = session or bpd .get_global_session ()
123127 self .model_path = model_path
@@ -134,7 +138,7 @@ def _create_bqml_model(self):
134138 def _from_bq (cls , session : bigframes .Session , model : bigquery .Model ) -> ONNXModel :
135139 assert model .model_type == "ONNX"
136140
137- onnx_model = cls (session = session , model_path = None )
141+ onnx_model = cls (session = session , model_path = "" )
138142 onnx_model ._bqml_model = core .BqmlModel (session , model )
139143 return onnx_model
140144
@@ -189,8 +193,8 @@ class XGBoostModel(base.Predictor):
189193 https://cloud.google.com/bigquery/docs/reference/standard-sql/bigqueryml-syntax-create-xgboost#limitations
190194
191195 Args:
192- session (BigQuery Session ):
193- BQ session to create the model
196+ model_path (str ):
197+ Cloud Storage path that holds the model files.
194198 input (Dict, default None):
195199 Specify the model input schema information when you
196200 create the XGBoost model. The input should be the format of
@@ -203,15 +207,17 @@ class XGBoostModel(base.Predictor):
203207 {field_name: field_type}. Output is optional only if feature_names
204208 and feature_types are both specified in the model file. Supported types
205209 are "bool", "string", "int64", "float64", "array<bool>", "array<string>", "array<int64>", "array<float64>".
206- model_path (str):
207- Cloud Storage path that holds the model files."""
210+ session (BigQuery Session):
211+ BQ session to create the model
212+ """
208213
209214 def __init__ (
210215 self ,
211- session : Optional [bigframes .Session ] = None ,
216+ model_path : str ,
217+ * ,
212218 input : Mapping [str , str ] = {},
213219 output : Mapping [str , str ] = {},
214- model_path : Optional [str ] = None ,
220+ session : Optional [bigframes . Session ] = None ,
215221 ):
216222 self .session = session or bpd .get_global_session ()
217223 self .model_path = model_path
@@ -248,7 +254,7 @@ def _from_bq(
248254 ) -> XGBoostModel :
249255 assert model .model_type == "XGBOOST"
250256
251- xgboost_model = cls (session = session , model_path = None )
257+ xgboost_model = cls (session = session , model_path = "" )
252258 xgboost_model ._bqml_model = core .BqmlModel (session , model )
253259 return xgboost_model
254260
0 commit comments