@@ -325,6 +325,116 @@ def save_user_meta(self):
325325 "PUT" , "/projects/%s/savedmodels/%s/versions/%s/user-meta" % (self .saved_model .project_key ,
326326 self .saved_model .sm_id , self .saved_model_version ), body = um )
327327
328+ class DSSTreeNode (object ):
329+ def __init__ (self , tree , i ):
330+ self .tree = tree
331+ self .i = i
332+
333+ def get_left_child (self ):
334+ """Gets a :class:`dataikuapi.dss.ml.DSSTreeNode` representing the left side of the tree node (or None)"""
335+ left = self .tree .tree ['leftChild' ][self .i ]
336+ if left < 0 :
337+ return None
338+ else :
339+ return DSSTreeNode (self .tree , left )
340+
341+ def get_right_child (self ):
342+ """Gets a :class:`dataikuapi.dss.ml.DSSTreeNode` representing the right side of the tree node (or None)"""
343+ left = self .tree .tree ['rightChild' ][self .i ]
344+ if left < 0 :
345+ return None
346+ else :
347+ return DSSTreeNode (self .tree , left )
348+
349+ def get_split_info (self ):
350+ """Gets the information on the split, as a dict"""
351+ info = {}
352+ features = self .tree .tree .get ("feature" , None )
353+ probas = self .tree .tree .get ("probas" , None )
354+ leftCategories = self .tree .tree .get ("leftCategories" , None )
355+ impurities = self .tree .tree .get ("impurity" , None )
356+ predicts = self .tree .tree .get ("predict" , None )
357+ thresholds = self .tree .tree .get ("threshold" , None )
358+ nSamples = self .tree .tree .get ("nSamples" , None )
359+ info ['feature' ] = self .tree .feature_names [features [self .i ]] if features is not None else None
360+ info ['probas' ] = probas [self .i ] if probas is not None else None
361+ info ['leftCategories' ] = leftCategories [self .i ] if leftCategories is not None else None
362+ info ['impurity' ] = impurities [self .i ] if impurities is not None else None
363+ info ['predict' ] = predicts [self .i ] if predicts is not None else None
364+ info ['nSamples' ] = nSamples [self .i ] if nSamples is not None else None
365+ info ['threshold' ] = thresholds [self .i ] if thresholds is not None else None
366+ return info
367+
368+ class DSSTree (object ):
369+ def __init__ (self , tree , feature_names ):
370+ self .tree = tree
371+ self .feature_names = feature_names
372+
373+ def get_raw (self ):
374+ """Gets the raw tree data structure"""
375+ return self .tree
376+
377+ def get_root (self ):
378+ """Gets a :class:`dataikuapi.dss.ml.DSSTreeNode` representing the root of the tree"""
379+ return DSSTreeNode (self , 0 )
380+
381+ class DSSTreeSet (object ):
382+ def __init__ (self , trees ):
383+ self .trees = trees
384+
385+ def get_raw (self ):
386+ """Gets the raw trees data structure"""
387+ return self .trees
388+
389+ def get_feature_names (self ):
390+ """Gets the list of feature names (after dummification) """
391+ return self .trees ["featureNames" ]
392+
393+ def get_trees (self ):
394+ """Gets the list of trees as :class:`dataikuapi.dss.ml.DSSTree` """
395+ return [DSSTree (t , self .trees ["featureNames" ]) for t in self .trees ["trees" ]]
396+
397+ class DSSCoefficientPaths (object ):
398+ def __init__ (self , paths ):
399+ self .paths = paths
400+
401+ def get_raw (self ):
402+ """Gets the raw paths data structure"""
403+ return self .paths
404+
405+ def get_feature_names (self ):
406+ """Get the feature names (after dummification)"""
407+ return self .paths ['features' ]
408+
409+ def get_coefficient_path (self , feature , class_index = 0 ):
410+ """Get the path of the feature"""
411+ i = self .paths ['features' ].index (feature )
412+ if i >= 0 and i < len (self .paths ['path' ][0 ][class_index ]):
413+ n = len (self .paths ['path' ])
414+ return [self .paths ['path' ][j ][class_index ][i ] for j in range (0 , n )]
415+ else :
416+ return None
417+
418+ class DSSScatterPlots (object ):
419+ def __init__ (self , scatters ):
420+ self .scatters = scatters
421+
422+ def get_raw (self ):
423+ """Gets the raw scatters data structure"""
424+ return self .scatters
425+
426+ def get_feature_names (self ):
427+ """Get the feature names (after dummification)"""
428+ feature_names = []
429+ for k in self .scatters ['features' ]:
430+ feature_names .append (k )
431+ return feature_names
432+
433+ def get_scatter_plot (self , feature_x , feature_y ):
434+ """Get the scatter plot between feature_x and feature_y"""
435+ ret = {'cluster' :self .scatters ['cluster' ], 'x' :self .scatters ['features' ].get (feature_x , None ), 'y' :self .scatters ['features' ].get (feature_x , None )}
436+ return ret
437+
328438class DSSTrainedPredictionModelDetails (DSSTrainedModelDetails ):
329439 """
330440 Object to read details of a trained prediction model
@@ -403,6 +513,32 @@ def get_actual_modeling_params(self):
403513 """
404514 return self .details ["actualParams" ]
405515
516+ def get_trees (self ):
517+ """
518+ Gets the trees in the model (for tree-based models)
519+
520+ :return: a DSSTreeSet object to interact with the trees
521+ :rtype: :class:`dataikuapi.dss.ml.DSSTreeSet`
522+ """
523+ data = self .mltask .client ._perform_json (
524+ "GET" , "/projects/%s/models/lab/%s/%s/models/%s/trees" % (self .mltask .project_key , self .mltask .analysis_id , self .mltask .mltask_id , self .mltask_model_id ))
525+ if data is None :
526+ raise ValueError ("This model has no tree data" )
527+ return DSSTreeSet (data )
528+
529+ def get_coefficient_paths (self ):
530+ """
531+ Gets the coefficient paths for Lasso models
532+
533+ :return: a DSSCoefficientPaths object to interact with the coefficient paths
534+ :rtype: :class:`dataikuapi.dss.ml.DSSCoefficientPaths`
535+ """
536+ data = self .mltask .client ._perform_json (
537+ "GET" , "/projects/%s/models/lab/%s/%s/models/%s/coef-paths" % (self .mltask .project_key , self .mltask .analysis_id , self .mltask .mltask_id , self .mltask_model_id ))
538+ if data is None :
539+ raise ValueError ("This model has no coefficient paths" )
540+ return DSSCoefficientPaths (data )
541+
406542
407543class DSSClustersFacts (object ):
408544 def __init__ (self , clusters_facts ):
@@ -506,6 +642,18 @@ def get_actual_modeling_params(self):
506642 """
507643 return self .details ["actualParams" ]
508644
645+ def get_scatter_plots (self ):
646+ """
647+ Gets the cluster scatter plot data
648+
649+ :return: a DSSScatterPlots object to interact with the scatter plots
650+ :rtype: :class:`dataikuapi.dss.ml.DSSScatterPlots`
651+ """
652+ scatters = self .mltask .client ._perform_json (
653+ "GET" , "/projects/%s/models/lab/%s/%s/models/%s/scatter-plots" % (self .mltask .project_key , self .mltask .analysis_id , self .mltask .mltask_id , self .mltask_model_id ))
654+ return DSSScatterPlots (scatters )
655+
656+
509657class DSSMLTask (object ):
510658 """A handle to interact with a MLTask for prediction or clustering in a DSS visual analysis"""
511659 def __init__ (self , client , project_key , analysis_id , mltask_id ):
@@ -652,8 +800,8 @@ def get_trained_model_details(self, id):
652800
653801 :param str id: Identifier of the trained model, as returned by :meth:`get_trained_models_ids`
654802
655- :return: A :class:`DSSTrainedPredictionModelDetails ` representing the details of this trained model id
656- :rtype: :class:`DSSTrainedPredictionModelDetails `
803+ :return: A :class:`DSSTrainedModelDetails ` representing the details of this trained model id
804+ :rtype: :class:`DSSTrainedModelDetails `
657805 """
658806 ret = self .client ._perform_json (
659807 "GET" , "/projects/%s/models/lab/%s/%s/models/%s/details" % (self .project_key , self .analysis_id , self .mltask_id ,id ))
0 commit comments