Skip to content

Commit d8fddab

Browse files
author
Valentin Thorey
authored
Merge pull request #198 from dataiku/feature/sc-77060-refactor-setup-mlflow-context-manager
Refactor setup_mlflow context manager to be used either as a context manager or using the provided MLflowHandle
2 parents 8d4b549 + 13caa15 commit d8fddab

File tree

3 files changed

+74
-60
lines changed

3 files changed

+74
-60
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .utils import load_dss_mlflow_plugin
1+
from .utils import MLflowHandle
Lines changed: 71 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,76 @@
11
import os
22
import sys
33
import tempfile
4+
import shutil
5+
from base64 import b64encode
46

57

6-
def load_dss_mlflow_plugin():
7-
""" Function to dynamically add entrypoints for MLflow
8-
9-
MLflow uses entrypoints==0.3 to load entrypoints from plugins at import time.
10-
This function adds dss-mlflow-plugin entrypoints dynamically by adding them in sys.path
11-
at call time.
12-
"""
13-
tempdir = tempfile.mkdtemp()
14-
plugin_dir = os.path.join(tempdir, "dss-plugin-mlflow.egg-info")
15-
if not os.path.isdir(plugin_dir):
16-
os.makedirs(plugin_dir)
17-
with open(os.path.join(plugin_dir, "entry_points.txt"), "w") as f:
18-
f.write(
19-
"[mlflow.request_header_provider]\n"
20-
"unused=dataikuapi.dss_plugin_mlflow.header_provider:PluginDSSHeaderProvider\n"
21-
"[mlflow.artifact_repository]\n"
22-
"dss-managed-folder=dataikuapi.dss_plugin_mlflow.artifact_repository:PluginDSSManagedFolderArtifactRepository\n"
23-
)
24-
# Load plugin
25-
sys.path.insert(0, tempdir)
26-
return tempdir
8+
class MLflowHandle:
9+
def __init__(self, client, project_key, managed_folder="mlflow_artifacts", host=None):
10+
""" Add the MLflow-plugin parts of dataikuapi to MLflow local setup.
11+
12+
This method deals with
13+
1. importing dynamically the DSS MLflow plugin:
14+
MLflow uses entrypoints==0.3 to load entrypoints from plugins at import time.
15+
We add dss-mlflow-plugin entrypoints dynamically by adding them in sys.path
16+
at call time.
17+
18+
2. Setup the correct environment variables to setup the plugin to work
19+
with DSS backend as the tracking backend and enable using DSS managed folder
20+
as artifact location.
21+
"""
22+
self.project_key = project_key
23+
self.client = client
24+
self.mlflow_env = {}
25+
26+
# Load DSS as the plugin
27+
self.tempdir = tempfile.mkdtemp()
28+
plugin_dir = os.path.join(self.tempdir, "dss-plugin-mlflow.egg-info")
29+
if not os.path.isdir(plugin_dir):
30+
os.makedirs(plugin_dir)
31+
with open(os.path.join(plugin_dir, "entry_points.txt"), "w") as f:
32+
f.write(
33+
"[mlflow.request_header_provider]\n"
34+
"unused=dataikuapi.dss_plugin_mlflow.header_provider:PluginDSSHeaderProvider\n"
35+
"[mlflow.artifact_repository]\n"
36+
"dss-managed-folder=dataikuapi.dss_plugin_mlflow.artifact_repository:PluginDSSManagedFolderArtifactRepository\n"
37+
)
38+
sys.path.insert(0, self.tempdir)
39+
40+
# Setup authentication
41+
if client._session.auth is not None:
42+
self.mlflow_env.update({
43+
"DSS_MLFLOW_HEADER": "Authorization",
44+
"DSS_MLFLOW_TOKEN": "Basic {}".format(
45+
b64encode("{}:".format(self.client._session.auth.username).encode("utf-8")).decode("utf-8")),
46+
"DSS_MLFLOW_APIKEY": self.client.api_key
47+
})
48+
elif client.internal_ticket:
49+
self.mlflow_env.update({
50+
"DSS_MLFLOW_HEADER": "X-DKU-APITicket",
51+
"DSS_MLFLOW_TOKEN": self.client.internal_ticket,
52+
"DSS_MLFLOW_INTERNAL_TICKET": self.client.internal_ticket
53+
})
54+
# Set host, tracking URI and project key
55+
self.mlflow_env.update({
56+
"DSS_MLFLOW_PROJECTKEY": project_key,
57+
"MLFLOW_TRACKING_URI": self.client.host + "/dip/publicapi" if host is None else host,
58+
"DSS_MLFLOW_HOST": self.client.host,
59+
"DSS_MLFLOW_MANAGED_FOLDER": managed_folder
60+
})
61+
os.environ.update(self.mlflow_env)
62+
63+
def clear(self):
64+
shutil.rmtree(self.tempdir)
65+
for variable in self.mlflow_env:
66+
os.environ.pop(variable, None)
67+
68+
def import_mlflow(self):
69+
import mlflow
70+
return mlflow
71+
72+
def __enter__(self):
73+
return self.import_mlflow()
74+
75+
def __exit__(self, exc_type, exc_value, traceback):
76+
self.clear()

dataikuapi/dssclient.py

Lines changed: 2 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,11 @@
11
import json
2-
import os
3-
import shutil
42

53
from requests import Session
64
from requests import exceptions
75
from requests.auth import HTTPBasicAuth
8-
from contextlib import contextmanager
96

107
from dataikuapi.dss.notebook import DSSNotebook
11-
from .dss_plugin_mlflow import load_dss_mlflow_plugin
8+
from .dss_plugin_mlflow import MLflowHandle
129
from .dss.future import DSSFuture
1310
from .dss.projectfolder import DSSProjectFolder
1411
from .dss.project import DSSProject
@@ -21,7 +18,6 @@
2118
from .dss.apideployer import DSSAPIDeployer
2219
from .dss.projectdeployer import DSSProjectDeployer
2320
import os.path as osp
24-
from base64 import b64encode
2521
from .utils import DataikuException, dku_basestring_type
2622

2723
class DSSClient(object):
@@ -1095,7 +1091,6 @@ def get_object_discussions(self, project_key, object_type, object_id):
10951091
########################################################
10961092
# MLflow
10971093
########################################################
1098-
@contextmanager
10991094
def setup_mlflow(self, project_key, managed_folder="mlflow_artifacts", host=None):
11001095
"""
11011096
Setup the dss-plugin for MLflow
@@ -1104,38 +1099,7 @@ def setup_mlflow(self, project_key, managed_folder="mlflow_artifacts", host=None
11041099
:param str managed_folder: managed folder where artifacts are stored
11051100
:param str host: setup a custom host if the backend used is not DSS
11061101
"""
1107-
tempdir = load_dss_mlflow_plugin()
1108-
mlflow_env = {}
1109-
if self._session.auth is not None:
1110-
mlflow_env.update({
1111-
"DSS_MLFLOW_HEADER": "Authorization",
1112-
"DSS_MLFLOW_TOKEN": "Basic {}".format(
1113-
b64encode("{}:".format(self._session.auth.username).encode("utf-8")).decode("utf-8")),
1114-
"DSS_MLFLOW_APIKEY": self.api_key
1115-
})
1116-
elif self.internal_ticket:
1117-
mlflow_env.update({
1118-
"DSS_MLFLOW_HEADER": "X-DKU-APITicket",
1119-
"DSS_MLFLOW_TOKEN": self.internal_ticket,
1120-
"DSS_MLFLOW_INTERNAL_TICKET": self.internal_ticket
1121-
})
1122-
mlflow_env.update({
1123-
"DSS_MLFLOW_PROJECTKEY": project_key,
1124-
"MLFLOW_TRACKING_URI": self.host + "/dip/publicapi" if host is None else host,
1125-
"DSS_MLFLOW_HOST": self.host,
1126-
"DSS_MLFLOW_MANAGED_FOLDER": managed_folder,
1127-
})
1128-
os.environ.update(mlflow_env)
1129-
1130-
try:
1131-
import mlflow
1132-
yield mlflow
1133-
except Exception as e:
1134-
raise e
1135-
finally:
1136-
shutil.rmtree(tempdir)
1137-
for variable in mlflow_env:
1138-
os.environ.pop(variable, None)
1102+
return MLflowHandle(client=self, project_key=project_key, managed_folder=managed_folder, host=host)
11391103

11401104

11411105
class TemporaryImportHandle(object):

0 commit comments

Comments
 (0)