11from ..utils import DataikuException
22from ..utils import DataikuUTF8CSVReader
33from ..utils import DataikuStreamedHttpUTF8CSVReader
4+ from ..utils import DSSExtendableDict
45import json
56import time
67from .metrics import ComputedMetrics
@@ -606,8 +607,8 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
606607 :param int n_jobs: number of cores used for parallel training. (-1 means 'all cores')
607608 :param bool debug_mode: if True, output all logs (slower)
608609
609- :returns: if wait is True, a dict containing the Subpopulation analyses, else a future to wait on the result
610- :rtype: dict or :class:`dataikuapi.dss.future.DSSFuture`
610+ :returns: if wait is True, an object containing the Subpopulation analyses, else a future to wait on the result
611+ :rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationAnalyses` or :class:`dataikuapi.dss.future.DSSFuture`
611612 """
612613
613614 body = {
@@ -633,26 +634,30 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
633634 )
634635 future = DSSFuture (self .saved_model .client , future_response .get ("jobId" , None ), future_response )
635636 if wait :
636- return future .wait_for_result ()
637+ return DSSSubpopulationAnalyses ( future .wait_for_result () )
637638 else :
638639 return future
639640
640641
641642 def get_subpopulation_analyses (self ):
642643 """
643- Retrieve all subpopulation analyses computed for this trained model as a dict
644+ Retrieve all subpopulation analyses computed for this trained model
645+
646+ :returns: the subpopulation analyses
647+ :rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationAnalyses`
644648 """
645649
646650 if self .mltask is not None :
647- return self .mltask .client ._perform_json (
651+ data = self .mltask .client ._perform_json (
648652 "GET" , "/projects/%s/models/lab/%s/%s/models/%s/subpopulation-analyses" %
649653 (self .mltask .project_key , self .mltask .analysis_id , self .mltask .mltask_id , self .mltask_model_id )
650654 )
651655 else :
652- return self .saved_model .client ._perform_json (
656+ data = self .saved_model .client ._perform_json (
653657 "GET" , "/projects/%s/savedmodels/%s/versions/%s/subpopulation-analyses" %
654658 (self .saved_model .project_key , self .saved_model .sm_id , self .saved_model_version ),
655659 )
660+ return DSSSubpopulationAnalyses (data )
656661
657662 def compute_partial_dependencies (self , features , wait = True , sample_size = 1000 , random_state = 1337 , n_jobs = 1 , debug_mode = False ):
658663 """
@@ -712,6 +717,74 @@ def get_partial_dependencies(self):
712717 (self .saved_model .project_key , self .saved_model .sm_id , self .saved_model_version ),
713718 )
714719
720+ class DSSSubpopulationAnalysis (DSSExtendableDict ):
721+ """
722+ Object to read details of a subpopulation analysis of a trained model
723+
724+ Do not create this object directly, use :meth:`DSSSubpopulationAnalyses.get_analysis(feature)` instead
725+ """
726+
727+ def __init__ (self , analysis ):
728+ super (DSSSubpopulationAnalysis , self ).__init__ (analysis )
729+
730+ def get_computation_params (self ):
731+ """
732+ Gets computation params
733+ """
734+ computation_params = {}
735+ computation_params ["nbRecords" ] = self .get ("nbRecords" )
736+ computation_params ["randomState" ] = self .get ("randomState" )
737+ computation_params ["onSample" ] = self .get ("onSample" )
738+ return computation_params
739+
740+ def get_raw (self ):
741+ """
742+ Gets the raw dictionary of the subpopulation analysis
743+ """
744+ return self .internal_dict
745+
746+
747+ class DSSSubpopulationAnalyses (DSSExtendableDict ):
748+ """
749+ Object to read details of subpopulation analyses of a trained model
750+
751+ Do not create this object directly, use :meth:`DSSTrainedPredictionModelDetails.get_subpopulation_analyses()` instead
752+ """
753+
754+ def __init__ (self , data ):
755+ super (DSSSubpopulationAnalyses , self ).__init__ (data )
756+ self .analyses = []
757+ for analysis in data .get ("subpopulationAnalyses" , []):
758+ self .analyses .append (DSSSubpopulationAnalysis (analysis ))
759+
760+ def get_raw (self ):
761+ """
762+ Gets the raw dictionary of subpopulation analyses
763+ """
764+ return self .internal_dict
765+
766+ def get_all_dataset (self ):
767+ """
768+ Retrieve information and performance on the full dataset used to compute the subpopulation analyses
769+ """
770+ return self .get ("allDataset" )
771+
772+ def list_analyses (self ):
773+ """
774+ Lists all features on which subpopulation analyses have been computed
775+ """
776+ return [analysis ["feature" ] for analysis in self .analyses ]
777+
778+ def get_analysis (self , feature ):
779+ """
780+ Retrieves the subpopulation analysis for a particular feature
781+ """
782+ if feature not in self .list_analyses ():
783+ raise ValueError ("Subpopulation analysis for feature '%s' cannot be found" % feature )
784+
785+ return next (analysis for analysis in self .analyses if analysis ["feature" ] == feature )
786+
787+
715788class DSSClustersFacts (object ):
716789 def __init__ (self , clusters_facts ):
717790 self .clusters_facts = clusters_facts
0 commit comments