Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐯 Fix LigerKernel for SFTTrainer #2940

Merged
merged 3 commits into from
Feb 24, 2025
Merged

🐯 Fix LigerKernel for SFTTrainer #2940

merged 3 commits into from
Feb 24, 2025

Conversation

lewtun
Copy link
Member

@lewtun lewtun commented Feb 23, 2025

What does this PR do?

Fixes an issue where use_liger=True threw an error in the loss computation because the fix in #2874 was partially reverted in this line of #2890. Without this change, one gets errors on the loss computation

[rank0]:     shift_logits = outputs.logits[..., :-1, :].contiguous()
[rank0]:                    ~~~~~~~~~~~~~~^^^^^^^^^^^^^
[rank0]: TypeError: 'NoneType' object is not subscriptable

Command to test:

trl sft --model_name_or_path Qwen/Qwen2.5-0.5B     --dataset_name trl-lib/Capybara     --output_dir Qwen2.5-0.5B-SFT --use_liger true

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@lewtun lewtun requested review from kashif and qgallouedec February 23, 2025 19:21
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@@ -173,7 +173,6 @@ def __init__(
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)
self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if the model used is already a liger model (and args.use_liger = False)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but as recommended by the Liger maintainer, we shouldn't be passing Liger models at the init (just patched via the config)

Should we deprecate passing the Liger model to the trainer or would you prefer an alternative?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that this command doesn't achieve what it's supposed to:

>>> from liger_kernel.transformers import AutoLigerKernelForCausalLM
>>> model =  AutoLigerKernelForCausalLM.from_pretrained("Qwen/Qwen2.5-0.5B")
Applied Liger kernels to Qwen2
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.
>>> isinstance(model, AutoLigerKernelForCausalLM)
False

I don't know of an easy way to test whether a model is liger (perhaps @ByronHsu does?).

Anyway, with your change you can still pass a liger model to the trainer. But you'll need to specify use_liger=True. Which sounds good to me.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line doesn't work when a Liger Model is converted to PEFT before passing into the trainer. It does not respect args.use_liger either.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, but as recommended by the Liger maintainer, we shouldn't be passing Liger models at the init (just patched via the config)

Should we deprecate passing the Liger model to the trainer or would you prefer an alternative?

Deprecating it might cause some problem. LoRA+ requires the model instance to be created before hand (to create optimizer). The flag use_liger does not convert PEFT wrapped model to liger model.

model = get_peft_model(model, lora_config)
optimizer = create_loraplus_optimizer(
    model=model,
    optimizer_cls=torch.optim.AdamW,
    lr=lr,
    eps=eps,
    betas=betas,
    weight_decay=weight_decay,
    loraplus_lr_ratio=loraplus_lr_ratio,
)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The flag use_liger does not convert PEFT wrapped model to liger model.

in sft_tariner.py:

        if args.use_liger:
            if not is_liger_kernel_available():
                raise ImportError("Please install Liger-kernel for use_liger=True")
            model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
        else:
            model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
        return model

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, many VLMs can't be loaded with AutoLigerKernelForCausalLM. For example, monkey patching apply_liger_kernel_to_qwen2_vl() is required for Qwen2-VL

@edbeeching
Copy link
Collaborator

@kashif @qgallouedec @lewtun ,

I think Liger models are now supported nativately in in transformers if the use_liger_kernel==True flag is set, perhaps we can drop the support for this in the SFTTrainer and use the native transformers implementation?

@kashif
Copy link
Collaborator

kashif commented Feb 24, 2025

i think so too... we will need to pin the transformer version but yes should be a better soluton

@qgallouedec
Copy link
Member

qgallouedec commented Feb 24, 2025

Thanks @edbeeching!

i think so too... we will need to pin the transformer version but yes should be a better soluton

After checking, it seems like use_liger_kernel exists for at least 4.46, which is the min version in TRL. So we shouldn't need to bump transformers

https://github.com/huggingface/transformers/blob/052e652d6d53c2b26ffde87e039b723949a53493/src/transformers/training_args.py#L1521

@qgallouedec
Copy link
Member

qgallouedec commented Feb 24, 2025

it was introduced in 4.45

@lewtun
Copy link
Member Author

lewtun commented Feb 24, 2025

Thanks for the pointer to transformers! Just to confirm, the current PR is fine to merge since the model init is taken care of in this line? https://github.com/huggingface/trl/pull/2940/files#r1967113819

If yes, feel free to merge if I'm offline :)

@qgallouedec qgallouedec changed the title Fix LigerKernel for SFTTrainer 🐯 Fix LigerKernel for SFTTrainer Feb 24, 2025
@qgallouedec qgallouedec merged commit 5c05913 into main Feb 24, 2025
14 checks passed
@qgallouedec qgallouedec deleted the fix-liger-sft branch February 24, 2025 16:29
qgallouedec added a commit that referenced this pull request Feb 25, 2025
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
jhinpan pushed a commit to jhinpan/trl-jin that referenced this pull request Mar 12, 2025
Co-authored-by: Quentin Gallouédec <45557362+qgallouedec@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants