Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Decoder and cross-attention shape is different when obtained by model.generate() and model() #33296

Closed
1 of 4 tasks
cgr71ii opened this issue Sep 4, 2024 · 2 comments
Closed
1 of 4 tasks

Comments

@cgr71ii
Copy link

cgr71ii commented Sep 4, 2024

System Info

  • transformers version: 4.43.3
  • Platform: Linux-6.5.0-45-generic-x86_64-with-glibc2.35
  • Python version: 3.11.9
  • Huggingface_hub version: 0.24.5
  • Safetensors version: 0.4.3
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.0 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?: no
  • Using GPU in script?: yes
  • GPU type: NVIDIA A100-PCIE-40GB

Who can help?

@gante

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Hi!

If you set trigger_error to True, you will see the differences for the decoder-attention (also for the cross-attention) shape when the translation is generated by model.generate() and model(). I don't know if this is a bug or just expected to be different. I have checked that the attention values are the same when all the information is structured the same way (there are differences in precision though, which I think is because model.generate() generates differently than model()).

import torch
import transformers

trigger_error = True # Chante THIS
pretrained_model = "facebook/nllb-200-distilled-600M"
device = "cuda" if torch.cuda.is_available() else "cpu" 
source_lang = "eng_Latn"
target_lang = "spa_Latn"
tokenizer = transformers.AutoTokenizer.from_pretrained(pretrained_model, src_lang=source_lang, tgt_lang=target_lang)
model = transformers.AutoModelForSeq2SeqLM.from_pretrained(pretrained_model).to(device)
source_text = ["Hello!", "Hello again!!!!!!"]
inputs = tokenizer(source_text, return_tensors="pt", add_special_tokens=True, max_length=1024, truncation=True, padding=True).to(device)
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang)
translated_tokens = model.generate(**inputs, forced_bos_token_id=target_lang_id, max_new_tokens=200, num_return_sequences=1, num_beams=1, return_dict_in_generate=True, output_attentions=True)
translated_tokens_model = model(**inputs, decoder_input_ids=translated_tokens.sequences[:,:-1], output_attentions=True)

# Checks
num_hidden_layers = model.config.num_hidden_layers
num_attention_heads = model.config.num_attention_heads
batch_size = len(source_text)

# Encoder
assert len(translated_tokens.encoder_attentions) == num_hidden_layers # 12
assert len(translated_tokens_model.encoder_attentions) == num_hidden_layers # 12

for l in range(num_hidden_layers):
  assert translated_tokens.encoder_attentions[l].shape == translated_tokens_model.encoder_attentions[l].shape
  assert (translated_tokens.encoder_attentions[l] == translated_tokens_model.encoder_attentions[l]).all().cpu().item()

#####

def transformer_attention_to_common_structure(attention_ttg, attention_ttm):
  ## Transform attention from model.generate() to common structure with model()
  decoded_tokens = attention_ttm[0].shape[-2:]
  _decoder_attention = torch.zeros(num_hidden_layers, batch_size, num_attention_heads, *decoded_tokens).to(device)

  for _decoded_tokens, t in enumerate(attention_ttg, 1):
    # Causal mask
    t = torch.stack(t, 0) # (num_hidden_layers, batch_size, num_attention_heads, 1, _decoded_tokens)
    t = t.squeeze(-2)
    _decoder_attention[:,:,:, _decoded_tokens - 1,:t.shape[-1]] = t

  for l in range(num_hidden_layers):
    assert _decoder_attention.shape == (len(attention_ttm), *attention_ttm[l].shape)
#    assert (_decoder_attention[l] == attention_ttm[l]).all().cpu().item() # Differences due to precision (even when device="cpu")... model.generate() is generating different to model()?
    assert torch.isclose(_decoder_attention[l], attention_ttm[l]).all().cpu().item()

# Decoder
assert len(translated_tokens.decoder_attentions) == num_hidden_layers if trigger_error else True
assert len(translated_tokens_model.decoder_attentions) == num_hidden_layers # 12

transformer_attention_to_common_structure(translated_tokens.decoder_attentions, translated_tokens_model.decoder_attentions)

# Cross
assert len(translated_tokens.cross_attentions) == num_hidden_layers if trigger_error else True
assert len(translated_tokens_model.cross_attentions) == num_hidden_layers # 12

transformer_attention_to_common_structure(translated_tokens.cross_attentions, translated_tokens_model.cross_attentions)

Expected behavior

I would expect to have the same format for the decoder and cross-attention shape regardless of where I use model.generate() or model(). Specifically, I would expect to obtain the result from model(), which for the decoder we obtain a matrix for each layer of the shape (batch_size, attention_heads, generated_tokens - 1, generated_tokens - 1).

@cgr71ii cgr71ii added the bug label Sep 4, 2024
@gante
Copy link
Member

gante commented Sep 4, 2024

Hi @cgr71ii 👋 Thank you for opening this issue 🤗

As shown in our documentation, the output of generate is different from the output of forward.

Namely, generate's attention output is a tuple where each item is the attention output of one forward pass. In your example, if you replace e.g. translated_tokens.decoder_attentions by translated_tokens.decoder_attentions[0] you'll obtain the results you were expecting :)

@cgr71ii
Copy link
Author

cgr71ii commented Sep 6, 2024

Oh, ok! Thank you! :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

3 participants