|
4 | 4 | import json |
5 | 5 | import time |
6 | 6 | from .metrics import ComputedMetrics |
7 | | -from .utils import DSSDatasetSelectionBuilder |
| 7 | +from .utils import DSSDatasetSelectionBuilder, DSSFilterBuilder |
8 | 8 |
|
9 | 9 | class PredictionSplitParamsHandler(object): |
10 | 10 | """Object to modify the train/test splitting params.""" |
@@ -57,14 +57,16 @@ def set_split_kfold(self, n_folds = 5, selection = None, dataset_name=None): |
57 | 57 | if dataset_name is not None: |
58 | 58 | sp["ssdDatasetSmartName"] = dataset_name |
59 | 59 |
|
60 | | - def set_split_explicit(self, train_selection, test_selection, dataset_name=None, test_dataset_name=None): |
| 60 | + def set_split_explicit(self, train_selection, test_selection, dataset_name=None, test_dataset_name=None, train_filter=None, test_filter=None): |
61 | 61 | """ |
62 | 62 | Sets the train/test split to explicit extract of one or two dataset |
63 | 63 |
|
64 | 64 | :param object train_selection: A :class:`DSSDatasetSelectionBuilder` to build the settings of the extract of the train dataset. May be None (won't be changed) |
65 | 65 | :param object test_selection: A :class:`DSSDatasetSelectionBuilder` to build the settings of the extract of the test dataset. May be None (won't be changed) |
66 | 66 | :param str dataset_name: Name of dataset to use for the extracts. If None, the main dataset used to create the ML Task will be used. |
67 | 67 | :param str test_dataset_name: Name of a second dataset to use for the test data extract. If None, both extracts are done from dataset_name |
| 68 | + :param object train_filter: A :class:`DSSFilterBuilder` to build the settings of the filter of the train dataset. May be None (won't be changed) |
| 69 | + :param object test_filter: A :class:`DSSFilterBuilder` to build the settings of the filter of the test dataset. May be None (won't be changed) |
68 | 70 | """ |
69 | 71 | sp = self.mltask_settings["splitParams"] |
70 | 72 | if dataset_name is None: |
@@ -94,6 +96,17 @@ def set_split_explicit(self, train_selection, test_selection, dataset_name=None, |
94 | 96 | else: |
95 | 97 | test_split["selection"] = test_selection |
96 | 98 |
|
| 99 | + if train_filter is not None: |
| 100 | + if isinstance(train_filter, DSSFilterBuilder): |
| 101 | + train_split["filter"] = train_filter.build() |
| 102 | + else: |
| 103 | + train_split["filter"] = train_filter |
| 104 | + if test_filter is not None: |
| 105 | + if isinstance(test_filter, DSSFilterBuilder): |
| 106 | + test_split["filter"] = test_filter.build() |
| 107 | + else: |
| 108 | + test_split["filter"] = test_filter |
| 109 | + |
97 | 110 |
|
98 | 111 | class DSSMLTaskSettings(object): |
99 | 112 | """ |
|
0 commit comments