Skip to content

Commit 17372a3

Browse files
committed
Updates to MLFlow import API
1 parent f76d119 commit 17372a3

File tree

2 files changed

+104
-7
lines changed

2 files changed

+104
-7
lines changed

dataikuapi/dss/project.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -699,18 +699,21 @@ def get_saved_model(self, sm_id):
699699
"""
700700
return DSSSavedModel(self.client, self.project_key, sm_id)
701701

702-
def create_mlflow_pyfunc_model(self, id):
703-
"""Starts creation of a new external saved model for storing 3rd party models
702+
def create_mlflow_pyfunc_model(self, id, prediction_type = None):
703+
"""
704+
Creates a new external saved model for storing and managing MLFlow models
704705
705-
:param string id:
706+
:param string id: Identifier for the new saved model in the flow
707+
:param string prediction_type: Optional (but needed for most operations). One of BINARY_CLASSIFICATION, MULTICLASS or REGRESSION
706708
"""
707709
if len(id) != 8:
708710
raise ValueError("model id must be 8 characters long")
709711
model = {
710-
"projectKey" : self.project_key,
711712
"id": id,
712-
"savedModelType" : "MLFLOW_PYFUNC"
713+
"savedModelType" : "MLFLOW_PYFUNC",
714+
"predictionType" : prediction_type
713715
}
716+
714717
self.client._perform_empty("POST", "/projects/%s/savedmodels/" % self.project_key, body = model)
715718
return self.get_saved_model(id)
716719

dataikuapi/dss/savedmodel.py

Lines changed: 96 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,35 @@ def get_origin_ml_task(self):
117117
if fmi is not None:
118118
return DSSMLTask.from_full_model_id(self.client, fmi, project_key=self.project_key)
119119

120-
def import_version_from_folder(self, version_id, folder):
120+
def import_mlflow_version_from_path(self, version_id, path):
121+
"""
122+
Create a new version for this saved model from a path containing a MLFlow model.
123+
124+
Requires the saved model to have been created using :meth:`dataikuapi.dss.project.DSSProject.create_mlflow_pyfunc_model`.
125+
126+
:param str version_id: Identifier of the version to create
127+
:param str path: An absolute path on the local filesystem. Must be a folder, and must contain a MLFlow model
128+
129+
:return a :class:MLFlowVersionHandler in order to interact with the new MLFlow model version
130+
"""
131+
# TODO: Add a check that it's indeed a MLFlow model folder
132+
# TODO: Put it in a proper temp folder
133+
# TODO: cleanup the archive
121134
import shutil
122-
shutil.make_archive("tmpmodel", "zip", folder) #[, root_dir[, base_dir[, verbose[, dry_run[, owner[, group[, logger]]]]]]])
135+
shutil.make_archive("tmpmodel", "zip", path) #[, root_dir[, base_dir[, verbose[, dry_run[, owner[, group[, logger]]]]]]])
123136

124137
with open("tmpmodel.zip", "rb") as fp:
125138
self.client._perform_empty("POST", "/projects/%s/savedmodels/%s/versions/%s" % (self.project_key, self.sm_id, version_id),
126139
files={"file":("tmpmodel.zip", fp)})
127140

141+
return self.get_mlflow_version_handler(version_id)
142+
143+
def get_mlflow_version_handler(self, version_id):
144+
"""
145+
Returns a :class:MLFlowVersionHandler to interact with a MLFlow model version
146+
"""
147+
return MLFlowVersionHandler(self, version_id)
148+
128149
########################################################
129150
# Metrics
130151
########################################################
@@ -211,6 +232,69 @@ def delete(self):
211232
"""
212233
return self.client._perform_empty("DELETE", "/projects/%s/savedmodels/%s" % (self.project_key, self.sm_id))
213234

235+
class MLFlowVersionHandler:
236+
"""Handler to interact with an imported MLFlow model version"""
237+
def __init__(self, saved_model, version_id):
238+
"""Do not call this, use :meth:`DSSSavedModel.get_mlflow_version_handler`"""
239+
self.saved_model = saved_model
240+
self.version_id = version_id
241+
242+
def set_core_metadata(self,
243+
target_column_name, class_labels = None,
244+
get_features_from_dataset=None, features_list = None,
245+
output_style="AUTO_DETECT"):
246+
"""
247+
Sets metadata for this MLFlow model version
248+
249+
In addition to target_column_name, one of get_features_from_dataset or features_list must be passed in order
250+
to be able to evaluate performance
251+
252+
:param str target_column_name: name of the target column. Mandatory in order to be able to evaluate performance
253+
:param list class_labels: List of strings, ordered class labels. Mandatory in order to be able to evaluate performance on classification models
254+
:param str get_features_from_dataset: Name of a dataset to get feature names from
255+
:param list features_list: List of {"name": "feature_name", "type": "feature_type"}
256+
"""
257+
258+
metadata = self.saved_model.client._perform_json("GET", "/projects/%s/savedmodels/%s/versions/%s/external-ml/metadata" % (self.saved_model.project_key, self.saved_model.sm_id, self.version_id))
259+
260+
if target_column_name is not None:
261+
metadata["targetColumnName"] = target_column_name
262+
263+
if class_labels is not None:
264+
metadata["classLabels"] = [{"label": l} for l in class_labels]
265+
266+
if get_features_from_dataset is not None:
267+
metadata["gatherFeaturesFromDataset"] = get_features_from_dataset
268+
269+
# TODO: add support for get_features_from_signature=False,
270+
#if get_features_from_signature:
271+
# raise Exception("Get features from signature is not yet implemented")
272+
273+
# TODO: Add support for features_list, with validation
274+
275+
self.saved_model.client._perform_empty("PUT",
276+
"/projects/%s/savedmodels/%s/versions/%s/external-ml/metadata" % (self.saved_model.project_key, self.saved_model.sm_id, self.version_id),
277+
body=metadata)
278+
279+
def evaluate(self, dataset_ref):
280+
"""
281+
Evaluates the performance of this model version on a particular dataset.
282+
After calling this, the "result screens" of the MLFlow model version will be available
283+
(confusion matrix, error distribution, performance metrics, ...)
284+
and more information will be available when calling :meth:`DSSSavedModel.get_version_details`
285+
286+
:meth:`set_core_metadata` must be called before you can evaluate a dataset
287+
288+
:param str dataset_ref: Name of the evaluation dataset to use (either a dataset name or "PROJECT.datasetName")
289+
"""
290+
# TODO Add support for handling a DSSDataset or dataiku.Dataset as dataset_ref
291+
req = {
292+
"datasetRef" : dataset_ref
293+
}
294+
self.saved_model.client._perform_empty("POST",
295+
"/projects/%s/savedmodels/%s/versions/%s/external-ml/actions/evaluate" % (self.saved_model.project_key, self.saved_model.sm_id, self.version_id),
296+
body=req)
297+
214298

215299
class DSSSavedModelSettings:
216300
"""
@@ -225,3 +309,13 @@ def __init__(self, saved_model, settings):
225309

226310
def get_raw(self):
227311
return self.settings
312+
313+
@property
314+
def prediction_metrics_settings(self):
315+
"""The settings of evaluation metrics for a prediction saved model"""
316+
return self.settings["miniTask"]["modeling"]["metrics"]
317+
318+
def save(self):
319+
"""Saves the settings of this saved model"""
320+
self.saved_model.client._perform_empty("PUT", "/projects/%s/savedmodels/%s" % (self.saved_model.project_key, self.saved_model.sm_id),
321+
body=self.settings)

0 commit comments

Comments
 (0)