From f7bfb25ef3d2841e1f8fc0da91e28b2e71efffdf Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 12 Jun 2024 13:17:39 +0800 Subject: [PATCH 1/3] add punica dimension for qwen2-72b lora --- csrc/punica/bgmv/bgmv_config.h | 12 ++++++++++-- tests/lora/test_punica.py | 4 ++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 4b376261d30d..5dc833e01a01 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -31,6 +31,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 3328) \ f(in_T, out_T, W_T, narrow, 3456) \ f(in_T, out_T, W_T, narrow, 3584) \ + f(in_T, out_T, W_T, narrow, 3712) \ f(in_T, out_T, W_T, narrow, 4096) \ f(in_T, out_T, W_T, narrow, 4608) \ f(in_T, out_T, W_T, narrow, 5120) \ @@ -41,6 +42,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 6848) \ f(in_T, out_T, W_T, narrow, 6912) \ f(in_T, out_T, W_T, narrow, 7168) \ + f(in_T, out_T, W_T, narrow, 7424) \ f(in_T, out_T, W_T, narrow, 8192) \ f(in_T, out_T, W_T, narrow, 9216) \ f(in_T, out_T, W_T, narrow, 10240) \ @@ -49,6 +51,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 14848) \ f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ f(in_T, out_T, W_T, narrow, 20480) \ @@ -57,6 +60,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 29696) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ f(in_T, out_T, W_T, narrow, 32512) \ @@ -77,9 +81,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py -// Used for defining kernels going from the variety of +// Used for defining kernels going from the variety of // dim in to the narrow dim out - // Using it for the fully sharded column + // Using it for the fully sharded column // parallel LoRA A which splits the rank dim #define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ f(in_T, out_T, W_T, 128, narrow) \ @@ -102,6 +106,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 3328, narrow) \ f(in_T, out_T, W_T, 3456, narrow) \ f(in_T, out_T, W_T, 3584, narrow) \ + f(in_T, out_T, W_T, 3712, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ f(in_T, out_T, W_T, 4608, narrow) \ f(in_T, out_T, W_T, 5120, narrow) \ @@ -112,6 +117,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 6848, narrow) \ f(in_T, out_T, W_T, 6912, narrow) \ f(in_T, out_T, W_T, 7168, narrow) \ + f(in_T, out_T, W_T, 7424, narrow) \ f(in_T, out_T, W_T, 8192, narrow) \ f(in_T, out_T, W_T, 9216, narrow) \ f(in_T, out_T, W_T, 10240, narrow) \ @@ -120,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 13696, narrow) \ f(in_T, out_T, W_T, 13824, narrow) \ f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 14848, narrow) \ f(in_T, out_T, W_T, 15360, narrow) \ f(in_T, out_T, W_T, 16384, narrow) \ f(in_T, out_T, W_T, 20480, narrow) \ @@ -128,6 +135,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 27392, narrow) \ f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 29696, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ f(in_T, out_T, W_T, 32512, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f021c003b132..358305879fba 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -61,6 +61,7 @@ def _lora_ref_impl( 3328, 3456, 3584, + 3712, 4096, 4608, 5120, @@ -71,17 +72,20 @@ def _lora_ref_impl( 6848, 6912, 7168, + 7424, 8192, 9216, 10240, 11008, 13824, 14336, + 14848, 15360, 22016, 24576, 27392, 27648, + 29696, 32000, 32256, 32512, From ba9c33796daa6e110c753e4180af798b828e46fd Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Wed, 12 Jun 2024 14:11:18 +0800 Subject: [PATCH 2/3] add dimension for other qwen2 model --- csrc/punica/bgmv/bgmv_config.h | 24 ++++++++++++++++++++++++ tests/lora/test_punica.py | 12 ++++++++++++ 2 files changed, 36 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 5dc833e01a01..99ddcb725ced 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -18,12 +18,16 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 768) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ + f(in_T, out_T, W_T, narrow, 1216) \ f(in_T, out_T, W_T, narrow, 1280) \ f(in_T, out_T, W_T, narrow, 1536) \ f(in_T, out_T, W_T, narrow, 1728) \ f(in_T, out_T, W_T, narrow, 1792) \ f(in_T, out_T, W_T, narrow, 2048) \ + f(in_T, out_T, W_T, narrow, 2240) \ f(in_T, out_T, W_T, narrow, 2304) \ + f(in_T, out_T, W_T, narrow, 2368) \ + f(in_T, out_T, W_T, narrow, 2432) \ f(in_T, out_T, W_T, narrow, 2560) \ f(in_T, out_T, W_T, narrow, 2752) \ f(in_T, out_T, W_T, narrow, 2816) \ @@ -33,7 +37,10 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 3584) \ f(in_T, out_T, W_T, narrow, 3712) \ f(in_T, out_T, W_T, narrow, 4096) \ + f(in_T, out_T, W_T, narrow, 4480) \ f(in_T, out_T, W_T, narrow, 4608) \ + f(in_T, out_T, W_T, narrow, 4736) \ + f(in_T, out_T, W_T, narrow, 4864) \ f(in_T, out_T, W_T, narrow, 5120) \ f(in_T, out_T, W_T, narrow, 5504) \ f(in_T, out_T, W_T, narrow, 5632) \ @@ -44,22 +51,27 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 7168) \ f(in_T, out_T, W_T, narrow, 7424) \ f(in_T, out_T, W_T, narrow, 8192) \ + f(in_T, out_T, W_T, narrow, 8960) \ f(in_T, out_T, W_T, narrow, 9216) \ + f(in_T, out_T, W_T, narrow, 9472) \ f(in_T, out_T, W_T, narrow, 10240) \ f(in_T, out_T, W_T, narrow, 11008) \ f(in_T, out_T, W_T, narrow, 12288) \ f(in_T, out_T, W_T, narrow, 13696) \ f(in_T, out_T, W_T, narrow, 13824) \ f(in_T, out_T, W_T, narrow, 14336) \ + f(in_T, out_T, W_T, narrow, 14784) \ f(in_T, out_T, W_T, narrow, 14848) \ f(in_T, out_T, W_T, narrow, 15360) \ f(in_T, out_T, W_T, narrow, 16384) \ + f(in_T, out_T, W_T, narrow, 18944) \ f(in_T, out_T, W_T, narrow, 20480) \ f(in_T, out_T, W_T, narrow, 22016) \ f(in_T, out_T, W_T, narrow, 24576) \ f(in_T, out_T, W_T, narrow, 27392) \ f(in_T, out_T, W_T, narrow, 27648) \ f(in_T, out_T, W_T, narrow, 28672) \ + f(in_T, out_T, W_T, narrow, 29568) \ f(in_T, out_T, W_T, narrow, 29696) \ f(in_T, out_T, W_T, narrow, 32000) \ f(in_T, out_T, W_T, narrow, 32256) \ @@ -93,12 +105,16 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 768, narrow) \ f(in_T, out_T, W_T, 1024, narrow) \ f(in_T, out_T, W_T, 1152, narrow) \ + f(in_T, out_T, W_T, 1216, narrow) \ f(in_T, out_T, W_T, 1280, narrow) \ f(in_T, out_T, W_T, 1536, narrow) \ f(in_T, out_T, W_T, 1728, narrow) \ f(in_T, out_T, W_T, 1792, narrow) \ f(in_T, out_T, W_T, 2048, narrow) \ + f(in_T, out_T, W_T, 2240, narrow) \ f(in_T, out_T, W_T, 2304, narrow) \ + f(in_T, out_T, W_T, 2368, narrow) \ + f(in_T, out_T, W_T, 2432, narrow) \ f(in_T, out_T, W_T, 2560, narrow) \ f(in_T, out_T, W_T, 2752, narrow) \ f(in_T, out_T, W_T, 2816, narrow) \ @@ -108,7 +124,10 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 3584, narrow) \ f(in_T, out_T, W_T, 3712, narrow) \ f(in_T, out_T, W_T, 4096, narrow) \ + f(in_T, out_T, W_T, 4480, narrow) \ f(in_T, out_T, W_T, 4608, narrow) \ + f(in_T, out_T, W_T, 4736, narrow) \ + f(in_T, out_T, W_T, 4864, narrow) \ f(in_T, out_T, W_T, 5120, narrow) \ f(in_T, out_T, W_T, 5504, narrow) \ f(in_T, out_T, W_T, 5632, narrow) \ @@ -119,22 +138,27 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 7168, narrow) \ f(in_T, out_T, W_T, 7424, narrow) \ f(in_T, out_T, W_T, 8192, narrow) \ + f(in_T, out_T, W_T, 8960, narrow) \ f(in_T, out_T, W_T, 9216, narrow) \ + f(in_T, out_T, W_T, 9472, narrow) \ f(in_T, out_T, W_T, 10240, narrow) \ f(in_T, out_T, W_T, 11008, narrow) \ f(in_T, out_T, W_T, 12288, narrow) \ f(in_T, out_T, W_T, 13696, narrow) \ f(in_T, out_T, W_T, 13824, narrow) \ f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 14784, narrow) \ f(in_T, out_T, W_T, 14848, narrow) \ f(in_T, out_T, W_T, 15360, narrow) \ f(in_T, out_T, W_T, 16384, narrow) \ + f(in_T, out_T, W_T, 18944, narrow) \ f(in_T, out_T, W_T, 20480, narrow) \ f(in_T, out_T, W_T, 22016, narrow) \ f(in_T, out_T, W_T, 24576, narrow) \ f(in_T, out_T, W_T, 27392, narrow) \ f(in_T, out_T, W_T, 27648, narrow) \ f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 29568, narrow) \ f(in_T, out_T, W_T, 29696, narrow) \ f(in_T, out_T, W_T, 32000, narrow) \ f(in_T, out_T, W_T, 32256, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 358305879fba..212a0faa7efa 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -51,10 +51,14 @@ def _lora_ref_impl( 512, 1024, 1152, + 1216, 1280, 1536, 2048, + 2240, 2304, + 2368, + 2432, 2560, 2752, 3072, @@ -63,7 +67,10 @@ def _lora_ref_impl( 3584, 3712, 4096, + 4480, 4608, + 4736, + 4864, 5120, 5504, 5632, @@ -74,17 +81,22 @@ def _lora_ref_impl( 7168, 7424, 8192, + 8960, 9216, + 9472, 10240, 11008, 13824, 14336, + 14784, 14848, 15360, + 18944, 22016, 24576, 27392, 27648, + 29568, 29696, 32000, 32256, From 2cfbf518d3eef009ab0ffd1e50b2a061046c4f17 Mon Sep 17 00:00:00 2001 From: Jinzhen Lin Date: Thu, 13 Jun 2024 09:49:09 +0800 Subject: [PATCH 3/3] add 896 --- csrc/punica/bgmv/bgmv_config.h | 2 ++ tests/lora/test_punica.py | 1 + 2 files changed, 3 insertions(+) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index 99ddcb725ced..732466587454 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -16,6 +16,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, narrow, 512) \ f(in_T, out_T, W_T, narrow, 640) \ f(in_T, out_T, W_T, narrow, 768) \ + f(in_T, out_T, W_T, narrow, 896) \ f(in_T, out_T, W_T, narrow, 1024) \ f(in_T, out_T, W_T, narrow, 1152) \ f(in_T, out_T, W_T, narrow, 1216) \ @@ -103,6 +104,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, f(in_T, out_T, W_T, 512, narrow) \ f(in_T, out_T, W_T, 640, narrow) \ f(in_T, out_T, W_T, 768, narrow) \ + f(in_T, out_T, W_T, 896, narrow) \ f(in_T, out_T, W_T, 1024, narrow) \ f(in_T, out_T, W_T, 1152, narrow) \ f(in_T, out_T, W_T, 1216, narrow) \ diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index 212a0faa7efa..aea89a58ad1d 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -49,6 +49,7 @@ def _lora_ref_impl( 128, 256, 512, + 896, 1024, 1152, 1216,