@@ -723,6 +723,44 @@ def get_partial_dependencies(self):
723723 return DSSPartialDependencies (data )
724724
725725
726+ class DSSSubpopulationGlobal (DSSExtensibleDict ):
727+ """
728+ Object to read details of performance on global dataset used for subpopulation analyses.
729+
730+ Do not create this object directly, use :meth:`DSSSubpopulationAnalyses.get_global()` instead
731+ """
732+
733+ def __init__ (self , data , prediction_type ):
734+ super (DSSSubpopulationGlobal , self ).__init__ (data )
735+ self .prediction_type = prediction_type
736+
737+ def get_performance_metrics (self ):
738+ """
739+ Gets the performance results of the global dataset used for the subpopulation analysis
740+ """
741+ return self .get ("performanceMetrics" )
742+
743+ def get_prediction_info (self ):
744+ """
745+ Gets the prediction info of the global dataset used for the subpopulation analysis
746+ """
747+ global_metrics = self .get ("perf" ).get ("globalMetrics" )
748+ if self .prediction_type == "BINARY_CLASSIFICATION" :
749+ return {
750+ "predictedPositiveRatio" : global_metrics ["predictionAvg" ][0 ],
751+ "actualPositiveRatio" : global_metrics ["targetAvg" ][0 ],
752+ "testWeight" : global_metrics ["testWeight" ]
753+ }
754+ elif self .prediction_type == "REGRESSION" :
755+ return {
756+ "predictedAvg" :global_metrics ["predictionAvg" ][0 ],
757+ "predictedStd" :global_metrics ["predictionStd" ][0 ],
758+ "actualAvg" :global_metrics ["targetAvg" ][0 ],
759+ "actualStd" :global_metrics ["targetStd" ][0 ],
760+ "testWeight" :global_metrics ["testWeight" ]
761+ }
762+
763+
726764class DSSSubpopulationModality (DSSExtensibleDict ):
727765 """
728766 Object to read details of a subpopulation analysis modality
@@ -769,6 +807,9 @@ def get_performance_metrics(self):
769807 return self .get ("performanceMetrics" )
770808
771809 def get_prediction_info (self ):
810+ """
811+ Gets the prediction info of the modality
812+ """
772813 if self .is_excluded ():
773814 raise ValueError ("Excluded modalities do not have prediction info" )
774815 global_metrics = self .get ("perf" ).get ("globalMetrics" )
@@ -925,6 +966,7 @@ class DSSSubpopulationAnalyses(DSSExtensibleDict):
925966
926967 def __init__ (self , data , prediction_type ):
927968 super (DSSSubpopulationAnalyses , self ).__init__ (data )
969+ self .prediction_type = prediction_type
928970 self .analyses = []
929971 for analysis in data .get ("subpopulationAnalyses" , []):
930972 self .analyses .append (DSSSubpopulationAnalysis (analysis , prediction_type ))
@@ -939,7 +981,7 @@ def get_global(self):
939981 """
940982 Retrieves information and performance on the full dataset used to compute the subpopulation analyses
941983 """
942- return self .get ("global" )
984+ return DSSSubpopulationGlobal ( self .get ("global" ), self . prediction_type )
943985
944986 def list_analyses (self ):
945987 """
0 commit comments