diff --git a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py index 6853f5333f137c..3ace323e822414 100644 --- a/src/transformers/models/deprecated/open_llama/modeling_open_llama.py +++ b/src/transformers/models/deprecated/open_llama/modeling_open_llama.py @@ -560,7 +560,7 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask + # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): # create causal mask # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] diff --git a/src/transformers/models/falcon/modeling_falcon.py b/src/transformers/models/falcon/modeling_falcon.py index e9dca6df989472..6eaeed4199771c 100644 --- a/src/transformers/models/falcon/modeling_falcon.py +++ b/src/transformers/models/falcon/modeling_falcon.py @@ -15,6 +15,7 @@ """PyTorch Falcon model.""" import math +import warnings from typing import Optional, Tuple, Union import torch @@ -76,9 +77,9 @@ def rotate_half(x): # Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -88,6 +89,143 @@ def _get_unpad_data(padding_mask): ) +# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter +class AttnMaskConverter: + """ + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + # TODO (joao): Is this the same implementation as in Llama? If so, let's make them the same and add the copy facilities class FalconRotaryEmbedding(nn.Module): """Implementation of RotaryEmbedding from GPT-NeoX. @@ -311,6 +449,7 @@ def __init__(self, config: FalconConfig): self.head_dim = self.hidden_size // self.num_heads self.split_size = self.hidden_size self.hidden_dropout = config.hidden_dropout + self.is_causal = True if self.head_dim * self.num_heads != self.hidden_size: raise ValueError( @@ -431,8 +570,13 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -465,9 +609,6 @@ def forward( else: present = None - float_min = torch.finfo(query_layer.dtype).min - attention_mask_float = (attention_mask * 1.0).masked_fill(attention_mask, float_min).to(query_layer.dtype) - query_layer_ = query_layer.reshape(batch_size, self.num_heads, -1, self.head_dim) key_layer_ = key_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) value_layer_ = value_layer.reshape(batch_size, num_kv_heads, -1, self.head_dim) @@ -482,16 +623,14 @@ def forward( ) attn_output = F.scaled_dot_product_attention( - query_layer_, key_layer_, value_layer_, attention_mask_float, 0.0, is_causal=False + query_layer_, key_layer_, value_layer_, attention_mask, 0.0, is_causal=False ) attention_scores = None else: attention_scores = query_layer_ @ key_layer_.transpose(-1, -2) attention_scores /= math.sqrt(self.head_dim) - attention_scores = F.softmax( - attention_scores + attention_mask_float, dim=-1, dtype=hidden_states.dtype - ) + attention_scores = F.softmax(attention_scores + attention_mask, dim=-1, dtype=hidden_states.dtype) attn_output = attention_scores @ value_layer_ attn_output = attn_output.view(batch_size, self.num_heads, query_length, self.head_dim) @@ -517,12 +656,12 @@ def forward( if input_dtype == torch.float16 or input_dtype == torch.bfloat16: attention_scores = attention_scores.to(torch.float32) # Matt (HF) note: We could possibly use F.scaled_dot_product_attention here too, by - # adding (alibi * self.inv_norm_factor) to attention_mask_float. I think this would be mathematically + # adding (alibi * self.inv_norm_factor) to attention_mask. I think this would be mathematically # equivalent and more performant, but there might be a numerical difference. If you're reading this # and you'd like to experiment and maybe file a PR, feel free! attention_logits = attention_scores + alibi.view(batch_size, self.num_heads, 1, -1) attention_logits *= self.inv_norm_factor - attention_probs = F.softmax(attention_logits + attention_mask_float, dim=-1, dtype=hidden_states.dtype) + attention_probs = F.softmax(attention_logits + attention_mask, dim=-1, dtype=hidden_states.dtype) # [batch_size, num_heads, q_length, kv_length] attention_probs = self.attention_dropout(attention_probs) @@ -563,8 +702,16 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + fused_qkv = self.query_key_value(hidden_states) # [batch_size, seq_length, 3 x hidden_size] num_kv_heads = self.num_heads if self.new_decoder_architecture else self.num_kv_heads # 3 x [batch_size, seq_length, num_heads, head_dim] @@ -630,7 +777,7 @@ def forward( value_layer = value_layer.to(target_dtype) attn_output = self._flash_attention_forward( - query_layer, key_layer, value_layer, padding_mask, query_length, dropout=attn_dropout + query_layer, key_layer, value_layer, attention_mask, query_length, dropout=attn_dropout ) attn_weights = attn_output.reshape(batch_size, query_length, self.num_heads * self.head_dim) @@ -643,7 +790,7 @@ def forward( # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._flash_attention_forward def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -656,7 +803,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -665,10 +812,10 @@ def _flash_attention_forward( The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -684,7 +831,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) @@ -696,8 +843,8 @@ def _flash_attention_forward( return attn_output # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2._upad_input - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( @@ -722,8 +869,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -752,7 +899,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class FalconDecoderLayer(nn.Module): - def __init__(self, config: FalconConfig): + def __init__(self, config: FalconConfig, attn_mask_converter=None): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -786,8 +933,13 @@ def forward( head_mask: Optional[torch.Tensor] = None, use_cache: bool = False, output_attentions: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + residual = hidden_states if self.config.new_decoder_architecture: @@ -806,7 +958,7 @@ def forward( head_mask=head_mask, use_cache=use_cache, output_attentions=output_attentions, - padding_mask=padding_mask, + **kwargs, ) attention_output = attn_outputs[0] @@ -1001,6 +1153,10 @@ def __init__(self, config: FalconConfig): # Embedding + LN Embedding self.word_embeddings = nn.Embedding(config.vocab_size, self.embed_dim) + # create attention mask cache that trickles down to each attention layer + # so that the attention_mask cache can be shared among layers + self.attn_mask_converter = AttnMaskConverter(is_causal=True) + # Transformer blocks self.h = nn.ModuleList([FalconDecoderLayer(config) for _ in range(config.num_hidden_layers)]) @@ -1015,37 +1171,6 @@ def __init__(self, config: FalconConfig): def get_input_embeddings(self): return self.word_embeddings - @staticmethod - def _prepare_attn_mask( - attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int - ) -> torch.BoolTensor: - # Create a causal mask - # The attention mask we receive as input should cover the whole extended sequence, including any past - # cache, so its shape should be [batch_size, seq_length + past_key_values_length] - # The output shape will be [batch_size, 1, seq_length, seq_length + past_key_values_length] - if input_shape[1] + past_key_values_length != attention_mask.shape[1]: - raise ValueError( - "Attention mask shape should be (batch_size, seq_length + past_key_values_length)" - f" but is {attention_mask.shape} with input_ids shape {input_shape} and past length" - f" {past_key_values_length}." - ) - combined_attention_mask = None - device = attention_mask.device - _, seq_length = input_shape - - if seq_length > 1: - combined_attention_mask = _make_causal_mask( - input_shape, device=device, past_key_values_length=past_key_values_length - ) - - # [batch_size, seq_length + past_key_values_length] -> [batch_size, 1, seq_length, seq_length + past_key_values_length] - expanded_attn_mask = _expand_mask(attention_mask, past_key_values_length=past_key_values_length) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask - ) - - return combined_attention_mask - def set_input_embeddings(self, new_embeddings: torch.Tensor): self.word_embeddings = new_embeddings @@ -1114,19 +1239,16 @@ def forward( past_key_values_length = 0 if past_key_values[0] is not None: past_key_values_length = past_key_values[0][0].shape[1] # 1 because RW-cache, not standard format - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=hidden_states.device) - padding_mask = None - else: - attention_mask = attention_mask.to(hidden_states.device) - - if 0 in attention_mask: - padding_mask = attention_mask - else: - padding_mask = None if self.use_alibi: - alibi = build_alibi_tensor(attention_mask, self.num_heads, dtype=hidden_states.dtype) + mask = ( + torch.ones( + (batch_size, seq_length + past_key_values_length), device=inputs_embeds.device, dtype=torch.long + ) + if attention_mask is None + else attention_mask + ) + alibi = build_alibi_tensor(mask, self.num_heads, dtype=hidden_states.dtype) else: alibi = None if position_ids is None: @@ -1136,11 +1258,20 @@ def forward( ) position_ids = position_ids.unsqueeze(0) - causal_mask = self._prepare_attn_mask( - attention_mask, - input_shape=(batch_size, seq_length), - past_key_values_length=past_key_values_length, - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + key_value_length = seq_length + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.attn_mask_converter.to_4d( + attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = self.attn_mask_converter.to_causal_4d( + batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): if output_hidden_states: @@ -1159,22 +1290,20 @@ def custom_forward(*inputs): create_custom_forward(block), hidden_states, alibi, - causal_mask, + attention_mask, position_ids, head_mask[i], - padding_mask, ) else: outputs = block( hidden_states, layer_past=layer_past, - attention_mask=causal_mask, + attention_mask=attention_mask, position_ids=position_ids, head_mask=head_mask[i], use_cache=use_cache, output_attentions=output_attentions, alibi=alibi, - padding_mask=padding_mask, ) hidden_states = outputs[0] diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index b67719ac327162..541455d86afde8 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -19,6 +19,7 @@ # limitations under the License. """ PyTorch LLaMA model.""" import math +import warnings from typing import List, Optional, Tuple, Union import torch @@ -51,9 +52,9 @@ _CONFIG_FOR_DOC = "LlamaConfig" -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -63,37 +64,140 @@ def _get_unpad_data(padding_mask): ) -# Copied from transformers.models.bart.modeling_bart._make_causal_mask -def _make_causal_mask( - input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0 -): +class AttnMaskConverter: """ - Make causal mask used for bi-directional self-attention. + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. """ - bsz, tgt_len = input_ids_shape - mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) - mask_cond = torch.arange(mask.size(-1), device=device) - mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - mask = mask.to(dtype) - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") -# Copied from transformers.models.bart.modeling_bart._expand_mask -def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): - """ - Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. - """ - bsz, src_len = mask.size() - tgt_len = tgt_len if tgt_len is not None else src_len + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask - expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + return expanded_4d_mask - inverted_mask = 1.0 - expanded_mask + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) - return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) class LlamaRMSNorm(nn.Module): @@ -272,6 +376,7 @@ def __init__(self, config: LlamaConfig): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta + self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -322,8 +427,13 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -420,14 +530,22 @@ class LlamaFlashAttention2(LlamaAttention): def forward( self, hidden_states: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: # LlamaFlashAttention2 attention does not support output_attentions + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + output_attentions = False bsz, q_len, _ = hidden_states.size() @@ -492,7 +610,7 @@ def forward( value_states = value_states.to(target_dtype) attn_output = self._flash_attention_forward( - query_states, key_states, value_states, padding_mask, q_len, dropout=dropout_rate + query_states, key_states, value_states, attention_mask, q_len, dropout=dropout_rate ) attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() @@ -504,7 +622,7 @@ def forward( return attn_output, attn_weights, past_key_value def _flash_attention_forward( - self, query_states, key_states, value_states, padding_mask, query_length, dropout=0.0, softmax_scale=None + self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None ): """ Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token @@ -517,7 +635,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -526,10 +644,10 @@ def _flash_attention_forward( The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -545,7 +663,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, ) attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) @@ -556,8 +674,8 @@ def _flash_attention_forward( return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape key_layer = index_first_axis( @@ -582,8 +700,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -616,13 +734,13 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -631,6 +749,10 @@ def forward( (see `past_key_values`). past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states """ + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) residual = hidden_states @@ -644,7 +766,7 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, + **kwargs, ) hidden_states = residual + hidden_states @@ -791,6 +913,10 @@ def __init__(self, config: LlamaConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + # create attention mask cache that trickles down to each attention layer + # so that the attention_mask cache can be shared among layers + self.attn_mask_converter = AttnMaskConverter(is_causal=True) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -805,30 +931,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask - def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) def forward( self, @@ -854,18 +956,15 @@ def forward( if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") elif input_ids is not None: - batch_size, seq_length = input_ids.shape + batch_size, seq_length = input_ids.shape[:2] elif inputs_embeds is not None: - batch_size, seq_length, _ = inputs_embeds.shape + batch_size, seq_length = inputs_embeds.shape[:2] else: raise ValueError("You have to specify either input_ids or inputs_embeds") - seq_length_with_past = seq_length past_key_values_length = 0 - if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[2] - seq_length_with_past = seq_length_with_past + past_key_values_length if position_ids is None: device = input_ids.device if input_ids is not None else inputs_embeds.device @@ -876,22 +975,23 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - padding_mask = None + + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None else: - if 0 in attention_mask: - padding_mask = attention_mask + key_value_length = seq_length + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.attn_mask_converter.to_4d( + attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype + ) else: - padding_mask = None - - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length - ) + attention_mask = self.attn_mask_converter.to_causal_4d( + batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) + # embed positions hidden_states = inputs_embeds if self.gradient_checkpointing and self.training: @@ -917,7 +1017,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -932,7 +1032,6 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index d650d60b8a553e..f735c471268136 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -20,6 +20,7 @@ """ PyTorch Mistral model.""" import inspect import math +import warnings from typing import List, Optional, Tuple, Union import torch @@ -53,10 +54,147 @@ _CONFIG_FOR_DOC = "MistralConfig" +# Copied from transformers.models.llama.modeling_llama.AttnMaskConverter +class AttnMaskConverter: + """ + A utility attention mask class that allows: + - Create a causal 4d mask + - Create a causal 4d mask with slided window + - Convert a 2d attention mask (batch_size, query_length) to a 4d attention mask (batch_size, 1, query_length, + key_value_length) that can be multiplied with attention scores + + Parameters: + is_causal (`bool`): + Whether the attention mask should be a uni-directional (causal) or bi-directional mask. + + sliding_window (`int`, *optional*): + Optionally, the sliding window masks can be created if `sliding_window` is defined to a positive integer. + """ + + def __init__(self, is_causal: bool, sliding_window: Optional[int] = None): + self.is_causal = is_causal + self.sliding_window = sliding_window + + def to_causal_4d( + self, + batch_size: int, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + device: Union[torch.device, "str"] = "cpu", + ) -> torch.Tensor: + """ + Creates a causal 4D mask of (bsz, head_dim=1, query_length, key_value_length) shape and adds large negative + bias to upper right hand triangular matrix (causal mask). + """ + if not self.is_causal: + raise ValueError(f"Please use `to_causal_4d` only if {self.__class__} has `is_causal` set to True.") + + # If shape is not cached, create a new causal mask and cache it + input_shape = (batch_size, query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if input_shape[-1] > 1 or self.sliding_window is not None: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + + return causal_4d_mask + + def to_4d( + self, + attention_mask_2d: torch.Tensor, + query_length: int, + key_value_length: int, + dtype: torch.dtype = torch.float32, + ) -> torch.Tensor: + """ + Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length, + key_value_length) shape and by adding a large negative bias to not-attended positions. If attention_mask is + causal, a causal mask will be added. + """ + input_shape = (attention_mask_2d.shape[0], query_length) + past_key_values_length = key_value_length - query_length + + # create causal mask + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + causal_4d_mask = None + if (input_shape[-1] > 1 or self.sliding_window is not None) and self.is_causal: + past_key_values_length = key_value_length - query_length + causal_4d_mask = self._make_causal_mask( + input_shape, + dtype, + device=attention_mask_2d.device, + past_key_values_length=past_key_values_length, + sliding_window=self.sliding_window, + ) + elif self.sliding_window is not None: + raise NotImplementedError("Sliding window is currently only implemented for causal masking") + + # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] + expanded_attn_mask = self._expand_mask(attention_mask_2d, dtype, tgt_len=input_shape[-1]).to( + attention_mask_2d.device + ) + expanded_4d_mask = expanded_attn_mask if causal_4d_mask is None else expanded_attn_mask + causal_4d_mask + + return expanded_4d_mask + + def _make_causal_mask( + self, + input_ids_shape: torch.Size, + dtype: torch.dtype, + device: torch.device, + past_key_values_length: int = 0, + sliding_window: Optional[int] = None, + ): + """ + Make causal mask used for bi-directional self-attention. + """ + bsz, tgt_len = input_ids_shape + mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device) + mask_cond = torch.arange(mask.size(-1), device=device) + mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0) + + mask = mask.to(dtype) + + if past_key_values_length > 0: + mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) + + # add lower triangular sliding window mask if necessary + if sliding_window is not None: + diagonal = past_key_values_length - sliding_window + 1 + + context_mask = 1 - torch.triu(torch.ones_like(mask, dtype=torch.int), diagonal=diagonal) + mask.masked_fill_(context_mask.bool(), torch.finfo(dtype).min) + + return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) + + def _expand_mask(self, mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + + # Copied from transformers.models.llama.modeling_llama._get_unpad_data -def _get_unpad_data(padding_mask): - seqlens_in_batch = padding_mask.sum(dim=-1, dtype=torch.int32) - indices = torch.nonzero(padding_mask.flatten(), as_tuple=False).flatten() +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() max_seqlen_in_batch = seqlens_in_batch.max().item() cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.torch.int32), (1, 0)) return ( @@ -66,33 +204,6 @@ def _get_unpad_data(padding_mask): ) -def _make_sliding_window_causal_mask( - input_ids_shape: torch.Size, - dtype: torch.dtype, - device: torch.device, - past_key_values_length: int = 0, - sliding_window: int = 4096, -): - """ - Make causal mask used for sliding window attention - """ - bsz, tgt_len = input_ids_shape - - tensor = torch.full( - (tgt_len, tgt_len), - fill_value=1, - device=device, - ) - mask = torch.tril(tensor, diagonal=0) - # make the mask banded to account for sliding window - mask = torch.triu(mask, diagonal=-sliding_window) - mask = torch.log(mask).to(dtype) - - if past_key_values_length > 0: - mask = torch.cat([torch.zeros(tgt_len, past_key_values_length, dtype=dtype, device=device), mask], dim=-1) - return mask[None, None, :, :].expand(bsz, 1, tgt_len, tgt_len + past_key_values_length) - - # Copied from transformers.models.bart.modeling_bart._expand_mask def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): """ @@ -223,6 +334,7 @@ def __init__(self, config: MistralConfig): self.num_key_value_groups = self.num_heads // self.num_key_value_heads self.max_position_embeddings = config.max_position_embeddings self.rope_theta = config.rope_theta + self.is_causal = True if (self.head_dim * self.num_heads) != self.hidden_size: raise ValueError( @@ -251,8 +363,12 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -332,8 +448,15 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: bool = False, use_cache: bool = False, - padding_mask: Optional[torch.LongTensor] = None, + **kwargs, ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") bsz, q_len, _ = hidden_states.size() query_states = self.q_proj(hidden_states) @@ -385,9 +508,9 @@ def forward( past_key_value = (past_key, past_value) - if padding_mask is not None: - padding_mask = padding_mask[:, slicing_tokens:] - padding_mask = torch.cat([padding_mask, torch.ones_like(padding_mask[:, -1:])], dim=-1) + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) key_states = torch.cat([past_key_value[0], key_states], dim=2) value_states = torch.cat([past_key_value[1], value_states], dim=2) @@ -433,7 +556,7 @@ def forward( query_states, key_states, value_states, - padding_mask, + attention_mask, q_len, dropout=dropout_rate, use_sliding_windows=use_sliding_windows, @@ -452,7 +575,7 @@ def _flash_attention_forward( query_states, key_states, value_states, - padding_mask, + attention_mask, query_length, dropout=0.0, softmax_scale=None, @@ -469,7 +592,7 @@ def _flash_attention_forward( Input key states to be passed to Flash Attention API value_states (`torch.Tensor`): Input value states to be passed to Flash Attention API - padding_mask (`torch.Tensor`): + attention_mask (`torch.Tensor`): The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the position of padding tokens and 1 for the position of non-padding tokens. dropout (`int`, *optional*): @@ -480,10 +603,10 @@ def _flash_attention_forward( Whether to activate sliding window attention. """ # Contains at least one padding token in the sequence - if padding_mask is not None: + if attention_mask is not None: batch_size = query_states.shape[0] query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( - query_states, key_states, value_states, padding_mask, query_length + query_states, key_states, value_states, attention_mask, query_length ) cu_seqlens_q, cu_seqlens_k = cu_seq_lens @@ -500,7 +623,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, ) else: attn_output_unpad = flash_attn_varlen_func( @@ -513,7 +636,7 @@ def _flash_attention_forward( max_seqlen_k=max_seqlen_in_batch_k, dropout_p=dropout, softmax_scale=softmax_scale, - causal=True, + causal=self.is_causal, window_size=(self.config.sliding_window, self.config.sliding_window), ) @@ -536,16 +659,16 @@ def _flash_attention_forward( return attn_output - def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_length): + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape # On the first iteration we need to properly re-create the padding mask # by slicing it on the proper place - if kv_seq_len != padding_mask.shape[-1]: - padding_mask_num_tokens = padding_mask.shape[-1] - padding_mask = padding_mask[:, padding_mask_num_tokens - kv_seq_len :] + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] - indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(padding_mask) + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) @@ -566,8 +689,8 @@ def _upad_input(self, query_layer, key_layer, value_layer, padding_mask, query_l query_layer = query_layer.squeeze(1) else: # The -q_len: slice assumes left padding. - padding_mask = padding_mask[:, -query_length:] - query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, padding_mask) + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) return ( query_layer, @@ -600,13 +723,17 @@ def forward( past_key_value: Optional[Tuple[torch.Tensor]] = None, output_attentions: Optional[bool] = False, use_cache: Optional[bool] = False, - padding_mask: Optional[torch.Tensor] = None, + **kwargs, ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) """ Args: hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` attention_mask (`torch.FloatTensor`, *optional*): attention mask of size - `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. + `(batch, sequence_length)` where padding elements are indicated by 0. output_attentions (`bool`, *optional*): Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned tensors for more detail. @@ -628,7 +755,6 @@ def forward( past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, ) hidden_states = residual + hidden_states @@ -775,6 +901,10 @@ def __init__(self, config: MistralConfig): self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size + # create attention mask cache that trickles down to each attention layer + # so that the attention_mask cache can be shared among layers + self.attn_mask_converter = AttnMaskConverter(is_causal=True, sliding_window=config.sliding_window) + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([MistralDecoderLayer(config) for _ in range(config.num_hidden_layers)]) self.norm = MistralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -789,32 +919,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.embed_tokens = value - def _prepare_decoder_attention_mask( - self, attention_mask, input_shape, inputs_embeds, past_key_values_length, sliding_window - ): - # create causal mask - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - combined_attention_mask = None - if input_shape[-1] > 1: - combined_attention_mask = _make_sliding_window_causal_mask( - input_shape, - inputs_embeds.dtype, - device=inputs_embeds.device, - past_key_values_length=past_key_values_length, - sliding_window=sliding_window, - ) - - if attention_mask is not None: - # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] - expanded_attn_mask = _expand_mask(attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]).to( - inputs_embeds.device - ) - combined_attention_mask = ( - expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask + combined_attention_mask - ) - - return combined_attention_mask - @add_start_docstrings_to_model_forward(MISTRAL_INPUTS_DOCSTRING) def forward( self, @@ -865,23 +969,13 @@ def forward( if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) - padding_mask = None - - # embed positions - if attention_mask is None: - attention_mask = torch.ones( - (batch_size, seq_length_with_past), dtype=torch.bool, device=inputs_embeds.device - ) - elif 0 in attention_mask: - padding_mask = attention_mask - if ( - padding_mask is not None + attention_mask is not None and hasattr(self.config, "_flash_attn_2_enabled") and self.config._flash_attn_2_enabled and past_key_values is not None ): - is_padding_right = padding_mask[:, -1].sum().item() != batch_size + is_padding_right = attention_mask[:, -1].sum().item() != batch_size if is_padding_right: raise ValueError( "You are attempting to perform batched generation with padding_side='right'" @@ -889,13 +983,20 @@ def forward( " call `tokenizer.padding_side = 'left'` before tokenizing the input. " ) - attention_mask = self._prepare_decoder_attention_mask( - attention_mask, - (batch_size, seq_length), - inputs_embeds, - past_key_values_length, - sliding_window=self.config.sliding_window, - ) + if getattr(self.config, "_flash_attn_2_enabled", False): + # 2d mask is passed through the layers + attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None + else: + key_value_length = seq_length + past_key_values_length + # 4d mask is passed through the layers + if attention_mask is not None: + attention_mask = self.attn_mask_converter.to_4d( + attention_mask, seq_length, key_value_length, dtype=inputs_embeds.dtype + ) + else: + attention_mask = self.attn_mask_converter.to_causal_4d( + batch_size, seq_length, key_value_length, dtype=inputs_embeds.dtype, device=inputs_embeds.device + ) hidden_states = inputs_embeds @@ -922,7 +1023,7 @@ def forward( def create_custom_forward(module): def custom_forward(*inputs): # None for past_key_value - return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask) + return module(*inputs, past_key_value, output_attentions) return custom_forward @@ -940,7 +1041,6 @@ def custom_forward(*inputs): past_key_value=past_key_value, output_attentions=output_attentions, use_cache=use_cache, - padding_mask=padding_mask, ) hidden_states = layer_outputs[0] diff --git a/src/transformers/models/persimmon/modeling_persimmon.py b/src/transformers/models/persimmon/modeling_persimmon.py index a0bc5726382336..d73cc4484484fa 100644 --- a/src/transformers/models/persimmon/modeling_persimmon.py +++ b/src/transformers/models/persimmon/modeling_persimmon.py @@ -548,7 +548,6 @@ class PersimmonModel(PersimmonPreTrainedModel): config: PersimmonConfig """ - # Copied from transformers.models.llama.modeling_llama.LlamaModel.__init__ with LLAMA->PERSIMMON,Llama->Persimmon,PersimmonRMSNorm->nn.LayerNorm,norm->final_layernorm,rms_final_layernorm_eps->layer_norm_eps def __init__(self, config: PersimmonConfig): super().__init__(config) self.padding_idx = config.pad_token_id diff --git a/tests/models/llama/test_modeling_llama.py b/tests/models/llama/test_modeling_llama.py index 2402986900fda6..df41c6c5f52044 100644 --- a/tests/models/llama/test_modeling_llama.py +++ b/tests/models/llama/test_modeling_llama.py @@ -39,6 +39,149 @@ LlamaModel, LlamaTokenizer, ) + from transformers.models.llama.modeling_llama import AttnMaskConverter + + +@require_torch +class AttentionMaskTester(unittest.TestCase): + def check_non_causal(self, bsz, q_len, kv_len, mask_2d, mask_4d): + mask_indices = (mask_2d != 1)[:, None].broadcast_to((bsz, q_len, kv_len)) + mask_4d_values = mask_4d[:, 0][mask_indices] + is_inf = mask_4d_values == -float("inf") + is_min = mask_4d_values == torch.finfo(mask_4d.dtype).min + assert torch.logical_or(is_inf, is_min).all() + + def check_to_4d(self, mask_converter, q_len, kv_len, additional_mask=None, bsz=3): + mask_2d = torch.ones((bsz, kv_len), device=torch_device, dtype=torch.long) + + if additional_mask is not None: + for bsz_idx, seq_idx in additional_mask: + mask_2d[bsz_idx, seq_idx] = 0 + + mask_4d = mask_converter.to_4d(mask_2d, query_length=q_len, key_value_length=kv_len) + + assert mask_4d.shape == (bsz, 1, q_len, kv_len) + + context = mask_converter.sliding_window + if mask_converter.is_causal and context is None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) + + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + if 0 in mask_2d: + # at least causal mask + maybe more + assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + elif not mask_converter.is_causal and context is None: + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == 0 + if 0 in mask_2d: + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + elif mask_converter.is_causal and context is not None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) + num_tokens_masked = bsz * num_tokens_masked + + if 0 not in mask_2d: + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + if 0 in mask_2d: + # at least causal mask + maybe more + assert (mask_4d != 0).sum().cpu().item() >= num_tokens_masked + self.check_non_causal(bsz, q_len, kv_len, mask_2d, mask_4d) + + def check_to_causal(self, mask_converter, q_len, kv_len, bsz=3): + mask_4d = mask_converter.to_causal_4d(bsz, query_length=q_len, key_value_length=kv_len, device=torch_device) + + if q_len == 1 and mask_converter.sliding_window is None: + # no causal mask if q_len is 1 + assert mask_4d is None + return + + context = mask_converter.sliding_window + if mask_converter.is_causal and context is None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = bsz * (q_len * (q_len - 1) // 2) + + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + elif not mask_converter.is_causal and context is None: + assert (mask_4d != 0).sum().cpu().item() == 0 + elif mask_converter.is_causal and context is not None: + # k * (k+1) / 2 tokens are masked in triangualar masks + num_tokens_masked = (q_len * (q_len - 1) // 2) + self.compute_num_context_mask(kv_len, context, q_len) + num_tokens_masked = bsz * num_tokens_masked + + assert (mask_4d != 0).sum().cpu().item() == num_tokens_masked + + def compute_num_context_mask(self, kv_len, context, q_len): + # This function computes the # of attention tokens that are added for + # the sliding window + c_mask_len = kv_len - context + num_mask_triangle = c_mask_len * (c_mask_len + 1) // 2 + cut_mask_len = max(c_mask_len - q_len, 0) + num_cut_mask = cut_mask_len * (cut_mask_len + 1) // 2 + return num_mask_triangle - num_cut_mask + + def test_2d_to_4d_causal(self): + mask_converter = AttnMaskConverter(is_causal=True) + + # auto-regressive use case + self.check_to_4d(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_4d(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_2d_to_4d(self): + torch.ones((3, 7), device=torch_device, dtype=torch.long) + mask_converter = AttnMaskConverter(is_causal=False) + + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_2d_to_4d_causal_sliding(self): + torch.ones((3, 7), device=torch_device, dtype=torch.long) + mask_converter = AttnMaskConverter(is_causal=True, sliding_window=5) + + # auto-regressive use case + self.check_to_4d(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_4d(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_4d(mask_converter, q_len=7, kv_len=7) + + # same with extra attention masks + self.check_to_4d(mask_converter, q_len=1, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=3, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + self.check_to_4d(mask_converter, q_len=7, kv_len=7, additional_mask=[(0, 2), (1, 3), (2, 0)]) + + def test_causal_mask(self): + mask_converter = AttnMaskConverter(is_causal=True) + + # auto-regressive use case + self.check_to_causal(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_causal(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_causal(mask_converter, q_len=7, kv_len=7) + + def test_causal_mask_sliding(self): + mask_converter = AttnMaskConverter(is_causal=True, sliding_window=3) + + # auto-regressive use case + self.check_to_causal(mask_converter, q_len=1, kv_len=7) + # special auto-regressive case + self.check_to_causal(mask_converter, q_len=3, kv_len=7) + # non auto-regressive case + self.check_to_causal(mask_converter, q_len=7, kv_len=7) class LlamaModelTester: