Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer #32860

Merged
merged 25 commits into from
Aug 23, 2024

Conversation

JasonZhu1313
Copy link
Contributor

@JasonZhu1313 JasonZhu1313 commented Aug 17, 2024

What does this PR do?

Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to HF Trainer with optional flag

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

Tests:

  • pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_apply_liger_kernel
pytest tests/trainer/test_trainer.py::TrainerIntegrationTest::test_use_liger_kernel_patching tests/trainer/test_trainer.py::TrainerIntegrationTest::test_use_liger_kernel_trainer

======================================= test session starts ========================================
platform linux -- Python 3.10.12, pytest-7.4.4, pluggy-1.5.0
rootdir: /content/transformers-jaszhu
configfile: pyproject.toml
plugins: rich-0.1.1, timeout-2.3.1, xdist-3.6.1
collected 2 items

tests/trainer/test_trainer.py ..                                                             [100%]

======================================== 2 passed in 9.47s =========================================


  • E2E test
{'loss': 1.6157, 'grad_norm': 32.0, 'learning_rate': 2.4324324324324326e-07, 'epoch': 0.0, 'num_input_tokens_seen': 60416, 'step': 3, 'step_time_sec': 4.87, 'avg_step_time_sec': 6.82, 'time_to_completion_sec': 4970.4, 'estimated_total_time_sec': 4990.85, 'step_peak_memory_allocated_MB': 76728.45, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 79692.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3138.55, 'avg_tokens_per_second': 3158.65}
{'loss': 1.5678, 'grad_norm': 26.875, 'learning_rate': 3.2432432432432436e-07, 'epoch': 0.01, 'num_input_tokens_seen': 84992, 'step': 4, 'step_time_sec': 7.82, 'avg_step_time_sec': 7.15, 'time_to_completion_sec': 5206.53, 'estimated_total_time_sec': 5235.14, 'step_peak_memory_allocated_MB': 76728.67, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80194.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3142.99, 'avg_tokens_per_second': 3152.94}
{'loss': 1.74, 'grad_norm': 28.875, 'learning_rate': 4.0540540540540546e-07, 'epoch': 0.01, 'num_input_tokens_seen': 103936, 'step': 5, 'step_time_sec': 5.75, 'avg_step_time_sec': 6.8, 'time_to_completion_sec': 4945.07, 'estimated_total_time_sec': 4979.08, 'step_peak_memory_allocated_MB': 76728.54, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80324.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3293.14, 'avg_tokens_per_second': 3182.59}
{'loss': 1.7297, 'grad_norm': 29.25, 'learning_rate': 4.864864864864865e-07, 'epoch': 0.01, 'num_input_tokens_seen': 124416, 'step': 6, 'step_time_sec': 6.23, 'avg_step_time_sec': 6.69, 'time_to_completion_sec': 4855.78, 'estimated_total_time_sec': 4895.91, 'step_peak_memory_allocated_MB': 76728.57, 'total_peak_memory_allocated_MB': 76728.74, 'step_peak_memory_reserved_MB': 80288.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3285.22, 'avg_tokens_per_second': 3201.72}
{'loss': 1.6393, 'grad_norm': 27.75, 'learning_rate': 5.675675675675676e-07, 'epoch': 0.01, 'num_input_tokens_seen': 153920, 'step': 7, 'step_time_sec': 9.22, 'avg_step_time_sec': 7.11, 'time_to_completion_sec': 5154.73, 'estimated_total_time_sec': 5204.5, 'step_peak_memory_allocated_MB': 76728.78, 'total_peak_memory_allocated_MB': 76728.78, 'step_peak_memory_reserved_MB': 79652.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3200.77, 'avg_tokens_per_second': 3201.51}
{'loss': 1.5642, 'grad_norm': 27.25, 'learning_rate': 6.486486486486487e-07, 'epoch': 0.01, 'num_input_tokens_seen': 170752, 'step': 8, 'step_time_sec': 5.49, 'avg_step_time_sec': 6.88, 'time_to_completion_sec': 4980.15, 'estimated_total_time_sec': 5035.18, 'step_peak_memory_allocated_MB': 76728.49, 'total_peak_memory_allocated_MB': 76728.78, 'step_peak_memory_reserved_MB': 79988.0, 'total_peak_memory_reserved_MB': 80364.0, 'step_tokens_per_second': 3065.48, 'avg_tokens_per_second': 3186.0}

  • When liger is lower version, the error is thrown ImportError: You have set use_ligertoTruebut liger-kernel >= 0.1.0 is not available. Please install it withpip install liger-kernel`
  • Model type is correct extracted as "llama"

Screenshot 2024-08-20 at 3 13 45 PM

Test conditions: LLaMA 3-8B, Batch Size = 64, Data Type = bf16, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 4 A100s.

When use_liger=Ture, memory usage and throughput shows improvement compared to use_liger=False, default value

image (3)
image (4)

Note: for more detailed benchmark setup and more exciting efficiency for multi-head training (Medusa), please refer to original repo: https://github.com/linkedin/Liger-Kernel (repo will be public soon!!!)

@amyeroberts
Copy link
Collaborator

cc @ArthurZucker @muellerzr

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds great to me! Let's make sure we add a tad bit of doc about it! 🤗

@JasonZhu1313 JasonZhu1313 changed the title [WIP] Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer Integrate Liger (Linkedin GPU Efficient Runtime) Kernel to Trainer Aug 19, 2024
Copy link
Contributor

@muellerzr muellerzr left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Can you rebase from main? (This should fix the CI I think)

@SunMarc SunMarc mentioned this pull request Aug 20, 2024
5 tasks
shimizust and others added 6 commits August 20, 2024 09:35
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
@shimizust shimizust force-pushed the jaszhu/liger-kernel branch from ade13f4 to 8639629 Compare August 20, 2024 16:51
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@JasonZhu1313 JasonZhu1313 marked this pull request as ready for review August 20, 2024 21:52
docs/source/en/trainer.md Outdated Show resolved Hide resolved
src/transformers/training_args.py Outdated Show resolved Hide resolved
@ByronHsu
Copy link
Contributor

lgtm!

Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice ! Just a nit ! Also, let us know when you want to merge this PR as the Liger repo is still not public.

tests/trainer/test_trainer.py Outdated Show resolved Hide resolved
@muellerzr
Copy link
Contributor

@JasonZhu1313 if you run make fixup it should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀

@JasonZhu1313
Copy link
Contributor Author

@JasonZhu1313 if you run make fixup it should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀

Thanks the repo will be open sourced on Friday

@JasonZhu1313
Copy link
Contributor Author

@JasonZhu1313 if you run make fixup it should fix the quality tests :) Otherwise as Marc said, let us know when we're okay to land this and we'll merge it immediately 🚀

Thanks the repo will be open sourced on Friday

The code is open to public, we are ready to merge the PR!

Copy link
Contributor

@ByronHsu ByronHsu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Excited to collaborate with Hugging Face!!

@SunMarc
Copy link
Member

SunMarc commented Aug 23, 2024

Nice ! Merging !

@SunMarc SunMarc merged commit adb9117 into huggingface:main Aug 23, 2024
24 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
…uggingface#32860)

* add liger integration

* fix syntax

* fix import issue

* add trainer.md

* Use _apply_liger_kernel()

* Fixed log message

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/en/trainer.md

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Fixed checkstyle and updated readme

* Added test

* Fixed checkstyle

* fix docstring

* rename use_liger to use_liger_kernel

* Trigger Build

* Added test

* add fix-copies

* Fixed copy inconsistencies

---------

Co-authored-by: shimizust <sshimizu@linkedin.com>
Co-authored-by: Steven Shimizu <shimizust@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Aug 30, 2024
…uggingface#32860)

* add liger integration

* fix syntax

* fix import issue

* add trainer.md

* Use _apply_liger_kernel()

* Fixed log message

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/en/trainer.md

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Fixed checkstyle and updated readme

* Added test

* Fixed checkstyle

* fix docstring

* rename use_liger to use_liger_kernel

* Trigger Build

* Added test

* add fix-copies

* Fixed copy inconsistencies

---------

Co-authored-by: shimizust <sshimizu@linkedin.com>
Co-authored-by: Steven Shimizu <shimizust@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
itazap pushed a commit to NielsRogge/transformers that referenced this pull request Sep 20, 2024
…uggingface#32860)

* add liger integration

* fix syntax

* fix import issue

* add trainer.md

* Use _apply_liger_kernel()

* Fixed log message

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/en/trainer.md

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Fixed checkstyle and updated readme

* Added test

* Fixed checkstyle

* fix docstring

* rename use_liger to use_liger_kernel

* Trigger Build

* Added test

* add fix-copies

* Fixed copy inconsistencies

---------

Co-authored-by: shimizust <sshimizu@linkedin.com>
Co-authored-by: Steven Shimizu <shimizust@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
BernardZach pushed a commit to BernardZach/transformers that referenced this pull request Dec 5, 2024
…uggingface#32860)

* add liger integration

* fix syntax

* fix import issue

* add trainer.md

* Use _apply_liger_kernel()

* Fixed log message

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/en/trainer.md

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Fixed checkstyle and updated readme

* Added test

* Fixed checkstyle

* fix docstring

* rename use_liger to use_liger_kernel

* Trigger Build

* Added test

* add fix-copies

* Fixed copy inconsistencies

---------

Co-authored-by: shimizust <sshimizu@linkedin.com>
Co-authored-by: Steven Shimizu <shimizust@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
BernardZach pushed a commit to innovationcore/transformers that referenced this pull request Dec 6, 2024
…uggingface#32860)

* add liger integration

* fix syntax

* fix import issue

* add trainer.md

* Use _apply_liger_kernel()

* Fixed log message

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update docs/source/en/trainer.md

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update src/transformers/trainer.py

Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>

* Update src/transformers/training_args.py

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Update docs/source/en/trainer.md

Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>

* Fixed checkstyle and updated readme

* Added test

* Fixed checkstyle

* fix docstring

* rename use_liger to use_liger_kernel

* Trigger Build

* Added test

* add fix-copies

* Fixed copy inconsistencies

---------

Co-authored-by: shimizust <sshimizu@linkedin.com>
Co-authored-by: Steven Shimizu <shimizust@gmail.com>
Co-authored-by: Marc Sun <57196510+SunMarc@users.noreply.github.com>
Co-authored-by: Byron Hsu <byronhsu1230@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

9 participants