Skip to content

Commit 00612fc

Browse files
author
Louis P
authored
Merge pull request #82 from dataiku/feature/dss80-get-origin-mltask
Add two methods to facilitate navigation from SM's objects to analysis' objects
2 parents 83e9ef4 + f31ed4a commit 00612fc

File tree

2 files changed

+47
-1
lines changed

2 files changed

+47
-1
lines changed

dataikuapi/dss/ml.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import re
2+
13
from ..utils import DataikuException
24
from ..utils import DataikuUTF8CSVReader
35
from ..utils import DataikuStreamedHttpUTF8CSVReader
@@ -524,6 +526,22 @@ def save_user_meta(self):
524526
"PUT", "/projects/%s/savedmodels/%s/versions/%s/user-meta" % (self.saved_model.project_key,
525527
self.saved_model.sm_id, self.saved_model_version), body = um)
526528

529+
def get_origin_analysis_trained_model(self):
530+
"""
531+
Fetch details about the model in an analysis, this model has been exported from. Returns None if the
532+
deployed trained model does not have an origin analysis trained model.
533+
534+
:rtype: DSSTrainedModelDetails | None
535+
"""
536+
if self.saved_model is None:
537+
return self
538+
else:
539+
fmi = self.get_raw().get("smOrigin", {}).get("fullModelId")
540+
if fmi is not None:
541+
origin_ml_task = DSSMLTask.from_full_model_id(self.saved_model.client, fmi,
542+
project_key=self.saved_model.project_key)
543+
return origin_ml_task.get_trained_model_details(fmi)
544+
527545
class DSSTreeNode(object):
528546
def __init__(self, tree, i):
529547
self.tree = tree
@@ -1425,6 +1443,17 @@ def get_scatter_plots(self):
14251443

14261444

14271445
class DSSMLTask(object):
1446+
1447+
@staticmethod
1448+
def from_full_model_id(client, fmi, project_key=None):
1449+
match = re.match("^A-(\w+)-(\w+)-(\w+)-(s[0-9]+)-(pp[0-9]+(-part-(\w+)|-base)?)-(m[0-9]+)$", fmi)
1450+
if match is None:
1451+
return DataikuException("Invalid model id: {}".format(fmi))
1452+
else:
1453+
if project_key is None:
1454+
project_key = match.group(1)
1455+
return DSSMLTask(client, project_key, match.group(2), match.group(3))
1456+
14281457
"""A handle to interact with a MLTask for prediction or clustering in a DSS visual analysis"""
14291458
def __init__(self, client, project_key, analysis_id, mltask_id):
14301459
self.client = client

dataikuapi/dss/savedmodel.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .ml import DSSTrainedPredictionModelDetails, DSSTrainedClusteringModelDetails
1+
from dataikuapi.dss.ml import DSSMLTask
22
from .metrics import ComputedMetrics
33
from .ml import DSSTrainedClusteringModelDetails
44
from .ml import DSSTrainedPredictionModelDetails
@@ -15,6 +15,9 @@ def __init__(self, client, project_key, sm_id):
1515
self.project_key = project_key
1616
self.sm_id = sm_id
1717

18+
def get_definition(self):
19+
return self.client._perform_json(
20+
"GET", "/projects/%s/savedmodels/%s" % (self.project_key, self.sm_id))
1821

1922
########################################################
2023
# Versions
@@ -86,6 +89,20 @@ def delete_versions(self, versions, remove_intermediate=True):
8689
self.client._perform_empty(
8790
"POST", "/projects/%s/savedmodels/%s/actions/delete-versions" % (self.project_key, self.sm_id),
8891
body=body)
92+
93+
def get_origin_ml_task(self):
94+
"""
95+
Fetch the last ML task that has been exported to this saved model. Returns None if the saved model
96+
does not have an origin ml task.
97+
98+
:rtype: DSSMLTask | None
99+
"""
100+
fmi = self.get_definition().get("lastExportedFrom")
101+
if fmi is not None:
102+
origin_ml_task = DSSMLTask.from_full_model_id(self.client, fmi, project_key=self.project_key)
103+
return origin_ml_task.get_trained_model_details(fmi)
104+
105+
89106
########################################################
90107
# Metrics
91108
########################################################

0 commit comments

Comments
 (0)