Skip to content

Commit 741ecf3

Browse files
authored
Merge PR #49 post train computations (subpop, pdp)
from feature/api-for-ml-post-train-computations
2 parents b831413 + 27477a6 commit 741ecf3

File tree

1 file changed

+119
-0
lines changed

1 file changed

+119
-0
lines changed

dataikuapi/dss/ml.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import time
66
from .metrics import ComputedMetrics
77
from .utils import DSSDatasetSelectionBuilder, DSSFilterBuilder
8+
from .future import DSSFuture
89

910
class PredictionSplitParamsHandler(object):
1011
"""Object to modify the train/test splitting params."""
@@ -592,6 +593,124 @@ def get_scoring_pmml_stream(self):
592593
"GET", "/projects/%s/savedmodels/%s/versions/%s/scoring-pmml" %
593594
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version))
594595

596+
## Post-train computations
597+
598+
def compute_subpopulation_analyses(self, split_by, wait=True, sample_size=1000, random_state=1337, n_jobs=1, debug_mode=False):
599+
"""
600+
Launch computation of Subpopulation analyses for this trained model.
601+
602+
:param list split_by: columns on which subpopulation analyses are to be computed (one analysis per column)
603+
:param bool wait: if True, the call blocks until the computation is finished and returns the results directly
604+
:param int sample_size: number of records of the dataset to use for the computation
605+
:param int random_state: random state to use to build sample, for reproducibility
606+
:param int n_jobs: number of cores used for parallel training. (-1 means 'all cores')
607+
:param bool debug_mode: if True, output all logs (slower)
608+
609+
:returns: if wait is True, a dict containing the Subpopulation analyses, else a future to wait on the result
610+
:rtype: dict or :class:`dataikuapi.dss.future.DSSFuture`
611+
"""
612+
613+
body = {
614+
"features": split_by,
615+
"computationParams": {
616+
"sample_size": sample_size,
617+
"random_state": random_state,
618+
"n_jobs": n_jobs,
619+
"debug_mode": debug_mode,
620+
}}
621+
if self.mltask is not None:
622+
future_response = self.mltask.client._perform_json(
623+
"POST", "/projects/%s/models/lab/%s/%s/models/%s/subpopulation-analyses" %
624+
(self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id),
625+
body=body
626+
)
627+
future = DSSFuture(self.mltask.client, future_response.get("jobId", None), future_response)
628+
else:
629+
future_response = self.saved_model.client._perform_json(
630+
"POST", "/projects/%s/savedmodels/%s/versions/%s/subpopulation-analyses" %
631+
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
632+
body=body
633+
)
634+
future = DSSFuture(self.saved_model.client, future_response.get("jobId", None), future_response)
635+
if wait:
636+
return future.wait_for_result()
637+
else:
638+
return future
639+
640+
641+
def get_subpopulation_analyses(self):
642+
"""
643+
Retrieve all subpopulation analyses computed for this trained model as a dict
644+
"""
645+
646+
if self.mltask is not None:
647+
return self.mltask.client._perform_json(
648+
"GET", "/projects/%s/models/lab/%s/%s/models/%s/subpopulation-analyses" %
649+
(self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id)
650+
)
651+
else:
652+
return self.saved_model.client._perform_json(
653+
"GET", "/projects/%s/savedmodels/%s/versions/%s/subpopulation-analyses" %
654+
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
655+
)
656+
657+
def compute_partial_dependencies(self, features, wait=True, sample_size=1000, random_state=1337, n_jobs=1, debug_mode=False):
658+
"""
659+
Launch computation of Partial dependencies for this trained model.
660+
661+
:param list features: features on which partial dependencies are to be computed
662+
:param bool wait: if True, the call blocks until the computation is finished and returns the results directly
663+
:param int sample_size: number of records of the dataset to use for the computation
664+
:param int random_state: random state to use to build sample, for reproducibility
665+
:param int n_jobs: number of cores used for parallel training. (-1 means 'all cores')
666+
:param bool debug_mode: if True, output all logs (slower)
667+
668+
:returns: if wait is True, a dict containing the Partial dependencies, else a future to wait on the result
669+
:rtype: dict or :class:`dataikuapi.dss.future.DSSFuture`
670+
"""
671+
672+
body = {
673+
"features": features,
674+
"computationParams": {
675+
"sample_size": sample_size,
676+
"random_state": random_state,
677+
"n_jobs": n_jobs,
678+
"debug_mode": debug_mode,
679+
}}
680+
if self.mltask is not None:
681+
future_response = self.mltask.client._perform_json(
682+
"POST", "/projects/%s/models/lab/%s/%s/models/%s/partial-dependencies" %
683+
(self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id),
684+
body=body
685+
)
686+
future = DSSFuture(self.mltask.client, future_response.get("jobId", None), future_response)
687+
else:
688+
future_response = self.saved_model.client._perform_json(
689+
"POST", "/projects/%s/savedmodels/%s/versions/%s/partial-dependencies" %
690+
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
691+
body=body
692+
)
693+
future = DSSFuture(self.saved_model.client, future_response.get("jobId", None), future_response)
694+
if wait:
695+
return future.wait_for_result()
696+
else:
697+
return future
698+
699+
def get_partial_dependencies(self):
700+
"""
701+
Retrieve all partial dependencies computed for this trained model as a dict
702+
"""
703+
704+
if self.mltask is not None:
705+
return self.mltask.client._perform_json(
706+
"GET", "/projects/%s/models/lab/%s/%s/models/%s/partial-dependencies" %
707+
(self.mltask.project_key, self.mltask.analysis_id, self.mltask.mltask_id, self.mltask_model_id)
708+
)
709+
else:
710+
return self.saved_model.client._perform_json(
711+
"GET", "/projects/%s/savedmodels/%s/versions/%s/partial-dependencies" %
712+
(self.saved_model.project_key, self.saved_model.sm_id, self.saved_model_version),
713+
)
595714

596715
class DSSClustersFacts(object):
597716
def __init__(self, clusters_facts):

0 commit comments

Comments
 (0)