@@ -134,23 +134,23 @@ def clean_experiment_tracking_db(self):
134134 """
135135 self .client ._perform_raw ("DELETE" , "/api/2.0/mlflow/extension/clean-db/%s" % self .project_key )
136136
137- def set_run_inference_info (self , run_id , prediction_type , classes , code_env_name , target ):
137+ def set_run_inference_info (self , run_id , model_type , classes = None , code_env_name = None , target = None ):
138138 """
139- Sets the prediction_type of the model, and the classes, if it is a classification model .
139+ Sets the type of the model, and optionally other information useful to deploy or evaluate it .
140140
141- prediction_type must be one of:
142- - NON_TABULAR if the model is not tabular
141+ model_type must be one of:
143142 - REGRESSION
144143 - BINARY_CLASSIFICATION
145144 - MULTICLASS
145+ - OTHER
146146
147147 Classes must be specified if and only if the model is a BINARY_CLASSIFICATION or MULTICLASS model.
148148
149149 This information is leveraged to filter saved models on their prediction type and prefill the classes
150150 when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
151151
152- :param prediction_type : prediction type (see doc)
153- :type prediction_type : str
152+ :param model_type : prediction type (see doc)
153+ :type model_type : str
154154 :param run_id: run_id for which to set the classes
155155 :type run_id: str
156156 :param classes: ordered list of classes (not for all prediction types, see doc)
@@ -160,18 +160,18 @@ def set_run_inference_info(self, run_id, prediction_type, classes, code_env_name
160160 :param target: name of the target
161161 :type target: str
162162 """
163- if prediction_type not in {"NON_TABULAR " , "REGRESSION " , "BINARY_CLASSIFICATION " , "MULTICLASS " }:
164- raise ValueError ('Invalid prediction type: {}' .format (prediction_type ))
163+ if model_type not in {"REGRESSION " , "BINARY_CLASSIFICATION " , "MULTICLASS " , "OTHER " }:
164+ raise ValueError ('Invalid prediction type: {}' .format (model_type ))
165165
166- if classes and prediction_type not in {"BINARY_CLASSIFICATION" , "MULTICLASS" }:
166+ if classes and model_type not in {"BINARY_CLASSIFICATION" , "MULTICLASS" }:
167167 raise ValueError ('Classes can be specified only for BINARY_CLASSIFICATION or MULTICLASS prediction types' )
168- if prediction_type in {"BINARY_CLASSIFICATION" , "MULTICLASS" }:
168+ if model_type in {"BINARY_CLASSIFICATION" , "MULTICLASS" }:
169169 if not classes :
170- raise ValueError ('Classes must be specified for {} prediction type' .format (prediction_type ))
170+ raise ValueError ('Classes must be specified for {} prediction type' .format (model_type ))
171171 if not isinstance (classes , list ):
172172 raise ValueError ('Wrong type for classes: {}' .format (type (classes )))
173173 for cur_class in classes :
174- if not cur_class :
174+ if cur_class is None :
175175 raise ValueError ('class can not be None' )
176176 if not isinstance (cur_class , str ):
177177 raise ValueError ('Wrong type for class {}: {}' .format (cur_class , type (cur_class )))
@@ -183,7 +183,7 @@ def set_run_inference_info(self, run_id, prediction_type, classes, code_env_name
183183
184184 params = {
185185 "run_id" : run_id ,
186- "prediction_type" : prediction_type
186+ "prediction_type" : model_type
187187 }
188188
189189 if classes :
0 commit comments