Skip to content

Commit e270259

Browse files
committed
Saved models API
1 parent c643875 commit e270259

File tree

2 files changed

+51
-7
lines changed

2 files changed

+51
-7
lines changed

dataikuapi/dss/ml.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ def get_raw(self):
221221
"""
222222
return self.details
223223

224+
def get_raw_snippet(self):
225+
"""
226+
Gets the raw dictionary of trained model snippet
227+
"""
228+
return self.summary
229+
224230
def get_train_info(self):
225231
"""
226232
Returns various information about the train process (size of the train set, quick description, timing information)
@@ -399,6 +405,7 @@ def get_actual_modeling_params(self):
399405
return self.details["actualParams"]
400406

401407
class DSSMLTask(object):
408+
"""A handle to interact with a MLTask for prediction or clustering in a DSS visual analysis"""
402409
def __init__(self, client, project_key, analysis_id, mltask_id):
403410
self.client = client
404411
self.project_key = project_key
@@ -476,7 +483,7 @@ def get_trained_models_ids(self):
476483

477484
def get_trained_model_summary(self, id):
478485
"""
479-
Gets a summary of a trained model
486+
Gets a quick summary of a trained model, as a dict. For complete information and a structured object, use :meth:get_trained_model_details
480487
481488
:rtype: dict
482489
"""

dataikuapi/dss/savedmodel.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
from ..utils import DataikuUTF8CSVReader
33
from ..utils import DataikuStreamedHttpUTF8CSVReader
44
import json
5+
from .ml import DSSTrainedPredictionModelDetails, DSSTrainedClusteringModelDetails
56
from .metrics import ComputedMetrics
67

78
class DSSSavedModel(object):
89
"""
9-
A saved model on the DSS instance
10+
A handle to interact with a saved model on the DSS instance.
11+
12+
Do not create this directly, use :meth:`dataikuapi.dss.DSSProject.get_saved_model`
1013
"""
1114
def __init__(self, client, project_key, sm_id):
1215
self.client = client
1316
self.project_key = project_key
1417
self.sm_id = sm_id
15-
18+
1619

1720
########################################################
1821
# Versions
@@ -22,12 +25,48 @@ def list_versions(self):
2225
"""
2326
Get the versions of this saved model
2427
25-
Returns:
26-
an list of the versions
28+
:return: a list of the versions, as a dict of object. Each object contains at least a "id" parameter, which can be passed to :meth:`get_metric_values`, :meth:`get_version_details` and :meth:`set_active_version`
29+
:rtype: list
2730
"""
2831
return self.client._perform_json(
2932
"GET", "/projects/%s/savedmodels/%s/versions" % (self.project_key, self.sm_id))
3033

34+
def get_active_version(self):
35+
"""
36+
Gets the active version of this saved model
37+
38+
:return: a dict representing the active version or None if no version is active. The dict contains at least a "id" parameter, which can be passed to :meth:`get_metric_values`, :meth:`get_version_details` and :meth:`set_active_version`
39+
:rtype: dict
40+
"""
41+
filtered = [x for x in self.list_versions() if x["active"]]
42+
if len(filtered) == 0:
43+
return None
44+
else:
45+
return filtered[0]
46+
47+
def get_version_details(self, version_id):
48+
"""
49+
Gets details for a version of a saved model
50+
51+
:param str version_id: Identifier of the version, as returned by :meth:`list_versions`
52+
53+
:return: A :class:`DSSTrainedPredictionModelDetails` representing the details of this trained model id
54+
:rtype: :class:`DSSTrainedPredictionModelDetails`
55+
"""
56+
details = self.client._perform_json(
57+
"GET", "/projects/%s/savedmodels/%s/versions/%s/details" % (self.project_key, self.sm_id, version_id))
58+
snippet = self.client._perform_json(
59+
"GET", "/projects/%s/savedmodels/%s/versions/%s/snippet" % (self.project_key, self.sm_id, version_id))
60+
61+
if "facts" in details:
62+
return DSSTrainedClusteringModelDetails(details, snippet)
63+
else:
64+
return DSSTrainedPredictionModelDetails(details, snippet)
65+
66+
def set_active_version(self, version_id):
67+
"""Sets a particular version of the saved model as the active one"""
68+
self.client._perform_empty(
69+
"POST", "/projects/%s/savedmodels/%s/versions/%s/actions/setActive" % (self.project_key, self.sm_id, version_id))
3170

3271
########################################################
3372
# Metrics
@@ -43,8 +82,6 @@ def get_metric_values(self, version_id):
4382
return ComputedMetrics(self.client._perform_json(
4483
"GET", "/projects/%s/savedmodels/%s/metrics/%s" % (self.project_key, self.sm_id, version_id)))
4584

46-
47-
4885

4986
########################################################
5087
# Usages

0 commit comments

Comments
 (0)