Skip to content

Commit 9c509c5

Browse files
authored
Merge pull request #25 from JulesBelveze/feat/hydra-parser
feat: use hydra config parser
2 parents 889fbdd + b25c02c commit 9c509c5

File tree

16 files changed

+311
-505
lines changed

16 files changed

+311
-505
lines changed

README.md

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -6,35 +6,40 @@
66
</p>
77

88
This repository contains an autoencoder for multivariate time series forecasting.
9-
It features two attention mechanisms described in *[A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction](https://arxiv.org/abs/1704.02971)* and was inspired by [Seanny123's repository](https://github.com/Seanny123/da-rnn).
9+
It features two attention mechanisms described
10+
in *[A Dual-Stage Attention-Based Recurrent Neural Network for Time Series Prediction](https://arxiv.org/abs/1704.02971)*
11+
and was inspired by [Seanny123's repository](https://github.com/Seanny123/da-rnn).
1012

1113
![Autoencoder architecture](autoenc_architecture.png)
14+
1215
## Download and dependencies
16+
1317
To clone the repository please run:
18+
1419
```
1520
git clone https://github.com/JulesBelveze/time-series-autoencoder.git
1621
```
1722

1823
To install all the required dependencies please run:
24+
1925
```
2026
python3 -m venv .venv/tsa
2127
source .venv/tsa/bin/active
2228
poetry install
2329
```
2430

2531
## Usage
32+
33+
The project uses [Hydra](https://hydra.cc/docs/intro/) as a configuration parser. You can simply change the parameters
34+
directly within your `.yaml` file or you can override/set parameter using flags (for a complete guide please refer to
35+
the docs).
36+
2637
```
27-
python main.py [-h] [--batch-size BATCH_SIZE] [--output-size OUTPUT_SIZE]
28-
[--label-col LABEL_COL] [--input-att INPUT_ATT]
29-
[--temporal-att TEMPORAL_ATT] [--seq-len SEQ_LEN]
30-
[--hidden-size-encoder HIDDEN_SIZE_ENCODER]
31-
[--hidden-size-decoder HIDDEN_SIZE_DECODER]
32-
[--reg-factor1 REG_FACTOR1] [--reg-factor2 REG_FACTOR2]
33-
[--reg1 REG1] [--reg2 REG2] [--denoising DENOISING]
34-
[--do-train DO_TRAIN] [--do-eval DO_EVAL]
35-
[--data-path DATA_PATH] [--output-dir OUTPUT_DIR] [--ckpt CKPT]
38+
python3 main.py -cn=[PATH_TO_FOLDER_CONFIG] -cp=[CONFIG_NAME]
3639
```
40+
3741
Optional arguments:
42+
3843
```
3944
-h, --help show this help message and exit
4045
--batch-size BATCH_SIZE
@@ -71,16 +76,19 @@ Optional arguments:
7176
name of folder to output files
7277
--ckpt CKPT checkpoint path for evaluation
7378
```
74-
75-
## Features
76-
* handles multivariate time series
77-
* attention mechanisms
78-
* denoising autoencoder
79-
* sparse autoencoder
80-
81-
## Examples
82-
You can find under the `examples` scripts to train the model in both cases:
83-
* reconstruction: the dataset can be found [here](https://gist.github.com/JulesBelveze/99ecdbea62f81ce647b131e7badbb24a)
84-
* forecasting: the dataset can be found [here](https://gist.github.com/JulesBelveze/e9997b9b0b68101029b461baf698bd72)
79+
80+
## Features
81+
82+
* handles multivariate time series
83+
* attention mechanisms
84+
* denoising autoencoder
85+
* sparse autoencoder
86+
87+
## Examples
88+
89+
You can find under the `examples` scripts to train the model in both cases:
90+
91+
* reconstruction: the dataset can be found [here](https://gist.github.com/JulesBelveze/99ecdbea62f81ce647b131e7badbb24a)
92+
* forecasting: the dataset can be found [here](https://gist.github.com/JulesBelveze/e9997b9b0b68101029b461baf698bd72)
8593

8694

examples/forecasting/config.yaml

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
data:
2+
_target_: tsa.dataset.TimeSeriesDataset
3+
batch_size: 16
4+
categorical_cols: [ ]
5+
index_col: "Date_Time"
6+
target_col: [ "AH" ]
7+
data_path: "../data/AirQualityUCI.csv"
8+
prediction_window: 1
9+
seq_length: 3
10+
task:
11+
_target_: tsa.dataset.Tasks
12+
value: prediction
13+
14+
training:
15+
batch_size: ${data.batch_size}
16+
denoising: False
17+
directions: 1
18+
gradient_accumulation_steps: 1
19+
hidden_size_encoder: 64
20+
hidden_size_decoder: 64
21+
input_att: True
22+
lr: 1e-5
23+
lrs_step_size: 5000
24+
max_grad_norm: 0.1
25+
num_epochs: 100
26+
output_size: 1
27+
reg1: True
28+
reg2: False
29+
reg_factor1: 1e-4
30+
reg_factor2: 1e-4
31+
seq_len: ${data.seq_length}
32+
temporal_att: True
33+
34+
general:
35+
do_eval: True
36+
do_train: True
37+
logging_steps: 100
38+
39+
output_dir: "output"
40+
save_steps: 5000
41+
eval_during_training: True

examples/forecasting/config_forecasting.py

Lines changed: 0 additions & 60 deletions
This file was deleted.
Lines changed: 19 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,29 @@
1+
import hydra
12
import torch
2-
import argparse
3-
import pandas as pd
43
import torch.nn as nn
5-
from tsa import TimeSeriesDataset, AutoEncForecast, train, evaluate
6-
from .config_forecasting import config
4+
from hydra.utils import instantiate
75

8-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
9-
10-
def parse_args():
11-
"""
12-
Parse command line arguments.
6+
from tsa import AutoEncForecast, train, evaluate
7+
from tsa.utils import load_checkpoint
138

14-
Args:
15-
"""
16-
parser = argparse.ArgumentParser()
17-
parser.add_argument("--batch-size", default=config["batch_size"], type=int, help="batch size")
18-
parser.add_argument("--output-size", default=config["output_size"], type=int,
19-
help="size of the ouput: default value to 1 for forecasting")
20-
parser.add_argument("--label-col", default=config["label_col"], type=str, help="name of the target column")
21-
parser.add_argument("--input-att", default=config["input_att"], type=lambda x: (str(x).lower() == "true"),
22-
help="whether or not activate the input attention mechanism")
23-
parser.add_argument("--temporal-att", default=config["temporal_att"], type=lambda x: (str(x).lower() == "true"),
24-
help="whether or not activate the temporal attention mechanism")
25-
parser.add_argument("--seq-len", default=config["seq_len"], type=int, help="window length to use for forecasting")
26-
parser.add_argument("--hidden-size-encoder", default=config["hidden_size_encoder"], type=int,
27-
help="size of the encoder's hidden states")
28-
parser.add_argument("--hidden-size-decoder", default=config["hidden_size_decoder"], type=int,
29-
help="size of the decoder's hidden states")
30-
parser.add_argument("--reg-factor1", default=config["reg_factor1"], type=float,
31-
help="contribution factor of the L1 regularization if using a sparse autoencoder")
32-
parser.add_argument("--reg-factor2", default=config["reg_factor2"], type=float,
33-
help="contribution factor of the L2 regularization if using a sparse autoencoder")
34-
parser.add_argument("--reg1", default=config["reg1"], type=lambda x: (str(x).lower() == "true"),
35-
help="activate/deactivate L1 regularization")
36-
parser.add_argument("--reg2", default=config["reg2"], type=lambda x: (str(x).lower() == "true"),
37-
help="activate/deactivate L2 regularization")
38-
parser.add_argument("--denoising", default=config["denoising"], type=lambda x: (str(x).lower() == "true"),
39-
help="whether or not to use a denoising autoencoder")
40-
parser.add_argument("--do-train", default=True, type=lambda x: (str(x).lower() == "true"),
41-
help="whether or not to train the model")
42-
parser.add_argument("--do-eval", default=False, type=lambda x: (str(x).lower() == "true"),
43-
help="whether or not evaluating the mode")
44-
parser.add_argument("--output-dir", default=config["output_dir"], help="name of folder to output files")
45-
parser.add_argument("--ckpt", default=None, help="checkpoint path for evaluation")
46-
return parser.parse_args()
9+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
4710

4811

49-
if __name__ == "__main__":
50-
args = vars(parse_args())
51-
config.update(args)
12+
@hydra.main(config_path="./", config_name="config")
13+
def run(cfg):
14+
ts = instantiate(cfg.data)
15+
train_iter, test_iter, nb_features = ts.get_loaders()
5216

53-
df = pd.read_csv("data/AirQualityUCI.csv", index_col=config["index_col"])
17+
model = AutoEncForecast(cfg.training, input_size=nb_features).to(device)
18+
criterion = nn.MSELoss()
19+
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.lr)
5420

55-
ts = TimeSeriesDataset(
56-
data=df,
57-
categorical_cols=config["categorical_cols"],
58-
target_col=config["label_col"],
59-
seq_length=config["seq_len"],
60-
prediction_window=config["prediction_window"]
61-
)
62-
train_iter, test_iter, nb_features = ts.get_loaders(batch_size=config["batch_size"])
21+
if cfg.general.do_train:
22+
train(train_iter, test_iter, model, criterion, optimizer, cfg, ts)
23+
if cfg.general.do_eval and cfg.general.get("ckpt", False):
24+
model, _, loss, epoch = load_checkpoint(cfg.general.ckpt, model, optimizer, device)
25+
evaluate(test_iter, loss, model, cfg, ts)
6326

64-
model = AutoEncForecast(config, input_size=nb_features).to(config["device"])
65-
criterion = nn.MSELoss()
66-
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
6727

68-
if config["do_eval"] and config["ckpt"]:
69-
model, _, loss, epoch = load_checkpoint(config["ckpt"], model, optimizer, config["device"])
70-
evaluate(test_iter, loss, model, config)
71-
elif config["do_train"]:
72-
train(train_iter, test_iter, model, criterion, optimizer, config)
28+
if __name__ == "__main__":
29+
run()
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
data:
2+
_target_: tsa.dataset.TimeSeriesDataset
3+
batch_size: 16
4+
categorical_cols: [ ]
5+
index_col: "Date_Time"
6+
target_col: [ ]
7+
data_path: "../data/AirQualityUCI.csv"
8+
prediction_window: 1
9+
seq_length: 3
10+
task:
11+
_target_: tsa.dataset.Tasks
12+
value: reconstruction
13+
14+
training:
15+
denoising: False
16+
directions: 1
17+
gradient_accumulation_steps: 1
18+
hidden_size_encoder: 64
19+
hidden_size_decoder: 64
20+
input_att: True
21+
lr: 1e-5
22+
lrs_step_size: 5000
23+
max_grad_norm: 0.1
24+
num_epochs: 100
25+
output_size: 13
26+
reg1: True
27+
reg2: False
28+
reg_factor1: 1e-4
29+
reg_factor2: 1e-4
30+
seq_len: ${data.seq_length}
31+
temporal_att: True
32+
33+
general:
34+
do_eval: True
35+
do_train: True
36+
logging_steps: 100
37+
38+
output_dir: "output"
39+
save_steps: 5000
40+
eval_during_training: True

examples/reconstruction/config_reconstruction.py

Lines changed: 0 additions & 60 deletions
This file was deleted.

0 commit comments

Comments
 (0)