Skip to content

Commit d51ba05

Browse files
authored
Merge pull request #212 from dataiku/feature/sc-78242-experiment-tracking-run-details
Store the managed folder id in the MLflow artifact URI
2 parents ea2dac7 + f175f48 commit d51ba05

File tree

4 files changed

+45
-17
lines changed

4 files changed

+45
-17
lines changed

dataikuapi/dss/project.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1608,14 +1608,14 @@ def get_app_manifest(self):
16081608

16091609
# MLflow experiment tracking
16101610
########################################################
1611-
def setup_mlflow(self, managed_folder="mlflow_artifacts", host=None):
1611+
def setup_mlflow(self, managed_folder_name="mlflow_artifacts", host=None):
16121612
"""
16131613
Setup the dss-plugin for MLflow
16141614
1615-
:param str managed_folder: managed folder where artifacts are stored
1615+
:param str managed_folder_name: name of the managed folder where artifacts should be stored
16161616
:param str host: setup a custom host if the backend used is not DSS
16171617
"""
1618-
return MLflowHandle(client=self.client, project_key=self.project_key, managed_folder=managed_folder, host=host)
1618+
return MLflowHandle(client=self.client, project_key=self.project_key, managed_folder_name=managed_folder_name, host=host)
16191619

16201620
def get_mlflow_extension(self):
16211621
"""

dataikuapi/dss_plugin_mlflow/artifact_repository.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,22 @@
22
import posixpath
33
import tempfile
44
import urllib
5+
import re
56
from dataikuapi import DSSClient
67

78

89
def parse_dss_managed_folder_uri(uri):
9-
"""Parse an S3 URI, returning (bucket, path)"""
1010
parsed = urllib.parse.urlparse(uri)
1111
if parsed.scheme != "dss-managed-folder":
1212
raise Exception("Not a DSS Managed Folder URI: %s" % uri)
13-
return os.path.normpath(parsed.path)
14-
13+
pattern = re.compile("^(\w+\.)?\w{8}")
14+
if not parsed.netloc or not pattern.match(parsed.netloc):
15+
raise Exception("Could not find a managed folder id in URI: %s" % uri)
16+
return parsed
1517

1618
class PluginDSSManagedFolderArtifactRepository:
1719

1820
def __init__(self, artifact_uri):
19-
self.base_artifact_path = parse_dss_managed_folder_uri(artifact_uri)
2021
if os.environ.get("DSS_MLFLOW_APIKEY") is not None:
2122
self.client = DSSClient(
2223
os.environ.get("DSS_MLFLOW_HOST"),
@@ -28,14 +29,19 @@ def __init__(self, artifact_uri):
2829
internal_ticket=os.environ.get("DSS_MLFLOW_INTERNAL_TICKET")
2930
)
3031
self.project = self.client.get_project(os.environ.get("DSS_MLFLOW_PROJECTKEY"))
31-
managed_folders = [
32-
x["id"] for x in self.project.list_managed_folders()
33-
if x["name"] == os.environ.get("DSS_MLFLOW_MANAGED_FOLDER")
34-
]
35-
if len(managed_folders) > 0:
36-
self.managed_folder = self.project.get_managed_folder(managed_folders[0])
32+
parsed_uri = parse_dss_managed_folder_uri(artifact_uri)
33+
self.managed_folder = self.__get_managed_folder(parsed_uri.netloc)
34+
self.base_artifact_path = os.path.normpath(parsed_uri.path)
35+
36+
def __get_managed_folder(self, managed_folder_smart_id):
37+
chunks = managed_folder_smart_id.split('.')
38+
if len(chunks) == 1:
39+
return self.project.get_managed_folder(chunks[0])
40+
elif len(chunks) == 2:
41+
project = self.client.get_project(chunks[0])
42+
return project.get_managed_folder(chunks[1])
3743
else:
38-
self.managed_folder = self.project.create_managed_folder(os.environ.get("DSS_MLFLOW_MANAGED_FOLDER"))
44+
raise Exception("Invalid managed folder id: %s" % managed_folder_smart_id)
3945

4046
def log_artifact(self, local_file, artifact_path=None):
4147
"""

dataikuapi/dss_plugin_mlflow/header_provider.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ def request_headers(self):
1111
headers = {
1212
os.environ.get("DSS_MLFLOW_HEADER"): os.environ.get("DSS_MLFLOW_TOKEN"),
1313
"x-dku-mlflow-project-key": os.environ.get("DSS_MLFLOW_PROJECTKEY"),
14+
"x-dku-mlflow-managed-folder-id": os.environ.get("DSS_MLFLOW_MANAGED_FOLDER_ID"),
1415
}
1516
return headers

dataikuapi/dss_plugin_mlflow/utils.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class MLflowHandle:
9-
def __init__(self, client, project_key, managed_folder="mlflow_artifacts", host=None):
9+
def __init__(self, client, project_key, managed_folder_name, host=None):
1010
""" Add the MLflow-plugin parts of dataikuapi to MLflow local setup.
1111
1212
This method deals with
@@ -61,9 +61,30 @@ def __init__(self, client, project_key, managed_folder="mlflow_artifacts", host=
6161
self.mlflow_env.update({
6262
"DSS_MLFLOW_PROJECTKEY": project_key,
6363
"MLFLOW_TRACKING_URI": self.client.host + "/dip/publicapi" if host is None else host,
64-
"DSS_MLFLOW_HOST": self.client.host,
65-
"DSS_MLFLOW_MANAGED_FOLDER": managed_folder
64+
"DSS_MLFLOW_HOST": self.client.host
6665
})
66+
67+
# Get or create managed folder with name 'managed_folder_name'
68+
project = self.client.get_project(project_key)
69+
managed_folders = [
70+
x["id"] for x in project.list_managed_folders()
71+
if x["name"] == managed_folder_name
72+
]
73+
managed_folder = None
74+
if len(managed_folders) > 0:
75+
managed_folder = project.get_managed_folder(managed_folders[0])
76+
else:
77+
managed_folder = project.create_managed_folder(managed_folder_name)
78+
79+
80+
# TODO: don't allow to pass a managed folder name, everything is already
81+
# wired in artifact_repository and in the backend for DSS_MLFLOW_MANAGED_FOLDER_ID to contain
82+
# a smart ref. So, instead of managed_folder_name, we should take a DSSManagedFolder
83+
# or a managed folder id/smart ref and set DSS_MLFLOW_MANAGED_FOLDER_ID to that (smart) id.
84+
self.mlflow_env.update({
85+
"DSS_MLFLOW_MANAGED_FOLDER_ID": managed_folder.id
86+
})
87+
6788
os.environ.update(self.mlflow_env)
6889

6990
def clear(self):

0 commit comments

Comments
 (0)