Skip to content

Commit 97e7334

Browse files
committed
Add helper to list/retrieve modalities from subpop analyses
1 parent ff85303 commit 97e7334

File tree

1 file changed

+120
-0
lines changed

1 file changed

+120
-0
lines changed

dataikuapi/dss/ml.py

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,86 @@ def get_partial_dependencies(self):
722722
return DSSPartialDependencies(data)
723723

724724

725+
class DSSSubpopulationModality(DSSExtendableDict):
726+
"""
727+
Object to read details of a subpopulation analysis modality
728+
729+
Do not create this object directly, use :meth:`DSSSubpopulationAnalysis.get_modality(definition)` instead
730+
"""
731+
732+
def __init__(self, computed_as_type, data):
733+
super(DSSSubpopulationModality, self).__init__(data)
734+
735+
if computed_as_type == "CATEGORY":
736+
self.definition = DSSSubpopulationCategoryModalityDefinition(data)
737+
elif computed_as_type == "NUMERIC":
738+
self.definition = DSSSubpopulationNumericModalityDefinition(data)
739+
740+
def get_raw(self):
741+
"""
742+
Gets the raw dictionary of the subpopulation analysis modality
743+
"""
744+
return self.internal_dict
745+
746+
def get_definition(self):
747+
"""
748+
Gets the definition of the subpopulation analysis modality
749+
750+
:returns: definition
751+
:rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationModalityDefinition`
752+
"""
753+
return self.definition
754+
755+
def excluded(self):
756+
"""
757+
Whether modality has been excluded from analysis (e.g. too few rows in the subpopulation)
758+
"""
759+
return self.get("excluded")
760+
761+
def get_perf(self):
762+
"""
763+
Gets the performance of the modality
764+
"""
765+
if self.excluded():
766+
raise ValueError("Excluded modalities do not have perf")
767+
return self.get("perf")
768+
769+
770+
class DSSSubpopulationModalityDefinition(object):
771+
772+
def __init__(self, data):
773+
self.missing_values = data.get("missing_values")
774+
self.index = data.get("index")
775+
776+
def missing_values(self):
777+
return self.missing_values
778+
779+
780+
class DSSSubpopulationNumericModalityDefinition(DSSSubpopulationModalityDefinition):
781+
782+
def __init__(self, data):
783+
super(DSSSubpopulationNumericModalityDefinition, self).__init__(data)
784+
self.lte = data.get("lte", None)
785+
self.gt = data.get("gt", None)
786+
self.gte = data.get("gte", None)
787+
788+
def contains(self, value):
789+
lte = self.lte if self.lte is not None else float("inf")
790+
gt = self.gt if self.gt is not None else float("-inf")
791+
gte = self.gte if self.gte is not None else float("-inf")
792+
return gt < value and gte <= value and lte >= value
793+
794+
795+
class DSSSubpopulationCategoryModalityDefinition(DSSSubpopulationModalityDefinition):
796+
797+
def __init__(self, data):
798+
super(DSSSubpopulationCategoryModalityDefinition, self).__init__(data)
799+
self.value = data.get("value", None)
800+
801+
def contains(self, value):
802+
return value == self.value
803+
804+
725805
class DSSSubpopulationAnalysis(DSSExtendableDict):
726806
"""
727807
Object to read details of a subpopulation analysis of a trained model
@@ -731,6 +811,8 @@ class DSSSubpopulationAnalysis(DSSExtendableDict):
731811

732812
def __init__(self, analysis):
733813
super(DSSSubpopulationAnalysis, self).__init__(analysis)
814+
self.computed_as_type = self.get("computed_as_type")
815+
self.modalities = [DSSSubpopulationModality(self.computed_as_type, m) for m in self.get("modalities", [])]
734816

735817
def get_computation_params(self):
736818
"""
@@ -741,6 +823,44 @@ def get_computation_params(self):
741823
computation_params["randomState"] = self.get("randomState")
742824
computation_params["onSample"] = self.get("onSample")
743825
return computation_params
826+
827+
def list_modalities(self):
828+
"""
829+
List definitions of modalities
830+
"""
831+
return [m.definition for m in self.modalities]
832+
833+
def get_modality(self, definition=None, missing_values=False):
834+
"""
835+
Retrieves modality from definition
836+
837+
:param definition: definition of modality to retrieve. Can be:
838+
* :class:`dataikuapi.dss.ml.DSSSubpopulationModalityDefinition`
839+
* for category modality, can be a str corresponding to the value of the modality
840+
* for numeric modality, can be a number inside the modality
841+
:param missing_values: whether to retrieve modality corresponding to missing values. If True,
842+
`definition` is ignored
843+
844+
:returns: the modality
845+
:rtype: :class:`dataikuapi.dss.ml.DSSSubpopulationModality`
846+
"""
847+
848+
if missing_values:
849+
for m in self.modalities:
850+
if m.definition.missing_values:
851+
return m
852+
raise ValueError("No 'missing values' modality found")
853+
854+
if isinstance(definition, DSSSubpopulationModalityDefinition):
855+
modality_candidates = [m for m in self.modalities if m.definition.index == definition.index]
856+
if len(modality_candidates) == 0:
857+
raise ValueError("Modality with index '%s' not found" % modality["index"])
858+
return modality_candidates[0]
859+
860+
for m in self.modalities:
861+
if m.definition.contains(definition):
862+
return m
863+
raise ValueError("Modality not found")
744864

745865
def get_raw(self):
746866
"""

0 commit comments

Comments
 (0)