-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove ambiguous padding_mask
and instead use a 2D->4D Attn Mask Mapper
#26792
Conversation
The documentation is not available anymore as the PR was closed or merged. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you! I am happy with it, just wondering whether changing the attention_mask
input from being 4D to 2D in LlamaDecoderLayer
& LlamaAttention
is considered a breaking change or not.
To me they are internal classes and thus changing the format of the attention_mask is ok. @LysandreJik @younesbelkada @ArthurZucker what do you think? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot this looks much cleaner indeed!
Regarding your comment as those are internal classes I think it is fine, as long as the changes are explicitly detailed in the next release notes
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes a lot of sense indeed, we have had the past logic for quite a while and an update is welcome!
@@ -41,6 +41,49 @@ | |||
from .configuration_llama import LlamaConfig | |||
|
|||
|
|||
class AttentionMask2DTo4D: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we plan to move this to the modelling utils or is this gonna be here for all models?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Either #Copied from or we move it to a utils file. Both would work for me
@@ -589,13 +634,13 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l | |||
|
|||
|
|||
class LlamaDecoderLayer(nn.Module): | |||
def __init__(self, config: LlamaConfig): | |||
def __init__(self, config: LlamaConfig, mask_converter=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we support passing the mask converter here but not in the parent classes it's kind of pointless no?
Wondering which one would be the best:
- Pass the mask converter class to all classes
- Only have it in the attention layer, controlled with a
MASK_CONVERTER = {"default": AttentionMask2DTo4D}
and just in the attention layer doself.mask_converter = MASK_CONVERTER[config.mask_converter]
with the attribute added to the config common?
(naming can be improve for sure!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry I don't fully understand this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To begin with I would not make the "attention cache" a class that the user plays around with, but instead use it as an internal convenience class that doesn't sacrifice speed but helps readability.
Since the same instance of the class needs to be shared among the different layers, we need to instantiate it at a ...Model
level and then let it trickle down to the respective attention classes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see sorry realised that if you want to share the same cached mask you gotta pass it, ignore my comment
Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
padding_mask
and instead use an AttentionMaskConverter class
padding_mask
and instead use an AttentionMaskConverter class padding_mask
and instead use a 2D->4D Attn Mask Mapper with Cache
@@ -538,7 +640,7 @@ def _flash_attention_forward( | |||
max_seqlen_k=max_seqlen_in_batch_k, | |||
dropout_p=dropout, | |||
softmax_scale=softmax_scale, | |||
causal=True, | |||
causal=self.attention_mask_cache.is_causal, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows us to easily copy-paste this function to non-causal attention layers (BERT)
padding_mask
and instead use a 2D->4D Attn Mask Mapper with Cachepadding_mask
and instead use a 2D->4D Attn Mask Mapper / Cache
This PR should help make the following PRs nicer / cleaner: |
@@ -609,7 +721,6 @@ def forward( | |||
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |||
output_attentions: Optional[bool] = False, | |||
use_cache: Optional[bool] = False, | |||
padding_mask: Optional[torch.LongTensor] = None, | |||
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | |||
""" | |||
Args: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@patrickvonplaten Make sure to change the docstring line 728 (of this branch):
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, we might have to keep some logic to pop the padding mask for 1 release for BC. let's do a deprecation cycle no?
@@ -560,7 +560,7 @@ def get_input_embeddings(self): | |||
def set_input_embeddings(self, value): | |||
self.embed_tokens = value | |||
|
|||
# Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask | |||
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's copy from original source
padding_mask
and instead use a 2D->4D Attn Mask Mapper / Cachepadding_mask
and instead use a 2D->4D Attn Mask Mapper
Update: I removed all the cache logic and instead just pass the attention_mask in the format that's needed. This is cleaner than caching tensors according to their shape, memory_id, etc... All the benefits are kept including much improved readability and comprehensive attention mask class that can be copied / re-used by other models. All tests pass just like before! |
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
examples/research_projects/jax-projects/big_bird/bigbird_flax.py
Outdated
Show resolved
Hide resolved
Then I am not sure what is the point of the class AttnMaskConverter? By the way, for SDPA, ideally we need both the information of 1/ is padding used 2/ transformers custom attention mask. This is because if custom masking is not used, we may dispatch on flash attention. So passing only a 4D mask for SDPA is suboptimal in my opinion. Or I could just always pass the 4D attention mask to SDPA, but that kind of defeats the point given that dispatch to FA is then impossible. |
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: | ||
""" | ||
Args: | ||
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | ||
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | ||
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | ||
`(batch, sequence_length)` where padding elements are indicated by 0. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this docstring is incorrect in the latest version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can confirm it looks all good on FA-2 end (benchmarks + tests)! thanks a lot @patrickvonplaten !
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we adding the sliding window as a new feature for these models? Otherwise would just use two different classes for Mistral and the other
@@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel): | |||
config: PersimmonConfig | |||
""" | |||
|
|||
# Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Persimmon = LLama in terms of architecture. It's alright to remove as its also very long but persimmon (and thus fuyu) will benefit from whatever happens in Llama so maybe a todo!
sliding_window (`int`, *optional*): | ||
Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
in this case the sliding window seems specific to Mistral
so would maybe only include it in mistral's case no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We loose the copied-from then. I'd expect more Mistral-like models to pop up and think it's not worth removing it, see arguments here: #26792 (comment)
if getattr(self.config, "_flash_attn_2_enabled", False): | ||
# 2d mask is passed through the layers | ||
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None | ||
else: | ||
key_value_length = seq_length + past_key_values_length | ||
# 4d mask is passed through the layers | ||
if attention_mask is not None: | ||
attention_mask = self.attn_mask_converter.to_4d( | ||
attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype | ||
) | ||
else: | ||
attention_mask = self.attn_mask_converter.to_causal_4d( | ||
batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not sure if this is cleaner than the previous version, passing a None attention mask. Seems like we could handle the None case in the class rather than here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We were also passing a None
attention mask previously for padding_mask
if "padding_mask" in kwargs: | ||
warnings.warn( | ||
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" | ||
) | ||
|
||
# overwrite attention_mask with padding_mask | ||
attention_mask = kwargs.pop("padding_mask") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks for this
|
||
# add lower triangular sliding window mask if necessary | ||
if sliding_window is not None: | ||
diagonal = past_key_values_length - sliding_window + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me!
It's not possible really to use sliding window in Llama because it's hardcoded at initialization "sliding_window=...." for Mistral. So the user can't (and should not use)
But I do see how sliding window is arguably a bit exotic for the mask converter and if people feel strongly I can put it in Mistral's forward method instead. Overall, we do move away a bit from "single-file" policy here as the attention converter is is a general class that has more features that needed for some models. But it does make sense here since there is really not much variation for attention mask across models and it greatly helps with readability. |
No problem for me to leave the sliding window in the mask converter class, I indeed think we'll get to see more models leveraging the sliding window (or users that want it supported) in other architectures. |
…pper (huggingface#26792) * [Attn Mask Converter] refactor attn mask * up * Apply suggestions from code review Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * improve * rename * better cache * renaming * improve more * improve * fix bug * finalize * make style & make fix-copies * correct more * start moving attention_mask * fix llama * improve falcon * up * improve more * improve more * Update src/transformers/models/owlv2/modeling_owlv2.py * make style * make style * rename to converter * Apply suggestions from code review --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
…pper (huggingface#26792) * [Attn Mask Converter] refactor attn mask * up * Apply suggestions from code review Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com> * improve * rename * better cache * renaming * improve more * improve * fix bug * finalize * make style & make fix-copies * correct more * start moving attention_mask * fix llama * improve falcon * up * improve more * improve more * Update src/transformers/models/owlv2/modeling_owlv2.py * make style * make style * rename to converter * Apply suggestions from code review --------- Co-authored-by: fxmarty <9808326+fxmarty@users.noreply.github.com>
What does this PR do?
For models that have Flash Attention 2 (FA2) implemented we currently pass both
padding_mask
andattention_mask
to the respective vanilla attention class, e.g.LlamaAttention
and to the FA2 class, e.g.LlamaFlashAttention2
.However,
padding_mask
is not used forLlamaAttention
andattention_mask
is not used forLlamaFlashAttention2
. Conceptually the two masks are the same, only thatattention_mask
is a 4D mask whilepadding_mask
is a 2D mask.Passing around both masks and having both masks as concepts in our codebase is ambiguous and hurts readability. In this PR, I propose to remove the concept of
padding_mask
completely and instead just pass either a 2D or 4Dattention_mask
depending on whether we use FA2 or not.Note: An additional benefit of this PR is that it will improve the performance when using FA2 as we will not create a 4D attention mask anymore.
Benchmarks:
The following script was used to benchmark the effect this mask implementation has on forward and generate.
This PR:
Current main:
=> We don't see any drop in performance at all.
I've verified that the following tests all pass on a single GPU (RTX4090):
FA2:
and all Llama fast tests: