Skip to content

Commit

Permalink
rename use_liger to use_liger_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
JasonZhu1313 committed Aug 20, 2024
1 parent fc05ba6 commit b2bae31
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 8 deletions.
6 changes: 3 additions & 3 deletions docs/source/en/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -395,7 +395,7 @@ First make sure to install Liger official repository:
pip install liger-kernel
```

You should pass `use_liger=True` to apply liger kernel on your model, for example:
You should pass `use_liger_kernel=True` to apply liger kernel on your model, for example:

```py
from transformers import TrainingArguments
Expand All @@ -411,11 +411,11 @@ training_args = TrainingArguments(
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=True,
use_liger=True
use_liger_kernel=True
)
```

The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.
The kernel supports the Llama, Gemma, Mistral, and Mixtral model architectures. The most up-to-date list of supported models can be found [here](https://github.com/linkedin/Liger-Kernel). When `use_liger_kernel` is set to `True`, the corresponding layers in the original model will be patched with Liger's efficient implementation, so you don't need to do anything extra other than setting the argument value.

## LOMO optimizer

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def __init__(
" to `True` to avoid any unexpected behavior such as device placement mismatching."
)

if self.args.use_liger:
if self.args.use_liger_kernel:
if is_liger_kernel_available():
from liger_kernel.transformers.trainer_integration import _apply_liger_kernel

Expand All @@ -478,7 +478,7 @@ def __init__(
)
else:
raise ImportError(
"You have set `use_liger` to `True` but liger-kernel >= 0.1.0 is not available. "
"You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. "
"Please install it with `pip install liger-kernel`"
)

Expand Down
4 changes: 2 additions & 2 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ class TrainingArguments:
eval_use_gather_object (`bool`, *optional*, defaults to `False`):
Whether to run recursively gather object in a nested list/tuple/dictionary of objects from all devices. This should only be enabled if users are not just returning tensors, and this is actively discouraged by PyTorch.
use_liger (`bool`, *optional*, defaults to `False`):
use_liger_kernel (`bool`, *optional*, defaults to `False`):
Whether enable [Liger](https://github.com/linkedin/Liger-Kernel) Kernel for LLM model training.
It can effectively increase multi-GPU training throughput by ~20% and reduces memory usage by ~60%, works out of the box with
flash attention, PyTorch FSDP, and Microsoft DeepSpeed. Currently, it supports llama, mistral, mixtral and gemma models.
Expand Down Expand Up @@ -1496,7 +1496,7 @@ class TrainingArguments:
},
)

use_liger: Optional[bool] = field(
use_liger_kernel: Optional[bool] = field(
default=False,
metadata={"help": "Whether or not to enable the Liger Kernel for model training."},
)
Expand Down
2 changes: 1 addition & 1 deletion tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1337,7 +1337,7 @@ def test_apply_liger_kernel(self):

args = TrainingArguments(
"./test",
use_liger=True,
use_liger_kernel=True,
)
Trainer(tiny_model, args)

Expand Down

0 comments on commit b2bae31

Please sign in to comment.