1+ from ..utils import DataikuException
2+ from ..dss .managedfolder import DSSManagedFolder
3+ import logging
14import os
25import sys
36import tempfile
69
710
811class MLflowHandle :
9- def __init__ (self , client , project_key , managed_folder_name , host = None ):
12+ def __init__ (self , client , project , managed_folder , host = None ):
1013 """ Add the MLflow-plugin parts of dataikuapi to MLflow local setup.
1114
1215 This method deals with
@@ -19,7 +22,8 @@ def __init__(self, client, project_key, managed_folder_name, host=None):
1922 with DSS backend as the tracking backend and enable using DSS managed folder
2023 as artifact location.
2124 """
22- self .project_key = project_key
25+ self .project = project
26+ self .project_key = project .project_key
2327 self .client = client
2428 self .mlflow_env = {}
2529
@@ -57,32 +61,25 @@ def __init__(self, client, project_key, managed_folder_name, host=None):
5761 "DSS_MLFLOW_TOKEN" : self .client .internal_ticket ,
5862 "DSS_MLFLOW_INTERNAL_TICKET" : self .client .internal_ticket
5963 })
60- # Set host, tracking URI and project key
61- self .mlflow_env .update ({
62- "DSS_MLFLOW_PROJECTKEY" : project_key ,
63- "MLFLOW_TRACKING_URI" : self .client .host + "/dip/publicapi" if host is None else host ,
64- "DSS_MLFLOW_HOST" : self .client .host
65- })
6664
67- # Get or create managed folder with name 'managed_folder_name'
68- project = self .client .get_project (project_key )
69- managed_folders = [
70- x ["id" ] for x in project .list_managed_folders ()
71- if x ["name" ] == managed_folder_name
72- ]
73- managed_folder = None
74- if len (managed_folders ) > 0 :
75- managed_folder = project .get_managed_folder (managed_folders [0 ])
76- else :
77- managed_folder = project .create_managed_folder (managed_folder_name )
65+ if not isinstance (managed_folder , DSSManagedFolder ):
66+ raise TypeError ('managed_folder must a DSSManagedFolder.' )
7867
68+ mf_project = managed_folder .project .project_key
69+ mf_id = managed_folder .id
70+ try :
71+ client .get_project (mf_project ).get_managed_folder (mf_id ).get_definition ()
72+ except DataikuException as e :
73+ if "NotFoundException" in str (e ):
74+ logging .error ('The managed folder "%s" does not exist, please create it in your project flow before running this command.' % (mf_id ))
75+ raise
7976
80- # TODO: don't allow to pass a managed folder name, everything is already
81- # wired in artifact_repository and in the backend for DSS_MLFLOW_MANAGED_FOLDER_ID to contain
82- # a smart ref. So, instead of managed_folder_name, we should take a DSSManagedFolder
83- # or a managed folder id/smart ref and set DSS_MLFLOW_MANAGED_FOLDER_ID to that (smart) id.
77+ # Set host, tracking URI, project key and managed_folder_id
8478 self .mlflow_env .update ({
85- "DSS_MLFLOW_MANAGED_FOLDER_ID" : managed_folder .id
79+ "DSS_MLFLOW_PROJECTKEY" : self .project_key ,
80+ "MLFLOW_TRACKING_URI" : self .client .host + "/dip/publicapi" if host is None else host ,
81+ "DSS_MLFLOW_HOST" : self .client .host ,
82+ "DSS_MLFLOW_MANAGED_FOLDER_ID" : mf_project + "." + mf_id
8683 })
8784
8885 os .environ .update (self .mlflow_env )
0 commit comments