Skip to content

Commit 15dd5e5

Browse files
committed
Fix a problem with the recent version of PyTorch
1 parent abe0de2 commit 15dd5e5

File tree

3 files changed

+11
-5
lines changed

3 files changed

+11
-5
lines changed

README.md

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
11
# pytorch-a3c
22

3-
NEED TO USE V-0.1.9 (or lower) OF PYTORCH, AND NOT V-0.1.10 BECAUSE OF THIS ISSUE:
4-
https://discuss.pytorch.org/t/problem-on-variable-grad-data/957/7
5-
63
This is a PyTorch implementation of Asynchronous Advantage Actor Critic (A3C) from ["Asynchronous Methods for Deep Reinforcement Learning"](https://arxiv.org/pdf/1602.01783v1.pdf).
74

85
This implementation is inspired by [Universe Starter Agent](https://github.com/openai/universe-starter-agent).
@@ -14,7 +11,7 @@ Contributions are very welcome. If you know how to make this code better, don't
1411

1512
## Usage
1613
```
17-
python main.py --env-name "PongDeterministic-v3" --num-processes 16
14+
OMP_NUM_THREADS=1 python main.py --env-name "PongDeterministic-v3" --num-processes 16
1815
```
1916

2017
This code runs evaluation in a separate thread in addition to 16 processes.

main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
args = parser.parse_args()
3939

4040
torch.manual_seed(args.seed)
41-
torch.set_num_threads(1)
4241

4342
env = create_atari_env(args.env_name)
4443
shared_model = ActorCritic(

model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ def __init__(self, num_inputs, action_space):
5959
self.lstm.bias_hh.data.fill_(0)
6060

6161
self.train()
62+
self.__dummy_backprob()
63+
64+
def __dummy_backprob(self):
65+
# See: https://discuss.pytorch.org/t/problem-on-variable-grad-data/957/7
66+
# An ugly hack until there is a better solution.
67+
inputs = Variable(torch.randn(1, 1, 42, 42))
68+
hx, cx = Variable(torch.randn(1, 256)), Variable(torch.randn(1, 256))
69+
outputs = self((inputs, (hx, cx)))
70+
loss = (outputs[0].mean() + outputs[1].mean()) * 0.0
71+
loss.backward()
6272

6373
def forward(self, inputs):
6474
inputs, (hx, cx) = inputs

0 commit comments

Comments
 (0)