Skip to content

Commit

Permalink
[OPT] Run test in lower precision on GPU (huggingface#17353)
Browse files Browse the repository at this point in the history
* [OPT] Run test only in half precision

* up

* up

* up

* up

* finish

* fix on GPU

* Update tests/models/opt/test_modeling_opt.py
  • Loading branch information
patrickvonplaten authored May 19, 2022
1 parent 2b28229 commit e8714c0
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 56 deletions.
77 changes: 23 additions & 54 deletions src/transformers/models/opt/modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import Tensor, nn
from torch import nn
from torch.nn import CrossEntropyLoss

from ...activations import ACT2FN
Expand Down Expand Up @@ -86,52 +85,28 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min)


def make_positions(mask, padding_idx: int):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
positions = (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx
return positions


class OPTLearnedPositionalEmbedding(nn.Embedding):
"""
This module learns positional embeddings up to a fixed maximum size. Padding ids are ignored by either offsetting
based on padding_idx or by setting padding_idx to None and ensuring that the appropriate position ids are passed to
the forward function.
This module learns positional embeddings up to a fixed maximum size.
"""

def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int = 1):
super().__init__(num_embeddings, embedding_dim, padding_idx)
self.onnx_trace = False
if self.padding_idx is not None:
self.max_positions = self.num_embeddings - self.padding_idx - 1
else:
self.max_positions = self.num_embeddings

def forward(self, attention_mask: Tensor, positions: Optional[Tensor] = None):
# attention_masks is expected to be of size [batch_size x seq_len].
if not ((positions is None) or (self.padding_idx is None)):
raise ValueError("If positions is pre-computed then padding_idx should not be set.")

if positions is None:
attention_mask = attention_mask.long()
positions = make_positions(attention_mask, self.padding_idx)

return F.embedding(
positions,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
def __init__(self, num_embeddings: int, embedding_dim: int):
# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
# and adjust num_embeddings appropriately. Other models don't have this hack
self.offset = 2
super().__init__(num_embeddings + self.offset, embedding_dim)

def forward(self, attention_mask: torch.LongTensor, past_key_values_length: int = 0):
"""`input_ids_shape` is expected to be [bsz x seqlen]."""
attention_mask = attention_mask.long()

# create positions depending on attention_mask
positions = (torch.cumsum(attention_mask, dim=1).type_as(attention_mask) * attention_mask).long() - 1

# cut positions if `past_key_values_length` is > 0
positions = positions[:, past_key_values_length:]

return super().forward(positions + self.offset)


# Copied from transformers.models.bart.modeling_bart.BartAttention with Bart->OPT
Expand Down Expand Up @@ -504,12 +479,7 @@ def __init__(self, config: OPTConfig):
self.vocab_size = config.vocab_size

self.embed_tokens = nn.Embedding(config.vocab_size, config.word_embed_proj_dim, self.padding_idx)

# OPT is set up so that if padding_idx is specified then offset the embedding ids by 2
if self.padding_idx is not None:
num_embeddings = config.max_position_embeddings + 2

self.embed_positions = OPTLearnedPositionalEmbedding(num_embeddings, config.hidden_size, self.padding_idx)
self.embed_positions = OPTLearnedPositionalEmbedding(config.max_position_embeddings, config.hidden_size)

if config.word_embed_proj_dim != config.hidden_size:
self.project_out = nn.Linear(config.hidden_size, config.word_embed_proj_dim, bias=False)
Expand Down Expand Up @@ -639,8 +609,7 @@ def forward(
# embed positions
if attention_mask is None:
attention_mask = torch.ones(inputs_embeds.shape[:2], dtype=torch.bool, device=inputs_embeds.device)

positions = self.embed_positions(attention_mask)[:, past_key_values_length:, :]
pos_embeds = self.embed_positions(attention_mask, past_key_values_length)

attention_mask = self._prepare_decoder_attention_mask(
attention_mask, input_shape, inputs_embeds, past_key_values_length
Expand All @@ -649,8 +618,7 @@ def forward(
if self.project_in is not None:
inputs_embeds = self.project_in(inputs_embeds)

hidden_states = inputs_embeds + positions

hidden_states = inputs_embeds + pos_embeds
hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)

# decoder layers
Expand All @@ -671,6 +639,7 @@ def forward(
# add LayerDrop (see https://arxiv.org/abs/1909.11556 for description)
if output_hidden_states:
all_hidden_states += (hidden_states,)

dropout_probability = random.uniform(0, 1)
if self.training and (dropout_probability < self.layerdrop):
continue
Expand Down
15 changes: 13 additions & 2 deletions tests/models/opt/test_modeling_opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,14 +271,25 @@ class OPTModelIntegrationTests(unittest.TestCase):
def test_inference_no_head(self):
model = OPTModel.from_pretrained("facebook/opt-350m").to(torch_device)
input_ids = _long_tensor([[0, 31414, 232, 328, 740, 1140, 12695, 69, 46078, 1588, 2]])

with torch.no_grad():
output = model(input_ids=input_ids).last_hidden_state

expected_shape = torch.Size((1, 11, 512))
self.assertEqual(output.shape, expected_shape)
expected_slice = torch.tensor(
[[-0.2867, -1.9256, -0.3062], [-1.2711, -0.1337, -0.1897], [0.4109, 0.1187, -1.3142]], device=torch_device
[[-0.2873, -1.9242, -0.3059], [-1.2738, -0.1333, -0.1877], [0.4116, 0.1192, -1.3107]],
device=torch_device,
)
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-3))
# Getting different logits results on GPU depending on PyTorch version (1.10+cu11.0 vs. 1.11+cu11.4)
# and results also differ between CPU and GPU. Only on CPU it seems to be deterministic.

# It's not because the weights are saved & loaded in FP16
# checked that the same happens when weights are stored in fp32 and loaded in fp32.
# The differences start to creep in in the first linear projection matrix project_in_dim
# It however also happens for BART (maybe related to training model in fp16?)
atol = 1e-2 if torch_device != "cpu" else 1e-3
assert_tensors_close(output[0, :3, :3], expected_slice, atol=atol)


@require_torch
Expand Down

0 comments on commit e8714c0

Please sign in to comment.