@@ -633,7 +633,8 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
633633 )
634634 future = DSSFuture (self .saved_model .client , future_response .get ("jobId" , None ), future_response )
635635 if wait :
636- return DSSSubpopulationAnalyses (future .wait_for_result ())
636+ prediction_type = self .details .get ("coreParams" , {}).get ("prediction_type" )
637+ return DSSSubpopulationAnalyses (future .wait_for_result (), prediction_type )
637638 else :
638639 return future
639640
@@ -656,7 +657,8 @@ def get_subpopulation_analyses(self):
656657 "GET" , "/projects/%s/savedmodels/%s/versions/%s/subpopulation-analyses" %
657658 (self .saved_model .project_key , self .saved_model .sm_id , self .saved_model_version ),
658659 )
659- return DSSSubpopulationAnalyses (data )
660+ prediction_type = self .details .get ("coreParams" , {}).get ("prediction_type" )
661+ return DSSSubpopulationAnalyses (data , prediction_type )
660662
661663 def compute_partial_dependencies (self , features , wait = True , sample_size = 1000 , random_state = 1337 , n_jobs = 1 , debug_mode = False ):
662664 """
@@ -728,9 +730,10 @@ class DSSSubpopulationModality(DSSExtensibleDict):
728730 Do not create this object directly, use :meth:`DSSSubpopulationAnalysis.get_modality_data(definition)` instead
729731 """
730732
731- def __init__ (self , feature_name , computed_as_type , data ):
733+ def __init__ (self , feature_name , computed_as_type , data , prediction_type ):
732734 super (DSSSubpopulationModality , self ).__init__ (data )
733735
736+ self .prediction_type = prediction_type
734737 if computed_as_type == "CATEGORY" :
735738 self .definition = DSSSubpopulationCategoryModalityDefinition (feature_name , data )
736739 elif computed_as_type == "NUMERIC" :
@@ -757,13 +760,32 @@ def is_excluded(self):
757760 """
758761 return self .get ("excluded" , False )
759762
760- def get_perf (self ):
763+ def get_performance_metrics (self ):
761764 """
762765 Gets the performance results of the modality
763766 """
764767 if self .is_excluded ():
765- raise ValueError ("Excluded modalities do not have perf" )
766- return self .get ("perf" )
768+ raise ValueError ("Excluded modalities do not have performance metrics" )
769+ return self .get ("performanceMetrics" )
770+
771+ def get_prediction_info (self ):
772+ if self .is_excluded ():
773+ raise ValueError ("Excluded modalities do not have prediction info" )
774+ global_metrics = self .get ("perf" ).get ("globalMetrics" )
775+ if self .prediction_type == "BINARY_CLASSIFICATION" :
776+ return {
777+ "predictedPositiveRatio" : global_metrics ["predictionAvg" ][0 ],
778+ "actualPositiveRatio" : global_metrics ["targetAvg" ][0 ],
779+ "testWeight" : global_metrics ["testWeight" ]
780+ }
781+ elif self .prediction_type == "REGRESSION" :
782+ return {
783+ "predictedAvg" :global_metrics ["predictionAvg" ][0 ],
784+ "predictedStd" :global_metrics ["predictionStd" ][0 ],
785+ "actualAvg" :global_metrics ["targetAvg" ][0 ],
786+ "actualStd" :global_metrics ["targetStd" ][0 ],
787+ "testWeight" :global_metrics ["testWeight" ]
788+ }
767789
768790
769791class DSSSubpopulationModalityDefinition (object ):
@@ -834,10 +856,10 @@ class DSSSubpopulationAnalysis(DSSExtensibleDict):
834856 Do not create this object directly, use :meth:`DSSSubpopulationAnalyses.get_analysis(feature)` instead
835857 """
836858
837- def __init__ (self , analysis ):
859+ def __init__ (self , analysis , prediction_type ):
838860 super (DSSSubpopulationAnalysis , self ).__init__ (analysis )
839861 self .computed_as_type = self .get ("computed_as_type" )
840- self .modalities = [DSSSubpopulationModality (analysis .get ("feature" ), self .computed_as_type , m ) for m in self .get ("modalities" , [])]
862+ self .modalities = [DSSSubpopulationModality (analysis .get ("feature" ), self .computed_as_type , m , prediction_type ) for m in self .get ("modalities" , [])]
841863
842864 def get_computation_params (self ):
843865 """
@@ -901,11 +923,11 @@ class DSSSubpopulationAnalyses(DSSExtensibleDict):
901923 Do not create this object directly, use :meth:`DSSTrainedPredictionModelDetails.get_subpopulation_analyses()` instead
902924 """
903925
904- def __init__ (self , data ):
926+ def __init__ (self , data , prediction_type ):
905927 super (DSSSubpopulationAnalyses , self ).__init__ (data )
906928 self .analyses = []
907929 for analysis in data .get ("subpopulationAnalyses" , []):
908- self .analyses .append (DSSSubpopulationAnalysis (analysis ))
930+ self .analyses .append (DSSSubpopulationAnalysis (analysis , prediction_type ))
909931
910932 def get_raw (self ):
911933 """
0 commit comments