Skip to content

Commit 402e4fa

Browse files
committed
save user meta
1 parent e270259 commit 402e4fa

File tree

2 files changed

+54
-20
lines changed

2 files changed

+54
-20
lines changed

dataikuapi/dss/ml.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -204,29 +204,23 @@ class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
204204
"DBSCAN" : "db_scan_clustering",
205205
}
206206

207-
class DSSTrainedPredictionModelDetails(object):
208-
"""
209-
Object to read details of a trained prediction model
210207

211-
Do not create this object directly, use :meth:`DSSMLTask.get_trained_model_details()` instead
212-
"""
213208

214-
def __init__(self, details, summary):
209+
class DSSTrainedModelDetails(object):
210+
def __init__(self, details, summary, saved_model=None, saved_model_version=None, mltask=None, mltask_model_id=None):
215211
self.details = details
216212
self.summary = summary
213+
self.saved_model = saved_model
214+
self.saved_model_version = saved_model_version
215+
self.mltask = mltask
216+
self.mltask_model_id = mltask_model_id
217217

218218
def get_raw(self):
219219
"""
220220
Gets the raw dictionary of trained model details
221221
"""
222222
return self.details
223223

224-
def get_raw_snippet(self):
225-
"""
226-
Gets the raw dictionary of trained model snippet
227-
"""
228-
return self.summary
229-
230224
def get_train_info(self):
231225
"""
232226
Returns various information about the train process (size of the train set, quick description, timing information)
@@ -235,6 +229,44 @@ def get_train_info(self):
235229
"""
236230
return self.details["trainInfo"]
237231

232+
def get_user_meta(self):
233+
"""
234+
Gets the user-accessible metadata (name, description, cluster labels, classification threshold)
235+
Returns the original object, not a copy. Changes to the returned object are persisted to DSS by calling
236+
:meth:`save_user_meta`
237+
238+
"""
239+
return self.details["userMeta"]
240+
241+
def save_user_meta(self):
242+
um = self.details["userMeta"]
243+
244+
if self.mltask is not None:
245+
self.mltask.client._perform_empty(
246+
"PUT", "/projects/%s/models/lab/%s/%s/models/%s/user-meta" % (self.mltask.project_key,
247+
self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id), body = um)
248+
else:
249+
self.saved_model.client._perform_empty(
250+
"PUT", "/projects/%s/savedmodels/%s/versions/%s/user-meta" % (self.saved_model.project_key,
251+
self.saved_model.sm_id, self.saved_model_version), body = um)
252+
253+
class DSSTrainedPredictionModelDetails(DSSTrainedModelDetails):
254+
"""
255+
Object to read details of a trained prediction model
256+
257+
Do not create this object directly, use :meth:`DSSMLTask.get_trained_model_details()` instead
258+
"""
259+
260+
def __init__(self, details, summary, saved_model=None, saved_model_version=None, mltask=None, mltask_model_id=None):
261+
DSSTrainedModelDetails.__init__(self, details, summary, saved_model, saved_model_version, mltask, mltask_model_id)
262+
263+
def get_raw_snippet(self):
264+
"""
265+
Gets the raw dictionary of trained model snippet
266+
"""
267+
return self.summary
268+
269+
238270
def get_roc_curve_data(self):
239271
roc = self.details.get("perf", {}).get("rocVizData",{})
240272
if roc is None:
@@ -303,6 +335,7 @@ def get_actual_modeling_params(self):
303335
"""
304336
return self.details["actualParams"]
305337

338+
306339
class DSSClustersFacts(object):
307340
def __init__(self, clusters_facts):
308341
self.clusters_facts = clusters_facts
@@ -331,16 +364,17 @@ def get_facts_for_cluster_and_feature(self, cluster_index, feature_name):
331364
"""
332365
return [x for x in self.get_facts_for_cluster(cluster_index) if x["feature_label"] == feature_name]
333366

334-
class DSSTrainedClusteringModelDetails(object):
367+
368+
class DSSTrainedClusteringModelDetails(DSSTrainedModelDetails):
335369
"""
336370
Object to read details of a trained clustering model
337371
338372
Do not create this object directly, use :meth:`DSSMLTask.get_trained_model_details()` instead
339373
"""
340374

341-
def __init__(self, details, summary):
342-
self.details = details
343-
self.summary = summary
375+
def __init__(self, details, summary, saved_model=None, saved_model_version=None, mltask=None, mltask_model_id=None):
376+
DSSTrainedModelDetails.__init__(self, details, summary, saved_model, saved_model_version, mltask, mltask_model_id)
377+
344378

345379
def get_raw(self):
346380
"""
@@ -511,9 +545,9 @@ def get_trained_model_details(self, id):
511545

512546

513547
if "facts" in ret:
514-
return DSSTrainedClusteringModelDetails(ret, summary)
548+
return DSSTrainedClusteringModelDetails(ret, summary, mltask=self, mltask_model_id=id)
515549
else:
516-
return DSSTrainedPredictionModelDetails(ret, summary)
550+
return DSSTrainedPredictionModelDetails(ret, summary, mltask=self, mltask_model_id=id)
517551

518552
def deploy_to_flow(self, model_id, model_name, train_dataset, test_dataset=None, redo_optimization=True):
519553
"""

dataikuapi/dss/savedmodel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,9 +59,9 @@ def get_version_details(self, version_id):
5959
"GET", "/projects/%s/savedmodels/%s/versions/%s/snippet" % (self.project_key, self.sm_id, version_id))
6060

6161
if "facts" in details:
62-
return DSSTrainedClusteringModelDetails(details, snippet)
62+
return DSSTrainedClusteringModelDetails(details, snippet, saved_model=self, saved_model_version=version_id)
6363
else:
64-
return DSSTrainedPredictionModelDetails(details, snippet)
64+
return DSSTrainedPredictionModelDetails(details, snippet, saved_model=self, saved_model_version=version_id)
6565

6666
def set_active_version(self, version_id):
6767
"""Sets a particular version of the saved model as the active one"""

0 commit comments

Comments
 (0)