Skip to content

Commit 67cf94b

Browse files
Valentin Thoreylpenet
andauthored
Add DSS plugin for MLflow (#191)
* Add DSS plugin for MLflow * Add doc to mlflow plugin * Fix inconsistent wording * Prevent creating multiple tempdir for mlflow plugin * Set host while setting up mlflow plugin * Rename MLflow plugin for consistency * Add artifact part of the mlflow plugin * Add log_artifact to mlflow plugin * Add log_artifacts and list_artifacts * Adapt wording to backend * Add deleta and download to artifact plugin * Add context manager for dss plugin * Revert formatting * Fix typo Co-authored-by: Ludovic Pénet <ludovic.penet@dataiku.com> Co-authored-by: Ludovic Pénet <ludovic.penet@dataiku.com>
1 parent b78b172 commit 67cf94b

File tree

5 files changed

+337
-0
lines changed

5 files changed

+337
-0
lines changed
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .utils import load_dss_mlflow_plugin
Lines changed: 244 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,244 @@
1+
import os
2+
import posixpath
3+
import tempfile
4+
import urllib
5+
from dataikuapi import DSSClient
6+
7+
8+
def parse_dss_managed_folder_uri(uri):
9+
"""Parse an S3 URI, returning (bucket, path)"""
10+
parsed = urllib.parse.urlparse(uri)
11+
if parsed.scheme != "dss-managed-folder":
12+
raise Exception("Not a DSS Managed Folder URI: %s" % uri)
13+
return os.path.normpath(parsed.path)
14+
15+
16+
class PluginDSSManagedFolderArtifactRepository:
17+
18+
def __init__(self, artifact_uri):
19+
self.base_artifact_path = parse_dss_managed_folder_uri(artifact_uri)
20+
if os.environ.get("DSS_MLFLOW_APIKEY") is not None:
21+
self.client = DSSClient(
22+
os.environ.get("DSS_MLFLOW_HOST"),
23+
api_key=os.environ.get("DSS_MLFLOW_APIKEY")
24+
)
25+
else:
26+
self.client = DSSClient(
27+
os.environ.get("DSS_MLFLOW_HOST"),
28+
internal_ticket=os.environ.get("DSS_MLFLOW_INTERNAL_TICKET")
29+
)
30+
self.project = self.client.get_project(os.environ.get("DSS_MLFLOW_PROJECTKEY"))
31+
managed_folders = [
32+
x["id"] for x in self.project.list_managed_folders()
33+
if x["name"] == os.environ.get("DSS_MLFLOW_MANAGED_FOLDER")
34+
]
35+
if len(managed_folders) > 0:
36+
self.managed_folder = self.project.get_managed_folder(managed_folders[0])
37+
else:
38+
self.managed_folder = self.project.create_managed_folder(os.environ.get("DSS_MLFLOW_MANAGED_FOLDER"))
39+
40+
def log_artifact(self, local_file, artifact_path=None):
41+
"""
42+
Log a local file as an artifact, optionally taking an ``artifact_path`` to place it in
43+
within the run's artifacts. Run artifacts can be organized into directories, so you can
44+
place the artifact in a directory this way.
45+
46+
:param local_file: Path to artifact to log
47+
:param artifact_path: Directory within the run's artifact directory in which to log the
48+
artifact.
49+
"""
50+
path = (
51+
os.path.join(self.base_artifact_path, artifact_path) if artifact_path else self.base_artifact_path
52+
)
53+
self.managed_folder.put_file(os.path.join(path, os.path.basename(local_file)), open(local_file))
54+
55+
def log_artifacts(self, local_dir, artifact_path=None):
56+
"""
57+
Log the files in the specified local directory as artifacts, optionally taking
58+
an ``artifact_path`` to place them in within the run's artifacts.
59+
60+
:param local_dir: Directory of local artifacts to log
61+
:param artifact_path: Directory within the run's artifact directory in which to log the
62+
artifacts
63+
"""
64+
path = (
65+
os.path.join(self.base_artifact_path, artifact_path) if artifact_path else self.base_artifact_path
66+
)
67+
self.managed_folder.upload_folder(path, local_dir)
68+
69+
def list_artifacts(self, path):
70+
"""
71+
Return all the artifacts for this run_id directly under path. If path is a file, returns
72+
an empty list. Will error if path is neither a file nor directory.
73+
74+
:param path: Relative source path that contains desired artifacts
75+
76+
:return: List of artifacts as FileInfo listed directly under path.
77+
"""
78+
from pathlib import Path
79+
path = Path(os.path.join(self.base_artifact_path, path))
80+
files = [x["path"] for x in self.managed_folder.list_contents().get("items", []) if path in Path(x["path"]).parents]
81+
return files
82+
83+
def _is_directory(self, artifact_path):
84+
listing = self.list_artifacts(artifact_path)
85+
return len(listing) > 0
86+
87+
def _create_download_destination(self, src_artifact_path, dst_local_dir_path=None):
88+
"""
89+
Creates a local filesystem location to be used as a destination for downloading the artifact
90+
specified by `src_artifact_path`. The destination location is a subdirectory of the
91+
specified `dst_local_dir_path`, which is determined according to the structure of
92+
`src_artifact_path`. For example, if `src_artifact_path` is `dir1/file1.txt`, then the
93+
resulting destination path is `<dst_local_dir_path>/dir1/file1.txt`. Local directories are
94+
created for the resulting destination location if they do not exist.
95+
96+
:param src_artifact_path: A relative, POSIX-style path referring to an artifact stored
97+
within the repository's artifact root location.
98+
`src_artifact_path` should be specified relative to the
99+
repository's artifact root location.
100+
:param dst_local_dir_path: The absolute path to a local filesystem directory in which the
101+
local destination path will be contained. The local destination
102+
path may be contained in a subdirectory of `dst_root_dir` if
103+
`src_artifact_path` contains subdirectories.
104+
:return: The absolute path to a local filesystem location to be used as a destination
105+
for downloading the artifact specified by `src_artifact_path`.
106+
"""
107+
src_artifact_path = src_artifact_path.rstrip("/") # Ensure correct dirname for trailing '/'
108+
dirpath = posixpath.dirname(src_artifact_path)
109+
local_dir_path = os.path.join(dst_local_dir_path, dirpath)
110+
local_file_path = os.path.join(dst_local_dir_path, src_artifact_path)
111+
if not os.path.exists(local_dir_path):
112+
os.makedirs(local_dir_path, exist_ok=True)
113+
return local_file_path
114+
115+
def download_artifacts(self, artifact_path, dst_path=None):
116+
"""
117+
Download an artifact file or directory to a local directory if applicable, and return a
118+
local path for it.
119+
The caller is responsible for managing the lifecycle of the downloaded artifacts.
120+
121+
:param artifact_path: Relative source path to the desired artifacts.
122+
:param dst_path: Absolute path of the local filesystem destination directory to which to
123+
download the specified artifacts. This directory must already exist.
124+
If unspecified, the artifacts will either be downloaded to a new
125+
uniquely-named directory on the local filesystem or will be returned
126+
directly in the case of the LocalArtifactRepository.
127+
128+
:return: Absolute path of the local filesystem location containing the desired artifacts.
129+
"""
130+
from mlflow.exceptions import MlflowException
131+
from mlflow.protos.databricks_pb2 import INVALID_PARAMETER_VALUE, RESOURCE_DOES_NOT_EXIST
132+
133+
# TODO: Probably need to add a more efficient method to stream just a single artifact
134+
# without downloading it, or to get a pre-signed URL for cloud storage.
135+
def download_artifact(src_artifact_path, dst_local_dir_path):
136+
"""
137+
Download the file artifact specified by `src_artifact_path` to the local filesystem
138+
directory specified by `dst_local_dir_path`.
139+
140+
:param src_artifact_path: A relative, POSIX-style path referring to a file artifact
141+
stored within the repository's artifact root location.
142+
`src_artifact_path` should be specified relative to the
143+
repository's artifact root location.
144+
:param dst_local_dir_path: Absolute path of the local filesystem destination directory
145+
to which to download the specified artifact. The downloaded
146+
artifact may be written to a subdirectory of
147+
`dst_local_dir_path` if `src_artifact_path` contains
148+
subdirectories.
149+
:return: A local filesystem path referring to the downloaded file.
150+
"""
151+
local_destination_file_path = self._create_download_destination(
152+
src_artifact_path=src_artifact_path, dst_local_dir_path=dst_local_dir_path
153+
)
154+
self._download_file(
155+
remote_file_path=src_artifact_path, local_path=local_destination_file_path
156+
)
157+
return local_destination_file_path
158+
159+
def download_artifact_dir(src_artifact_dir_path, dst_local_dir_path):
160+
local_dir = os.path.join(dst_local_dir_path, src_artifact_dir_path)
161+
dir_content = [ # prevent infinite loop, sometimes the dir is recursively included
162+
file_info
163+
for file_info in self.list_artifacts(src_artifact_dir_path)
164+
if file_info.path != "." and file_info.path != src_artifact_dir_path
165+
]
166+
if not dir_content: # empty dir
167+
if not os.path.exists(local_dir):
168+
os.makedirs(local_dir, exist_ok=True)
169+
else:
170+
for file_info in dir_content:
171+
if file_info.is_dir:
172+
download_artifact_dir(
173+
src_artifact_dir_path=file_info.path,
174+
dst_local_dir_path=dst_local_dir_path,
175+
)
176+
else:
177+
download_artifact(
178+
src_artifact_path=file_info.path, dst_local_dir_path=dst_local_dir_path
179+
)
180+
return local_dir
181+
182+
if dst_path is None:
183+
dst_path = tempfile.mkdtemp()
184+
dst_path = os.path.abspath(dst_path)
185+
186+
if not os.path.exists(dst_path):
187+
raise MlflowException(
188+
message=(
189+
"The destination path for downloaded artifacts does not"
190+
" exist! Destination path: {dst_path}".format(dst_path=dst_path)
191+
),
192+
error_code=RESOURCE_DOES_NOT_EXIST,
193+
)
194+
elif not os.path.isdir(dst_path):
195+
raise MlflowException(
196+
message=(
197+
"The destination path for downloaded artifacts must be a directory!"
198+
" Destination path: {dst_path}".format(dst_path=dst_path)
199+
),
200+
error_code=INVALID_PARAMETER_VALUE,
201+
)
202+
203+
# Check if the artifacts points to a directory
204+
if self._is_directory(artifact_path):
205+
return download_artifact_dir(
206+
src_artifact_dir_path=artifact_path, dst_local_dir_path=dst_path
207+
)
208+
else:
209+
return download_artifact(src_artifact_path=artifact_path, dst_local_dir_path=dst_path)
210+
211+
def _download_file(self, remote_file_path, local_path):
212+
"""
213+
Download the file at the specified relative remote path and saves
214+
it at the specified local path.
215+
216+
:param remote_file_path: Source path to the remote file, relative to the root
217+
directory of the artifact repository.
218+
:param local_path: The path to which to save the downloaded file.
219+
"""
220+
with self.managed_folder.get_file(remote_file_path) as remote_file:
221+
with open(local_path, "wb") as local_file:
222+
for line in remote_file:
223+
local_file.write(line)
224+
225+
def delete_artifacts(self, artifact_path=None):
226+
"""
227+
Delete the artifacts at the specified location.
228+
Supports the deletion of a single file or of a directory. Deletion of a directory
229+
is recursive.
230+
:param artifact_path: Path of the artifact to delete
231+
"""
232+
path = (
233+
os.path.join(self.base_artifact_path, artifact_path) if artifact_path else self.base_artifact_path
234+
)
235+
self.managed_folder.delete_file(path)
236+
237+
238+
def verify_artifact_path(artifact_path):
239+
from mlflow.exceptions import MlflowException
240+
from mlflow.utils.validation import path_not_unique, bad_path_message
241+
if artifact_path and path_not_unique(artifact_path):
242+
raise MlflowException(
243+
"Invalid artifact path: '%s'. %s" % (artifact_path, bad_path_message(artifact_path))
244+
)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import os
2+
3+
class PluginDSSHeaderProvider:
4+
5+
def in_context(self):
6+
env_variables = {"DSS_MLFLOW_HEADER", "DSS_MLFLOW_TOKEN", "DSS_MLFLOW_PROJECTKEY"}
7+
if len(env_variables.difference(set(os.environ))) == 0:
8+
return True
9+
10+
def request_headers(self):
11+
headers = {
12+
os.environ.get("DSS_MLFLOW_HEADER"): os.environ.get("DSS_MLFLOW_TOKEN"),
13+
"x-dku-mlflow-project-key": os.environ.get("DSS_MLFLOW_PROJECTKEY"),
14+
}
15+
return headers
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import os
2+
import sys
3+
import tempfile
4+
5+
6+
def load_dss_mlflow_plugin():
7+
""" Function to dynamically add entrypoints for MLflow
8+
9+
MLflow uses entrypoints==0.3 to load entrypoints from plugins at import time.
10+
This function adds dss-mlflow-plugin entrypoints dynamically by adding them in sys.path
11+
at call time.
12+
"""
13+
tempdir = tempfile.mkdtemp()
14+
plugin_dir = os.path.join(tempdir, "dss-plugin-mlflow.egg-info")
15+
if not os.path.isdir(plugin_dir):
16+
os.makedirs(plugin_dir)
17+
with open(os.path.join(plugin_dir, "entry_points.txt"), "w") as f:
18+
f.write(
19+
"[mlflow.request_header_provider]\n"
20+
"unused=dataikuapi.dss_plugin_mlflow.header_provider:PluginDSSHeaderProvider\n"
21+
"[mlflow.artifact_repository]\n"
22+
"dss-managed-folder=dataikuapi.dss_plugin_mlflow.artifact_repository:PluginDSSManagedFolderArtifactRepository\n"
23+
)
24+
# Load plugin
25+
sys.path.insert(0, tempdir)
26+
return tempdir

dataikuapi/dssclient.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
import json
2+
import os
3+
import shutil
24

35
from requests import Session
46
from requests import exceptions
57
from requests.auth import HTTPBasicAuth
8+
from contextlib import contextmanager
69

710
from dataikuapi.dss.notebook import DSSNotebook
11+
from .dss_plugin_mlflow import load_dss_mlflow_plugin
812
from .dss.future import DSSFuture
913
from .dss.projectfolder import DSSProjectFolder
1014
from .dss.project import DSSProject
@@ -17,6 +21,7 @@
1721
from .dss.apideployer import DSSAPIDeployer
1822
from .dss.projectdeployer import DSSProjectDeployer
1923
import os.path as osp
24+
from base64 import b64encode
2025
from .utils import DataikuException, dku_basestring_type
2126

2227
class DSSClient(object):
@@ -1087,6 +1092,52 @@ def get_object_discussions(self, project_key, object_type, object_id):
10871092
"""
10881093
return DSSObjectDiscussions(self, project_key, object_type, object_id)
10891094

1095+
########################################################
1096+
# MLflow
1097+
########################################################
1098+
@contextmanager
1099+
def setup_mlflow(self, project_key, managed_folder="mlflow_artifacts", host=None):
1100+
"""
1101+
Setup the dss-plugin for MLflow
1102+
1103+
:param str project_key: identifier of the project to access
1104+
:param str managed_folder: managed folder where artifacts are stored
1105+
:param str host: setup a custom host if the backend used is not DSS
1106+
"""
1107+
tempdir = load_dss_mlflow_plugin()
1108+
mlflow_env = {}
1109+
if self._session.auth is not None:
1110+
mlflow_env.update({
1111+
"DSS_MLFLOW_HEADER": "Authorization",
1112+
"DSS_MLFLOW_TOKEN": "Basic {}".format(
1113+
b64encode("{}:".format(self._session.auth.username).encode("utf-8")).decode("utf-8")),
1114+
"DSS_MLFLOW_APIKEY": self.api_key
1115+
})
1116+
elif self.internal_ticket:
1117+
mlflow_env.update({
1118+
"DSS_MLFLOW_HEADER": "X-DKU-APITicket",
1119+
"DSS_MLFLOW_TOKEN": self.internal_ticket,
1120+
"DSS_MLFLOW_INTERNAL_TICKET": self.internal_ticket
1121+
})
1122+
mlflow_env.update({
1123+
"DSS_MLFLOW_PROJECTKEY": project_key,
1124+
"MLFLOW_TRACKING_URI": self.host + "/dip/publicapi" if host is None else host,
1125+
"DSS_MLFLOW_HOST": self.host,
1126+
"DSS_MLFLOW_MANAGED_FOLDER": managed_folder,
1127+
})
1128+
os.environ.update(mlflow_env)
1129+
1130+
try:
1131+
import mlflow
1132+
yield mlflow
1133+
except Exception as e:
1134+
raise e
1135+
finally:
1136+
shutil.rmtree(tempdir)
1137+
for variable in mlflow_env:
1138+
os.environ.pop(variable, None)
1139+
1140+
10901141
class TemporaryImportHandle(object):
10911142
def __init__(self, client, import_id):
10921143
self.client = client

0 commit comments

Comments
 (0)