Skip to content

Commit

Permalink
FIX: Transpose weight matrix based on fan_in_fan_out condition in PiS…
Browse files Browse the repository at this point in the history
…SA initialization (#2103)

This update addresses an issue where the weight matrix was converted to float32 without considering the need for transposition. The weight matrix is now transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.

To ensure this fix is robust, tests have been updated to include parameterized cases for different devices and bit configurations. Additionally, the isinstance checks have been modified to include Conv1D layers, ensuring all relevant layers are processed correctly.
  • Loading branch information
Yang Su committed Oct 8, 2024
1 parent 5e91b54 commit 4d77af8
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
4 changes: 2 additions & 2 deletions src/peft/tuners/lora/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def pissa_init(self, adapter_name, init_lora_weights):
"Please initialize PiSSA under float32, float16, or bfloat16. "
"Subsequently, re-quantize the residual model to help minimize quantization errors."
)
weight = weight.to(torch.float32)
weight = transpose(weight.to(torch.float32), self.fan_in_fan_out)
if init_lora_weights == "pissa":
# USV^T = W <-> VSU^T = W^T, where W^T = weight.data in R^{out_channel, in_channel},
V, S, Uh = torch.linalg.svd(weight.data, full_matrices=False)
Expand All @@ -245,7 +245,7 @@ def pissa_init(self, adapter_name, init_lora_weights):
self.lora_A[adapter_name].weight.data = lora_A
self.lora_B[adapter_name].weight.data = lora_B
weight = weight.data - self.scaling[adapter_name] * lora_B @ lora_A
weight = weight.to(dtype)
weight = transpose(weight.to(dtype), self.fan_in_fan_out)
self.get_base_layer().weight.data = weight

def loftq_init(self, adapter_name):
Expand Down
15 changes: 13 additions & 2 deletions tests/test_gpu_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
WhisperProcessor,
WhisperTokenizer,
)
from transformers.pytorch_utils import Conv1D

from peft import (
AdaLoraConfig,
Expand Down Expand Up @@ -1718,7 +1719,7 @@ def quantize_model(self, model, num_bits=4, device="cuda"):
# Quantize the `weight.data` of the linear layer in the model to `num_bits` and store it with full precision.
quantizer = NFQuantizer(num_bits=num_bits, device=device, method="normal", block_size=64)
for name, module in model.named_modules():
if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
if isinstance(module, (torch.nn.Linear, Conv1D)) and "lm_head" not in name:
quantized_weight, max_abs, shape = quantizer.quantize_block(module.weight.data.to(device))
module.weight.data = quantizer.dequantize_block(quantized_weight, max_abs, shape)
return model
Expand All @@ -1727,7 +1728,7 @@ def nuclear_norm(self, base_model, quantized_model):
# Calculate the nuclear norm (sum of singular values) of the error matrices between the `quantized_model` and the `base_model`.
error_list = []
for name, module in base_model.named_modules():
if isinstance(module, torch.nn.Linear) and "lm_head" not in name:
if isinstance(module, (torch.nn.Linear, Conv1D)) and "lm_head" not in name:
quant_module = quantized_model.get_submodule(name)
error_list.append(torch.linalg.svdvals(module.weight.data - quant_module.weight.data).sum())
return torch.Tensor(error_list).sum()
Expand Down Expand Up @@ -1821,6 +1822,16 @@ def test_t5_pissa_4bit(self, device, tmp_path):
def test_t5_pissa_8bit(self, device, tmp_path):
self.get_errors(bits=8, device=device, model_id="t5-small", tmp_path=tmp_path)

@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_gpt2_pissa_4bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path)

@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_gpt2_pissa_8bit(self, device, tmp_path):
# see 2104
self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path)

@require_bitsandbytes
def test_lora_pissa_conversion_same_output_after_loading_with_quantization(self, tmp_path):
# A copy of the test `test_lora_pissa_conversion_same_output_after_loading` in peft/tests/test_initialization.py,
Expand Down

0 comments on commit 4d77af8

Please sign in to comment.