Skip to content

Commit 8e0f4df

Browse files
committed
chore: use hydra config parser
1 parent 3e55e09 commit 8e0f4df

File tree

10 files changed

+183
-115
lines changed

10 files changed

+183
-115
lines changed

examples/forecasting/config.yaml

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,35 @@
11
data:
2+
_target_: tsa.dataset.TimeSeriesDataset
3+
batch_size: 16
24
categorical_cols: [ ]
3-
label_col: ["AH"]
45
index_col: "Date_Time"
6+
target_col: [ "AH" ]
7+
data_path: "../data/AirQualityUCI.csv"
58
prediction_window: 1
6-
seq_len: 3
9+
seq_length: 3
10+
task:
11+
_target_: tsa.dataset.Tasks
12+
value: prediction
713

814
training:
9-
num_epochs: 100
10-
batch_size: 16
15+
batch_size: ${data.batch_size}
16+
denoising: False
17+
directions: 1
18+
gradient_accumulation_steps: 1
19+
hidden_size_encoder: 64
20+
hidden_size_decoder: 64
21+
input_att: True
1122
lr: 1e-5
23+
lrs_step_size: 5000
24+
max_grad_norm: 0.1
25+
num_epochs: 100
26+
output_size: 1
1227
reg1: True
1328
reg2: False
1429
reg_factor1: 1e-4
1530
reg_factor2: 1e-4
16-
hidden_size_encoder: 64
17-
hidden_size_decoder: 64
18-
output_size: 1
19-
input_att: True
31+
seq_len: ${data.seq_length}
2032
temporal_att: True
21-
denoising: False
22-
directions: 1
23-
max_grad_norm: 0.1
24-
gradient_accumulation_steps: 1
25-
lrs_step_size: 5000
2633

2734
general:
2835
do_eval: True

examples/forecasting/run_forecasting.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,28 @@
11
import hydra
2-
import pandas as pd
32
import torch
43
import torch.nn as nn
4+
from hydra.utils import instantiate
55

6-
from tsa import TimeSeriesDataset, AutoEncForecast, train, evaluate
6+
from tsa import AutoEncForecast, train, evaluate
77
from tsa.utils import load_checkpoint
88

99
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1010

1111

1212
@hydra.main(config_path="./", config_name="config")
13-
def run(args):
14-
df = pd.read_csv("data/AirQualityUCI.csv", index_col=args.data.index_col)
15-
16-
ts = TimeSeriesDataset(
17-
data=df,
18-
categorical_cols=args.data.categorical_cols,
19-
target_col=args.data.label_col,
20-
seq_length=args.data.seq_len,
21-
prediction_window=args.data.prediction_window
22-
)
23-
train_iter, test_iter, nb_features = ts.get_loaders(batch_size=args.data.batch_size)
24-
25-
model = AutoEncForecast(args.training, input_size=nb_features).to(device)
13+
def run(cfg):
14+
ts = instantiate(cfg.data)
15+
train_iter, test_iter, nb_features = ts.get_loaders()
16+
17+
model = AutoEncForecast(cfg.training, input_size=nb_features).to(device)
2618
criterion = nn.MSELoss()
27-
optimizer = torch.optim.Adam(model.parameters(), lr=args.training.lr)
19+
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.lr)
2820

29-
if args.general.do_eval and args.general.ckpt:
30-
model, _, loss, epoch = load_checkpoint(args.general.ckpt, model, optimizer, device)
31-
evaluate(test_iter, loss, model, args, ts)
32-
elif args.general.do_train:
33-
train(train_iter, test_iter, model, criterion, optimizer, args, ts)
21+
if cfg.general.do_train:
22+
train(train_iter, test_iter, model, criterion, optimizer, cfg, ts)
23+
if cfg.general.do_eval and cfg.general.get("ckpt", False):
24+
model, _, loss, epoch = load_checkpoint(cfg.general.ckpt, model, optimizer, device)
25+
evaluate(test_iter, loss, model, cfg, ts)
3426

3527

3628
if __name__ == "__main__":

examples/reconstruction/config.yaml

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,34 @@
11
data:
2+
_target_: tsa.dataset.TimeSeriesDataset
3+
batch_size: 16
24
categorical_cols: [ ]
3-
label_col: []
45
index_col: "Date_Time"
6+
target_col: [ ]
7+
data_path: "../data/AirQualityUCI.csv"
58
prediction_window: 1
6-
seq_len: 3
7-
batch_size: 16
9+
seq_length: 3
10+
task:
11+
_target_: tsa.dataset.Tasks
12+
value: reconstruction
813

914
training:
10-
num_epochs: 100
15+
denoising: False
16+
directions: 1
17+
gradient_accumulation_steps: 1
18+
hidden_size_encoder: 64
19+
hidden_size_decoder: 64
20+
input_att: True
1121
lr: 1e-5
22+
lrs_step_size: 5000
23+
max_grad_norm: 0.1
24+
num_epochs: 100
25+
output_size: 13
1226
reg1: True
1327
reg2: False
1428
reg_factor1: 1e-4
1529
reg_factor2: 1e-4
16-
hidden_size_encoder: 64
17-
hidden_size_decoder: 64
18-
input_att: True
30+
seq_len: ${data.seq_length}
1931
temporal_att: True
20-
output_size: 1
21-
denoising: False
22-
max_grad_norm: 0.1
23-
lrs_step_size: 5000
24-
gradient_accumulation_steps: 1
25-
directions: 1
2632

2733
general:
2834
do_eval: True

examples/reconstruction/run_reconstruction.py

Lines changed: 13 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,28 @@
11
import hydra
2-
import pandas as pd
32
import torch
43
import torch.nn as nn
4+
from hydra.utils import instantiate
55

6+
from tsa import AutoEncForecast, train, evaluate
67
from tsa.utils import load_checkpoint
7-
from tsa import TimeSeriesDataset, AutoEncForecast, train, evaluate
88

99
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
1010

1111

1212
@hydra.main(config_path="./", config_name="config")
13-
def run(args):
14-
df = pd.read_csv("data/AirQualityUCI.csv", index_col=args.data.index_col)
15-
16-
ts = TimeSeriesDataset(
17-
data=df,
18-
categorical_cols=args.data.categorical_cols,
19-
target_col=args.data.label_col,
20-
seq_length=args.data.seq_len,
21-
prediction_window=args.data.prediction_window
22-
)
23-
train_iter, test_iter, nb_features = ts.get_loaders(batch_size=args.data.batch_size)
24-
25-
model = AutoEncForecast(args, input_size=nb_features).to(device)
13+
def run(cfg):
14+
ts = instantiate(cfg.data)
15+
train_iter, test_iter, nb_features = ts.get_loaders()
16+
17+
model = AutoEncForecast(cfg.training, input_size=nb_features).to(device)
2618
criterion = nn.MSELoss()
27-
optimizer = torch.optim.Adam(model.parameters(), lr=args.training.lr)
19+
optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.lr)
2820

29-
if args.general.do_eval and args.general.ckpt:
30-
model, _, loss, epoch = load_checkpoint(args.general.ckpt, model, optimizer, device)
31-
evaluate(test_iter, loss, model, args, ts)
32-
elif args.general.do_train:
33-
train(train_iter, test_iter, model, criterion, optimizer, args, ts)
21+
if cfg.general.do_eval and cfg.general.get("ckpt", False):
22+
model, _, loss, epoch = load_checkpoint(cfg.general.ckpt, model, optimizer, device)
23+
evaluate(test_iter, loss, model, cfg, ts)
24+
elif cfg.general.do_train:
25+
train(train_iter, test_iter, model, criterion, optimizer, cfg, ts)
3426

3527

3628
if __name__ == "__main__":

poetry.lock

Lines changed: 47 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ sklearn = "^0.0"
1212
matplotlib = "^3.3.4"
1313
tensorboardX = "^2.1"
1414
tqdm = "^4.59.0"
15-
hydra-core = "^1.1.2"
15+
hydra-core = "^1.2.0"
1616

1717
[tool.poetry.dev-dependencies]
1818

tsa/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from .config import config
21
from .dataset import TimeSeriesDataset
32
from .eval import evaluate
43
from .train import train

0 commit comments

Comments
 (0)