Skip to content

Commit b4fad34

Browse files
committed
Merge remote-tracking branch 'origin/release/6.0' into release/7.0
2 parents 09c5530 + 6ccde56 commit b4fad34

File tree

2 files changed

+72
-5
lines changed

2 files changed

+72
-5
lines changed

dataikuapi/dss/ml.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,6 @@ def use_sample_weighting(self, feature_name):
227227
self.mltask_settings['weight']['sampleWeightVariable'] = feature_name
228228
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'WEIGHT'
229229

230-
231230
def remove_sample_weighting(self):
232231
"""
233232
Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
@@ -271,6 +270,35 @@ def set_algorithm_enabled(self, algorithm_name, enabled):
271270
"""
272271
self.get_algorithm_settings(algorithm_name)["enabled"] = enabled
273272

273+
def disable_all_algorithms(self):
274+
"""Disables all algorithms"""
275+
276+
for algorithm_name in self.__class__.algorithm_remap.keys():
277+
key = self.__class__.algorithm_remap[algorithm_name]
278+
if key in self.mltask_settings["modeling"]:
279+
self.mltask_settings["modeling"][key]["enabled"] = False
280+
281+
for custom_mllib in self.mltask_settings["modeling"]["custom_mllib"]:
282+
custom_mllib["enabled"] = False
283+
for custom_python in self.mltask_settings["modeling"]["custom_python"]:
284+
custom_python["enabled"] = False
285+
for plugin in self.mltask_settings["modeling"]["plugin"].values():
286+
plugin["enabled"] = False
287+
288+
def get_all_possible_algorithm_names():
289+
"""
290+
Returns the list of possible algorithm names, i.e. the list of valid
291+
identifiers for :meth:`set_algorithm_enabled` and :meth:`get_algorithm_settings`
292+
293+
This does not include Custom Python models, Custom MLLib models, plugin models.
294+
This includes all possible algorithms, regardless of the prediction kind (regression/classification)
295+
or engine, so some algorithms may be irrelevant
296+
297+
:returns: the list of algorithm names as a list of strings
298+
:rtype: list of string
299+
"""
300+
return self.__class__.algorithm_remap.keys()
301+
274302
def set_metric(self, metric=None, custom_metric=None, custom_metric_greater_is_better=True, custom_metric_use_probas=False):
275303
"""
276304
Sets the score metric to optimize for a prediction ML Task
@@ -297,20 +325,43 @@ def save(self):
297325
class DSSPredictionMLTaskSettings(DSSMLTaskSettings):
298326
__doc__ = []
299327
algorithm_remap = {
328+
"RANDOM_FOREST_CLASSIFICATION": "random_forest_classification",
329+
"RANDOM_FOREST_REGRESSION" : "random_forest_regression",
330+
"EXTRA_TREES": "extra_trees",
331+
"GBT_CLASSIFICATION" : "gbt_classification",
332+
"GBT_REGRESSION" : "gbt_regression",
333+
"DECISION_TREE_CLASSIFICATION" : "decision_tree_classification",
334+
"DECISION_TREE_REGRESSION" : "decision_tree_regression",
335+
"RIDGE_REGRESSION": "ridge_regression",
336+
"LASSO_REGRESSION" : "lasso_regression",
337+
"LEASTSQUARE_REGRESSION": "leastsquare_regression",
338+
"SGD_REGRESSION" : "sgd_regression",
339+
"KNN": "knn",
340+
"LOGISTIC_REGRESSION" : "logistic_regression",
341+
"NEURAL_NETWORK" :"neural_network",
300342
"SVC_CLASSIFICATION" : "svc_classifier",
343+
"SVM_REGRESSION" : "svm_regression",
301344
"SGD_CLASSIFICATION" : "sgd_classifier",
345+
"LARS" : "lars_params",
346+
"XGBOOST_CLASSIFICATION" : "xgboost",
347+
"XGBOOST_REGRESSION" : "xgboost",
302348
"SPARKLING_DEEP_LEARNING" : "deep_learning_sparkling",
303349
"SPARKLING_GBM" : "gbm_sparkling",
304350
"SPARKLING_RF" : "rf_sparkling",
305351
"SPARKLING_GLM" : "glm_sparkling",
306352
"SPARKLING_NB" : "nb_sparkling",
307-
"XGBOOST_CLASSIFICATION" : "xgboost",
308-
"XGBOOST_REGRESSION" : "xgboost",
309353
"MLLIB_LOGISTIC_REGRESSION" : "mllib_logit",
354+
"MLLIB_NAIVE_BAYES" : "mllib_naive_bayes",
310355
"MLLIB_LINEAR_REGRESSION" : "mllib_linreg",
311-
"MLLIB_RANDOM_FOREST" : "mllib_rf"
356+
"MLLIB_RANDOM_FOREST" : "mllib_rf",
357+
"MLLIB_GBT": "mllib_gbt",
358+
"MLLIB_DECISION_TREE" : "mllib_dt",
359+
"VERTICA_LINEAR_REGRESSION" : "vertica_linear_regression",
360+
"VERTICA_LOGISTIC_REGRESSION" : "vertica_logistic_regression",
361+
"KERAS_CODE" : "keras"
312362
}
313363

364+
314365
class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
315366
__doc__ = []
316367
algorithm_remap = {
@@ -529,6 +580,22 @@ def get_performance_metrics(self):
529580
return clean_snippet
530581

531582

583+
def get_hyperparameter_search_points(self):
584+
"""
585+
Gets the list of points in the hyperparameter search space that have been tested.
586+
587+
Returns a list of dict. Each entry in the list represents a point.
588+
589+
For each point, the dict contains at least:
590+
- "score": the average value of the optimization metric over all the folds at this point
591+
- "params": a dict of the parameters at this point. This dict has the same structure
592+
as the params of the best parameters
593+
"""
594+
595+
if not "gridCells" in self.details["iperf"]:
596+
raise ValueError("No hyperparameter search result, maybe this model did not perform hyperparameter optimization")
597+
return self.details["iperf"]["gridCells"]
598+
532599
def get_preprocessing_settings(self):
533600
"""
534601
Gets the preprocessing settings that were used to train this model

dataikuapi/dss/project.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def get_metadata(self):
142142
Get the metadata attached to this project. The metadata contains label, description
143143
checklists, tags and custom metadata of the project.
144144
145-
For more information on available metadata, please see https://doc.dataiku.com/dss/api/5.0/rest/
145+
For more information on available metadata, please see https://doc.dataiku.com/dss/api/6.0/rest/
146146
147147
:returns: a dict object containing the project metadata.
148148
:rtype: dict

0 commit comments

Comments
 (0)