Skip to content
This repository has been archived by the owner on Jul 7, 2023. It is now read-only.

Commit

Permalink
[Mesh-TF] Add is_training as an arg to mtf.dropout
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 361088273
  • Loading branch information
afrozenator authored and copybara-github committed Mar 5, 2021
1 parent e19130b commit 5623deb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 6 deletions.
14 changes: 13 additions & 1 deletion tensor2tensor/models/mtf_image_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,8 +243,10 @@ def import_to_batch_by_length(x, name):
def layer_prepostprocess_dropout(x, hparams):
batch_dim = x.shape.dims[0]
model_dim = x.shape.dims[-1]
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
return mtf.dropout(
x,
x, is_training,
keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
noise_shape=mtf.Shape([batch_dim, model_dim]))

Expand All @@ -259,6 +261,8 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
x = mtf.reshape(
x, mtf.Shape([batch_dim, num_w_blocks_dim, blocks_w_dim, model_dim]))
# [ self attention - ffn - residual + dropout] x n
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
for layer in range(hparams.num_decoder_layers):
layer_name = "decoder_layer_%d" % layer
with tf.variable_scope(layer_name):
Expand All @@ -268,6 +272,7 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
kv_dim,
heads_dim,
is_training,
memory_w_dim=blocks_w_dim,
mask_right=True,
name="self_att"), hparams)
Expand All @@ -276,6 +281,7 @@ def local_attention1d_spatial_decoder(x, kv_dim, heads_dim,
mtf.layers.dense_relu_dense(
mtf.layers.layer_norm(x, model_dim, name="layer_norm_ffn"),
feedforward_dim,
is_training,
hparams.dropout,
dropout_broadcast_dims=[length_dim]), hparams)

Expand Down Expand Up @@ -305,6 +311,8 @@ def local_attention2d_spatial_decoder(x, kv_dim, heads_dim,
batch_dim, num_h_blocks_dim, num_w_blocks_dim,
blocks_h_dim, blocks_w_dim, model_dim
]))
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
# Image Transformer Decoder
# [ self attention - ffn - residual + dropout] x n
for layer in range(hparams.num_decoder_layers):
Expand All @@ -316,6 +324,7 @@ def local_attention2d_spatial_decoder(x, kv_dim, heads_dim,
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
kv_dim,
heads_dim,
is_training,
memory_h_dim=num_h_blocks_dim,
memory_w_dim=num_w_blocks_dim,
name="self_att"), hparams)
Expand All @@ -336,6 +345,8 @@ def local_attention1d_masked_decoder(x, kv_dim, heads_dim,
"""Image Transformer decoder with local1D masked layers."""
print(x)
_, length_dim, model_dim = x.shape.dims
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
for layer in range(hparams.num_decoder_layers):
layer_name = "decoder_layer_%d" % layer
with tf.variable_scope(layer_name):
Expand All @@ -347,6 +358,7 @@ def local_attention1d_masked_decoder(x, kv_dim, heads_dim,
mtf.layers.layer_norm(x, model_dim, name="layer_norm_att"),
kv_dim,
heads_dim,
is_training,
window_size=hparams.block_length,
length_per_split=length_per_split,
name="self_att"), hparams)
Expand Down
18 changes: 13 additions & 5 deletions tensor2tensor/models/mtf_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ def _mtf_model_fn(self, features, mesh):
hparams = self._hparams
extra_losses = []
targets = tf.to_int32(features["targets"])
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
if len(targets.get_shape()) > 2:
tf.logging.info("targets = %s" % targets)
targets = tf.squeeze(targets, [2, 3])
Expand Down Expand Up @@ -289,7 +291,7 @@ def pad_to_max_length(x):

def layer_prepostprocess_dropout(x):
return mtf.dropout(
x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))

(inputs_embedding_var,
Expand Down Expand Up @@ -426,10 +428,11 @@ def _feedforward_layer(self, x, layer_type, losses=None):
ValueError: if hparams make no sense
"""
hparams = self._hparams

mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
if layer_type == "drd":
return mtf.layers.dense_relu_dense(
x, self.feedforward_dim, dropout=hparams.relu_dropout,
x, self.feedforward_dim, is_training, dropout=hparams.relu_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype)
Expand Down Expand Up @@ -493,11 +496,13 @@ def _layer_stack(self,
"""
hparams = self._hparams
is_incremental = (step_num is not None)
mode = getattr(hparams, "mode", tf.estimator.ModeKeys.TRAIN)
is_training = mode == tf.estimator.ModeKeys.TRAIN
def layer_prepostprocess_dropout(x):
if is_incremental:
return x
return mtf.dropout(
x, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
x, is_training, keep_prob=1.0 - hparams.layer_prepostprocess_dropout,
noise_shape=mtf.Shape(self.batch_dims + [self.model_dim]))
num_layers = len(layers)
num_layer_norms = num_layers + 1
Expand Down Expand Up @@ -540,6 +545,7 @@ def normalize(x):
mtf.layers.multihead_attention(
normalize(x), None,
self_attention_mask, self.kv_dim, self.heads_dim,
is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
Expand All @@ -560,6 +566,7 @@ def normalize(x):
mtf.layers.multihead_attention(
normalize(x), encoder_output,
encdec_attention_mask, self.kv_dim, self.heads_dim,
is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
Expand All @@ -582,7 +589,7 @@ def normalize(x):
x += layer_prepostprocess_dropout(
mtf.layers.masked_local_attention_1d(
normalize(x),
self.kv_dim, self.heads_dim,
self.kv_dim, self.heads_dim, is_training,
window_size=hparams.local_attention_window_size,
master_dtype=self.master_dtype,
slice_dtype=self.slice_dtype,
Expand All @@ -601,6 +608,7 @@ def normalize(x):
compression_factor=hparams.compression_factor,
kv_channels=self.kv_dim,
heads=self.heads_dim,
is_training=is_training,
dropout=hparams.attention_dropout,
dropout_broadcast_dims=[self.length_dim],
master_dtype=self.master_dtype,
Expand Down

0 comments on commit 5623deb

Please sign in to comment.