Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: Transpose weight matrix based on fan_in_fan_out condition in PiSSA initialization (#2103) #2104

Merged

Conversation

suyang160
Copy link
Contributor

Previously, the weight matrix was converted to float32 without considering the need for transposition. This update ensures that the weight matrix is transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this PR. First of all, my apologies for not responding earlier. The notification somehow slipped my attention and I just wasn't aware of this PR. In the future, feel free to ping my after a couple of days when there is no response.

The fix looks good, thanks for that. Let's add some tests to ensure that this bug doesn't happen again. For this, could you please add the following tests to the existing PiSSA tests:

    @pytest.mark.parametrize("device", ["cuda", "cpu"])
    def test_gpt2_pissa_4bit(self, device, tmp_path):
        # see 2104
        self.get_errors(bits=4, device=device, model_id="gpt2", tmp_path=tmp_path)

    @pytest.mark.parametrize("device", ["cuda", "cpu"])
    def test_gpt2_pissa_8bit(self, device, tmp_path):
        # see 2104
        self.get_errors(bits=8, device=device, model_id="gpt2", tmp_path=tmp_path)

For this to work, we need to make some changes to these lines though:

if isinstance(module, torch.nn.Linear) and "lm_head" not in name:

if isinstance(module, torch.nn.Linear) and "lm_head" not in name:

There, we need to change isinstance(module, torch.nn.Linear) to isinstance(module, (torch.nn.Linear, Conv1D)), where Conv1D is imported from transformers.pytorch_utils.

@suyang160 suyang160 force-pushed the bugfix/issue-2103-fix-pissa-init branch from 2a513f6 to 4d77af8 Compare October 8, 2024 16:10
…SA initialization (huggingface#2103)

This update addresses an issue where the weight matrix was converted to float32 without considering the need for transposition. The weight matrix is now transposed when the fan_in_fan_out condition is met, resolving dimension mismatch issues during GPT-2 training.

To ensure this fix is robust, tests have been updated to include parameterized cases for different devices and bit configurations. Additionally, the isinstance checks have been modified to include Conv1D layers, ensuring all relevant layers are processed correctly.
@suyang160 suyang160 force-pushed the bugfix/issue-2103-fix-pissa-init branch from 4d77af8 to 1bf7d7a Compare October 8, 2024 16:18
@suyang160
Copy link
Contributor Author

@BenjaminBossan Thank you for your feedback and suggestions, I've updated this PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@BenjaminBossan BenjaminBossan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for this fix, LGTM.

@BenjaminBossan BenjaminBossan merged commit a724834 into huggingface:main Oct 8, 2024
14 checks passed
BenjaminBossan pushed a commit to BenjaminBossan/peft that referenced this pull request Oct 22, 2024
…ce#2104)

Transpose weight matrix based on fan_in_fan_out condition in PiSSA
initialization.

Co-authored-by: Yang Su <suyang360@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants