Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Fix to prefix tuning to fit transformers #2096

Draft
wants to merge 27 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
7c8e287
[WIP] Fix to prefix tuning to fit transformers
BenjaminBossan Sep 25, 2024
b666532
Update src/peft/peft_model.py
BenjaminBossan Oct 22, 2024
73496ee
FIX: Change check if past_key_values is empty (#2106)
BenjaminBossan Sep 27, 2024
d60d1b6
DOC Update source install instruction (#2110)
Salehbigdeli Sep 30, 2024
faa4dd8
FIX Refactor OFT, small changes to BOFT (#1996)
Zeju1997 Oct 1, 2024
5cd5a45
ENH: Improved attribute access for modules_to_save (#2117)
BenjaminBossan Oct 2, 2024
0312b30
FIX low_cpu_mem_usage consolidates devices (#2113)
BenjaminBossan Oct 2, 2024
4c50892
TST Mark flaky X-LoRA test as xfail (#2114)
BenjaminBossan Oct 2, 2024
8699ba4
ENH: Warn when from_pretrained misses PEFT keys (#2118)
BenjaminBossan Oct 2, 2024
9ddc9f1
FEAT: Adding exclude modules param(#2044) (#2102)
JINO-ROHIT Oct 3, 2024
5a560da
FIX BC breaking change to boft conv2d scaling variable (#2127)
Zeju1997 Oct 7, 2024
d10151e
FEAT: VeRA quantization using bitsandbytes (#2070) (#2076)
ZiadHelal Oct 7, 2024
1d55d8b
Bump version to 0.13.2.dev0 (#2137)
BenjaminBossan Oct 8, 2024
98cf284
FEAT: Support torchao (#2062)
BenjaminBossan Oct 8, 2024
7961e8c
FIX: PiSSA now works with Conv1D layers (#2103) (#2104)
suyang160 Oct 8, 2024
fe8ba8e
FIX Type annoations in vera/bnb.py (#2139)
BenjaminBossan Oct 9, 2024
171cc75
ENH Make PEFT configs forward compatible (#2038)
BenjaminBossan Oct 9, 2024
858e1d2
FIX Raise mixed adapter infer with missing adapter (#2090)
BenjaminBossan Oct 9, 2024
b494d0e
FIX Prompt learning with latest transformers error (#2140)
BenjaminBossan Oct 9, 2024
f2d40e7
ENH LoRA notebook for NER task (#2126)
JINO-ROHIT Oct 10, 2024
7e5519a
FIX TST NaN issue with HQQ GPU test (#2143)
BenjaminBossan Oct 10, 2024
d0c22b3
FIX Bug in target module optimization if suffix (#2144)
BenjaminBossan Oct 10, 2024
3d205bc
Bump version to 0.13.2.dev0 (#2145)
BenjaminBossan Oct 11, 2024
7dfd956
FIX Don't assume past_key_valus for encoder models (#2149)
BenjaminBossan Oct 14, 2024
e74a6b9
FIX Use `SFTConfig` instead of `SFTTrainer` keyword args (#2150)
qgallouedec Oct 15, 2024
f481c5d
make style
BenjaminBossan Oct 22, 2024
9b223ea
Merge branch 'main' into fix-prefix-tuning-dynamic-cache
BenjaminBossan Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion src/peft/peft_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,7 +730,21 @@ def get_prompt(self, batch_size: int, task_ids: Optional[torch.Tensor] = None) -
if TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING.get(self.config.model_type, None) is not None:
post_process_fn = TRANSFORMERS_MODELS_TO_PREFIX_TUNING_POSTPROCESS_MAPPING[self.config.model_type]
past_key_values = post_process_fn(past_key_values)
return past_key_values
elif peft_config.num_transformer_submodules == 1:
# Dont' apply this to encoder-decoder models and not to models requiring special processing.
# local import in case users use a very old transformers version
from transformers import DynamicCache

past_key_values = DynamicCache.from_legacy_cache(past_key_values)
elif peft_config.num_transformer_submodules == 2 and self.base_model._supports_cache_class:
# Dont' apply this to encoder-decoder models that don't support new Cachc format yet
# If we don't apply this, prefix-tuning fails to update cross-attn cache
from transformers import EncoderDecoderCache

past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
past_key_values.is_updated = {
layer_idx: False for layer_idx in range(len(past_key_values.cross_attention_cache.key_cache))
}
else:
if peft_config.peft_type == PeftType.MULTITASK_PROMPT_TUNING:
prompts = prompt_encoder(prompt_tokens, task_ids)
Expand Down
42 changes: 41 additions & 1 deletion tests/test_decoder_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
from unittest.mock import Mock, call, patch

import pytest
import torch
from datasets import load_dataset
from parameterized import parameterized
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)

from peft import (
AdaLoraConfig,
Expand Down Expand Up @@ -466,3 +474,35 @@ def test_prompt_learning_with_grouped_query_attention(self):
x = torch.tensor([[1, 2, 3]])
# does not raise
model(x)

def test_prefix_tuning_foobar(self):
# TODO
# See issue 869, 1962
model_id = "hf-internal-testing/tiny-random-MistralForCausalLM"
base_model = AutoModelForCausalLM.from_pretrained(model_id)
peft_config = PrefixTuningConfig(num_virtual_tokens=10, task_type="CAUSAL_LM")
model = get_peft_model(base_model, peft_config)

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

def process(samples):
tokenized = tokenizer(samples["quote"], truncation=True, max_length=128)
return tokenized

data = load_dataset("ybelkada/english_quotes_copy")
data = data.map(process, batched=True)

with tempfile.TemporaryDirectory() as tmp_dirname:
trainer = Trainer(
model=model,
train_dataset=data["train"],
args=TrainingArguments(
num_train_epochs=1,
max_steps=5,
per_device_train_batch_size=4,
output_dir=tmp_dirname,
),
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
trainer.train()
Loading