|
11 | 11 | from ignite.utils import manual_seed |
12 | 12 | from models import setup_model |
13 | 13 | from torch import nn, optim |
14 | | -from torch.optim.lr_scheduler import _LRScheduler, LambdaLR |
| 14 | +from torch.optim.lr_scheduler import LambdaLR |
15 | 15 | from trainers import setup_evaluator, setup_trainer |
16 | 16 | from utils import * |
17 | 17 | from vis import predictions_gt_images_handler |
18 | 18 |
|
| 19 | +try: |
| 20 | + from torch.optim.lr_scheduler import LRScheduler as PyTorchLRScheduler |
| 21 | +except ImportError: |
| 22 | + from torch.optim.lr_scheduler import _LRScheduler as PyTorchLRScheduler |
| 23 | + |
19 | 24 |
|
20 | 25 | def run(local_rank: int, config: Any): |
21 | 26 | # make a certain seed |
@@ -71,10 +76,10 @@ def run(local_rank: int, config: Any): |
71 | 76 | (config.output_dir / "config-lock.yaml").write_text(yaml.dump(config)) |
72 | 77 | trainer.logger = evaluator.logger = logger |
73 | 78 |
|
74 | | - if isinstance(lr_scheduler, _LRScheduler): |
| 79 | + if isinstance(lr_scheduler, PyTorchLRScheduler): |
75 | 80 | trainer.add_event_handler( |
76 | 81 | Events.ITERATION_COMPLETED, |
77 | | - lambda engine: cast(_LRScheduler, lr_scheduler).step(), |
| 82 | + lambda engine: cast(PyTorchLRScheduler, lr_scheduler).step(), |
78 | 83 | ) |
79 | 84 | elif isinstance(lr_scheduler, LRScheduler): |
80 | 85 | trainer.add_event_handler(Events.ITERATION_COMPLETED, lr_scheduler) |
|
0 commit comments