Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT fix: T5 prefix-tuning with new cache format #34312

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Oct 22, 2024

What does this PR do?

Fixes prefix tuning for T5 models which started breaking after the recent compile-compatibility PR. However, note this will not enable correct prefix tuning unless huggingface/peft#2096 (review) is merged

The main issue was that in new cache format we don't concatenate new key/values with cached key/values if past_key_values.is_updated. That issue is fixed on PEFT side by init a cache object and setting is_updated=False.

Another issue is the mask shape mismatch since we have to extend cross_attention_mask to account for new virtual tokens. Current PR adds enables it but I am also thinking if it will be possible to do in PEFT and pass already extended attention masks?

Tested with the below code for T5:

import torch
from transformers import AutoModelForSeq2SeqLM
from peft import PrefixTuningConfig, get_peft_model

inputs = {
    "input_ids": torch.tensor([[1, 2, 3, 4, 5, 6, 7]]),
    "decoder_input_ids": torch.tensor([[1, 2, 3, 4, 5]]),
    "attention_mask": torch.tensor([[1, 1, 1, 1, 1, 1, 1]]),
}
model_id = "ybelkada/tiny-random-T5ForConditionalGeneration-calibrated"
model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
model(**inputs)

config = PrefixTuningConfig(num_virtual_tokens=20, task_type="SEQ_2_SEQ_LM")
model = get_peft_model(model, config)

output = model(**inputs)

cc @BenjaminBossan wdyt?

@zucchini-nlp zucchini-nlp changed the title T5 compile PEFT fix: T5 prefix-tuning with new cache format Oct 22, 2024
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@BenjaminBossan
Copy link
Member

Just to update, we're internally discussing other options.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants