From 166b2bd2fa9900ca32c8e420f06b2ec7070035fe Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Mon, 16 Sep 2024 00:42:19 +0000 Subject: [PATCH 1/2] Updated liger-kernel integration in Trainer to call correct patching API --- src/transformers/trainer.py | 13 +++++------ src/transformers/utils/import_utils.py | 2 +- tests/trainer/test_trainer.py | 31 ++++++++++++++++---------- 3 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index f815c50d597f..97a052093652 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -468,19 +468,18 @@ def __init__( if self.args.use_liger_kernel: if is_liger_kernel_available(): - from liger_kernel.transformers.trainer_integration import _apply_liger_kernel + from liger_kernel.transformers import _apply_liger_kernel_to_instance - model_type = getattr(model, "config", None) and getattr(model.config, "model_type", None) - if model_type: - # Monkey patch the model with liger kernels. Use the default kernel configurations. - _apply_liger_kernel(model_type=model_type) + if isinstance(model, PreTrainedModel): + # Patch the model with liger kernels. Use the default kernel configurations. + _apply_liger_kernel_to_instance(model=model) else: logger.warning( - "The model does not have a valid `model_type` specified. No liger kernels will be applied." + "The model is not an instance of PreTrainedModel. No liger kernels will be applied." ) else: raise ImportError( - "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.1.0 is not available. " + "You have set `use_liger_kernel` to `True` but liger-kernel >= 0.3.0 is not available. " "Please install it with `pip install liger-kernel`" ) diff --git a/src/transformers/utils/import_utils.py b/src/transformers/utils/import_utils.py index af8a56944346..ad8b649aaa4e 100755 --- a/src/transformers/utils/import_utils.py +++ b/src/transformers/utils/import_utils.py @@ -1187,7 +1187,7 @@ def is_liger_kernel_available(): if not _liger_kernel_available: return False - return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.1.0") + return version.parse(importlib.metadata.version("liger_kernel")) >= version.parse("0.3.0") # docstyle-ignore diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 1837d9890352..50e736923402 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1344,22 +1344,29 @@ def test_get_eval_dataloader_with_persistent_workers(self): @require_liger_kernel def test_use_liger_kernel_patching(self): - # Test that the model code actually gets patched with Liger kernel - from liger_kernel.transformers.rms_norm import LigerRMSNorm + # Ensure any monkey patching is cleaned up for subsequent tests + with patch("transformers.models.llama.modeling_llama"): + from liger_kernel.transformers import LigerRMSNorm, liger_rotary_pos_emb - from transformers.models.llama import modeling_llama + from transformers.models.llama import modeling_llama - config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) - tiny_llama = LlamaForCausalLM(config) + config = LlamaConfig(vocab_size=100, hidden_size=32, num_hidden_layers=3, num_attention_heads=4) + tiny_llama = LlamaForCausalLM(config) - args = TrainingArguments( - "./test", - use_liger_kernel=True, - ) - Trainer(tiny_llama, args) + # Spot check that modeling code and model instance variables are not yet patched + self.assertNotEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb) + self.assertFalse(isinstance(tiny_llama.model.norm, LigerRMSNorm)) + + args = TrainingArguments( + "./test", + use_liger_kernel=True, + ) + Trainer(tiny_llama, args) + + # Spot check that modeling code and model instance variables are patched + self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb) + self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm)) - # Check that one of the Llama model layers has been correctly patched with Liger kernel - self.assertEqual(modeling_llama.LlamaRMSNorm, LigerRMSNorm) @require_liger_kernel @require_torch_gpu From 26ea9e2d47bc29617b53cb12fe0b0154661c6291 Mon Sep 17 00:00:00 2001 From: Steven Shimizu Date: Mon, 16 Sep 2024 00:42:40 +0000 Subject: [PATCH 2/2] Fixed styling --- tests/trainer/test_trainer.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 50e736923402..791486ec8374 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1367,7 +1367,6 @@ def test_use_liger_kernel_patching(self): self.assertEqual(modeling_llama.apply_rotary_pos_emb, liger_rotary_pos_emb) self.assertTrue(isinstance(tiny_llama.model.norm, LigerRMSNorm)) - @require_liger_kernel @require_torch_gpu def test_use_liger_kernel_trainer(self):