Skip to content

Commit 5ce6228

Browse files
committed
bugfix
1 parent 58a3907 commit 5ce6228

File tree

2 files changed

+23
-22
lines changed

2 files changed

+23
-22
lines changed

clearbox_preprocessor/preprocessor.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .utils.datetime_transformer import DatetimeTransformer
1717

1818
class Preprocessor:
19-
ML_TASKS = {"classification", "regression"}
19+
ML_TASKS = {"classification", "regression", None}
2020
NUM_FILL_NULL_STRATEGIES = {"interpolate","forward", "backward", "min", "max", "mean", "zero", "one"}
2121
SCALING_STRATEGIES = {"none", "normalize", "standardize", "quantile"}
2222
"""
@@ -117,7 +117,7 @@ def __init__(
117117
scaling: Literal["none", "normalize", "standardize", "quantile"] = "none",
118118
num_fill_null : Literal["interpolate","forward", "backward", "min", "max", "mean", "zero", "one"] = "mean",
119119
unseen_labels = 'ignore',
120-
ml_task: Literal["classification", "regression"] = "classification",
120+
ml_task: Literal["classification", "regression", None] = None,
121121
target_column: str = None,
122122
):
123123
# Argument values check
@@ -159,9 +159,9 @@ def __init__(
159159
self.ml_task = ml_task
160160

161161
if ml_task is not None and target_column is None:
162-
warnings.warn('The target column is not specified.')
162+
warnings.warn('The Machine Learning task was specified but the target column was not specified.')
163163
if target_column is not None and ml_task is None:
164-
warnings.warn('The Machine Learning task is not specified.')
164+
warnings.warn('The target column was not specified but the Machine Learning task was not specified.')
165165
if target_column is not None:
166166
self.excluded_col.append(target_column)
167167

@@ -178,15 +178,16 @@ def __init__(
178178
if len(self.categorical_features) > 0:
179179
self.categorical_transformer = CategoricalTransformer(data, self)
180180

181-
match ml_task:
182-
case "classification":
183-
self.target_col_encoder = LabelEncoder()
184-
self.target_col_encoder.fit(data.select(pl.col(target_column)).collect().to_series())
185-
case "regression":
186-
self.target_col_encoder = [data.select(pl.col(target_column)).min().collect(),
187-
data.select(pl.col(target_column)).max().collect()]
188-
case None:
189-
pass
181+
if target_column is not None:
182+
match ml_task:
183+
case "classification":
184+
self.target_col_encoder = LabelEncoder()
185+
self.target_col_encoder.fit(data.select(pl.col(target_column)).collect().to_series())
186+
case "regression":
187+
self.target_col_encoder = [data.select(pl.col(target_column)).min().collect(),
188+
data.select(pl.col(target_column)).max().collect()]
189+
case None:
190+
pass
190191

191192
def _infer_feature_types(
192193
self,
@@ -399,7 +400,7 @@ def transform(
399400
# Convert time columns to timestamp integers, fill null values by linear interpolation and scale time columns
400401
if len(self.datetime_features)>0:
401402
data = self.datetime_transformer.transform(data, self.time)
402-
if self.time not in self.datetime_transformer.datetime_formats.keys():
403+
if self.time is not None and self.time not in self.datetime_transformer.datetime_formats.keys():
403404
warnings.warn(f"The time column specified '{self.time}' was not detected as datetime type", UserWarning)
404405

405406
# Numerical features processing
@@ -647,14 +648,14 @@ def get_categorical_features(self) -> Tuple[str]:
647648
# real_data = pl.read_csv(os.path.join(file_path,"census_dataset_training.csv"))
648649

649650
# Time series
650-
# file_path = "https://raw.githubusercontent.com/Clearbox-AI/clearbox-synthetic-kit/main/tutorials/time_series/data/daily_delhi_climate"
651-
# path=os.path.join(file_path, "DailyDelhiClimateTrain.csv")
652-
# real_data = pl.read_csv(path)
651+
file_path = "https://raw.githubusercontent.com/Clearbox-AI/clearbox-synthetic-kit/main/tutorials/time_series/data/daily_delhi_climate"
652+
path=os.path.join(file_path, "DailyDelhiClimateTrain.csv")
653+
real_data = pl.read_csv(path)
653654

654-
file_path = "https://raw.githubusercontent.com/Clearbox-AI/clearbox-synthetic-kit/main/tests/resources/uci_adult_dataset"
655-
real_data = pd.read_csv(os.path.join(file_path,"dataset.csv"))
656-
# real_data["income"] = real_data["income"].map({"<=50K": 0, ">50K": 1})
655+
# file_path = "https://raw.githubusercontent.com/Clearbox-AI/clearbox-synthetic-kit/main/tests/resources/uci_adult_dataset"
656+
# real_data = pd.read_csv(os.path.join(file_path,"dataset.csv"))
657+
# # real_data["income"] = real_data["income"].map({"<=50K": 0, ">50K": 1})
657658

658-
preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='normalize', ml_task = "classification",target_column="income")
659+
preprocessor = Preprocessor(real_data, num_fill_null='forward', scaling='normalize')
659660
real_data_preprocessed = preprocessor.transform(real_data)
660661
df_inverse = preprocessor.inverse_transform(real_data_preprocessed)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
setup(
1313
name="clearbox-preprocessor",
14-
version="0.11.7",
14+
version="0.11.8",
1515
author="Dario Brunelli",
1616
author_email="dario@clearbox.ai",
1717
description="A fast polars based data pre-processor for ML datasets",

0 commit comments

Comments
 (0)