Skip to content

Commit ff85303

Browse files
committed
Add wrapper for partial dependencies results
1 parent 98f816f commit ff85303

File tree

1 file changed

+73
-6
lines changed

1 file changed

+73
-6
lines changed

dataikuapi/dss/ml.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -670,8 +670,8 @@ def compute_partial_dependencies(self, features, wait=True, sample_size=1000, ra
670670
:param int n_jobs: number of cores used for parallel training. (-1 means 'all cores')
671671
:param bool debug_mode: if True, output all logs (slower)
672672
673-
:returns: if wait is True, a dict containing the Partial dependencies, else a future to wait on the result
674-
:rtype: dict or :class:`dataikuapi.dss.future.DSSFuture`
673+
:returns: if wait is True, an object containing the Partial dependencies, else a future to wait on the result
674+
:rtype: :class:`dataikuapi.dss.ml.DSSPartialDependencies` or :class:`dataikuapi.dss.future.DSSFuture`
675675
"""
676676

677677
body = {
@@ -697,25 +697,30 @@ def compute_partial_dependencies(self, features, wait=True, sample_size=1000, ra
697697
)
698698
future = DSSFuture(self.saved_model.client, future_response.get("jobId", None), future_response)
699699
if wait:
700-
return future.wait_for_result()
700+
return DSSPartialDependencies(future.wait_for_result())
701701
else:
702702
return future
703703

704704
def get_partial_dependencies(self):
705705
"""
706-
Retrieve all partial dependencies computed for this trained model as a dict
706+
Retrieve all partial dependencies computed for this trained model
707+
708+
:returns: the partial dependencies
709+
:rtype: :class:`dataikuapi.dss.ml.DSSPartialDependencies`
707710
"""
708711

709712
if self.mltask is not None:
710-
return self.mltask.client._perform_json(
713+
data = self.mltask.client._perform_json(
711714
"GET", "/projects/%s/models/lab/%s/%s/models/%s/partial-dependencies" %
712715
(self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id)
713716
)
714717
else:
715-
return self.saved_model.client._perform_json(
718+
data = self.saved_model.client._perform_json(
716719
"GET", "/projects/%s/savedmodels/%s/versions/%s/partial-dependencies" %
717720
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
718721
)
722+
return DSSPartialDependencies(data)
723+
719724

720725
class DSSSubpopulationAnalysis(DSSExtendableDict):
721726
"""
@@ -785,6 +790,68 @@ def get_analysis(self, feature):
785790
return next(analysis for analysis in self.analyses if analysis["feature"] == feature)
786791

787792

793+
class DSSPartialDependence(DSSExtendableDict):
794+
"""
795+
Object to read details of partial dependence of a trained model
796+
797+
Do not create this object directly, use :meth:`DSSPartialDependencies.get_partial_dependence(feature)` instead
798+
"""
799+
800+
def __init__(self, data):
801+
super(DSSPartialDependence, self).__init__(data)
802+
803+
def get_computation_params(self):
804+
"""
805+
Gets computation params
806+
"""
807+
computation_params = {}
808+
computation_params["nbRecords"] = self.get("nbRecords")
809+
computation_params["randomState"] = self.get("randomState")
810+
computation_params["onSample"] = self.get("onSample")
811+
return computation_params
812+
813+
def get_raw(self):
814+
"""
815+
Gets the raw dictionary of the partial dependence
816+
"""
817+
return self.internal_dict
818+
819+
820+
class DSSPartialDependencies(DSSExtendableDict):
821+
"""
822+
Object to read details of partial dependencies of a trained model
823+
824+
Do not create this object directly, use :meth:`DSSTrainedPredictionModelDetails.get_partial_dependencies()` instead
825+
"""
826+
827+
def __init__(self, data):
828+
super(DSSPartialDependencies, self).__init__(data)
829+
self.partial_dependencies = []
830+
for pd in data.get("partialDependencies", []):
831+
self.partial_dependencies.append(DSSPartialDependence(pd))
832+
833+
def get_raw(self):
834+
"""
835+
Gets the raw dictionary of partial dependencies
836+
"""
837+
return self.internal_dict
838+
839+
def list_partial_dependencies(self):
840+
"""
841+
Lists all features on which partial dependencies have been computed
842+
"""
843+
return [partial_dep["feature"] for partial_dep in self.partial_dependencies]
844+
845+
def get_partial_dependence(self, feature):
846+
"""
847+
Retrieves the subpopulation analysis for a particular feature
848+
"""
849+
if feature not in self.list_partial_dependencies():
850+
raise ValueError("Partial dependence for feature '%s' cannot be found" % feature)
851+
852+
return next(pd for pd in self.partial_dependencies if pd["feature"] == feature)
853+
854+
788855
class DSSClustersFacts(object):
789856
def __init__(self, clusters_facts):
790857
self.clusters_facts = clusters_facts

0 commit comments

Comments
 (0)