From 929b4f2973ec6a53ea4f0f03d21147ef8b8278be Mon Sep 17 00:00:00 2001 From: Woosuk Kwon Date: Wed, 28 Feb 2024 13:03:28 -0800 Subject: [PATCH] Add LoRA support for Gemma (#3050) --- .buildkite/test-pipeline.yaml | 2 +- csrc/punica/bgmv/bgmv_config.h | 2 ++ tests/lora/conftest.py | 5 ++++ tests/lora/test_gemma.py | 46 +++++++++++++++++++++++++++++ tests/lora/test_punica.py | 4 +-- vllm/model_executor/models/gemma.py | 28 ++++++++++++++++-- vllm/model_executor/models/llama.py | 2 +- 7 files changed, 82 insertions(+), 7 deletions(-) create mode 100644 tests/lora/test_gemma.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index efcc4d2d07a12..c65ab04b8ddda 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -50,7 +50,7 @@ steps: command: pytest -v -s worker - label: LoRA Test - command: pytest -v -s lora + command: pytest -v -s lora --forked - label: Metrics Test command: pytest -v -s metrics diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index ebf638f104c3f..d5fee9c40d00c 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ + f(in_T, out_T, W_T, narrow, 6144) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 8192) \ @@ -39,6 +40,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 14336) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ + f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 28672) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 75f4e41290c36..67273144ecd02 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -126,6 +126,11 @@ def mixtral_lora_files(): return snapshot_download(repo_id="terrysun/mixtral-lora-adapter") +@pytest.fixture(scope="session") +def gemma_lora_files(): + return snapshot_download(repo_id="wskwon/gemma-7b-test-lora") + + @pytest.fixture def llama_2_7b_engine_extra_embeddings() -> nn.Module: cleanup() diff --git a/tests/lora/test_gemma.py b/tests/lora/test_gemma.py new file mode 100644 index 0000000000000..0082c6e74e888 --- /dev/null +++ b/tests/lora/test_gemma.py @@ -0,0 +1,46 @@ +import vllm +from vllm.lora.request import LoRARequest + +MODEL_PATH = "google/gemma-7b" + + +def do_sample(llm, lora_path: str, lora_id: int) -> str: + prompts = [ + "Quote: Imagination is", + "Quote: Be yourself;", + "Quote: So many books,", + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_gemma_lora(gemma_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4) + + expected_lora_output = [ + "more important than knowledge.\nAuthor: Albert Einstein\n", + "everyone else is already taken.\nAuthor: Oscar Wilde\n", + "so little time\nAuthor: Frank Zappa\n", + ] + + output1 = do_sample(llm, gemma_lora_files, lora_id=1) + for i in range(len(expected_lora_output)): + assert output1[i].startswith(expected_lora_output[i]) + output2 = do_sample(llm, gemma_lora_files, lora_id=2) + for i in range(len(expected_lora_output)): + assert output2[i].startswith(expected_lora_output[i]) diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 903814faa5dc7..cbe0f6fa2e851 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -44,8 +44,8 @@ def _lora_ref_impl( H1 = H2 = [ 128, 256, 512, 1024, 1280, 2048, 2560, 2752, 3072, 3456, 3584, 4096, 5120, - 5504, 5632, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, 32000, - 32256, 32512, 32768, 33024 + 5504, 5632, 6144, 6912, 7168, 8192, 9216, 10240, 11008, 13824, 14336, + 24576, 32000, 32256, 32512, 32768, 33024 ] SEED = [0xabcdabcd987] diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index d8b515993d8ff..03948132d32c3 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -20,6 +20,7 @@ from torch import nn from transformers import GemmaConfig +from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.attention import PagedAttention @@ -246,12 +247,36 @@ def forward( class GemmaForCausalLM(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + # Gemma does not apply LoRA to the embedding layer. + embedding_modules = {} + embedding_padding_modules = [] def __init__( self, config: GemmaConfig, linear_method: Optional[LinearMethodBase] = None, + lora_config: Optional[LoRAConfig] = None, ) -> None: + del lora_config # Unused. super().__init__() self.config = config self.linear_method = linear_method @@ -305,9 +330,6 @@ def load_weights(self, weight_loader(param, loaded_weight, shard_id) break else: - # Skip loading extra layer for lora models. - if "lm_head" in name: - continue # GemmaRMSNorm is different from Llama's in that it multiplies # (1 + weight) to the output, instead of just weight. if "norm.weight" in name: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index b7f6b8f3ec374..d35887cc0f6a3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -27,6 +27,7 @@ from torch import nn from transformers import LlamaConfig +from vllm.config import LoRAConfig from vllm.model_executor.input_metadata import InputMetadata from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention @@ -45,7 +46,6 @@ from vllm.model_executor.weight_utils import (default_weight_loader, hf_model_weights_iterator) from vllm.sequence import SamplerOutput -from vllm.config import LoRAConfig KVCache = Tuple[torch.Tensor, torch.Tensor]