Skip to content

Commit

Permalink
Remove hard-coded uses of float32 to fix mixed precision use (hugging…
Browse files Browse the repository at this point in the history
  • Loading branch information
schmidek committed Aug 21, 2020
1 parent 9e8c494 commit 6ac59c3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
11 changes: 6 additions & 5 deletions src/transformers/modeling_tf_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,8 @@ def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, tra
if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)

position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)
embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings, training=training)
Expand Down Expand Up @@ -281,7 +281,7 @@ def call(self, hidden_states, attention_mask, head_mask, output_attentions, trai
attention_scores = tf.matmul(
query_layer, key_layer, transpose_b=True
) # (batch size, num_heads, seq_len_q, seq_len_k)
dk = tf.cast(shape_list(key_layer)[-1], tf.float32) # scale attention_scores
dk = tf.cast(shape_list(key_layer)[-1], attention_scores.dtype) # scale attention_scores
attention_scores = attention_scores / tf.math.sqrt(dk)

if attention_mask is not None:
Expand Down Expand Up @@ -613,6 +613,8 @@ def call(
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)

embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)

# We create a 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, to_seq_length]
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
Expand All @@ -626,7 +628,7 @@ def call(
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.

extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = tf.cast(extended_attention_mask, embedding_output.dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

# Prepare head mask if needed
Expand All @@ -640,7 +642,6 @@ def call(
head_mask = [None] * self.num_hidden_layers
# head_mask = tf.constant([0] * self.num_hidden_layers)

embedding_output = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)
encoder_outputs = self.encoder(
embedding_output,
extended_attention_mask,
Expand Down
14 changes: 7 additions & 7 deletions src/transformers/modeling_tf_electra.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,8 @@ def _embedding(self, input_ids, position_ids, token_type_ids, inputs_embeds, tra

if inputs_embeds is None:
inputs_embeds = tf.gather(self.word_embeddings, input_ids)
position_embeddings = self.position_embeddings(position_ids)
token_type_embeddings = self.token_type_embeddings(token_type_ids)
position_embeddings = tf.cast(self.position_embeddings(position_ids), inputs_embeds.dtype)
token_type_embeddings = tf.cast(self.token_type_embeddings(token_type_ids), inputs_embeds.dtype)

embeddings = inputs_embeds + position_embeddings + token_type_embeddings
embeddings = self.LayerNorm(embeddings)
Expand Down Expand Up @@ -194,7 +194,7 @@ class TFElectraPreTrainedModel(TFBertPreTrainedModel):
config_class = ElectraConfig
base_model_prefix = "electra"

def get_extended_attention_mask(self, attention_mask, input_shape):
def get_extended_attention_mask(self, attention_mask, input_shape, dtype):
if attention_mask is None:
attention_mask = tf.fill(input_shape, 1)

Expand All @@ -211,7 +211,7 @@ def get_extended_attention_mask(self, attention_mask, input_shape):
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.

extended_attention_mask = tf.cast(extended_attention_mask, tf.float32)
extended_attention_mask = tf.cast(extended_attention_mask, dtype)
extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0

return extended_attention_mask
Expand Down Expand Up @@ -314,11 +314,11 @@ def call(
if token_type_ids is None:
token_type_ids = tf.fill(input_shape, 0)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
head_mask = self.get_head_mask(head_mask)

hidden_states = self.embeddings(input_ids, position_ids, token_type_ids, inputs_embeds, training=training)

extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape, hidden_states.dtype)
head_mask = self.get_head_mask(head_mask)

if hasattr(self, "embeddings_project"):
hidden_states = self.embeddings_project(hidden_states, training=training)

Expand Down

0 comments on commit 6ac59c3

Please sign in to comment.