Skip to content

Commit dabb26f

Browse files
committed
ML API: add synchronous method for create / train / ensemble
1 parent aa5d4ea commit dabb26f

File tree

3 files changed

+101
-37
lines changed

3 files changed

+101
-37
lines changed

dataikuapi/dss/ml.py

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,8 @@ def __init__(self, client, project_key, analysis_id, mltask_id):
664664

665665
def wait_guess_complete(self):
666666
"""
667-
Waits for guess to be complete. This should be called immediately after the creation of a new ML Task,
667+
Waits for guess to be complete. This should be called immediately after the creation of a new ML Task
668+
(if the ML Task was created with wait_guess_complete=False),
668669
before calling ``get_settings`` or ``train``
669670
"""
670671
while True:
@@ -699,22 +700,49 @@ def get_settings(self):
699700
else:
700701
return DSSClusteringMLTaskSettings(self.client, self.project_key, self.analysis_id, self.mltask_id, settings)
701702

702-
def start_train(self, session_name=None, session_description=None):
703+
def train(self, session_name=None, session_description=None):
703704
"""
704-
Starts asynchronously a new train session for this ML Task.
705-
705+
Trains models for this ML Task
706+
706707
:param str session_name: name for the session
707708
:param str session_description: description for the session
708709
709-
This returns immediately, before train is complete. To wait for train to complete, use ``wait_train_complete()``
710+
This method waits for train to complete. If you want to train asynchronously, use :meth:`start_train` and :meth:`wait_train_complete`
711+
712+
This method returns the list of trained model identifiers. It returns models that have been trained for this train
713+
session, not all trained models for this ML task. To get all identifiers for all models trained across all training sessions,
714+
use :meth:`get_trained_models_ids`
715+
716+
These identifiers can be used for :meth:`get_trained_model_snippet`, :meth:`get_trained_model_details` and :meth:`deploy_to_flow`
717+
718+
:return: A list of model identifiers
719+
:rtype: list of strings
710720
"""
711-
session_info = {
712-
"sessionName" : session_name,
713-
"sessionDescription" : session_description
714-
}
721+
train_ret = self.start_train(session_name, session_description)
722+
self.wait_train_complete()
723+
return self.get_trained_models_ids(session_id = train_ret["sessionId"])
715724

716-
return self.client._perform_json(
717-
"POST", "/projects/%s/models/lab/%s/%s/train" % (self.project_key, self.analysis_id, self.mltask_id), body=session_info)
725+
def ensemble(self, model_ids=[], method=None):
726+
"""
727+
Create an ensemble model of a set of models
728+
729+
:param list model_ids: A list of model identifiers
730+
:param str method: the ensembling method. One of: AVERAGE, PROBA_AVERAGE, MEDIAN, VOTE, LINEAR_MODEL, LOGISTIC_MODEL
731+
732+
This method waits for the ensemble train to complete. If you want to train asynchronously, use :meth:`start_ensembling` and :meth:`wait_train_complete`
733+
734+
This method returns the identifier of the trained ensemble.
735+
To get all identifiers for all models trained across all training sessions,
736+
use :meth:`get_trained_models_ids`
737+
738+
This identifier can be used for :meth:`get_trained_model_snippet`, :meth:`get_trained_model_details` and :meth:`deploy_to_flow`
739+
740+
:return: A model identifier
741+
:rtype: string
742+
"""
743+
train_ret = self.start_ensembling(model_ids, method)
744+
self.wait_train_complete()
745+
return train_ret
718746

719747
def start_ensembling(self, model_ids=[], method=None):
720748
"""
@@ -736,6 +764,24 @@ def start_ensembling(self, model_ids=[], method=None):
736764
return self.client._perform_json(
737765
"POST", "/projects/%s/models/lab/%s/%s/ensemble" % (self.project_key, self.analysis_id, self.mltask_id), body=ensembling_request)['id']
738766

767+
768+
def start_train(self, session_name=None, session_description=None):
769+
"""
770+
Starts asynchronously a new train session for this ML Task.
771+
772+
:param str session_name: name for the session
773+
:param str session_description: description for the session
774+
775+
This returns immediately, before train is complete. To wait for train to complete, use ``wait_train_complete()``
776+
"""
777+
session_info = {
778+
"sessionName" : session_name,
779+
"sessionDescription" : session_description
780+
}
781+
782+
return self.client._perform_json(
783+
"POST", "/projects/%s/models/lab/%s/%s/train" % (self.project_key, self.analysis_id, self.mltask_id), body=session_info)
784+
739785
def wait_train_complete(self):
740786
"""
741787
Waits for train to be complete.

dataikuapi/dss/project.py

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -133,25 +133,30 @@ def get_dataset(self, dataset_name):
133133
"""
134134
Get a handle to interact with a specific dataset
135135
136-
Args:
137-
dataset_name: the name of the desired dataset
136+
:param string dataset_name: the name of the desired dataset
138137
139-
Returns:
140-
A :class:`dataikuapi.dss.dataset.DSSDataset` dataset handle
138+
:returns: A :class:`dataikuapi.dss.dataset.DSSDataset` dataset handle
141139
"""
142140
return DSSDataset(self.client, self.project_key, dataset_name)
143141

144142
def create_dataset(self, dataset_name, type,
145143
params={}, formatType=None, formatParams={}):
146144
"""
147-
Create a new dataset in the project, and return a handle to interact with it
145+
Create a new dataset in the project, and return a handle to interact with it.
146+
147+
The precise structure of ``params`` and ``formatParams`` depends on the specific dataset
148+
type and dataset format type. To know which fields exist for a given dataset type and format type,
149+
create a dataset from the UI, and use :meth:`get_dataset` to retrieve the configuration
150+
of the dataset and inspect it. Then reproduce a similar structure in the :meth:`create_dataset` call.
151+
152+
Not all settings of a dataset can be set at creation time (for example partitioning). After creation,
153+
you'll have the ability to modify the dataset
148154
149-
Args:
150-
dataset_name: the name for the new dataset
151-
type: the type of the dataset
152-
params: the parameters for the type, as a JSON object
153-
formatType: an optional format to create the dataset with
154-
formatParams: the parameters to the format, as a JSON object
155+
:param string dataset_name: the name for the new dataset
156+
:param string type: the type of the dataset
157+
:param dict params: the parameters for the type, as a JSON object
158+
:param string formatType: an optional format to create the dataset with (only for file-oriented datasets)
159+
:param string formatParams: the parameters to the format, as a JSON object (only for file-oriented datasets)
155160
156161
Returns:
157162
A :class:`dataikuapi.dss.dataset.DSSDataset` dataset handle
@@ -173,25 +178,20 @@ def create_dataset(self, dataset_name, type,
173178
########################################################
174179

175180
def create_prediction_ml_task(self, input_dataset, target_variable,
176-
ml_backend_type = "PY_MEMORY",
177-
guess_policy = "DEFAULT"):
178-
181+
ml_backend_type = "PY_MEMORY",
182+
guess_policy = "DEFAULT",
183+
wait_guess_complete=True):
179184

180185
"""Creates a new prediction task in a new visual analysis lab
181186
for a dataset.
182187
183-
184-
The returned ML task will be in 'guessing' state, i.e. analyzing
185-
the input dataset to determine feature handling and algorithms.
186-
187-
You should wait for the guessing to be completed by calling
188-
``wait_guess_complete`` on the returned object before doing anything
189-
else (in particular calling ``train`` or ``get_settings``)
190-
191188
:param string ml_backend_type: ML backend to use, one of PY_MEMORY, MLLIB or H2O
192189
:param string guess_policy: Policy to use for setting the default parameters. Valid values are: DEFAULT, SIMPLE_FORMULA, DECISION_TREE, EXPLANATORY and PERFORMANCE
190+
:param boolean wait_guess_complete: if False, the returned ML task will be in 'guessing' state, i.e. analyzing the input dataset to determine feature handling and algorithms.
191+
You should wait for the guessing to be completed by calling
192+
``wait_guess_complete`` on the returned object before doing anything
193+
else (in particular calling ``train`` or ``get_settings``)
193194
"""
194-
195195
obj = {
196196
"inputDataset" : input_dataset,
197197
"taskType" : "PREDICTION",
@@ -201,7 +201,10 @@ def create_prediction_ml_task(self, input_dataset, target_variable,
201201
}
202202

203203
ref = self.client._perform_json("POST", "/projects/%s/models/lab/" % self.project_key, body=obj)
204-
return DSSMLTask(self.client, self.project_key, ref["analysisId"], ref["mlTaskId"])
204+
ret = DSSMLTask(self.client, self.project_key, ref["analysisId"], ref["mlTaskId"])
205+
if wait_guess_complete:
206+
ret.wait_guess_complete()
207+
return ret
205208

206209
def create_clustering_ml_task(self, input_dataset,
207210
ml_backend_type = "PY_MEMORY",

dataikuapi/dssclient.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -596,19 +596,31 @@ def get_general_settings(self):
596596
########################################################
597597

598598
def create_project_from_bundle_local_archive(self, archive_path):
599+
"""
600+
Create a project from a bundle archive.
601+
Warning: this method can only be used on an automation node.
602+
603+
:param string archive_path: Path on the local machine where the archive is
604+
"""
599605
return self._perform_json("POST",
600606
"/projectsFromBundle/fromArchive",
601607
params = { "archivePath" : osp.abspath(archive_path) })
602608

603609
def create_project_from_bundle_archive(self, fp):
610+
"""
611+
Create a project from a bundle archive (as a file object)
612+
Warning: this method can only be used on an automation node.
613+
614+
:param string fp: A file-like object pointing to a bundle archive zip
615+
"""
604616
files = {'file': fp }
605617
return self._perform_json("POST",
606618
"/projectsFromBundle/", files=files)
607619

608-
609620
def prepare_project_import(self, f):
610621
"""
611-
Prepares import of a project archive
622+
Prepares import of a project archive.
623+
Warning: this method can only be used on a design node.
612624
613625
:param file-like fp: the input stream, as a file-like object
614626
:returns: a :class:`TemporaryImportHandle` to interact with the prepared import
@@ -624,6 +636,9 @@ def prepare_project_import(self, f):
624636
########################################################
625637

626638
def catalog_index_connections(self, connection_names=[], all_connections=False, indexing_mode="FULL"):
639+
"""
640+
Triggers an indexing of multiple connections in the data catalog
641+
"""
627642
return self._perform_json("POST", "/catalog/index", body={
628643
"connectionNames": connection_names,
629644
"indexAllConnections": all_connections,

0 commit comments

Comments
 (0)