Skip to content

Commit 02dedae

Browse files
committed
Renaming set_run_classes to set_run_inference_info. Adding prediction_type, code_env-name and target.
1 parent e2ab554 commit 02dedae

File tree

1 file changed

+55
-21
lines changed

1 file changed

+55
-21
lines changed

dataikuapi/dss/mlflow.py

Lines changed: 55 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -134,33 +134,67 @@ def clean_experiment_tracking_db(self):
134134
"""
135135
self.client._perform_raw("DELETE", "/api/2.0/mlflow/extension/clean-db/%s" % self.project_key)
136136

137-
def set_run_classes(self, run_id, classes):
137+
def set_run_inference_info(self, run_id, prediction_type, classes, code_env_name, target):
138138
"""
139-
Stores the classes of the target of classification models trained in the specified run. This information is leveraged
140-
to prefill the classes when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
139+
Sets the prediction_type of the model, and the classes, if it is a classification model.
141140
141+
prediction_type must be one of:
142+
- NON_TABULAR if the model is not tabular
143+
- REGRESSION
144+
- BINARY_CLASSIFICATION
145+
- MULTICLASS
146+
147+
Classes must be specified if and only if the model is a BINARY_CLASSIFICATION or MULTICLASS model.
148+
149+
This information is leveraged to filter saved models on their prediction type and prefill the classes
150+
when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
151+
152+
:param prediction_type: prediction type (see doc)
153+
:type prediction_type: str
142154
:param run_id: run_id for which to set the classes
143155
:type run_id: str
144-
:param classes: ordered list of classes
156+
:param classes: ordered list of classes (not for all prediction types, see doc)
145157
:type classes: list(str)
146-
"""
158+
:param code_env_name: name of an adequate DSS python code environment
159+
:type code_env_name: str
160+
:param target: name of the target
161+
:type target: str
162+
"""
163+
if prediction_type not in {"NON_TABULAR", "REGRESSION", "BINARY_CLASSIFICATION", "MULTICLASS"}:
164+
raise ValueError('Invalid prediction type: {}'.format(prediction_type))
165+
166+
if classes and prediction_type not in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
167+
raise ValueError('Classes can be specified only for BINARY_CLASSIFICATION or MULTICLASS prediction types')
168+
if prediction_type in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
169+
if not classes:
170+
raise ValueError('Classes must be specified for {} prediction type'.format(prediction_type))
171+
if not isinstance(classes, list):
172+
raise ValueError('Wrong type for classes: {}'.format(type(classes)))
173+
for cur_class in classes:
174+
if not cur_class:
175+
raise ValueError('class can not be None')
176+
if not isinstance(cur_class, str):
177+
raise ValueError('Wrong type for class {}: {}'.format(cur_class, type(cur_class)))
178+
179+
if code_env_name and not isinstance(code_env_name, str):
180+
raise ValueError('code_env_name must be a string')
181+
if target and not isinstance(target, str):
182+
raise ValueError('target must be a string')
183+
184+
params = {
185+
"run_id": run_id,
186+
"prediction_type": prediction_type
187+
}
188+
189+
if classes:
190+
params["classes"] = json.dumps(classes)
191+
if code_env_name:
192+
params["code_env_name"] = code_env_name
193+
if target:
194+
params["target"] = target
147195

148-
if not classes:
149-
raise ValueError('Parameter classes must be defined')
150-
if not isinstance(classes, list):
151-
raise ValueError('Wrong type for classes: {}'.format(type(classes)))
152-
for cur_class in classes:
153-
if not cur_class:
154-
raise ValueError('class can not be None')
155-
if not isinstance(cur_class, str):
156-
raise ValueError('Wrong type for class {}: {}'.format(cur_class, type(cur_class)))
157196
self.client._perform_http(
158-
"POST", "/api/2.0/mlflow/runs/set-tag",
197+
"POST", "/api/2.0/mlflow/extension/set-run-inference-info",
159198
headers={"x-dku-mlflow-project-key": self.project_key},
160-
body={
161-
"run_id": run_id,
162-
"run_uuid": run_id,
163-
"key": "dku-ext.targetClasses",
164-
"value": json.dumps(classes)
165-
}
199+
body=params
166200
)

0 commit comments

Comments
 (0)