Skip to content

Commit 8464243

Browse files
committed
Add syntactic sugars for PredictionAlgorithmSettings subclasses hyperparameters
1 parent c223677 commit 8464243

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

dataikuapi/dss/ml.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -877,11 +877,11 @@ def __init__(self, name, algo_settings, accepted_values=None):
877877

878878
def __repr__(self):
879879
if self.accepted_values is not None:
880-
return self.__class__.__name__ + "(hyperparameter=\"{}\", value={}, accepted_values={})".format(self.name,
880+
return self.__class__.__name__ + "(hyperparameter=\"{}\", value=\"{}\", accepted_values={})".format(self.name,
881881
self._algo_settings[self.name],
882882
self.accepted_values)
883883
else:
884-
return self.__class__.__name__ + "(hyperparameter=\"{}\", value={})".format(self.name, self._algo_settings[self.name])
884+
return self.__class__.__name__ + "(hyperparameter=\"{}\", value=\"{}\")".format(self.name, self._algo_settings[self.name])
885885

886886
__str__ = __repr__
887887

@@ -923,6 +923,26 @@ def __init__(self, raw_settings, hyperparameter_search_params):
923923
self._hyperparameter_search_params = hyperparameter_search_params
924924
self._hyperparameters_registry = dict()
925925

926+
def __setattr__(self, key, value):
927+
if key in {"_hyperparameters_registry", "_hyperparameter_search_params"} or key in {"enabled", "strategy"}:
928+
# attributes and properties of the PredictionAlgorithmSettings object must be handled separately
929+
super(PredictionAlgorithmSettings, self).__setattr__(key, value)
930+
elif key in self._hyperparameters_registry or (key=="lambda_" and "lambda" in self._hyperparameters_registry):
931+
if isinstance(value, HyperparameterSettings):
932+
# call from a PredictionAlgorithmSettings child's __init__
933+
super(PredictionAlgorithmSettings, self).__setattr__(key, value)
934+
else:
935+
# syntactic sugars
936+
target = self._hyperparameters_registry[key]
937+
if isinstance(target, (SingleValueHyperparameterSettings, SingleCategoryHyperparameterSettings)):
938+
target.set_value(value)
939+
elif isinstance(target, CategoricalHyperparameterSettings):
940+
target.set_values(value)
941+
else:
942+
raise Exception("Invalid type for assignment of a NumericalHyperparameterSettings object")
943+
else:
944+
raise Exception("Unknown hyperparameter: \"{}\"".format(key))
945+
926946
def _register_numerical_hyperparameter(self, name):
927947
self._hyperparameters_registry[name] = NumericalHyperparameterSettings(name, self)
928948
return self._hyperparameters_registry[name]

0 commit comments

Comments
 (0)