From 97bb9299c4f179384d470b6909d2259368011be1 Mon Sep 17 00:00:00 2001 From: Ilyas Moutawwakil <57442720+IlyasMoutawwakil@users.noreply.github.com> Date: Mon, 28 Oct 2024 11:24:56 +0100 Subject: [PATCH] Fix pix2struct (#34374) * fix * fix and test use_cache test * style * remove atol --- .../models/pix2struct/modeling_pix2struct.py | 60 +++++++++++-------- .../pix2struct/test_modeling_pix2struct.py | 11 ++++ 2 files changed, 46 insertions(+), 25 deletions(-) diff --git a/src/transformers/models/pix2struct/modeling_pix2struct.py b/src/transformers/models/pix2struct/modeling_pix2struct.py index b1ac81bb1f21..176dadd5b883 100644 --- a/src/transformers/models/pix2struct/modeling_pix2struct.py +++ b/src/transformers/models/pix2struct/modeling_pix2struct.py @@ -762,11 +762,14 @@ def _relative_position_bucket(relative_position, bidirectional=True, num_buckets return relative_buckets # Adapted from transformers.models.t5.modeling_t5.T5Attention.compute_bias - def compute_bias(self, query_length, key_length, device=None): + def compute_bias(self, query_length, key_length, device=None, cache_position=None): """Compute binned relative position bias""" if device is None: device = self.relative_attention_bias.weight.device - context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + if cache_position is None: + context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None] + else: + context_position = cache_position[:, None].to(device) memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :] relative_position = memory_position - context_position # shape (query_length, key_length) relative_position_bucket = self._relative_position_bucket( @@ -779,6 +782,7 @@ def compute_bias(self, query_length, key_length, device=None): values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length) return values + # Adapted from transformers.models.t5.modeling_t5.T5Attention.forward def forward( self, hidden_states, @@ -796,61 +800,66 @@ def forward( Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states). """ # Input is (batch_size, seq_length, dim) - # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, query_length, key_length) + # Mask is (batch_size, 1, 1, key_length) (non-causal) or (batch_size, 1, seq_length, key_length) (causal decoder) batch_size, seq_length = hidden_states.shape[:2] # if key_value_states are provided this layer is used as a cross-attention layer for the decoder is_cross_attention = key_value_states is not None - query_states = self.query(hidden_states).contiguous() + query_states = self.query(hidden_states) query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) if past_key_value is not None: is_updated = past_key_value.is_updated.get(self.layer_idx) if is_cross_attention: # after the first generated id, we can subsequently re-use all key/value_states from cache - past_key_value = past_key_value.cross_attention_cache + curr_past_key_value = past_key_value.cross_attention_cache else: - past_key_value = past_key_value.self_attention_cache + curr_past_key_value = past_key_value.self_attention_cache - # get key/value states current_states = key_value_states if is_cross_attention else hidden_states if is_cross_attention and past_key_value and is_updated: # reuse k,v, cross_attentions - key_states = past_key_value.key_cache[self.layer_idx] - value_states = past_key_value.value_cache[self.layer_idx] + key_states = curr_past_key_value.key_cache[self.layer_idx] + value_states = curr_past_key_value.value_cache[self.layer_idx] else: - key_states = self.key(current_states).contiguous() - value_states = self.value(current_states).contiguous() + key_states = self.key(current_states) + value_states = self.value(current_states) key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2) + if past_key_value is not None: # save all key/value_states to cache to be re-used for fast auto-regressive generation cache_position = cache_position if not is_cross_attention else None - key_states, value_states = past_key_value.update( + key_states, value_states = curr_past_key_value.update( key_states, value_states, self.layer_idx, {"cache_position": cache_position} ) # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls if is_cross_attention: past_key_value.is_updated[self.layer_idx] = True - # compute scores + # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9 scores = torch.matmul(query_states, key_states.transpose(3, 2)) if position_bias is None: - real_seq_length = cache_position[-1] + 1 if query_length is None else query_length - key_length = real_seq_length if key_value_states is None else key_value_states.shape[1] + key_length = key_states.shape[-2] + # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past) + real_seq_length = query_length if query_length is not None else cache_position[-1] + 1 if not self.has_relative_attention_bias: position_bias = torch.zeros( - (1, self.n_heads, real_seq_length, key_length), device=scores.device, dtype=scores.dtype + (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype ) if self.gradient_checkpointing and self.training: position_bias.requires_grad = True else: - position_bias = self.compute_bias(real_seq_length, key_length, device=scores.device) + position_bias = self.compute_bias( + real_seq_length, key_length, device=scores.device, cache_position=cache_position + ) + position_bias = position_bias[:, :, -seq_length:, :] if mask is not None: - position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length) + causal_mask = mask[:, :, :, : key_states.shape[-2]] + position_bias = position_bias + causal_mask if self.pruned_heads: mask = torch.ones(position_bias.shape[1]) @@ -860,10 +869,9 @@ def forward( position_bias_masked = position_bias scores += position_bias_masked - # (batch_size, n_heads, seq_length, key_length) - attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (batch_size, n_heads, seq_length, key_length) + attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) # Mask heads if we want to @@ -871,12 +879,12 @@ def forward( attn_weights = attn_weights * layer_head_mask attn_output = torch.matmul(attn_weights, value_states) - # (batch_size, seq_length, dim) - attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim) + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.view(batch_size, -1, self.inner_dim) attn_output = self.output(attn_output) - outputs = (attn_output,) + (past_key_value,) + (position_bias,) + outputs = (attn_output, past_key_value, position_bias) if output_attentions: outputs = outputs + (attn_weights,) @@ -969,7 +977,10 @@ def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optiona layer_idx=layer_idx, ) - self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention(config) + self.encoder_decoder_attention = Pix2StructTextLayerCrossAttention( + config, + layer_idx=layer_idx, + ) self.mlp = Pix2StructTextLayerFF(config) @@ -1019,7 +1030,6 @@ def forward( query_length=cache_position[-1] + 1, use_cache=use_cache, output_attentions=output_attentions, - cache_position=cache_position, ) hidden_states, past_key_value = cross_attention_outputs[:2] diff --git a/tests/models/pix2struct/test_modeling_pix2struct.py b/tests/models/pix2struct/test_modeling_pix2struct.py index 2d762008cbbc..18b79f3fbc9c 100644 --- a/tests/models/pix2struct/test_modeling_pix2struct.py +++ b/tests/models/pix2struct/test_modeling_pix2struct.py @@ -419,6 +419,7 @@ def prepare_config_and_inputs_for_common(self): @require_torch class Pix2StructModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase): all_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else () + all_generative_model_classes = (Pix2StructForConditionalGeneration,) if is_torch_available() else {} pipeline_model_mapping = {"image-to-text": Pix2StructForConditionalGeneration} if is_torch_available() else {} fx_compatible = False test_head_masking = False @@ -445,6 +446,16 @@ def test_model(self): ), ) + def test_generative_model(self): + config, input_dict = self.model_tester.prepare_config_and_inputs_for_common() + for model_class in self.all_generative_model_classes: + model = model_class(config).eval().to(torch_device) + + output = model.generate(**input_dict, use_cache=False, min_new_tokens=10, max_new_tokens=10) + output_use_cache = model.generate(**input_dict, use_cache=True, min_new_tokens=10, max_new_tokens=10) + + torch.testing.assert_close(output, output_use_cache) + @unittest.skip(reason="Hidden_states is tested in individual model tests") def test_hidden_states_output(self): pass