Skip to content

Commit b7dda42

Browse files
committed
Add syntactic sugar for numerical hyperparameters
1 parent 3e60674 commit b7dda42

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

dataikuapi/dss/ml.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -829,17 +829,28 @@ def set_range(self, min=None, max=None, nb_values=None):
829829

830830
@property
831831
def range(self):
832-
return Range(self)
832+
return RangeSettings(self)
833833

834834

835835
class Range(object):
836836

837+
def __init__(self, min, max, nb_values=3):
838+
self.min = min
839+
self.max = max
840+
self.nb_values = nb_values
841+
842+
def __repr__(self):
843+
return "Range(min={}, max={}, nb_values={})".format(self.min, self.max, self.nb_values)
844+
845+
846+
class RangeSettings(object):
847+
837848
def __init__(self, numerical_hyperparameter_settings):
838849
self._numerical_hyperparameter_settings = numerical_hyperparameter_settings
839850
self._range_dict = self._numerical_hyperparameter_settings._algo_settings[numerical_hyperparameter_settings.name]["range"]
840851

841852
def __repr__(self):
842-
return "Range(min={}, max={}, nb_values={})".format(self.min, self.max, self.nb_values)
853+
return "RangeSettings(min={}, max={}, nb_values={})".format(self.min, self.max, self.nb_values)
843854

844855
@property
845856
def min(self):
@@ -1047,7 +1058,10 @@ def __setattr__(self, attr_name, value):
10471058
elif isinstance(target, CategoricalHyperparameterSettings):
10481059
target.set_values(value)
10491060
elif isinstance(target, NumericalHyperparameterSettings):
1050-
raise Exception("Invalid assignment of a NumericalHyperparameterSettings object")
1061+
if isinstance(value, list):
1062+
target.set_explicit_values(values=value)
1063+
elif isinstance(value, Range):
1064+
target.set_range(min=value.min, max=value.max, nb_values=value.nb_values)
10511065
else:
10521066
# simple parameter
10531067
assert isinstance(value, type(target)), "Invalid type {} for parameter {}: expected {}".format(type(value), attr_name, type(target))

0 commit comments

Comments
 (0)