Skip to content

Commit e91f0b6

Browse files
author
kgued
committed
move sample weight functions to Prediction specific settings class
1 parent 7b7d573 commit e91f0b6

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

dataikuapi/dss/ml.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -209,22 +209,14 @@ def use_sample_weighting(self, feature_name):
209209
Uses a feature as sample weight
210210
:param str feature_name: Name of the feature to use
211211
"""
212-
self.remove_sample_weighting()
213-
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
214-
raise ValueError("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name)
215-
self.mltask_settings['weight']['weightMethod'] = 'SAMPLE_WEIGHT'
216-
self.mltask_settings['weight']['sampleWeightVariable'] = feature_name
217-
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'WEIGHT'
212+
raise NotImplementedError("use_sample_weighting not available for class {}".format(self.__class__))
218213

219214
def remove_sample_weighting(self):
220215
"""
221216
Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
222217
"""
223-
self.mltask_settings['weight']['weightMethod'] = 'NO_WEIGHTING'
224-
for feature_name in self.mltask_settings['preprocessing']['per_feature']:
225-
if self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] == 'WEIGHT':
226-
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'INPUT'
227-
218+
raise NotImplementedError("remove_sample_weighting not available for class {}".format(self.__class__))
219+
228220
def get_algorithm_settings(self, algorithm_name):
229221
"""
230222
Gets the training settings for a particular algorithm. This returns a reference to the
@@ -406,6 +398,27 @@ def remove_ordered_split(self):
406398
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_SINGLE_SPLIT":
407399
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "SHUFFLE"
408400

401+
def use_sample_weighting(self, feature_name):
402+
"""
403+
Uses a feature as sample weight
404+
:param str feature_name: Name of the feature to use
405+
"""
406+
self.remove_sample_weighting()
407+
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
408+
raise ValueError("Feature %s doesn't exist in this ML task, can't use as weight" % feature_name)
409+
self.mltask_settings['weight']['weightMethod'] = 'SAMPLE_WEIGHT'
410+
self.mltask_settings['weight']['sampleWeightVariable'] = feature_name
411+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'WEIGHT'
412+
413+
def remove_sample_weighting(self):
414+
"""
415+
Remove sample weighting. If a feature was used as weight, it's set back to being an input feature
416+
"""
417+
self.mltask_settings['weight']['weightMethod'] = 'NO_WEIGHTING'
418+
for feature_name in self.mltask_settings['preprocessing']['per_feature']:
419+
if self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] == 'WEIGHT':
420+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['role'] = 'INPUT'
421+
409422

410423
class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
411424
__doc__ = []

0 commit comments

Comments
 (0)