Skip to content

Commit d695351

Browse files
author
Louis Pouillot
committed
Add optional parameters to force prediction type in ML pred task creation, and add a method to re-guess the task parameters with a option to force pred type
1 parent db2fc7e commit d695351

File tree

2 files changed

+39
-9
lines changed

2 files changed

+39
-9
lines changed

dataikuapi/dss/ml.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def set_split_explicit(self, train_selection, test_selection, dataset_name=None,
7878
sp['efsdDatasetSmartName'] = dataset_name
7979
sp['efsdTrain'] = train_split
8080
sp['efsdTest'] = test_split
81-
else:
81+
else:
8282
sp["ttPolicy"] = "EXPLICIT_FILTERING_TWO_DATASETS"
8383
train_split ={'datasetSmartName' : dataset_name}
8484
test_split = {'datasetSmartName' : test_dataset_name}
@@ -373,7 +373,7 @@ def get_split_info(self):
373373
info['nSamples'] = nSamples[self.i] if nSamples is not None else None
374374
info['threshold'] = thresholds[self.i] if thresholds is not None else None
375375
return info
376-
376+
377377
class DSSTree(object):
378378
def __init__(self, tree, feature_names):
379379
self.tree = tree
@@ -677,7 +677,7 @@ def delete(self):
677677
"""
678678
return self.client._perform_json(
679679
"DELETE", "/projects/%s/models/lab/%s/%s/" % (self.project_key, self.analysis_id, self.mltask_id))
680-
680+
681681

682682
def wait_guess_complete(self):
683683
"""
@@ -700,7 +700,7 @@ def get_status(self):
700700
"""
701701
return self.client._perform_json(
702702
"GET", "/projects/%s/models/lab/%s/%s/status" % (self.project_key, self.analysis_id, self.mltask_id))
703-
703+
704704

705705
def get_settings(self):
706706
"""
@@ -921,3 +921,25 @@ def redeploy_to_flow(self, model_id, recipe_name=None, saved_model_id=None, acti
921921
"POST", "/projects/%s/models/lab/%s/%s/models/%s/actions/redeployToFlow" % (self.project_key, self.analysis_id, self.mltask_id, model_id),
922922
body = obj)
923923

924+
def start_guess(self,
925+
prediction_type=None,
926+
wait_guess_complete=True):
927+
"""
928+
Guess the feature handling and the algorithms.
929+
:param string prediction_type: In case of a prediction problem the prediction type can be specify. Valid values are BINARY_CLASSIFICATION, REGRESSION, MULTICLASS.
930+
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state.
931+
You should wait for the guessing to be completed by calling
932+
``wait_guess_complete`` on the returned object before doing anything
933+
else (in particular calling ``train`` or ``get_settings``)
934+
:return:
935+
"""
936+
obj = {}
937+
if prediction_type is not None:
938+
obj["predictionType"] = prediction_type
939+
940+
self.client._perform_empty("PUT",
941+
"/projects/%s/models/lab/%s/%s/guess" % (self.project_key, self.analysis_id, self.mltask_id),
942+
params=obj)
943+
944+
if wait_guess_complete:
945+
self.wait_guess_complete()

dataikuapi/dss/project.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,30 +190,38 @@ def create_dataset(self, dataset_name, type,
190190
########################################################
191191

192192
def create_prediction_ml_task(self, input_dataset, target_variable,
193-
ml_backend_type = "PY_MEMORY",
194-
guess_policy = "DEFAULT",
193+
ml_backend_type="PY_MEMORY",
194+
guess_policy="DEFAULT",
195+
prediction_type=None,
195196
wait_guess_complete=True):
196197

197198
"""Creates a new prediction task in a new visual analysis lab
198199
for a dataset.
199200
201+
:param string input_dataset: the dataset to use for training/testing the model
202+
:param string target_variable: the variable to predict
200203
:param string ml_backend_type: ML backend to use, one of PY_MEMORY, MLLIB or H2O
201204
:param string guess_policy: Policy to use for setting the default parameters. Valid values are: DEFAULT, SIMPLE_FORMULA, DECISION_TREE, EXPLANATORY and PERFORMANCE
205+
:param string prediction_type: The type of prediction problem this is. If not provided the prediction type will be guessed. Valid values are: BINARY_CLASSIFICATION, REGRESSION, MULTICLASS
202206
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state, i.e. analyzing the input dataset to determine feature handling and algorithms.
203207
You should wait for the guessing to be completed by calling
204208
``wait_guess_complete`` on the returned object before doing anything
205209
else (in particular calling ``train`` or ``get_settings``)
206210
"""
207211
obj = {
208-
"inputDataset" : input_dataset,
209-
"taskType" : "PREDICTION",
210-
"targetVariable" : target_variable,
212+
"inputDataset": input_dataset,
213+
"taskType": "PREDICTION",
214+
"targetVariable": target_variable,
211215
"backendType": ml_backend_type,
212216
"guessPolicy": guess_policy
213217
}
214218

219+
if prediction_type is not None:
220+
obj["predictionType"] = prediction_type
221+
215222
ref = self.client._perform_json("POST", "/projects/%s/models/lab/" % self.project_key, body=obj)
216223
ret = DSSMLTask(self.client, self.project_key, ref["analysisId"], ref["mlTaskId"])
224+
217225
if wait_guess_complete:
218226
ret.wait_guess_complete()
219227
return ret

0 commit comments

Comments
 (0)