Skip to content

Commit 7f85e92

Browse files
authored
Merge pull request #173 from dataiku/experiment/mlflow-models
Experiment/mlflow models
2 parents cebcd93 + 17372a3 commit 7f85e92

File tree

2 files changed

+119
-0
lines changed

2 files changed

+119
-0
lines changed

dataikuapi/dss/project.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,24 @@ def get_saved_model(self, sm_id):
709709
"""
710710
return DSSSavedModel(self.client, self.project_key, sm_id)
711711

712+
def create_mlflow_pyfunc_model(self, id, prediction_type = None):
713+
"""
714+
Creates a new external saved model for storing and managing MLFlow models
715+
716+
:param string id: Identifier for the new saved model in the flow
717+
:param string prediction_type: Optional (but needed for most operations). One of BINARY_CLASSIFICATION, MULTICLASS or REGRESSION
718+
"""
719+
if len(id) != 8:
720+
raise ValueError("model id must be 8 characters long")
721+
model = {
722+
"id": id,
723+
"savedModelType" : "MLFLOW_PYFUNC",
724+
"predictionType" : prediction_type
725+
}
726+
727+
self.client._perform_empty("POST", "/projects/%s/savedmodels/" % self.project_key, body = model)
728+
return self.get_saved_model(id)
729+
712730
########################################################
713731
# Managed folders
714732
########################################################

dataikuapi/dss/savedmodel.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,34 @@ def get_origin_ml_task(self):
117117
if fmi is not None:
118118
return DSSMLTask.from_full_model_id(self.client, fmi, project_key=self.project_key)
119119

120+
def import_mlflow_version_from_path(self, version_id, path):
121+
"""
122+
Create a new version for this saved model from a path containing a MLFlow model.
123+
124+
Requires the saved model to have been created using :meth:`dataikuapi.dss.project.DSSProject.create_mlflow_pyfunc_model`.
125+
126+
:param str version_id: Identifier of the version to create
127+
:param str path: An absolute path on the local filesystem. Must be a folder, and must contain a MLFlow model
128+
129+
:return a :class:MLFlowVersionHandler in order to interact with the new MLFlow model version
130+
"""
131+
# TODO: Add a check that it's indeed a MLFlow model folder
132+
# TODO: Put it in a proper temp folder
133+
# TODO: cleanup the archive
134+
import shutil
135+
shutil.make_archive("tmpmodel", "zip", path) #[, root_dir[, base_dir[, verbose[, dry_run[, owner[, group[, logger]]]]]]])
136+
137+
with open("tmpmodel.zip", "rb") as fp:
138+
self.client._perform_empty("POST", "/projects/%s/savedmodels/%s/versions/%s" % (self.project_key, self.sm_id, version_id),
139+
files={"file":("tmpmodel.zip", fp)})
140+
141+
return self.get_mlflow_version_handler(version_id)
142+
143+
def get_mlflow_version_handler(self, version_id):
144+
"""
145+
Returns a :class:MLFlowVersionHandler to interact with a MLFlow model version
146+
"""
147+
return MLFlowVersionHandler(self, version_id)
120148

121149
########################################################
122150
# Metrics
@@ -204,6 +232,69 @@ def delete(self):
204232
"""
205233
return self.client._perform_empty("DELETE", "/projects/%s/savedmodels/%s" % (self.project_key, self.sm_id))
206234

235+
class MLFlowVersionHandler:
236+
"""Handler to interact with an imported MLFlow model version"""
237+
def __init__(self, saved_model, version_id):
238+
"""Do not call this, use :meth:`DSSSavedModel.get_mlflow_version_handler`"""
239+
self.saved_model = saved_model
240+
self.version_id = version_id
241+
242+
def set_core_metadata(self,
243+
target_column_name, class_labels = None,
244+
get_features_from_dataset=None, features_list = None,
245+
output_style="AUTO_DETECT"):
246+
"""
247+
Sets metadata for this MLFlow model version
248+
249+
In addition to target_column_name, one of get_features_from_dataset or features_list must be passed in order
250+
to be able to evaluate performance
251+
252+
:param str target_column_name: name of the target column. Mandatory in order to be able to evaluate performance
253+
:param list class_labels: List of strings, ordered class labels. Mandatory in order to be able to evaluate performance on classification models
254+
:param str get_features_from_dataset: Name of a dataset to get feature names from
255+
:param list features_list: List of {"name": "feature_name", "type": "feature_type"}
256+
"""
257+
258+
metadata = self.saved_model.client._perform_json("GET", "/projects/%s/savedmodels/%s/versions/%s/external-ml/metadata" % (self.saved_model.project_key, self.saved_model.sm_id, self.version_id))
259+
260+
if target_column_name is not None:
261+
metadata["targetColumnName"] = target_column_name
262+
263+
if class_labels is not None:
264+
metadata["classLabels"] = [{"label": l} for l in class_labels]
265+
266+
if get_features_from_dataset is not None:
267+
metadata["gatherFeaturesFromDataset"] = get_features_from_dataset
268+
269+
# TODO: add support for get_features_from_signature=False,
270+
#if get_features_from_signature:
271+
# raise Exception("Get features from signature is not yet implemented")
272+
273+
# TODO: Add support for features_list, with validation
274+
275+
self.saved_model.client._perform_empty("PUT",
276+
"/projects/%s/savedmodels/%s/versions/%s/external-ml/metadata" % (self.saved_model.project_key, self.saved_model.sm_id, self.version_id),
277+
body=metadata)
278+
279+
def evaluate(self, dataset_ref):
280+
"""
281+
Evaluates the performance of this model version on a particular dataset.
282+
After calling this, the "result screens" of the MLFlow model version will be available
283+
(confusion matrix, error distribution, performance metrics, ...)
284+
and more information will be available when calling :meth:`DSSSavedModel.get_version_details`
285+
286+
:meth:`set_core_metadata` must be called before you can evaluate a dataset
287+
288+
:param str dataset_ref: Name of the evaluation dataset to use (either a dataset name or "PROJECT.datasetName")
289+
"""
290+
# TODO Add support for handling a DSSDataset or dataiku.Dataset as dataset_ref
291+
req = {
292+
"datasetRef" : dataset_ref
293+
}
294+
self.saved_model.client._perform_empty("POST",
295+
"/projects/%s/savedmodels/%s/versions/%s/external-ml/actions/evaluate" % (self.saved_model.project_key, self.saved_model.sm_id, self.version_id),
296+
body=req)
297+
207298

208299
class DSSSavedModelSettings:
209300
"""
@@ -218,3 +309,13 @@ def __init__(self, saved_model, settings):
218309

219310
def get_raw(self):
220311
return self.settings
312+
313+
@property
314+
def prediction_metrics_settings(self):
315+
"""The settings of evaluation metrics for a prediction saved model"""
316+
return self.settings["miniTask"]["modeling"]["metrics"]
317+
318+
def save(self):
319+
"""Saves the settings of this saved model"""
320+
self.saved_model.client._perform_empty("PUT", "/projects/%s/savedmodels/%s" % (self.saved_model.project_key, self.saved_model.sm_id),
321+
body=self.settings)

0 commit comments

Comments
 (0)