Skip to content

Commit 0bba88d

Browse files
authored
Enhance lora tests with more layer and rank variations (#3243)
1 parent 8437bae commit 0bba88d

File tree

3 files changed

+106
-0
lines changed

3 files changed

+106
-0
lines changed

csrc/punica/bgmv/bgmv_config.h

+1
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
1414
f(in_T, out_T, W_T, narrow, 128) \
1515
f(in_T, out_T, W_T, narrow, 256) \
1616
f(in_T, out_T, W_T, narrow, 512) \
17+
f(in_T, out_T, W_T, narrow, 768) \
1718
f(in_T, out_T, W_T, narrow, 1024) \
1819
f(in_T, out_T, W_T, narrow, 1280) \
1920
f(in_T, out_T, W_T, narrow, 1728) \

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ einops # required for MPT
2121
openai
2222
requests
2323
ray
24+
peft
2425

2526
# Benchmarking
2627
aiohttp

tests/lora/test_layer_variation.py

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
from typing import List, Optional
2+
import peft
3+
import pytest
4+
from random import sample
5+
import tempfile
6+
from transformers import AutoModelForCausalLM
7+
8+
import vllm
9+
from vllm.lora.request import LoRARequest
10+
from .conftest import cleanup
11+
12+
MODEL_PATH = "Felladrin/Llama-68M-Chat-v1"
13+
PROMPTS = [
14+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nSpellForce 3 is a pretty bad game. The developer Grimlore Games is clearly a bunch of no-talent hacks, and 2017 was a terrible year for games anyway. [/user] [assistant]",
15+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nI wanted to like Grimlore Games' 2017 entry, but in SpellForce 3 they just didn't get anything right. [/user] [assistant]",
16+
"[system] Given a target sentence construct the underlying meaning representation\nof the input sentence as a single function with attributes and attribute\nvalues. This function should describe the target string accurately and the\nfunction must be one of the following ['inform', 'request', 'give_opinion',\n'confirm', 'verify_attribute', 'suggest', 'request_explanation',\n'recommend', 'request_attribute'].\n\nThe attributes must be one of the following:\n['name', 'exp_release_date', 'release_year', 'developer', 'esrb', 'rating',\n'genres', 'player_perspective', 'has_multiplayer', 'platforms',\n'available_on_steam', 'has_linux_release', 'has_mac_release', 'specifier'] [/system] [user] Here is the target sentence:\nBioShock is a good role-playing, action-adventure, shooter that released for PlayStation, Xbox, and PC in 2007. It is available on Steam, and it has a Mac release but not a Linux release. [/user] [assistant]",
17+
]
18+
19+
20+
def get_lora_model(model_id: str, target_modules: List[str], rank: int):
21+
model = AutoModelForCausalLM.from_pretrained(model_id)
22+
lora_config = peft.tuners.lora.LoraConfig(target_modules, rank)
23+
lora_model = peft.PeftModel(model, lora_config)
24+
return lora_model
25+
26+
27+
def do_sample(llm,
28+
lora_path: Optional[str] = None,
29+
lora_id: Optional[int] = None,
30+
logprobs: int = 0,
31+
n_tokens: int = 256):
32+
prompts = PROMPTS
33+
sampling_params = vllm.SamplingParams(temperature=0,
34+
max_tokens=n_tokens,
35+
logprobs=logprobs,
36+
stop=["[/assistant]"])
37+
outputs = llm.generate(
38+
prompts,
39+
sampling_params,
40+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path)
41+
if lora_id else None)
42+
# Print the outputs.
43+
generated_texts = []
44+
generated_logprobs = []
45+
for output in outputs:
46+
prompt = output.prompt
47+
generated_text = output.outputs[0].text
48+
generated_texts.append(generated_text)
49+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
50+
generated_logprobs.append([
51+
list(logprob.keys()) for out in output.outputs
52+
for logprob in out.logprobs
53+
])
54+
return generated_logprobs if logprobs else generated_texts
55+
56+
57+
SUPPORTED_MODULES = [
58+
"qkv_proj", "o_proj", "gate_up_proj", "down_proj", "embed_tokens",
59+
"lm_head"
60+
]
61+
TARGET_MODULES_LIST = []
62+
for length in range(2, 6):
63+
TARGET_MODULES_LIST.extend(
64+
[sample(SUPPORTED_MODULES, length) for _ in range(3)])
65+
66+
67+
# Test the correctness when layer and rank are varied
68+
# step 1: init a base model and serve with LoRA to get the reference results
69+
# step 2: merge the same LoRA to the base model, serve the merged model
70+
# step 3: compare the results from step 1 and step 2
71+
@pytest.mark.parametrize("tp_size", [1])
72+
@pytest.mark.parametrize("target_modules", TARGET_MODULES_LIST)
73+
@pytest.mark.parametrize("rank", [8, 16, 32, 64])
74+
def test_layer_variation_correctness(tp_size, target_modules, rank):
75+
llm = vllm.LLM(MODEL_PATH,
76+
enable_lora=True,
77+
max_num_seqs=16,
78+
max_loras=4,
79+
tensor_parallel_size=tp_size,
80+
worker_use_ray=True)
81+
model = get_lora_model(MODEL_PATH, target_modules, rank)
82+
with tempfile.TemporaryDirectory() as tmpdir:
83+
model.save_pretrained(tmpdir)
84+
merged_probs = do_sample(llm, tmpdir, 1, logprobs=5, n_tokens=32)
85+
del llm
86+
cleanup()
87+
reference_id_sets = [set(prob[0]) for prob in merged_probs]
88+
89+
model = get_lora_model(MODEL_PATH, target_modules, rank)
90+
with tempfile.TemporaryDirectory() as tmpdir:
91+
merged_model = model.merge_and_unload()
92+
merged_model.save_pretrained(tmpdir)
93+
llm = vllm.LLM(tmpdir,
94+
tokenizer=MODEL_PATH,
95+
enable_lora=False,
96+
max_num_seqs=16,
97+
tensor_parallel_size=tp_size,
98+
worker_use_ray=True)
99+
probs = do_sample(llm, logprobs=5, n_tokens=32)
100+
del llm
101+
cleanup()
102+
# verify the top-5 tokens are identical for each token
103+
id_sets = [set(prob[0]) for prob in probs]
104+
assert id_sets == reference_id_sets

0 commit comments

Comments
 (0)