Skip to content

Commit 6e273dc

Browse files
committed
Add performance metrics and info for modalities of subpop
1 parent 2c11d20 commit 6e273dc

File tree

1 file changed

+32
-10
lines changed

1 file changed

+32
-10
lines changed

dataikuapi/dss/ml.py

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,8 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
633633
)
634634
future = DSSFuture(self.saved_model.client, future_response.get("jobId", None), future_response)
635635
if wait:
636-
return DSSSubpopulationAnalyses(future.wait_for_result())
636+
prediction_type = self.details.get("coreParams", {}).get("prediction_type")
637+
return DSSSubpopulationAnalyses(future.wait_for_result(), prediction_type)
637638
else:
638639
return future
639640

@@ -656,7 +657,8 @@ def get_subpopulation_analyses(self):
656657
"GET", "/projects/%s/savedmodels/%s/versions/%s/subpopulation-analyses" %
657658
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
658659
)
659-
return DSSSubpopulationAnalyses(data)
660+
prediction_type = self.details.get("coreParams", {}).get("prediction_type")
661+
return DSSSubpopulationAnalyses(data, prediction_type)
660662

661663
def compute_partial_dependencies(self, features, wait=True, sample_size=1000, random_state=1337, n_jobs=1, debug_mode=False):
662664
"""
@@ -728,9 +730,10 @@ class DSSSubpopulationModality(DSSExtensibleDict):
728730
Do not create this object directly, use :meth:`DSSSubpopulationAnalysis.get_modality_data(definition)` instead
729731
"""
730732

731-
def __init__(self, feature_name, computed_as_type, data):
733+
def __init__(self, feature_name, computed_as_type, data, prediction_type):
732734
super(DSSSubpopulationModality, self).__init__(data)
733735

736+
self.prediction_type = prediction_type
734737
if computed_as_type == "CATEGORY":
735738
self.definition = DSSSubpopulationCategoryModalityDefinition(feature_name, data)
736739
elif computed_as_type == "NUMERIC":
@@ -757,13 +760,32 @@ def is_excluded(self):
757760
"""
758761
return self.get("excluded", False)
759762

760-
def get_perf(self):
763+
def get_performance_metrics(self):
761764
"""
762765
Gets the performance results of the modality
763766
"""
764767
if self.is_excluded():
765-
raise ValueError("Excluded modalities do not have perf")
766-
return self.get("perf")
768+
raise ValueError("Excluded modalities do not have performance metrics")
769+
return self.get("performanceMetrics")
770+
771+
def get_prediction_info(self):
772+
if self.is_excluded():
773+
raise ValueError("Excluded modalities do not have prediction info")
774+
global_metrics = self.get("perf").get("globalMetrics")
775+
if self.prediction_type == "BINARY_CLASSIFICATION":
776+
return {
777+
"predictedPositiveRatio": global_metrics["predictionAvg"][0],
778+
"actualPositiveRatio": global_metrics["targetAvg"][0],
779+
"testWeight": global_metrics["testWeight"]
780+
}
781+
elif self.prediction_type == "REGRESSION":
782+
return {
783+
"predictedAvg":global_metrics["predictionAvg"][0],
784+
"predictedStd":global_metrics["predictionStd"][0],
785+
"actualAvg":global_metrics["targetAvg"][0],
786+
"actualStd":global_metrics["targetStd"][0],
787+
"testWeight":global_metrics["testWeight"]
788+
}
767789

768790

769791
class DSSSubpopulationModalityDefinition(object):
@@ -834,10 +856,10 @@ class DSSSubpopulationAnalysis(DSSExtensibleDict):
834856
Do not create this object directly, use :meth:`DSSSubpopulationAnalyses.get_analysis(feature)` instead
835857
"""
836858

837-
def __init__(self, analysis):
859+
def __init__(self, analysis, prediction_type):
838860
super(DSSSubpopulationAnalysis, self).__init__(analysis)
839861
self.computed_as_type = self.get("computed_as_type")
840-
self.modalities = [DSSSubpopulationModality(analysis.get("feature"), self.computed_as_type, m) for m in self.get("modalities", [])]
862+
self.modalities = [DSSSubpopulationModality(analysis.get("feature"), self.computed_as_type, m, prediction_type) for m in self.get("modalities", [])]
841863

842864
def get_computation_params(self):
843865
"""
@@ -901,11 +923,11 @@ class DSSSubpopulationAnalyses(DSSExtensibleDict):
901923
Do not create this object directly, use :meth:`DSSTrainedPredictionModelDetails.get_subpopulation_analyses()` instead
902924
"""
903925

904-
def __init__(self, data):
926+
def __init__(self, data, prediction_type):
905927
super(DSSSubpopulationAnalyses, self).__init__(data)
906928
self.analyses = []
907929
for analysis in data.get("subpopulationAnalyses", []):
908-
self.analyses.append(DSSSubpopulationAnalysis(analysis))
930+
self.analyses.append(DSSSubpopulationAnalysis(analysis, prediction_type))
909931

910932
def get_raw(self):
911933
"""

0 commit comments

Comments
 (0)