Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper #26792

Merged
merged 28 commits into from
Oct 23, 2023

Conversation

patrickvonplaten
Copy link
Contributor

@patrickvonplaten patrickvonplaten commented Oct 13, 2023

What does this PR do?

For models that have Flash Attention 2 (FA2) implemented we currently pass both padding_mask and attention_mask to the respective vanilla attention class, e.g. LlamaAttention and to the FA2 class, e.g. LlamaFlashAttention2.

However, padding_mask is not used for LlamaAttention and attention_mask is not used for LlamaFlashAttention2. Conceptually the two masks are the same, only that attention_mask is a 4D mask while padding_mask is a 2D mask.

Passing around both masks and having both masks as concepts in our codebase is ambiguous and hurts readability. In this PR, I propose to remove the concept of padding_mask completely and instead just pass either a 2D or 4D attention_mask depending on whether we use FA2 or not.

Note: An additional benefit of this PR is that it will improve the performance when using FA2 as we will not create a 4D attention mask anymore.

Benchmarks:

The following script was used to benchmark the effect this mask implementation has on forward and generate.

#!/usr/bin/env python3
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import torch

DEVICE = "cuda:1"

tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to(DEVICE)


# forward
print("Forward benchmarks")
print(50 * "=")

for batch_size in (1, 4, 16):
    for input_seq in (4, 16, 256):
        input_ids = torch.ones((batch_size, input_seq), dtype=torch.long, device=DEVICE)
        attention_mask = torch.ones_like(input_ids)
        attention_mask[0, 3] = 0

        times = []
        for _ in range(3):
            start_time = time.time()
            with torch.no_grad():
                logits = model(input_ids=input_ids, attention_mask=attention_mask).logits
            times.append(time.time() - start_time)

        result = min(times)

        print(f"Forward bsz={batch_size}, input_seq={input_seq}: {result}")


# generate
print("Generate benchmarks")
print(50 * "=")

for batch_size in (1, 16):
    for input_seq in (4, 256):
        input_ids = torch.ones((batch_size, input_seq), dtype=torch.long, device=DEVICE)
        attention_mask = torch.ones_like(input_ids)
        attention_mask[0, 3] = 0

        times = []
        for _ in range(3):
            start_time = time.time()
            out = model.generate(input_ids=input_ids, max_new_tokens=256, do_sample=False)
            times.append(time.time() - start_time)

        result = min(times)

        print(f"Generate bsz={batch_size}, input_seq={input_seq}: {result}")

This PR:

Forward benchmarks
==================================================
Forward bsz=1, input_seq=4: 0.012479066848754883
Forward bsz=1, input_seq=16: 0.011297464370727539
Forward bsz=1, input_seq=256: 0.01240849494934082
Forward bsz=4, input_seq=4: 0.011190414428710938
Forward bsz=4, input_seq=16: 0.013025283813476562
Forward bsz=4, input_seq=256: 0.03526663780212402
Forward bsz=16, input_seq=4: 0.01126551628112793
Forward bsz=16, input_seq=16: 0.012389421463012695
Forward bsz=16, input_seq=256: 0.1560053825378418
Generate benchmarks
==================================================
Generate bsz=1, input_seq=4: 4.527426719665527
Generate bsz=1, input_seq=256: 4.667049169540405
Generate bsz=16, input_seq=4: 5.524803400039673
Generate bsz=16, input_seq=256: 7.931211709976196

Current main:

Forward benchmarks
==================================================
Forward bsz=1, input_seq=4: 0.017528295516967773
Forward bsz=1, input_seq=16: 0.012105464935302734
Forward bsz=1, input_seq=256: 0.01315617561340332
Forward bsz=4, input_seq=4: 0.011912107467651367
Forward bsz=4, input_seq=16: 0.013910531997680664
Forward bsz=4, input_seq=256: 0.035504817962646484
Forward bsz=16, input_seq=4: 0.012083053588867188
Forward bsz=16, input_seq=16: 0.012537956237792969
Forward bsz=16, input_seq=256: 0.15653300285339355
Generate benchmarks
==================================================
Generate bsz=1, input_seq=4: 4.554980516433716
Generate bsz=1, input_seq=256: 4.695344686508179
Generate bsz=16, input_seq=4: 5.55778431892395
Generate bsz=16, input_seq=256: 7.969247102737427

=> We don't see any drop in performance at all.

I've verified that the following tests all pass on a single GPU (RTX4090):
FA2:

RUN_SLOW=1 pytest -m flash_attn_test tests/models/llama/test_modeling_llama.py

and all Llama fast tests:

CUDA_VISIBLE_DEVICES="0" RUN_SLOW=1 pytest tests/models/llama/test_modeling_llama.py

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Oct 13, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Contributor

@fxmarty fxmarty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! I am happy with it, just wondering whether changing the attention_mask input from being 4D to 2D in LlamaDecoderLayer & LlamaAttention is considered a breaking change or not.

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@patrickvonplaten patrickvonplaten changed the title [Attn Mask Converter] refactor attn mask [WIP][Attn Mask Converter] refactor attn mask Oct 13, 2023
@patrickvonplaten
Copy link
Contributor Author

Thank you! I am happy with it, just wondering whether changing the attention_mask input from being 4D to 2D in LlamaDecoderLayer & LlamaAttention is considered a breaking change or not.

To me they are internal classes and thus changing the format of the attention_mask is ok. @LysandreJik @younesbelkada @ArthurZucker what do you think?

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot this looks much cleaner indeed!
Regarding your comment as those are internal classes I think it is fine, as long as the changes are explicitly detailed in the next release notes

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes a lot of sense indeed, we have had the past logic for quite a while and an update is welcome!

@@ -41,6 +41,49 @@
from .configuration_llama import LlamaConfig


class AttentionMask2DTo4D:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we plan to move this to the modelling utils or is this gonna be here for all models?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either #Copied from or we move it to a utils file. Both would work for me

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@@ -589,13 +634,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l


class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
def __init__(self, config: LlamaConfig, mask_converter=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we support passing the mask converter here but not in the parent classes it's kind of pointless no?
Wondering which one would be the best:

  1. Pass the mask converter class to all classes
  2. Only have it in the attention layer, controlled with a MASK_CONVERTER = {"default": AttentionMask2DTo4D} and just in the attention layer do self.mask_converter = MASK_CONVERTER[config.mask_converter] with the attribute added to the config common?
    (naming can be improve for sure!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I don't fully understand this

Copy link
Contributor Author

@patrickvonplaten patrickvonplaten Oct 16, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To begin with I would not make the "attention cache" a class that the user plays around with, but instead use it as an internal convenience class that doesn't sacrifice speed but helps readability.

Since the same instance of the class needs to be shared among the different layers, we need to instantiate it at a ...Model level and then let it trickle down to the respective attention classes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see sorry realised that if you want to share the same cached mask you gotta pass it, ignore my comment

patrickvonplaten and others added 5 commits October 16, 2023 10:42
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
@patrickvonplaten patrickvonplaten changed the title [WIP][Attn Mask Converter] refactor attn mask Remove ambiguous padding_mask and instead use an AttentionMaskConverter class Oct 16, 2023
@patrickvonplaten patrickvonplaten changed the title Remove ambiguous padding_mask and instead use an AttentionMaskConverter class Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper with Cache Oct 16, 2023
@@ -538,7 +640,7 @@ def _flash_attention_forward(
max_seqlen_k=max_seqlen_in_batch_k,
dropout_p=dropout,
softmax_scale=softmax_scale,
causal=True,
causal=self.attention_mask_cache.is_causal,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This allows us to easily copy-paste this function to non-causal attention layers (BERT)

@patrickvonplaten patrickvonplaten changed the title Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper with Cache Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper / Cache Oct 16, 2023
@patrickvonplaten
Copy link
Contributor Author

This PR should help make the following PRs nicer / cleaner:

@@ -609,7 +721,6 @@ def forward(
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@patrickvonplaten Make sure to change the docstring line 728 (of this branch):

attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, we might have to keep some logic to pop the padding mask for 1 release for BC. let's do a deprecation cycle no?

src/transformers/models/llama/modeling_llama.py Outdated Show resolved Hide resolved
@@ -560,7 +560,7 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.embed_tokens = value

# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's copy from original source

@patrickvonplaten patrickvonplaten changed the title Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper / Cache Remove ambiguous padding_mask and instead use a 2D->4D Attn Mask Mapper Oct 19, 2023
@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Oct 19, 2023

Update: I removed all the cache logic and instead just pass the attention_mask in the format that's needed. This is cleaner than caching tensors according to their shape, memory_id, etc...

All the benefits are kept including much improved readability and comprehensive attention mask class that can be copied / re-used by other models.

All tests pass just like before!

@fxmarty
Copy link
Contributor

fxmarty commented Oct 20, 2023

Then I am not sure what is the point of the class AttnMaskConverter? By the way, for SDPA, ideally we need both the information of 1/ is padding used 2/ transformers custom attention mask. This is because if custom masking is not used, we may dispatch on flash attention. So passing only a 4D mask for SDPA is suboptimal in my opinion.

Or I could just always pass the 4D attention mask to SDPA, but that kind of defeats the point given that dispatch to FA is then impossible.

) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
`(batch, sequence_length)` where padding elements are indicated by 0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this docstring is incorrect in the latest version.

Copy link
Contributor

@younesbelkada younesbelkada left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can confirm it looks all good on FA-2 end (benchmarks + tests)! thanks a lot @patrickvonplaten !

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we adding the sliding window as a new feature for these models? Otherwise would just use two different classes for Mistral and the other

@@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel):
config: PersimmonConfig
"""

# Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Persimmon = LLama in terms of architecture. It's alright to remove as its also very long but persimmon (and thus fuyu) will benefit from whatever happens in Llama so maybe a todo!

Comment on lines +105 to +106
sliding_window (`int`, *optional*):
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case the sliding window seems specific to Mistral so would maybe only include it in mistral's case no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We loose the copied-from then. I'd expect more Mistral-like models to pop up and think it's not worth removing it, see arguments here: #26792 (comment)

Comment on lines +1261 to +1274
if getattr(self.config, "_flash_attn_2_enabled", False):
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
key_value_length = seq_length + past_key_values_length
# 4d mask is passed through the layers
if attention_mask is not None:
attention_mask = self.attn_mask_converter.to_4d(
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype
)
else:
attention_mask = self.attn_mask_converter.to_causal_4d(
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure if this is cleaner than the previous version, passing a None attention mask. Seems like we could handle the None case in the class rather than here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We were also passing a None attention mask previously for padding_mask

Comment on lines +453 to +459
if "padding_mask" in kwargs:
warnings.warn(
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
)

# overwrite attention_mask with padding_mask
attention_mask = kwargs.pop("padding_mask")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for this


# add lower triangular sliding window mask if necessary
if sliding_window is not None:
diagonal = past_key_values_length - sliding_window + 1
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me!

@patrickvonplaten
Copy link
Contributor Author

patrickvonplaten commented Oct 23, 2023

Are we adding the sliding window as a new feature for these models? Otherwise would just use two different classes for Mistral and the other

It's not possible really to use sliding window in Llama because it's hardcoded at initialization "sliding_window=...." for Mistral. So the user can't (and should not use) sliding_window for Llama via any config parameters (in the same way is_causal is hardcoded to True). It is true that that we only have a couple of architectures that use sliding windows (Mistral, Longformer, ...) so we could move it out of the attention mask converter class and instead put it directly into the forward for Mistral. I think it's better though to leave as is because:

  • It's just a one-liner to add it to the attention converter and can be very nicely tested (which allowed us to spot the bug in Mistral)
  • There is a high chance that we'll have more models with windowed attention if models build on Mistral
  • We don't allow the user to configure windowed attention, so it's not like we adding a windowed attention feature to Llama or Falcon & thus making them more complicated.

But I do see how sliding window is arguably a bit exotic for the mask converter and if people feel strongly I can put it in Mistral's forward method instead.

Overall, we do move away a bit from "single-file" policy here as the attention converter is is a general class that has more features that needed for some models. But it does make sense here since there is really not much variation for attention mask across models and it greatly helps with readability.

@LysandreJik
Copy link
Member

No problem for me to leave the sliding window in the mask converter class, I indeed think we'll get to see more models leveraging the sliding window (or users that want it supported) in other architectures.

@patrickvonplaten patrickvonplaten merged commit 33f98cf into main Oct 23, 2023
3 checks passed
@patrickvonplaten patrickvonplaten deleted the attn_mask_converter branch October 23, 2023 16:54
staghado pushed a commit to staghado/transformers that referenced this pull request Oct 24, 2023
…pper (huggingface#26792)

* [Attn Mask Converter] refactor attn mask

* up

* Apply suggestions from code review

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* improve

* rename

* better cache

* renaming

* improve more

* improve

* fix bug

* finalize

* make style & make fix-copies

* correct more

* start moving attention_mask

* fix llama

* improve falcon

* up

* improve more

* improve more

* Update src/transformers/models/owlv2/modeling_owlv2.py

* make style

* make style

* rename to converter

* Apply suggestions from code review

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
EduardoPach pushed a commit to EduardoPach/transformers that referenced this pull request Nov 19, 2023
…pper (huggingface#26792)

* [Attn Mask Converter] refactor attn mask

* up

* Apply suggestions from code review

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>

* improve

* rename

* better cache

* renaming

* improve more

* improve

* fix bug

* finalize

* make style & make fix-copies

* correct more

* start moving attention_mask

* fix llama

* improve falcon

* up

* improve more

* improve more

* Update src/transformers/models/owlv2/modeling_owlv2.py

* make style

* make style

* rename to converter

* Apply suggestions from code review

---------

Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants