Skip to content

Commit

Permalink
[ tests] remove some flash attention class tests (#35817)
Browse files Browse the repository at this point in the history
remove class from tests
  • Loading branch information
ArthurZucker authored Jan 23, 2025
1 parent 2c3a44f commit 8736e91
Showing 1 changed file with 0 additions and 31 deletions.
31 changes: 0 additions & 31 deletions tests/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4628,43 +4628,12 @@ def test_flash_attn_2_from_config(self):
dummy_input = torch.LongTensor([[0, 2, 3, 4], [0, 2, 3, 4]]).to(torch_device)
dummy_attention_mask = torch.LongTensor([[1, 1, 1, 1], [0, 1, 1, 1]]).to(torch_device)

fa2_correctly_converted = False

for _, module in fa2_model.named_modules():
if "FlashAttention" in module.__class__.__name__:
fa2_correctly_converted = True
break

fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else fa2_model.config._attn_implementation == "flash_attention_2"
)
self.assertTrue(fa2_correctly_converted)

_ = fa2_model(input_ids=dummy_input, attention_mask=dummy_attention_mask)

with tempfile.TemporaryDirectory() as tmpdirname:
fa2_model.save_pretrained(tmpdirname)

model_from_pretrained = model_class.from_pretrained(tmpdirname)

self.assertTrue(model_from_pretrained.config._attn_implementation != "flash_attention_2")

fa2_correctly_converted = False

for _, module in model_from_pretrained.named_modules():
if "FlashAttention" in module.__class__.__name__:
fa2_correctly_converted = True
break

fa2_correctly_converted = (
fa2_correctly_converted
if not model_class._supports_flex_attn
else model_from_pretrained.config._attn_implementation == "flash_attention_2"
)
self.assertFalse(fa2_correctly_converted)

def _get_custom_4d_mask_test_data(self):
# Sequence in which all but the last token is the same
input_ids = torch.tensor(
Expand Down

0 comments on commit 8736e91

Please sign in to comment.