Skip to content

Commit 6c5ecc1

Browse files
committed
beef up the listing of trained models
1 parent 89b9358 commit 6c5ecc1

File tree

1 file changed

+29
-8
lines changed

1 file changed

+29
-8
lines changed

dataikuapi/dss/ml.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -510,7 +510,8 @@ def wait_train_complete(self):
510510
break
511511
time.sleep(2)
512512

513-
def get_trained_models_ids(self):
513+
514+
def get_trained_models_ids(self, session_id=None, algorithm=None):
514515
"""
515516
Gets the list of trained model identifiers for this ML task.
516517
@@ -519,23 +520,43 @@ def get_trained_models_ids(self):
519520
:return: A list of model identifiers
520521
:rtype: list of strings
521522
"""
522-
status = self.get_status()
523-
return [x["id"] for x in status["fullModelIds"]]
523+
full_model_ids = self.get_status()["fullModelIds"]
524+
if session_id is not None:
525+
full_model_ids = [fmi for fmi in full_model_ids if fmi.get('fullModelId', {}).get('sessionId', '') == session_id]
526+
model_ids = [x["id"] for x in full_model_ids]
527+
if algorithm is not None:
528+
# algorithm is in the snippets
529+
model_ids = [fmi for fmi, s in self.get_trained_model_snippet(ids=model_ids).iteritems() if s.get("algorithm", "") == algorithm]
530+
return model_ids
524531

525532

526-
def get_trained_model_snippet(self, id):
533+
def get_trained_model_snippet(self, id=None, ids=None):
527534
"""
528535
Gets a quick summary of a trained model, as a dict. For complete information and a structured object, use :meth:get_trained_model_details
529536
537+
:param str id: a model id
538+
:param list ids: a list of model ids
539+
530540
:rtype: dict
531541
"""
532-
obj = {
533-
"modelsIds" : [id]
534-
}
542+
if id is not None:
543+
obj = {
544+
"modelsIds" : [id]
545+
}
546+
elif ids is not None:
547+
obj = {
548+
"modelsIds" : ids
549+
}
550+
else:
551+
obj = {}
552+
535553
ret = self.client._perform_json(
536554
"POST", "/projects/%s/models/lab/%s/%s/models-snippets" % (self.project_key, self.analysis_id, self.mltask_id),
537555
body = obj)
538-
return ret[id]
556+
if id is not None:
557+
return ret[id]
558+
else:
559+
return ret
539560

540561
def get_trained_model_details(self, id):
541562
"""

0 commit comments

Comments
 (0)