From b94376d4f7360872c760568d7b96a855b43eb901 Mon Sep 17 00:00:00 2001 From: eljandoubi Date: Tue, 22 Oct 2024 12:42:32 +0200 Subject: [PATCH] fix conflict --- .../models/pix2struct/modeling_pix2struct.py | 410 ++++++++++++------ 1 file changed, 276 insertions(+), 134 deletions(-) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index 9adeea48b78850..182e866e9e0f7b 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -22,7 +22,9 @@ from torch import nn from ...activations import ACT2FN +from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache, StaticCache from ...generation import GenerationMixin +from ...modeling_attn_mask_utils import AttentionMaskConverter from ...modeling_outputs import ( BaseModelOutput, BaseModelOutputWithPooling, @@ -38,6 +40,7 @@ add_start_docstrings, add_start_docstrings_to_model_forward, is_torch_fx_proxy, + is_torchdynamo_compiling, logging, replace_return_docstrings, ) @@ -184,14 +187,17 @@ def to_projection_shape(states): if self.gradient_checkpointing and self.training: position_bias.requires_grad = True - if attention_mask is None: - attention_mask = torch.ones((batch_size, seq_length), device=scores.device, dtype=scores.dtype) - if attention_mask.dim() == 2: position_bias = position_bias + attention_mask[:, None, None, :].to(position_bias.device) - else: + elif attention_mask is not None: # (batch_size, n_heads, seq_length, key_length) position_bias = position_bias + attention_mask.to(position_bias.device) + elif not is_torchdynamo_compiling(): + attention_mask = torch.ones( + (batch_size, seq_length), device=position_bias.device, dtype=position_bias.dtype + ) + position_bias = position_bias + attention_mask.to(position_bias.device) + position_bias = 1 - position_bias position_bias_masked = position_bias.masked_fill(position_bias == 1, torch.finfo(scores.dtype).min) @@ -355,6 +361,8 @@ class Pix2StructPreTrainedModel(PreTrainedModel): """ config_class = Pix2StructConfig + _supports_cache_class = True + _supports_static_cache = False @property def dummy_inputs(self): @@ -673,7 +681,9 @@ def forward(self, hidden_states): class Pix2StructTextAttention(nn.Module): - def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=False): + def __init__( + self, config: Pix2StructTextConfig, has_relative_attention_bias=False, layer_idx: Optional[int] = None + ): super().__init__() self.has_relative_attention_bias = has_relative_attention_bias self.relative_attention_num_buckets = config.relative_attention_num_buckets @@ -683,6 +693,13 @@ def __init__(self, config: Pix2StructTextConfig, has_relative_attention_bias=Fal self.n_heads = config.num_heads self.dropout = config.dropout_rate self.inner_dim = self.n_heads * self.key_value_proj_dim + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and " + "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) # Mesh TensorFlow initialization to avoid scaling before softmax self.query = nn.Linear(self.hidden_size, self.hidden_size, bias=False) @@ -773,75 +790,56 @@ def forward( query_length=None, use_cache=False, output_attentions=False, + cache_position=None, ): """ Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length) - # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head) + # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length) batch_size, seq_length = hidden_states.shape[:2] - real_seq_length = seq_length - - if past_key_value is not None: - if len(past_key_value) != 2: - raise ValueError( - f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states" - ) - real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length - - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] - - def to_projection_shape(states): - """projection""" - return states.contiguous().view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + # if key_value_states are provided this layer is used as a cross-attention layer for the decoder + is_cross_attention = key_value_states is not None - def project(hidden_states, proj_layer, key_value_states, past_key_value): - """projects hidden states correctly to key/query states""" - if key_value_states is None: - # self-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(hidden_states)) - elif past_key_value is None: - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(key_value_states)) - - if past_key_value is not None: - if key_value_states is None: - # self-attn - # (batch_size, n_heads, key_length, dim_per_head) - hidden_states = torch.cat([past_key_value, hidden_states], dim=2) - elif past_key_value.shape[2] != key_value_states.shape[1]: - # checking that the `sequence_length` of the `past_key_value` is the same as - # the provided `key_value_states` to support prefix tuning - # cross-attn - # (batch_size, n_heads, seq_length, dim_per_head) - hidden_states = to_projection_shape(proj_layer(key_value_states)) - else: - # cross-attn - hidden_states = past_key_value - return hidden_states + query_states = self.query(hidden_states).contiguous() + query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) - # get query states - # (batch_size, n_heads, seq_length, dim_per_head) - query_states = to_projection_shape(self.query(hidden_states)) + if past_key_value is not None: + is_updated = past_key_value.is_updated.get(self.layer_idx) + if is_cross_attention: + # after the first generated id, we can subsequently re-use all key/value_states from cache + past_key_value = past_key_value.cross_attention_cache + else: + past_key_value = past_key_value.self_attention_cache # get key/value states - key_states = project( - hidden_states, self.key, key_value_states, past_key_value[0] if past_key_value is not None else None - ) - value_states = project( - hidden_states, self.value, key_value_states, past_key_value[1] if past_key_value is not None else None - ) + current_states = key_value_states if is_cross_attention else hidden_states + if is_cross_attention and past_key_value and is_updated: + # reuse k,v, cross_attentions + key_states = past_key_value.key_cache[self.layer_idx] + value_states = past_key_value.value_cache[self.layer_idx] + else: + key_states = self.key(current_states).contiguous() + value_states = self.value(current_states).contiguous() + key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: + # save all key/value_states to cache to be re-used for fast auto-regressive generation + cache_position = cache_position if not is_cross_attention else None + key_states, value_states = past_key_value.update( + key_states, value_states, self.layer_idx, {"cache_position": cache_position} + ) + # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls + if is_cross_attention: + past_key_value.is_updated[self.layer_idx] = True # compute scores - scores = torch.matmul( - query_states, key_states.transpose(3, 2) - ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 + scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: + real_seq_length = cache_position[-1] + 1 if query_length is None else query_length + key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] if not self.has_relative_attention_bias: position_bias = torch.zeros( (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype @@ -851,11 +849,6 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): else: position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) - # if key and values are already calculated - # we want only the last query position bias - if past_key_value is not None: - position_bias = position_bias[:, :, -hidden_states.size(1) :, :] - if mask is not None: position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) @@ -883,19 +876,20 @@ def project(hidden_states, proj_layer, key_value_states, past_key_value): attn_output = self.output(attn_output) - present_key_value_state = (key_states, value_states) if use_cache else None - outputs = (attn_output,) + (present_key_value_state,) + (position_bias,) + outputs = (attn_output,) + (past_key_value,) + (position_bias,) if output_attentions: outputs = outputs + (attn_weights,) return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size +# Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerSelfAttention->Pix2StructTextLayerSelfAttention,self.SelfAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerSelfAttention(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=has_relative_attention_bias) + self.attention = Pix2StructTextAttention( + config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx + ) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -908,6 +902,7 @@ def forward( past_key_value=None, use_cache=False, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( @@ -918,17 +913,18 @@ def forward( past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) hidden_states = hidden_states + self.dropout(attention_output[0]) outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them return outputs -# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size +# Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5LayerNorm->Pix2StructLayerNorm,T5Attention->Pix2StructTextAttention,T5LayerCrossAttention->Pix2StructTextLayerCrossAttention,self.EncDecAttention->self.attention,config.d_model->config.hidden_size class Pix2StructTextLayerCrossAttention(nn.Module): - def __init__(self, config): + def __init__(self, config, layer_idx: Optional[int] = None): super().__init__() - self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False) + self.attention = Pix2StructTextAttention(config, has_relative_attention_bias=False, layer_idx=layer_idx) self.layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -943,6 +939,7 @@ def forward( use_cache=False, query_length=None, output_attentions=False, + cache_position=None, ): normed_hidden_states = self.layer_norm(hidden_states) attention_output = self.attention( @@ -955,6 +952,7 @@ def forward( use_cache=use_cache, query_length=query_length, output_attentions=output_attentions, + cache_position=cache_position, ) layer_output = hidden_states + self.dropout(attention_output[0]) outputs = (layer_output,) + attention_output[1:] # add attentions if we output them @@ -962,11 +960,13 @@ def forward( class Pix2StructTextBlock(nn.Module): - def __init__(self, config, has_relative_attention_bias=False): + def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None): super().__init__() self.self_attention = Pix2StructTextLayerSelfAttention( - config, has_relative_attention_bias=has_relative_attention_bias + config, + has_relative_attention_bias=has_relative_attention_bias, + layer_idx=layer_idx, ) self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) @@ -987,32 +987,19 @@ def forward( use_cache=False, output_attentions=False, return_dict=True, + cache_position=None, ): - if past_key_value is not None: - expected_num_past_key_values = 2 if encoder_hidden_states is None else 4 - - if len(past_key_value) != expected_num_past_key_values: - raise ValueError( - f"There should be {expected_num_past_key_values} past states. " - f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}" - f"Got {len(past_key_value)} past key / value states" - ) - - self_attn_past_key_value = past_key_value[:2] - cross_attn_past_key_value = past_key_value[2:] - else: - self_attn_past_key_value, cross_attn_past_key_value = None, None - self_attention_outputs = self.self_attention( hidden_states, attention_mask=attention_mask, position_bias=position_bias, layer_head_mask=layer_head_mask, - past_key_value=self_attn_past_key_value, + past_key_value=past_key_value, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states, present_key_value_state = self_attention_outputs[:2] + hidden_states, past_key_value = self_attention_outputs[:2] attention_outputs = self_attention_outputs[2:] # Keep self-attention outputs and relative position weights # clamp inf values to enable fp16 training @@ -1022,35 +1009,25 @@ def forward( do_cross_attention = encoder_hidden_states is not None if do_cross_attention: - # the actual query length is unknown for cross attention - # if using past key value states. Need to inject it here - if present_key_value_state is not None: - query_length = present_key_value_state[0].shape[2] - else: - query_length = None - cross_attention_outputs = self.encoder_decoder_attention( hidden_states, key_value_states=encoder_hidden_states, attention_mask=encoder_attention_mask, position_bias=encoder_decoder_position_bias, layer_head_mask=cross_attn_layer_head_mask, - past_key_value=cross_attn_past_key_value, - query_length=query_length, + past_key_value=past_key_value, + query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) - hidden_states = cross_attention_outputs[0] + hidden_states, past_key_value = cross_attention_outputs[:2] # clamp inf values to enable fp16 training if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any(): clamp_value = torch.finfo(hidden_states.dtype).max - 1000 hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value) - # Combine self attn and cross attn key value states - if present_key_value_state is not None: - present_key_value_state = present_key_value_state + cross_attention_outputs[1] - # Keep cross-attention outputs and relative position weights attention_outputs = attention_outputs + cross_attention_outputs[2:] @@ -1065,7 +1042,7 @@ def forward( outputs = (hidden_states,) if use_cache: - outputs = outputs + (present_key_value_state,) + attention_outputs + outputs = outputs + (past_key_value,) + attention_outputs else: outputs = outputs + attention_outputs @@ -1187,6 +1164,9 @@ def forward( more detail. return_dict (`bool`, *optional*): Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. + cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*): + Indices depicting the position of the input sequence tokens in the sequence. It is used to update the + cache in the correct position and to infer the complete sequence length. """ PIX2STRUCT_INPUTS_DOCSTRING = r""" @@ -1290,11 +1270,13 @@ class Pix2StructTextModel(Pix2StructPreTrainedModel): def __init__(self, config): super().__init__(config) - self.num_layers = config.num_layers self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) self.layer = nn.ModuleList( - [Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)] + [ + Pix2StructTextBlock(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) + for i in range(config.num_layers) + ] ) self.final_layer_norm = Pix2StructLayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.dropout = nn.Dropout(config.dropout_rate) @@ -1365,6 +1347,7 @@ def forward( output_hidden_states: Optional[bool] = None, labels: Optional[torch.LongTensor] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, **kwargs, ) -> Union[Tuple[torch.FloatTensor, ...], CausalLMOutputWithCrossAttentions]: r""" @@ -1406,24 +1389,54 @@ def forward( batch_size, seq_length = input_shape - # required mask seq length can be calculated via length of past - mask_seq_length = past_key_values[0][0].shape[2] + seq_length if past_key_values is not None else seq_length + # initialize past_key_values + return_legacy_cache = False + return_self_attention_cache = False + if use_cache or past_key_values is not None: + if isinstance(past_key_values, Cache) and not isinstance(past_key_values, EncoderDecoderCache): + return_self_attention_cache = True + past_key_values = EncoderDecoderCache(past_key_values, DynamicCache()) + elif not isinstance(past_key_values, EncoderDecoderCache): + return_legacy_cache = True + logger.warning_once( + "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. " + "You should pass an instance of `EncoderDecoderCache` instead, e.g. " + "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`." + ) + past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values) + elif past_key_values is None: + past_key_values = EncoderDecoderCache(DynamicCache(), DynamicCache()) + + past_key_values_length = 0 + if cache_position is not None: + past_key_values_length = cache_position[0] + elif past_key_values is not None: + past_key_values_length = past_key_values.get_seq_length() + + if cache_position is None: + cache_position = torch.arange( + past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device + ) if attention_mask is None: - attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - if encoder_attention_mask is None and encoder_hidden_states is not None: - encoder_seq_length = encoder_hidden_states.shape[1] - encoder_attention_mask = torch.ones( - batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long + # required mask seq length can be calculated via length of past + mask_seq_length = ( + past_key_values.get_seq_length() + seq_length if past_key_values is not None else seq_length ) + attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) - # initialize past_key_values with `None` if past does not exist - if past_key_values is None: - past_key_values = [None] * self.num_layers - - # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length] - # ourselves in which case we just need to make it broadcastable to all heads. - extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape) + if self.config.is_decoder: + causal_mask = self._update_causal_mask( + attention_mask, + inputs_embeds, + cache_position, + past_key_values.self_attention_cache if past_key_values is not None else None, + output_attentions, + ) + else: + causal_mask = attention_mask[:, None, None, :] + causal_mask = causal_mask.to(dtype=inputs_embeds.dtype) + causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min # If a 2D or 3D attention mask is provided for the cross-attention # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length] @@ -1439,7 +1452,6 @@ def forward( # Prepare head mask if needed head_mask = self.get_head_mask(head_mask, self.config.num_layers) cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers) - present_key_value_states = () if use_cache else None all_hidden_states = () if output_hidden_states else None all_attentions = () if output_attentions else None all_cross_attentions = () if (output_attentions) else None @@ -1448,7 +1460,7 @@ def forward( hidden_states = self.dropout(inputs_embeds) - for i, (layer_module, past_key_value) in enumerate(zip(self.layer, past_key_values)): + for i, layer_module in enumerate(self.layer): layer_head_mask = head_mask[i] cross_attn_layer_head_mask = cross_attn_head_mask[i] if output_hidden_states: @@ -1463,7 +1475,7 @@ def forward( layer_outputs = self._gradient_checkpointing_func( layer_module.forward, hidden_states, - extended_attention_mask, + causal_mask, position_bias, encoder_hidden_states, encoder_extended_attention_mask, @@ -1473,20 +1485,22 @@ def forward( None, # past_key_value is always None with gradient checkpointing use_cache, output_attentions, + cache_position, ) else: layer_outputs = layer_module( hidden_states, - attention_mask=extended_attention_mask, + attention_mask=causal_mask, position_bias=position_bias, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_extended_attention_mask, encoder_decoder_position_bias=encoder_decoder_position_bias, layer_head_mask=layer_head_mask, cross_attn_layer_head_mask=cross_attn_layer_head_mask, - past_key_value=past_key_value, + past_key_value=past_key_values, use_cache=use_cache, output_attentions=output_attentions, + cache_position=cache_position, ) # layer_outputs is a tuple with: @@ -1494,7 +1508,7 @@ def forward( if use_cache is False: layer_outputs = layer_outputs[:1] + (None,) + layer_outputs[1:] - hidden_states, present_key_value_state = layer_outputs[:2] + hidden_states, next_decoder_cache = layer_outputs[:2] # We share the position biases between the layers - the first layer store them # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights), @@ -1502,9 +1516,6 @@ def forward( position_bias = layer_outputs[2] if encoder_hidden_states is not None: encoder_decoder_position_bias = layer_outputs[4 if output_attentions else 3] - # append next layer key value states - if use_cache: - present_key_value_states = present_key_value_states + (present_key_value_state,) if output_attentions: all_attentions = all_attentions + (layer_outputs[3],) @@ -1528,13 +1539,19 @@ def forward( loss = loss_fct(logits.contiguous().view(-1, logits.size(-1)), labels.contiguous().view(-1)) + next_cache = next_decoder_cache if use_cache else None + if return_self_attention_cache: + next_cache = past_key_values.self_attention_cache + if return_legacy_cache: + next_cache = past_key_values.to_legacy_cache() + if not return_dict: return tuple( v for v in [ loss, logits, - present_key_value_states, + next_cache, all_hidden_states, all_attentions, all_cross_attentions, @@ -1544,12 +1561,135 @@ def forward( return CausalLMOutputWithCrossAttentions( loss=loss, logits=logits, - past_key_values=present_key_value_states, + past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_attentions, cross_attentions=all_cross_attentions, ) + # Copied from transformers.models.llama.modeling_llama.LlamaModel._update_causal_mask + def _update_causal_mask( + self, + attention_mask: torch.Tensor, + input_tensor: torch.Tensor, + cache_position: torch.Tensor, + past_key_values: Cache, + output_attentions: bool, + ): + if self.config._attn_implementation == "flash_attention_2": + if attention_mask is not None and 0.0 in attention_mask: + return attention_mask + return None + + # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in + # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail + # to infer the attention mask. + past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0 + using_static_cache = isinstance(past_key_values, StaticCache) + + # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward + if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: + if AttentionMaskConverter._ignore_causal_mask_sdpa( + attention_mask, + inputs_embeds=input_tensor, + past_key_values_length=past_seen_tokens, + is_training=self.training, + ): + return None + + dtype, device = input_tensor.dtype, input_tensor.device + sequence_length = input_tensor.shape[1] + if using_static_cache: + target_length = past_key_values.get_max_cache_shape() + else: + target_length = ( + attention_mask.shape[-1] + if isinstance(attention_mask, torch.Tensor) + else past_seen_tokens + sequence_length + 1 + ) + + # In case the provided `attention` mask is 2D, we generate a causal mask here (4D). + causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position( + attention_mask, + sequence_length=sequence_length, + target_length=target_length, + dtype=dtype, + device=device, + cache_position=cache_position, + batch_size=input_tensor.shape[0], + ) + + if ( + self.config._attn_implementation == "sdpa" + and attention_mask is not None + and attention_mask.device.type == "cuda" + and not output_attentions + ): + # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when + # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. + # Details: https://github.com/pytorch/pytorch/issues/110213 + min_dtype = torch.finfo(dtype).min + causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) + + return causal_mask + + @staticmethod + # Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel._prepare_4d_causal_attention_mask_with_cache_position + def _prepare_4d_causal_attention_mask_with_cache_position( + attention_mask: torch.Tensor, + sequence_length: int, + target_length: int, + dtype: torch.dtype, + device: torch.device, + cache_position: torch.Tensor, + batch_size: int, + **kwargs, + ): + """ + Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape + `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing. + + Args: + attention_mask (`torch.Tensor`): + A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape + `(batch_size, 1, query_length, key_value_length)`. + sequence_length (`int`): + The sequence length being processed. + target_length (`int`): + The target length: when generating with static cache, the mask should be as long as the static cache, + to account for the 0 padding, the part of the cache that is not filled yet. + dtype (`torch.dtype`): + The dtype to use for the 4D attention mask. + device (`torch.device`): + The device to plcae the 4D attention mask on. + cache_position (`torch.Tensor`): + Indices depicting the position of the input sequence tokens in the sequence. + batch_size (`torch.Tensor`): + Batch size. + """ + if attention_mask is not None and attention_mask.dim() == 4: + # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing. + causal_mask = attention_mask + else: + min_dtype = torch.finfo(dtype).min + causal_mask = torch.full( + (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device + ) + if sequence_length != 1: + causal_mask = torch.triu(causal_mask, diagonal=1) + causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) + causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) + if attention_mask is not None: + causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit + mask_length = attention_mask.shape[-1] + padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] + padding_mask = padding_mask == 0 + causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( + padding_mask, min_dtype + ) + + return causal_mask + @add_start_docstrings( "A conditional generation model with a language modeling head. Can be used for sequence generation tasks.", @@ -1616,6 +1756,7 @@ def forward( output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, + cache_position: Optional[torch.LongTensor] = None, ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]: r""" Returns: @@ -1724,6 +1865,7 @@ def forward( output_hidden_states=output_hidden_states, labels=labels, return_dict=return_dict, + cache_position=cache_position, ) if not return_dict: @@ -1739,4 +1881,4 @@ def forward( encoder_last_hidden_state=encoder_outputs.last_hidden_state, encoder_hidden_states=encoder_outputs.hidden_states, encoder_attentions=encoder_outputs.attentions, - ) + ) \ No newline at end of file