|
| 1 | +import json |
| 2 | + |
1 | 3 | class DSSMLflowExtension(object): |
2 | 4 | """ |
3 | 5 | A handle to interact with specific endpoints of the DSS MLflow integration. |
@@ -131,3 +133,34 @@ def clean_experiment_tracking_db(self): |
131 | 133 | This call requires an API key with admin rights |
132 | 134 | """ |
133 | 135 | self.client._perform_raw("DELETE", "/api/2.0/mlflow/extension/clean-db/%s" % self.project_key) |
| 136 | + |
| 137 | + def set_run_classes(self, run_id, classes): |
| 138 | + """ |
| 139 | + Stores the classes of the target of classification models trained in the specified run. This information is leveraged |
| 140 | + to prefill the classes when deploying using the GUI an MLflow model as a version of a DSS Saved Model. |
| 141 | +
|
| 142 | + :param run_id: run_id for which to set the classes |
| 143 | + :type run_id: str |
| 144 | + :param classes: ordered list of classes |
| 145 | + :type classes: list(str) |
| 146 | + """ |
| 147 | + |
| 148 | + if not classes: |
| 149 | + raise ValueError('Parameter classes must be defined') |
| 150 | + if not isinstance(classes, list): |
| 151 | + raise ValueError('Wrong type for classes: {}'.format(type(classes))) |
| 152 | + for cur_class in classes: |
| 153 | + if not cur_class: |
| 154 | + raise ValueError('class can not be None') |
| 155 | + if not isinstance(cur_class, str): |
| 156 | + raise ValueError('Wrong type for class {}: {}'.format(cur_class, type(cur_class))) |
| 157 | + self.client._perform_http( |
| 158 | + "POST", "/api/2.0/mlflow/runs/set-tag", |
| 159 | + headers={"x-dku-mlflow-project-key": self.project_key}, |
| 160 | + body={ |
| 161 | + "run_id": run_id, |
| 162 | + "run_uuid": run_id, |
| 163 | + "key": "dku-ext.targetClasses", |
| 164 | + "value": json.dumps(classes) |
| 165 | + } |
| 166 | + ) |
0 commit comments