@@ -242,6 +242,8 @@ def _mtf_model_fn(self, features, mesh):
242242 hparams = self ._hparams
243243 extra_losses = []
244244 targets = tf .to_int32 (features ["targets" ])
245+ mode = getattr (hparams , "mode" , tf .estimator .ModeKeys .TRAIN )
246+ is_training = mode == tf .estimator .ModeKeys .TRAIN
245247 if len (targets .get_shape ()) > 2 :
246248 tf .logging .info ("targets = %s" % targets )
247249 targets = tf .squeeze (targets , [2 , 3 ])
@@ -289,7 +291,7 @@ def pad_to_max_length(x):
289291
290292 def layer_prepostprocess_dropout (x ):
291293 return mtf .dropout (
292- x , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
294+ x , is_training , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
293295 noise_shape = mtf .Shape (self .batch_dims + [self .model_dim ]))
294296
295297 (inputs_embedding_var ,
@@ -426,10 +428,11 @@ def _feedforward_layer(self, x, layer_type, losses=None):
426428 ValueError: if hparams make no sense
427429 """
428430 hparams = self ._hparams
429-
431+ mode = getattr (hparams , "mode" , tf .estimator .ModeKeys .TRAIN )
432+ is_training = mode == tf .estimator .ModeKeys .TRAIN
430433 if layer_type == "drd" :
431434 return mtf .layers .dense_relu_dense (
432- x , self .feedforward_dim , dropout = hparams .relu_dropout ,
435+ x , self .feedforward_dim , is_training , dropout = hparams .relu_dropout ,
433436 dropout_broadcast_dims = [self .length_dim ],
434437 master_dtype = self .master_dtype ,
435438 slice_dtype = self .slice_dtype )
@@ -493,11 +496,13 @@ def _layer_stack(self,
493496 """
494497 hparams = self ._hparams
495498 is_incremental = (step_num is not None )
499+ mode = getattr (hparams , "mode" , tf .estimator .ModeKeys .TRAIN )
500+ is_training = mode == tf .estimator .ModeKeys .TRAIN
496501 def layer_prepostprocess_dropout (x ):
497502 if is_incremental :
498503 return x
499504 return mtf .dropout (
500- x , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
505+ x , is_training , keep_prob = 1.0 - hparams .layer_prepostprocess_dropout ,
501506 noise_shape = mtf .Shape (self .batch_dims + [self .model_dim ]))
502507 num_layers = len (layers )
503508 num_layer_norms = num_layers + 1
@@ -540,6 +545,7 @@ def normalize(x):
540545 mtf .layers .multihead_attention (
541546 normalize (x ), None ,
542547 self_attention_mask , self .kv_dim , self .heads_dim ,
548+ is_training ,
543549 dropout = hparams .attention_dropout ,
544550 dropout_broadcast_dims = [self .length_dim ],
545551 master_dtype = self .master_dtype ,
@@ -560,6 +566,7 @@ def normalize(x):
560566 mtf .layers .multihead_attention (
561567 normalize (x ), encoder_output ,
562568 encdec_attention_mask , self .kv_dim , self .heads_dim ,
569+ is_training ,
563570 dropout = hparams .attention_dropout ,
564571 dropout_broadcast_dims = [self .length_dim ],
565572 master_dtype = self .master_dtype ,
@@ -582,7 +589,7 @@ def normalize(x):
582589 x += layer_prepostprocess_dropout (
583590 mtf .layers .masked_local_attention_1d (
584591 normalize (x ),
585- self .kv_dim , self .heads_dim ,
592+ self .kv_dim , self .heads_dim , is_training ,
586593 window_size = hparams .local_attention_window_size ,
587594 master_dtype = self .master_dtype ,
588595 slice_dtype = self .slice_dtype ,
@@ -601,6 +608,7 @@ def normalize(x):
601608 compression_factor = hparams .compression_factor ,
602609 kv_channels = self .kv_dim ,
603610 heads = self .heads_dim ,
611+ is_training = is_training ,
604612 dropout = hparams .attention_dropout ,
605613 dropout_broadcast_dims = [self .length_dim ],
606614 master_dtype = self .master_dtype ,
0 commit comments