1+ from ..dss .utils import AnyLoc
2+ from ..utils import DataikuException
3+ import logging
14import os
25import sys
36import tempfile
69
710
811class MLflowHandle :
9- def __init__ (self , client , project_key , managed_folder_name , host = None ):
12+ def __init__ (self , client , project , managed_folder , host = None ):
1013 """ Add the MLflow-plugin parts of dataikuapi to MLflow local setup.
1114
1215 This method deals with
@@ -19,7 +22,8 @@ def __init__(self, client, project_key, managed_folder_name, host=None):
1922 with DSS backend as the tracking backend and enable using DSS managed folder
2023 as artifact location.
2124 """
22- self .project_key = project_key
25+ self .project = project
26+ self .project_key = project .project_key
2327 self .client = client
2428 self .mlflow_env = {}
2529
@@ -57,32 +61,28 @@ def __init__(self, client, project_key, managed_folder_name, host=None):
5761 "DSS_MLFLOW_TOKEN" : self .client .internal_ticket ,
5862 "DSS_MLFLOW_INTERNAL_TICKET" : self .client .internal_ticket
5963 })
60- # Set host, tracking URI and project key
61- self .mlflow_env .update ({
62- "DSS_MLFLOW_PROJECTKEY" : project_key ,
63- "MLFLOW_TRACKING_URI" : self .client .host + "/dip/publicapi" if host is None else host ,
64- "DSS_MLFLOW_HOST" : self .client .host
65- })
6664
67- # Get or create managed folder with name 'managed_folder_name'
68- project = self .client .get_project (project_key )
69- managed_folders = [
70- x ["id" ] for x in project .list_managed_folders ()
71- if x ["name" ] == managed_folder_name
72- ]
73- managed_folder = None
74- if len (managed_folders ) > 0 :
75- managed_folder = project .get_managed_folder (managed_folders [0 ])
76- else :
77- managed_folder = project .create_managed_folder (managed_folder_name )
78-
79-
80- # TODO: don't allow to pass a managed folder name, everything is already
81- # wired in artifact_repository and in the backend for DSS_MLFLOW_MANAGED_FOLDER_ID to contain
82- # a smart ref. So, instead of managed_folder_name, we should take a DSSManagedFolder
83- # or a managed folder id/smart ref and set DSS_MLFLOW_MANAGED_FOLDER_ID to that (smart) id.
65+ def full_id (mf ):
66+ return mf .project .project_key + "." + mf .id
67+
68+ def get_managed_folder (smart_mf_id ):
69+ return client .get_project (smart_mf_id .project_key ).get_managed_folder (smart_mf_id .object_id )
70+
71+ mf_id = managed_folder if isinstance (managed_folder , str ) else full_id (managed_folder )
72+ smart_mf_id = AnyLoc .from_ref (self .project_key , mf_id )
73+ try :
74+ get_managed_folder (smart_mf_id ).get_definition ()
75+ except DataikuException as e :
76+ if "NotFoundException" in str (e ):
77+ logging .error ('The managed folder "%s" does not exist, please create it in your project flow before running this command.' % (mf_id ))
78+ raise
79+
80+ # Set host, tracking URI, project key and managed_folder_id
8481 self .mlflow_env .update ({
85- "DSS_MLFLOW_MANAGED_FOLDER_ID" : managed_folder .id
82+ "DSS_MLFLOW_PROJECTKEY" : self .project_key ,
83+ "MLFLOW_TRACKING_URI" : self .client .host + "/dip/publicapi" if host is None else host ,
84+ "DSS_MLFLOW_HOST" : self .client .host ,
85+ "DSS_MLFLOW_MANAGED_FOLDER_ID" : mf_id
8686 })
8787
8888 os .environ .update (self .mlflow_env )
0 commit comments