Skip to content

Commit 8d4b549

Browse files
author
Valentin Thorey
committed
Revert "Refactor setup_mlflow contextmanager"
This reverts commit 9e5c95b.
1 parent 9e5c95b commit 8d4b549

File tree

3 files changed

+60
-73
lines changed

3 files changed

+60
-73
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .utils import MLflowHandle
1+
from .utils import load_dss_mlflow_plugin
Lines changed: 21 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,26 @@
11
import os
22
import sys
33
import tempfile
4-
import shutil
5-
from base64 import b64encode
64

75

8-
class MLflowHandle:
9-
def __init__(self, client, project_key, managed_folder, host=None):
10-
""" Add the MLflow-plugin parts of dataikuapi to MLflow local setup.
11-
12-
This functions deals with
13-
1. importing dynamically the DSS Mlflow plugin:
14-
MLflow uses entrypoints==0.3 to load entrypoints from plugins at import time.
15-
We add dss-mlflow-plugin entrypoints dynamically by adding them in sys.path
16-
at call time.
17-
18-
2. Setup the correct environemnent variables to setup the plugin to work
19-
with DSS backend as the tracking backend and enable using DSS managed folder
20-
as artifact location.
21-
"""
22-
self.project_key = project_key
23-
self.client = client
24-
self.mlflow_env = {}
25-
26-
# Load DSS as the plugin
27-
self.tempdir = tempfile.mkdtemp()
28-
plugin_dir = os.path.join(self.tempdir, "dss-plugin-mlflow.egg-info")
29-
if not os.path.isdir(plugin_dir):
30-
os.makedirs(plugin_dir)
31-
with open(os.path.join(plugin_dir, "entry_points.txt"), "w") as f:
32-
f.write(
33-
"[mlflow.request_header_provider]\n"
34-
"unused=dataikuapi.dss_plugin_mlflow.header_provider:PluginDSSHeaderProvider\n"
35-
"[mlflow.artifact_repository]\n"
36-
"dss-managed-folder=dataikuapi.dss_plugin_mlflow.artifact_repository:PluginDSSManagedFolderArtifactRepository\n"
37-
)
38-
sys.path.insert(0, self.tempdir)
39-
40-
# Setup authentication
41-
if client._session.auth is not None:
42-
self.mlflow_env.update({
43-
"DSS_MLFLOW_HEADER": "Authorization",
44-
"DSS_MLFLOW_TOKEN": "Basic {}".format(
45-
b64encode("{}:".format(self.client._session.auth.username).encode("utf-8")).decode("utf-8")),
46-
"DSS_MLFLOW_APIKEY": self.client.api_key
47-
})
48-
elif self.internal_ticket:
49-
self.mlflow_env.update({
50-
"DSS_MLFLOW_HEADER": "X-DKU-APITicket",
51-
"DSS_MLFLOW_TOKEN": self.client.internal_ticket,
52-
"DSS_MLFLOW_INTERNAL_TICKET": self.client.internal_ticket
53-
})
54-
# Set host, tracking URI and project key
55-
self.mlflow_env.update({
56-
"DSS_MLFLOW_PROJECTKEY": project_key,
57-
"MLFLOW_TRACKING_URI": self.client.host + "/dip/publicapi" if host is None else host,
58-
"DSS_MLFLOW_HOST": self.client.host,
59-
})
60-
os.environ.update(self.mlflow_env)
61-
62-
def clear(self):
63-
shutil.rmtree(self.tempdir)
64-
for variable in self.mlflow_env:
65-
os.environ.pop(variable, None)
66-
67-
def import_mlflow(self):
68-
import mlflow
69-
return mlflow
70-
71-
def __enter__(self):
72-
return self.import_mlflow()
73-
74-
def __exit__(self, exc_type, exc_value, traceback):
75-
self.clear()
6+
def load_dss_mlflow_plugin():
7+
""" Function to dynamically add entrypoints for MLflow
8+
9+
MLflow uses entrypoints==0.3 to load entrypoints from plugins at import time.
10+
This function adds dss-mlflow-plugin entrypoints dynamically by adding them in sys.path
11+
at call time.
12+
"""
13+
tempdir = tempfile.mkdtemp()
14+
plugin_dir = os.path.join(tempdir, "dss-plugin-mlflow.egg-info")
15+
if not os.path.isdir(plugin_dir):
16+
os.makedirs(plugin_dir)
17+
with open(os.path.join(plugin_dir, "entry_points.txt"), "w") as f:
18+
f.write(
19+
"[mlflow.request_header_provider]\n"
20+
"unused=dataikuapi.dss_plugin_mlflow.header_provider:PluginDSSHeaderProvider\n"
21+
"[mlflow.artifact_repository]\n"
22+
"dss-managed-folder=dataikuapi.dss_plugin_mlflow.artifact_repository:PluginDSSManagedFolderArtifactRepository\n"
23+
)
24+
# Load plugin
25+
sys.path.insert(0, tempdir)
26+
return tempdir

dataikuapi/dssclient.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import json
2+
import os
3+
import shutil
24

35
from requests import Session
46
from requests import exceptions
57
from requests.auth import HTTPBasicAuth
8+
from contextlib import contextmanager
69

710
from dataikuapi.dss.notebook import DSSNotebook
8-
from .dss_plugin_mlflow import MLflowHandle
11+
from .dss_plugin_mlflow import load_dss_mlflow_plugin
912
from .dss.future import DSSFuture
1013
from .dss.projectfolder import DSSProjectFolder
1114
from .dss.project import DSSProject
@@ -18,6 +21,7 @@
1821
from .dss.apideployer import DSSAPIDeployer
1922
from .dss.projectdeployer import DSSProjectDeployer
2023
import os.path as osp
24+
from base64 import b64encode
2125
from .utils import DataikuException, dku_basestring_type
2226

2327
class DSSClient(object):
@@ -1091,6 +1095,7 @@ def get_object_discussions(self, project_key, object_type, object_id):
10911095
########################################################
10921096
# MLflow
10931097
########################################################
1098+
@contextmanager
10941099
def setup_mlflow(self, project_key, managed_folder="mlflow_artifacts", host=None):
10951100
"""
10961101
Setup the dss-plugin for MLflow
@@ -1099,7 +1104,38 @@ def setup_mlflow(self, project_key, managed_folder="mlflow_artifacts", host=None
10991104
:param str managed_folder: managed folder where artifacts are stored
11001105
:param str host: setup a custom host if the backend used is not DSS
11011106
"""
1102-
return MLflowHandle(client=self, project_key=project_key, managed_folder=managed_folder, host=host)
1107+
tempdir = load_dss_mlflow_plugin()
1108+
mlflow_env = {}
1109+
if self._session.auth is not None:
1110+
mlflow_env.update({
1111+
"DSS_MLFLOW_HEADER": "Authorization",
1112+
"DSS_MLFLOW_TOKEN": "Basic {}".format(
1113+
b64encode("{}:".format(self._session.auth.username).encode("utf-8")).decode("utf-8")),
1114+
"DSS_MLFLOW_APIKEY": self.api_key
1115+
})
1116+
elif self.internal_ticket:
1117+
mlflow_env.update({
1118+
"DSS_MLFLOW_HEADER": "X-DKU-APITicket",
1119+
"DSS_MLFLOW_TOKEN": self.internal_ticket,
1120+
"DSS_MLFLOW_INTERNAL_TICKET": self.internal_ticket
1121+
})
1122+
mlflow_env.update({
1123+
"DSS_MLFLOW_PROJECTKEY": project_key,
1124+
"MLFLOW_TRACKING_URI": self.host + "/dip/publicapi" if host is None else host,
1125+
"DSS_MLFLOW_HOST": self.host,
1126+
"DSS_MLFLOW_MANAGED_FOLDER": managed_folder,
1127+
})
1128+
os.environ.update(mlflow_env)
1129+
1130+
try:
1131+
import mlflow
1132+
yield mlflow
1133+
except Exception as e:
1134+
raise e
1135+
finally:
1136+
shutil.rmtree(tempdir)
1137+
for variable in mlflow_env:
1138+
os.environ.pop(variable, None)
11031139

11041140

11051141
class TemporaryImportHandle(object):

0 commit comments

Comments
 (0)