Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEFT doesn't inject virtual tokens into generate forward pass #2134

Open
2 of 4 tasks
Kami-chanw opened this issue Oct 6, 2024 · 3 comments
Open
2 of 4 tasks

PEFT doesn't inject virtual tokens into generate forward pass #2134

Kami-chanw opened this issue Oct 6, 2024 · 3 comments

Comments

@Kami-chanw
Copy link

System Info

  • transformers version: 4.46.0.dev0
  • Platform: Linux-5.4.0-148-generic-x86_64-with-glibc2.31
  • Python version: 3.9.19
  • Huggingface_hub version: 0.24.0
  • Safetensors version: 0.4.3
  • Accelerate version: 0.33.0
  • Accelerate config: not found
  • PyTorch version (GPU?): 2.4.1+cu121 (True)
  • Tensorflow version (GPU?): not installed (NA)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Using distributed or parallel set-up in script?:
  • Using GPU in script?:
  • GPU type: NVIDIA RTX A6000

Who can help?

@BenjaminBossan @sayakpaul

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I met the problem by the following code:

import peft
from transformers import IdeficsForVisionText2Text, AutoProcessor
import sys
import torch

sys.path.insert(0, "..")
import config  # my own file

device = torch.device("cuda:3")
model = IdeficsForVisionText2Text.from_pretrained(
    config.idefics_9b_path, torch_dtype=torch.float16
).to(device)
processor = AutoProcessor.from_pretrained(
    config.idefics_9b_path, torch_dtype=torch.float16
)
model = peft.get_peft_model(
    model,
    peft.PrefixTuningConfig(
        peft_type="PREFIX_TUNING",
        task_type="CAUSAL_LM",
        num_virtual_tokens=2,
        token_dim=4096,
        num_transformer_submodules=1,
        num_attention_heads=32,
        num_layers=32,
        encoder_hidden_size=768,
    ),
    mixed=False,
)
inputs = processor(["hello"]).to(device)
model.eval()
model.generate(**inputs)

When I add print(past_key_values) in transformers side, I got DynamicCache(), which means the virtual tokens weren't injected to forward pass.

Expected behavior

It should get a cache with length of num_virtual_tokens.

@BenjaminBossan
Copy link
Member

Could you please clarify where you check past_key_values and what you would expect there?

Also, note that #2096 is in the works that should hopefully fix some issues that prefix tuning has with the latest transformers version. If possible, you could check if that branch fixes the error.

@Kami-chanw
Copy link
Author

Kami-chanw commented Oct 7, 2024

#2096 never fixes anything except suppressing the warnings since here will convert legacy past_key_values to correct Cache instance. I added a print(past_key_values) at this line and I got DynamicCache(), which made me curious about whether prefix tokens were injected to forward pass when generating.

Then, I started to debug step by step until I found here. The past_key_values is DynamicCache() while num_virtual_tokens is 1. I wonder if this is correct behavior because in my opinion, the virtual tokens should be injected as past_key_values just as what we done in training procedure.

@BenjaminBossan
Copy link
Member

When I add print(past_key_values) in transformers side, I got DynamicCache(), which means the virtual tokens weren't injected to forward pass.

Then, I started to debug step by step until I found here. The past_key_values is DynamicCache() while num_virtual_tokens is 1. I wonder if this is correct behavior because in my opinion, the virtual tokens should be injected as past_key_values just as what we done in training procedure.

I'm not sure if I follow. I set a debugger at the line you mentioned. past_key_values is indeed a DynamicCache instance, and it should contain the virtual tokens. When I have 2 virtual tokens, past_key_values.get_seq_length() returns 2. When I have 20 virtual tokens, past_key_values.get_seq_length() returns 20, etc. What would be your expectation?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants