Skip to content

Commit 11ef45b

Browse files
committed
Add wrapper for global results of subpop analysis
1 parent 6e273dc commit 11ef45b

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

dataikuapi/dss/ml.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,44 @@ def get_partial_dependencies(self):
723723
return DSSPartialDependencies(data)
724724

725725

726+
class DSSSubpopulationGlobal(DSSExtensibleDict):
727+
"""
728+
Object to read details of performance on global dataset used for subpopulation analyses.
729+
730+
Do not create this object directly, use :meth:`DSSSubpopulationAnalyses.get_global()` instead
731+
"""
732+
733+
def __init__(self, data, prediction_type):
734+
super(DSSSubpopulationGlobal, self).__init__(data)
735+
self.prediction_type = prediction_type
736+
737+
def get_performance_metrics(self):
738+
"""
739+
Gets the performance results of the global dataset used for the subpopulation analysis
740+
"""
741+
return self.get("performanceMetrics")
742+
743+
def get_prediction_info(self):
744+
"""
745+
Gets the prediction info of the global dataset used for the subpopulation analysis
746+
"""
747+
global_metrics = self.get("perf").get("globalMetrics")
748+
if self.prediction_type == "BINARY_CLASSIFICATION":
749+
return {
750+
"predictedPositiveRatio": global_metrics["predictionAvg"][0],
751+
"actualPositiveRatio": global_metrics["targetAvg"][0],
752+
"testWeight": global_metrics["testWeight"]
753+
}
754+
elif self.prediction_type == "REGRESSION":
755+
return {
756+
"predictedAvg":global_metrics["predictionAvg"][0],
757+
"predictedStd":global_metrics["predictionStd"][0],
758+
"actualAvg":global_metrics["targetAvg"][0],
759+
"actualStd":global_metrics["targetStd"][0],
760+
"testWeight":global_metrics["testWeight"]
761+
}
762+
763+
726764
class DSSSubpopulationModality(DSSExtensibleDict):
727765
"""
728766
Object to read details of a subpopulation analysis modality
@@ -769,6 +807,9 @@ def get_performance_metrics(self):
769807
return self.get("performanceMetrics")
770808

771809
def get_prediction_info(self):
810+
"""
811+
Gets the prediction info of the modality
812+
"""
772813
if self.is_excluded():
773814
raise ValueError("Excluded modalities do not have prediction info")
774815
global_metrics = self.get("perf").get("globalMetrics")
@@ -925,6 +966,7 @@ class DSSSubpopulationAnalyses(DSSExtensibleDict):
925966

926967
def __init__(self, data, prediction_type):
927968
super(DSSSubpopulationAnalyses, self).__init__(data)
969+
self.prediction_type = prediction_type
928970
self.analyses = []
929971
for analysis in data.get("subpopulationAnalyses", []):
930972
self.analyses.append(DSSSubpopulationAnalysis(analysis, prediction_type))
@@ -939,7 +981,7 @@ def get_global(self):
939981
"""
940982
Retrieves information and performance on the full dataset used to compute the subpopulation analyses
941983
"""
942-
return self.get("global")
984+
return DSSSubpopulationGlobal(self.get("global"), self.prediction_type)
943985

944986
def list_analyses(self):
945987
"""

0 commit comments

Comments
 (0)