Skip to content

Commit 9bff0e7

Browse files
authored
fix: target scale (#19)
1 parent 2c00b37 commit 9bff0e7

File tree

4 files changed

+44
-32
lines changed

4 files changed

+44
-32
lines changed

tsa/dataset.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,47 +7,50 @@
77

88
class TimeSeriesDataset(object):
99
def __init__(self, data, categorical_cols, target_col, seq_length, prediction_window=1):
10-
'''
10+
"""
1111
:param data: dataset of type pandas.DataFrame
1212
:param categorical_cols: name of the categorical columns, if None pass empty list
1313
:param target_col: name of the targeted column
1414
:param seq_length: window length to use
1515
:param prediction_window: window length to predict
16-
'''
16+
"""
1717
self.data = data
1818
self.categorical_cols = categorical_cols
1919
self.numerical_cols = list(set(data.columns) - set(categorical_cols) - set(target_col))
2020
self.target_col = target_col
2121
self.seq_length = seq_length
2222
self.prediction_window = prediction_window
23-
self.preprocessor = None
24-
25-
def preprocess_data(self):
26-
'''Preprocessing function'''
27-
X = self.data.drop(self.target_col, axis=1)
28-
y = self.data[self.target_col]
2923

3024
self.preprocessor = ColumnTransformer(
3125
[("scaler", StandardScaler(), self.numerical_cols),
3226
("encoder", OneHotEncoder(), self.categorical_cols)],
3327
remainder="passthrough"
3428
)
29+
if self.target_col:
30+
self.y_scaler = StandardScaler()
31+
32+
def preprocess_data(self):
33+
"""Preprocessing function"""
34+
X = self.data.drop(self.target_col, axis=1)
35+
y = self.data[self.target_col]
3536

3637
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=0.8, shuffle=False)
3738
X_train = self.preprocessor.fit_transform(X_train)
3839
X_test = self.preprocessor.transform(X_test)
3940

4041
if self.target_col:
41-
return X_train, X_test, y_train.values, y_test.values
42+
y_train = self.y_scaler.fit_transform(y_train)
43+
y_test = self.y_scaler.transform(y_test)
44+
return X_train, X_test, y_train, y_test
4245
return X_train, X_test
4346

4447
def frame_series(self, X, y=None):
45-
'''
48+
"""
4649
Function used to prepare the data for time series prediction
4750
:param X: set of features
4851
:param y: targeted value to predict
4952
:return: TensorDataset
50-
'''
53+
"""
5154
nb_obs, nb_features = X.shape
5255
features, target, y_hist = [], [], []
5356

@@ -69,11 +72,11 @@ def frame_series(self, X, y=None):
6972
return TensorDataset(features_var)
7073

7174
def get_loaders(self, batch_size: int):
72-
'''
75+
"""
7376
Preprocess and frame the dataset
7477
:param batch_size: batch size
7578
:return: DataLoaders associated to training and testing data
76-
'''
79+
"""
7780
X_train, X_test, y_train, y_test = self.preprocess_data()
7881
nb_features = X_train.shape[1]
7982

@@ -83,3 +86,9 @@ def get_loaders(self, batch_size: int):
8386
train_iter = DataLoader(train_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
8487
test_iter = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
8588
return train_iter, test_iter, nb_features
89+
90+
def invert_scale(self, predictions):
91+
if isinstance(predictions, torch.Tensor):
92+
predictions = predictions.numpy()
93+
unscaled = self.y_scaler.inverse_transform(predictions)
94+
return torch.Tensor(unscaled)

tsa/eval.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from tqdm import tqdm
88

99

10-
def evaluate(test_iter, criterion, model, config):
10+
def evaluate(test_iter, criterion, model, config, ts):
1111
"""
1212
Evaluate the model on the given test set.
1313
@@ -39,18 +39,21 @@ def evaluate(test_iter, criterion, model, config):
3939
targets.append(target.squeeze(1).cpu())
4040
attentions.append(att.cpu())
4141

42+
predictions, targets = torch.cat(predictions), torch.cat(targets)
4243

4344
if config['do_eval']:
45+
preds, targets = ts.invert_scale(predictions), ts.invert_scale(targets)
46+
4447
plt.figure()
45-
plt.plot(torch.cat(predictions), linewidth=.3)
46-
plt.plot(torch.cat(targets), linewidth=.3)
48+
plt.plot(preds, linewidth=.3)
49+
plt.plot(targets, linewidth=.3)
4750
plt.savefig("{}/preds.png".format(config["output_dir"]))
4851

49-
torch.save(torch.cat(targets), os.path.join(config['output_dir'], "targets.pt"))
50-
torch.save(torch.cat(predictions), os.path.join(config['output_dir'], "predictions.pt"))
51-
torch.save(torch.cat(attentions), os.path.join(config['output_dir'], "attentions.pt"))
52+
torch.save(targets, os.path.join(config['output_dir'], "targets.pt"))
53+
torch.save(predictions, os.path.join(config['output_dir'], "predictions.pt"))
54+
torch.save(attentions, os.path.join(config['output_dir'], "attentions.pt"))
5255

53-
results = get_eval_report(eval_loss / len(test_iter), torch.cat(predictions), torch.cat(targets))
56+
results = get_eval_report(eval_loss / len(test_iter), predictions, targets)
5457
file_eval = os.path.join(config['output_dir'], "eval_results.txt")
5558
with open(file_eval, "w") as f:
5659
f.write("********* EVAL REPORT ********\n")

tsa/main.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,11 @@
33
import argparse
44
import pandas as pd
55
import torch.nn as nn
6-
from config import config
7-
from dataset import TimeSeriesDataset
8-
from model import AutoEncForecast
9-
from train import train
10-
from eval import evaluate
6+
from .config import config
7+
from .dataset import TimeSeriesDataset
8+
from .model import AutoEncForecast
9+
from .train import train
10+
from .eval import evaluate
1111

1212

1313
def parse_args():
@@ -40,11 +40,11 @@ def parse_args():
4040
help="activate/deactivate L2 regularization")
4141
parser.add_argument("--denoising", default=config["denoising"], type=lambda x: (str(x).lower() == "true"),
4242
help="whether or not to use a denoising autoencoder")
43-
parser.add_argument("--do-train", default=False, type=lambda x: (str(x).lower() == "true"),
43+
parser.add_argument("--do-train", default=True, type=lambda x: (str(x).lower() == "true"),
4444
help="whether or not to train the model")
45-
parser.add_argument("--do-eval", default=False, type=lambda x: (str(x).lower() == "true"),
45+
parser.add_argument("--do-eval", default=True, type=lambda x: (str(x).lower() == "true"),
4646
help="whether or not evaluating the mode")
47-
parser.add_argument("--data-path", default="data.csv", help="path to data file")
47+
parser.add_argument("--data-path", default='nflx.csv', help="path to data file")
4848
parser.add_argument("--output-dir", default=config["output_dir"], help="name of folder to output files")
4949
parser.add_argument("--ckpt", default=None, help="checkpoint path for evaluation")
5050
return parser.parse_args()
@@ -91,9 +91,9 @@ def run(args):
9191

9292
if config["do_eval"] and config["ckpt"]:
9393
model, _, loss, epoch = load_checkpoint(config["ckpt"], model, optimizer, config["device"])
94-
evaluate(test_iter, loss, model, config)
94+
evaluate(test_iter, loss, model, config, ts)
9595
elif config["do_train"]:
96-
train(train_iter, test_iter, model, criterion, optimizer, config)
96+
train(train_iter, test_iter, model, criterion, optimizer, config, ts)
9797

9898

9999
if __name__ == "__main__":

tsa/train.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from .eval import evaluate
99

1010

11-
def train(train_iter, test_iter, model, criterion, optimizer, config):
11+
def train(train_iter, test_iter, model, criterion, optimizer, config, ts):
1212
"""
1313
Training function.
1414
@@ -65,7 +65,7 @@ def train(train_iter, test_iter, model, criterion, optimizer, config):
6565

6666
if global_step % config['logging_steps'] == 0:
6767
if config['eval_during_training']:
68-
results = evaluate(test_iter, criterion, model, config)
68+
results = evaluate(test_iter, criterion, model, config, ts)
6969
for key, val in results.items():
7070
tb_writer_test.add_scalar("eval_{}".format(key), val, global_step)
7171

0 commit comments

Comments
 (0)