22import posixpath
33import tempfile
44import urllib
5+ import re
56from dataikuapi import DSSClient
67
78
89def parse_dss_managed_folder_uri (uri ):
9- """Parse an S3 URI, returning (bucket, path)"""
1010 parsed = urllib .parse .urlparse (uri )
1111 if parsed .scheme != "dss-managed-folder" :
1212 raise Exception ("Not a DSS Managed Folder URI: %s" % uri )
13- return os .path .normpath (parsed .path )
14-
13+ pattern = re .compile ("^(\w+\.)?\w{8}" )
14+ if not parsed .netloc or not pattern .match (parsed .netloc ):
15+ raise Exception ("Could not find a managed folder id in URI: %s" % uri )
16+ return parsed
1517
1618class PluginDSSManagedFolderArtifactRepository :
1719
1820 def __init__ (self , artifact_uri ):
19- self .base_artifact_path = parse_dss_managed_folder_uri (artifact_uri )
2021 if os .environ .get ("DSS_MLFLOW_APIKEY" ) is not None :
2122 self .client = DSSClient (
2223 os .environ .get ("DSS_MLFLOW_HOST" ),
@@ -28,14 +29,19 @@ def __init__(self, artifact_uri):
2829 internal_ticket = os .environ .get ("DSS_MLFLOW_INTERNAL_TICKET" )
2930 )
3031 self .project = self .client .get_project (os .environ .get ("DSS_MLFLOW_PROJECTKEY" ))
31- managed_folders = [
32- x ["id" ] for x in self .project .list_managed_folders ()
33- if x ["name" ] == os .environ .get ("DSS_MLFLOW_MANAGED_FOLDER" )
34- ]
35- if len (managed_folders ) > 0 :
36- self .managed_folder = self .project .get_managed_folder (managed_folders [0 ])
32+ parsed_uri = parse_dss_managed_folder_uri (artifact_uri )
33+ self .managed_folder = self .__get_managed_folder (parsed_uri .netloc )
34+ self .base_artifact_path = os .path .normpath (parsed_uri .path )
35+
36+ def __get_managed_folder (self , managed_folder_smart_id ):
37+ chunks = managed_folder_smart_id .split ('.' )
38+ if len (chunks ) == 1 :
39+ return self .project .get_managed_folder (chunks [0 ])
40+ elif len (chunks ) == 2 :
41+ project = self .client .get_project (chunks [0 ])
42+ return project .get_managed_folder (chunks [1 ])
3743 else :
38- self . managed_folder = self . project . create_managed_folder ( os . environ . get ( "DSS_MLFLOW_MANAGED_FOLDER" ) )
44+ raise Exception ( "Invalid managed folder id: %s" % managed_folder_smart_id )
3945
4046 def log_artifact (self , local_file , artifact_path = None ):
4147 """
0 commit comments