diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 529daf54faecf..dcfe228ce8eae 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -242,7 +242,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py --ignore=lora/test_chatglm3_tp.py --ignore=lora/test_llama_tp.py --ignore=lora/test_minicpmv_tp.py parallelism: 4 - label: "PyTorch Fullgraph Smoke Test" # 9min @@ -535,6 +535,7 @@ steps: # requires multi-GPU testing for validation. - pytest -v -s -x lora/test_chatglm3_tp.py - pytest -v -s -x lora/test_llama_tp.py + - pytest -v -s -x lora/test_minicpmv_tp.py - label: Weight Loading Multiple GPU Test # 33min diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py deleted file mode 100644 index 78bf5a1617233..0000000000000 --- a/tests/lora/test_minicpmv.py +++ /dev/null @@ -1,77 +0,0 @@ -from typing import List - -import pytest - -import vllm -from vllm.assets.image import ImageAsset -from vllm.lora.request import LoRARequest -from vllm.platforms import current_platform - -MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" - -PROMPT_TEMPLATE = ( - "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" - "(./)\nWhat is in the image?<|eot_id|>" - "<|start_header_id|>assistant<|end_header_id|>\n\n") - -IMAGE_ASSETS = [ - ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), -] - -# After fine-tuning with LoRA, all generated content should start begin `A`. -EXPECTED_OUTPUT = [ - "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 - "A pink cherry blossom tree with a blue sky in the background.", -] - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: - sampling_params = vllm.SamplingParams( - temperature=0, - max_tokens=5, - stop_token_ids=[128001, 128009], # eos_id, eot_id - ) - - inputs = [{ - "prompt": PROMPT_TEMPLATE, - "multi_modal_data": { - "image": asset.pil_image - }, - } for asset in IMAGE_ASSETS] - - outputs = llm.generate( - inputs, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None, - ) - # Print the outputs. - generated_texts: List[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text.strip() - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -@pytest.mark.xfail( - current_platform.is_rocm(), - reason="MiniCPM-V dependency xformers incompatible with ROCm") -def test_minicpmv_lora(minicpmv_lora_files): - llm = vllm.LLM( - MODEL_PATH, - max_num_seqs=2, - enable_lora=True, - max_loras=4, - max_lora_rank=64, - trust_remote_code=True, - enable_chunked_prefill=True, - ) - output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) - for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output1[i]) - output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) - for i in range(len(EXPECTED_OUTPUT)): - assert EXPECTED_OUTPUT[i].startswith(output2[i]) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py index 930f177953a5f..3b0f18325a40b 100644 --- a/tests/lora/test_minicpmv_tp.py +++ b/tests/lora/test_minicpmv_tp.py @@ -3,10 +3,10 @@ import pytest import vllm +from tests.utils import fork_new_process_for_each_test from vllm.assets.image import ImageAsset from vllm.lora.request import LoRARequest - -from ..utils import multi_gpu_test +from vllm.platforms import current_platform MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" @@ -17,13 +17,11 @@ IMAGE_ASSETS = [ ImageAsset("stop_sign"), - ImageAsset("cherry_blossom"), ] # After fine-tuning with LoRA, all generated content should start begin `A`. EXPECTED_OUTPUT = [ "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 - "A pink cherry blossom tree with a blue sky in the background.", ] @@ -50,48 +48,75 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: # Print the outputs. generated_texts: List[str] = [] for output in outputs: - prompt = output.prompt generated_text = output.outputs[0].text.strip() generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Generated text: {generated_text!r}") return generated_texts -@multi_gpu_test(num_gpus=2) -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="MiniCPM-V dependency xformers incompatible with ROCm") +@fork_new_process_for_each_test +def test_minicpmv_lora(minicpmv_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_num_seqs=2, + enable_lora=True, + max_loras=2, + max_lora_rank=8, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output1[i]) + output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output2[i]) + + +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="MiniCPM-V dependency xformers incompatible with ROCm") +@fork_new_process_for_each_test +def test_minicpmv_tp4_wo_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, enable_lora=True, max_num_seqs=2, max_loras=4, max_lora_rank=64, - tensor_parallel_size=2, + tensor_parallel_size=4, trust_remote_code=True, - fully_sharded_loras=fully_sharded, + enforce_eager=True, enable_chunked_prefill=True, ) - output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) - for i in range(len(EXPECTED_OUTPUT)): assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) -@multi_gpu_test(num_gpus=4) -@pytest.mark.parametrize("fully_sharded", [True, False]) -def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): +@pytest.mark.xfail( + current_platform.is_rocm(), + reason="MiniCPM-V dependency xformers incompatible with ROCm") +@fork_new_process_for_each_test +def test_minicpmv_tp4_fully_sharded_loras(minicpmv_lora_files): llm = vllm.LLM( MODEL_PATH, enable_lora=True, max_num_seqs=2, - max_loras=4, - max_lora_rank=64, + max_loras=2, + max_lora_rank=8, tensor_parallel_size=4, trust_remote_code=True, - fully_sharded_loras=fully_sharded, + fully_sharded_loras=True, enable_chunked_prefill=True, ) output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) for i in range(len(EXPECTED_OUTPUT)): assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) diff --git a/tests/lora/test_punica_sizes.py b/tests/lora/test_punica_sizes.py index 66b5f82bbb97d..0351fedd1cfa5 100644 --- a/tests/lora/test_punica_sizes.py +++ b/tests/lora/test_punica_sizes.py @@ -4,6 +4,8 @@ whether the corresponding Triton kernel can run normally when tensor parallelism is set to [1, 2, 4, 8, 16, 32, 64]. """ +from threading import Lock + import pytest import torch @@ -11,12 +13,13 @@ from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink +from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.platforms import current_platform -from .utils import (generate_data, generate_data_for_expand_nslices, - ref_torch_groupgemm) +from .utils import (assert_close, generate_data, + generate_data_for_expand_nslices, + generate_data_for_nslices, ref_torch_groupgemm) HIDDEN_SIZES = [ 128, @@ -112,14 +115,7 @@ SEED = [0] CUDA_DEVICES = [f"cuda:{0}"] - -def assert_close(a, b): - rtol, atol = { - torch.float16: (6e-2, 6e-2), - torch.bfloat16: (6e-2, 6e-2), - torch.float32: (1e-2, 1e-2), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) +_dict_lock = Lock() @pytest.mark.parametrize("batches", BATCHES) @@ -127,6 +123,7 @@ def assert_close(a, b): @pytest.mark.parametrize("rank", MAX_RANKS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @@ -137,6 +134,7 @@ def test_punica_sgmv( rank: int, hidden_size: int, scaling: float, + nslices: int, dtype: torch.dtype, op_type: str, seed: int, @@ -148,19 +146,20 @@ def test_punica_sgmv( seq_length = 128 ( inputs_tensor, - lora_weights, + lora_weights_lst, our_out_tensor, ref_out_tensor, b_seq_start_loc, lora_indices_tensor, seq_len_tensor, indices, - ) = generate_data( + ) = generate_data_for_nslices( batches, hidden_size, num_loras, rank, seq_length, + nslices, dtype, op_type, device, @@ -172,43 +171,64 @@ def test_punica_sgmv( else: max_seq_length = max_seq_length.item() if op_type == "shrink": - sgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) + # Preventing cache error pointer. + with _dict_lock: + _LORA_A_PTR_DICT.clear() + sgmv_shrink( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) + for index in range(nslices): + ref_torch_groupgemm( + ref_out_tensor[index], + inputs_tensor, + lora_weights_lst[index], + lora_indices_tensor, + seq_len_tensor, + batches, + scaling, + op_type, + ) else: - sgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) + with _dict_lock: + _LORA_B_PTR_DICT.clear() + sgmv_expand( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + offset_start=0, + add_inputs=True, + ) + + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + ref_torch_groupgemm( + ref_out_tensor[:, slice_offset:slice_offset + hidden_size], + inputs_tensor[index], + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type, + ) + slice_offset += hidden_size + assert_close(our_out_tensor, ref_out_tensor) @@ -292,25 +312,22 @@ def test_punica_bgmv( @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_punica_expand_nslices( +def test_punica_bgmv_expand_nslices( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, - op_type: str, seed: int, device: str, ): - torch.set_default_device(device) current_platform.seed_everything(seed) - seq_length = 128 if op_type == "sgmv" else 1 + seq_length = 1 ( inputs_tensor, lora_weights_lst, @@ -330,41 +347,18 @@ def test_punica_expand_nslices( nslices, device, ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] - if op_type == "sgmv": - sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) ref_torch_groupgemm( ref_outputs[:, slice_offset:slice_offset + hidden_size], inputs_tensor, diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 3b20033271d26..9ee10e7c23ee6 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -3,6 +3,8 @@ under different conditions, including various batches, numbers of LoRA , and maximum ranks. """ +from threading import Lock + import pytest import torch @@ -11,12 +13,13 @@ import vllm.lora.ops.bgmv_expand_slice import vllm.lora.ops.bgmv_shrink import vllm.lora.ops.sgmv_expand -import vllm.lora.ops.sgmv_expand_slice import vllm.lora.ops.sgmv_shrink # noqa: F401 +from vllm.lora.ops.utils import _LORA_A_PTR_DICT, _LORA_B_PTR_DICT from vllm.platforms import current_platform -from .utils import (generate_data, generate_data_for_expand_nslices, - ref_torch_groupgemm) +from .utils import (assert_close, generate_data, + generate_data_for_expand_nslices, + generate_data_for_nslices, ref_torch_groupgemm) HIDDEN_SIZES = [4097] @@ -28,31 +31,23 @@ SEED = [0] CUDA_DEVICES = [f"cuda:{0}"] - -def assert_close(a, b): - rtol, atol = { - torch.float16: (6e-2, 6e-2), - torch.bfloat16: (6e-2, 6e-2), - torch.float32: (1e-2, 1e-2), - }[a.dtype] - torch.testing.assert_close(a, b, rtol=rtol, atol=atol) - - # Unlike test_punica_sizes.py, we directly utilize custom op for # testing, which verifies the correct registration of these ops. bgmv_expand = torch.ops.vllm.bgmv_expand bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice bgmv_shrink = torch.ops.vllm.bgmv_shrink sgmv_expand = torch.ops.vllm.sgmv_expand -sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice sgmv_shrink = torch.ops.vllm.sgmv_shrink +_dict_lock = Lock() + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("scaling", SCALES) +@pytest.mark.parametrize("nslices", [1, 2, 3]) @pytest.mark.parametrize("dtype", DTYPES) @pytest.mark.parametrize("op_type", ["shrink", "expand"]) @pytest.mark.parametrize("seed", SEED) @@ -63,6 +58,7 @@ def test_punica_sgmv( rank: int, hidden_size: int, scaling: float, + nslices: int, dtype: torch.dtype, op_type: str, seed: int, @@ -74,19 +70,20 @@ def test_punica_sgmv( seq_length = 128 ( inputs_tensor, - lora_weights, + lora_weights_lst, our_out_tensor, ref_out_tensor, b_seq_start_loc, lora_indices_tensor, seq_len_tensor, indices, - ) = generate_data( + ) = generate_data_for_nslices( batches, hidden_size, num_loras, rank, seq_length, + nslices, dtype, op_type, device, @@ -98,43 +95,64 @@ def test_punica_sgmv( else: max_seq_length = max_seq_length.item() if op_type == "shrink": - sgmv_shrink( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - scaling, - ) + # Preventing cache error pointer. + with _dict_lock: + _LORA_A_PTR_DICT.clear() + sgmv_shrink( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + scaling, + ) + for index in range(nslices): + ref_torch_groupgemm( + ref_out_tensor[index], + inputs_tensor, + lora_weights_lst[index], + lora_indices_tensor, + seq_len_tensor, + batches, + scaling, + op_type, + ) else: - sgmv_expand( - inputs_tensor, - lora_weights, - our_out_tensor, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - add_inputs=True, - ) - ref_torch_groupgemm( - ref_out_tensor, - inputs_tensor, - lora_weights, - lora_indices_tensor, - seq_len_tensor, - batches, - scaling if op_type == "shrink" else 1.0, - op_type, - ) - if op_type == "shrink": - ref_out_tensor = ref_out_tensor.to(torch.float32) + with _dict_lock: + _LORA_B_PTR_DICT.clear() + sgmv_expand( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + b_seq_start_loc, + seq_len_tensor, + lora_indices_tensor, + batches, + max_seq_length, + token_nums, + offset_start=0, + add_inputs=True, + ) + + slice_offset = 0 + for index in range(nslices): + lora_weights = lora_weights_lst[index] + ref_torch_groupgemm( + ref_out_tensor[:, slice_offset:slice_offset + hidden_size], + inputs_tensor[index], + lora_weights, + lora_indices_tensor, + seq_len_tensor, + batches, + 1.0, + op_type, + ) + slice_offset += hidden_size + assert_close(our_out_tensor, ref_out_tensor) @@ -220,24 +238,22 @@ def test_punica_bgmv( @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("nslices", [2, 3]) @pytest.mark.parametrize("dtype", DTYPES) -@pytest.mark.parametrize("op_type", ["sgmv", "bgmv"]) @pytest.mark.parametrize("seed", SEED) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_punica_expand_nslices( +def test_punica_bgmv_expand_nslices( batches: int, num_loras: int, rank: int, hidden_size: int, nslices: int, dtype: torch.dtype, - op_type: str, seed: int, device: str, ): torch.set_default_device(device) current_platform.seed_everything(seed) - seq_length = 128 if op_type == "sgmv" else 1 + seq_length = 1 ( inputs_tensor, lora_weights_lst, @@ -257,40 +273,18 @@ def test_punica_expand_nslices( nslices, device, ) - max_seq_length = seq_len_tensor.max() - token_nums = seq_len_tensor.sum().item() - if isinstance(max_seq_length, tuple): - max_seq_length = max_seq_length[0].item() - else: - max_seq_length = max_seq_length.item() slice_offset = 0 for index in range(nslices): lora_weights = lora_weights_lst[index] - if op_type == "sgmv": - sgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - batches, - max_seq_length, - token_nums, - slice_offset, - hidden_size, - add_inputs=True, - ) - else: - bgmv_expand_slice( - inputs_tensor, - lora_weights, - our_outputs, - indices, - slice_offset, - slice_size=hidden_size, - add_inputs=True, - ) + bgmv_expand_slice( + inputs_tensor, + lora_weights, + our_outputs, + indices, + slice_offset, + slice_size=hidden_size, + add_inputs=True, + ) ref_torch_groupgemm( ref_outputs[:, slice_offset:slice_offset + hidden_size], inputs_tensor, diff --git a/tests/lora/utils.py b/tests/lora/utils.py index e394c33b3f9ea..b66d18074a7bf 100644 --- a/tests/lora/utils.py +++ b/tests/lora/utils.py @@ -18,11 +18,13 @@ def set_module_lora(self, module_name: str, lora: LoRALayerWeights): def get_module_lora(self, module_name: str) -> LoRALayerWeights: return self._loras[module_name] - def init_random_lora(self, - module_name: str, - weight: torch.Tensor, - rank: int = 8, - generate_embeddings_tensor: int = 0): + def init_random_lora( + self, + module_name: str, + weight: torch.Tensor, + rank: int = 8, + generate_embeddings_tensor: int = 0, + ): lora = LoRALayerWeights( module_name, rank=rank, @@ -35,21 +37,25 @@ def init_random_lora(self, device=self._device), ) if generate_embeddings_tensor: - lora.embeddings_tensor = torch.rand(5, - generate_embeddings_tensor, - dtype=weight.dtype, - device=self._device) + lora.embeddings_tensor = torch.rand( + 5, + generate_embeddings_tensor, + dtype=weight.dtype, + device=self._device, + ) self.set_module_lora(module_name, lora) return lora - def init_lora(self, - module_name: str, - input_dim: int, - output_dim: int, - rank=8, - noop=False, - embeddings_tensor=None): + def init_lora( + self, + module_name: str, + input_dim: int, + output_dim: int, + rank=8, + noop=False, + embeddings_tensor=None, + ): lora = LoRALayerWeights( module_name, rank=rank, @@ -125,8 +131,16 @@ def ref_torch_groupgemm( return -def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype, - op_type, device): +def generate_data( + batches, + hidden_size, + lora_nums, + max_rank, + seq_length, + dtype, + op_type, + device, +): seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -187,8 +201,16 @@ def generate_data(batches, hidden_size, lora_nums, max_rank, seq_length, dtype, ) -def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank, - seq_length, dtype, nslices, device): +def generate_data_for_expand_nslices( + batches, + hidden_size, + lora_nums, + max_rank, + seq_length, + dtype, + nslices, + device, +): seq_len_tensor = torch.randint(seq_length, seq_length + 1, (batches, )).to(device) b_seq_start_loc = torch.cumsum( @@ -221,7 +243,87 @@ def generate_data_for_expand_nslices(batches, hidden_size, lora_nums, max_rank, for b_id in range(batches): lora_index = lora_indices_tensor[b_id] indices[current_offset:current_offset + - seq_len_tensor[b_id]] = lora_index.item() + seq_len_tensor[b_id]] = (lora_index.item()) + current_offset += seq_len_tensor[b_id].item() + + lora_indices_tensor = lora_indices_tensor.to(device) + return ( + inputs_tensor, + lora_weights_lst, + our_out_tensor, + ref_out_tensor, + b_seq_start_loc, + lora_indices_tensor, + seq_len_tensor, + indices, + ) + + +def generate_data_for_nslices( + batches, + hidden_size, + lora_nums, + max_rank, + seq_length, + nslices, + dtype, + op_type, + device, +): + seq_len_tensor = torch.randint(seq_length, seq_length + 1, + (batches, )).to(device) + b_seq_start_loc = torch.cumsum( + torch.tensor([0] + seq_len_tensor[:-1].tolist(), dtype=torch.long), + dim=0, + ).to(device) + total_tokens = seq_len_tensor.sum() + + lora_weights_lst = [] + if op_type == "shrink": + + inputs_tensor = torch.rand((total_tokens, hidden_size), + dtype=dtype).to(device) + + for _ in range(nslices): + if op_type == "shrink": + lora_weights_lst.append( + torch.rand( + (lora_nums, max_rank, hidden_size), # col-major + dtype=dtype, + ).to(device)) + # NOTE shrink kernel using torch.float32 as output type + # shrink op need atomic_add, so output is initinized by 0 + our_out_tensor = torch.zeros( + (nslices, total_tokens, max_rank), + dtype=torch.float32, + ).to(device) + else: + inputs_tensor = torch.rand( + (nslices, total_tokens, max_rank), + dtype=dtype, + ).to(device) + for _ in range(nslices): + lora_weights_lst.append( + torch.rand( + (lora_nums, hidden_size, max_rank), # col-major + dtype=dtype, + ).to(device)) + # expand op needs to complete y+=a@lora_b, so output is + # initinized randomly + our_out_tensor = torch.rand((total_tokens, hidden_size * nslices), + dtype=dtype).to(device) + + # Ensure the same input. + ref_out_tensor = our_out_tensor.clone() + lora_indices_tensor = torch.randint(0, + lora_nums - 1 if lora_nums > 1 else 1, + (batches, )) + indices = torch.zeros((total_tokens), dtype=torch.long).to(device) + current_offset = 0 + for b_id in range(batches): + lora_index = lora_indices_tensor[b_id] + indices[current_offset:current_offset + + seq_len_tensor[b_id]] = (lora_index.item()) current_offset += seq_len_tensor[b_id].item() lora_indices_tensor = lora_indices_tensor.to(device) diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index 77c5178493c44..8af44b703810b 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -1,66 +1,109 @@ """ Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. +Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). +Punica: Multi-Tenant LoRA Serving. https://arxiv.org/abs/2310.18547 """ +from typing import List + import torch import triton import triton.language as tl from vllm.utils import direct_register_custom_op +from .utils import _get_lora_b_ptr + @triton.jit def _sgmv_expand_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - xm_stride, - xk_stride, # 1 - l0_stride, # hidden_size*max_rank - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, -): + input_ptr, + lora_ptr, + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + slice_start_loc, + input_d0_stride, + input_d1_stride, + input_d2_stride, # 1 + ls_d0_ptr, + ls_d1_ptr, + ls_d2_ptr, # 1 + output_d0_stride, + output_d1_stride, # 1 + output_hs_ptr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + ADD_INPUTS: tl.constexpr, + CAST_TYPE: tl.constexpr, + SLICE_NUM: tl.constexpr, + SAME_STRIDE: tl.constexpr): """ - The sgmv's expand triton kernel is based on GroupGEMM. + + Similar to the 'sgmv_expand' operator, but with an added parameter + 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator + might be that in the future, we could implement a fusion operator to + achieve the current functionality instead of having to call it multiple + times. """ pid = tl.program_id(axis=0) cur_batch = tl.program_id(axis=1) + slice_id = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) + # When the output dimensions of each slice are the same,cur_n=N, otherwise + # cur_n=tl.load(output_hs_ptr + slice_id), this situation exists in GQA's + # qkv linear. + curr_N = N if SAME_STRIDE else tl.load(output_hs_ptr + slice_id) pid_m = pid // cta_n_num pid_n = pid % cta_n_num M = tl.load(seq_lens + cur_batch) if pid_m * BLOCK_M > M: return + if pid_n * BLOCK_N > curr_N: + return lora_index = tl.load(lora_indices + cur_batch) if lora_index == -1: return + cur_seq_start = tl.load(b_seq_start_loc + cur_batch) offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N offset_k = tl.arange(0, BLOCK_K) ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + rbn = tl.max_contiguous(tl.multiple_of(offset_n % curr_N, BLOCK_N), + BLOCK_N) + # ls_d*_ptr can be either an integer or a pointer + if SAME_STRIDE: + # integer + cur_lora_d0_stride = ls_d0_ptr + cur_lora_d1_stride = ls_d1_ptr + cur_lora_d2_stride = ls_d2_ptr + else: + # pointer + cur_lora_d0_stride = tl.load(ls_d0_ptr + slice_id) + cur_lora_d1_stride = tl.load(ls_d1_ptr + slice_id) + cur_lora_d2_stride = tl.load(ls_d2_ptr + slice_id) + if SLICE_NUM == 1: + cur_input_ptr = input_ptr + cur_lora_ptr = lora_ptr - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride, ) - b_ptr = (lora_ptr + l0_stride * lora_index + - offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) + else: + cur_input_ptr = input_ptr + slice_id * input_d0_stride + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(out_ptr.dtype.element_ty)) + + a_ptr = (cur_input_ptr + cur_seq_start * input_d1_stride + + ram[:, None] * input_d1_stride + + offset_k[None, :] * input_d2_stride, ) + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + + offset_k[:, None] * cur_lora_d2_stride + + rbn[None, :] * cur_lora_d1_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(tl.cdiv(K, BLOCK_K)): if EVEN_K: @@ -74,26 +117,30 @@ def _sgmv_expand_kernel( mask=offset_k[:, None] < K - k * BLOCK_K, other=0) if CAST_TYPE: - tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) + tiled_a = tiled_a.to(cur_lora_ptr.dtype.element_ty) accumulator += tl.dot( tiled_a, tiled_b, ) - a_ptr += BLOCK_K * xk_stride - b_ptr += BLOCK_K * lora_n_stride - tiled_c = accumulator.to(lora_ptr.dtype.element_ty) + a_ptr += BLOCK_K * input_d2_stride + b_ptr += BLOCK_K * cur_lora_d2_stride + + tiled_c = accumulator.to(cur_lora_ptr.dtype.element_ty) + if SLICE_NUM == 1: + cur_slice_start = slice_start_loc + else: + cur_slice_start = tl.load(slice_start_loc + slice_id) + offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) + offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + cur_slice_start + c_ptr = (out_ptr + offset_cm[:, None] * output_d0_stride + + offset_cn[None, :] * output_d1_stride) M = tl.load(seq_lens + cur_batch) c_mask = (offset_cm[:, None] < - (cur_seq_start + M)) & (offset_cn[None, :] < N) + (cur_seq_start + M)) & (offset_cn[None, :] < + (cur_slice_start + curr_N)) if ADD_INPUTS: - # explicitly pass in other=None to tell triton that masked values - # can be uninitialized. This is OK because the later tl.store operation - # uses the same mask, eliminating the risk of garbage values propagating - tiled_out = tl.load(c_ptr, mask=c_mask, other=None) + tiled_out = tl.load(c_ptr, mask=c_mask) tiled_c += tiled_out tl.store(c_ptr, tiled_c, mask=c_mask) @@ -101,7 +148,7 @@ def _sgmv_expand_kernel( @torch.inference_mode() def _sgmv_expand( inputs: torch.Tensor, - lora_b_weights: torch.Tensor, + lora_b_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -109,17 +156,18 @@ def _sgmv_expand( batches: int, max_seq_length: int, token_nums: int, + offset_start: int = 0, add_inputs: bool = False, ) -> None: """ Args: inputs (torch.Tensor): input tensor - lora_b_weights (torch.Tensor): lora'a weight + lora_b_weights (List[torch.Tensor]): lora'b weight output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index into sequence. E.g., if the sequence length is [4, 6], it is - [0, 4, 10]. + [0, 4]. seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence length of the sequences in the batch. lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index @@ -130,77 +178,80 @@ def _sgmv_expand( batch. token_nums (int): The token numbers in the batch. Used to verify if the token numbers in the inputs matches the one in the metadata. - add_inputs (bool, optional): Defaults to False, adds the final lora - results to the output. + offset_start (int, optional): Offset start for output_tensor. + Defaults to 0. + add_inputs (bool, optional): Whether to add the input tensor to the + output tensor. Defaults to False. """ - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, - ] - assert inputs.size(0) == token_nums - assert inputs.size(1) == lora_b_weights.size(-1) + for weight in lora_b_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + + assert inputs.size(1) == token_nums + assert inputs.size(0) == len(lora_b_weights) + assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches - assert inputs.is_contiguous() assert output_tensor.is_contiguous() - - if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weights.size(1) == 1 - lora_b_weights = lora_b_weights.squeeze(dim=1) - else: - assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) - - assert lora_b_weights.is_contiguous() + (slice_start_tensor, lora_ptr_tensor, lora_strides_d0_tensor, + lora_strides_d1_tensor, lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) = _get_lora_b_ptr(lora_b_weights, offset_start, + b_seq_start_loc.device) # TODO tuning this config + K = lora_b_weights[0].shape[-1] # K= rank - N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size - BLOCK_M = 32 - BLOCK_N = 32 + BLOCK_M = 64 + BLOCK_N = 128 BLOCK_K = 16 EVEN_K = K % BLOCK_K == 0 ADD_INPUTS = add_inputs CAST_TYPE = False - if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ + + if inputs.dtype == torch.float32 and lora_b_weights[0].dtype in [ torch.float16, torch.bfloat16, ]: CAST_TYPE = True grid = ( - triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), + triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(MAX_N, BLOCK_N), batches, + len(lora_b_weights), ) _sgmv_expand_kernel[grid]( inputs, - lora_b_weights, + lora_ptr_tensor, output_tensor, - N, + MAX_N, K, b_seq_start_loc, seq_len_tensor, lora_indices_tensor, + slice_start_tensor, inputs.stride(0), inputs.stride(1), - lora_b_weights.stride(0), - lora_b_weights.stride(1), - lora_b_weights.stride(2), + inputs.stride(2), + lora_strides_d0_tensor, + lora_strides_d1_tensor, + lora_strides_d2_tensor, output_tensor.stride(0), output_tensor.stride(1), + hidden_sizes_tensor, BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, ADD_INPUTS, CAST_TYPE, + len(lora_b_weights), + same_stride, ) return -def sgmv_expand_fake( +def _sgmv_expand_fake( inputs: torch.Tensor, - lora_b_weights: torch.Tensor, + lora_b_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -208,18 +259,18 @@ def sgmv_expand_fake( batches: int, max_seq_length: int, token_nums: int, + offset_start: int = 0, add_inputs: bool = False, ) -> None: return try: - direct_register_custom_op( op_name="sgmv_expand", op_func=_sgmv_expand, mutates_args=["output_tensor"], - fake_impl=sgmv_expand_fake, + fake_impl=_sgmv_expand_fake, ) sgmv_expand = torch.ops.vllm.sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py deleted file mode 100644 index 55c4fb68ed128..0000000000000 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ /dev/null @@ -1,241 +0,0 @@ -""" -Based on: -Chen, L., Ye, Z., Wu, Y., Zhuo, D., Ceze, L., & Krishnamurthy, A. (2023). -Punica: Multi-Tenant LoRA Serving. -https://arxiv.org/abs/2310.18547 -""" - -import torch -import triton -import triton.language as tl - -from vllm.utils import direct_register_custom_op - - -@triton.jit -def _sgmv_expand_slice_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - xm_stride, - xk_stride, # 1 - l0_stride, # hidden_size*max_rank - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - slice_offset, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - ADD_INPUTS: tl.constexpr, - CAST_TYPE: tl.constexpr, -): - """ - - Similar to the 'sgmv_expand' operator, but with an added parameter - 'slice_offset'. The reason for not reusing the 'sgmv_expand' operator - might be that in the future, we could implement a fusion operator to - achieve the current functionality instead of having to call it multiple - times. - """ - pid = tl.program_id(axis=0) - cur_batch = tl.program_id(axis=1) - cta_n_num = tl.cdiv(N, BLOCK_N) - pid_m = pid // cta_n_num - pid_n = pid % cta_n_num - M = tl.load(seq_lens + cur_batch) - if pid_m * BLOCK_M > M: - return - lora_index = tl.load(lora_indices + cur_batch) - if lora_index == -1: - return - cur_seq_start = tl.load(b_seq_start_loc + cur_batch) - offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - offset_k = tl.arange(0, BLOCK_K) - ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) - rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) - - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride, ) - b_ptr = (lora_ptr + l0_stride * lora_index + - offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride) - accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) - for k in range(tl.cdiv(K, BLOCK_K)): - if EVEN_K: - tiled_a = tl.load(a_ptr) - tiled_b = tl.load(b_ptr) - else: - tiled_a = tl.load(a_ptr, - mask=offset_k[None, :] < K - k * BLOCK_K, - other=0) - tiled_b = tl.load(b_ptr, - mask=offset_k[:, None] < K - k * BLOCK_K, - other=0) - if CAST_TYPE: - tiled_a = tiled_a.to(lora_ptr.dtype.element_ty) - accumulator += tl.dot( - tiled_a, - tiled_b, - ) - a_ptr += BLOCK_K * xk_stride - b_ptr += BLOCK_K * lora_n_stride - tiled_c = accumulator.to(lora_ptr.dtype.element_ty) - offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M - offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) - M = tl.load(seq_lens + cur_batch) - c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < - (slice_offset + N)) - if ADD_INPUTS: - # explicitly pass in other=None to tell triton that masked values - # can be uninitialized. This is OK because the later tl.store operation - # uses the same mask, eliminating the risk of garbage values propagating - tiled_out = tl.load(c_ptr, mask=c_mask, other=None) - tiled_c += tiled_out - tl.store(c_ptr, tiled_c, mask=c_mask) - - -@torch.inference_mode() -def _sgmv_expand_slice( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False, -) -> None: - """_summary_ - - Args: - inputs (torch.Tensor): input tensor - lora_b_weights (torch.Tensor): lora'a weight - output_tensor (torch.Tensor): output tensor - b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative - sequence lengths of the sequences in the batch, used to index - into sequence. E.g., if the sequence length is [4, 6], it is - [0, 4, 10]. - seq_len_tensor (torch.Tensor): (batch_size,). Record the sequence - length of the sequences in the batch - lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index - corresponding to each batch. An index of -1 means no lora should be - applied. - batches (int): batch size - max_seq_length (int): The max sequence lengths of the sequences - in the batch - token_nums (int): The token numbers in the batch. Used to verify if the - token numbers in the inputs matches the one in the metadata. - slice_offset (int): output_tensor's offset - slice_size (int): current output_tensor's size - add_inputs (bool, optional): Defaults to False, adds the final lora - results to the output. - """ - - assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] - assert lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, - ] - assert inputs.size(0) == token_nums - assert inputs.size(1) == lora_b_weights.size(-1) - assert b_seq_start_loc.size(0) == batches - assert lora_indices_tensor.size(0) == batches - assert slice_size == lora_b_weights.size(-2) - assert inputs.is_contiguous() - assert output_tensor.is_contiguous() - - if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank) - assert lora_b_weights.size(1) == 1 - lora_b_weights = lora_b_weights.squeeze(dim=1) - else: - assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank) - - assert lora_b_weights.is_contiguous() - - # TODO tuning this config - N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size - - BLOCK_M = 32 - BLOCK_N = 32 - BLOCK_K = 16 - EVEN_K = K % BLOCK_K == 0 - ADD_INPUTS = add_inputs - CAST_TYPE = False - if inputs.dtype == torch.float32 and lora_b_weights.dtype in [ - torch.float16, - torch.bfloat16, - ]: - CAST_TYPE = True - grid = ( - triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), - batches, - ) - _sgmv_expand_slice_kernel[grid]( - inputs, - lora_b_weights, - output_tensor, - N, - K, - b_seq_start_loc, - seq_len_tensor, - lora_indices_tensor, - inputs.stride(0), - inputs.stride(1), - lora_b_weights.stride(0), - lora_b_weights.stride(1), - lora_b_weights.stride(2), - output_tensor.stride(0), - output_tensor.stride(1), - slice_offset, - BLOCK_M, - BLOCK_N, - BLOCK_K, - EVEN_K, - ADD_INPUTS, - CAST_TYPE, - ) - return - - -def sgmv_expand_slice_fake( - inputs: torch.Tensor, - lora_b_weights: torch.Tensor, - output_tensor: torch.Tensor, - b_seq_start_loc: torch.Tensor, - seq_len_tensor: torch.Tensor, - lora_indices_tensor: torch.Tensor, - batches: int, - max_seq_length: int, - token_nums: int, - slice_offset: int, - slice_size: int, - add_inputs: bool = False, -) -> None: - return - - -try: - direct_register_custom_op( - op_name="sgmv_expand_slice", - op_func=_sgmv_expand_slice, - mutates_args=["output_tensor"], - fake_impl=sgmv_expand_slice_fake, - ) - sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice - -except AttributeError: - sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index 37d1dc84eebca..3d2ebe8286f56 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -5,48 +5,60 @@ https://arxiv.org/abs/2310.18547 """ +from typing import List + import torch import triton import triton.language as tl from vllm.utils import direct_register_custom_op +from .utils import _get_lora_a_ptr + @triton.jit def _sgmv_shrink_kernel( - input_ptr, - lora_ptr, - out_ptr, - N, - K, - b_seq_start_loc, - seq_lens, - lora_indices, - scaling, - xm_stride, # hidden_size - xk_stride, # 1 - l0_stride, # hidden_size*max_rank - lora_k_stride, - lora_n_stride, - cm_stride, - cn_stride, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_K: tl.constexpr, - EVEN_K: tl.constexpr, - SPLIT_K: tl.constexpr, -): + input_ptr, + lora_ptr, #1-3 + out_ptr, + N, + K, + b_seq_start_loc, + seq_lens, + lora_indices, + scaling, + input_d0_stride, + input_d1_stride, # 1 + lora_d0_stride, + lora_d1_stride, + lora_d2_stride, # 1 + output_d0_stride, + output_d1_stride, + output_d2_stride, # 1 + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + BLOCK_K: tl.constexpr, + EVEN_K: tl.constexpr, + SPLIT_K: tl.constexpr, + SLICE_NUM: tl.constexpr): """ The sgmv's shrink triton kernel is based on GroupGEMM+SPLIT-K. The GEMM of Multi-LoRA can be considered as GroupGEMM. Additionally, introducing SPLIT-K can improve performance """ pid = tl.program_id(axis=0) - pid_sk = tl.program_id(axis=1) + pid_mix = tl.program_id(axis=1) cur_batch = tl.program_id(axis=2) cta_n_num = tl.cdiv(N, BLOCK_N) pid_m = pid // cta_n_num pid_n = pid % cta_n_num + if SLICE_NUM == 1: + slice_id: tl.constexpr = 0 + pid_sk = tl.program_id(axis=1) + else: + pid_mix = tl.program_id(axis=1) + slice_id = pid_mix // SPLIT_K + pid_sk = pid_mix % SPLIT_K M = tl.load(seq_lens + cur_batch) if pid_m * BLOCK_M > M: @@ -61,11 +73,22 @@ def _sgmv_shrink_kernel( ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M) rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N) + # input ptr + a_ptr = (input_ptr + cur_seq_start * input_d0_stride + + ram[:, None] * input_d0_stride + + offset_k[None, :] * input_d1_stride) + + if SLICE_NUM == 1: + # current lora ptr + cur_lora_ptr = lora_ptr + else: + # current lora ptr + cur_lora_ptr = tl.load(lora_ptr + slice_id).to( + tl.pointer_type(input_ptr.dtype.element_ty)) - a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride + - offset_k[None, :] * xk_stride) - b_ptr = (lora_ptr + l0_stride * lora_index + rbn[None, :] * lora_k_stride + - offset_k[:, None] * lora_n_stride) + b_ptr = (cur_lora_ptr + lora_d0_stride * lora_index + + rbn[None, :] * lora_d1_stride + + offset_k[:, None] * lora_d2_stride) accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32) for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)): @@ -82,13 +105,15 @@ def _sgmv_shrink_kernel( other=0.0) accumulator += tl.dot(tiled_a, tiled_b) - a_ptr += BLOCK_K * SPLIT_K * xk_stride - b_ptr += BLOCK_K * SPLIT_K * lora_n_stride + a_ptr += BLOCK_K * SPLIT_K * input_d1_stride + b_ptr += BLOCK_K * SPLIT_K * lora_d2_stride offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N - c_ptr = (out_ptr + offset_cm[:, None] * cm_stride + - offset_cn[None, :] * cn_stride) + cur_out_ptr = (out_ptr if SLICE_NUM == 1 else out_ptr + + slice_id * output_d0_stride) + c_ptr = cur_out_ptr + offset_cm[:, None] * output_d1_stride + offset_cn[ + None, :] * output_d2_stride c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] < N) accumulator *= scaling @@ -102,7 +127,7 @@ def _sgmv_shrink_kernel( @torch.inference_mode() def _sgmv_shrink( inputs: torch.Tensor, - lora_a_weights: torch.Tensor, + lora_a_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, @@ -113,10 +138,9 @@ def _sgmv_shrink( scaling: float, ) -> None: """ - Args: inputs (torch.Tensor): input tensor - lora_a_weights (torch.Tensor): lora'a weight + lora_a_weights (List[torch.Tensor]): lora'a weight output_tensor (torch.Tensor): output tensor b_seq_start_loc (torch.Tensor): (batch_size,). The cumulative sequence lengths of the sequences in the batch, used to index @@ -134,27 +158,21 @@ def _sgmv_shrink( token numbers in the inputs matches the one in the metadata. scaling (float): Scaling factor. """ - assert inputs.dtype == lora_a_weights.dtype + assert inputs.dtype == lora_a_weights[0].dtype assert inputs.dtype in [torch.float16, torch.bfloat16] - assert lora_a_weights.dtype in [ - torch.float16, - torch.bfloat16, - ] + for weight in lora_a_weights: + assert weight.dtype in [torch.float16, torch.bfloat16] + assert inputs.size(0) == token_nums - assert inputs.size(1) == lora_a_weights.size(-1) + assert inputs.size(1) == lora_a_weights[0].size(-1) assert b_seq_start_loc.size(0) == batches assert lora_indices_tensor.size(0) == batches assert inputs.is_contiguous() - - if lora_a_weights.ndim == 4: # shape:(lora_num,1,rank, size) - assert lora_a_weights.size(1) == 1 - lora_a_weights = lora_a_weights.squeeze(dim=1) - else: - assert lora_a_weights.ndim == 3 # shape:(lora_num,rank, size) - assert lora_a_weights.is_contiguous() assert output_tensor.is_contiguous() + (lora_ptr_tensor, lora_strides_d0, lora_strides_d1, + lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, b_seq_start_loc.device) # TODO tuning this config - N, K = lora_a_weights.shape[-2:] # K=hidden_size,N=rank + N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank BLOCK_M = 32 BLOCK_N = 16 BLOCK_K = 32 @@ -162,13 +180,12 @@ def _sgmv_shrink( EVEN_K = K % (BLOCK_K * SPLIT_K) == 0 grid = ( triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N), - SPLIT_K, + SPLIT_K * len(lora_a_weights), batches, ) - _sgmv_shrink_kernel[grid]( inputs, - lora_a_weights, + lora_ptr_tensor, output_tensor, N, K, @@ -178,23 +195,25 @@ def _sgmv_shrink( scaling, inputs.stride(0), inputs.stride(1), - lora_a_weights.stride(0), - lora_a_weights.stride(1), - lora_a_weights.stride(2), + lora_strides_d0, + lora_strides_d1, + lora_strides_d2, output_tensor.stride(0), output_tensor.stride(1), + output_tensor.stride(2), BLOCK_M, BLOCK_N, BLOCK_K, EVEN_K, SPLIT_K, + len(lora_a_weights), ) return def sgmv_shrink_fake( inputs: torch.Tensor, - lora_a_weights: torch.Tensor, + lora_a_weights: List[torch.Tensor], output_tensor: torch.Tensor, b_seq_start_loc: torch.Tensor, seq_len_tensor: torch.Tensor, diff --git a/vllm/lora/ops/utils.py b/vllm/lora/ops/utils.py index 7c3e27313ad97..7df5bc2c225e5 100644 --- a/vllm/lora/ops/utils.py +++ b/vllm/lora/ops/utils.py @@ -1,5 +1,7 @@ import functools -from typing import Dict +from typing import Dict, List, Tuple + +import torch @functools.lru_cache @@ -44,3 +46,120 @@ def get_lora_op_configs(op_type: str, batch: int, if not config: config = _get_default_config(op_type, batch, hidden_size) return config + + +_LORA_A_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} +_LORA_B_PTR_DICT: Dict[Tuple[int, ...], Tuple[torch.tensor, ...]] = {} + + +def _get_lora_a_ptr(lora_a_weights: List[torch.Tensor], device: str): + """ + `_LORA_A_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + """ + key = tuple(lora_weight.data_ptr() for lora_weight in lora_a_weights) + + if values := _LORA_A_PTR_DICT.get(key): + return values + + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + tensor_ptrs = [] + for lora_a_weight in lora_a_weights: + if lora_a_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_a_weight.size(1) == 1 + lora_a_weight = lora_a_weight.squeeze(dim=1) + else: + assert lora_a_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_a_weight.is_contiguous() + tensor_ptrs.append(lora_a_weight.data_ptr()) + lora_strides_d0.append(lora_a_weight.stride(0)) + lora_strides_d1.append(lora_a_weight.stride(1)) + lora_strides_d2.append(lora_a_weight.stride(2)) + if len(lora_a_weights) > 1: + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + else: + lora_ptr_tensor = lora_a_weights[0] + + if (len(set(lora_strides_d0)) > 1 or len(set(lora_strides_d1)) > 1 + or len(set(lora_strides_d2)) > 1): + raise ValueError("All LoRA weights must have the same stride.") + + _LORA_A_PTR_DICT[key] = ( + lora_ptr_tensor, + lora_strides_d0[0], + lora_strides_d1[0], + lora_strides_d2[0], + ) + return _LORA_A_PTR_DICT.get(key) + + +def _get_lora_b_ptr(lora_weights: List[torch.Tensor], offset_start: int, + device: str): + """ + `_LORA_B_PTR_DICT` collects the required information during `profile_run`, + After this, it remains constant and subsequent usage is through LUT. + Refer to: + https://github.com/triton-lang/triton/blob/release/3.1.x/python/tutorials/08-grouped-gemm.py + + """ + + key = tuple(lora_weight.data_ptr() for lora_weight in lora_weights) + if values := _LORA_B_PTR_DICT.get(key): + return values + slice_offset_lst = [] + tensor_ptrs = [] + lora_strides_d0 = [] + lora_strides_d1 = [] + lora_strides_d2 = [] + hidden_sizes = [] + slice_offset = offset_start + for lora_b_weight in lora_weights: + if lora_b_weight.ndim == 4: # shape:(lora_num,1,size,rank) + assert lora_b_weight.size(1) == 1 + lora_b_weight = lora_b_weight.squeeze(dim=1) + else: + assert lora_b_weight.ndim == 3 # shape:(lora_num,size,rank) + assert lora_b_weight.is_contiguous() + tensor_ptrs.append(lora_b_weight.data_ptr()) + lora_strides_d0.append(lora_b_weight.stride(0)) + lora_strides_d1.append(lora_b_weight.stride(1)) + lora_strides_d2.append(lora_b_weight.stride(2)) + slice_offset_lst.append(slice_offset) + slice_offset += lora_b_weight.size(1) + hidden_sizes.append(lora_b_weight.size(1)) + + if len(lora_weights) > 1: + # note these are device tensors + lora_ptr_tensor = torch.tensor(tensor_ptrs, device=device) + slice_start_tensor = torch.tensor(slice_offset_lst, device=device) + else: + slice_start_tensor = slice_offset_lst[0] + lora_ptr_tensor = lora_b_weight[0] + + # If each lora has the same stride, there's no need to use a + # tensor for storage. + if (len(set(lora_strides_d0)) == 1 and len(set(lora_strides_d1)) == 1 and + len(set(lora_strides_d2)) == 1) and len(set(hidden_sizes)) == 1: + lora_strides_d0_tensor = lora_strides_d0[0] + lora_strides_d1_tensor = lora_strides_d1[0] + lora_strides_d2_tensor = lora_strides_d2[0] + hidden_sizes_tensor = hidden_sizes[0] + same_stride = True + + else: + lora_strides_d0_tensor = torch.tensor(lora_strides_d0, device=device) + lora_strides_d1_tensor = torch.tensor(lora_strides_d1, device=device) + lora_strides_d2_tensor = torch.tensor(lora_strides_d2, device=device) + hidden_sizes_tensor = torch.tensor(hidden_sizes, device=device) + same_stride = False + # MAX_N is the maximum hidden size among all the lora_b weights + MAX_N = max(hidden_sizes) + _LORA_B_PTR_DICT[key] = (slice_start_tensor, lora_ptr_tensor, + lora_strides_d0_tensor, lora_strides_d1_tensor, + lora_strides_d2_tensor, hidden_sizes_tensor, + same_stride, MAX_N) + return _LORA_B_PTR_DICT.get(key) diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index de378df8b3cfa..278f7b5a8e9f4 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -5,7 +5,7 @@ https://arxiv.org/abs/2310.18547 """ -from typing import Callable, Optional, Tuple, Union, final +from typing import Optional, Tuple, Union, final import torch @@ -16,7 +16,6 @@ from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice from vllm.lora.ops.bgmv_shrink import bgmv_shrink from vllm.lora.ops.sgmv_expand import sgmv_expand - from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice from vllm.lora.ops.sgmv_shrink import sgmv_shrink from .punica_base import PunicaWrapperBase @@ -35,11 +34,11 @@ def __init__(self, max_num_batched_tokens: int, max_batches: int, PunicaWrapperBase.__init__(self, max_num_batched_tokens, max_batches, device) - def _shrink_prefill( + def _apply_shrink_prefill( self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, + w_t_all: Tuple[torch.Tensor, ...], scale: float, ): #No LoRA request, so return directly @@ -53,7 +52,7 @@ def _shrink_prefill( scale, ) - def _shrink_decode( + def _apply_shrink_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -62,56 +61,28 @@ def _shrink_decode( ): bgmv_shrink(x, w_t_all, y, self.token_lora_indices, scale) - def _expand_prefill( + def _apply_expand_prefill( self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + offset_start: int, add_inputs: bool, ): #No LoRA request, so return directly if self.no_lora: return - sgmv_expand( - x, - w_t_all, - y, - *self.prefill_metadata, - add_inputs, - ) - - def _expand_decode( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - add_inputs: bool, - ): - bgmv_expand(x, w_t_all, y, self.token_lora_indices, add_inputs) - def _expand_slice_prefill( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_inputs: bool, - ): - #No LoRA request, so return directly - if self.no_lora: - return - sgmv_expand_slice( + sgmv_expand( x, w_t_all, y, *self.prefill_metadata, - y_offset, - y_slice_size, - add_inputs, + offset_start=offset_start, + add_inputs=add_inputs, ) - def _expand_slice_decode( + def _apply_expand_decode( self, y: torch.Tensor, x: torch.Tensor, @@ -123,43 +94,6 @@ def _expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_inputs) - def _apply_expand( - self, - y: torch.Tensor, - x: torch.Tensor, - w_t_all: torch.Tensor, - y_offset: Optional[int], - y_slice_size: Optional[int], - add_inputs: bool = True, - ): - """ - Perform the ` y[:,y_offset:y_offset+y_slice_size]+=x@w_t_all` - computation, which is suitable for the - GEMM of lora'b. - """ - - expand_slice_fun: Callable = (self._expand_slice_prefill - if self.is_prefill else - self._expand_slice_decode) - expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_inputs) - - def _apply_shrink(self, y: torch.Tensor, x: torch.Tensor, - w_t_all: torch.Tensor, scale: float): - """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the - GEMM of lora'a. - When `is_prefill is` true, it indicates that it is currently the - prefill stage, and the `_shrink_prefill` function should be called. - Otherwise, it is the decode stage, and the _shrink_decode function - should be called. - """ - y_org = y - y = y.view(-1, y.shape[-1]) - shrink_fun: Callable = (self._shrink_prefill - if self.is_prefill else self._shrink_decode) - shrink_fun(y, x, w_t_all, scale) - y = y.view_as(y_org) - def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], x: torch.Tensor, lora_a_stacked: Tuple[torch.Tensor, ...], scale: float, **kwargs): @@ -182,10 +116,15 @@ def add_shrink(self, y: Union[Tuple[torch.Tensor, ...], torch.Tensor], """ x = x.view(-1, x.shape[-1]) - # TODO fuse these kernels - for slice_idx in range(len(lora_a_stacked)): - self._apply_shrink(y[slice_idx], x, lora_a_stacked[slice_idx], - scale) + + if self.is_prefill: + # NOTE fused kernel + self._apply_shrink_prefill(y, x, lora_a_stacked, scale) + else: + # TODO fuse these kernels + for slice_idx in range(len(lora_a_stacked)): + self._apply_shrink_decode(y[slice_idx], x, + lora_a_stacked[slice_idx], scale) def add_expand(self, y: torch.Tensor, @@ -217,20 +156,28 @@ def add_expand(self, """ y_org = y y = y.view(-1, y.shape[-1]) - offset_left = offset_start if lora_bias_stacked is not None: self._apply_bias(self.token_lora_indices, y, output_slices, lora_bias_stacked) - for slice_idx in range(len(lora_b_stacked)): - self._apply_expand( - y, - x[slice_idx], - lora_b_stacked[slice_idx], - offset_left, - output_slices[slice_idx], - add_inputs=add_inputs, - ) - offset_left += output_slices[slice_idx] + if self.is_prefill: + # NOTE fused kernel + self._apply_expand_prefill(y, + x, + lora_b_stacked, + offset_start, + add_inputs=True) + else: + # TODO fuse these kernels + for slice_idx in range(len(lora_b_stacked)): + self._apply_expand_decode( + y, + x[slice_idx], + lora_b_stacked[slice_idx], + offset_start, + output_slices[slice_idx], + add_inputs=add_inputs, + ) + offset_start += output_slices[slice_idx] y = y.view_as(y_org) def add_lora_embedding(self, @@ -252,10 +199,18 @@ def add_lora_embedding(self, add_inputs (bool): Default to True. """ - # Embedding layer only need expand op - expand_fun: Callable = (self._expand_prefill - if self.is_prefill else self._expand_decode) - expand_fun(y, x, lora_b_stacked, add_inputs) + if self.is_prefill: + sgmv_expand( + x.unsqueeze(dim=0), + [lora_b_stacked], + y, + *self.prefill_metadata, + offset_start=0, + add_inputs=add_inputs, + ) + else: + bgmv_expand(x, lora_b_stacked, y, self.token_lora_indices, + add_inputs) def add_lora_linear(self, y: torch.Tensor, @@ -301,10 +256,11 @@ def add_lora_linear(self, r = lora_b_stacked[0].size(-1) # We set the buffer to be float32 by default ,refer to: # https://github.com/triton-lang/triton/issues/1387 - buffer = tuple( - torch.zeros( - (x.size(0), r), dtype=torch.float32, device=x.device) - for _ in range(len(output_slices))) + buffer = torch.zeros( + (len(output_slices), x.size(0), r), + dtype=torch.float32, + device=x.device, + ) self.add_shrink(buffer, x, lora_a_stacked, scale, **kwargs) self.add_expand(y, buffer,