|
1 | 1 | import os |
2 | 2 | import sys |
3 | 3 | import tempfile |
| 4 | +import shutil |
| 5 | +from base64 import b64encode |
4 | 6 |
|
5 | 7 |
|
6 | | -def load_dss_mlflow_plugin(): |
7 | | - """ Function to dynamically add entrypoints for MLflow |
8 | | -
|
9 | | - MLflow uses entrypoints==0.3 to load entrypoints from plugins at import time. |
10 | | - This function adds dss-mlflow-plugin entrypoints dynamically by adding them in sys.path |
11 | | - at call time. |
12 | | - """ |
13 | | - tempdir = tempfile.mkdtemp() |
14 | | - plugin_dir = os.path.join(tempdir, "dss-plugin-mlflow.egg-info") |
15 | | - if not os.path.isdir(plugin_dir): |
16 | | - os.makedirs(plugin_dir) |
17 | | - with open(os.path.join(plugin_dir, "entry_points.txt"), "w") as f: |
18 | | - f.write( |
19 | | - "[mlflow.request_header_provider]\n" |
20 | | - "unused=dataikuapi.dss_plugin_mlflow.header_provider:PluginDSSHeaderProvider\n" |
21 | | - "[mlflow.artifact_repository]\n" |
22 | | - "dss-managed-folder=dataikuapi.dss_plugin_mlflow.artifact_repository:PluginDSSManagedFolderArtifactRepository\n" |
23 | | - ) |
24 | | - # Load plugin |
25 | | - sys.path.insert(0, tempdir) |
26 | | - return tempdir |
| 8 | +class MLflowHandle: |
| 9 | + def __init__(self, client, project_key, managed_folder="mlflow_artifacts", host=None): |
| 10 | + """ Add the MLflow-plugin parts of dataikuapi to MLflow local setup. |
| 11 | +
|
| 12 | + This method deals with |
| 13 | + 1. importing dynamically the DSS MLflow plugin: |
| 14 | + MLflow uses entrypoints==0.3 to load entrypoints from plugins at import time. |
| 15 | + We add dss-mlflow-plugin entrypoints dynamically by adding them in sys.path |
| 16 | + at call time. |
| 17 | +
|
| 18 | + 2. Setup the correct environment variables to setup the plugin to work |
| 19 | + with DSS backend as the tracking backend and enable using DSS managed folder |
| 20 | + as artifact location. |
| 21 | + """ |
| 22 | + self.project_key = project_key |
| 23 | + self.client = client |
| 24 | + self.mlflow_env = {} |
| 25 | + |
| 26 | + # Load DSS as the plugin |
| 27 | + self.tempdir = tempfile.mkdtemp() |
| 28 | + plugin_dir = os.path.join(self.tempdir, "dss-plugin-mlflow.egg-info") |
| 29 | + if not os.path.isdir(plugin_dir): |
| 30 | + os.makedirs(plugin_dir) |
| 31 | + with open(os.path.join(plugin_dir, "entry_points.txt"), "w") as f: |
| 32 | + f.write( |
| 33 | + "[mlflow.request_header_provider]\n" |
| 34 | + "unused=dataikuapi.dss_plugin_mlflow.header_provider:PluginDSSHeaderProvider\n" |
| 35 | + "[mlflow.artifact_repository]\n" |
| 36 | + "dss-managed-folder=dataikuapi.dss_plugin_mlflow.artifact_repository:PluginDSSManagedFolderArtifactRepository\n" |
| 37 | + ) |
| 38 | + sys.path.insert(0, self.tempdir) |
| 39 | + |
| 40 | + # Setup authentication |
| 41 | + if client._session.auth is not None: |
| 42 | + self.mlflow_env.update({ |
| 43 | + "DSS_MLFLOW_HEADER": "Authorization", |
| 44 | + "DSS_MLFLOW_TOKEN": "Basic {}".format( |
| 45 | + b64encode("{}:".format(self.client._session.auth.username).encode("utf-8")).decode("utf-8")), |
| 46 | + "DSS_MLFLOW_APIKEY": self.client.api_key |
| 47 | + }) |
| 48 | + elif client.internal_ticket: |
| 49 | + self.mlflow_env.update({ |
| 50 | + "DSS_MLFLOW_HEADER": "X-DKU-APITicket", |
| 51 | + "DSS_MLFLOW_TOKEN": self.client.internal_ticket, |
| 52 | + "DSS_MLFLOW_INTERNAL_TICKET": self.client.internal_ticket |
| 53 | + }) |
| 54 | + # Set host, tracking URI and project key |
| 55 | + self.mlflow_env.update({ |
| 56 | + "DSS_MLFLOW_PROJECTKEY": project_key, |
| 57 | + "MLFLOW_TRACKING_URI": self.client.host + "/dip/publicapi" if host is None else host, |
| 58 | + "DSS_MLFLOW_HOST": self.client.host, |
| 59 | + "DSS_MLFLOW_MANAGED_FOLDER": managed_folder |
| 60 | + }) |
| 61 | + os.environ.update(self.mlflow_env) |
| 62 | + |
| 63 | + def clear(self): |
| 64 | + shutil.rmtree(self.tempdir) |
| 65 | + for variable in self.mlflow_env: |
| 66 | + os.environ.pop(variable, None) |
| 67 | + |
| 68 | + def import_mlflow(self): |
| 69 | + import mlflow |
| 70 | + return mlflow |
| 71 | + |
| 72 | + def __enter__(self): |
| 73 | + return self.import_mlflow() |
| 74 | + |
| 75 | + def __exit__(self, exc_type, exc_value, traceback): |
| 76 | + self.clear() |
0 commit comments