Skip to content

Commit 023ef7a

Browse files
committed
Force managed_folder to be a DSSManagedFolder in setup_mlflow
1 parent 0fd015c commit 023ef7a

File tree

2 files changed

+9
-13
lines changed

2 files changed

+9
-13
lines changed

dataikuapi/dss/project.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import time, warnings, sys, os.path as osp
1+
import warnings, os.path as osp
22

33
from ..dss_plugin_mlflow import MLflowHandle
44

@@ -1629,8 +1629,7 @@ def setup_mlflow(self, managed_folder, host=None):
16291629
"""
16301630
Setup the dss-plugin for MLflow
16311631
1632-
:param managed_folder: Managed folder where artifacts should be stored.
1633-
Either a managed folder id, or an instance of :class:`dataikuapi.dss.DSSManagedFolder`.
1632+
:param object managed_folder: a :class:`dataikuapi.dss.DSSManagedFolder` where MLflow artifacts should be stored.
16341633
:param str host: setup a custom host if the backend used is not DSS
16351634
"""
16361635
return MLflowHandle(client=self.client, project=self, managed_folder=managed_folder, host=host)

dataikuapi/dss_plugin_mlflow/utils.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from ..dss.utils import AnyLoc
21
from ..utils import DataikuException
2+
from ..dss.managedfolder import DSSManagedFolder
33
import logging
44
import os
55
import sys
@@ -62,16 +62,13 @@ def __init__(self, client, project, managed_folder, host=None):
6262
"DSS_MLFLOW_INTERNAL_TICKET": self.client.internal_ticket
6363
})
6464

65-
def full_id(mf):
66-
return mf.project.project_key + "." + mf.id
65+
if not isinstance(managed_folder, DSSManagedFolder):
66+
raise TypeError('managed_folder must a DSSManagedFolder.')
6767

68-
def get_managed_folder(smart_mf_id):
69-
return client.get_project(smart_mf_id.project_key).get_managed_folder(smart_mf_id.object_id)
70-
71-
mf_id = managed_folder if isinstance(managed_folder, str) else full_id(managed_folder)
72-
smart_mf_id = AnyLoc.from_ref(self.project_key, mf_id)
68+
mf_project = managed_folder.project.project_key
69+
mf_id = managed_folder.id
7370
try:
74-
get_managed_folder(smart_mf_id).get_definition()
71+
client.get_project(mf_project).get_managed_folder(mf_id).get_definition()
7572
except DataikuException as e:
7673
if "NotFoundException" in str(e):
7774
logging.error('The managed folder "%s" does not exist, please create it in your project flow before running this command.' % (mf_id))
@@ -82,7 +79,7 @@ def get_managed_folder(smart_mf_id):
8279
"DSS_MLFLOW_PROJECTKEY": self.project_key,
8380
"MLFLOW_TRACKING_URI": self.client.host + "/dip/publicapi" if host is None else host,
8481
"DSS_MLFLOW_HOST": self.client.host,
85-
"DSS_MLFLOW_MANAGED_FOLDER_ID": mf_id
82+
"DSS_MLFLOW_MANAGED_FOLDER_ID": mf_project + "." + mf_id
8683
})
8784

8885
os.environ.update(self.mlflow_env)

0 commit comments

Comments
 (0)