Skip to content

Commit 93e6c87

Browse files
authored
Merge PR #148 train & list ML Task Queues
from feature/dss100-mltask-queues
2 parents 93299ff + af15dd8 commit 93e6c87

File tree

2 files changed

+39
-5
lines changed

2 files changed

+39
-5
lines changed

dataikuapi/dss/ml.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3348,7 +3348,7 @@ def get_settings(self):
33483348
else:
33493349
return DSSClusteringMLTaskSettings(self.client, self.project_key, self.analysis_id, self.mltask_id, settings)
33503350

3351-
def train(self, session_name=None, session_description=None):
3351+
def train(self, session_name=None, session_description=None, run_queue=False):
33523352
"""
33533353
Trains models for this ML Task
33543354
@@ -3366,7 +3366,7 @@ def train(self, session_name=None, session_description=None):
33663366
:return: A list of model identifiers
33673367
:rtype: list of strings
33683368
"""
3369-
train_ret = self.start_train(session_name, session_description)
3369+
train_ret = self.start_train(session_name, session_description, run_queue)
33703370
self.wait_train_complete()
33713371
return self.get_trained_models_ids(session_id = train_ret["sessionId"])
33723372

@@ -3395,7 +3395,7 @@ def ensemble(self, model_ids=None, method=None):
33953395
return train_ret
33963396

33973397

3398-
def start_train(self, session_name=None, session_description=None):
3398+
def start_train(self, session_name=None, session_description=None, run_queue=False):
33993399
"""
34003400
Starts asynchronously a new train session for this ML Task.
34013401
@@ -3406,7 +3406,8 @@ def start_train(self, session_name=None, session_description=None):
34063406
"""
34073407
session_info = {
34083408
"sessionName" : session_name,
3409-
"sessionDescription" : session_description
3409+
"sessionDescription" : session_description,
3410+
"runQueue": run_queue
34103411
}
34113412

34123413
return self.client._perform_json(
@@ -3521,6 +3522,16 @@ def delete_trained_model(self, model_id):
35213522
self.client._perform_empty(
35223523
"DELETE", "/projects/%s/models/lab/%s/%s/models/%s" % (self.project_key, self.analysis_id, self.mltask_id, model_id))
35233524

3525+
def train_queue(self):
3526+
"""
3527+
Trains this MLTask's queue
3528+
3529+
:return: A dict including the next sessionID to be trained in the queue
3530+
:rtype dict
3531+
"""
3532+
return self.client._perform_json(
3533+
"POST", "/projects/%s/models/lab/%s/%s/actions/train-queue" % (self.project_key, self.analysis_id, self.mltask_id))
3534+
35243535
def deploy_to_flow(self, model_id, model_name, train_dataset, test_dataset=None, redo_optimization=True):
35253536
"""
35263537
Deploys a trained model from this ML Task to a saved model + train recipe in the Flow.
@@ -3606,3 +3617,17 @@ def guess(self, prediction_type=None, reguess_level=None):
36063617
"PUT",
36073618
"/projects/%s/models/lab/%s/%s/guess" % (self.project_key, self.analysis_id, self.mltask_id),
36083619
params = obj)
3620+
3621+
3622+
class DSSMLTaskQueues(object):
3623+
"""
3624+
Iterable listing of MLTask queues
3625+
"""
3626+
def __init__(self, data):
3627+
self.data = data
3628+
3629+
def __iter__(self):
3630+
return self.data["queues"].__iter__()
3631+
3632+
def get_raw(self):
3633+
return self.data

dataikuapi/dss/project.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .macro import DSSMacro
1717
from .wiki import DSSWiki
1818
from .discussion import DSSObjectDiscussions
19-
from .ml import DSSMLTask
19+
from .ml import DSSMLTask, DSSMLTaskQueues
2020
from .analysis import DSSAnalysis
2121
from .flow import DSSProjectFlow
2222
from .app import DSSAppManifest
@@ -638,6 +638,15 @@ def get_ml_task(self, analysis_id, mltask_id):
638638
"""
639639
return DSSMLTask(self.client, self.project_key, analysis_id, mltask_id)
640640

641+
def list_mltask_queues(self):
642+
"""
643+
List non-empty ML task queues in this project
644+
645+
:returns: an iterable :class:`DSSMLTaskQueues` listing of MLTask queues (each a dict)
646+
:rtype: :class:`DSSMLTaskQueues`
647+
"""
648+
data = self.client._perform_json("GET", "/projects/%s/models/labs/mltask-queues" % self.project_key)
649+
return DSSMLTaskQueues(data)
641650

642651
def create_analysis(self, input_dataset):
643652
"""

0 commit comments

Comments
 (0)