1111from torchvision import datasets , transforms
1212
1313
14+ def ensure_shared_grads (model , shared_model ):
15+ for param , shared_param in zip (model .parameters (), shared_model .parameters ()):
16+ if shared_param .grad is not None :
17+ return
18+ shared_param ._grad = param .grad
19+
20+
1421def train (rank , args , shared_model ):
1522 torch .manual_seed (args .seed + rank )
1623
@@ -19,10 +26,6 @@ def train(rank, args, shared_model):
1926
2027 model = ActorCritic (env .observation_space .shape [0 ], env .action_space )
2128
22- for param , shared_param in zip (model .parameters (), shared_model .parameters ()):
23- # Use gradients from the local model
24- shared_param .grad .data = param .grad .data
25-
2629 optimizer = optim .Adam (shared_model .parameters (), lr = args .lr )
2730
2831 model .train ()
@@ -102,14 +105,9 @@ def train(rank, args, shared_model):
102105 log_probs [i ] * Variable (gae ) - 0.01 * entropies [i ]
103106
104107 optimizer .zero_grad ()
108+
105109 (policy_loss + 0.5 * value_loss ).backward ()
110+ torch .nn .utils .clip_grad_norm (model .parameters (), 40 )
106111
107- global_norm = 0
108- for param in model .parameters ():
109- global_norm += param .grad .data .pow (2 ).sum ()
110- global_norm = math .sqrt (global_norm )
111- ratio = 40 / global_norm
112- if ratio < 1 :
113- for param in model .parameters ():
114- param .grad .data .mul_ (ratio )
112+ ensure_shared_grads (model , shared_model )
115113 optimizer .step ()
0 commit comments