@@ -204,29 +204,23 @@ class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
204204 "DBSCAN" : "db_scan_clustering" ,
205205 }
206206
207- class DSSTrainedPredictionModelDetails (object ):
208- """
209- Object to read details of a trained prediction model
210207
211- Do not create this object directly, use :meth:`DSSMLTask.get_trained_model_details()` instead
212- """
213208
214- def __init__ (self , details , summary ):
209+ class DSSTrainedModelDetails (object ):
210+ def __init__ (self , details , summary , saved_model = None , saved_model_version = None , mltask = None , mltask_model_id = None ):
215211 self .details = details
216212 self .summary = summary
213+ self .saved_model = saved_model
214+ self .saved_model_version = saved_model_version
215+ self .mltask = mltask
216+ self .mltask_model_id = mltask_model_id
217217
218218 def get_raw (self ):
219219 """
220220 Gets the raw dictionary of trained model details
221221 """
222222 return self .details
223223
224- def get_raw_snippet (self ):
225- """
226- Gets the raw dictionary of trained model snippet
227- """
228- return self .summary
229-
230224 def get_train_info (self ):
231225 """
232226 Returns various information about the train process (size of the train set, quick description, timing information)
@@ -235,6 +229,44 @@ def get_train_info(self):
235229 """
236230 return self .details ["trainInfo" ]
237231
232+ def get_user_meta (self ):
233+ """
234+ Gets the user-accessible metadata (name, description, cluster labels, classification threshold)
235+ Returns the original object, not a copy. Changes to the returned object are persisted to DSS by calling
236+ :meth:`save_user_meta`
237+
238+ """
239+ return self .details ["userMeta" ]
240+
241+ def save_user_meta (self ):
242+ um = self .details ["userMeta" ]
243+
244+ if self .mltask is not None :
245+ self .mltask .client ._perform_empty (
246+ "PUT" , "/projects/%s/models/lab/%s/%s/models/%s/user-meta" % (self .mltask .project_key ,
247+ self .mltask .analysis_id , self .mltask .mltask_id , self .mltask_model_id ), body = um )
248+ else :
249+ self .saved_model .client ._perform_empty (
250+ "PUT" , "/projects/%s/savedmodels/%s/versions/%s/user-meta" % (self .saved_model .project_key ,
251+ self .saved_model .sm_id , self .saved_model_version ), body = um )
252+
253+ class DSSTrainedPredictionModelDetails (DSSTrainedModelDetails ):
254+ """
255+ Object to read details of a trained prediction model
256+
257+ Do not create this object directly, use :meth:`DSSMLTask.get_trained_model_details()` instead
258+ """
259+
260+ def __init__ (self , details , summary , saved_model = None , saved_model_version = None , mltask = None , mltask_model_id = None ):
261+ DSSTrainedModelDetails .__init__ (self , details , summary , saved_model , saved_model_version , mltask , mltask_model_id )
262+
263+ def get_raw_snippet (self ):
264+ """
265+ Gets the raw dictionary of trained model snippet
266+ """
267+ return self .summary
268+
269+
238270 def get_roc_curve_data (self ):
239271 roc = self .details .get ("perf" , {}).get ("rocVizData" ,{})
240272 if roc is None :
@@ -303,6 +335,7 @@ def get_actual_modeling_params(self):
303335 """
304336 return self .details ["actualParams" ]
305337
338+
306339class DSSClustersFacts (object ):
307340 def __init__ (self , clusters_facts ):
308341 self .clusters_facts = clusters_facts
@@ -331,16 +364,17 @@ def get_facts_for_cluster_and_feature(self, cluster_index, feature_name):
331364 """
332365 return [x for x in self .get_facts_for_cluster (cluster_index ) if x ["feature_label" ] == feature_name ]
333366
334- class DSSTrainedClusteringModelDetails (object ):
367+
368+ class DSSTrainedClusteringModelDetails (DSSTrainedModelDetails ):
335369 """
336370 Object to read details of a trained clustering model
337371
338372 Do not create this object directly, use :meth:`DSSMLTask.get_trained_model_details()` instead
339373 """
340374
341- def __init__ (self , details , summary ):
342- self . details = details
343- self . summary = summary
375+ def __init__ (self , details , summary , saved_model = None , saved_model_version = None , mltask = None , mltask_model_id = None ):
376+ DSSTrainedModelDetails . __init__ ( self , details , summary , saved_model , saved_model_version , mltask , mltask_model_id )
377+
344378
345379 def get_raw (self ):
346380 """
@@ -511,9 +545,9 @@ def get_trained_model_details(self, id):
511545
512546
513547 if "facts" in ret :
514- return DSSTrainedClusteringModelDetails (ret , summary )
548+ return DSSTrainedClusteringModelDetails (ret , summary , mltask = self , mltask_model_id = id )
515549 else :
516- return DSSTrainedPredictionModelDetails (ret , summary )
550+ return DSSTrainedPredictionModelDetails (ret , summary , mltask = self , mltask_model_id = id )
517551
518552 def deploy_to_flow (self , model_id , model_name , train_dataset , test_dataset = None , redo_optimization = True ):
519553 """
0 commit comments