@@ -670,8 +670,8 @@ def compute_partial_dependencies(self, features, wait=True, sample_size=1000, ra
670670 :param int n_jobs: number of cores used for parallel training. (-1 means 'all cores')
671671 :param bool debug_mode: if True, output all logs (slower)
672672
673- :returns: if wait is True, a dict containing the Partial dependencies, else a future to wait on the result
674- :rtype: dict or :class:`dataikuapi.dss.future.DSSFuture`
673+ :returns: if wait is True, an object containing the Partial dependencies, else a future to wait on the result
674+ :rtype: :class:`dataikuapi.dss.ml.DSSPartialDependencies` or :class:`dataikuapi.dss.future.DSSFuture`
675675 """
676676
677677 body = {
@@ -697,25 +697,30 @@ def compute_partial_dependencies(self, features, wait=True, sample_size=1000, ra
697697 )
698698 future = DSSFuture (self .saved_model .client , future_response .get ("jobId" , None ), future_response )
699699 if wait :
700- return future .wait_for_result ()
700+ return DSSPartialDependencies ( future .wait_for_result ())
701701 else :
702702 return future
703703
704704 def get_partial_dependencies (self ):
705705 """
706- Retrieve all partial dependencies computed for this trained model as a dict
706+ Retrieve all partial dependencies computed for this trained model
707+
708+ :returns: the partial dependencies
709+ :rtype: :class:`dataikuapi.dss.ml.DSSPartialDependencies`
707710 """
708711
709712 if self .mltask is not None :
710- return self .mltask .client ._perform_json (
713+ data = self .mltask .client ._perform_json (
711714 "GET" , "/projects/%s/models/lab/%s/%s/models/%s/partial-dependencies" %
712715 (self .mltask .project_key , self .mltask .analysis_id , self .mltask .mltask_id , self .mltask_model_id )
713716 )
714717 else :
715- return self .saved_model .client ._perform_json (
718+ data = self .saved_model .client ._perform_json (
716719 "GET" , "/projects/%s/savedmodels/%s/versions/%s/partial-dependencies" %
717720 (self .saved_model .project_key , self .saved_model .sm_id , self .saved_model_version ),
718721 )
722+ return DSSPartialDependencies (data )
723+
719724
720725class DSSSubpopulationAnalysis (DSSExtendableDict ):
721726 """
@@ -785,6 +790,68 @@ def get_analysis(self, feature):
785790 return next (analysis for analysis in self .analyses if analysis ["feature" ] == feature )
786791
787792
793+ class DSSPartialDependence (DSSExtendableDict ):
794+ """
795+ Object to read details of partial dependence of a trained model
796+
797+ Do not create this object directly, use :meth:`DSSPartialDependencies.get_partial_dependence(feature)` instead
798+ """
799+
800+ def __init__ (self , data ):
801+ super (DSSPartialDependence , self ).__init__ (data )
802+
803+ def get_computation_params (self ):
804+ """
805+ Gets computation params
806+ """
807+ computation_params = {}
808+ computation_params ["nbRecords" ] = self .get ("nbRecords" )
809+ computation_params ["randomState" ] = self .get ("randomState" )
810+ computation_params ["onSample" ] = self .get ("onSample" )
811+ return computation_params
812+
813+ def get_raw (self ):
814+ """
815+ Gets the raw dictionary of the partial dependence
816+ """
817+ return self .internal_dict
818+
819+
820+ class DSSPartialDependencies (DSSExtendableDict ):
821+ """
822+ Object to read details of partial dependencies of a trained model
823+
824+ Do not create this object directly, use :meth:`DSSTrainedPredictionModelDetails.get_partial_dependencies()` instead
825+ """
826+
827+ def __init__ (self , data ):
828+ super (DSSPartialDependencies , self ).__init__ (data )
829+ self .partial_dependencies = []
830+ for pd in data .get ("partialDependencies" , []):
831+ self .partial_dependencies .append (DSSPartialDependence (pd ))
832+
833+ def get_raw (self ):
834+ """
835+ Gets the raw dictionary of partial dependencies
836+ """
837+ return self .internal_dict
838+
839+ def list_partial_dependencies (self ):
840+ """
841+ Lists all features on which partial dependencies have been computed
842+ """
843+ return [partial_dep ["feature" ] for partial_dep in self .partial_dependencies ]
844+
845+ def get_partial_dependence (self , feature ):
846+ """
847+ Retrieves the subpopulation analysis for a particular feature
848+ """
849+ if feature not in self .list_partial_dependencies ():
850+ raise ValueError ("Partial dependence for feature '%s' cannot be found" % feature )
851+
852+ return next (pd for pd in self .partial_dependencies if pd ["feature" ] == feature )
853+
854+
788855class DSSClustersFacts (object ):
789856 def __init__ (self , clusters_facts ):
790857 self .clusters_facts = clusters_facts
0 commit comments