Skip to content

Commit d2ac280

Browse files
committed
Merge PR #32 predictionType option creating MLTask
from feature/dss50-create-ml-task-with-prediction-type
2 parents db2fc7e + ca03f21 commit d2ac280

File tree

3 files changed

+51
-24
lines changed

3 files changed

+51
-24
lines changed

dataikuapi/dss/analysis.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -152,24 +152,24 @@ def set_definition(self, definition):
152152
# ML
153153
########################################################
154154

155-
def create_prediction_ml_task(self, target_variable,
156-
ml_backend_type = "PY_MEMORY",
157-
guess_policy = "DEFAULT"):
158-
159-
155+
def create_prediction_ml_task(self,
156+
target_variable,
157+
ml_backend_type="PY_MEMORY",
158+
guess_policy="DEFAULT",
159+
prediction_type=None,
160+
wait_guess_complete=True):
160161
"""Creates a new prediction task in this visual analysis lab
161162
for a dataset.
162163
163-
164-
The returned ML task will be in 'guessing' state, i.e. analyzing
165-
the input dataset to determine feature handling and algorithms.
166-
167-
You should wait for the guessing to be completed by calling
168-
``wait_guess_complete`` on the returned object before doing anything
169-
else (in particular calling ``train`` or ``get_settings``)
170-
164+
:param string target_variable: Variable to predict
171165
:param string ml_backend_type: ML backend to use, one of PY_MEMORY, MLLIB or H2O
172166
:param string guess_policy: Policy to use for setting the default parameters. Valid values are: DEFAULT, SIMPLE_FORMULA, DECISION_TREE, EXPLANATORY and PERFORMANCE
167+
:param string prediction_type: The type of prediction problem this is. If not provided the prediction type will be guessed. Valid values are: BINARY_CLASSIFICATION, REGRESSION, MULTICLASS
168+
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state, i.e. analyzing the input dataset to determine feature handling and algorithms.
169+
You should wait for the guessing to be completed by calling
170+
``wait_guess_complete`` on the returned object before doing anything
171+
else (in particular calling ``train`` or ``get_settings``)
172+
:return :class dataiku.dss.ml.DSSMLTask
173173
"""
174174

175175
obj = {
@@ -178,9 +178,14 @@ def create_prediction_ml_task(self, target_variable,
178178
"backendType": ml_backend_type,
179179
"guessPolicy": guess_policy
180180
}
181-
181+
if prediction_type is not None:
182+
obj["predictionType"] = prediction_type
182183
ref = self.client._perform_json("POST", "/projects/%s/lab/%s/models/" % (self.project_key, self.analysis_id), body=obj)
183-
return DSSMLTask(self.client, self.project_key, self.analysis_id, ref["mlTaskId"])
184+
mltask = DSSMLTask(self.client, self.project_key, self.analysis_id, ref["mlTaskId"])
185+
186+
if wait_guess_complete:
187+
mltask.wait_guess_complete()
188+
return mltask
184189

185190
def create_clustering_ml_task(self,
186191
ml_backend_type = "PY_MEMORY",

dataikuapi/dss/ml.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def set_split_explicit(self, train_selection, test_selection, dataset_name=None,
7878
sp['efsdDatasetSmartName'] = dataset_name
7979
sp['efsdTrain'] = train_split
8080
sp['efsdTest'] = test_split
81-
else:
81+
else:
8282
sp["ttPolicy"] = "EXPLICIT_FILTERING_TWO_DATASETS"
8383
train_split ={'datasetSmartName' : dataset_name}
8484
test_split = {'datasetSmartName' : test_dataset_name}
@@ -373,7 +373,7 @@ def get_split_info(self):
373373
info['nSamples'] = nSamples[self.i] if nSamples is not None else None
374374
info['threshold'] = thresholds[self.i] if thresholds is not None else None
375375
return info
376-
376+
377377
class DSSTree(object):
378378
def __init__(self, tree, feature_names):
379379
self.tree = tree
@@ -677,7 +677,7 @@ def delete(self):
677677
"""
678678
return self.client._perform_json(
679679
"DELETE", "/projects/%s/models/lab/%s/%s/" % (self.project_key, self.analysis_id, self.mltask_id))
680-
680+
681681

682682
def wait_guess_complete(self):
683683
"""
@@ -700,7 +700,7 @@ def get_status(self):
700700
"""
701701
return self.client._perform_json(
702702
"GET", "/projects/%s/models/lab/%s/%s/status" % (self.project_key, self.analysis_id, self.mltask_id))
703-
703+
704704

705705
def get_settings(self):
706706
"""
@@ -921,3 +921,17 @@ def redeploy_to_flow(self, model_id, recipe_name=None, saved_model_id=None, acti
921921
"POST", "/projects/%s/models/lab/%s/%s/models/%s/actions/redeployToFlow" % (self.project_key, self.analysis_id, self.mltask_id, model_id),
922922
body = obj)
923923

924+
def guess(self, prediction_type=None):
925+
"""
926+
Guess the feature handling and the algorithms.
927+
:param string prediction_type: In case of a prediction problem the prediction type can be specify. Valid values are BINARY_CLASSIFICATION, REGRESSION, MULTICLASS.
928+
"""
929+
obj = {}
930+
if prediction_type is not None:
931+
obj["predictionType"] = prediction_type
932+
933+
self.client._perform_empty(
934+
"PUT",
935+
"/projects/%s/models/lab/%s/%s/guess" % (self.project_key, self.analysis_id, self.mltask_id),
936+
params = obj)
937+

dataikuapi/dss/project.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -190,30 +190,38 @@ def create_dataset(self, dataset_name, type,
190190
########################################################
191191

192192
def create_prediction_ml_task(self, input_dataset, target_variable,
193-
ml_backend_type = "PY_MEMORY",
194-
guess_policy = "DEFAULT",
193+
ml_backend_type="PY_MEMORY",
194+
guess_policy="DEFAULT",
195+
prediction_type=None,
195196
wait_guess_complete=True):
196197

197198
"""Creates a new prediction task in a new visual analysis lab
198199
for a dataset.
199200
201+
:param string input_dataset: the dataset to use for training/testing the model
202+
:param string target_variable: the variable to predict
200203
:param string ml_backend_type: ML backend to use, one of PY_MEMORY, MLLIB or H2O
201204
:param string guess_policy: Policy to use for setting the default parameters. Valid values are: DEFAULT, SIMPLE_FORMULA, DECISION_TREE, EXPLANATORY and PERFORMANCE
205+
:param string prediction_type: The type of prediction problem this is. If not provided the prediction type will be guessed. Valid values are: BINARY_CLASSIFICATION, REGRESSION, MULTICLASS
202206
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state, i.e. analyzing the input dataset to determine feature handling and algorithms.
203207
You should wait for the guessing to be completed by calling
204208
``wait_guess_complete`` on the returned object before doing anything
205209
else (in particular calling ``train`` or ``get_settings``)
206210
"""
207211
obj = {
208-
"inputDataset" : input_dataset,
209-
"taskType" : "PREDICTION",
210-
"targetVariable" : target_variable,
212+
"inputDataset": input_dataset,
213+
"taskType": "PREDICTION",
214+
"targetVariable": target_variable,
211215
"backendType": ml_backend_type,
212216
"guessPolicy": guess_policy
213217
}
214218

219+
if prediction_type is not None:
220+
obj["predictionType"] = prediction_type
221+
215222
ref = self.client._perform_json("POST", "/projects/%s/models/lab/" % self.project_key, body=obj)
216223
ret = DSSMLTask(self.client, self.project_key, ref["analysisId"], ref["mlTaskId"])
224+
217225
if wait_guess_complete:
218226
ret.wait_guess_complete()
219227
return ret

0 commit comments

Comments
 (0)