Skip to content

Commit 72876b0

Browse files
author
kgued
committed
set and unset split time ordering from split_params instead of ML settings
1 parent 876b7b4 commit 72876b0

File tree

1 file changed

+45
-23
lines changed

1 file changed

+45
-23
lines changed

dataikuapi/dss/ml.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,45 @@ def set_split_explicit(self, train_selection, test_selection, dataset_name=None,
119119
else:
120120
test_split["filter"] = test_filter
121121

122+
def set_order_by(self, feature_name, ascending=True):
123+
"""
124+
Uses a variable to sort the data for train/test split and hyperparameter optimization
125+
:param str feature_name: Name of the variable to use
126+
:param bool ascending: True iff the test set is expected to have larger time values than the train set
127+
"""
128+
self.unset_order_by()
129+
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
130+
raise ValueError("Feature %s doesn't exist in this ML task, can't use as time" % feature_name)
131+
self.mltask_settings['time']['enabled'] = True
132+
self.mltask_settings['time']['timeVariable'] = feature_name
133+
self.mltask_settings['time']['ascending'] = ascending
134+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['missing_handling'] = "DROP_ROW"
135+
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
136+
self.mltask_settings['splitParams']['ssdSplitMode'] = "SORTED"
137+
self.mltask_settings['splitParams']['ssdColumn'] = feature_name
138+
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "KFOLD":
139+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_KFOLD"
140+
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "SHUFFLE":
141+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_SINGLE_SPLIT"
142+
143+
return self
144+
145+
def unset_order_by(self):
146+
"""
147+
Remove time-based ordering.
148+
"""
149+
self.mltask_settings['time']['enabled'] = False
150+
self.mltask_settings['time']['timeVariable'] = None
151+
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
152+
self.mltask_settings['splitParams']['ssdSplitMode'] = "RANDOM"
153+
self.mltask_settings['splitParams']['ssdColumn'] = None
154+
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_KFOLD":
155+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "KFOLD"
156+
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_SINGLE_SPLIT":
157+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "SHUFFLE"
158+
159+
return self
160+
122161

123162
class DSSMLTaskSettings(object):
124163
"""
@@ -368,37 +407,20 @@ def split_ordered_by(self, feature_name, ascending=True):
368407
Uses a variable to sort the data for train/test split and hyperparameter optimization
369408
:param str feature_name: Name of the variable to use
370409
:param bool ascending: True iff the test set is expected to have larger time values than the train set
410+
411+
:rtype: self
371412
"""
372-
self.remove_ordered_split()
373-
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
374-
raise ValueError("Feature %s doesn't exist in this ML task, can't use as time" % feature_name)
375-
self.mltask_settings['time']['enabled'] = True
376-
self.mltask_settings['time']['timeVariable'] = feature_name
377-
self.mltask_settings['time']['ascending'] = ascending
378-
self.mltask_settings['preprocessing']['per_feature'][feature_name]['missing_handling'] = "DROP_ROW"
379-
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
380-
self.mltask_settings['splitParams']['ssdSplitMode'] = "SORTED"
381-
self.mltask_settings['splitParams']['ssdColumn'] = feature_name
382-
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "KFOLD":
383-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_KFOLD"
384-
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "SHUFFLE":
385-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_SINGLE_SPLIT"
413+
self.split_params.set_order_by(feature_name, ascending=True)
386414

387415
return self
388416

389417
def remove_ordered_split(self):
390418
"""
391419
Remove time-based ordering.
420+
421+
:rtype: self
392422
"""
393-
self.mltask_settings['time']['enabled'] = False
394-
self.mltask_settings['time']['timeVariable'] = None
395-
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
396-
self.mltask_settings['splitParams']['ssdSplitMode'] = "RANDOM"
397-
self.mltask_settings['splitParams']['ssdColumn'] = None
398-
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_KFOLD":
399-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "KFOLD"
400-
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_SINGLE_SPLIT":
401-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "SHUFFLE"
423+
self.split_params.unset_order_by()
402424

403425
return self
404426

0 commit comments

Comments
 (0)