Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

4D attention_mask support #27539

Merged
merged 14 commits into from
Dec 17, 2023
31 changes: 29 additions & 2 deletions src/transformers/modeling_attn_mask_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,10 +302,22 @@ def _prepare_4d_causal_attention_mask(
key_value_length = input_shape[-1] + past_key_values_length

# 4d mask is passed through the layers
if attention_mask is not None:
if attention_mask is not None and len(attention_mask.shape) == 2:
attention_mask = attn_mask_converter.to_4d(
attention_mask, input_shape[-1], key_value_length=key_value_length, dtype=inputs_embeds.dtype
)
elif attention_mask is not None and len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
else:
attention_mask = attn_mask_converter.to_causal_4d(
input_shape[0], input_shape[-1], key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
Expand Down Expand Up @@ -340,7 +352,22 @@ def _prepare_4d_causal_attention_mask_for_sdpa(
is_tracing = torch.jit.is_tracing()

if attention_mask is not None:
if torch.all(attention_mask == 1):
# 4d mask is passed through
if len(attention_mask.shape) == 4:
expected_shape = (input_shape[0], 1, input_shape[1], key_value_length)
if tuple(attention_mask.shape) != expected_shape:
raise ValueError(
f"Incorrect 4D attention_mask shape: {tuple(attention_mask.shape)}; expected: {expected_shape}."
)
else:
# if the 4D mask has correct shape - invert it and fill with negative infinity
inverted_mask = 1.0 - attention_mask.to(inputs_embeds.dtype)
attention_mask = inverted_mask.masked_fill(
inverted_mask.to(torch.bool), torch.finfo(inputs_embeds.dtype).min
)
return attention_mask

elif torch.all(attention_mask == 1):
if is_tracing:
pass
elif query_length == 1:
Expand Down
133 changes: 133 additions & 0 deletions tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import gc
import glob
import json
import os
Expand Down Expand Up @@ -49,6 +50,7 @@
require_tf,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_accelerator,
require_usr_bin_time,
slow,
Expand Down Expand Up @@ -1850,3 +1852,134 @@ def test_not_available_sdpa(self):
)

self.assertTrue("PyTorch SDPA requirements in Transformers are not met" in str(cm.exception))


@slow
@require_torch_gpu
class Mask4DTestBase(unittest.TestCase):
def tearDown(self):
gc.collect()
torch.cuda.empty_cache()

def get_test_data(self):
texts = ["the cat sat", "the cat had", "the cat is"]
encoded = [self.tokenizer.encode(t) for t in texts]
input_0 = torch.tensor(encoded, device=torch_device)
# tensor([[ 1, 278, 6635, 3290],
# [ 1, 278, 6635, 750],
# [ 1, 278, 6635, 338]], device='cuda:0')

# Combining common prefix with the unique ending tokens:
input_1 = torch.cat([input_0[0][:-1], input_0[:, -1]]).unsqueeze(0)
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')

# Creating a 4D mask where each of the last 3 tokens do not attend to each other.
mask_1 = torch.tensor(
[
[
[
[1, 0, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 1, 0, 0],
[1, 1, 1, 0, 1, 0],
[1, 1, 1, 0, 0, 1],
]
]
],
device="cuda:0",
poedator marked this conversation as resolved.
Show resolved Hide resolved
dtype=torch.int64,
)

# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_1 = torch.tensor([[0, 1, 2, 3, 3, 3]], device=torch_device, dtype=torch.int64)

return input_0, input_1, mask_1, position_ids_1


@slow
@require_torch_gpu
class Mask4DTestFP32(Mask4DTestBase):
def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
model_dtype = torch.float32
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)

def test_attention(self):
"""comparing outputs of attention layer"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

hid_0 = self.model.model.embed_tokens(input_0)
outs_0 = self.model.model.layers[0].self_attn.forward(hid_0)[0]
# outs_0.shape == torch.Size([3, 4, 768])

hid_1 = self.model.model.embed_tokens(input_1)
outs_1 = self.model.model.layers[0].self_attn.forward(
hid_1, attention_mask=mask_1.bool(), position_ids=position_ids_1
)[0]
# outs_1.shape == torch.Size([1, 6, 768])

outs_0_last_tokens = outs_0[:, -1, :] # last tokens in each batch line
outs_1_last_tokens = outs_1[0, -3:, :] # last three tokens
assert torch.allclose(outs_0_last_tokens, outs_1_last_tokens)

def test_inner_model(self):
"""comparing hidden outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits

logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
torch.testing.assert_close(
logits_0_last_tokens,
logits_1_last_tokens,
)

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits

logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens
torch.testing.assert_close(
logits_0_last_tokens,
logits_1_last_tokens,
)


ArthurZucker marked this conversation as resolved.
Show resolved Hide resolved
@slow
@require_torch_gpu
class Mask4DTestFP16(Mask4DTestBase):
test_attention = Mask4DTestFP32.test_attention

def setUp(self):
model_name = "JackFram/llama-68m" # small Llama-like model from FlexFlow
model_dtype = torch.float16
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=model_dtype).to(torch_device)

def test_causal_model_logits(self):
"""comparing logits outputs of whole inner model"""
input_0, input_1, mask_1, position_ids_1 = self.get_test_data()

logits_0 = self.model.forward(input_0).logits
logits_1 = self.model.forward(input_1, attention_mask=mask_1.bool(), position_ids=position_ids_1).logits

logits_0_last_tokens = logits_0[:, -1, :] # last tokens in each batch line
logits_1_last_tokens = logits_1[0, -3:, :] # last three tokens

indices_0 = logits_0_last_tokens.sort(descending=True).indices
indices_1 = logits_1_last_tokens.sort(descending=True).indices

# checking logits, but note relaxed tolerances for FP16
torch.testing.assert_close(logits_0_last_tokens, logits_1_last_tokens, atol=0.02, rtol=0.001)

# checking tokens order for the top tokens
for token_ids_0, token_ids_1 in zip(indices_0, indices_1):
self.assertTrue(torch.equal(token_ids_0[:128], token_ids_1[:128]))