Skip to content

Commit

Permalink
ENH FIX Allow "all-linear" to target custom models (#2267)
Browse files Browse the repository at this point in the history
Description

When the option to specify target_modules="all-linear" was introduced in
PEFT (#1295), the restriction was added to only allow it for instances
of PreTrainedModel. This was because we want to exclude the output layer
from being targeted, even if it is a linear layer, and we can't
determine this layer well except by convention.

This PR lifts the restriction to PreTrainedModels. Thus, users can now
target other models like diffusers models or custom models. The caveat
is to use this "at your own risk", since all linear layers will be
targeted, whether they be output layers or not.

Bugfix

While working on this, I found a potential bug. The logic for updating
target_modules was that only the last part of the linear module's name
was used. So e.g. if the module was named "foo.bar.baz", then "baz" was
added to target_modules. This will lead to problems if there is another
"baz" module that is not a linear layer.

This bug was fixed by adding the full name ("foo.bar.baz" in this
example) to the updated target_modules. This can potentially lead to big
target_modules with a lot of almost repititions, but it's worth it to
avoid targeting the wrong module.

It is not clear to me why only the last part was added. The PR that
added this to PEFT copied that part from here:

https://github.com/artidoro/qlora/blob/7f4e95a68dc076bea9b3a413d2b512eca6d004e5/qlora.py#L248

but it's not clear why that repo did it that way. Maybe it was just to
keep the set size smaller.

The bug was uncovered by the unet test that is already present. Still, I
extended this test, as well as another one, to better cover this
potential issue, by ensuring that the number of target layers is as
expected.

Backwards compatibility

Technically, this change is breaking backwards compatibility. To go back
to the previous example, let's say we have a module that is called
"conv.baz" and that is a Conv2d layer. With the old behavior, since
"baz" is added to the target_modules, we would now also target this
Conv2d layer, which is supported by LoRA. After merging this PR, the
Conv2d layer would no longer be targeted.

I'd argue this is the correct behavior and thus worth changing. Also,
note that since we override target_modules, this is reflected in the
adapter_config.json. Therefore, if a user loads an adapter that had this
"baz" target, it will still work as it did previously.
  • Loading branch information
BenjaminBossan authored Dec 13, 2024
1 parent 5cdade9 commit a217507
Show file tree
Hide file tree
Showing 3 changed files with 35 additions and 32 deletions.
11 changes: 6 additions & 5 deletions src/peft/tuners/lora/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,10 +132,10 @@ class LoraConfig(PeftConfig):
The names of the modules to apply the adapter to. If this is specified, only the modules with the specified
names will be replaced. When passing a string, a regex match will be performed. When passing a list of
strings, either an exact match will be performed or it is checked if the name of the module ends with any
of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen,
excluding the output layer. If this is not specified, modules will be chosen according to the model
architecture. If the architecture is not known, an error will be raised -- in this case, you should specify
the target modules manually.
of the passed strings. If this is specified as 'all-linear', then all linear/Conv1D modules are chosen (if
the model is a PreTrainedModel, the output layer excluded). If this is not specified, modules will be
chosen according to the model architecture. If the architecture is not known, an error will be raised -- in
this case, you should specify the target modules manually.
exclude_modules (`Optional[Union[List[str], str]]`):
The names of the modules to not apply the adapter. When passing a string, a regex match will be performed.
When passing a list of strings, either an exact match will be performed or it is checked if the name of the
Expand Down Expand Up @@ -225,7 +225,8 @@ class LoraConfig(PeftConfig):
"help": (
"List of module names or regex expression of the module names to replace with LoRA."
"For example, ['q', 'v'] or '.*decoder.*(SelfAttention|EncDecAttention).*(q|v)$'."
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D layers except the output layer."
"This can also be a wildcard 'all-linear' which matches all linear/Conv1D "
"(if the model is a PreTrainedModel, the output layer excluded)."
"If not specified, modules will be chosen according to the model architecture, If the architecture is "
"not known, an error will be raised -- in this case, you should specify the target modules manually."
),
Expand Down
38 changes: 16 additions & 22 deletions src/peft/tuners/tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,37 +1039,31 @@ def _maybe_include_all_linear_layers(peft_config: PeftConfig, model: nn.Module)
):
return peft_config

if not isinstance(model, PreTrainedModel):
raise ValueError(
f"Only instances of PreTrainedModel support `target_modules={INCLUDE_LINEAR_LAYERS_SHORTHAND!r}`"
)

linear_classes = (torch.nn.Linear, Conv1D)

linear_module_names = set()
for name, module in model.named_modules():
# match with all linear classes.
if isinstance(module, linear_classes):
names = name.rsplit(".", 1)[-1] # get the base name
linear_module_names.add(names)
linear_module_names.add(name)

# Try to remove linear layers that should not be targeted as best as possible. We have to rely on convention as
# there are no hard rules to detect these modules.
module_names_to_exclude = set()
output_emb = model.get_output_embeddings()
if output_emb is not None:
# ignore the last classification head for text generation models
last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
module_names_to_exclude.add(last_module_name)
elif peft_config.task_type == TaskType.SEQ_CLS:
# ignore classifier head for classification models (issue 2027)
# there is no fix name for the classifier head, so check the common ones
for name in SEQ_CLS_HEAD_NAMES:
cls_head = getattr(model, name, None)
if cls_head is not None:
last_module_name = [name for name, module in model.named_modules() if module is cls_head][0]
module_names_to_exclude.add(last_module_name)
break
if isinstance(model, PreTrainedModel):
output_emb = model.get_output_embeddings()
if output_emb is not None:
# ignore the last classification head for text generation models
last_module_name = [name for name, module in model.named_modules() if module is output_emb][0]
module_names_to_exclude.add(last_module_name)
elif peft_config.task_type == TaskType.SEQ_CLS:
# ignore classifier head for classification models (issue 2027)
# there is no fix name for the classifier head, so check the common ones
for name in SEQ_CLS_HEAD_NAMES:
cls_head = getattr(model, name, None)
if cls_head is not None:
last_module_name = [name for name, module in model.named_modules() if module is cls_head][0]
module_names_to_exclude.add(last_module_name)
break

linear_module_names -= module_names_to_exclude
peft_config.target_modules = linear_module_names
Expand Down
18 changes: 13 additions & 5 deletions tests/test_tuners_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
AutoModelForSequenceClassification,
BitsAndBytesConfig,
)
from transformers.pytorch_utils import Conv1D

from peft import (
AdaptionPromptConfig,
Expand Down Expand Up @@ -333,11 +334,12 @@ def test_maybe_include_all_linear_layers_diffusion(self):
model_id = "hf-internal-testing/tiny-stable-diffusion-torch"
model = StableDiffusionPipeline.from_pretrained(model_id)
config = LoraConfig(base_model_name_or_path=model_id, target_modules="all-linear")
with pytest.raises(
ValueError,
match="Only instances of PreTrainedModel support `target_modules='all-linear'`",
):
model.unet = get_peft_model(model.unet, config)

# all linear layers should be converted
num_linear = sum(isinstance(module, (nn.Linear, Conv1D)) for module in model.unet.modules())
model.unet = get_peft_model(model.unet, config)
num_lora = sum(isinstance(module, LoraLayer) for module in model.unet.modules())
assert num_lora == num_linear

def test_maybe_include_all_linear_does_not_target_classifier_head(self):
# See issue 2027
Expand All @@ -348,6 +350,8 @@ def test_maybe_include_all_linear_does_not_target_classifier_head(self):
# sanity check
assert isinstance(model.score, nn.Linear)

num_linear = sum(isinstance(module, (nn.Linear, Conv1D)) for module in model.modules())

config = LoraConfig(task_type="SEQ_CLS", target_modules="all-linear")
model = get_peft_model(model, config)
assert isinstance(model.base_model.score, ModulesToSaveWrapper)
Expand All @@ -356,6 +360,10 @@ def test_maybe_include_all_linear_does_not_target_classifier_head(self):
assert isinstance(model.base_model.score.original_module, nn.Linear)
assert isinstance(model.base_model.score.modules_to_save["default"], nn.Linear)

# ensure that all but one linear layer was targeted by LoRA
num_lora = sum(isinstance(module, LoraLayer) for module in model.modules())
assert num_lora == num_linear - 1


class MLP(nn.Module):
def __init__(self, bias=True):
Expand Down

0 comments on commit a217507

Please sign in to comment.