Skip to content

Commit e2ab554

Browse files
committed
Add a set_run_classes method to MLflow extension, to store the classes of classification models
1 parent 926eda6 commit e2ab554

File tree

1 file changed

+33
-0
lines changed

1 file changed

+33
-0
lines changed

dataikuapi/dss/mlflow.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import json
2+
13
class DSSMLflowExtension(object):
24
"""
35
A handle to interact with specific endpoints of the DSS MLflow integration.
@@ -131,3 +133,34 @@ def clean_experiment_tracking_db(self):
131133
This call requires an API key with admin rights
132134
"""
133135
self.client._perform_raw("DELETE", "/api/2.0/mlflow/extension/clean-db/%s" % self.project_key)
136+
137+
def set_run_classes(self, run_id, classes):
138+
"""
139+
Stores the classes of the target of classification models trained in the specified run. This information is leveraged
140+
to prefill the classes when deploying using the GUI an MLflow model as a version of a DSS Saved Model.
141+
142+
:param run_id: run_id for which to set the classes
143+
:type run_id: str
144+
:param classes: ordered list of classes
145+
:type classes: list(str)
146+
"""
147+
148+
if not classes:
149+
raise ValueError('Parameter classes must be defined')
150+
if not isinstance(classes, list):
151+
raise ValueError('Wrong type for classes: {}'.format(type(classes)))
152+
for cur_class in classes:
153+
if not cur_class:
154+
raise ValueError('class can not be None')
155+
if not isinstance(cur_class, str):
156+
raise ValueError('Wrong type for class {}: {}'.format(cur_class, type(cur_class)))
157+
self.client._perform_http(
158+
"POST", "/api/2.0/mlflow/runs/set-tag",
159+
headers={"x-dku-mlflow-project-key": self.project_key},
160+
body={
161+
"run_id": run_id,
162+
"run_uuid": run_id,
163+
"key": "dku-ext.targetClasses",
164+
"value": json.dumps(classes)
165+
}
166+
)

0 commit comments

Comments
 (0)