Skip to content

Commit 8da9c54

Browse files
author
Valentin Thorey
authored
Create dedicated class for MLflow custom endpoints (#203)
* Create dedicated class for MLflow custom endpoint * Remove unused parameter
1 parent bb2ac4d commit 8da9c54

File tree

3 files changed

+34
-5
lines changed

3 files changed

+34
-5
lines changed

dataikuapi/dss/mlflow.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
class DSSMLflowExtension(object):
2+
"""
3+
A handle to interact with specific endpoints of the DSS MLflow integration.
4+
5+
Do not create this directly, use :meth:`dataikuapi.dss.DSSProject.get_mlflow_extension`
6+
"""
7+
8+
def __init__(self, client, project_key):
9+
self.client = client
10+
self.project = client.get_project(project_key)
11+
self.project_key = project_key
12+
13+
def list_models(self, run_id):
14+
response = self.client._perform_http(
15+
"GET", "/api/2.0/mlflow/extension/models/{}".format(run_id),
16+
headers={"x-dku-mlflow-project-key": self.project_key}
17+
)
18+
return response.json()

dataikuapi/dss/project.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .managedfolder import DSSManagedFolder
1313
from .savedmodel import DSSSavedModel
1414
from .modelevaluationstore import DSSModelEvaluationStore
15+
from .mlflow import DSSMLflowExtension
1516
from .job import DSSJob, DSSJobWaiter
1617
from .scenario import DSSScenario, DSSScenarioListItem
1718
from .continuousactivity import DSSContinuousActivity
@@ -33,8 +34,8 @@ class DSSProject(object):
3334
Do not create this class directly, instead use :meth:`dataikuapi.DSSClient.get_project`
3435
"""
3536
def __init__(self, client, project_key):
36-
self.client = client
37-
self.project_key = project_key
37+
self.client = client
38+
self.project_key = project_key
3839

3940
def get_summary(self):
4041
"""
@@ -1608,6 +1609,15 @@ def setup_mlflow(self, managed_folder="mlflow_artifacts", host=None):
16081609
"""
16091610
return MLflowHandle(client=self.client, project_key=self.project_key, managed_folder=managed_folder, host=host)
16101611

1612+
def get_mlflow_extension(self):
1613+
"""
1614+
Get a handle to interact with the extension of MLflow provided by DSS
1615+
1616+
:returns: A :class:`dataikuapi.dss.mlflow.DSSMLflowExtension` Mlflow Extension handle
1617+
1618+
"""
1619+
return DSSMLflowExtension(client=self.client, project_key=self.project_key)
1620+
16111621
def clean_experiment_tracking_db(self):
16121622
"""
16131623
Cleans the experiments, runs, params, metrics, tags, etc. for this project

dataikuapi/dssclient.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1028,7 +1028,7 @@ def set_license(self, license):
10281028
# Internal Request handling
10291029
########################################################
10301030

1031-
def _perform_http(self, method, path, params=None, body=None, stream=False, files=None, raw_body=None):
1031+
def _perform_http(self, method, path, params=None, body=None, stream=False, files=None, raw_body=None, headers=None):
10321032
if body is not None:
10331033
body = json.dumps(body)
10341034
if raw_body is not None:
@@ -1038,8 +1038,9 @@ def _perform_http(self, method, path, params=None, body=None, stream=False, file
10381038
http_res = self._session.request(
10391039
method, "%s/dip/publicapi%s" % (self.host, path),
10401040
params=params, data=body,
1041-
files = files,
1042-
stream = stream)
1041+
files=files,
1042+
stream=stream,
1043+
headers=headers)
10431044
http_res.raise_for_status()
10441045
return http_res
10451046
except exceptions.HTTPError:

0 commit comments

Comments
 (0)