Skip to content

Commit 28f0ec6

Browse files
committed
more splitting options
1 parent 217698b commit 28f0ec6

File tree

1 file changed

+61
-0
lines changed

1 file changed

+61
-0
lines changed

dataikuapi/dss/ml.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,71 @@ def set_split_random(self, train_ratio = 0.8, selection = None, dataset_name=Non
3030
sp["ssdSelection"] = selection
3131

3232
sp["ssdTrainingRatio"] = train_ratio
33+
sp["kfold"] = False
3334

3435
if dataset_name is not None:
3536
sp["ssdDatasetSmartName"] = dataset_name
3637

38+
def set_split_kfold(self, n_folds = 5, selection = None, dataset_name=None):
39+
"""
40+
Sets the train/test split to k-fold splitting of an extract of a single dataset
41+
42+
:param int n_folds: number of folds. Must be greater than 0
43+
:param object selection: A :class:`DSSDatasetSelectionBuilder` to build the settings of the extract of the dataset. May be None (won't be changed)
44+
:param str dataset_name: Name of dataset to split. If None, the main dataset used to create the ML Task will be used.
45+
"""
46+
sp = self.mltask_settings["splitParams"]
47+
sp["ttPolicy"] = "SPLIT_SINGLE_DATASET"
48+
if selection is not None:
49+
if isinstance(selection, DSSDatasetSelectionBuilder):
50+
sp["ssdSelection"] = selection.build()
51+
else:
52+
sp["ssdSelection"] = selection
53+
54+
sp["kfold"] = True
55+
sp["nFolds"] = n_folds
56+
57+
if dataset_name is not None:
58+
sp["ssdDatasetSmartName"] = dataset_name
59+
60+
def set_split_explicit(self, train_selection, test_selection, dataset_name=None, test_dataset_name=None):
61+
"""
62+
Sets the train/test split to explicit extract of one or two dataset
63+
64+
:param object train_selection: A :class:`DSSDatasetSelectionBuilder` to build the settings of the extract of the train dataset. May be None (won't be changed)
65+
:param object test_selection: A :class:`DSSDatasetSelectionBuilder` to build the settings of the extract of the test dataset. May be None (won't be changed)
66+
:param str dataset_name: Name of dataset to use for the extracts. If None, the main dataset used to create the ML Task will be used.
67+
:param str test_dataset_name: Name of a second dataset to use for the test data extract. If None, both extracts are done from dataset_name
68+
"""
69+
sp = self.mltask_settings["splitParams"]
70+
if dataset_name is None:
71+
raise Exception("For explicit splitting a dataset_name is mandatory")
72+
if test_dataset_name is None or test_dataset_name == dataset_name:
73+
sp["ttPolicy"] = "EXPLICIT_FILTERING_SINGLE_DATASET"
74+
train_split ={}
75+
test_split = {}
76+
sp['efsdDatasetSmartName'] = dataset_name
77+
sp['efsdTrain'] = train_split
78+
sp['efsdTest'] = test_split
79+
else:
80+
sp["ttPolicy"] = "EXPLICIT_FILTERING_TWO_DATASETS"
81+
train_split ={'datasetSmartName' : dataset_name}
82+
test_split = {'datasetSmartName' : test_dataset_name}
83+
sp['eftdTrain'] = train_split
84+
sp['eftdTest'] = test_split
85+
86+
if train_selection is not None:
87+
if isinstance(train_selection, DSSDatasetSelectionBuilder):
88+
train_split["selection"] = train_selection.build()
89+
else:
90+
train_split["selection"] = train_selection
91+
if test_selection is not None:
92+
if isinstance(test_selection, DSSDatasetSelectionBuilder):
93+
test_split["selection"] = test_selection.build()
94+
else:
95+
test_split["selection"] = test_selection
96+
97+
3798
class DSSMLTaskSettings(object):
3899
"""
39100
Object to read and modify the settings of a ML task.

0 commit comments

Comments
 (0)