Skip to content

Commit b0c1560

Browse files
authored
Merge pull request #9 from apaszke/master
A cleaner solution to grad sharing problem
2 parents 15dd5e5 + 3e4b32b commit b0c1560

File tree

3 files changed

+12
-22
lines changed

3 files changed

+12
-22
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
__pycache__
2+
*.pyc

model.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,6 @@ def __init__(self, num_inputs, action_space):
5959
self.lstm.bias_hh.data.fill_(0)
6060

6161
self.train()
62-
self.__dummy_backprob()
63-
64-
def __dummy_backprob(self):
65-
# See: https://discuss.pytorch.org/t/problem-on-variable-grad-data/957/7
66-
# An ugly hack until there is a better solution.
67-
inputs = Variable(torch.randn(1, 1, 42, 42))
68-
hx, cx = Variable(torch.randn(1, 256)), Variable(torch.randn(1, 256))
69-
outputs = self((inputs, (hx, cx)))
70-
loss = (outputs[0].mean() + outputs[1].mean()) * 0.0
71-
loss.backward()
7262

7363
def forward(self, inputs):
7464
inputs, (hx, cx) = inputs

train.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@
1111
from torchvision import datasets, transforms
1212

1313

14+
def ensure_shared_grads(model, shared_model):
15+
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
16+
if shared_param.grad is not None:
17+
return
18+
shared_param._grad = param.grad
19+
20+
1421
def train(rank, args, shared_model):
1522
torch.manual_seed(args.seed + rank)
1623

@@ -19,10 +26,6 @@ def train(rank, args, shared_model):
1926

2027
model = ActorCritic(env.observation_space.shape[0], env.action_space)
2128

22-
for param, shared_param in zip(model.parameters(), shared_model.parameters()):
23-
# Use gradients from the local model
24-
shared_param.grad.data = param.grad.data
25-
2629
optimizer = optim.Adam(shared_model.parameters(), lr=args.lr)
2730

2831
model.train()
@@ -102,14 +105,9 @@ def train(rank, args, shared_model):
102105
log_probs[i] * Variable(gae) - 0.01 * entropies[i]
103106

104107
optimizer.zero_grad()
108+
105109
(policy_loss + 0.5 * value_loss).backward()
110+
torch.nn.utils.clip_grad_norm(model.parameters(), 40)
106111

107-
global_norm = 0
108-
for param in model.parameters():
109-
global_norm += param.grad.data.pow(2).sum()
110-
global_norm = math.sqrt(global_norm)
111-
ratio = 40 / global_norm
112-
if ratio < 1:
113-
for param in model.parameters():
114-
param.grad.data.mul_(ratio)
112+
ensure_shared_grads(model, shared_model)
115113
optimizer.step()

0 commit comments

Comments
 (0)