Skip to content

Commit 37cdbe3

Browse files
committed
Various nipticks and clarifications in subpop and pdp wrappers
1 parent 97e7334 commit 37cdbe3

File tree

1 file changed

+32
-30
lines changed

1 file changed

+32
-30
lines changed

dataikuapi/dss/ml.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -752,28 +752,30 @@ def get_definition(self):
752752
"""
753753
return self.definition
754754

755-
def excluded(self):
755+
def is_excluded(self):
756756
"""
757757
Whether modality has been excluded from analysis (e.g. too few rows in the subpopulation)
758758
"""
759-
return self.get("excluded")
759+
return self.get("excluded", False)
760760

761761
def get_perf(self):
762762
"""
763-
Gets the performance of the modality
763+
Gets the performance results of the modality
764764
"""
765-
if self.excluded():
765+
if self.is_excluded():
766766
raise ValueError("Excluded modalities do not have perf")
767767
return self.get("perf")
768768

769769

770770
class DSSSubpopulationModalityDefinition(object):
771771

772+
MISSING_VALUES = "__DSSSubpopulationModalidityDefinition__MISSINGVALUES"
773+
772774
def __init__(self, data):
773-
self.missing_values = data.get("missing_values")
775+
self.missing_values = data.get("missing_values", False)
774776
self.index = data.get("index")
775777

776-
def missing_values(self):
778+
def is_missing_values(self):
777779
return self.missing_values
778780

779781

@@ -789,7 +791,7 @@ def contains(self, value):
789791
lte = self.lte if self.lte is not None else float("inf")
790792
gt = self.gt if self.gt is not None else float("-inf")
791793
gte = self.gte if self.gte is not None else float("-inf")
792-
return gt < value and gte <= value and lte >= value
794+
return not self.missing_values and gt < value and gte <= value and lte >= value
793795

794796

795797
class DSSSubpopulationCategoryModalityDefinition(DSSSubpopulationModalityDefinition):
@@ -818,34 +820,34 @@ def get_computation_params(self):
818820
"""
819821
Gets computation params
820822
"""
821-
computation_params = {}
822-
computation_params["nbRecords"] = self.get("nbRecords")
823-
computation_params["randomState"] = self.get("randomState")
824-
computation_params["onSample"] = self.get("onSample")
825-
return computation_params
823+
return {
824+
"nbRecords": self.get("nbRecords"),
825+
"randomState": self.get("randomState"),
826+
"onSample": self.get("onSample")
827+
}
826828

827829
def list_modalities(self):
828830
"""
829831
List definitions of modalities
830832
"""
831833
return [m.definition for m in self.modalities]
832834

833-
def get_modality(self, definition=None, missing_values=False):
835+
def get_modality(self, definition=None):
834836
"""
835837
Retrieves modality from definition
836838
837839
:param definition: definition of modality to retrieve. Can be:
838840
* :class:`dataikuapi.dss.ml.DSSSubpopulationModalityDefinition`
841+
* `dataikuapi.dss.ml.DSSSubpopulationModalityDefinition.MISSING_VALUES`
842+
to retrieve modality corresponding to missing values
839843
* for category modality, can be a str corresponding to the value of the modality
840844
* for numeric modality, can be a number inside the modality
841-
:param missing_values: whether to retrieve modality corresponding to missing values. If True,
842-
`definition` is ignored
843-
845+
844846
:returns: the modality
845847
:rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationModality`
846848
"""
847849

848-
if missing_values:
850+
if definition == DSSSubpopulationModalityDefinition.MISSING_VALUES:
849851
for m in self.modalities:
850852
if m.definition.missing_values:
851853
return m
@@ -860,7 +862,7 @@ def get_modality(self, definition=None, missing_values=False):
860862
for m in self.modalities:
861863
if m.definition.contains(definition):
862864
return m
863-
raise ValueError("Modality not found")
865+
raise ValueError("Modality not found: %s" % definition)
864866

865867
def get_raw(self):
866868
"""
@@ -904,11 +906,11 @@ def get_analysis(self, feature):
904906
"""
905907
Retrieves the subpopulation analysis for a particular feature
906908
"""
907-
if feature not in self.list_analyses():
909+
try:
910+
return next(analysis for analysis in self.analyses if analysis["feature"] == feature)
911+
except StopIteration:
908912
raise ValueError("Subpopulation analysis for feature '%s' cannot be found" % feature)
909913

910-
return next(analysis for analysis in self.analyses if analysis["feature"] == feature)
911-
912914

913915
class DSSPartialDependence(DSSExtendableDict):
914916
"""
@@ -924,11 +926,11 @@ def get_computation_params(self):
924926
"""
925927
Gets computation params
926928
"""
927-
computation_params = {}
928-
computation_params["nbRecords"] = self.get("nbRecords")
929-
computation_params["randomState"] = self.get("randomState")
930-
computation_params["onSample"] = self.get("onSample")
931-
return computation_params
929+
return {
930+
"nbRecords": self.get("nbRecords"),
931+
"randomState": self.get("randomState"),
932+
"onSample": self.get("onSample")
933+
}
932934

933935
def get_raw(self):
934936
"""
@@ -964,13 +966,13 @@ def list_partial_dependencies(self):
964966

965967
def get_partial_dependence(self, feature):
966968
"""
967-
Retrieves the subpopulation analysis for a particular feature
969+
Retrieves the partial dependencies for a particular feature
968970
"""
969-
if feature not in self.list_partial_dependencies():
971+
try:
972+
return next(pd for pd in self.partial_dependencies if pd["feature"] == feature)
973+
except StopIteration:
970974
raise ValueError("Partial dependence for feature '%s' cannot be found" % feature)
971975

972-
return next(pd for pd in self.partial_dependencies if pd["feature"] == feature)
973-
974976

975977
class DSSClustersFacts(object):
976978
def __init__(self, clusters_facts):

0 commit comments

Comments
 (0)