Skip to content

Commit b872440

Browse files
committed
Fix syntactic sugar handling of attributes whose names in the PredictionAlgorithmSettings object are not identical to the json key
1 parent c545a1f commit b872440

File tree

1 file changed

+42
-31
lines changed

1 file changed

+42
-31
lines changed

dataikuapi/dss/ml.py

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -940,41 +940,52 @@ def __init__(self, raw_settings, hyperparameter_search_params):
940940
super(PredictionAlgorithmSettings, self).__init__(raw_settings)
941941
self._hyperparameter_search_params = hyperparameter_search_params
942942
self._hyperparameters_registry = dict()
943+
self._attr_to_json_remapping = dict()
943944

944-
def __setattr__(self, key, value):
945-
if not hasattr(self, key):
945+
def __setattr__(self, attr_name, value):
946+
if not hasattr(self, attr_name):
946947
# call from __init__
947-
super(PredictionAlgorithmSettings, self).__setattr__(key, value)
948-
elif key in self._hyperparameters_registry:
949-
# syntactic sugars
950-
target = self._hyperparameters_registry[key]
951-
if isinstance(target, (SingleValueHyperparameterSettings, SingleCategoryHyperparameterSettings)):
952-
target.set_value(value)
953-
elif isinstance(target, CategoricalHyperparameterSettings):
954-
target.set_values(value)
955-
else:
956-
raise Exception("Invalid assignment of a NumericalHyperparameterSettings object")
957-
elif key == "lambda_":
958-
raise Exception("Invalid assignment of a NumericalHyperparameterSettings object")
948+
super(PredictionAlgorithmSettings, self).__setattr__(attr_name, value)
959949
else:
960-
# other cases (properties setter, new attribute...)
961-
super(PredictionAlgorithmSettings, self).__setattr__(key, value)
950+
if attr_name in self._attr_to_json_remapping:
951+
# attribute name and json key mismatch (e.g. "lambda", "alphaMode")
952+
attr_name = self._attr_to_json_remapping[attr_name]
953+
if attr_name in self._hyperparameters_registry:
954+
# syntactic sugars
955+
target = self._hyperparameters_registry[attr_name]
956+
if isinstance(target, (SingleValueHyperparameterSettings, SingleCategoryHyperparameterSettings)):
957+
target.set_value(value)
958+
elif isinstance(target, CategoricalHyperparameterSettings):
959+
target.set_values(value)
960+
else:
961+
raise Exception("Invalid assignment of a NumericalHyperparameterSettings object")
962+
else:
963+
# other cases (properties setter, new attribute...)
964+
super(PredictionAlgorithmSettings, self).__setattr__(attr_name, value)
962965

963-
def _register_numerical_hyperparameter(self, name):
964-
self._hyperparameters_registry[name] = NumericalHyperparameterSettings(name, self)
965-
return self._hyperparameters_registry[name]
966+
def _maybe_register_attr_json_mismatch(self, json_key, attr_name):
967+
if attr_name is not None:
968+
self._attr_to_json_remapping[attr_name] = json_key
966969

967-
def _register_categorical_hyperparameter(self, name):
968-
self._hyperparameters_registry[name] = CategoricalHyperparameterSettings(name, self)
969-
return self._hyperparameters_registry[name]
970+
def _register_numerical_hyperparameter(self, json_key, attr_name=None):
971+
self._maybe_register_attr_json_mismatch(json_key, attr_name)
972+
self._hyperparameters_registry[json_key] = NumericalHyperparameterSettings(json_key, self)
973+
return self._hyperparameters_registry[json_key]
970974

971-
def _register_single_category_hyperparameter(self, name, accepted_values=None):
972-
self._hyperparameters_registry[name] = SingleCategoryHyperparameterSettings(name, self, accepted_values=accepted_values)
973-
return self._hyperparameters_registry[name]
975+
def _register_categorical_hyperparameter(self, json_key, attr_name=None):
976+
self._maybe_register_attr_json_mismatch(json_key, attr_name)
977+
self._hyperparameters_registry[json_key] = CategoricalHyperparameterSettings(json_key, self)
978+
return self._hyperparameters_registry[json_key]
974979

975-
def _register_single_value_hyperparameter(self, name, accepted_types=None):
976-
self._hyperparameters_registry[name] = SingleValueHyperparameterSettings(name, self, accepted_types=accepted_types)
977-
return self._hyperparameters_registry[name]
980+
def _register_single_category_hyperparameter(self, json_key, accepted_values=None, attr_name=None):
981+
self._maybe_register_attr_json_mismatch(json_key, attr_name)
982+
self._hyperparameters_registry[json_key] = SingleCategoryHyperparameterSettings(json_key, self, accepted_values=accepted_values)
983+
return self._hyperparameters_registry[json_key]
984+
985+
def _register_single_value_hyperparameter(self, json_key, accepted_types=None, attr_name=None):
986+
self._maybe_register_attr_json_mismatch(json_key, attr_name)
987+
self._hyperparameters_registry[json_key] = SingleValueHyperparameterSettings(json_key, self, accepted_types=accepted_types)
988+
return self._hyperparameters_registry[json_key]
978989

979990
def _repr_html_(self):
980991
res = "<pre>" + self.__class__.__name__ + "\n"
@@ -1045,7 +1056,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
10451056
self.colsample_bytree = self._register_numerical_hyperparameter("colsample_bytree")
10461057
self.colsample_bylevel = self._register_numerical_hyperparameter("colsample_bylevel")
10471058
self.alpha = self._register_numerical_hyperparameter("alpha")
1048-
self.lambda_ = self._register_numerical_hyperparameter("lambda")
1059+
self.lambda_ = self._register_numerical_hyperparameter("lambda", attr_name="lambda_")
10491060
self.booster = self._register_categorical_hyperparameter("booster")
10501061
self.objective = self._register_categorical_hyperparameter("objective")
10511062
self.n_estimators = self._register_single_value_hyperparameter("n_estimators", accepted_types=[int])
@@ -1101,15 +1112,15 @@ class RidgeRegressionSettings(PredictionAlgorithmSettings):
11011112
def __init__(self, raw_settings, hyperparameter_search_params):
11021113
super(RidgeRegressionSettings, self).__init__(raw_settings, hyperparameter_search_params)
11031114
self.alpha = self._register_numerical_hyperparameter("alpha")
1104-
self.alpha_mode = self._register_single_category_hyperparameter("alphaMode", accepted_values=["MANUAL", "AUTO"])
1115+
self.alpha_mode = self._register_single_category_hyperparameter("alphaMode", accepted_values=["MANUAL", "AUTO"], attr_name="alpha_mode")
11051116

11061117

11071118
class LassoRegressionSettings(PredictionAlgorithmSettings):
11081119

11091120
def __init__(self, raw_settings, hyperparameter_search_params):
11101121
super(LassoRegressionSettings, self).__init__(raw_settings, hyperparameter_search_params)
11111122
self.alpha = self._register_numerical_hyperparameter("alpha")
1112-
self.alpha_mode = self._register_single_category_hyperparameter("alphaMode", accepted_values=["MANUAL", "AUTO_CV", "AUTO_IC"]) # TODO: enforce attribute name = parameter name ?
1123+
self.alpha_mode = self._register_single_category_hyperparameter("alphaMode", accepted_values=["MANUAL", "AUTO_CV", "AUTO_IC"], attr_name="alpha_mode")
11131124

11141125

11151126
class OLSSettings(PredictionAlgorithmSettings):

0 commit comments

Comments
 (0)