Skip to content

Commit 9b095f2

Browse files
committed
ML API: add helpers to get all algorithm names and to disable all algorithms
1 parent c0bf745 commit 9b095f2

File tree

1 file changed

+55
-4
lines changed

1 file changed

+55
-4
lines changed

dataikuapi/dss/ml.py

Lines changed: 55 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,6 @@ def use_sample_weighting(self, feature_name):
191191
self.mltask_settings['weight']['sampleWeightVariable'] = feature_name
192192
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'WEIGHT'
193193

194-
195194
def remove_sample_weighting(self):
196195
"""
197196
Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
@@ -235,6 +234,35 @@ def set_algorithm_enabled(self, algorithm_name, enabled):
235234
"""
236235
self.get_algorithm_settings(algorithm_name)["enabled"] = enabled
237236

237+
def disable_all_algorithms(self):
238+
"""Disables all algorithms"""
239+
240+
for algorithm_name in self.__class__.algorithm_remap.keys():
241+
key = self.__class__.algorithm_remap[algorithm_name]
242+
if key in self.mltask_settings["modeling"]:
243+
self.mltask_settings["modeling"][key]["enabled"] = False
244+
245+
for custom_mllib in self.mltask_settings["modeling"]["custom_mllib"]:
246+
custom_mllib["enabled"] = False
247+
for custom_python in self.mltask_settings["modeling"]["custom_python"]:
248+
custom_python["enabled"] = False
249+
for plugin in self.mltask_settings["modeling"]["plugin"].values():
250+
plugin["enabled"] = False
251+
252+
def get_all_possible_algorithm_names():
253+
"""
254+
Returns the list of possible algorithm names, i.e. the list of valid
255+
identifiers for :meth:`set_algorithm_enabled` and :meth:`get_algorithm_settings`
256+
257+
This does not include Custom Python models, Custom MLLib models, plugin models.
258+
This includes all possible algorithms, regardless of the prediction kind (regression/classification)
259+
or engine, so some algorithms may be irrelevant
260+
261+
:returns: the list of algorithm names as a list of strings
262+
:rtype: list of string
263+
"""
264+
return self.__class__.algorithm_remap.keys()
265+
238266
def set_metric(self, metric=None, custom_metric=None, custom_metric_greater_is_better=True, custom_metric_use_probas=False):
239267
"""
240268
Sets the score metric to optimize for a prediction ML Task
@@ -261,20 +289,43 @@ def save(self):
261289
class DSSPredictionMLTaskSettings(DSSMLTaskSettings):
262290
__doc__ = []
263291
algorithm_remap = {
292+
"RANDOM_FOREST_CLASSIFICATION": "random_forest_classification",
293+
"RANDOM_FOREST_REGRESSION" : "random_forest_regression",
294+
"EXTRA_TREES": "extra_trees",
295+
"GBT_CLASSIFICATION" : "gbt_classification",
296+
"GBT_REGRESSION" : "gbt_regression",
297+
"DECISION_TREE_CLASSIFICATION" : "decision_tree_classification",
298+
"DECISION_TREE_REGRESSION" : "decision_tree_regression",
299+
"RIDGE_REGRESSION": "ridge_regression",
300+
"LASSO_REGRESSION" : "lasso_regression",
301+
"LEASTSQUARE_REGRESSION": "leastsquare_regression",
302+
"SGD_REGRESSION" : "sgd_regression",
303+
"KNN": "knn",
304+
"LOGISTIC_REGRESSION" : "logistic_regression",
305+
"NEURAL_NETWORK" :"neural_network",
264306
"SVC_CLASSIFICATION" : "svc_classifier",
307+
"SVM_REGRESSION" : "svm_regression",
265308
"SGD_CLASSIFICATION" : "sgd_classifier",
309+
"LARS" : "lars_params",
310+
"XGBOOST_CLASSIFICATION" : "xgboost",
311+
"XGBOOST_REGRESSION" : "xgboost",
266312
"SPARKLING_DEEP_LEARNING" : "deep_learning_sparkling",
267313
"SPARKLING_GBM" : "gbm_sparkling",
268314
"SPARKLING_RF" : "rf_sparkling",
269315
"SPARKLING_GLM" : "glm_sparkling",
270316
"SPARKLING_NB" : "nb_sparkling",
271-
"XGBOOST_CLASSIFICATION" : "xgboost",
272-
"XGBOOST_REGRESSION" : "xgboost",
273317
"MLLIB_LOGISTIC_REGRESSION" : "mllib_logit",
318+
"MLLIB_NAIVE_BAYES" : "mllib_naive_bayes",
274319
"MLLIB_LINEAR_REGRESSION" : "mllib_linreg",
275-
"MLLIB_RANDOM_FOREST" : "mllib_rf"
320+
"MLLIB_RANDOM_FOREST" : "mllib_rf",
321+
"MLLIB_GBT": "mllib_gbt",
322+
"MLLIB_DECISION_TREE" : "mllib_dt",
323+
"VERTICA_LINEAR_REGRESSION" : "vertica_linear_regression",
324+
"VERTICA_LOGISTIC_REGRESSION" : "vertica_logistic_regression",
325+
"KERAS_CODE" : "keras"
276326
}
277327

328+
278329
class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
279330
__doc__ = []
280331
algorithm_remap = {

0 commit comments

Comments
 (0)