Skip to content

Commit f082a00

Browse files
author
Louis Pouillot
committed
Add explanations parameters to APINodeClient.predict_record(s)
1 parent 166a248 commit f082a00

File tree

1 file changed

+32
-5
lines changed

1 file changed

+32
-5
lines changed

dataikuapi/apinode_client.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def __init__(self, uri, service_id, api_key=None):
1616
"""
1717
DSSBaseClient.__init__(self, "%s/%s" % (uri, "public/api/v1/%s" % service_id), api_key)
1818

19-
def predict_record(self, endpoint_id, features, forced_generation=None, dispatch_key=None, context=None):
19+
def predict_record(self, endpoint_id, features, forced_generation=None, dispatch_key=None, context=None,
20+
with_explanations=None, explanation_method=None, n_explanations=None, n_explanations_mc_steps=None):
2021
"""
2122
Predicts a single record on a DSS API node endpoint (standard or custom prediction)
2223
@@ -25,12 +26,25 @@ def predict_record(self, endpoint_id, features, forced_generation=None, dispatch
2526
:param forced_generation: See documentation about multi-version prediction
2627
:param dispatch_key: See documentation about multi-version prediction
2728
:param context: Optional, Python dictionary of additional context information. The context information is logged, but not directly used.
29+
:param with_explanations: Whether individual explanations should be computed for each records.
30+
Explanations must be enabled for the prediction endpoint.
31+
:param explanation_method: Optional, method to compute those explanations. If None, will use the value configured in the endpoint.
32+
:param n_explanations: Optional, number of explanations to output per prediction. If None, will use the value configured in the endpoint.
33+
:param n_explanations_mc_steps: Optional, precision parameter for SHAPLEY method, higher means more precise but slower (between 25 and 400).
34+
If None, will use the value configured in the endpoint.
2835
2936
:return: a Python dict of the API answer. The answer contains a "result" key (itself a dict)
3037
"""
31-
obj = {
32-
"features" :features
38+
obj = {
39+
"features": features,
40+
"explanations": {
41+
"enabled": with_explanations,
42+
"method": explanation_method,
43+
"nExplanations": n_explanations,
44+
"nMonteCarloSteps": n_explanations_mc_steps
45+
}
3346
}
47+
3448
if context is not None:
3549
obj["context"] = context
3650
if forced_generation is not None:
@@ -40,14 +54,21 @@ def predict_record(self, endpoint_id, features, forced_generation=None, dispatch
4054

4155
return self._perform_json("POST", "%s/predict" % endpoint_id, body = obj)
4256

43-
def predict_records(self, endpoint_id, records, forced_generation=None, dispatch_key=None):
57+
def predict_records(self, endpoint_id, records, forced_generation=None, dispatch_key=None, with_explanations=None,
58+
explanation_method=None, n_explanations=None, n_explanations_mc_steps=None):
4459
"""
4560
Predicts a batch of records on a DSS API node endpoint (standard or custom prediction)
4661
4762
:param str endpoint_id: Identifier of the endpoint to query
4863
:param records: Python list of records. Each record must be a Python dict. Each record must contain a "features" dict (see predict_record) and optionally a "context" dict.
4964
:param forced_generation: See documentation about multi-version prediction
5065
:param dispatch_key: See documentation about multi-version prediction
66+
:param with_explanations: Whether individual explanations should be computed for each records.
67+
Explanations must be enabled for the prediction endpoint.
68+
:param explanation_method: Optional, method to compute those explanations. If None, will use the value configured in the endpoint.
69+
:param n_explanations: Optional, number of explanations to output per prediction. If None, will use the value configured in the endpoint.
70+
:param n_explanations_mc_steps: Optional, precision parameter for SHAPLEY method, higher means more precise but slower (between 25 and 400).
71+
If None, will use the value configured in the endpoint.
5172
5273
:return: a Python dict of the API answer. The answer contains a "results" key (which is an array of result objects)
5374
"""
@@ -57,7 +78,13 @@ def predict_records(self, endpoint_id, records, forced_generation=None, dispatch
5778
raise ValueError("Each record must contain a 'features' dict")
5879

5980
obj = {
60-
"items" : records
81+
"items": records,
82+
"explanations": {
83+
"enabled": with_explanations,
84+
"method": explanation_method,
85+
"nExplanations": n_explanations,
86+
"nMonteCarloSteps": n_explanations_mc_steps
87+
}
6188
}
6289

6390
if forced_generation is not None:

0 commit comments

Comments
 (0)