Skip to content

Commit 5cc5faa

Browse files
committed
Initial support for model public API
1 parent 33d81f0 commit 5cc5faa

File tree

2 files changed

+179
-0
lines changed

2 files changed

+179
-0
lines changed

dataikuapi/dss/ml.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
from ..utils import DataikuException
2+
from ..utils import DataikuUTF8CSVReader
3+
from ..utils import DataikuStreamedHttpUTF8CSVReader
4+
import json
5+
import time
6+
from .metrics import ComputedMetrics
7+
8+
class DSSMLTaskSettings(object):
9+
def __init__(self, client, project_key, analysis_id, mltask_id, mltask_settings):
10+
self.client = client
11+
self.project_key = project_key
12+
self.analysis_id = analysis_id
13+
self.mltask_id = mltask_id
14+
self.mltask_settings = mltask_settings
15+
16+
def get_raw(self):
17+
"""Gets the raw settings.
18+
This returns a reference to the raw settings, not a copy.
19+
"""
20+
return self.mltask_settings
21+
22+
def get_feature_preprocessing(self, feature_name):
23+
return self.mltask_settings["preprocessing"]["per_feature"][feature_name]
24+
25+
def reject_feature(self, feature_name):
26+
self.get_feature_preprocessing(feature_name)["role"] = "REJECT"
27+
28+
def use_feature(self, feature_name):
29+
self.get_feature_preprocessing(feature_name)["role"] = "INPUT"
30+
31+
def get_algorithm_settings(self, algorithm_name):
32+
algorithm_remap = {
33+
"SVC_CLASSIFICATION" : "svc_classifier",
34+
"SGD_CLASSIFICATION" : "sgd_classifier",
35+
"SPARKLING_DEEP_LEARNING" : "deep_learning_sparkling",
36+
"SPARKLING_GBM" : "gbm_sparkling",
37+
"SPARKLING_RF" : "rf_sparkling",
38+
"SPARKLING_GLM" : "glm_sparkling",
39+
"SPARKLING_NB" : "nb_sparkling",
40+
"XGBOOST_CLASSIFICATION" : "xgboost",
41+
"XGBOOST_REGRESSION" : "xgboost",
42+
"MLLIB_LOGISTIC_REGRESSION" : "mllib_logit",
43+
"MLLIB_LINEAR_REGRESSION" : "mllib_linreg",
44+
"MLLIB_RANDOM_FOREST" : "mllib_rf"
45+
}
46+
if algorithm_name in algorithm_remap:
47+
algorithm_name = algorithm_remap[algorithm_name]
48+
49+
return self.mltask_settings["modeling"][algorithm_name.lower()]
50+
51+
def set_algorithm_enabled(self, algorithm_name, enabled):
52+
self.get_algorithm_settings(algorithm_name)["enabled"] = enabled
53+
54+
def save(self):
55+
"""Saves back these settings to the ML Task"""
56+
57+
print("WILL SAVE: %s" % json.dumps(self.mltask_settings, indent=2))
58+
59+
self.client._perform_empty(
60+
"POST", "/projects/%s/models/lab/%s/%s/settings" % (self.project_key, self.analysis_id, self.mltask_id),
61+
body = self.mltask_settings)
62+
63+
class DSSMLTask(object):
64+
def __init__(self, client, project_key, analysis_id, mltask_id):
65+
self.client = client
66+
self.project_key = project_key
67+
self.analysis_id = analysis_id
68+
self.mltask_id = mltask_id
69+
70+
def wait_guess_complete(self):
71+
"""
72+
Waits for guess to be complete. This should be called immediately after the creation of a new ML Task,
73+
before calling ``get_settings`` or ``train``
74+
"""
75+
while True:
76+
status = self.get_status()
77+
if status.get("guessing", "???") == False:
78+
break
79+
time.sleep(0.2)
80+
81+
def wait_train_complete(self):
82+
"""
83+
Waits for train to be complete.
84+
"""
85+
while True:
86+
status = self.get_status()
87+
if status.get("training", "???") == False:
88+
break
89+
time.sleep(2)
90+
91+
def get_status(self):
92+
"""
93+
Gets the status of this ML Task
94+
95+
:return: a dict
96+
"""
97+
return self.client._perform_json(
98+
"GET", "/projects/%s/models/lab/%s/%s/status" % (self.project_key, self.analysis_id, self.mltask_id))
99+
100+
101+
def get_settings(self):
102+
"""
103+
Gets the settings of this ML Tasks
104+
105+
:return: a DSSMLTaskSettings object to interact with the settings
106+
"""
107+
settings = self.client._perform_json(
108+
"GET", "/projects/%s/models/lab/%s/%s/settings" % (self.project_key, self.analysis_id, self.mltask_id))
109+
110+
return DSSMLTaskSettings(self.client, self.project_key, self.analysis_id, self.mltask_id, settings)
111+
112+
def start_train(self):
113+
"""Starts asynchronously a new train session for this ML Task.
114+
115+
This returns immediately, before train is complete. To wait for train to complete,
116+
poll on ``get_status`` until ``training`` is False"""
117+
self.client._perform_empty(
118+
"POST", "/projects/%s/models/lab/%s/%s/train" % (self.project_key, self.analysis_id, self.mltask_id))
119+
120+
121+
def get_trained_models_ids(self):
122+
status = self.get_status()
123+
return [x["id"] for x in status["fullModelIds"]]
124+
125+
126+
def get_trained_model_summary(self, id):
127+
obj = {
128+
"modelsIds" : [id]
129+
}
130+
return self.client._perform_json(
131+
"POST", "/projects/%s/models/lab/%s/%s/models-summaries" % (self.project_key, self.analysis_id, self.mltask_id),
132+
body = obj)[id]
133+
134+
def deploy_to_flow(self, model_id, model_name, train_dataset, test_dataset=None, redo_optimization=True):
135+
obj = {
136+
"trainDatasetRef" : train_dataset,
137+
"testDatasetRef" : test_dataset,
138+
"modelName" : model_name,
139+
"redoOptimization": redo_optimization
140+
}
141+
return self.client._perform_json(
142+
"POST", "/projects/%s/models/lab/%s/%s/models/%s/actions/deployToFlow" % (self.project_key, self.analysis_id, self.mltask_id, model_id),
143+
body = obj)
144+

dataikuapi/dss/project.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .future import DSSFuture
1212
from .notebook import DSSNotebook
1313
from .macro import DSSMacro
14+
from .ml import DSSMLTask
1415
from dataikuapi.utils import DataikuException
1516

1617

@@ -167,6 +168,40 @@ def create_dataset(self, dataset_name, type,
167168
body = obj)
168169
return DSSDataset(self.client, self.project_key, dataset_name)
169170

171+
########################################################
172+
# ML
173+
########################################################
174+
175+
def create_prediction_lab_task(self, input_dataset, target_variable,
176+
ml_backend_type = "PY_MEMORY",
177+
guess_policy = "DEFAULT"):
178+
179+
180+
"""Creates a new prediction task in a new visual analysis lab
181+
for a dataset.
182+
183+
184+
The returned ML task will be in 'guessing' state, i.e. analyzing
185+
the input dataset to determine feature handling and algorithms.
186+
187+
You should wait for the guessing to be completed by calling
188+
``wait_guess_complete`` on the returned object before doing anything
189+
else (in particular calling ``train`` or ``get_settings``)
190+
"""
191+
192+
obj = {
193+
"inputDataset" : input_dataset,
194+
"taskType" : "PREDICTION",
195+
"targetVariable" : target_variable,
196+
"backendType": ml_backend_type,
197+
"guessPolicy": guess_policy
198+
}
199+
200+
ref = self.client._perform_json("POST", "/projects/%s/models/lab/" % self.project_key, body=obj)
201+
return DSSMLTask(self.client, self.project_key, ref["analysisId"], ref["mlTaskId"])
202+
203+
204+
170205

171206
########################################################
172207
# Saved models

0 commit comments

Comments
 (0)