@@ -940,41 +940,52 @@ def __init__(self, raw_settings, hyperparameter_search_params):
940940 super (PredictionAlgorithmSettings , self ).__init__ (raw_settings )
941941 self ._hyperparameter_search_params = hyperparameter_search_params
942942 self ._hyperparameters_registry = dict ()
943+ self ._attr_to_json_remapping = dict ()
943944
944- def __setattr__ (self , key , value ):
945- if not hasattr (self , key ):
945+ def __setattr__ (self , attr_name , value ):
946+ if not hasattr (self , attr_name ):
946947 # call from __init__
947- super (PredictionAlgorithmSettings , self ).__setattr__ (key , value )
948- elif key in self ._hyperparameters_registry :
949- # syntactic sugars
950- target = self ._hyperparameters_registry [key ]
951- if isinstance (target , (SingleValueHyperparameterSettings , SingleCategoryHyperparameterSettings )):
952- target .set_value (value )
953- elif isinstance (target , CategoricalHyperparameterSettings ):
954- target .set_values (value )
955- else :
956- raise Exception ("Invalid assignment of a NumericalHyperparameterSettings object" )
957- elif key == "lambda_" :
958- raise Exception ("Invalid assignment of a NumericalHyperparameterSettings object" )
948+ super (PredictionAlgorithmSettings , self ).__setattr__ (attr_name , value )
959949 else :
960- # other cases (properties setter, new attribute...)
961- super (PredictionAlgorithmSettings , self ).__setattr__ (key , value )
950+ if attr_name in self ._attr_to_json_remapping :
951+ # attribute name and json key mismatch (e.g. "lambda", "alphaMode")
952+ attr_name = self ._attr_to_json_remapping [attr_name ]
953+ if attr_name in self ._hyperparameters_registry :
954+ # syntactic sugars
955+ target = self ._hyperparameters_registry [attr_name ]
956+ if isinstance (target , (SingleValueHyperparameterSettings , SingleCategoryHyperparameterSettings )):
957+ target .set_value (value )
958+ elif isinstance (target , CategoricalHyperparameterSettings ):
959+ target .set_values (value )
960+ else :
961+ raise Exception ("Invalid assignment of a NumericalHyperparameterSettings object" )
962+ else :
963+ # other cases (properties setter, new attribute...)
964+ super (PredictionAlgorithmSettings , self ).__setattr__ (attr_name , value )
962965
963- def _register_numerical_hyperparameter (self , name ):
964- self . _hyperparameters_registry [ name ] = NumericalHyperparameterSettings ( name , self )
965- return self ._hyperparameters_registry [ name ]
966+ def _maybe_register_attr_json_mismatch (self , json_key , attr_name ):
967+ if attr_name is not None :
968+ self ._attr_to_json_remapping [ attr_name ] = json_key
966969
967- def _register_categorical_hyperparameter (self , name ):
968- self ._hyperparameters_registry [name ] = CategoricalHyperparameterSettings (name , self )
969- return self ._hyperparameters_registry [name ]
970+ def _register_numerical_hyperparameter (self , json_key , attr_name = None ):
971+ self ._maybe_register_attr_json_mismatch (json_key , attr_name )
972+ self ._hyperparameters_registry [json_key ] = NumericalHyperparameterSettings (json_key , self )
973+ return self ._hyperparameters_registry [json_key ]
970974
971- def _register_single_category_hyperparameter (self , name , accepted_values = None ):
972- self ._hyperparameters_registry [name ] = SingleCategoryHyperparameterSettings (name , self , accepted_values = accepted_values )
973- return self ._hyperparameters_registry [name ]
975+ def _register_categorical_hyperparameter (self , json_key , attr_name = None ):
976+ self ._maybe_register_attr_json_mismatch (json_key , attr_name )
977+ self ._hyperparameters_registry [json_key ] = CategoricalHyperparameterSettings (json_key , self )
978+ return self ._hyperparameters_registry [json_key ]
974979
975- def _register_single_value_hyperparameter (self , name , accepted_types = None ):
976- self ._hyperparameters_registry [name ] = SingleValueHyperparameterSettings (name , self , accepted_types = accepted_types )
977- return self ._hyperparameters_registry [name ]
980+ def _register_single_category_hyperparameter (self , json_key , accepted_values = None , attr_name = None ):
981+ self ._maybe_register_attr_json_mismatch (json_key , attr_name )
982+ self ._hyperparameters_registry [json_key ] = SingleCategoryHyperparameterSettings (json_key , self , accepted_values = accepted_values )
983+ return self ._hyperparameters_registry [json_key ]
984+
985+ def _register_single_value_hyperparameter (self , json_key , accepted_types = None , attr_name = None ):
986+ self ._maybe_register_attr_json_mismatch (json_key , attr_name )
987+ self ._hyperparameters_registry [json_key ] = SingleValueHyperparameterSettings (json_key , self , accepted_types = accepted_types )
988+ return self ._hyperparameters_registry [json_key ]
978989
979990 def _repr_html_ (self ):
980991 res = "<pre>" + self .__class__ .__name__ + "\n "
@@ -1045,7 +1056,7 @@ def __init__(self, raw_settings, hyperparameter_search_params):
10451056 self .colsample_bytree = self ._register_numerical_hyperparameter ("colsample_bytree" )
10461057 self .colsample_bylevel = self ._register_numerical_hyperparameter ("colsample_bylevel" )
10471058 self .alpha = self ._register_numerical_hyperparameter ("alpha" )
1048- self .lambda_ = self ._register_numerical_hyperparameter ("lambda" )
1059+ self .lambda_ = self ._register_numerical_hyperparameter ("lambda" , attr_name = "lambda_" )
10491060 self .booster = self ._register_categorical_hyperparameter ("booster" )
10501061 self .objective = self ._register_categorical_hyperparameter ("objective" )
10511062 self .n_estimators = self ._register_single_value_hyperparameter ("n_estimators" , accepted_types = [int ])
@@ -1101,15 +1112,15 @@ class RidgeRegressionSettings(PredictionAlgorithmSettings):
11011112 def __init__ (self , raw_settings , hyperparameter_search_params ):
11021113 super (RidgeRegressionSettings , self ).__init__ (raw_settings , hyperparameter_search_params )
11031114 self .alpha = self ._register_numerical_hyperparameter ("alpha" )
1104- self .alpha_mode = self ._register_single_category_hyperparameter ("alphaMode" , accepted_values = ["MANUAL" , "AUTO" ])
1115+ self .alpha_mode = self ._register_single_category_hyperparameter ("alphaMode" , accepted_values = ["MANUAL" , "AUTO" ], attr_name = "alpha_mode" )
11051116
11061117
11071118class LassoRegressionSettings (PredictionAlgorithmSettings ):
11081119
11091120 def __init__ (self , raw_settings , hyperparameter_search_params ):
11101121 super (LassoRegressionSettings , self ).__init__ (raw_settings , hyperparameter_search_params )
11111122 self .alpha = self ._register_numerical_hyperparameter ("alpha" )
1112- self .alpha_mode = self ._register_single_category_hyperparameter ("alphaMode" , accepted_values = ["MANUAL" , "AUTO_CV" , "AUTO_IC" ]) # TODO: enforce attribute name = parameter name ?
1123+ self .alpha_mode = self ._register_single_category_hyperparameter ("alphaMode" , accepted_values = ["MANUAL" , "AUTO_CV" , "AUTO_IC" ], attr_name = "alpha_mode" )
11131124
11141125
11151126class OLSSettings (PredictionAlgorithmSettings ):
0 commit comments