22from ..utils import DataikuUTF8CSVReader
33from ..utils import DataikuStreamedHttpUTF8CSVReader
44import json
5+ from .ml import DSSTrainedPredictionModelDetails , DSSTrainedClusteringModelDetails
56from .metrics import ComputedMetrics
67
78class DSSSavedModel (object ):
89 """
9- A saved model on the DSS instance
10+ A handle to interact with a saved model on the DSS instance.
11+
12+ Do not create this directly, use :meth:`dataikuapi.dss.DSSProject.get_saved_model`
1013 """
1114 def __init__ (self , client , project_key , sm_id ):
1215 self .client = client
1316 self .project_key = project_key
1417 self .sm_id = sm_id
15-
18+
1619
1720 ########################################################
1821 # Versions
@@ -22,12 +25,48 @@ def list_versions(self):
2225 """
2326 Get the versions of this saved model
2427
25- Returns:
26- an list of the versions
28+ :return: a list of the versions, as a dict of object. Each object contains at least a "id" parameter, which can be passed to :meth:`get_metric_values`, :meth:`get_version_details` and :meth:`set_active_version`
29+ :rtype: list
2730 """
2831 return self .client ._perform_json (
2932 "GET" , "/projects/%s/savedmodels/%s/versions" % (self .project_key , self .sm_id ))
3033
34+ def get_active_version (self ):
35+ """
36+ Gets the active version of this saved model
37+
38+ :return: a dict representing the active version or None if no version is active. The dict contains at least a "id" parameter, which can be passed to :meth:`get_metric_values`, :meth:`get_version_details` and :meth:`set_active_version`
39+ :rtype: dict
40+ """
41+ filtered = [x for x in self .list_versions () if x ["active" ]]
42+ if len (filtered ) == 0 :
43+ return None
44+ else :
45+ return filtered [0 ]
46+
47+ def get_version_details (self , version_id ):
48+ """
49+ Gets details for a version of a saved model
50+
51+ :param str version_id: Identifier of the version, as returned by :meth:`list_versions`
52+
53+ :return: A :class:`DSSTrainedPredictionModelDetails` representing the details of this trained model id
54+ :rtype: :class:`DSSTrainedPredictionModelDetails`
55+ """
56+ details = self .client ._perform_json (
57+ "GET" , "/projects/%s/savedmodels/%s/versions/%s/details" % (self .project_key , self .sm_id , version_id ))
58+ snippet = self .client ._perform_json (
59+ "GET" , "/projects/%s/savedmodels/%s/versions/%s/snippet" % (self .project_key , self .sm_id , version_id ))
60+
61+ if "facts" in details :
62+ return DSSTrainedClusteringModelDetails (details , snippet )
63+ else :
64+ return DSSTrainedPredictionModelDetails (details , snippet )
65+
66+ def set_active_version (self , version_id ):
67+ """Sets a particular version of the saved model as the active one"""
68+ self .client ._perform_empty (
69+ "POST" , "/projects/%s/savedmodels/%s/versions/%s/actions/setActive" % (self .project_key , self .sm_id , version_id ))
3170
3271 ########################################################
3372 # Metrics
@@ -43,8 +82,6 @@ def get_metric_values(self, version_id):
4382 return ComputedMetrics (self .client ._perform_json (
4483 "GET" , "/projects/%s/savedmodels/%s/metrics/%s" % (self .project_key , self .sm_id , version_id )))
4584
46-
47-
4885
4986 ########################################################
5087 # Usages
0 commit comments