Skip to content

Commit da47ea4

Browse files
committed
chore: use hydra config parser
1 parent b76694a commit da47ea4

File tree

13 files changed

+304
-437
lines changed

13 files changed

+304
-437
lines changed

examples/forecasting/config.yaml

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
data:
2+
categorical_cols: [ ]
3+
label_col: ["AH"]
4+
index_col: "Date_Time"
5+
prediction_window: 1
6+
seq_len: 3
7+
8+
training:
9+
num_epochs: 100
10+
batch_size: 16
11+
lr: 1e-5
12+
reg1: True
13+
reg2: False
14+
reg_factor1: 1e-4
15+
reg_factor2: 1e-4
16+
hidden_size_encoder: 64
17+
hidden_size_decoder: 64
18+
output_size: 1
19+
input_att: True
20+
temporal_att: True
21+
denoising: False
22+
directions: 1
23+
max_grad_norm: 0.1
24+
gradient_accumulation_steps: 1
25+
lrs_step_size: 5000
26+
27+
general:
28+
do_eval: True
29+
do_train: True
30+
logging_steps: 100
31+
32+
output_dir: "output"
33+
save_steps: 5000
34+
eval_during_training: True

examples/forecasting/config_forecasting.py

Lines changed: 0 additions & 60 deletions
This file was deleted.
Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,37 @@
1-
import torch
2-
import argparse
1+
import hydra
32
import pandas as pd
3+
import torch
44
import torch.nn as nn
5+
56
from tsa import TimeSeriesDataset, AutoEncForecast, train, evaluate
6-
from .config_forecasting import config
7+
from tsa.utils import load_checkpoint
78

89
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
910

10-
def parse_args():
11-
"""
12-
Parse command line arguments.
13-
14-
Args:
15-
"""
16-
parser = argparse.ArgumentParser()
17-
parser.add_argument("--batch-size", default=config["batch_size"], type=int, help="batch size")
18-
parser.add_argument("--output-size", default=config["output_size"], type=int,
19-
help="size of the ouput: default value to 1 for forecasting")
20-
parser.add_argument("--label-col", default=config["label_col"], type=str, help="name of the target column")
21-
parser.add_argument("--input-att", default=config["input_att"], type=lambda x: (str(x).lower() == "true"),
22-
help="whether or not activate the input attention mechanism")
23-
parser.add_argument("--temporal-att", default=config["temporal_att"], type=lambda x: (str(x).lower() == "true"),
24-
help="whether or not activate the temporal attention mechanism")
25-
parser.add_argument("--seq-len", default=config["seq_len"], type=int, help="window length to use for forecasting")
26-
parser.add_argument("--hidden-size-encoder", default=config["hidden_size_encoder"], type=int,
27-
help="size of the encoder's hidden states")
28-
parser.add_argument("--hidden-size-decoder", default=config["hidden_size_decoder"], type=int,
29-
help="size of the decoder's hidden states")
30-
parser.add_argument("--reg-factor1", default=config["reg_factor1"], type=float,
31-
help="contribution factor of the L1 regularization if using a sparse autoencoder")
32-
parser.add_argument("--reg-factor2", default=config["reg_factor2"], type=float,
33-
help="contribution factor of the L2 regularization if using a sparse autoencoder")
34-
parser.add_argument("--reg1", default=config["reg1"], type=lambda x: (str(x).lower() == "true"),
35-
help="activate/deactivate L1 regularization")
36-
parser.add_argument("--reg2", default=config["reg2"], type=lambda x: (str(x).lower() == "true"),
37-
help="activate/deactivate L2 regularization")
38-
parser.add_argument("--denoising", default=config["denoising"], type=lambda x: (str(x).lower() == "true"),
39-
help="whether or not to use a denoising autoencoder")
40-
parser.add_argument("--do-train", default=True, type=lambda x: (str(x).lower() == "true"),
41-
help="whether or not to train the model")
42-
parser.add_argument("--do-eval", default=False, type=lambda x: (str(x).lower() == "true"),
43-
help="whether or not evaluating the mode")
44-
parser.add_argument("--output-dir", default=config["output_dir"], help="name of folder to output files")
45-
parser.add_argument("--ckpt", default=None, help="checkpoint path for evaluation")
46-
return parser.parse_args()
47-
4811

49-
if __name__ == "__main__":
50-
args = vars(parse_args())
51-
config.update(args)
52-
53-
df = pd.read_csv("data/AirQualityUCI.csv", index_col=config["index_col"])
12+
@hydra.main(config_path="./", config_name="config")
13+
def run(args):
14+
df = pd.read_csv("data/AirQualityUCI.csv", index_col=args.data.index_col)
5415

5516
ts = TimeSeriesDataset(
5617
data=df,
57-
categorical_cols=config["categorical_cols"],
58-
target_col=config["label_col"],
59-
seq_length=config["seq_len"],
60-
prediction_window=config["prediction_window"]
18+
categorical_cols=args.data.categorical_cols,
19+
target_col=args.data.label_col,
20+
seq_length=args.data.seq_len,
21+
prediction_window=args.data.prediction_window
6122
)
62-
train_iter, test_iter, nb_features = ts.get_loaders(batch_size=config["batch_size"])
23+
train_iter, test_iter, nb_features = ts.get_loaders(batch_size=args.data.batch_size)
6324

64-
model = AutoEncForecast(config, input_size=nb_features).to(config["device"])
25+
model = AutoEncForecast(args.training, input_size=nb_features).to(device)
6526
criterion = nn.MSELoss()
66-
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
27+
optimizer = torch.optim.Adam(model.parameters(), lr=args.training.lr)
28+
29+
if args.general.do_eval and args.general.ckpt:
30+
model, _, loss, epoch = load_checkpoint(args.general.ckpt, model, optimizer, device)
31+
evaluate(test_iter, loss, model, args, ts)
32+
elif args.general.do_train:
33+
train(train_iter, test_iter, model, criterion, optimizer, args, ts)
6734

68-
if config["do_eval"] and config["ckpt"]:
69-
model, _, loss, epoch = load_checkpoint(config["ckpt"], model, optimizer, config["device"])
70-
evaluate(test_iter, loss, model, config)
71-
elif config["do_train"]:
72-
train(train_iter, test_iter, model, criterion, optimizer, config)
35+
36+
if __name__ == "__main__":
37+
run()
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
data:
2+
categorical_cols: [ ]
3+
label_col: []
4+
index_col: "Date_Time"
5+
prediction_window: 1
6+
seq_len: 3
7+
batch_size: 16
8+
9+
training:
10+
num_epochs: 100
11+
lr: 1e-5
12+
reg1: True
13+
reg2: False
14+
reg_factor1: 1e-4
15+
reg_factor2: 1e-4
16+
hidden_size_encoder: 64
17+
hidden_size_decoder: 64
18+
input_att: True
19+
temporal_att: True
20+
output_size: 1
21+
denoising: False
22+
max_grad_norm: 0.1
23+
lrs_step_size: 5000
24+
gradient_accumulation_steps: 1
25+
directions: 1
26+
27+
general:
28+
do_eval: True
29+
do_train: True
30+
logging_steps: 100
31+
32+
output_dir: "output"
33+
save_steps: 5000
34+
eval_during_training: True

examples/reconstruction/config_reconstruction.py

Lines changed: 0 additions & 60 deletions
This file was deleted.
Lines changed: 23 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,37 @@
1-
import torch
2-
import argparse
1+
import hydra
32
import pandas as pd
3+
import torch
44
import torch.nn as nn
5+
6+
from tsa.utils import load_checkpoint
57
from tsa import TimeSeriesDataset, AutoEncForecast, train, evaluate
6-
from .config_reconstruction import config
78

89
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
910

10-
def parse_args():
11-
"""
12-
Parse command line arguments.
13-
14-
Args:
15-
"""
16-
parser = argparse.ArgumentParser()
17-
parser.add_argument("--batch-size", default=config["batch_size"], type=int, help="batch size")
18-
parser.add_argument("--output-size", default=config["output_size"], type=int,
19-
help="size of the ouput: default value to 1 for forecasting")
20-
parser.add_argument("--label-col", default=config["label_col"], type=str, help="name of the target column")
21-
parser.add_argument("--input-att", default=config["input_att"], type=lambda x: (str(x).lower() == "true"),
22-
help="whether or not activate the input attention mechanism")
23-
parser.add_argument("--temporal-att", default=config["temporal_att"], type=lambda x: (str(x).lower() == "true"),
24-
help="whether or not activate the temporal attention mechanism")
25-
parser.add_argument("--seq-len", default=config["seq_len"], type=int, help="window length to use for forecasting")
26-
parser.add_argument("--hidden-size-encoder", default=config["hidden_size_encoder"], type=int,
27-
help="size of the encoder's hidden states")
28-
parser.add_argument("--hidden-size-decoder", default=config["hidden_size_decoder"], type=int,
29-
help="size of the decoder's hidden states")
30-
parser.add_argument("--reg-factor1", default=config["reg_factor1"], type=float,
31-
help="contribution factor of the L1 regularization if using a sparse autoencoder")
32-
parser.add_argument("--reg-factor2", default=config["reg_factor2"], type=float,
33-
help="contribution factor of the L2 regularization if using a sparse autoencoder")
34-
parser.add_argument("--reg1", default=config["reg1"], type=lambda x: (str(x).lower() == "true"),
35-
help="activate/deactivate L1 regularization")
36-
parser.add_argument("--reg2", default=config["reg2"], type=lambda x: (str(x).lower() == "true"),
37-
help="activate/deactivate L2 regularization")
38-
parser.add_argument("--denoising", default=config["denoising"], type=lambda x: (str(x).lower() == "true"),
39-
help="whether or not to use a denoising autoencoder")
40-
parser.add_argument("--do-train", default=True, type=lambda x: (str(x).lower() == "true"),
41-
help="whether or not to train the model")
42-
parser.add_argument("--do-eval", default=False, type=lambda x: (str(x).lower() == "true"),
43-
help="whether or not evaluating the mode")
44-
parser.add_argument("--output-dir", default=config["output_dir"], help="name of folder to output files")
45-
parser.add_argument("--ckpt", default=None, help="checkpoint path for evaluation")
46-
return parser.parse_args()
47-
4811

49-
if __name__ == "__main__":
50-
args = vars(parse_args())
51-
config.update(args)
52-
53-
df = pd.read_csv("data/AirQualityUCI.csv", index_col=config["index_col"])
12+
@hydra.main(config_path="./", config_name="config")
13+
def run(args):
14+
df = pd.read_csv("data/AirQualityUCI.csv", index_col=args.data.index_col)
5415

5516
ts = TimeSeriesDataset(
5617
data=df,
57-
categorical_cols=config["categorical_cols"],
58-
target_col=config["label_col"],
59-
seq_length=config["seq_len"],
60-
prediction_window=config["prediction_window"]
18+
categorical_cols=args.data.categorical_cols,
19+
target_col=args.data.label_col,
20+
seq_length=args.data.seq_len,
21+
prediction_window=args.data.prediction_window
6122
)
62-
train_iter, test_iter, nb_features = ts.get_loaders(batch_size=config["batch_size"])
23+
train_iter, test_iter, nb_features = ts.get_loaders(batch_size=args.data.batch_size)
6324

64-
model = AutoEncForecast(config, input_size=nb_features).to(config["device"])
25+
model = AutoEncForecast(args, input_size=nb_features).to(device)
6526
criterion = nn.MSELoss()
66-
optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])
27+
optimizer = torch.optim.Adam(model.parameters(), lr=args.training.lr)
28+
29+
if args.general.do_eval and args.general.ckpt:
30+
model, _, loss, epoch = load_checkpoint(args.general.ckpt, model, optimizer, device)
31+
evaluate(test_iter, loss, model, args, ts)
32+
elif args.general.do_train:
33+
train(train_iter, test_iter, model, criterion, optimizer, args, ts)
6734

68-
if config["do_eval"] and config["ckpt"]:
69-
model, _, loss, epoch = load_checkpoint(config["ckpt"], model, optimizer, config["device"])
70-
evaluate(test_iter, loss, model, config)
71-
elif config["do_train"]:
72-
train(train_iter, test_iter, model, criterion, optimizer, config)
35+
36+
if __name__ == "__main__":
37+
run()

0 commit comments

Comments
 (0)