Skip to content

Commit 98f816f

Browse files
committed
Implement wrappers for subpopulation results
1 parent 741ecf3 commit 98f816f

File tree

2 files changed

+141
-6
lines changed

2 files changed

+141
-6
lines changed

dataikuapi/dss/ml.py

Lines changed: 79 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from ..utils import DataikuException
22
from ..utils import DataikuUTF8CSVReader
33
from ..utils import DataikuStreamedHttpUTF8CSVReader
4+
from ..utils import DSSExtendableDict
45
import json
56
import time
67
from .metrics import ComputedMetrics
@@ -606,8 +607,8 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
606607
:param int n_jobs: number of cores used for parallel training. (-1 means 'all cores')
607608
:param bool debug_mode: if True, output all logs (slower)
608609
609-
:returns: if wait is True, a dict containing the Subpopulation analyses, else a future to wait on the result
610-
:rtype: dict or :class:`dataikuapi.dss.future.DSSFuture`
610+
:returns: if wait is True, an object containing the Subpopulation analyses, else a future to wait on the result
611+
:rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationAnalyses` or :class:`dataikuapi.dss.future.DSSFuture`
611612
"""
612613

613614
body = {
@@ -633,26 +634,30 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
633634
)
634635
future = DSSFuture(self.saved_model.client, future_response.get("jobId", None), future_response)
635636
if wait:
636-
return future.wait_for_result()
637+
return DSSSubpopulationAnalyses(future.wait_for_result())
637638
else:
638639
return future
639640

640641

641642
def get_subpopulation_analyses(self):
642643
"""
643-
Retrieve all subpopulation analyses computed for this trained model as a dict
644+
Retrieve all subpopulation analyses computed for this trained model
645+
646+
:returns: the subpopulation analyses
647+
:rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationAnalyses`
644648
"""
645649

646650
if self.mltask is not None:
647-
return self.mltask.client._perform_json(
651+
data = self.mltask.client._perform_json(
648652
"GET", "/projects/%s/models/lab/%s/%s/models/%s/subpopulation-analyses" %
649653
(self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id)
650654
)
651655
else:
652-
return self.saved_model.client._perform_json(
656+
data = self.saved_model.client._perform_json(
653657
"GET", "/projects/%s/savedmodels/%s/versions/%s/subpopulation-analyses" %
654658
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
655659
)
660+
return DSSSubpopulationAnalyses(data)
656661

657662
def compute_partial_dependencies(self, features, wait=True, sample_size=1000, random_state=1337, n_jobs=1, debug_mode=False):
658663
"""
@@ -712,6 +717,74 @@ def get_partial_dependencies(self):
712717
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
713718
)
714719

720+
class DSSSubpopulationAnalysis(DSSExtendableDict):
721+
"""
722+
Object to read details of a subpopulation analysis of a trained model
723+
724+
Do not create this object directly, use :meth:`DSSSubpopulationAnalyses.get_analysis(feature)` instead
725+
"""
726+
727+
def __init__(self, analysis):
728+
super(DSSSubpopulationAnalysis, self).__init__(analysis)
729+
730+
def get_computation_params(self):
731+
"""
732+
Gets computation params
733+
"""
734+
computation_params = {}
735+
computation_params["nbRecords"] = self.get("nbRecords")
736+
computation_params["randomState"] = self.get("randomState")
737+
computation_params["onSample"] = self.get("onSample")
738+
return computation_params
739+
740+
def get_raw(self):
741+
"""
742+
Gets the raw dictionary of the subpopulation analysis
743+
"""
744+
return self.internal_dict
745+
746+
747+
class DSSSubpopulationAnalyses(DSSExtendableDict):
748+
"""
749+
Object to read details of subpopulation analyses of a trained model
750+
751+
Do not create this object directly, use :meth:`DSSTrainedPredictionModelDetails.get_subpopulation_analyses()` instead
752+
"""
753+
754+
def __init__(self, data):
755+
super(DSSSubpopulationAnalyses, self).__init__(data)
756+
self.analyses = []
757+
for analysis in data.get("subpopulationAnalyses", []):
758+
self.analyses.append(DSSSubpopulationAnalysis(analysis))
759+
760+
def get_raw(self):
761+
"""
762+
Gets the raw dictionary of subpopulation analyses
763+
"""
764+
return self.internal_dict
765+
766+
def get_all_dataset(self):
767+
"""
768+
Retrieve information and performance on the full dataset used to compute the subpopulation analyses
769+
"""
770+
return self.get("allDataset")
771+
772+
def list_analyses(self):
773+
"""
774+
Lists all features on which subpopulation analyses have been computed
775+
"""
776+
return [analysis["feature"] for analysis in self.analyses]
777+
778+
def get_analysis(self, feature):
779+
"""
780+
Retrieves the subpopulation analysis for a particular feature
781+
"""
782+
if feature not in self.list_analyses():
783+
raise ValueError("Subpopulation analysis for feature '%s' cannot be found" % feature)
784+
785+
return next(analysis for analysis in self.analyses if analysis["feature"] == feature)
786+
787+
715788
class DSSClustersFacts(object):
716789
def __init__(self, clusters_facts):
717790
self.clusters_facts = clusters_facts

dataikuapi/utils.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,3 +94,65 @@ def str_to_bool(s):
9494
doublequote=True):
9595
yield [none_if_throws(caster)(val)
9696
for (caster, val) in dku_zip_longest(casters, uncasted_tuple)]
97+
98+
class DSSExtendableDict(dict):
99+
100+
def __init__(self, orig_dict=None):
101+
if orig_dict is None:
102+
self.internal_dict = dict()
103+
else:
104+
self.internal_dict = orig_dict
105+
106+
def __getitem__(self, key):
107+
return self.internal_dict[key]
108+
109+
def __iter__(self):
110+
return self.internal_dict.__iter__()
111+
112+
def __setitem__(self, key, value):
113+
self.internal_dict[key] = value
114+
115+
def __str__(self):
116+
return self.internal_dict.__str__()
117+
118+
def __len__(self):
119+
return self.internal_dict.__len__()
120+
121+
def clear(self):
122+
self.internal_dict.clear()
123+
124+
def __contains__(self, key):
125+
return self.internal_dict.__contains__(key)
126+
127+
def copy(self):
128+
return self.internal_dict.copy()
129+
130+
def fromkeys(self, sequence, value=None):
131+
return self.internal_dict.fromkeys(sequence, value)
132+
133+
def get(self, key, value=None):
134+
return self.internal_dict.get(key, value)
135+
136+
def items(self):
137+
return self.internal_dict.items()
138+
139+
def keys(self):
140+
return self.internal_dict.keys()
141+
142+
def popitem(self):
143+
return self.internal_dict.popitem()
144+
145+
def pop(self, key, *argv):
146+
return self.internal_dict.pop(key, *argv)
147+
148+
def setdefault(self, key, default_value=None):
149+
return self.internal_dict.setdefault(key, default_value)
150+
151+
def update(self,*args, **kwargs):
152+
if len(args) == 1 and isinstance(args[0], DSSExtendableDict):
153+
self.internal_dict.update(args[0].internal_dict, **kwargs)
154+
else:
155+
self.internal_dict.update(*args, **kwargs)
156+
157+
def values(self):
158+
return self.internal_dict.values()

0 commit comments

Comments
 (0)