Skip to content

Commit 5191416

Browse files
committed
Simplify weight handling
1 parent a8ca347 commit 5191416

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

dataikuapi/dss/ml.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -413,11 +413,18 @@ def use_sample_weighting(self, feature_name):
413413
def set_weighting(self, method, feature_name=None):
414414
"""
415415
Sets the method to weight samples.
416+
417+
If there was a WEIGHT feature declared previously, it will be set back as an INPUT feature first.
418+
416419
:param str method: Method to use. One of NO_WEIGHTING, SAMPLE_WEIGHT (must give a feature name),
417420
CLASS_WEIGHT or CLASS_AND_SAMPLE_WEIGHT (must give a feature name)
418421
:param str feature_name: Name of the feature to use as sample weight
419422
"""
420-
self.unset_weighting()
423+
424+
# First, if there was a WEIGHT feature, restore it as INPUT
425+
for feature_name in self.mltask_settings['preprocessing']['per_feature']:
426+
if self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] == 'WEIGHT':
427+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'INPUT'
421428

422429
if method == "NO_WEIGHTING":
423430
self.mltask_settings['weight']['weightMethod'] = method
@@ -455,23 +462,9 @@ def remove_sample_weighting(self):
455462
"""
456463
Deprecated. Use unset_weighting() instead
457464
"""
458-
warnings.warn("remove_sample_weighting() is deprecated, please use unset_weighting() instead", DeprecationWarning)
465+
warnings.warn("remove_sample_weighting() is deprecated, please use set_weigthing(method=\"NO_WEIGHTING\") instead", DeprecationWarning)
459466
return self.unset_weighting()
460467

461-
def unset_weighting(self):
462-
"""
463-
Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
464-
465-
:rtype: self
466-
"""
467-
self.mltask_settings['weight']['weightMethod'] = 'NO_WEIGHTING'
468-
for feature_name in self.mltask_settings['preprocessing']['per_feature']:
469-
if self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] == 'WEIGHT':
470-
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'INPUT'
471-
472-
return self
473-
474-
475468
class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
476469
__doc__ = []
477470
algorithm_remap = {

0 commit comments

Comments
 (0)