Skip to content

Commit 7b7d573

Browse files
author
kgued
committed
move split functions to Prediction ml task settings instead of abstract ml task settings
1 parent 3b17b94 commit 7b7d573

File tree

1 file changed

+59
-37
lines changed

1 file changed

+59
-37
lines changed

dataikuapi/dss/ml.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -142,61 +142,27 @@ def get_raw(self):
142142
"""
143143
return self.mltask_settings
144144

145-
@property
146-
def split_params(self):
147-
"""
148-
Gets a handle to modify train/test splitting params.
149-
150-
:rtype: :class:`PredictionSplitParamsHandler`
151-
"""
152-
return self.get_split_params()
153-
154145
def get_split_params(self):
155146
"""
156147
Gets a handle to modify train/test splitting params.
157148
158149
:rtype: :class:`PredictionSplitParamsHandler`
159150
"""
160-
return PredictionSplitParamsHandler(self.mltask_settings)
161-
162-
@split_params.setter
163-
def split_params(self, value):
164-
raise AttributeError("split_params reference cannot be overwritten, get a handle and modify it with a set method instead")
151+
raise NotImplementedError("get_split_params not available for class {}".format(self.__class__))
165152

166153
def split_ordered_by(self, feature_name, ascending=True):
167154
"""
168155
Uses a variable to sort the data for train/test split and hyperparameter optimization
169156
:param str feature_name: Name of the variable to use
170157
:param bool ascending: True iff the test set is expected to have larger time values than the train set
171158
"""
172-
self.remove_ordered_split()
173-
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
174-
raise ValueError("Feature %s doesn't exist in this ML task, can't use as time" % feature_name)
175-
self.mltask_settings['time']['enabled'] = True
176-
self.mltask_settings['time']['timeVariable'] = feature_name
177-
self.mltask_settings['time']['ascending'] = ascending
178-
self.mltask_settings['preprocessing']['per_feature'][feature_name]['missing_handling'] = "DROP_ROW"
179-
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
180-
self.mltask_settings['splitParams']['ssdSplitMode'] = "SORTED"
181-
self.mltask_settings['splitParams']['ssdColumn'] = feature_name
182-
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "KFOLD":
183-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_KFOLD"
184-
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "SHUFFLE":
185-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_SINGLE_SPLIT"
159+
raise NotImplementedError("split_ordered_by not available for class {}".format(self.__class__))
186160

187161
def remove_ordered_split(self):
188162
"""
189163
Remove time-based ordering.
190164
"""
191-
self.mltask_settings['time']['enabled'] = False
192-
self.mltask_settings['time']['timeVariable'] = None
193-
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
194-
self.mltask_settings['splitParams']['ssdSplitMode'] = "RANDOM"
195-
self.mltask_settings['splitParams']['ssdColumn'] = None
196-
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_KFOLD":
197-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "KFOLD"
198-
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_SINGLE_SPLIT":
199-
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "SHUFFLE"
165+
raise NotImplementedError("remove_ordered_split not available for class {}".format(self.__class__))
200166

201167
def get_feature_preprocessing(self, feature_name):
202168
"""
@@ -384,6 +350,62 @@ class DSSPredictionMLTaskSettings(DSSMLTaskSettings):
384350
"KERAS_CODE" : "keras"
385351
}
386352

353+
@property
354+
def split_params(self):
355+
"""
356+
Gets a handle to modify train/test splitting params.
357+
358+
:rtype: :class:`PredictionSplitParamsHandler`
359+
"""
360+
return self.get_split_params()
361+
362+
def get_split_params(self):
363+
"""
364+
Gets a handle to modify train/test splitting params.
365+
366+
:rtype: :class:`PredictionSplitParamsHandler`
367+
"""
368+
return PredictionSplitParamsHandler(self.mltask_settings)
369+
370+
@split_params.setter
371+
def split_params(self, value):
372+
raise AttributeError("split_params reference cannot be overwritten, get a handle and modify it with a set method instead")
373+
374+
def split_ordered_by(self, feature_name, ascending=True):
375+
"""
376+
Uses a variable to sort the data for train/test split and hyperparameter optimization
377+
:param str feature_name: Name of the variable to use
378+
:param bool ascending: True iff the test set is expected to have larger time values than the train set
379+
"""
380+
self.remove_ordered_split()
381+
if not feature_name in self.mltask_settings["preprocessing"]["per_feature"]:
382+
raise ValueError("Feature %s doesn't exist in this ML task, can't use as time" % feature_name)
383+
self.mltask_settings['time']['enabled'] = True
384+
self.mltask_settings['time']['timeVariable'] = feature_name
385+
self.mltask_settings['time']['ascending'] = ascending
386+
self.mltask_settings['preprocessing']['per_feature'][feature_name]['missing_handling'] = "DROP_ROW"
387+
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
388+
self.mltask_settings['splitParams']['ssdSplitMode'] = "SORTED"
389+
self.mltask_settings['splitParams']['ssdColumn'] = feature_name
390+
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "KFOLD":
391+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_KFOLD"
392+
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "SHUFFLE":
393+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "TIME_SERIES_SINGLE_SPLIT"
394+
395+
def remove_ordered_split(self):
396+
"""
397+
Remove time-based ordering.
398+
"""
399+
self.mltask_settings['time']['enabled'] = False
400+
self.mltask_settings['time']['timeVariable'] = None
401+
if self.mltask_settings['splitParams']['ttPolicy'] == "SPLIT_SINGLE_DATASET":
402+
self.mltask_settings['splitParams']['ssdSplitMode'] = "RANDOM"
403+
self.mltask_settings['splitParams']['ssdColumn'] = None
404+
if self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_KFOLD":
405+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "KFOLD"
406+
elif self.mltask_settings['modeling']['gridSearchParams']['mode'] == "TIME_SERIES_SINGLE_SPLIT":
407+
self.mltask_settings['modeling']['gridSearchParams']['mode'] = "SHUFFLE"
408+
387409

388410
class DSSClusteringMLTaskSettings(DSSMLTaskSettings):
389411
__doc__ = []

0 commit comments

Comments
 (0)