Skip to content

Commit 2c11d20

Browse files
committed
Improvements and fixes in subpop and pdp wrappers
1 parent 83d9bc1 commit 2c11d20

File tree

2 files changed

+45
-22
lines changed

2 files changed

+45
-22
lines changed

dataikuapi/dss/ml.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
600600
"""
601601
Launch computation of Subpopulation analyses for this trained model.
602602
603-
:param list split_by: columns on which subpopulation analyses are to be computed (one analysis per column)
603+
:param list|str split_by: column(s) on which subpopulation analyses are to be computed (one analysis per column)
604604
:param bool wait: if True, the call blocks until the computation is finished and returns the results directly
605605
:param int sample_size: number of records of the dataset to use for the computation
606606
:param int random_state: random state to use to build sample, for reproducibility
@@ -610,9 +610,8 @@ def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000,
610610
:returns: if wait is True, an object containing the Subpopulation analyses, else a future to wait on the result
611611
:rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationAnalyses` or :class:`dataikuapi.dss.future.DSSFuture`
612612
"""
613-
614613
body = {
615-
"features": split_by,
614+
"features": split_by if isinstance(split_by, list) else [split_by],
616615
"computationParams": {
617616
"sample_size": sample_size,
618617
"random_state": random_state,
@@ -663,7 +662,7 @@ def compute_partial_dependencies(self, features, wait=True, sample_size=1000, ra
663662
"""
664663
Launch computation of Partial dependencies for this trained model.
665664
666-
:param list features: features on which partial dependencies are to be computed
665+
:param list|str features: feature(s) on which partial dependencies are to be computed
667666
:param bool wait: if True, the call blocks until the computation is finished and returns the results directly
668667
:param int sample_size: number of records of the dataset to use for the computation
669668
:param int random_state: random state to use to build sample, for reproducibility
@@ -675,7 +674,7 @@ def compute_partial_dependencies(self, features, wait=True, sample_size=1000, ra
675674
"""
676675

677676
body = {
678-
"features": features,
677+
"features": features if isinstance(features, list) else [features],
679678
"computationParams": {
680679
"sample_size": sample_size,
681680
"random_state": random_state,
@@ -726,16 +725,16 @@ class DSSSubpopulationModality(DSSExtensibleDict):
726725
"""
727726
Object to read details of a subpopulation analysis modality
728727
729-
Do not create this object directly, use :meth:`DSSSubpopulationAnalysis.get_modality(definition)` instead
728+
Do not create this object directly, use :meth:`DSSSubpopulationAnalysis.get_modality_data(definition)` instead
730729
"""
731730

732-
def __init__(self, computed_as_type, data):
731+
def __init__(self, feature_name, computed_as_type, data):
733732
super(DSSSubpopulationModality, self).__init__(data)
734733

735734
if computed_as_type == "CATEGORY":
736-
self.definition = DSSSubpopulationCategoryModalityDefinition(data)
735+
self.definition = DSSSubpopulationCategoryModalityDefinition(feature_name, data)
737736
elif computed_as_type == "NUMERIC":
738-
self.definition = DSSSubpopulationNumericModalityDefinition(data)
737+
self.definition = DSSSubpopulationNumericModalityDefinition(feature_name, data)
739738

740739
def get_raw(self):
741740
"""
@@ -771,18 +770,19 @@ class DSSSubpopulationModalityDefinition(object):
771770

772771
MISSING_VALUES = "__DSSSubpopulationModalidityDefinition__MISSINGVALUES"
773772

774-
def __init__(self, data):
773+
def __init__(self, feature_name, data):
775774
self.missing_values = data.get("missing_values", False)
776775
self.index = data.get("index")
776+
self.feature_name = feature_name
777777

778778
def is_missing_values(self):
779779
return self.missing_values
780780

781781

782782
class DSSSubpopulationNumericModalityDefinition(DSSSubpopulationModalityDefinition):
783783

784-
def __init__(self, data):
785-
super(DSSSubpopulationNumericModalityDefinition, self).__init__(data)
784+
def __init__(self, feature_name, data):
785+
super(DSSSubpopulationNumericModalityDefinition, self).__init__(feature_name, data)
786786
self.lte = data.get("lte", None)
787787
self.gt = data.get("gt", None)
788788
self.gte = data.get("gte", None)
@@ -792,17 +792,40 @@ def contains(self, value):
792792
gt = self.gt if self.gt is not None else float("-inf")
793793
gte = self.gte if self.gte is not None else float("-inf")
794794
return not self.missing_values and gt < value and gte <= value and lte >= value
795+
796+
def __repr__(self):
797+
if self.missing_values:
798+
return "DSSSubpopulationNumericModalityDefinition(missing_values)"
799+
else:
800+
if self.gt is not None:
801+
repr_gt = "%s<" % self.gt
802+
elif self.gte is not None:
803+
repr_gt = "%s<=" % self.gte
804+
else:
805+
repr_gt = ""
795806

807+
if self.lte is not None:
808+
repr_lt = "<=%s" % self.lte
809+
else:
810+
repr_lt = ""
811+
812+
return "DSSSubpopulationNumericModalityDefinition(%s%s%s)" % (repr_gt, self.feature_name, repr_lt)
796813

797814
class DSSSubpopulationCategoryModalityDefinition(DSSSubpopulationModalityDefinition):
798815

799-
def __init__(self, data):
800-
super(DSSSubpopulationCategoryModalityDefinition, self).__init__(data)
816+
def __init__(self, feature_name, data):
817+
super(DSSSubpopulationCategoryModalityDefinition, self).__init__(feature_name, data)
801818
self.value = data.get("value", None)
802819

803820
def contains(self, value):
804821
return value == self.value
805822

823+
def __repr__(self):
824+
if self.missing_values:
825+
return "DSSSubpopulationCategoryModalityDefinition(missing_values)"
826+
else:
827+
return "DSSSubpopulationCategoryModalityDefinition(%s='%s')" % (self.feature_name, self.value)
828+
806829

807830
class DSSSubpopulationAnalysis(DSSExtensibleDict):
808831
"""
@@ -814,7 +837,7 @@ class DSSSubpopulationAnalysis(DSSExtensibleDict):
814837
def __init__(self, analysis):
815838
super(DSSSubpopulationAnalysis, self).__init__(analysis)
816839
self.computed_as_type = self.get("computed_as_type")
817-
self.modalities = [DSSSubpopulationModality(self.computed_as_type, m) for m in self.get("modalities", [])]
840+
self.modalities = [DSSSubpopulationModality(analysis.get("feature"), self.computed_as_type, m) for m in self.get("modalities", [])]
818841

819842
def get_computation_params(self):
820843
"""
@@ -832,7 +855,7 @@ def list_modalities(self):
832855
"""
833856
return [m.definition for m in self.modalities]
834857

835-
def get_modality(self, definition=None):
858+
def get_modality_data(self, definition=None):
836859
"""
837860
Retrieves modality from definition
838861
@@ -890,11 +913,11 @@ def get_raw(self):
890913
"""
891914
return self.internal_dict
892915

893-
def get_all_dataset(self):
916+
def get_global(self):
894917
"""
895-
Retrieve information and performance on the full dataset used to compute the subpopulation analyses
918+
Retrieves information and performance on the full dataset used to compute the subpopulation analyses
896919
"""
897-
return self.get("allDataset")
920+
return self.get("global")
898921

899922
def list_analyses(self):
900923
"""
@@ -958,7 +981,7 @@ def get_raw(self):
958981
"""
959982
return self.internal_dict
960983

961-
def list_partial_dependencies(self):
984+
def list_features(self):
962985
"""
963986
Lists all features on which partial dependencies have been computed
964987
"""

dataikuapi/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,8 @@ def __iter__(self):
121121
def __setitem__(self, key, value):
122122
self.internal_dict[key] = value
123123

124-
def __str__(self):
125-
return self.internal_dict.__str__()
124+
def __repr__(self):
125+
return self.__class__.__name__ + "(" + self.internal_dict.__repr__() + ")"
126126

127127
def __len__(self):
128128
return self.internal_dict.__len__()

0 commit comments

Comments
 (0)