@@ -134,33 +134,67 @@ def clean_experiment_tracking_db(self):
134134 """
135135 self .client ._perform_raw ("DELETE" , "/api/2.0/mlflow/extension/clean-db/%s" % self .project_key )
136136
137- def set_run_classes (self , run_id , classes ):
137+ def set_run_inference_info (self , run_id , prediction_type , classes , code_env_name , target ):
138138 """
139- Stores the classes of the target of classification models trained in the specified run. This information is leveraged
140- to prefill the classes when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
139+ Sets the prediction_type of the model, and the classes, if it is a classification model.
141140
141+ prediction_type must be one of:
142+ - NON_TABULAR if the model is not tabular
143+ - REGRESSION
144+ - BINARY_CLASSIFICATION
145+ - MULTICLASS
146+
147+ Classes must be specified if and only if the model is a BINARY_CLASSIFICATION or MULTICLASS model.
148+
149+ This information is leveraged to filter saved models on their prediction type and prefill the classes
150+ when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
151+
152+ :param prediction_type: prediction type (see doc)
153+ :type prediction_type: str
142154 :param run_id: run_id for which to set the classes
143155 :type run_id: str
144- :param classes: ordered list of classes
156+ :param classes: ordered list of classes (not for all prediction types, see doc)
145157 :type classes: list(str)
146- """
158+ :param code_env_name: name of an adequate DSS python code environment
159+ :type code_env_name: str
160+ :param target: name of the target
161+ :type target: str
162+ """
163+ if prediction_type not in {"NON_TABULAR" , "REGRESSION" , "BINARY_CLASSIFICATION" , "MULTICLASS" }:
164+ raise ValueError ('Invalid prediction type: {}' .format (prediction_type ))
165+
166+ if classes and prediction_type not in {"BINARY_CLASSIFICATION" , "MULTICLASS" }:
167+ raise ValueError ('Classes can be specified only for BINARY_CLASSIFICATION or MULTICLASS prediction types' )
168+ if prediction_type in {"BINARY_CLASSIFICATION" , "MULTICLASS" }:
169+ if not classes :
170+ raise ValueError ('Classes must be specified for {} prediction type' .format (prediction_type ))
171+ if not isinstance (classes , list ):
172+ raise ValueError ('Wrong type for classes: {}' .format (type (classes )))
173+ for cur_class in classes :
174+ if not cur_class :
175+ raise ValueError ('class can not be None' )
176+ if not isinstance (cur_class , str ):
177+ raise ValueError ('Wrong type for class {}: {}' .format (cur_class , type (cur_class )))
178+
179+ if code_env_name and not isinstance (code_env_name , str ):
180+ raise ValueError ('code_env_name must be a string' )
181+ if target and not isinstance (target , str ):
182+ raise ValueError ('target must be a string' )
183+
184+ params = {
185+ "run_id" : run_id ,
186+ "prediction_type" : prediction_type
187+ }
188+
189+ if classes :
190+ params ["classes" ] = json .dumps (classes )
191+ if code_env_name :
192+ params ["code_env_name" ] = code_env_name
193+ if target :
194+ params ["target" ] = target
147195
148- if not classes :
149- raise ValueError ('Parameter classes must be defined' )
150- if not isinstance (classes , list ):
151- raise ValueError ('Wrong type for classes: {}' .format (type (classes )))
152- for cur_class in classes :
153- if not cur_class :
154- raise ValueError ('class can not be None' )
155- if not isinstance (cur_class , str ):
156- raise ValueError ('Wrong type for class {}: {}' .format (cur_class , type (cur_class )))
157196 self .client ._perform_http (
158- "POST" , "/api/2.0/mlflow/runs /set-tag " ,
197+ "POST" , "/api/2.0/mlflow/extension /set-run-inference-info " ,
159198 headers = {"x-dku-mlflow-project-key" : self .project_key },
160- body = {
161- "run_id" : run_id ,
162- "run_uuid" : run_id ,
163- "key" : "dku-ext.targetClasses" ,
164- "value" : json .dumps (classes )
165- }
199+ body = params
166200 )
0 commit comments