Skip to content

Commit 87795d0

Browse files
committed
adding tests and fixing missing value reconstruction
1 parent 5739cb3 commit 87795d0

File tree

4 files changed

+249
-152
lines changed

4 files changed

+249
-152
lines changed

clearbox_preprocessor/preprocessor.py

Lines changed: 54 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import polars.selectors as cs
66
from sklearn.preprocessing import LabelEncoder
77

8-
from tsfresh import extract_relevant_features
8+
from tsfresh import extract_relevant_features, extract_features
99

1010
from typing import List, Tuple, Literal, Dict
1111
import warnings
@@ -23,7 +23,7 @@
2323
class Preprocessor:
2424
ML_TASKS = {"classification", "regression", None}
2525
NUM_FILL_NULL_STRATEGIES = {"none", "interpolate","forward", "backward", "min", "max", "mean", "zero", "one"}
26-
SCALING_STRATEGIES = {"none", "normalize", "standardize", "quantile"}
26+
SCALING_STRATEGIES = {"none", "normalize", "standardize", "quantile", "kbins"}
2727
"""
2828
A class for preprocessing datasets based on polars, including feature selection, handling missing values, scaling,
2929
and time-series feature extraction.
@@ -128,6 +128,8 @@ def __init__(
128128
# Argument values check
129129
if cat_labels_threshold>1 or cat_labels_threshold<0:
130130
raise ValueError("Invalid value for cat_labels_threshold")
131+
if missing_values_threshold > 1 or missing_values_threshold < 0:
132+
raise ValueError("Invalid value for missing_values_threshold")
131133
if ml_task not in self.ML_TASKS:
132134
raise ValueError("Invalid value for ml_task")
133135
if target_column is not None and target_column not in data.columns:
@@ -184,15 +186,16 @@ def __init__(
184186
self.categorical_transformer = CategoricalTransformer(data, self)
185187

186188
if target_column is not None:
187-
match ml_task:
188-
case "classification":
189-
self.target_col_encoder = LabelEncoder()
190-
self.target_col_encoder.fit(data.select(pl.col(target_column)).collect().to_series())
191-
case "regression":
192-
self.target_col_encoder = [data.select(pl.col(target_column)).min().collect(),
193-
data.select(pl.col(target_column)).max().collect()]
194-
case None:
195-
pass
189+
if ml_task == "classification":
190+
self.target_col_encoder = LabelEncoder()
191+
self.target_col_encoder.fit(data.select(pl.col(target_column)).collect().to_series())
192+
elif ml_task == "regression":
193+
self.target_col_encoder = [data.select(pl.col(target_column)).min().collect(),
194+
data.select(pl.col(target_column)).max().collect()]
195+
elif ml_task is None:
196+
pass
197+
else:
198+
raise ValueError(f"Unsupported ml_task: {ml_task}")
196199

197200
def _infer_feature_types(
198201
self,
@@ -378,6 +381,11 @@ def transform(
378381
preprocessor = Preprocessor(real_data, scaling="standardize")
379382
transformed_data = preprocessor.transform(real_data)
380383
"""
384+
# Check data type compatibility
385+
data_is_pd = isinstance(data, pd.DataFrame)
386+
if data_is_pd != self.data_was_pd:
387+
sys.exit(f'Type mismatch: Preprocessor was initialized with {"pandas" if self.data_was_pd else "polars"} DataFrame but transform was called with {"pandas" if data_is_pd else "polars"} DataFrame')
388+
381389
# Transform data from Pandas.DataFrame or Polars.DataFrame to Polars.LazyFrame
382390
if isinstance(data, pd.DataFrame):
383391
data = pl.from_pandas(data).lazy()
@@ -435,16 +443,17 @@ def transform(
435443

436444
# Handling the target column
437445
if self.target_column is not None:
438-
match self.ml_task:
439-
case "classification":
440-
y_encoded = self.target_col_encoder.transform(data.select(pl.col(self.target_column)).to_series())
441-
data = data.with_columns(pl.Series(y_encoded).alias(self.target_column))
442-
case "regression":
443-
col_min = self.target_col_encoder[0][self.target_column].item()
444-
col_max = self.target_col_encoder[1][self.target_column].item()
445-
data = data.with_columns((pl.col(self.target_column) - col_min) / (col_max - col_min))
446-
case None:
447-
pass
446+
if self.ml_task == "classification":
447+
y_encoded = self.target_col_encoder.transform(data.select(pl.col(self.target_column)).to_series())
448+
data = data.with_columns(pl.Series(y_encoded).alias(self.target_column))
449+
elif self.ml_task == "regression":
450+
col_min = self.target_col_encoder[0][self.target_column].item()
451+
col_max = self.target_col_encoder[1][self.target_column].item()
452+
data = data.with_columns((pl.col(self.target_column) - col_min) / (col_max - col_min))
453+
elif self.ml_task is None:
454+
pass
455+
else:
456+
raise ValueError(f"Unsupported ml_task: {self.ml_task}")
448457

449458
if self.data_was_pd:
450459
data = data.to_pandas()
@@ -515,15 +524,14 @@ def inverse_transform(
515524
if len(self.categorical_features)>0:
516525
data = self.categorical_transformer.inverse_transform(data)
517526
if self.target_column is not None:
518-
match self.ml_task:
519-
case "classification":
520-
y_original = self.target_col_encoder.inverse_transform(data.select(pl.col(self.target_column)).to_series())
521-
data = data.with_columns(pl.Series(y_original).alias(self.target_column))
522-
case "regression":
523-
col_min = self.numerical_parameters[0][self.target_column].item()
524-
col_max = self.numerical_parameters[1][self.target_column].item()
525-
data = data.with_columns(pl.col(self.target_column) * (col_max - col_min) + col_min)
526-
527+
if self.ml_task == "classification":
528+
y_original = self.target_col_encoder.inverse_transform(data.select(pl.col(self.target_column)).to_series())
529+
data = data.with_columns(pl.Series(y_original).alias(self.target_column))
530+
elif self.ml_task == "regression":
531+
col_min = self.target_col_encoder[0][self.target_column].item()
532+
col_max = self.target_col_encoder[1][self.target_column].item()
533+
data = data.with_columns(pl.col(self.target_column) * (col_max - col_min) + col_min)
534+
527535
if self.data_was_pd:
528536
data = data.to_pandas()
529537
return data
@@ -547,7 +555,7 @@ def extract_ts_features(
547555
The label series associated with the data. It can be a Polars Series or a Pandas Series.
548556
time : str, optional
549557
The name of the time column used to sort the data. If not provided, the method
550-
will try to use ``self.time`` if available.
558+
will try to use ``self.time_id`` if available.
551559
column_id : str, optional
552560
The name of the ID column, if present in the data. This is used to distinguish
553561
different time-series within the same dataset.
@@ -564,14 +572,21 @@ def extract_ts_features(
564572
ValueError
565573
If the provided label series is not a Polars Series or a Pandas Series.
566574
ValueError
567-
If the time column name is not provided and ``self.time`` is not available.
575+
If the time column name is not provided and ``self.time_id`` is not available.
568576
569577
Notes
570578
-----
571579
- The function uses the ``extract_relevant_features`` method from the ``tsfresh`` library
572580
to extract features from the time-series data.
573581
- The method stores the filtered features in ``self.features_filtered`` for further use.
574582
"""
583+
# Check if time column is provided, else fallback to self.time_id if available
584+
if time is None:
585+
if self.time_id is not None:
586+
time = self.time_id
587+
else:
588+
raise ValueError("Time column name is required for time-series feature extraction.")
589+
575590
# Transform input dataframe into Pandas.DataFrame
576591
if isinstance(data, pl.LazyFrame):
577592
data_pd = data.collect().to_pandas()
@@ -591,15 +606,16 @@ def extract_ts_features(
591606
print("The labels series must be a Polars Series or a Pandas Series")
592607
return
593608

594-
if not self.time and not time:
595-
print("Please enter a name for the time column")
596-
return
597-
elif self.time and not time:
598-
time = self.time
599-
609+
# First try to extract relevant features
600610
features_filtered = extract_relevant_features(data_pd, y, column_sort=time, column_id=column_id)
601611
self.features_filtered = features_filtered
602612

613+
# If no relevant features were found, fall back to extracting all features
614+
if features_filtered.shape[1] == 0:
615+
# Extract all features without filtering for relevance
616+
all_features = extract_features(data_pd, column_sort=time, column_id=column_id)
617+
return all_features
618+
603619
return features_filtered
604620

605621
def get_features_sizes(self) -> Tuple[List[int], List[int]]:

0 commit comments

Comments
 (0)