Skip to content

Commit 4680b60

Browse files
author
kgued
committed
refactor weighting implem, add all 4 options, deprecated global to DSSMLTaskSettings methods
1 parent cd87725 commit 4680b60

File tree

1 file changed

+72
-14
lines changed

1 file changed

+72
-14
lines changed

dataikuapi/dss/ml.py

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -251,16 +251,15 @@ def use_feature(self, feature_name):
251251

252252
def use_sample_weighting(self, feature_name):
253253
"""
254-
Uses a feature as sample weight
255-
:param str feature_name: Name of the feature to use
254+
Deprecated. Will be removed from DSSMLTaskSettings class
256255
"""
257-
raise NotImplementedError("use_sample_weighting not available for class {}".format(self.__class__))
256+
raise NotImplementedError("use_sample_weighting() not available for class {}".format(self.__class__))
258257

259258
def remove_sample_weighting(self):
260259
"""
261-
Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
260+
Deprecated. Will be removed from DSSMLTaskSettings class
262261
"""
263-
raise NotImplementedError("remove_sample_weighting not available for class {}".format(self.__class__))
262+
raise NotImplementedError("remove_sample_weighting() not available for class {}".format(self.__class__))
264263

265264
def get_algorithm_settings(self, algorithm_name):
266265
"""
@@ -387,6 +386,22 @@ class DSSPredictionMLTaskSettings(DSSMLTaskSettings):
387386
"KERAS_CODE" : "keras"
388387
}
389388

389+
class PredictionTypes:
390+
BINARY = "BINARY_CLASSIFICATION"
391+
REGRESSION = "REGRESSION"
392+
MULTICLASS = "MULTICLASS"
393+
394+
def __init__(self, client, project_key, analysis_id, mltask_id, mltask_settings):
395+
DSSMLTaskSettings.__init__(self, client, project_key, analysis_id, mltask_id, mltask_settings)
396+
397+
if self.get_prediction_type() not in [self.PredictionTypes.BINARY, self.PredictionTypes.REGRESSION, self.PredictionTypes.MULTICLASS]:
398+
raise ValueError("Unknown prediction type: {}".format(self.prediction_type))
399+
400+
self.classification_prediction_types = [self.PredictionTypes.BINARY, self.PredictionTypes.MULTICLASS]
401+
402+
def get_prediction_type(self):
403+
return self.mltask_settings['predictionType']
404+
390405
@property
391406
def split_params(self):
392407
"""
@@ -416,7 +431,7 @@ def split_ordered_by(self, feature_name, ascending=True):
416431
417432
:rtype: self
418433
"""
419-
warnings.warn("split_ordered_by is deprecated, please use split_params.set_order_by() instead", DeprecationWarning)
434+
warnings.warn("split_ordered_by() is deprecated, please use split_params.set_order_by() instead", DeprecationWarning)
420435
self.split_params.set_order_by(feature_name, ascending=True)
421436

422437
return self
@@ -427,34 +442,77 @@ def remove_ordered_split(self):
427442
428443
:rtype: self
429444
"""
430-
warnings.warn("remove_ordered_split is deprecated, please use split_params.unset_order_by() instead", DeprecationWarning)
445+
warnings.warn("remove_ordered_split() is deprecated, please use split_params.unset_order_by() instead", DeprecationWarning)
431446
self.split_params.unset_order_by()
432447

433448
return self
434449

435450
def use_sample_weighting(self, feature_name):
451+
"""
452+
Deprecated. use set_weighting()
453+
"""
454+
warnings.warn("use_sample_weighting() is deprecated, please use set_weighting() instead", DeprecationWarning)
455+
return self.set_weighting(method='SAMPLE_WEIGHT', feature_name=feature_name, )
456+
457+
def set_weighting(self, method, feature_name=None):
436458
"""
437459
Uses a feature as sample weight
438460
:param str feature_name: Name of the feature to use
439461
"""
440-
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
441-
raise ValueError("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name)
462+
self.unset_weighting()
442463

443-
self.remove_sample_weighting()
444-
445-
self.mltask_settings['weight']['weightMethod'] = 'SAMPLE_WEIGHT'
446-
self.mltask_settings['weight']['sampleWeightVariable'] = feature_name
447-
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'WEIGHT'
464+
if method == "NO_WEIGHTING":
465+
self.mltask_settings['weight']['weightMethod'] = method
466+
467+
elif method == "SAMPLE_WEIGHT":
468+
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
469+
raise ValueError("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name)
470+
471+
self.mltask_settings['weight']['weightMethod'] = method
472+
self.mltask_settings['weight']['sampleWeightVariable'] = feature_name
473+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'WEIGHT'
474+
475+
elif method == "CLASS_WEIGHT":
476+
if self.get_prediction_type() not in self.classification_prediction_types:
477+
raise ValueError("Weighting method: {} not compatible with prediction type: {}, should be in {}".format(method, self.get_prediction_type(), self.classification_prediction_types))
478+
479+
self.mltask_settings['weight']['weightMethod'] = method
480+
481+
elif method == "CLASS_AND_SAMPLE_WEIGHT":
482+
if self.get_prediction_type() not in self.classification_prediction_types:
483+
raise ValueError("Weighting method: {} not compatible with prediction type: {}, should be in {}".format(method, self.get_prediction_type(), self.classification_prediction_types))
484+
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
485+
raise ValueError("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name)
486+
487+
self.mltask_settings['weight']['weightMethod'] = method
488+
self.mltask_settings['weight']['sampleWeightVariable'] = feature_name
489+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'WEIGHT'
490+
491+
else:
492+
raise ValueError("Unknown weighting method: {}".format(method))
493+
494+
return self
448495

449496
def remove_sample_weighting(self):
497+
"""
498+
Deprecated. Use unset_weighting() instead
499+
"""
500+
warnings.warn("remove_sample_weighting() is deprecated, please use unset_weighting() instead", DeprecationWarning)
501+
return self.unset_weighting()
502+
503+
def unset_weighting(self):
450504
"""
451505
Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
506+
507+
:rtype: self
452508
"""
453509
self.mltask_settings['weight']['weightMethod'] = 'NO_WEIGHTING'
454510
for feature_name in self.mltask_settings['preprocessing']['per_feature']:
455511
if self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] == 'WEIGHT':
456512
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'INPUT'
457513

514+
return self
515+
458516

459517
class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
460518
__doc__ = []

0 commit comments

Comments
 (0)