diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 2dabfb6b4337..a3ffc53d8cd1 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -143,6 +143,12 @@ def baichuan_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-text2sql-spider") +@pytest.fixture(scope="session") +def baichuan_zero_lora_files(): + # all the lora_B weights are initialized to zero. + return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_lora_checkpoints.py b/tests/lora/test_lora_checkpoints.py index 35ad7342944c..d4d1665b624e 100644 --- a/tests/lora/test_lora_checkpoints.py +++ b/tests/lora/test_lora_checkpoints.py @@ -3,9 +3,16 @@ from vllm.lora.models import LoRAModel from vllm.model_executor.models.baichuan import BaiChuanBaseForCausalLM +lora_lst = ["baichuan7B", "baichuan7B-zero", "chatglm3-6b"] -@pytest.mark.parametrize("lora_name", ["baichuan7B", "chatglm3-6b"]) -def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): + +@pytest.mark.parametrize("lora_name", lora_lst) +def test_load_checkpoints( + lora_name, + baichuan_lora_files, + baichuan_zero_lora_files, + chatglm3_lora_files, +): supported_lora_modules = BaiChuanBaseForCausalLM.supported_lora_modules packed_modules_mapping = BaiChuanBaseForCausalLM.packed_modules_mapping embedding_modules = BaiChuanBaseForCausalLM.embedding_modules @@ -26,6 +33,17 @@ def test_load_checkpoints(lora_name, chatglm3_lora_files, baichuan_lora_files): device="cpu", embedding_modules=embedding_modules, embedding_padding_modules=embed_padding_modules) + elif lora_name == "baichuan7B-zero": + #Test that the target_modules contain prefix + # such as "model.layers.0.self_atten.W_pack", and + # the test should pass. + LoRAModel.from_local_checkpoint( + baichuan_zero_lora_files, + expected_lora_modules, + lora_model_id=1, + device="cpu", + embedding_modules=embedding_modules, + embedding_padding_modules=embed_padding_modules) else: # For the baichuan7B model, load chatglm3-6b's LoRA, # and the test should raise the following error. diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 62f150245800..6bb9fee27d53 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -212,7 +212,9 @@ def from_local_checkpoint( target_modules = config["target_modules"] unexpected_modules = [] for module in target_modules: - if module not in expected_lora_modules: + # Compatible with more modules, such as:layers.11.self_attn.k_proj + part_name = module.split(".")[-1] + if part_name not in expected_lora_modules: unexpected_modules.append(module) # loaded lora's target modules must be a subset of expected_lora_modules if unexpected_modules: