Skip to content

Commit 73523ba

Browse files
committed
Add time-based ordering shortcuts for DSSMLTaskSettings
1 parent 3b01c7e commit 73523ba

File tree

1 file changed

+36
-0
lines changed

1 file changed

+36
-0
lines changed

dataikuapi/dss/ml.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,42 @@ def get_split_params(self):
138138
"""
139139
return PredictionSplitParamsHandler(self.mltask_settings)
140140

141+
def use_time_variable(self, feature_name, ascending=True):
142+
"""
143+
Uses a variable to sort the data for train/test split and hyperparameter optimization
144+
:param str feature_name: Name of the variable to use
145+
:param bool ascending: True iff the test set is expected to have larger time values than the train set
146+
"""
147+
self.remove_time_variable()
148+
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
149+
raise ValueError("Feature %s doesn't exist in this ML task, can't use as time" % feature_name)
150+
self.mltask_settings['time']['enabled'] = True
151+
self.mltask_settings['time']['timeVariable'] = feature_name
152+
self.mltask_settings['time']['ascending'] = ascending
153+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['missing_handling'] = "DROP_ROW"
154+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['numerical_handling'] = "REGULAR"
155+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['rescaling'] = "NONE"
156+
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
157+
self.mltask_settings['splitParams']['ssdSplitMode'] = "SORTED"
158+
self.mltask_settings['splitParams']['ssdColumn'] = feature_name
159+
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "KFOLD":
160+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_KFOLD"
161+
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "SHUFFLE":
162+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_SINGLE_SPLIT"
163+
164+
def remove_time_variable(self):
165+
"""
166+
Remove time-based ordering.
167+
"""
168+
self.mltask_settings['time']['enabled'] = False
169+
self.mltask_settings['time']['timeVariable'] = None
170+
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
171+
self.mltask_settings['splitParams']['ssdSplitMode'] = "RANDOM"
172+
self.mltask_settings['splitParams']['ssdColumn'] = None
173+
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_KFOLD":
174+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "KFOLD"
175+
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_SINGLE_SPLIT":
176+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "SHUFFLE"
141177

142178
def get_feature_preprocessing(self, feature_name):
143179
"""

0 commit comments

Comments
 (0)