Skip to content

Commit 969a4e5

Browse files
committed
Renaming prediction_type to model_type. Renaming NON_TABULAR to OTHER
1 parent 02dedae commit 969a4e5

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

dataikuapi/dss/mlflow.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,23 +134,23 @@ def clean_experiment_tracking_db(self):
134134
"""
135135
self.client._perform_raw("DELETE", "/api/2.0/mlflow/extension/clean-db/%s" % self.project_key)
136136

137-
def set_run_inference_info(self, run_id, prediction_type, classes, code_env_name, target):
137+
def set_run_inference_info(self, run_id, model_type, classes=None, code_env_name=None, target=None):
138138
"""
139-
Sets the prediction_type of the model, and the classes, if it is a classification model.
139+
Sets the type of the model, and optionally other information useful to deploy or evaluate it.
140140
141-
prediction_type must be one of:
142-
- NON_TABULAR if the model is not tabular
141+
model_type must be one of:
143142
- REGRESSION
144143
- BINARY_CLASSIFICATION
145144
- MULTICLASS
145+
- OTHER
146146
147147
Classes must be specified if and only if the model is a BINARY_CLASSIFICATION or MULTICLASS model.
148148
149149
This information is leveraged to filter saved models on their prediction type and prefill the classes
150150
when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
151151
152-
:param prediction_type: prediction type (see doc)
153-
:type prediction_type: str
152+
:param model_type: prediction type (see doc)
153+
:type model_type: str
154154
:param run_id: run_id for which to set the classes
155155
:type run_id: str
156156
:param classes: ordered list of classes (not for all prediction types, see doc)
@@ -160,18 +160,18 @@ def set_run_inference_info(self, run_id, prediction_type, classes, code_env_name
160160
:param target: name of the target
161161
:type target: str
162162
"""
163-
if prediction_type not in {"NON_TABULAR", "REGRESSION", "BINARY_CLASSIFICATION", "MULTICLASS"}:
164-
raise ValueError('Invalid prediction type: {}'.format(prediction_type))
163+
if model_type not in {"REGRESSION", "BINARY_CLASSIFICATION", "MULTICLASS", "OTHER"}:
164+
raise ValueError('Invalid prediction type: {}'.format(model_type))
165165

166-
if classes and prediction_type not in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
166+
if classes and model_type not in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
167167
raise ValueError('Classes can be specified only for BINARY_CLASSIFICATION or MULTICLASS prediction types')
168-
if prediction_type in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
168+
if model_type in {"BINARY_CLASSIFICATION", "MULTICLASS"}:
169169
if not classes:
170-
raise ValueError('Classes must be specified for {} prediction type'.format(prediction_type))
170+
raise ValueError('Classes must be specified for {} prediction type'.format(model_type))
171171
if not isinstance(classes, list):
172172
raise ValueError('Wrong type for classes: {}'.format(type(classes)))
173173
for cur_class in classes:
174-
if not cur_class:
174+
if cur_class is None:
175175
raise ValueError('class can not be None')
176176
if not isinstance(cur_class, str):
177177
raise ValueError('Wrong type for class {}: {}'.format(cur_class, type(cur_class)))
@@ -183,7 +183,7 @@ def set_run_inference_info(self, run_id, prediction_type, classes, code_env_name
183183

184184
params = {
185185
"run_id": run_id,
186-
"prediction_type": prediction_type
186+
"prediction_type": model_type
187187
}
188188

189189
if classes:

0 commit comments

Comments
 (0)