|
| 1 | +import hydra |
1 | 2 | import torch |
2 | | -import argparse |
3 | | -import pandas as pd |
4 | 3 | import torch.nn as nn |
5 | | -from tsa import TimeSeriesDataset, AutoEncForecast, train, evaluate |
6 | | -from .config_forecasting import config |
| 4 | +from hydra.utils import instantiate |
7 | 5 |
|
8 | | -device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
9 | | - |
10 | | -def parse_args(): |
11 | | - """ |
12 | | - Parse command line arguments. |
| 6 | +from tsa import AutoEncForecast, train, evaluate |
| 7 | +from tsa.utils import load_checkpoint |
13 | 8 |
|
14 | | - Args: |
15 | | - """ |
16 | | - parser = argparse.ArgumentParser() |
17 | | - parser.add_argument("--batch-size", default=config["batch_size"], type=int, help="batch size") |
18 | | - parser.add_argument("--output-size", default=config["output_size"], type=int, |
19 | | - help="size of the ouput: default value to 1 for forecasting") |
20 | | - parser.add_argument("--label-col", default=config["label_col"], type=str, help="name of the target column") |
21 | | - parser.add_argument("--input-att", default=config["input_att"], type=lambda x: (str(x).lower() == "true"), |
22 | | - help="whether or not activate the input attention mechanism") |
23 | | - parser.add_argument("--temporal-att", default=config["temporal_att"], type=lambda x: (str(x).lower() == "true"), |
24 | | - help="whether or not activate the temporal attention mechanism") |
25 | | - parser.add_argument("--seq-len", default=config["seq_len"], type=int, help="window length to use for forecasting") |
26 | | - parser.add_argument("--hidden-size-encoder", default=config["hidden_size_encoder"], type=int, |
27 | | - help="size of the encoder's hidden states") |
28 | | - parser.add_argument("--hidden-size-decoder", default=config["hidden_size_decoder"], type=int, |
29 | | - help="size of the decoder's hidden states") |
30 | | - parser.add_argument("--reg-factor1", default=config["reg_factor1"], type=float, |
31 | | - help="contribution factor of the L1 regularization if using a sparse autoencoder") |
32 | | - parser.add_argument("--reg-factor2", default=config["reg_factor2"], type=float, |
33 | | - help="contribution factor of the L2 regularization if using a sparse autoencoder") |
34 | | - parser.add_argument("--reg1", default=config["reg1"], type=lambda x: (str(x).lower() == "true"), |
35 | | - help="activate/deactivate L1 regularization") |
36 | | - parser.add_argument("--reg2", default=config["reg2"], type=lambda x: (str(x).lower() == "true"), |
37 | | - help="activate/deactivate L2 regularization") |
38 | | - parser.add_argument("--denoising", default=config["denoising"], type=lambda x: (str(x).lower() == "true"), |
39 | | - help="whether or not to use a denoising autoencoder") |
40 | | - parser.add_argument("--do-train", default=True, type=lambda x: (str(x).lower() == "true"), |
41 | | - help="whether or not to train the model") |
42 | | - parser.add_argument("--do-eval", default=False, type=lambda x: (str(x).lower() == "true"), |
43 | | - help="whether or not evaluating the mode") |
44 | | - parser.add_argument("--output-dir", default=config["output_dir"], help="name of folder to output files") |
45 | | - parser.add_argument("--ckpt", default=None, help="checkpoint path for evaluation") |
46 | | - return parser.parse_args() |
| 9 | +device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
47 | 10 |
|
48 | 11 |
|
49 | | -if __name__ == "__main__": |
50 | | - args = vars(parse_args()) |
51 | | - config.update(args) |
| 12 | +@hydra.main(config_path="./", config_name="config") |
| 13 | +def run(cfg): |
| 14 | + ts = instantiate(cfg.data) |
| 15 | + train_iter, test_iter, nb_features = ts.get_loaders() |
52 | 16 |
|
53 | | - df = pd.read_csv("data/AirQualityUCI.csv", index_col=config["index_col"]) |
| 17 | + model = AutoEncForecast(cfg.training, input_size=nb_features).to(device) |
| 18 | + criterion = nn.MSELoss() |
| 19 | + optimizer = torch.optim.Adam(model.parameters(), lr=cfg.training.lr) |
54 | 20 |
|
55 | | - ts = TimeSeriesDataset( |
56 | | - data=df, |
57 | | - categorical_cols=config["categorical_cols"], |
58 | | - target_col=config["label_col"], |
59 | | - seq_length=config["seq_len"], |
60 | | - prediction_window=config["prediction_window"] |
61 | | - ) |
62 | | - train_iter, test_iter, nb_features = ts.get_loaders(batch_size=config["batch_size"]) |
| 21 | + if cfg.general.do_train: |
| 22 | + train(train_iter, test_iter, model, criterion, optimizer, cfg, ts) |
| 23 | + if cfg.general.do_eval and cfg.general.get("ckpt", False): |
| 24 | + model, _, loss, epoch = load_checkpoint(cfg.general.ckpt, model, optimizer, device) |
| 25 | + evaluate(test_iter, loss, model, cfg, ts) |
63 | 26 |
|
64 | | - model = AutoEncForecast(config, input_size=nb_features).to(config["device"]) |
65 | | - criterion = nn.MSELoss() |
66 | | - optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"]) |
67 | 27 |
|
68 | | - if config["do_eval"] and config["ckpt"]: |
69 | | - model, _, loss, epoch = load_checkpoint(config["ckpt"], model, optimizer, config["device"]) |
70 | | - evaluate(test_iter, loss, model, config) |
71 | | - elif config["do_train"]: |
72 | | - train(train_iter, test_iter, model, criterion, optimizer, config) |
| 28 | +if __name__ == "__main__": |
| 29 | + run() |
0 commit comments