Skip to content

Commit cc0ce56

Browse files
committed
add calls to get extra model details
1 parent 20a1ee4 commit cc0ce56

File tree

1 file changed

+150
-2
lines changed

1 file changed

+150
-2
lines changed

dataikuapi/dss/ml.py

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,6 +325,116 @@ def save_user_meta(self):
325325
"PUT", "/projects/%s/savedmodels/%s/versions/%s/user-meta" % (self.saved_model.project_key,
326326
self.saved_model.sm_id, self.saved_model_version), body = um)
327327

328+
class DSSTreeNode(object):
329+
def __init__(self, tree, i):
330+
self.tree = tree
331+
self.i = i
332+
333+
def get_left_child(self):
334+
"""Gets a :class:`dataikuapi.dss.ml.DSSTreeNode` representing the left side of the tree node (or None)"""
335+
left = self.tree.tree['leftChild'][self.i]
336+
if left < 0:
337+
return None
338+
else:
339+
return DSSTreeNode(self.tree, left)
340+
341+
def get_right_child(self):
342+
"""Gets a :class:`dataikuapi.dss.ml.DSSTreeNode` representing the right side of the tree node (or None)"""
343+
left = self.tree.tree['rightChild'][self.i]
344+
if left < 0:
345+
return None
346+
else:
347+
return DSSTreeNode(self.tree, left)
348+
349+
def get_split_info(self):
350+
"""Gets the information on the split, as a dict"""
351+
info = {}
352+
features = self.tree.tree.get("feature", None)
353+
probas = self.tree.tree.get("probas", None)
354+
leftCategories = self.tree.tree.get("leftCategories", None)
355+
impurities = self.tree.tree.get("impurity", None)
356+
predicts = self.tree.tree.get("predict", None)
357+
thresholds = self.tree.tree.get("threshold", None)
358+
nSamples = self.tree.tree.get("nSamples", None)
359+
info['feature'] = self.tree.feature_names[features[self.i]] if features is not None else None
360+
info['probas'] = probas[self.i] if probas is not None else None
361+
info['leftCategories'] = leftCategories[self.i] if leftCategories is not None else None
362+
info['impurity'] = impurities[self.i] if impurities is not None else None
363+
info['predict'] = predicts[self.i] if predicts is not None else None
364+
info['nSamples'] = nSamples[self.i] if nSamples is not None else None
365+
info['threshold'] = thresholds[self.i] if thresholds is not None else None
366+
return info
367+
368+
class DSSTree(object):
369+
def __init__(self, tree, feature_names):
370+
self.tree = tree
371+
self.feature_names = feature_names
372+
373+
def get_raw(self):
374+
"""Gets the raw tree data structure"""
375+
return self.tree
376+
377+
def get_root(self):
378+
"""Gets a :class:`dataikuapi.dss.ml.DSSTreeNode` representing the root of the tree"""
379+
return DSSTreeNode(self, 0)
380+
381+
class DSSTreeSet(object):
382+
def __init__(self, trees):
383+
self.trees = trees
384+
385+
def get_raw(self):
386+
"""Gets the raw trees data structure"""
387+
return self.trees
388+
389+
def get_feature_names(self):
390+
"""Gets the list of feature names (after dummification) """
391+
return self.trees["featureNames"]
392+
393+
def get_trees(self):
394+
"""Gets the list of trees as :class:`dataikuapi.dss.ml.DSSTree` """
395+
return [DSSTree(t, self.trees["featureNames"]) for t in self.trees["trees"]]
396+
397+
class DSSCoefficientPaths(object):
398+
def __init__(self, paths):
399+
self.paths = paths
400+
401+
def get_raw(self):
402+
"""Gets the raw paths data structure"""
403+
return self.paths
404+
405+
def get_feature_names(self):
406+
"""Get the feature names (after dummification)"""
407+
return self.paths['features']
408+
409+
def get_coefficient_path(self, feature, class_index=0):
410+
"""Get the path of the feature"""
411+
i = self.paths['features'].index(feature)
412+
if i >= 0 and i < len(self.paths['path'][0][class_index]):
413+
n = len(self.paths['path'])
414+
return [self.paths['path'][j][class_index][i] for j in range(0, n)]
415+
else:
416+
return None
417+
418+
class DSSScatterPlots(object):
419+
def __init__(self, scatters):
420+
self.scatters = scatters
421+
422+
def get_raw(self):
423+
"""Gets the raw scatters data structure"""
424+
return self.scatters
425+
426+
def get_feature_names(self):
427+
"""Get the feature names (after dummification)"""
428+
feature_names = []
429+
for k in self.scatters['features']:
430+
feature_names.append(k)
431+
return feature_names
432+
433+
def get_scatter_plot(self, feature_x, feature_y):
434+
"""Get the scatter plot between feature_x and feature_y"""
435+
ret = {'cluster':self.scatters['cluster'], 'x':self.scatters['features'].get(feature_x, None), 'y':self.scatters['features'].get(feature_x, None)}
436+
return ret
437+
328438
class DSSTrainedPredictionModelDetails(DSSTrainedModelDetails):
329439
"""
330440
Object to read details of a trained prediction model
@@ -403,6 +513,32 @@ def get_actual_modeling_params(self):
403513
"""
404514
return self.details["actualParams"]
405515

516+
def get_trees(self):
517+
"""
518+
Gets the trees in the model (for tree-based models)
519+
520+
:return: a DSSTreeSet object to interact with the trees
521+
:rtype: :class:`dataikuapi.dss.ml.DSSTreeSet`
522+
"""
523+
data = self.mltask.client._perform_json(
524+
"GET", "/projects/%s/models/lab/%s/%s/models/%s/trees" % (self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id))
525+
if data is None:
526+
raise ValueError("This model has no tree data")
527+
return DSSTreeSet(data)
528+
529+
def get_coefficient_paths(self):
530+
"""
531+
Gets the coefficient paths for Lasso models
532+
533+
:return: a DSSCoefficientPaths object to interact with the coefficient paths
534+
:rtype: :class:`dataikuapi.dss.ml.DSSCoefficientPaths`
535+
"""
536+
data = self.mltask.client._perform_json(
537+
"GET", "/projects/%s/models/lab/%s/%s/models/%s/coef-paths" % (self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id))
538+
if data is None:
539+
raise ValueError("This model has no coefficient paths")
540+
return DSSCoefficientPaths(data)
541+
406542

407543
class DSSClustersFacts(object):
408544
def __init__(self, clusters_facts):
@@ -506,6 +642,18 @@ def get_actual_modeling_params(self):
506642
"""
507643
return self.details["actualParams"]
508644

645+
def get_scatter_plots(self):
646+
"""
647+
Gets the cluster scatter plot data
648+
649+
:return: a DSSScatterPlots object to interact with the scatter plots
650+
:rtype: :class:`dataikuapi.dss.ml.DSSScatterPlots`
651+
"""
652+
scatters = self.mltask.client._perform_json(
653+
"GET", "/projects/%s/models/lab/%s/%s/models/%s/scatter-plots" % (self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id))
654+
return DSSScatterPlots(scatters)
655+
656+
509657
class DSSMLTask(object):
510658
"""A handle to interact with a MLTask for prediction or clustering in a DSS visual analysis"""
511659
def __init__(self, client, project_key, analysis_id, mltask_id):
@@ -652,8 +800,8 @@ def get_trained_model_details(self, id):
652800
653801
:param str id: Identifier of the trained model, as returned by :meth:`get_trained_models_ids`
654802
655-
:return: A :class:`DSSTrainedPredictionModelDetails` representing the details of this trained model id
656-
:rtype: :class:`DSSTrainedPredictionModelDetails`
803+
:return: A :class:`DSSTrainedModelDetails` representing the details of this trained model id
804+
:rtype: :class:`DSSTrainedModelDetails`
657805
"""
658806
ret = self.client._perform_json(
659807
"GET", "/projects/%s/models/lab/%s/%s/models/%s/details" % (self.project_key, self.analysis_id, self.mltask_id,id))

0 commit comments

Comments
 (0)