diff --git a/.buildkite/run-neuron-test.sh b/.buildkite/run-neuron-test.sh index 8ba03b78e8db..252c0f7fecd1 100644 --- a/.buildkite/run-neuron-test.sh +++ b/.buildkite/run-neuron-test.sh @@ -4,6 +4,20 @@ set -e # Try building the docker image aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin 763104351884.dkr.ecr.us-west-2.amazonaws.com + +# prune old image and containers to save disk space, and only once a day +# by using a timestamp file in tmp. +if [ -f /tmp/neuron-docker-build-timestamp ]; then + last_build=$(cat /tmp/neuron-docker-build-timestamp) + current_time=$(date +%s) + if [ $((current_time - last_build)) -gt 86400 ]; then + docker system prune -f + echo $current_time > /tmp/neuron-docker-build-timestamp + fi +else + echo $(date +%s) > /tmp/neuron-docker-build-timestamp +fi + docker build -t neuron -f Dockerfile.neuron . # Setup cleanup diff --git a/.buildkite/test-template.j2 b/.buildkite/test-template.j2 index fb1086db7782..5c9515840bb0 100644 --- a/.buildkite/test-template.j2 +++ b/.buildkite/test-template.j2 @@ -25,6 +25,7 @@ steps: agents: queue: neuron command: bash .buildkite/run-neuron-test.sh + soft_fail: true - label: "CPU Test" command: bash .buildkite/run-cpu-test.sh diff --git a/.github/ISSUE_TEMPLATE/750-RFC.yml b/.github/ISSUE_TEMPLATE/750-RFC.yml new file mode 100644 index 000000000000..5382b124dcd7 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/750-RFC.yml @@ -0,0 +1,49 @@ +name: 💬 Request for comments (RFC). +description: Ask for feedback on major architectural changes or design choices. +title: "[RFC]: " +labels: ["RFC"] + +body: +- type: markdown + attributes: + value: > + #### Please take a look at previous [RFCs](https://github.com/vllm-project/vllm/issues?q=label%3ARFC+sort%3Aupdated-desc) for reference. +- type: textarea + attributes: + label: Motivation. + description: > + The motivation of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Proposed Change. + description: > + The proposed change of the RFC. + validations: + required: true +- type: textarea + attributes: + label: Feedback Period. + description: > + The feedback period of the RFC. Usually at least one week. + validations: + required: false +- type: textarea + attributes: + label: CC List. + description: > + The list of people you want to CC. + validations: + required: false +- type: textarea + attributes: + label: Any Other Things. + description: > + Any other things you would like to mention. + validations: + required: false +- type: markdown + attributes: + value: > + Thanks for contributing 🎉! diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 089c7d18ad6f..a19be8525f90 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -43,8 +43,8 @@ jobs: mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml + mypy vllm/model_executor --config-file pyproject.toml # TODO(sang): Fix nested dir - mypy vllm/model_executor/*.py --config-file pyproject.toml mypy vllm/core/*.py --follow-imports=skip --config-file pyproject.toml diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 4b9fc3d04d87..d79681f03b00 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -49,7 +49,7 @@ jobs: matrix: os: ['ubuntu-20.04'] python-version: ['3.8', '3.9', '3.10', '3.11'] - pytorch-version: ['2.2.1'] # Must be the most recent version that meets requirements-cuda.txt. + pytorch-version: ['2.3.0'] # Must be the most recent version that meets requirements-cuda.txt. cuda-version: ['11.8', '12.1'] steps: diff --git a/CMakeLists.txt b/CMakeLists.txt index e9262b57d086..f817f3382c5e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,7 +31,7 @@ set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx11 # requirements.txt files and should be kept consistent. The ROCm torch # versions are derived from Dockerfile.rocm # -set(TORCH_SUPPORTED_VERSION_CUDA "2.2.1") +set(TORCH_SUPPORTED_VERSION_CUDA "2.3.0") set(TORCH_SUPPORTED_VERSION_ROCM_5X "2.0.1") set(TORCH_SUPPORTED_VERSION_ROCM_6X "2.1.1") @@ -177,6 +177,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" "csrc/quantization/marlin/marlin_cuda_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" "csrc/custom_all_reduce.cu") endif() diff --git a/Dockerfile b/Dockerfile index d1d29177b0f4..e471a6e93b96 100644 --- a/Dockerfile +++ b/Dockerfile @@ -85,7 +85,7 @@ FROM dev as flash-attn-builder ARG max_jobs=2 ENV MAX_JOBS=${max_jobs} # flash attention version -ARG flash_attn_version=v2.5.6 +ARG flash_attn_version=v2.5.8 ENV FLASH_ATTN_VERSION=${flash_attn_version} WORKDIR /usr/src/flash-attention-v2 diff --git a/csrc/ops.h b/csrc/ops.h index ff7a3de1a0a8..04b97d1784cd 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -124,6 +124,24 @@ torch::Tensor marlin_gemm( int64_t size_m, int64_t size_n, int64_t size_k); + +torch::Tensor gptq_marlin_gemm( + torch::Tensor &a, + torch::Tensor &b_q_weight, + torch::Tensor &b_scales, + torch::Tensor &g_idx, + torch::Tensor &perm, + torch::Tensor &workspace, + int64_t size_m, + int64_t size_n, + int64_t size_k, + bool is_k_full); + +torch::Tensor gptq_marlin_repack( + torch::Tensor &b_q_weight, + torch::Tensor &perm, + int64_t size_k, + int64_t size_n); #endif void squeezellm_gemm( @@ -146,7 +164,12 @@ void gptq_shuffle( torch::Tensor q_perm, int bit); -void scaled_fp8_quant( +void static_scaled_fp8_quant( + torch::Tensor& out, + torch::Tensor& input, + torch::Tensor& scale); + +void dynamic_scaled_fp8_quant( torch::Tensor& out, torch::Tensor& input, torch::Tensor& scale); diff --git a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu index c642e94925fe..86846c274c90 100644 --- a/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu index 0607cebfeac4..de39c3121f5d 100644 --- a/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu +++ b/csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_bfloat16, float, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_bfloat16, float, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_config.h b/csrc/punica/bgmv/bgmv_config.h index fec484d69305..19c058cacfbc 100644 --- a/csrc/punica/bgmv/bgmv_config.h +++ b/csrc/punica/bgmv/bgmv_config.h @@ -74,6 +74,74 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, // Keep above in sync with vllm/lora/layers::LogitsProcessorWithLoRA // and vllm/tests/lora/test_punica.py +// Used for defining kernels going from the variety of +// dim in to the narrow dim out + // Using it for the fully sharded column + // parallel LoRA A which splits the rank dim +#define FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, narrow) \ + f(in_T, out_T, W_T, 128, narrow) \ + f(in_T, out_T, W_T, 256, narrow) \ + f(in_T, out_T, W_T, 512, narrow) \ + f(in_T, out_T, W_T, 640, narrow) \ + f(in_T, out_T, W_T, 768, narrow) \ + f(in_T, out_T, W_T, 1024, narrow) \ + f(in_T, out_T, W_T, 1152, narrow) \ + f(in_T, out_T, W_T, 1280, narrow) \ + f(in_T, out_T, W_T, 1536, narrow) \ + f(in_T, out_T, W_T, 1728, narrow) \ + f(in_T, out_T, W_T, 1792, narrow) \ + f(in_T, out_T, W_T, 2048, narrow) \ + f(in_T, out_T, W_T, 2304, narrow) \ + f(in_T, out_T, W_T, 2560, narrow) \ + f(in_T, out_T, W_T, 2752, narrow) \ + f(in_T, out_T, W_T, 2816, narrow) \ + f(in_T, out_T, W_T, 3072, narrow) \ + f(in_T, out_T, W_T, 3456, narrow) \ + f(in_T, out_T, W_T, 3584, narrow) \ + f(in_T, out_T, W_T, 4096, narrow) \ + f(in_T, out_T, W_T, 4608, narrow) \ + f(in_T, out_T, W_T, 5120, narrow) \ + f(in_T, out_T, W_T, 5504, narrow) \ + f(in_T, out_T, W_T, 5632, narrow) \ + f(in_T, out_T, W_T, 6144, narrow) \ + f(in_T, out_T, W_T, 6848, narrow) \ + f(in_T, out_T, W_T, 6912, narrow) \ + f(in_T, out_T, W_T, 7168, narrow) \ + f(in_T, out_T, W_T, 8192, narrow) \ + f(in_T, out_T, W_T, 9216, narrow) \ + f(in_T, out_T, W_T, 10240, narrow) \ + f(in_T, out_T, W_T, 11008, narrow) \ + f(in_T, out_T, W_T, 12288, narrow) \ + f(in_T, out_T, W_T, 13696, narrow) \ + f(in_T, out_T, W_T, 13824, narrow) \ + f(in_T, out_T, W_T, 14336, narrow) \ + f(in_T, out_T, W_T, 15360, narrow) \ + f(in_T, out_T, W_T, 16384, narrow) \ + f(in_T, out_T, W_T, 20480, narrow) \ + f(in_T, out_T, W_T, 22016, narrow) \ + f(in_T, out_T, W_T, 24576, narrow) \ + f(in_T, out_T, W_T, 27392, narrow) \ + f(in_T, out_T, W_T, 28672, narrow) \ + f(in_T, out_T, W_T, 32000, narrow) \ + f(in_T, out_T, W_T, 32256, narrow) \ + f(in_T, out_T, W_T, 32512, narrow) \ + f(in_T, out_T, W_T, 32768, narrow) \ + f(in_T, out_T, W_T, 33024, narrow) \ + f(in_T, out_T, W_T, 36864, narrow) \ + f(in_T, out_T, W_T, 43264, narrow) \ + f(in_T, out_T, W_T, 49152, narrow) \ + f(in_T, out_T, W_T, 64000, narrow) \ + f(in_T, out_T, W_T, 64256, narrow) \ + f(in_T, out_T, W_T, 64512, narrow) \ + f(in_T, out_T, W_T, 102400, narrow) \ + f(in_T, out_T, W_T, 102656, narrow) \ + f(in_T, out_T, W_T, 102912, narrow) \ + f(in_T, out_T, W_T, 128000, narrow) \ + f(in_T, out_T, W_T, 128256, narrow) \ + f(in_T, out_T, W_T, 128512, narrow) \ +// Keep above in sync with vllm/lora/layers::SamplerWithLoRA + + // Keep this in sync with vllm/config::LoRAConfig #define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \ @@ -81,4 +149,14 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \ FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64) + +#define FOR_INST_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 1) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 2) \ + FOR_INST_BGMV_NARROW(f, in_T, out_T, W_T, 4) \ + f(in_T, out_T, W_T, 8, 64) \ + f(in_T, out_T, W_T, 16, 64) \ + f(in_T, out_T, W_T, 32, 64) \ + f(in_T, out_T, W_T, 64, 64) + // clang-format on diff --git a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu index f1db6df5f733..d225a1eaa82b 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu index c01ddd009d74..b37d288a7556 100644 --- a/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, nv_half, float, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, nv_half, float, nv_half) diff --git a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu index f45183ffd348..a1ab2deecbab 100644 --- a/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_bfloat16, nv_bfloat16) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_bfloat16, nv_bfloat16) diff --git a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu index 409774348808..0b35bf569989 100644 --- a/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu +++ b/csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu @@ -2,3 +2,4 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, float, nv_half, nv_half) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, float, nv_half, nv_half) diff --git a/csrc/punica/bgmv/bgmv_impl.cuh b/csrc/punica/bgmv/bgmv_impl.cuh index 995de26e8bad..dad8805c750c 100644 --- a/csrc/punica/bgmv/bgmv_impl.cuh +++ b/csrc/punica/bgmv/bgmv_impl.cuh @@ -199,7 +199,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, constexpr int tz = 4; const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - if constexpr (feat_in < feat_out) { + if constexpr (feat_in <= feat_out) { static_assert(feat_in % vec_size == 0); constexpr int tx = feat_in / vec_size; @@ -289,6 +289,9 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X, int64_t y_offset, int64_t full_y_size, int64_t batch_size, \ int64_t num_layers, int64_t layer_idx, float scale); +#define INST_BGMV_ONESIDE(in_T, out_T, W_T, feat_in, feat_out) \ + INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) + #define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \ INST_BGMV(narrow, wide, in_T, out_T, W_T) \ INST_BGMV(wide, narrow, in_T, out_T, W_T) diff --git a/csrc/punica/bgmv/generator.py b/csrc/punica/bgmv/generator.py index 9bf7f6358880..972df5a7208c 100644 --- a/csrc/punica/bgmv/generator.py +++ b/csrc/punica/bgmv/generator.py @@ -10,6 +10,7 @@ #include "bgmv_impl.cuh" FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype}) +FOR_INST_BGMV_WIDE_NARROW(INST_BGMV_ONESIDE, {input_dtype}, {output_dtype}, {weight_dtype}) """.lstrip() # noqa: E501 for input_dtype in DTYPES: diff --git a/csrc/punica/punica_ops.cc b/csrc/punica/punica_ops.cc index a1eaa90e85f2..8797fde85744 100644 --- a/csrc/punica/punica_ops.cc +++ b/csrc/punica/punica_ops.cc @@ -79,12 +79,12 @@ inline bool launch_bgmv_kernel(out_T *Y, const in_T *X, const W_T *W, CASE_ONESIDE(in_T, out_T, W_T, wide, narrow) FOR_BGMV_WIDE_NARROW(CASE, _, _, _) + FOR_INST_BGMV_WIDE_NARROW(CASE_ONESIDE, _, _, _) #undef CASE #undef CASE_ONESIDE default: return false; } - return true; } diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a5b16c5abc3e..9839bfc0331c 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -67,13 +67,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ops.def("aqlm_dequant", &aqlm_dequant, "Decompression method for AQLM"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("marlin_gemm", &marlin_gemm, "Marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_gemm", &gptq_marlin_gemm, "gptq_marlin Optimized Quantized GEMM for GPTQ"); + ops.def("gptq_marlin_repack", &gptq_marlin_repack, "gptq_marlin repack from GPTQ"); ops.def("awq_dequantize", &awq_dequantize, "Dequantization for AWQ"); #endif ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); - ops.def("scaled_fp8_quant", &scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); + ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor"); + ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor"); ops.def( "moe_align_block_size", &moe_align_block_size, diff --git a/csrc/quantization/fp8/fp8_cuda_kernels.cu b/csrc/quantization/fp8/fp8_cuda_kernels.cu index c3337cede128..2477051eb60d 100644 --- a/csrc/quantization/fp8/fp8_cuda_kernels.cu +++ b/csrc/quantization/fp8/fp8_cuda_kernels.cu @@ -74,7 +74,30 @@ __global__ void scaled_fp8_quant_kernel( } // namespace vllm -void scaled_fp8_quant( +void static_scaled_fp8_quant( + torch::Tensor& out, // [..., d] + torch::Tensor& input, // [..., d] + torch::Tensor& scale) // [1] +{ + int64_t num_tokens = input.numel() / input.size(-1); + int64_t num_elems = input.numel(); + dim3 grid(num_tokens); + dim3 block(1024); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), + "scaled_fp8_quant_kernel", + [&] { + vllm::scaled_fp8_quant_kernel<<>>( + out.data_ptr(), + input.data_ptr(), + scale.data_ptr(), + num_elems); + }); +} + +void dynamic_scaled_fp8_quant( torch::Tensor& out, // [..., d] torch::Tensor& input, // [..., d] torch::Tensor& scale) // [1] diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu new file mode 100644 index 000000000000..9902f55167d8 --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -0,0 +1,1520 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#include "gptq_marlin.cuh" + +template inline std::string str(T x) { return std::to_string(x); } + +namespace gptq_marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) {} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &g_idx, + torch::Tensor &perm, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full) { + TORCH_CHECK_NOT_IMPLEMENTED(false, + "marlin_gemm(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA &a_frag, const FragB &frag_b, + FragC &frag_c) { + const uint32_t *a = reinterpret_cast(&a_frag); + const uint32_t *b = reinterpret_cast(&frag_b); + float *c = reinterpret_cast(&frag_c); + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), + "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA &frag_a, const void *smem_ptr) { + uint32_t *a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template __device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB &frag_b, FragS &frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half *>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB &frag_b, FragS &frag_s_1, FragS &frag_s_2, + FragS &frag_s_3, FragS &frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half *>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half *>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half *>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half *>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int *lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int *lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const *__restrict__ a_int4_ptr, + int const *__restrict__ perm_int_ptr, + int4 *__restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int offset = row * row_stride; + + half const *a_row_half = + reinterpret_cast(a_int4_ptr + offset); + half *out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks with + // a separate quantization scale + > +__global__ void +Marlin(const int4 *__restrict__ A, // fp16 input matrix of shape mxk + const int4 *__restrict__ B, // 4bit quantized weight matrix of shape kxn + int4 *__restrict__ C, // fp16 output buffer of shape mxn + const int4 *__restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int *__restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int *locks // extra global storage for barrier synchronization +) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) + slice_iters = 0; + if (slice_iters == 0) + return; + if (slice_row + slice_iters > k_tiles) + slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) + slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) + slice_idx--; + } + } + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * prob_k / 8; + C += 16 * thread_m_blocks * prob_n / 8; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4 *B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4 *sh_a = sh; + int4 *sh_b = sh_a + (stages * a_sh_stage); + int4 *sh_g_idx = sh_b + (stages * b_sh_stage); + int4 *sh_s = sh_g_idx + (stages * g_idx_stage); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred( + &sh_a_stage[a_sh_wr_trans[i]], + &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4_stream(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const *cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4_stream(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4_stream(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4 *sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4 *sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4 *sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4 *sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4 *sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int *sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + + FragB frag_b0 = dequant(b_quant); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + FragB frag_b1 = dequant(b_quant_shift); + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float *c_rd = reinterpret_cast( + &sh[red_sh_delta * j + red_sh_rd]); + float *c_wr = reinterpret_cast(&sh[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float *c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped portioning + // minimizes the number of such reductions and our outputs are usually rather + // small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up the +// compiler and lead to slowdowns, hence we also use async-copies even though +// these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || + 8 * (i / 2) + row < prob_m); + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half *>(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half *>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = + c; + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS &s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { + res = __hmul2(res, s[0]); + } + + ((half2 *)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; + i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + C[c_gl_wr] = sh[c_sh_rd]; + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + if (s_sh_wr_pred) { + cp_async4_stream(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + + // if (blockIdx.x == 0 && threadIdx.x == 0) { + // printf("Move\n"); + // } + start_pipes(); + } + } + } +} + +#define __CALL_IF(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + Marlin, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + Marlin \ + <<>>( \ + A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \ + prob_k, locks); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 128, 128}, // Reduce N 2X, same K + // {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const &th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + + // TODO: Enable if needed after some more testing + if (prob_m <= 0) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_cuda(const void *A, const void *B, void *C, void *s, void *g_idx, + void *perm, void *a_tmp, int prob_m, int prob_n, int prob_k, + void *workspace, bool has_act_order, bool is_k_full, + int num_groups, int group_size, int dev = 0, + cudaStream_t stream = 0, int thread_k = -1, int thread_n = -1, + int sms = -1, int max_par = 16) { + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + int tot_m = prob_m; + int tot_m_blocks = div_ceil(tot_m, 16); + int pad = 16 * tot_m_blocks - tot_m; + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, default_threads}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + + " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + + str(prob_n) + "]"); + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + const int4 *A_ptr = (const int4 *)A; + const int4 *B_ptr = (const int4 *)B; + int4 *C_ptr = (int4 *)C; + const int4 *s_ptr = (const int4 *)s; + const int *g_idx_ptr = (const int *)g_idx; + const int *perm_ptr = (const int *)perm; + int4 *a_tmp_ptr = (int4 *)a_tmp; + + int *locks = (int *)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by having + // a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + // Main loop + for (int i = 0; i < tot_m_blocks; i += 4) { + int thread_m_blocks = tot_m_blocks - i; + prob_m = tot_m - 16 * i; + int par = 1; + if (thread_m_blocks > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * thread_m_blocks - pad) / 64; + if (par > max_par) + par = max_par; + prob_m = 64 * par; + i += 4 * (par - 1); + thread_m_blocks = 4; + } + + // Define kernel configurations + if (false) { + } + CALL_IF(16, 4, 256) + CALL_IF(8, 8, 256) + CALL_IF(8, 4, 128) + CALL_IF(4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + + A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par; + C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par; + } +} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_gemm(torch::Tensor &a, torch::Tensor &b_q_weight, + torch::Tensor &b_scales, torch::Tensor &g_idx, + torch::Tensor &perm, torch::Tensor &workspace, + int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full) { + // Verify A + TORCH_CHECK(a.size(0) == size_m, + "Shape mismatch: a.size(0) = " + str(a.size(0)) + + ", size_m = " + str(size_m)); + TORCH_CHECK(a.size(1) == size_k, + "Shape mismatch: a.size(1) = " + str(a.size(1)) + + ", size_k = " + str(size_k)); + + // Verify B + TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, + "size_k = " + str(size_k) + " is not divisible by tile_size = " + + str(gptq_marlin::tile_size)); + TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = " + + str(b_q_weight.size(0)) + ", size_k = " + str(size_k) + + ", tile_size = " + str(gptq_marlin::tile_size)); + TORCH_CHECK( + b_q_weight.size(1) % gptq_marlin::tile_size == 0, + "b_q_weight.size(1) = " + str(b_q_weight.size(1)) + + " is not divisible by tile_size = " + str(gptq_marlin::tile_size)); + int actual_size_n = (b_q_weight.size(1) / gptq_marlin::tile_size) * + gptq_marlin::pack_factor_4bit; + TORCH_CHECK(size_n == actual_size_n, + "size_n = " + str(size_n) + + ", actual_size_n = " + str(actual_size_n)); + + // Verify device and strides + TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); + TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + + TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); + TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + + TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU"); + TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); + auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device()); + torch::Tensor c = torch::empty({size_m, size_n}, options); + torch::Tensor a_tmp = torch::empty({size_m, size_k}, options); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Verify g_idx and perm + TORCH_CHECK((g_idx.size(0) == 0 && perm.size(0) == 0) || + (g_idx.size(0) == size_k && perm.size(0) == size_k), + "Unexpected g_idx.size(0) = " + str(g_idx.size(0)) + + " and perm.size(0) = " + str(perm.size(0)) + + ", where size_k = " + str(size_k)); + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(0) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2"); + TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1), + " is not size_n = ", size_n); + num_groups = b_scales.size(0); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, + "size_k = " + str(size_k) + + ", is not divisible by num_groups = " + str(num_groups)); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK(size_k % num_groups == 0, + "size_k = " + str(size_k) + + ", is not divisible by b_scales.size(0) = " + + str(b_scales.size(0))); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + // Verify workspace size + TORCH_CHECK(size_n % gptq_marlin::min_thread_n == 0, + "size_n = " + str(size_n) + + ", is not divisible by min_thread_n = " + + str(gptq_marlin::min_thread_n)); + int min_workspace_size = + (size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par; + TORCH_CHECK(workspace.numel() >= min_workspace_size, + "workspace.numel = " + str(workspace.numel()) + + " is below min_workspace_size = " + str(min_workspace_size)); + + int dev = a.get_device(); + gptq_marlin::marlin_cuda( + a.data_ptr(), b_q_weight.data_ptr(), c.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), size_m, size_n, + size_k, workspace.data_ptr(), has_act_order, is_k_full, num_groups, + group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, + sms, gptq_marlin::max_par); + + return c; +} + +#endif diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cuh b/csrc/quantization/gptq_marlin/gptq_marlin.cuh new file mode 100644 index 000000000000..8cfce6b2575d --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cuh @@ -0,0 +1,74 @@ +#pragma once + +#include + +#include +#include +#include +#include +#include +#include + +namespace gptq_marlin { + +// 8 warps are a good choice since every SM has 4 schedulers and having more than 1 warp per +// schedule allows some more latency hiding. At the same time, we want relatively few warps to have +// many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +static constexpr int pack_factor_4bit = 8; // We have 8 4-bit vals inside a 32 bit + +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +constexpr int div_ceil(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + // No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async4_stream(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("{\n" + " .reg .b64 p;\n" + " createpolicy.fractional.L2::evict_first.b64 p, 1.0;" + " cp.async.cg.shared.global.L2::cache_hint [%0], [%1], %2, p;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { asm volatile("cp.async.commit_group;\n" ::); } + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace gptq_marlin diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu new file mode 100644 index 000000000000..fa45ce68a0c7 --- /dev/null +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -0,0 +1,324 @@ +#include "gptq_marlin.cuh" + +namespace gptq_marlin { + +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +template +__global__ void +marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, int size_k, int size_n) {} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, + int64_t size_k, int64_t size_n) { + TORCH_CHECK_NOT_IMPLEMENTED( + false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); + return torch::empty({1, 1}); +} + +#else + +template +__global__ void +marlin_repack_kernel(uint32_t const *__restrict__ b_q_weight_ptr, + uint32_t const *__restrict__ perm_ptr, + uint32_t *__restrict__ out_ptr, int size_k, int size_n) { + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + int start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4 *sh_perm_ptr = sh; + int4 *sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = + has_perm ? tile_k_size : tile_k_size / pack_factor_4bit; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const *perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4 *sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + uint32_t const *sh_perm_int_ptr = + reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor_4bit; + + cp_async4_stream( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&( + b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + int k_id = threadIdx.x / stage_n_threads; + int n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor_4bit; + + cp_async4_stream(&sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k_packed + k_id) * size_n + + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + int warp_id = threadIdx.x / 32; + int th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + + int4 *sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t *sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t *sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[pack_factor_4bit]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor_4bit; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * 4)) & 0xf; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * 4)) & 0xf; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + + uint32_t b1_val_1 = sh_stage_int_ptr[cur_n]; + uint32_t b1_val_2 = sh_stage_int_ptr[sh_stride + cur_n]; + + uint32_t b2_val_1 = sh_stage_int_ptr[cur_n + 8]; + uint32_t b2_val_2 = sh_stage_int_ptr[sh_stride + cur_n + 8]; + +#pragma unroll + for (int i = 0; i < 2; i++) { + int cur_elem = tc_row + tc_offsets[i]; + vals[i] = (b1_val_1 >> (cur_elem * 4)) & 0xf; + vals[4 + i] = (b2_val_1 >> (cur_elem * 4)) & 0xf; + } + +#pragma unroll + for (int i = 2; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i] - 8; + vals[i] = (b1_val_2 >> (cur_elem * 4)) & 0xf; + vals[4 + i] = (b2_val_2 >> (cur_elem * 4)) & 0xf; + } + } + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + constexpr int pack_idx[pack_factor_4bit] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < pack_factor_4bit; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor_4bit; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, + n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} + +} // namespace gptq_marlin + +torch::Tensor gptq_marlin_repack(torch::Tensor &b_q_weight, torch::Tensor &perm, + int64_t size_k, int64_t size_n) { + // Verify compatibility with marlin tile of 16x64 + TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k, + " is not divisible by tile_k_size = ", gptq_marlin::tile_k_size); + TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n, + " is not divisible by tile_n_size = ", gptq_marlin::tile_n_size); + + // Verify B + TORCH_CHECK((size_k / gptq_marlin::pack_factor_4bit) == b_q_weight.size(0), + "Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0), + ", size_k = ", size_k, + ", pack_factor_4bit = ", gptq_marlin::pack_factor_4bit); + TORCH_CHECK(b_q_weight.size(1) == size_n, + "b_q_weight.size(1) = ", b_q_weight.size(1), + " is not size_n = ", size_n); + + // Verify device and strides + TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); + TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous"); + TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt"); + + TORCH_CHECK(perm.device().is_cuda(), "perm is not on GPU"); + TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous"); + TORCH_CHECK(perm.dtype() == at::kInt, "perm type is not at::kInt"); + + // Alloc buffers + const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight)); + auto options = torch::TensorOptions() + .dtype(b_q_weight.dtype()) + .device(b_q_weight.device()); + torch::Tensor out = torch::empty( + {size_k / gptq_marlin::tile_size, + size_n * gptq_marlin::tile_size / gptq_marlin::pack_factor_4bit}, + options); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const *b_q_weight_ptr = + reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const *perm_ptr = + reinterpret_cast(perm.data_ptr()); + uint32_t *out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + int dev = b_q_weight.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + int blocks; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + if (has_perm) { + cudaFuncSetAttribute( + gptq_marlin::marlin_repack_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + gptq_marlin::marlin_repack_kernel + <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); + + } else { + cudaFuncSetAttribute( + gptq_marlin::marlin_repack_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + max_shared_mem); + gptq_marlin::marlin_repack_kernel + <<>>(b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); + } + + return out; +} + +#endif diff --git a/examples/production_monitoring/grafana.json b/examples/production_monitoring/grafana.json index 071f134c6e5e..5e9bd5bd0386 100644 --- a/examples/production_monitoring/grafana.json +++ b/examples/production_monitoring/grafana.json @@ -873,6 +873,289 @@ ], "title": "Cache Utilization", "type": "timeseries" + }, + { + "type": "heatmap", + "title": "Request Prompt Length", + "description": "Heatmap of request prompt length", + "gridPos": { + "x": 0, + "y": 24, + "w": 12, + "h": 8 + }, + "datasource": { + "uid": "prometheus", + "type": "prometheus" + }, + "id": 12, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "refId": "A", + "expr": "sum by(le) (increase(vllm:request_prompt_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "range": true, + "instant": false, + "editorMode": "builder", + "legendFormat": "{{le}}", + "useBackend": false, + "disableTextWrap": false, + "fullMetaSearch": false, + "includeNullMetadata": true, + "format": "heatmap" + } + ], + "options": { + "calculate": false, + "yAxis": { + "axisPlacement": "left", + "reverse": false, + "unit": "none", + "axisLabel": "Prompt Length" + }, + "rowsFrame": { + "layout": "auto", + "value": "Request count" + }, + "color": { + "mode": "scheme", + "fill": "dark-orange", + "scale": "exponential", + "exponent": 0.5, + "scheme": "Spectral", + "steps": 64, + "reverse": false, + "min": 0 + }, + "cellGap": 1, + "filterValues": { + "le": 1e-9 + }, + "tooltip": { + "show": true, + "yHistogram": true + }, + "legend": { + "show": true + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "cellValues": { + "unit": "none" + } + }, + "fieldConfig": { + "defaults": { + "custom": { + "scaleDistribution": { + "type": "linear" + }, + "hideFrom": { + "tooltip": false, + "viz": false, + "legend": false + } + } + }, + "overrides": [] + }, + "pluginVersion": "10.2.0" + }, + { + "datasource": { + "uid": "prometheus", + "type": "prometheus" + }, + "type": "heatmap", + "title": "Request Generation Length", + "description": "Heatmap of request generation length", + "gridPos": { + "x": 12, + "y": 24, + "w": 12, + "h": 8 + }, + "id": 13, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "refId": "A", + "expr": "sum by(le) (increase(vllm:request_generation_tokens_bucket{model_name=\"$model_name\"}[$__rate_interval]))", + "range": true, + "instant": false, + "editorMode": "builder", + "legendFormat": "{{le}}", + "useBackend": false, + "disableTextWrap": false, + "fullMetaSearch": false, + "includeNullMetadata": true, + "format": "heatmap" + } + ], + "options": { + "calculate": false, + "yAxis": { + "axisPlacement": "left", + "reverse": false, + "unit": "none", + "axisLabel": "Generation Length" + }, + "rowsFrame": { + "layout": "auto", + "value": "Request count" + }, + "color": { + "mode": "scheme", + "fill": "dark-orange", + "scale": "exponential", + "exponent": 0.5, + "scheme": "Spectral", + "steps": 64, + "reverse": false, + "min": 0 + }, + "cellGap": 1, + "filterValues": { + "le": 1e-9 + }, + "tooltip": { + "show": true, + "yHistogram": true + }, + "legend": { + "show": true + }, + "exemplars": { + "color": "rgba(255,0,255,0.7)" + }, + "cellValues": { + "unit": "none" + } + }, + "fieldConfig": { + "defaults": { + "custom": { + "scaleDistribution": { + "type": "linear" + }, + "hideFrom": { + "tooltip": false, + "viz": false, + "legend": false + } + } + }, + "overrides": [] + }, + "pluginVersion": "10.2.0" + }, + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "fieldConfig": { + "defaults": { + "custom": { + "drawStyle": "line", + "lineInterpolation": "linear", + "barAlignment": 0, + "lineWidth": 1, + "fillOpacity": 0, + "gradientMode": "none", + "spanNulls": false, + "insertNulls": false, + "showPoints": "auto", + "pointSize": 5, + "stacking": { + "mode": "none", + "group": "A" + }, + "axisPlacement": "auto", + "axisLabel": "", + "axisColorMode": "text", + "axisBorderShow": false, + "scaleDistribution": { + "type": "linear" + }, + "axisCenteredZero": false, + "hideFrom": { + "tooltip": false, + "viz": false, + "legend": false + }, + "thresholdsStyle": { + "mode": "off" + } + }, + "color": { + "mode": "palette-classic" + }, + "mappings": [], + "thresholds": { + "mode": "absolute", + "steps": [ + { + "color": "green", + "value": null + }, + { + "color": "red", + "value": 80 + } + ] + } + }, + "overrides": [] + }, + "gridPos": { + "h": 8, + "w": 12, + "x": 0, + "y": 32 + }, + "id": 11, + "options": { + "tooltip": { + "mode": "single", + "sort": "none" + }, + "legend": { + "showLegend": true, + "displayMode": "list", + "placement": "bottom", + "calcs": [] + } + }, + "targets": [ + { + "datasource": { + "type": "prometheus", + "uid": "prometheus" + }, + "disableTextWrap": false, + "editorMode": "builder", + "expr": "sum by(finished_reason) (increase(vllm:request_success_total{model_name=\"$model_name\"}[$__rate_interval]))", + "fullMetaSearch": false, + "includeNullMetadata": true, + "instant": false, + "interval": "", + "legendFormat": "__auto", + "range": true, + "refId": "A", + "useBackend": false + } + ], + "title": "Finish Reason", + "description": "Number of finished requests by their finish reason: either an EOS token was generated or the max sequence length was reached.", + "type": "timeseries" } ], "refresh": "", diff --git a/format.sh b/format.sh index 4ac1842daef0..bd12e61d7780 100755 --- a/format.sh +++ b/format.sh @@ -105,7 +105,7 @@ mypy vllm/transformers_utils --config-file pyproject.toml mypy vllm/engine --config-file pyproject.toml mypy vllm/worker --config-file pyproject.toml mypy vllm/spec_decode --config-file pyproject.toml -mypy vllm/model_executor/*.py --config-file pyproject.toml +mypy vllm/model_executor --config-file pyproject.toml mypy vllm/lora --config-file pyproject.toml diff --git a/pyproject.toml b/pyproject.toml index 2e026c1ac891..6a448defc16e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ requires = [ "ninja", "packaging", "setuptools >= 49.4.0", - "torch == 2.2.1", + "torch == 2.3.0", "wheel", ] build-backend = "setuptools.build_meta" diff --git a/requirements-build.txt b/requirements-build.txt index 2bc07fb152aa..1a07a94e82e0 100644 --- a/requirements-build.txt +++ b/requirements-build.txt @@ -3,5 +3,5 @@ cmake>=3.21 ninja packaging setuptools>=49.4.0 -torch==2.2.1 +torch==2.3.0 wheel diff --git a/requirements-common.txt b/requirements-common.txt index 3cc7bba8f84d..3abb82811668 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -8,9 +8,11 @@ py-cpuinfo transformers >= 4.40.0 # Required for StarCoder2 & Llava, Llama 3. tokenizers >= 0.19.1 # Required for Llama 3. fastapi +openai uvicorn[standard] pydantic >= 2.0 # Required for OpenAI server. prometheus_client >= 0.18.0 +prometheus-fastapi-instrumentator >= 7.0.0 tiktoken == 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.9.8 outlines == 0.0.34 # Requires torch >= 2.1.0 diff --git a/requirements-cpu.txt b/requirements-cpu.txt index e911ad03295f..b739642d8d34 100644 --- a/requirements-cpu.txt +++ b/requirements-cpu.txt @@ -2,5 +2,5 @@ -r requirements-common.txt # Dependencies for x86_64 CPUs -torch == 2.2.1+cpu +torch == 2.3.0+cpu triton >= 2.2.0 # FIXME(woosuk): This is a hack to avoid import error. \ No newline at end of file diff --git a/requirements-cuda.txt b/requirements-cuda.txt index 1bddae4c6f40..6548d7a6684b 100644 --- a/requirements-cuda.txt +++ b/requirements-cuda.txt @@ -5,5 +5,5 @@ ray >= 2.9 nvidia-ml-py # for pynvml package vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library -torch == 2.2.1 -xformers == 0.0.25 # Requires PyTorch 2.2.1 +torch == 2.3.0 +xformers == 0.0.26.post1 # Requires PyTorch 2.3.0 diff --git a/requirements-dev.txt b/requirements-dev.txt index d9816828d007..324039186142 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -21,7 +21,6 @@ pytest-rerunfailures pytest-shard httpx einops # required for MPT -openai requests ray peft diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index cb125a7bfec3..b69cdc0a2140 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -91,4 +91,6 @@ async def test_new_requests_event(): assert engine.engine.step_calls == old_step_calls + 1 engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True) + assert engine.get_model_config() is not None assert engine.get_tokenizer() is not None + assert engine.get_decoding_config() is not None diff --git a/tests/async_engine/test_merge_async_iterators.py b/tests/async_engine/test_merge_async_iterators.py new file mode 100644 index 000000000000..ea453526c77f --- /dev/null +++ b/tests/async_engine/test_merge_async_iterators.py @@ -0,0 +1,41 @@ +import asyncio +from typing import AsyncIterator, Tuple + +import pytest + +from vllm.utils import merge_async_iterators + + +@pytest.mark.asyncio +async def test_merge_async_iterators(): + + async def mock_async_iterator(idx: int) -> AsyncIterator[str]: + try: + while True: + yield f"item from iterator {idx}" + await asyncio.sleep(0.1) + except asyncio.CancelledError: + pass + + iterators = [mock_async_iterator(i) for i in range(3)] + merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( + *iterators) + + async def stream_output(generator: AsyncIterator[Tuple[int, str]]): + async for idx, output in generator: + print(f"idx: {idx}, output: {output}") + + task = asyncio.create_task(stream_output(merged_iterator)) + await asyncio.sleep(0.5) + task.cancel() + with pytest.raises(asyncio.CancelledError): + await task + + for iterator in iterators: + try: + await asyncio.wait_for(anext(iterator), 1) + except StopAsyncIteration: + # All iterators should be cancelled and print this message. + print("Iterator was cancelled normally") + except (Exception, asyncio.CancelledError) as e: + raise AssertionError() from e diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py new file mode 100644 index 000000000000..4b97af88012b --- /dev/null +++ b/tests/async_engine/test_openapi_server_ray.py @@ -0,0 +1,157 @@ +# imports for guided decoding tests +import os +import subprocess +import sys +import time + +import openai # use the official client for correctness check +import pytest +# using Ray for overall ease of process management, parallel requests, +# and debugging. +import ray +import requests + +MAX_SERVER_START_WAIT_S = 600 # wait for server to start for 60 seconds +# any model with a chat template should work here +MODEL_NAME = "facebook/opt-125m" + + +@ray.remote(num_gpus=1) +class ServerRunner: + + def __init__(self, args): + env = os.environ.copy() + env["PYTHONUNBUFFERED"] = "1" + self.proc = subprocess.Popen( + ["python3", "-m", "vllm.entrypoints.openai.api_server"] + args, + env=env, + stdout=sys.stdout, + stderr=sys.stderr, + ) + self._wait_for_server() + + def ready(self): + return True + + def _wait_for_server(self): + # run health check + start = time.time() + while True: + try: + if requests.get( + "http://localhost:8000/health").status_code == 200: + break + except Exception as err: + if self.proc.poll() is not None: + raise RuntimeError("Server exited unexpectedly.") from err + + time.sleep(0.5) + if time.time() - start > MAX_SERVER_START_WAIT_S: + raise RuntimeError( + "Server failed to start in time.") from err + + def __del__(self): + if hasattr(self, "proc"): + self.proc.terminate() + + +@pytest.fixture(scope="session") +def server(): + ray.init() + server_runner = ServerRunner.remote([ + "--model", + MODEL_NAME, + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--max-model-len", + "2048", + "--enforce-eager", + "--engine-use-ray" + ]) + ray.get(server_runner.ready.remote()) + yield server_runner + ray.shutdown() + + +@pytest.fixture(scope="session") +def client(): + client = openai.AsyncOpenAI( + base_url="http://localhost:8000/v1", + api_key="token-abc123", + ) + yield client + + +@pytest.mark.asyncio +async def test_check_models(server, client: openai.AsyncOpenAI): + models = await client.models.list() + models = models.data + served_model = models[0] + assert served_model.id == MODEL_NAME + assert all(model.root == MODEL_NAME for model in models) + + +@pytest.mark.asyncio +async def test_single_completion(server, client: openai.AsyncOpenAI): + completion = await client.completions.create(model=MODEL_NAME, + prompt="Hello, my name is", + max_tokens=5, + temperature=0.0) + + assert completion.id is not None + assert completion.choices is not None and len(completion.choices) == 1 + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + assert completion.choices[0].finish_reason == "length" + assert completion.usage == openai.types.CompletionUsage( + completion_tokens=5, prompt_tokens=6, total_tokens=11) + + # test using token IDs + completion = await client.completions.create( + model=MODEL_NAME, + prompt=[0, 0, 0, 0, 0], + max_tokens=5, + temperature=0.0, + ) + assert completion.choices[0].text is not None and len( + completion.choices[0].text) >= 5 + + +@pytest.mark.asyncio +async def test_single_chat_session(server, client: openai.AsyncOpenAI): + messages = [{ + "role": "system", + "content": "you are a helpful assistant" + }, { + "role": "user", + "content": "what is 1+1?" + }] + + # test single completion + chat_completion = await client.chat.completions.create(model=MODEL_NAME, + messages=messages, + max_tokens=10, + logprobs=True, + top_logprobs=5) + assert chat_completion.id is not None + assert chat_completion.choices is not None and len( + chat_completion.choices) == 1 + assert chat_completion.choices[0].message is not None + assert chat_completion.choices[0].logprobs is not None + assert chat_completion.choices[0].logprobs.top_logprobs is not None + assert len(chat_completion.choices[0].logprobs.top_logprobs[0]) == 5 + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 10 + assert message.role == "assistant" + messages.append({"role": "assistant", "content": message.content}) + + # test multi-turn dialogue + messages.append({"role": "user", "content": "express your result in json"}) + chat_completion = await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_tokens=10, + ) + message = chat_completion.choices[0].message + assert message.content is not None and len(message.content) >= 0 diff --git a/tests/entrypoints/test_openai_server.py b/tests/entrypoints/test_openai_server.py index 85a7ef464c03..68332228ace0 100644 --- a/tests/entrypoints/test_openai_server.py +++ b/tests/entrypoints/test_openai_server.py @@ -15,6 +15,7 @@ import requests # downloading lora to test lora requests from huggingface_hub import snapshot_download +from openai import BadRequestError from vllm.transformers_utils.tokenizer import get_tokenizer @@ -770,6 +771,21 @@ async def test_response_format_json_object(server, client: openai.AsyncOpenAI): assert loaded == {"result": 2}, loaded +async def test_extra_fields(server, client: openai.AsyncOpenAI): + with pytest.raises(BadRequestError) as exc_info: + await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "system", + "content": "You are a helpful assistant.", + "extra_field": "0", + }], # type: ignore + temperature=0, + seed=0) + + assert "extra_forbidden" in exc_info.value.message + + async def test_guided_grammar(server, client: openai.AsyncOpenAI): simple_sql_grammar = """ start: select_statement diff --git a/tests/lora/test_layers.py b/tests/lora/test_layers.py index 1616fdfd4cff..0eb04f4ccd13 100644 --- a/tests/lora/test_layers.py +++ b/tests/lora/test_layers.py @@ -8,6 +8,10 @@ import torch.nn.functional as F from vllm.config import LoRAConfig +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) # yapf conflicts with isort for this block # yapf: disable from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, @@ -524,13 +528,16 @@ def _pretest(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("orientation", ["row", "column"]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_linear_parallel(dist_init, num_loras, orientation, device) -> None: +def test_linear_parallel(dist_init, num_loras, orientation, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_random_linear_parallel_layer(): @@ -540,14 +547,17 @@ def create_random_linear_parallel_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = RowParallelLinearWithLoRA(linear) + lora_linear = (RowParallelLinearWithLoRA(linear) if not fully_shard + else RowParallelLinearWithShardedLoRA(linear)) else: linear = ColumnParallelLinear(4096, 4096, bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = ColumnParallelLinearWithLoRA(linear) + lora_linear = (ColumnParallelLinearWithLoRA(linear) + if not fully_shard else + ColumnParallelLinearWithShardedLoRA(linear)) lora_linear.create_lora_weights(max_loras, lora_config) return linear, lora_linear @@ -629,13 +639,16 @@ def create_random_linear_parallel_layer(): @torch.inference_mode() @pytest.mark.parametrize("num_loras", [1, 2, 4, 8]) @pytest.mark.parametrize("repeats", [1, 2, 3]) +@pytest.mark.parametrize("fully_shard", [True, False]) @pytest.mark.parametrize("device", CUDA_DEVICES) -def test_column_parallel_packed(dist_init, num_loras, repeats, device) -> None: +def test_column_parallel_packed(dist_init, num_loras, repeats, fully_shard, + device) -> None: torch.set_default_device(device) max_loras = 8 lora_config = LoRAConfig(max_loras=max_loras, max_lora_rank=8, + fully_sharded_loras=fully_shard, lora_dtype=torch.float16) def create_column_parallel_packed_layer(): @@ -644,7 +657,9 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedColumnParallelLinearWithLoRA(linear) + lora_linear = (MergedColumnParallelLinearWithLoRA(linear) + if not fully_shard else + MergedColumnParallelLinearWithShardedLoRA(linear)) elif repeats == 3: linear = QKVParallelLinear(4096, 64, @@ -652,7 +667,9 @@ def create_column_parallel_packed_layer(): bias=False, params_dtype=torch.float16) linear.weight.data = torch.rand_like(linear.weight.data) - lora_linear = MergedQKVParallelLinearWithLora(linear) + lora_linear = (MergedQKVParallelLinearWithLora(linear) + if not fully_shard else + MergedQKVParallelLinearWithShardedLora(linear)) else: linear = QKVParallelLinear(4096, 64, diff --git a/tests/lora/test_punica.py b/tests/lora/test_punica.py index f3b9bd591296..fd2a1b75f460 100644 --- a/tests/lora/test_punica.py +++ b/tests/lora/test_punica.py @@ -34,11 +34,14 @@ def _lora_ref_impl( for i, lora_idx in zip(range(bs), indicies.cpu().tolist()): xi = x[i].unsqueeze(0).to(torch.float32) wa = wa_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) - wb = wb_T_all[lora_idx, layer_idx].transpose(-1, -2).to(torch.float32) + if wb_T_all is not None: + wb = wb_T_all[lora_idx, layer_idx].transpose(-1, + -2).to(torch.float32) tmp = xi @ wa y_stage_1[i] = tmp.squeeze(0) - y_final[i] += (tmp @ wb).squeeze(0) * s + y_final[i] += ((tmp @ wb).squeeze(0) * + s if wb_T_all is not None else y_stage_1[i]) return y_final, y_stage_1 @@ -91,12 +94,56 @@ def _lora_ref_impl( 128000, 128256, ] +H2 = [64] + H2 +R = [1, 2, 4] SEED = [0xabcdabcd987] CUDA_DEVICES = [ f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) ] +@pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) +@pytest.mark.parametrize("h1", H1) +@pytest.mark.parametrize("r", R) +@pytest.mark.parametrize("seed", SEED) +@torch.inference_mode() +def test_lora_a_extra_shapes(dtype_str, h1, r, seed): + torch.manual_seed(seed) + num_loras = 4 + num_layers = 1 + bs = 32 + dtype = getattr(torch, dtype_str) + device = torch.device("cuda") + + wa_T_all = torch.randn(num_loras, + num_layers, + r, + h1, + dtype=dtype, + device=device) + indices = torch.randint(num_loras, (bs, ), dtype=torch.long, device=device) + + for layer_idx in range(num_layers): + x = torch.randn(bs, h1, dtype=dtype, device=device) + y = torch.randn(bs, r, dtype=dtype, device=device) + + y_ref = y.clone() + _lora_ref_impl( + y_ref, + x, + wa_T_all, + None, + indices, + layer_idx, + 1.0, + ) + + y_our = y.clone() + punica.bgmv(y_our, x, wa_T_all, indices, layer_idx, 1.0) + + assert_close(y_ref, y_our) + + @pytest.mark.parametrize("dtype_str", ["float16", "bfloat16"]) @pytest.mark.parametrize("h1", H1) @pytest.mark.parametrize("h2", H2) diff --git a/tests/model_executor/weight_utils.py b/tests/model_executor/weight_utils.py index b0086dd7a7d7..c8b9bed691bb 100644 --- a/tests/model_executor/weight_utils.py +++ b/tests/model_executor/weight_utils.py @@ -1,9 +1,12 @@ import os +import tempfile import huggingface_hub.constants import pytest +from huggingface_hub.utils import LocalEntryNotFoundError -from vllm.model_executor.model_loader.weight_utils import enable_hf_transfer +from vllm.model_executor.model_loader.weight_utils import ( + download_weights_from_hf, enable_hf_transfer) def test_hf_transfer_auto_activation(): @@ -22,5 +25,30 @@ def test_hf_transfer_auto_activation(): HF_TRANFER_ACTIVE) +def test_download_weights_from_hf(): + with tempfile.TemporaryDirectory() as tmpdir: + # assert LocalEntryNotFoundError error is thrown + # if offline is set and model is not cached + huggingface_hub.constants.HF_HUB_OFFLINE = True + with pytest.raises(LocalEntryNotFoundError): + download_weights_from_hf("facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) + + # download the model + huggingface_hub.constants.HF_HUB_OFFLINE = False + download_weights_from_hf("facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) + + # now it should work offline + huggingface_hub.constants.HF_HUB_OFFLINE = True + assert download_weights_from_hf( + "facebook/opt-125m", + allow_patterns=["*.safetensors", "*.bin"], + cache_dir=tmpdir) is not None + + if __name__ == "__main__": test_hf_transfer_auto_activation() + test_download_weights_from_hf() diff --git a/tests/models/test_gptq_marlin.py b/tests/models/test_gptq_marlin.py new file mode 100644 index 000000000000..dc027697ffd4 --- /dev/null +++ b/tests/models/test_gptq_marlin.py @@ -0,0 +1,93 @@ +"""Compares the outputs of gptq vs gptq_marlin +Note: GPTQ and Marlin do not have bitwise correctness. +As a result, in this test, we just confirm that the top selected tokens of the +Marlin/GPTQ models are in the top 3 selections of each other. +Note: Marlin internally uses locks to synchronize the threads. This can +result in very slight nondeterminism for Marlin. As a result, we re-run the test +up to 3 times to see if we pass. +Note: This test currently fails running with --forked with the following: + RuntimeError: Cannot re-initialize CUDA in forked subprocess. + To use CUDA with multiprocessing, you must use the 'spawn' start method +Run `pytest tests/models/test_gptq_marlin.py`. +""" +import os + +import pytest +import torch + +from tests.models.utils import check_logprobs_close +from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS + +os.environ["TOKENIZERS_PARALLELISM"] = "true" + +MAX_MODEL_LEN = 1024 + +capability = torch.cuda.get_device_capability() +capability = capability[0] * 10 + capability[1] +gptq_marlin_not_supported = ( + capability < QUANTIZATION_METHODS["gptq_marlin"].get_min_capability()) + +MODELS = [ + # act_order==False, group_size=channelwise + ("robertgshaw2/zephyr-7b-beta-channelwise-gptq", "main"), + # act_order==False, group_size=128 + ("TheBloke/Llama-2-7B-GPTQ", "main"), + + # act_order==True, group_size=128 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "main"), + # act_order==True, group_size=64 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-64g-actorder_True"), + # act_order==True, group_size=32 + ("TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ", "gptq-4bit-32g-actorder_True"), +] + + +@pytest.mark.flaky(reruns=2) +@pytest.mark.skipif(gptq_marlin_not_supported, + reason="gptq_marlin is not supported on this GPU type.") +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + vllm_runner, + example_prompts, + model, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + model_name, revision = model + + # Run marlin. + gptq_marlin_model = vllm_runner(model_name=model_name, + revision=revision, + dtype=dtype, + quantization="marlin", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + disable_custom_all_reduce=True) + + gptq_marlin_outputs = gptq_marlin_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + del gptq_marlin_model + + # Run gptq. + gptq_model = vllm_runner(model_name=model_name, + revision=revision, + dtype=dtype, + quantization="gptq", + max_model_len=MAX_MODEL_LEN, + tensor_parallel_size=1, + disable_custom_all_reduce=True) + gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, + max_tokens, + num_logprobs) + del gptq_model + + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=gptq_marlin_outputs, + name_0="gptq", + name_1="gptq_marlin", + ) diff --git a/tests/models/test_marlin.py b/tests/models/test_marlin.py index 4fe6daec0252..fa846d43d0e8 100644 --- a/tests/models/test_marlin.py +++ b/tests/models/test_marlin.py @@ -10,12 +10,12 @@ Run `pytest tests/models/test_marlin.py`. """ - from dataclasses import dataclass import pytest import torch +from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS capability = torch.cuda.get_device_capability() @@ -55,43 +55,24 @@ def test_models( max_tokens: int, num_logprobs: int, ) -> None: - marlin_model = vllm_runner(model_pair.model_marlin, dtype=dtype) + marlin_model = vllm_runner(model_pair.model_marlin, + dtype=dtype, + quantization="marlin") marlin_outputs = marlin_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. del marlin_model - gptq_model = vllm_runner(model_pair.model_gptq, dtype=dtype) + gptq_model = vllm_runner(model_pair.model_gptq, + dtype=dtype, + quantization="gptq") gptq_outputs = gptq_model.generate_greedy_logprobs(example_prompts, max_tokens, num_logprobs) - - # Note: not sure why, but deleting just the model on Ada Lovelace - # does not free the GPU memory. On Ampere, deleting the just model - # frees the memory. del gptq_model - # loop through the prompts - for prompt_idx in range(len(example_prompts)): - gptq_output_ids, gptq_output_str, gptq_logprobs = gptq_outputs[ - prompt_idx] - marlin_output_ids, marlin_output_str, marlin_logprobs = marlin_outputs[ - prompt_idx] - - for idx, (gptq_output_id, marlin_output_id) in enumerate( - zip(gptq_output_ids, marlin_output_ids)): - # If sequence is not an exact match, - if marlin_output_id != gptq_output_id: - # Each predicted token must be in top 5 of the other's - assert gptq_output_id in marlin_logprobs[idx], ( - f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n" - f"Marlin:\t{marlin_output_str!r}") - assert marlin_output_id in gptq_logprobs[idx], ( - f"Test{prompt_idx}:\nGPTQ:\t{gptq_output_str!r}\n" - f"Marlin:\t{marlin_output_str!r}") - - # Break out since sequences will now diverge. - break + check_logprobs_close( + outputs_0_lst=gptq_outputs, + outputs_1_lst=marlin_outputs, + name_0="gptq", + name_1="marlin", + ) diff --git a/tests/models/utils.py b/tests/models/utils.py new file mode 100644 index 000000000000..3e49dfb33117 --- /dev/null +++ b/tests/models/utils.py @@ -0,0 +1,29 @@ +def check_logprobs_close(outputs_0_lst, outputs_1_lst, name_0, name_1): + """Compare the logprobs of two sequences generated by different models, + which should be similar but not necessarily equal. + """ + # Loop through responses to each prompt. + for prompt_idx, (outputs_0, + outputs_1) in enumerate(zip(outputs_0_lst, + outputs_1_lst)): + output_ids_0, output_str_0, logprobs_0 = outputs_0 + output_ids_1, output_str_1, logprobs_1 = outputs_1 + + # Loop through generated tokens. + for idx, (output_id_0, + output_id_1) in enumerate(zip(output_ids_0, output_ids_1)): + + # If generated tokens don't match, then + if output_id_0 != output_id_1: + # Each predicted token must be in top N logprobs of the other + assert output_id_0 in logprobs_1[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + assert output_id_1 in logprobs_0[idx], ( + f"Test{prompt_idx}:" + f"\n{name_0}:\t{output_str_0!r}" + f"\n{name_1}:\t{output_str_1!r}") + + # Break out since sequences will now diverge. + break diff --git a/tests/quantization/test_autogptq_marlin_configs.py b/tests/quantization/test_autogptq_marlin_configs.py deleted file mode 100644 index 1310b4da218b..000000000000 --- a/tests/quantization/test_autogptq_marlin_configs.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Tests whether Marlin models can be loaded from the autogptq config. - -Run `pytest tests/quantization/test_autogptq_marlin_configs.py --forked`. -""" - -from dataclasses import dataclass - -import pytest - -from vllm.config import ModelConfig - - -@dataclass -class ModelPair: - model_marlin: str - model_gptq: str - - -# Model Id // Expected Kernel -MODELS_QUANT_TYPE = [ - # compat: autogptq <=0.7.1 is_marlin_format: bool - ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin"), - ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq"), - # compat: autogptq >=0.8.0 use checkpoint_format: str - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin"), - ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq") -] - - -@pytest.mark.parametrize("model_quant_type", MODELS_QUANT_TYPE) -def test_auto_gptq(model_quant_type: str, ) -> None: - model_path, quant_type = model_quant_type - - model_config_no_quant_arg = ModelConfig( - model_path, - model_path, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="float16", - revision=None, - quantization=None # case 1 - ) - - model_config_quant_arg = ModelConfig( - model_path, - model_path, - tokenizer_mode="auto", - trust_remote_code=False, - seed=0, - dtype="float16", - revision=None, - quantization="gptq" # case 2 - ) - - assert model_config_no_quant_arg.quantization == quant_type, ( - f"Expected quant_type == {quant_type} for {model_path}, " - f"but found {model_config_no_quant_arg.quantization} " - "for no --quantization None case") - - assert model_config_quant_arg.quantization == quant_type, ( - f"Expected quant_type == {quant_type} for {model_path}, " - f"but found {model_config_quant_arg.quantization} " - "for --quantization gptq case") diff --git a/tests/quantization/test_configs.py b/tests/quantization/test_configs.py new file mode 100644 index 000000000000..6820b2728e3c --- /dev/null +++ b/tests/quantization/test_configs.py @@ -0,0 +1,73 @@ +"""Tests whether Marlin models can be loaded from the autogptq config. + +Run `pytest tests/quantization/test_configs.py --forked`. +""" + +from dataclasses import dataclass + +import pytest + +from vllm.config import ModelConfig + + +@dataclass +class ModelPair: + model_marlin: str + model_gptq: str + + +# Model Id // Quantization Arg // Expected Type +MODEL_ARG_EXPTYPES = [ + # AUTOGPTQ + # compat: autogptq <=0.7.1 is_marlin_format: bool + # Model Serialized in Marlin Format should always use Marlin kernel. + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", None, "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "marlin", "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "gptq", "marlin"), + ("neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin", "awq", "ERROR"), + # Model Serialized in Exllama Format. + ("TheBloke/Llama-2-7B-Chat-GPTQ", None, "gptq_marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "marlin", "gptq_marlin"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "gptq", "gptq"), + ("TheBloke/Llama-2-7B-Chat-GPTQ", "awq", "ERROR"), + # compat: autogptq >=0.8.0 use checkpoint_format: str + # Model Serialized in Marlin Format should always use Marlin kernel. + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", None, "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "marlin", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "gptq", "marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit", "awq", "ERROR"), + # Model Serialized in Exllama Format. + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", None, "gptq_marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "marlin", "gptq_marlin"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "gptq", "gptq"), + ("LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit", "awq", "ERROR"), + + # AUTOAWQ + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", None, "awq"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "awq", "awq"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "marlin", "ERROR"), + ("TheBloke/OpenHermes-2.5-Mistral-7B-AWQ", "gptq", "ERROR"), +] + + +@pytest.mark.parametrize("model_arg_exptype", MODEL_ARG_EXPTYPES) +def test_auto_gptq(model_arg_exptype: str) -> None: + model_path, quantization_arg, expected_type = model_arg_exptype + + try: + model_config = ModelConfig(model_path, + model_path, + tokenizer_mode="auto", + trust_remote_code=False, + seed=0, + dtype="float16", + revision=None, + quantization=quantization_arg) + found_quantization_type = model_config.quantization + except ValueError: + found_quantization_type = "ERROR" + + assert found_quantization_type == expected_type, ( + f"Expected quant_type == {expected_type} for {model_path}, " + f"but found {found_quantization_type} " + f"for no --quantization {quantization_arg} case") diff --git a/tests/quantization/test_fp8.py b/tests/quantization/test_fp8.py index fa10e60de10a..607544a1c839 100644 --- a/tests/quantization/test_fp8.py +++ b/tests/quantization/test_fp8.py @@ -20,5 +20,5 @@ def test_load_fp16_model(vllm_runner) -> None: model = llm.model.llm_engine.model_executor.driver_worker.model_runner.model fc1 = model.model.decoder.layers[0].fc1 - assert isinstance(fc1.linear_method, Fp8LinearMethod) + assert isinstance(fc1.quant_method, Fp8LinearMethod) assert fc1.weight.dtype == torch.float8_e4m3fn diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 6f2145f8cdcf..7859f0b21812 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -207,7 +207,7 @@ def test_sampler_min_tokens_penalty(seed: int, device: str): def create_sampling_params(min_tokens, eos_token_id=0, *, - stop_token_ids: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, prompt_logprobs: Optional[int] = None): sampling_params = SamplingParams( min_tokens=min_tokens, @@ -216,7 +216,7 @@ def create_sampling_params(min_tokens, # requesting prompt_logprobs changes the structure of `logits` prompt_logprobs=prompt_logprobs, ) - sampling_params.eos_token_id = eos_token_id + sampling_params.all_stop_token_ids.add(eos_token_id) return sampling_params def create_sequence_data(num_input=3, num_generated=0): @@ -461,10 +461,7 @@ def run_test_case(*, for logits_idx, (should_penalize, sampling_params) in enumerate( zip(expected_penalization, sampling_params_per_row)): - tokens_to_check = [sampling_params.eos_token_id] - if sampling_params.stop_token_ids: - tokens_to_check.extend(sampling_params.stop_token_ids) - tokens_to_check = set(tokens_to_check) + tokens_to_check = sampling_params.all_stop_token_ids if should_penalize: for token_id in tokens_to_check: diff --git a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py index e4b15fd57add..0e113ab647e6 100644 --- a/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py +++ b/tests/tensorizer_loader/tensorize_vllm_model_for_testing.py @@ -6,14 +6,14 @@ from functools import partial from typing import Type -import torch import torch.nn as nn from tensorizer import (DecryptionParams, EncryptionParams, TensorDeserializer, TensorSerializer, stream_io) from tensorizer.utils import convert_bytes, get_mem_usage, no_init_or_tensor from transformers import AutoConfig, PretrainedConfig -from vllm.distributed import initialize_model_parallel +from vllm.distributed import (init_distributed_environment, + initialize_model_parallel) from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.model_executor.model_loader.tensorizer import TensorizerArgs @@ -226,7 +226,7 @@ def deserialize(): os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "8080" -torch.distributed.init_process_group(world_size=1, rank=0) +init_distributed_environment(world_size=1, rank=0, local_rank=0) initialize_model_parallel() keyfile = args.keyfile if args.keyfile else None diff --git a/tests/tensorizer_loader/test_tensorizer.py b/tests/tensorizer_loader/test_tensorizer.py index a97cc0b3706b..df1db4e6c400 100644 --- a/tests/tensorizer_loader/test_tensorizer.py +++ b/tests/tensorizer_loader/test_tensorizer.py @@ -50,10 +50,10 @@ def test_load_with_tensorizer(mock_agent, tensorizer_config): mock_agent_instance.deserialize.return_value = MagicMock() result = load_with_tensorizer(tensorizer_config, - linear_method=mock_linear_method) + quant_method=mock_linear_method) mock_agent.assert_called_once_with(tensorizer_config, - linear_method=mock_linear_method) + quant_method=mock_linear_method) mock_agent_instance.deserialize.assert_called_once() assert result == mock_agent_instance.deserialize.return_value diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index abb401f25c10..56fe6db589f1 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -2,8 +2,10 @@ import torch from vllm.config import ModelConfig, SchedulerConfig +from vllm.distributed.parallel_state import init_distributed_environment from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata +from vllm.utils import get_open_port from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size @@ -249,19 +251,18 @@ def test_empty_seq_group(): assert len(return_prompt_lens) == 0 -@pytest.mark.parametrize("batch_size", list(range(2, 128))) -@pytest.mark.parametrize("enforce_eager", [True, False]) -def test_hybrid_batches(batch_size, enforce_eager, monkeypatch): - - def get_world_size(group=None): - return 1 +@pytest.fixture +def distributed_init(): + init_distributed_environment( + world_size=1, + rank=0, + distributed_init_method=f"tcp://127.0.0.1:{get_open_port()}", + local_rank=0) - def mock_get_process_group_ranks(group=None): - return [0] - monkeypatch.setattr(torch.distributed, "get_world_size", get_world_size) - monkeypatch.setattr(torch.distributed, "get_process_group_ranks", - mock_get_process_group_ranks) +@pytest.mark.parametrize("batch_size", list(range(2, 128))) +@pytest.mark.parametrize("enforce_eager", [True, False]) +def test_hybrid_batches(batch_size, enforce_eager, distributed_init): model_config = ModelConfig( "facebook/opt-125m", diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 508d35656eb0..4af8b09b1e16 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -167,11 +167,33 @@ def aqlm_dequant(codes: torch.Tensor, codebooks: torch.Tensor, return vllm_ops.aqlm_dequant(codes, codebooks, codebook_partition_sizes) +# gptq_marlin +def gptq_marlin_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int) -> torch.Tensor: + return vllm_ops.gptq_marlin_repack(b_q_weight, perm, size_k, size_n) + + +def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, + b_scales: torch.Tensor, g_idx: torch.Tensor, + perm: torch.Tensor, workspace: torch.Tensor, size_m: int, + size_n: int, size_k: int, + is_k_full: bool) -> torch.Tensor: + return vllm_ops.gptq_marlin_gemm(a, b_q_weight, b_scales, g_idx, perm, + workspace, size_m, size_n, size_k, + is_k_full) + + # fp8 -def scaled_fp8_quant(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - scale = torch.zeros(1, device=input.device, dtype=torch.float32) +def scaled_fp8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, +) -> Tuple[torch.Tensor, torch.Tensor]: output = torch.empty_like(input, dtype=torch.float8_e4m3fn) - vllm_ops.scaled_fp8_quant(output, input, scale) + if scale is None: + scale = torch.zeros(1, device=input.device, dtype=torch.float32) + vllm_ops.dynamic_scaled_fp8_quant(output, input, scale) + else: + vllm_ops.static_scaled_fp8_quant(output, input, scale) return output, scale diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 7c5863a030ff..934acea0a3d6 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -253,36 +253,31 @@ def forward( # triton attention # When block_tables are not filled, it means q and k are the # prompt, and they have the same length. - if self.use_triton_flash_attn or self.use_naive_attn: + if self.use_triton_flash_attn: + out, _ = self.attn_func( + query, + key, + value, + None, + prefill_meta.seq_start_loc, + prefill_meta.seq_start_loc, + prefill_meta.max_prompt_len, + prefill_meta.max_prompt_len, + True, + self.scale, + ) + elif self.use_naive_attn: if self.num_kv_heads != self.num_heads: # Interleave for MQA workaround. key = self.repeat_kv(key, self.num_queries_per_kv) value = self.repeat_kv(value, self.num_queries_per_kv) - if self.use_naive_attn: - out = self.attn_func( - query, - key, - value, - prefill_meta.prompt_lens, - self.scale, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out - else: - out, _ = self.attn_func( - query, - key, - value, - None, - prefill_meta.seq_start_loc, - prefill_meta.seq_start_loc, - prefill_meta.max_prompt_len, - prefill_meta.max_prompt_len, - True, - self.scale, - ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + out = self.attn_func( + query, + key, + value, + prefill_meta.prompt_lens, + self.scale, + ) else: out = self.attn_func( q=query, @@ -295,8 +290,10 @@ def forward( softmax_scale=self.scale, causal=True, ) - assert output[:num_prefill_tokens].shape == out.shape - output[:num_prefill_tokens] = out + + # common code for prefill + assert output[:num_prefill_tokens].shape == out.shape + output[:num_prefill_tokens] = out else: # prefix-enabled attention output[:num_prefill_tokens] = PagedAttention.forward_prefix( diff --git a/vllm/attention/ops/triton_flash_attention.py b/vllm/attention/ops/triton_flash_attention.py index e160411859f0..1147664183ff 100644 --- a/vllm/attention/ops/triton_flash_attention.py +++ b/vllm/attention/ops/triton_flash_attention.py @@ -293,7 +293,7 @@ def _attn_fwd_inner( num_warps=4, ), ], - key=["hq", "hk", "IS_CAUSAL", "dropout_p", "BLOCK_DMODEL"], + key=['IS_CAUSAL', 'dropout_p', 'BLOCK_DMODEL'], ) @triton.jit def attn_fwd( @@ -330,8 +330,8 @@ def attn_fwd( philox_seed, philox_offset_base, encoded_softmax, - hq, - hk, + HQ: tl.constexpr, + HK: tl.constexpr, ACTUAL_BLOCK_DMODEL: tl.constexpr, MAX_SEQLENS_Q: tl.constexpr, MAX_SEQLENS_K: tl.constexpr, @@ -403,7 +403,7 @@ def attn_fwd( # We still need to write 0s to the result # tl.store(O_block_ptr, # acc.to(Out.type.element_ty), boundary_check=(0,1)) - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q # + offs_m # We store inf to LSE, not -inf because in the bwd pass, # we subtract this @@ -414,11 +414,9 @@ def attn_fwd( # TODO: Should dropout and return encoded softmax be handled here? return - is_mqa = hq != hk - if is_mqa: # noqa: SIM108 - off_h_k = off_h_q % hk - else: - off_h_k = off_h_q + # If MQA / GQA, set the K and V head offsets appropriately. + GROUP_SIZE: tl.constexpr = HQ // HK + off_h_k = off_h_q // GROUP_SIZE if GROUP_SIZE != 1 else off_h_q n_extra_tokens = 0 if seqlen_k < BLOCK_N: @@ -471,7 +469,7 @@ def attn_fwd( bias_ptr = None if ENABLE_DROPOUT: batch_philox_offset = philox_offset_base \ - + (off_z * hq + off_h_q) \ + + (off_z * HQ + off_h_q) \ * seqlen_q * seqlen_k else: batch_philox_offset = 0 @@ -624,7 +622,7 @@ def attn_fwd( z = 0.0 acc = tl.where(out_ptrs_mask, acc, z.to(acc.type.element_ty)) # write back LSE - # l_ptrs = L + off_z * hq * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m + # l_ptrs = L + off_z * HQ * MAX_SEQLENS_Q + off_h_q * MAX_SEQLENS_Q + offs_m # If seqlen_q not multiple of BLOCK_M, we need to mask out the last # few rows. This is only true for the last M block. For others, # overflow_size will be -ve @@ -784,8 +782,8 @@ def forward( philox_seed=philox_seed, philox_offset_base=philox_offset, encoded_softmax=encoded_softmax, - hq=nheads_q, - hk=nheads_k, + HQ=nheads_q, + HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q, MAX_SEQLENS_K=max_seqlens_k, diff --git a/vllm/config.py b/vllm/config.py index 887a73d9462d..a5512c657e03 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -9,11 +9,14 @@ from transformers import PretrainedConfig from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + get_quantization_config) from vllm.transformers_utils.config import get_config, get_hf_text_config from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip, is_neuron) +GPTQMarlinConfig = get_quantization_config("gptq_marlin") + if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -138,14 +141,34 @@ def _verify_quantization(self) -> None: is_format_marlin = (quant_cfg.get("checkpoint_format") == "marlin" or quant_cfg.get("is_marlin_format", False)) - # Use marlin if the GPTQ model is serialized in marlin format. - if quant_method == "gptq" and is_format_marlin: - logger.info("The model is serialized in Marlin format. " + # Check which LinearMethod the GPTQ model should use. + if quant_method == "gptq": + # If serialized in Marlin format, use MarlinLinearMethod. + # TODO (@robertgshaw): migrate under GPTQMarlinLinearMethod. + if is_format_marlin: + logger.info("The model is serialized in Marlin format. " + "Using Marlin kernel.") + quant_method = "marlin" + if self.quantization == "gptq": + self.quantization = quant_method + + # If convertible to Marlin format, use GPTQMarlinLinearMethod + # unless the user explicitly specified GPTQLinearMethod. + elif GPTQMarlinConfig.is_marlin_compatible(quant_cfg): + if self.quantization == "gptq": + logger.warning( + "The model is convertible to Marlin format, but " + "you specified quantization=gptq. Use " + "quantization=marlin for faster inference.") + else: + logger.info( + "The model is convertible to Marlin format. " "Using Marlin kernel.") - quant_method = "marlin" - if self.quantization == "gptq": - self.quantization = quant_method + quant_method = "gptq_marlin" + if self.quantization == "marlin": + self.quantization = quant_method + # Verify quantization configurations. if self.quantization is None: self.quantization = quant_method elif self.quantization != quant_method: @@ -165,7 +188,7 @@ def _verify_quantization(self) -> None: raise ValueError( f"{self.quantization} quantization is currently not " f"supported in ROCm.") - if self.quantization != "marlin": + if (self.quantization not in ["marlin", "gptq_marlin"]): logger.warning( "%s quantization is not fully " "optimized yet. The speed can be slower than " @@ -862,6 +885,7 @@ def __repr__(self) -> str: class LoRAConfig: max_lora_rank: int max_loras: int + fully_sharded_loras: bool = False max_cpu_loras: Optional[int] = None lora_dtype: Optional[torch.dtype] = None lora_extra_vocab_size: int = 256 diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index be093922b84f..1fac2636e86f 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" +import math from abc import ABC, abstractmethod from itertools import count, takewhile from os.path import commonprefix @@ -220,9 +221,9 @@ def __init__( self.block_sliding_window = None if sliding_window is not None: - assert sliding_window % block_size == 0, (sliding_window, - block_size) - self.block_sliding_window = sliding_window // block_size + # Round up to nearest block size to regularize sliding window + # allocation sizes. + self.block_sliding_window = math.ceil(sliding_window / block_size) self.watermark = watermark assert watermark >= 0.0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 7439f7dc33e8..024b7e701344 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -320,7 +320,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: for seq_group in state_queue: if not request_ids: # Using 'break' here may add two extra iterations, - # but is acceptable to reduce complexity . + # but is acceptable to reduce complexity. break if seq_group.request_id in request_ids: # Appending aborted group into pending list. diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index a3e93691a1e8..8b2c26c3a8af 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -4,7 +4,8 @@ import torch from torch.distributed import ProcessGroup -from .parallel_state import (get_tensor_model_parallel_group, +from .parallel_state import (get_cpu_world_group, + get_tensor_model_parallel_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, is_pynccl_enabled_for_all_reduce) @@ -140,13 +141,46 @@ def broadcast_object_list(obj_list: List[Any], TensorMetadata = namedtuple("TensorMetadata", ["dtype", "size"]) +def _split_tensor_dict( + tensor_dict: Dict[Any, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list = [] + tensor_list = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note(youkaichao): currently this only supports broadcasting + # tensors on cuda. In the future, we can add device as a field in + # TensorMetadata to support broadcasting tensors on different + # devices. + assert value.is_cuda, ( + f"Tensor {key}: {value} is not on cuda. Currently we only " + f"support broadcasting tensors on cuda.") + metadata_list.append((key, TensorMetadata(value.dtype, + value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + def broadcast_tensor_dict( tensor_dict: Optional[Dict[Any, Union[torch.Tensor, Any]]] = None, src: int = 0, group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None ) -> Optional[Dict[Any, Union[torch.Tensor, Any]]]: - """Broadcast the input tensor dictionary.""" + """Broadcast the input tensor dictionary. + `group` is used to broadcast the tensors, while `metadata_group` is used + to broadcast the metadata of the dict (e.g. dict structure, tensor sizes, + dtypes). + """ group = group or torch.distributed.group.WORLD + metadata_group = metadata_group or get_cpu_world_group() ranks = torch.distributed.get_process_group_ranks(group) assert src in ranks, f"Invalid src rank ({src})" @@ -161,27 +195,20 @@ def broadcast_tensor_dict( assert isinstance( tensor_dict, dict), (f"Expecting a dictionary, got {type(tensor_dict)}") - for key, value in tensor_dict.items(): - if isinstance(value, torch.Tensor): - assert value.is_cuda, ( - f"Tensor {key}: {value} is not on cuda. Currently we only " - f"support broadcasting tensors on cuda.") - metadata_list.append( - (key, TensorMetadata(value.dtype, value.size()))) - else: - metadata_list.append((key, value)) + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` involves serialization and deserialization, + # all happening on CPU. Therefore, we can use the CPU group. torch.distributed.broadcast_object_list([metadata_list], src=src, - group=group) + group=metadata_group) async_handles = [] - for key, value in metadata_list: - if isinstance(value, TensorMetadata): - tensor = tensor_dict[key] - async_handles.append( - torch.distributed.broadcast(tensor, - src=src, - group=group, - async_op=True)) + for tensor in tensor_list: + async_handles.append( + torch.distributed.broadcast(tensor, + src=src, + group=group, + async_op=True)) for async_handle in async_handles: async_handle.wait() @@ -189,7 +216,7 @@ def broadcast_tensor_dict( recv_metadata_list = [None] torch.distributed.broadcast_object_list(recv_metadata_list, src=src, - group=group) + group=metadata_group) assert recv_metadata_list[0] is not None tensor_dict = {} async_handles = [] diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 6a6ac49ae321..bd6437ee44c2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -52,6 +52,7 @@ class EngineArgs: enable_lora: bool = False max_loras: int = 1 max_lora_rank: int = 16 + fully_sharded_loras: bool = False lora_extra_vocab_size: int = 256 lora_dtype = 'auto' max_cpu_loras: Optional[int] = None @@ -376,6 +377,14 @@ def add_cli_args( help=('Maximum number of LoRAs to store in CPU memory. ' 'Must be >= than max_num_seqs. ' 'Defaults to max_num_seqs.')) + parser.add_argument( + '--fully-sharded-loras', + action='store_true', + help=('By default, only half of the LoRA computation is ' + 'sharded with tensor parallelism. ' + 'Enabling this will use the fully sharded layers. ' + 'At high sequence length, max rank or ' + 'tensor parallel size, this is likely faster.')) parser.add_argument("--device", type=str, default=EngineArgs.device, @@ -509,6 +518,7 @@ def create_engine_config(self, ) -> EngineConfig: lora_config = LoRAConfig( max_lora_rank=self.max_lora_rank, max_loras=self.max_loras, + fully_sharded_loras=self.fully_sharded_loras, lora_extra_vocab_size=self.lora_extra_vocab_size, lora_dtype=self.lora_dtype, max_cpu_loras=self.max_cpu_loras if self.max_cpu_loras diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 89ee3f0db491..7c1eb2ecbe55 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -7,7 +7,7 @@ from transformers import PreTrainedTokenizer -from vllm.config import ModelConfig +from vllm.config import DecodingConfig, ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.executor.ray_utils import initialize_ray_cluster, ray @@ -697,6 +697,14 @@ async def get_model_config(self) -> ModelConfig: else: return self.engine.get_model_config() + async def get_decoding_config(self) -> DecodingConfig: + """Get the decoding configuration of the vLLM engine.""" + if self.engine_use_ray: + return await self.engine.get_decoding_config.remote( # type: ignore + ) + else: + return self.engine.get_decoding_config() + async def do_log_stats(self) -> None: if self.engine_use_ray: await self.engine.do_log_stats.remote() # type: ignore diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 741d3bcd8089..835803fd4e75 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -22,7 +22,8 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import (MultiModalData, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata) + SequenceGroup, SequenceGroupMetadata, + SequenceStatus) from vllm.transformers_utils.detokenizer import Detokenizer from vllm.transformers_utils.tokenizer_group import (BaseTokenizerGroup, get_tokenizer_group) @@ -101,7 +102,7 @@ def __init__( "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " "tokenizer_revision=%s, trust_remote_code=%s, dtype=%s, " "max_seq_len=%d, download_dir=%r, load_format=%s, " - "tensor_parallel_size=%d, disable_custom_all_reduce=%s" + "tensor_parallel_size=%d, disable_custom_all_reduce=%s, " "quantization=%s, enforce_eager=%s, kv_cache_dtype=%s, " "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, seed=%d)", @@ -217,7 +218,8 @@ def __init__( if self.log_stats: self.stat_logger = StatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.model)) + labels=dict(model_name=model_config.model), + max_model_len=self.model_config.max_model_len) self.stat_logger.info("cache_config", self.cache_config) # Create sequence output processor, e.g. for beam search or @@ -431,9 +433,10 @@ def add_request( # Defensive copy of SamplingParams, which are used by the sampler, # this doesn't deep-copy LogitsProcessor objects sampling_params = sampling_params.clone() - # inject the eos token id into the sampling_params to support min_tokens + # Add the eos token id into the sampling_params to support min_tokens # processing - sampling_params.eos_token_id = seq.eos_token_id + if seq.eos_token_id is not None: + sampling_params.all_stop_token_ids.add(seq.eos_token_id) sampling_params.update_from_generation_config( self.generation_config_fields) @@ -467,6 +470,10 @@ def get_model_config(self) -> ModelConfig: """Gets the model configuration.""" return self.model_config + def get_decoding_config(self) -> DecodingConfig: + """Gets the decoding configuration.""" + return self.decoding_config + def get_num_unfinished_requests(self) -> int: """Gets the number of unfinished requests.""" return self.scheduler.get_num_unfinished_seq_groups() @@ -614,59 +621,109 @@ def _get_stats( """ now = time.time() - # KV Cache Usage in %. + # System State + # Scheduler State + num_running_sys = len(self.scheduler.running) + num_swapped_sys = len(self.scheduler.swapped) + num_waiting_sys = len(self.scheduler.waiting) + + # KV Cache Usage in % num_total_gpu = self.cache_config.num_gpu_blocks num_free_gpu = self.scheduler.block_manager.get_num_free_gpu_blocks() - gpu_cache_usage = 1.0 - (num_free_gpu / num_total_gpu) + gpu_cache_usage_sys = 1.0 - (num_free_gpu / num_total_gpu) num_total_cpu = self.cache_config.num_cpu_blocks - cpu_cache_usage = 0. + cpu_cache_usage_sys = 0. if num_total_cpu > 0: num_free_cpu = self.scheduler.block_manager.get_num_free_cpu_blocks( ) - cpu_cache_usage = 1.0 - (num_free_cpu / num_total_cpu) - - # Scheduler State - num_running = len(self.scheduler.running) - num_swapped = len(self.scheduler.swapped) - num_waiting = len(self.scheduler.waiting) - - # Iteration stats if we have scheduler output. - num_prompt_tokens = 0 - num_generation_tokens = 0 - time_to_first_tokens = [] - time_per_output_tokens = [] - time_e2e_requests = [] + cpu_cache_usage_sys = 1.0 - (num_free_cpu / num_total_cpu) + + # Iteration stats + num_prompt_tokens_iter = 0 + num_generation_tokens_iter = 0 + time_to_first_tokens_iter: List[float] = [] + time_per_output_tokens_iter: List[float] = [] + + # Request stats + # Latency + time_e2e_requests: List[float] = [] + # Metadata + num_prompt_tokens_requests: List[int] = [] + num_generation_tokens_requests: List[int] = [] + best_of_requests: List[int] = [] + n_requests: List[int] = [] + finished_reason_requests: List[str] = [] + + # NOTE: This loop assumes prefill seq_groups are before + # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: - prompt_run = scheduler_outputs.num_prefill_groups > 0 - - # Number of Tokens. - if prompt_run: - num_prompt_tokens = sum( - len(scheduled_seq_group.seq_group.prompt_token_ids) - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups) - num_generation_tokens = sum( - scheduled_seq_group.seq_group.num_seqs() - for scheduled_seq_group in - scheduler_outputs.scheduled_seq_groups) - else: - num_generation_tokens = scheduler_outputs.num_batched_tokens - - # Latency Timings. - time_last_iters = [] - for scheduled_seq_group in scheduler_outputs.scheduled_seq_groups: + num_generation_tokens_from_prefill_groups = 0. + if scheduler_outputs.num_prefill_groups > 0 and len( + scheduler_outputs.scheduled_seq_groups + ) != scheduler_outputs.num_prefill_groups: + print("DETECTED CHUNKED") + + for idx, scheduled_seq_group in enumerate( + scheduler_outputs.scheduled_seq_groups): + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group - # Time since last token. - # (n.b. updates seq_group.metrics.last_token_time) - time_last_iters.append(seq_group.get_last_latency(now)) - # Time since arrival for all finished requests. + + # NOTE: a seq_group that completed all of its prefill tokens + # in the last iteration will have seq_group.is_prefill() = False + # with group_was_prefill = True + if group_was_prefill: + # Number of prompt tokens. + num_prompt_tokens_iter += ( + scheduled_seq_group.token_chunk_size) + + # If the seq_group just finished the prefill state + # get TTFT. + if not seq_group.is_prefill(): + latency = seq_group.get_last_latency(now) + time_to_first_tokens_iter.append(latency) + + # One generation token per finished prefill. + num_generation_tokens_from_prefill_groups += ( + seq_group.num_seqs()) + else: + # TPOTs. + latency = seq_group.get_last_latency(now) + time_per_output_tokens_iter.append(latency) + + # Because of chunked prefill, we can have a single sequence + # group that does multiple prompt_runs. To prevent logging + # the same metadata more than once per request, we standardize + # on logging request level information for finished requests, + # which can only happen once. if seq_group.is_finished(): + # Latency timings time_e2e_requests.append(now - seq_group.metrics.arrival_time) - time_to_first_tokens = time_last_iters if prompt_run else [] - time_per_output_tokens = [] if prompt_run else time_last_iters + # Metadata + num_prompt_tokens_requests.append( + len(seq_group.prompt_token_ids)) + num_generation_tokens_requests.extend([ + seq.get_output_len() + for seq in seq_group.get_finished_seqs() + ]) + best_of_requests.append(seq_group.sampling_params.best_of) + n_requests.append(seq_group.sampling_params.n) + finished_reason_requests.extend([ + SequenceStatus.get_finished_reason(seq.status) + for seq in seq_group.get_finished_seqs() + ]) + + # Number of generation tokens. + # num_batched_tokens equals the number of prompt_tokens plus the + # number of decode_tokens in a single iteration. So, + # num_generation_tokens = num_batched_tokens - num_prompt_tokens + # + num_generation_tokens_from_prefill_groups (since we generate + # one token on prefills on iters where the prefill finishes). + num_generation_tokens_iter = ( + scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + + num_generation_tokens_from_prefill_groups) # Spec decode, if enabled, emits specialized metrics from the worker in # sampler output. @@ -678,17 +735,32 @@ def _get_stats( return Stats( now=now, - num_running=num_running, - num_swapped=num_swapped, - num_waiting=num_waiting, - gpu_cache_usage=gpu_cache_usage, - cpu_cache_usage=cpu_cache_usage, - num_prompt_tokens=num_prompt_tokens, - num_generation_tokens=num_generation_tokens, - time_to_first_tokens=time_to_first_tokens, - time_per_output_tokens=time_per_output_tokens, - time_e2e_requests=time_e2e_requests, + + # System stats + # Scheduler State + num_running_sys=num_running_sys, + num_swapped_sys=num_swapped_sys, + num_waiting_sys=num_waiting_sys, + # KV Cache Usage in % + gpu_cache_usage_sys=gpu_cache_usage_sys, + cpu_cache_usage_sys=cpu_cache_usage_sys, + + # Iteration stats + num_prompt_tokens_iter=num_prompt_tokens_iter, + num_generation_tokens_iter=num_generation_tokens_iter, + time_to_first_tokens_iter=time_to_first_tokens_iter, + time_per_output_tokens_iter=time_per_output_tokens_iter, spec_decode_metrics=spec_decode_metrics, + + # Request stats + # Latency + time_e2e_requests=time_e2e_requests, + # Metadata + num_prompt_tokens_requests=num_prompt_tokens_requests, + num_generation_tokens_requests=num_generation_tokens_requests, + best_of_requests=best_of_requests, + n_requests=n_requests, + finished_reason_requests=finished_reason_requests, ) def add_lora(self, lora_request: LoRARequest) -> bool: diff --git a/vllm/engine/metrics.py b/vllm/engine/metrics.py index d3560f5fefff..45bfad03ec86 100644 --- a/vllm/engine/metrics.py +++ b/vllm/engine/metrics.py @@ -1,6 +1,8 @@ import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Dict, List, Optional, Protocol +from typing import TYPE_CHECKING +from typing import Counter as CollectionsCounter +from typing import Dict, List, Optional, Protocol, Union import numpy as np from prometheus_client import (REGISTRY, Counter, Gauge, Histogram, Info, @@ -21,8 +23,9 @@ # begin-metrics-definitions class Metrics: + labelname_finish_reason = "finished_reason" - def __init__(self, labelnames: List[str]): + def __init__(self, labelnames: List[str], max_model_len: int): # Unregister any existing vLLM collectors for collector in list(REGISTRY._collector_to_names): if hasattr(collector, "_name") and "vllm" in collector._name: @@ -34,18 +37,20 @@ def __init__(self, labelnames: List[str]): documentation='information of cache_config') # System stats + # Scheduler State self.gauge_scheduler_running = Gauge( name="vllm:num_requests_running", documentation="Number of requests currently running on GPU.", labelnames=labelnames) - self.gauge_scheduler_swapped = Gauge( - name="vllm:num_requests_swapped", - documentation="Number of requests swapped to CPU.", - labelnames=labelnames) self.gauge_scheduler_waiting = Gauge( name="vllm:num_requests_waiting", documentation="Number of requests waiting to be processed.", labelnames=labelnames) + self.gauge_scheduler_swapped = Gauge( + name="vllm:num_requests_swapped", + documentation="Number of requests swapped to CPU.", + labelnames=labelnames) + # KV Cache Usage in % self.gauge_gpu_cache_usage = Gauge( name="vllm:gpu_cache_usage_perc", documentation="GPU KV-cache usage. 1 means 100 percent usage.", @@ -55,7 +60,7 @@ def __init__(self, labelnames: List[str]): documentation="CPU KV-cache usage. 1 means 100 percent usage.", labelnames=labelnames) - # Raw stats from last model iteration + # Iteration stats self.counter_prompt_tokens = Counter( name="vllm:prompt_tokens_total", documentation="Number of prefill tokens processed.", @@ -80,18 +85,51 @@ def __init__(self, labelnames: List[str]): 0.01, 0.025, 0.05, 0.075, 0.1, 0.15, 0.2, 0.3, 0.4, 0.5, 0.75, 1.0, 2.5 ]) - self.histogram_e2e_request_latency = Histogram( + + # Request stats + # Latency + self.histogram_e2e_time_request = Histogram( name="vllm:e2e_request_latency_seconds", documentation="Histogram of end to end request latency in seconds.", labelnames=labelnames, buckets=[1.0, 2.5, 5.0, 10.0, 15.0, 20.0, 30.0, 40.0, 50.0, 60.0]) + # Metadata + self.histogram_num_prompt_tokens_request = Histogram( + name="vllm:request_prompt_tokens", + documentation="Number of prefill tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_num_generation_tokens_request = Histogram( + name="vllm:request_generation_tokens", + documentation="Number of generation tokens processed.", + labelnames=labelnames, + buckets=build_1_2_5_buckets(max_model_len), + ) + self.histogram_best_of_request = Histogram( + name="vllm:request_params_best_of", + documentation="Histogram of the best_of request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.histogram_n_request = Histogram( + name="vllm:request_params_n", + documentation="Histogram of the n request parameter.", + labelnames=labelnames, + buckets=[1, 2, 5, 10, 20], + ) + self.counter_request_success = Counter( + name="vllm:request_success", + documentation="Count of successfully processed requests.", + labelnames=labelnames + [Metrics.labelname_finish_reason]) - # Legacy metrics + # Deprecated in favor of vllm:prompt_tokens_total self.gauge_avg_prompt_throughput = Gauge( name="vllm:avg_prompt_throughput_toks_per_s", documentation="Average prefill throughput in tokens/s.", labelnames=labelnames, ) + # Deprecated in favor of vllm:generation_tokens_total self.gauge_avg_generation_throughput = Gauge( name="vllm:avg_generation_throughput_toks_per_s", documentation="Average generation throughput in tokens/s.", @@ -102,24 +140,57 @@ def __init__(self, labelnames: List[str]): # end-metrics-definitions +def build_1_2_5_buckets(max_value: int): + """ + Builds a list of buckets with increasing powers of 10 multiplied by + mantissa values (1, 2, 5) until the value exceeds the specified maximum. + + Example: + >>> build_1_2_5_buckets(100) + [1, 2, 5, 10, 20, 50, 100] + """ + mantissa_lst = [1, 2, 5] + exponent = 0 + buckets = [] + while True: + for m in mantissa_lst: + value = m * 10**exponent + if value <= max_value: + buckets.append(value) + else: + return buckets + exponent += 1 + + @dataclass class Stats: """Created by LLMEngine for use by StatLogger.""" now: float - # System stats. - num_running: int - num_waiting: int - num_swapped: int - gpu_cache_usage: float - cpu_cache_usage: float - - # Raw stats from last model iteration. - num_prompt_tokens: int - num_generation_tokens: int - time_to_first_tokens: List[float] - time_per_output_tokens: List[float] + # System stats (should have _sys suffix) + # Scheduler State + num_running_sys: int + num_waiting_sys: int + num_swapped_sys: int + # KV Cache Usage in % + gpu_cache_usage_sys: float + cpu_cache_usage_sys: float + + # Iteration stats (should have _iter suffix) + num_prompt_tokens_iter: int + num_generation_tokens_iter: int + time_to_first_tokens_iter: List[float] + time_per_output_tokens_iter: List[float] + + # Request stats (should have _requests suffix) + # Latency time_e2e_requests: List[float] + # Metadata + num_prompt_tokens_requests: List[int] + num_generation_tokens_requests: List[int] + best_of_requests: List[int] + n_requests: List[int] + finished_reason_requests: List[str] spec_decode_metrics: Optional["SpecDecodeWorkerMetrics"] = None @@ -133,7 +204,8 @@ def metrics_info(self) -> Dict[str, str]: class StatLogger: """StatLogger is used LLMEngine to log to Promethus and Stdout.""" - def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: + def __init__(self, local_interval: float, labels: Dict[str, str], + max_model_len: int) -> None: # Metadata for logging locally. self.last_local_log = time.time() self.local_interval = local_interval @@ -144,7 +216,8 @@ def __init__(self, local_interval: float, labels: Dict[str, str]) -> None: # Prometheus metrics self.labels = labels - self.metrics = Metrics(labelnames=list(labels.keys())) + self.metrics = Metrics(labelnames=list(labels.keys()), + max_model_len=max_model_len) def info(self, type: str, obj: SupportsMetricsInfo) -> None: if type == "cache_config": @@ -158,34 +231,66 @@ def _local_interval_elapsed(self, now: float) -> bool: return elapsed_time > self.local_interval def _log_prometheus(self, stats: Stats) -> None: - # Set system stat gauges. - self.metrics.gauge_scheduler_running.labels(**self.labels).set( - stats.num_running) - self.metrics.gauge_scheduler_swapped.labels(**self.labels).set( - stats.num_swapped) - self.metrics.gauge_scheduler_waiting.labels(**self.labels).set( - stats.num_waiting) - self.metrics.gauge_gpu_cache_usage.labels(**self.labels).set( - stats.gpu_cache_usage) - self.metrics.gauge_cpu_cache_usage.labels(**self.labels).set( - stats.cpu_cache_usage) - - # Add to token counters. - self.metrics.counter_prompt_tokens.labels(**self.labels).inc( - stats.num_prompt_tokens) - self.metrics.counter_generation_tokens.labels(**self.labels).inc( - stats.num_generation_tokens) - - # Observe request level latencies in histograms. - for ttft in stats.time_to_first_tokens: - self.metrics.histogram_time_to_first_token.labels( - **self.labels).observe(ttft) - for tpot in stats.time_per_output_tokens: - self.metrics.histogram_time_per_output_token.labels( - **self.labels).observe(tpot) - for e2e in stats.time_e2e_requests: - self.metrics.histogram_e2e_request_latency.labels( - **self.labels).observe(e2e) + # System state data + self._log_gauge(self.metrics.gauge_scheduler_running, + stats.num_running_sys) + self._log_gauge(self.metrics.gauge_scheduler_swapped, + stats.num_swapped_sys) + self._log_gauge(self.metrics.gauge_scheduler_waiting, + stats.num_waiting_sys) + self._log_gauge(self.metrics.gauge_gpu_cache_usage, + stats.gpu_cache_usage_sys) + self._log_gauge(self.metrics.gauge_cpu_cache_usage, + stats.cpu_cache_usage_sys) + + # Iteration level data + self._log_counter(self.metrics.counter_prompt_tokens, + stats.num_prompt_tokens_iter) + self._log_counter(self.metrics.counter_generation_tokens, + stats.num_generation_tokens_iter) + self._log_histogram(self.metrics.histogram_time_to_first_token, + stats.time_to_first_tokens_iter) + self._log_histogram(self.metrics.histogram_time_per_output_token, + stats.time_per_output_tokens_iter) + + # Request level data + # Latency + self._log_histogram(self.metrics.histogram_e2e_time_request, + stats.time_e2e_requests) + # Metadata + finished_reason_counter = CollectionsCounter( + stats.finished_reason_requests) + self._log_counter_labels(self.metrics.counter_request_success, + finished_reason_counter, + Metrics.labelname_finish_reason) + self._log_histogram(self.metrics.histogram_num_prompt_tokens_request, + stats.num_prompt_tokens_requests) + self._log_histogram( + self.metrics.histogram_num_generation_tokens_request, + stats.num_generation_tokens_requests) + self._log_histogram(self.metrics.histogram_n_request, stats.n_requests) + self._log_histogram(self.metrics.histogram_best_of_request, + stats.best_of_requests) + + def _log_gauge(self, gauge: Gauge, data: Union[int, float]) -> None: + # Convenience function for logging to gauge. + gauge.labels(**self.labels).set(data) + + def _log_counter(self, counter: Counter, data: Union[int, float]) -> None: + # Convenience function for logging to counter. + counter.labels(**self.labels).inc(data) + + def _log_counter_labels(self, counter: Counter, data: CollectionsCounter, + label_key: str) -> None: + # Convenience function for collection counter of labels. + for label, count in data.items(): + counter.labels(**{**self.labels, label_key: label}).inc(count) + + def _log_histogram(self, histogram: Histogram, + data: Union[List[int], List[float]]) -> None: + # Convenience function for logging list to histogram. + for datum in data: + histogram.labels(**self.labels).observe(datum) def _log_prometheus_interval(self, prompt_throughput: float, generation_throughput: float) -> None: @@ -210,8 +315,8 @@ def log(self, stats: Stats) -> None: self._log_prometheus(stats) # Save tracked stats for token counters. - self.num_prompt_tokens.append(stats.num_prompt_tokens) - self.num_generation_tokens.append(stats.num_generation_tokens) + self.num_prompt_tokens.append(stats.num_prompt_tokens_iter) + self.num_generation_tokens.append(stats.num_generation_tokens_iter) # Log locally every local_interval seconds. if self._local_interval_elapsed(stats.now): @@ -230,15 +335,15 @@ def log(self, stats: Stats) -> None: "Avg prompt throughput: %.1f tokens/s, " "Avg generation throughput: %.1f tokens/s, " "Running: %d reqs, Swapped: %d reqs, " - "Pending: %d reqs, GPU KV cache usage: %.1f%, " - "CPU KV cache usage: %.1f%", + "Pending: %d reqs, GPU KV cache usage: %.1f%%, " + "CPU KV cache usage: %.1f%%", prompt_throughput, generation_throughput, - stats.num_running, - stats.num_swapped, - stats.num_waiting, - stats.gpu_cache_usage * 100, - stats.cpu_cache_usage * 100, + stats.num_running_sys, + stats.num_swapped_sys, + stats.num_waiting_sys, + stats.gpu_cache_usage_sys * 100, + stats.cpu_cache_usage_sys * 100, ) # Reset tracked stats for next interval. diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 5c361b4d184e..16c5b6c08d37 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -9,7 +9,7 @@ import ssl from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.entrypoints.openai.serving_engine import LoRA +from vllm.entrypoints.openai.serving_engine import LoRAModulePath class LoRAParserAction(argparse.Action): @@ -18,7 +18,7 @@ def __call__(self, parser, namespace, values, option_string=None): lora_list = [] for item in values: name, path = item.split('=') - lora_list.append(LoRA(name, path)) + lora_list.append(LoRAModulePath(name, path)) setattr(namespace, self.dest, lora_list) diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index d9763d024eb8..0a949f986775 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -4,14 +4,20 @@ from typing import Dict, List, Literal, Optional, Union import torch -from pydantic import BaseModel, Field, model_validator +from openai.types.chat import ChatCompletionMessageParam +from pydantic import BaseModel, ConfigDict, Field, model_validator from typing_extensions import Annotated from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -class ErrorResponse(BaseModel): +class OpenAIBaseModel(BaseModel): + # OpenAI API does not allow extra fields + model_config = ConfigDict(extra="forbid") + + +class ErrorResponse(OpenAIBaseModel): object: str = "error" message: str type: str @@ -19,7 +25,7 @@ class ErrorResponse(BaseModel): code: int -class ModelPermission(BaseModel): +class ModelPermission(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"modelperm-{random_uuid()}") object: str = "model_permission" created: int = Field(default_factory=lambda: int(time.time())) @@ -34,7 +40,7 @@ class ModelPermission(BaseModel): is_blocking: bool = False -class ModelCard(BaseModel): +class ModelCard(OpenAIBaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) @@ -44,26 +50,26 @@ class ModelCard(BaseModel): permission: List[ModelPermission] = Field(default_factory=list) -class ModelList(BaseModel): +class ModelList(OpenAIBaseModel): object: str = "list" data: List[ModelCard] = Field(default_factory=list) -class UsageInfo(BaseModel): +class UsageInfo(OpenAIBaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 -class ResponseFormat(BaseModel): +class ResponseFormat(OpenAIBaseModel): # type must be "json_object" or "text" type: Literal["text", "json_object"] -class ChatCompletionRequest(BaseModel): +class ChatCompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create - messages: List[Dict[str, str]] + messages: List[ChatCompletionMessageParam] model: str frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None @@ -204,7 +210,7 @@ def check_guided_decoding_count(cls, data): return data -class CompletionRequest(BaseModel): +class CompletionRequest(OpenAIBaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/completions/create model: str @@ -343,19 +349,19 @@ def check_guided_decoding_count(cls, data): return data -class LogProbs(BaseModel): +class LogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) tokens: List[str] = Field(default_factory=list) top_logprobs: Optional[List[Optional[Dict[str, float]]]] = None -class CompletionResponseChoice(BaseModel): +class CompletionResponseChoice(OpenAIBaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -364,7 +370,7 @@ class CompletionResponseChoice(BaseModel): ) -class CompletionResponse(BaseModel): +class CompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -373,12 +379,12 @@ class CompletionResponse(BaseModel): usage: UsageInfo -class CompletionResponseStreamChoice(BaseModel): +class CompletionResponseStreamChoice(OpenAIBaseModel): index: int text: str logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = Field( + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = Field( default=None, description=( "The stop string or token id that caused the completion " @@ -387,7 +393,7 @@ class CompletionResponseStreamChoice(BaseModel): ) -class CompletionStreamResponse(BaseModel): +class CompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"cmpl-{random_uuid()}") object: str = "text_completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -396,20 +402,20 @@ class CompletionStreamResponse(BaseModel): usage: Optional[UsageInfo] = Field(default=None) -class ChatMessage(BaseModel): +class ChatMessage(OpenAIBaseModel): role: str content: str -class ChatCompletionResponseChoice(BaseModel): +class ChatCompletionResponseChoice(OpenAIBaseModel): index: int message: ChatMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionResponse(BaseModel): +class ChatCompletionResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion" created: int = Field(default_factory=lambda: int(time.time())) @@ -418,20 +424,20 @@ class ChatCompletionResponse(BaseModel): usage: UsageInfo -class DeltaMessage(BaseModel): +class DeltaMessage(OpenAIBaseModel): role: Optional[str] = None content: Optional[str] = None -class ChatCompletionResponseStreamChoice(BaseModel): +class ChatCompletionResponseStreamChoice(OpenAIBaseModel): index: int delta: DeltaMessage logprobs: Optional[LogProbs] = None - finish_reason: Optional[Literal["stop", "length"]] = None - stop_reason: Union[None, int, str] = None + finish_reason: Optional[str] = None + stop_reason: Optional[Union[int, str]] = None -class ChatCompletionStreamResponse(BaseModel): +class ChatCompletionStreamResponse(OpenAIBaseModel): id: str = Field(default_factory=lambda: f"chatcmpl-{random_uuid()}") object: str = "chat.completion.chunk" created: int = Field(default_factory=lambda: int(time.time())) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index f6011b6fc4cb..5ed042ef386e 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,8 +1,11 @@ import codecs import time -from typing import AsyncGenerator, AsyncIterator, List, Optional, Union +from typing import (AsyncGenerator, AsyncIterator, Awaitable, Iterable, List, + Optional, Tuple, TypedDict, Union, final) from fastapi import Request +from openai.types.chat import (ChatCompletionContentPartParam, + ChatCompletionRole) from vllm.engine.async_llm_engine import AsyncLLMEngine from vllm.entrypoints.openai.protocol import ( @@ -10,7 +13,8 @@ ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice, ChatCompletionStreamResponse, ChatMessage, DeltaMessage, ErrorResponse, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -20,20 +24,41 @@ logger = init_logger(__name__) +@final # So that it should be compatible with Dict[str, str] +class ConversationMessage(TypedDict): + role: str + content: str + + class OpenAIServingChat(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], response_role: str, - lora_modules: Optional[List[LoRA]] = None, - chat_template=None): + lora_modules: Optional[List[LoRAModulePath]] = None, + chat_template: Optional[str] = None): super().__init__(engine=engine, served_model_names=served_model_names, lora_modules=lora_modules) self.response_role = response_role self._load_chat_template(chat_template) + def _parse_chat_message_content( + self, + role: ChatCompletionRole, + content: Optional[Union[str, + Iterable[ChatCompletionContentPartParam]]], + ) -> Tuple[List[ConversationMessage], List[Awaitable[object]]]: + if content is None: + return [], [] + if isinstance(content, str): + return [ConversationMessage(role=role, content=content)], [] + + # To be implemented: https://github.com/vllm-project/vllm/pull/3467 + # To be implemented: https://github.com/vllm-project/vllm/pull/4200 + raise NotImplementedError("Complex input not supported yet") + async def create_chat_completion( self, request: ChatCompletionRequest, raw_request: Request ) -> Union[ErrorResponse, AsyncGenerator[str, None], @@ -52,10 +77,19 @@ async def create_chat_completion( return error_check_ret try: + conversation: List[ConversationMessage] = [] + + for m in request.messages: + messages, _ = self._parse_chat_message_content( + m["role"], m["content"]) + + conversation.extend(messages) + prompt = self.tokenizer.apply_chat_template( - conversation=request.messages, + conversation=conversation, tokenize=False, - add_generation_prompt=request.add_generation_prompt) + add_generation_prompt=request.add_generation_prompt, + ) except Exception as e: logger.error("Error in applying chat template from request: %s", e) return self.create_error_response(str(e)) @@ -67,7 +101,7 @@ async def create_chat_completion( request, prompt=prompt) sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - decoding_config = self.engine.engine.decoding_config + decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logits_processor = ( @@ -105,9 +139,8 @@ def get_chat_request_role(self, request: ChatCompletionRequest) -> str: async def chat_completion_stream_generator( self, request: ChatCompletionRequest, - result_generator: AsyncIterator[RequestOutput], request_id: str - ) -> Union[ErrorResponse, AsyncGenerator[str, None]]: - + result_generator: AsyncIterator[RequestOutput], + request_id: str) -> AsyncGenerator[str, None]: model_name = self.served_model_names[0] created_time = int(time.time()) chunk_object_type = "chat.completion.chunk" @@ -252,7 +285,7 @@ async def chat_completion_full_generator( model_name = self.served_model_names[0] created_time = int(time.time()) - final_res: RequestOutput = None + final_res: Optional[RequestOutput] = None async for res in result_generator: if await raw_request.is_disconnected(): @@ -317,7 +350,7 @@ async def chat_completion_full_generator( return response - def _load_chat_template(self, chat_template): + def _load_chat_template(self, chat_template: Optional[str]): tokenizer = self.tokenizer if chat_template is not None: diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 211b2e0424c3..6a7f29c4c96f 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -11,7 +11,8 @@ CompletionResponseStreamChoice, CompletionStreamResponse, LogProbs, UsageInfo) -from vllm.entrypoints.openai.serving_engine import LoRA, OpenAIServing +from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, + OpenAIServing) from vllm.logger import init_logger from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) @@ -54,7 +55,7 @@ class OpenAIServingCompletion(OpenAIServing): def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], - lora_modules: Optional[List[LoRA]] = None): + lora_modules: Optional[List[LoRAModulePath]] = None): super().__init__(engine=engine, served_model_names=served_model_names, lora_modules=lora_modules) @@ -84,11 +85,11 @@ async def create_completion(self, request: CompletionRequest, created_time = int(time.time()) # Schedule the request and get the result generator. - generators = [] + generators: List[AsyncIterator[RequestOutput]] = [] try: sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) - decoding_config = self.engine.engine.decoding_config + decoding_config = await self.engine.get_decoding_config() guided_decoding_backend = request.guided_decoding_backend \ or decoding_config.guided_decoding_backend guided_decode_logit_processor = ( @@ -148,7 +149,7 @@ async def create_completion(self, request: CompletionRequest, num_prompts=len(prompts)) # Non-streaming response - final_res_batch: RequestOutput = [None] * len(prompts) + final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: if await raw_request.is_disconnected(): diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 31da27a447c6..3d5ed328b9d1 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -22,17 +22,15 @@ @dataclass -class LoRA: +class LoRAModulePath: name: str local_path: str class OpenAIServing: - def __init__(self, - engine: AsyncLLMEngine, - served_model_names: List[str], - lora_modules=Optional[List[LoRA]]): + def __init__(self, engine: AsyncLLMEngine, served_model_names: List[str], + lora_modules: Optional[List[LoRAModulePath]]): self.engine = engine self.served_model_names = served_model_names if lora_modules is None: @@ -158,7 +156,9 @@ def create_streaming_error_response( }) return json_str - async def _check_model(self, request) -> Optional[ErrorResponse]: + async def _check_model( + self, request: Union[CompletionRequest, ChatCompletionRequest] + ) -> Optional[ErrorResponse]: if request.model in self.served_model_names: return None if request.model in [lora.lora_name for lora in self.lora_requests]: @@ -168,14 +168,16 @@ async def _check_model(self, request) -> Optional[ErrorResponse]: err_type="NotFoundError", status_code=HTTPStatus.NOT_FOUND) - def _maybe_get_lora(self, request) -> Optional[LoRARequest]: + def _maybe_get_lora( + self, request: Union[CompletionRequest, ChatCompletionRequest] + ) -> Optional[LoRARequest]: if request.model in self.served_model_names: return None for lora in self.lora_requests: if request.model == lora.lora_name: return lora # if _check_model has been called earlier, this will be unreachable - raise ValueError("The model `{request.model}` does not exist.") + raise ValueError(f"The model `{request.model}` does not exist.") def _validate_prompt_and_tokenize( self, diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index aa810f974339..e4436b2144bd 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -109,7 +109,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py new file mode 100644 index 000000000000..4c922ef63ee0 --- /dev/null +++ b/vllm/executor/distributed_gpu_executor.py @@ -0,0 +1,115 @@ +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Set, Tuple + +from vllm.executor.executor_base import ExecutorAsyncBase +from vllm.executor.gpu_executor import GPUExecutor +from vllm.logger import init_logger +from vllm.lora.request import LoRARequest +from vllm.sequence import SamplerOutput + +logger = init_logger(__name__) + + +class DistributedGPUExecutor(GPUExecutor): + """Abstract superclass of multi-GPU executor implementations.""" + + def determine_num_available_blocks(self) -> Tuple[int, int]: + """Determine the number of available KV blocks. + + This invokes `determine_num_available_blocks` on each worker and takes + the min of the results, guaranteeing that the selected cache sizes are + compatible with all workers. + + Returns: + - tuple[num_gpu_blocks, num_cpu_blocks] + """ + # Get the maximum number of blocks that can be allocated on GPU and CPU. + num_blocks = self._run_workers("determine_num_available_blocks", ) + + # Since we use a shared centralized controller, we take the minimum + # number of blocks across all workers to make sure all the memory + # operators can be applied to all workers. + num_gpu_blocks = min(b[0] for b in num_blocks) + num_cpu_blocks = min(b[1] for b in num_blocks) + + return num_gpu_blocks, num_cpu_blocks + + def initialize_cache(self, num_gpu_blocks: int, + num_cpu_blocks: int) -> None: + """Initialize the KV cache in all workers. + """ + + # NOTE: We log here to avoid multiple logs when number of workers is + # greater than one. We could log in the engine, but not all executors + # have GPUs. + logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, + num_cpu_blocks) + + self.cache_config.num_gpu_blocks = num_gpu_blocks + self.cache_config.num_cpu_blocks = num_cpu_blocks + + self._run_workers("initialize_cache", + num_gpu_blocks=num_gpu_blocks, + num_cpu_blocks=num_cpu_blocks) + + def execute_model(self, *args, **kwargs) -> List[SamplerOutput]: + all_outputs = self._run_workers("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] + + def add_lora(self, lora_request: LoRARequest) -> bool: + assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "add_lora", + lora_request=lora_request, + ) + + def remove_lora(self, lora_id: int) -> bool: + assert lora_id > 0, "lora_id must be greater than 0." + return self._run_workers( + "remove_lora", + lora_id=lora_id, + ) + + def list_loras(self) -> Set[int]: + return self._run_workers("list_loras") + + @abstractmethod + def _run_workers( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + max_concurrent_workers: Optional[int] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + +class DistributedGPUExecutorAsync(DistributedGPUExecutor, ExecutorAsyncBase): + + @abstractmethod + async def _run_workers_async( + self, + method: str, + *args, + driver_args: Optional[Tuple[Any, ...]] = None, + driver_kwargs: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> Any: + """Runs the given method on all workers.""" + raise NotImplementedError + + async def execute_model_async(self, *args, + **kwargs) -> List[SamplerOutput]: + all_outputs = await self._run_workers_async("execute_model", + driver_args=args, + driver_kwargs=kwargs) + + # Only the driver worker returns the sampling results. + return all_outputs[0] diff --git a/vllm/executor/executor_base.py b/vllm/executor/executor_base.py index 1838c34be2fd..c36aa18fb25b 100644 --- a/vllm/executor/executor_base.py +++ b/vllm/executor/executor_base.py @@ -112,7 +112,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: """Executes one model step on the given sequences.""" raise NotImplementedError diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index d2c60a3b68e1..5ac62f02b99c 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -163,7 +163,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, blocks_to_swap_in=blocks_to_swap_in, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 5a137d1bdcb3..f406287f3c1d 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -84,7 +84,7 @@ async def execute_model_async( blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: + ) -> List[SamplerOutput]: output = await make_async(self.driver_worker.execute_model)( seq_group_metadata_list=seq_group_metadata_list, ) return output diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 6f72babe14fd..b6bcda4e6b18 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -3,12 +3,12 @@ import pickle from collections import defaultdict from itertools import islice, repeat -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple -from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase +from vllm.executor.distributed_gpu_executor import ( # yapf: disable + DistributedGPUExecutor, DistributedGPUExecutorAsync) from vllm.executor.ray_utils import RayWorkerWrapper, ray from vllm.logger import init_logger -from vllm.lora.request import LoRARequest from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) @@ -27,7 +27,7 @@ USE_RAY_COMPILED_DAG = bool(os.getenv("VLLM_USE_RAY_COMPILED_DAG", 0)) -class RayGPUExecutor(ExecutorBase): +class RayGPUExecutor(DistributedGPUExecutor): def _init_executor(self) -> None: assert (not self.speculative_config @@ -179,57 +179,16 @@ def collect_arg_helper_func(**kwargs): self._run_workers("init_worker", all_kwargs=init_worker_all_kwargs) self._run_workers("init_device") - self._run_workers( - "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, - ) - - def determine_num_available_blocks(self) -> Tuple[int, int]: - """Determine the number of available KV blocks. - - This invokes `determine_num_available_blocks` on each worker and takes - the min of the results, guaranteeing that the selected cache sizes are - compatible with all workers. - - Returns: - - Tuple[num_gpu_blocks, num_cpu_blocks] - """ - # Get the maximum number of blocks that can be allocated on GPU and CPU. - num_blocks = self._run_workers("determine_num_available_blocks", ) - - # Since we use a shared centralized controller, we take the minimum - # number of blocks across all workers to make sure all the memory - # operators can be applied to all workers. - num_gpu_blocks = min(b[0] for b in num_blocks) - num_cpu_blocks = min(b[1] for b in num_blocks) - - return num_gpu_blocks, num_cpu_blocks - - def initialize_cache(self, num_gpu_blocks: int, - num_cpu_blocks: int) -> None: - """Initialize the KV cache in all workers. - """ - - # NOTE: We log here to avoid multiple logs when number of workers is - # greater than one. We could log in the engine, but not all executors - # have GPUs. - logger.info("# GPU blocks: %d, # CPU blocks: %d", num_gpu_blocks, - num_cpu_blocks) - - self.cache_config.num_gpu_blocks = num_gpu_blocks - self.cache_config.num_cpu_blocks = num_cpu_blocks - - self._run_workers("initialize_cache", - num_gpu_blocks=num_gpu_blocks, - num_cpu_blocks=num_cpu_blocks) + self._run_workers("load_model", + max_concurrent_workers=self.parallel_config. + max_parallel_loading_workers) def execute_model(self, seq_group_metadata_list: List[SequenceGroupMetadata], blocks_to_swap_in: Dict[int, int], blocks_to_swap_out: Dict[int, int], blocks_to_copy: Dict[int, List[int]], - num_lookahead_slots: int = 0) -> SamplerOutput: + num_lookahead_slots: int = 0) -> List[SamplerOutput]: all_outputs = self._run_workers( "execute_model", driver_kwargs={ @@ -244,23 +203,6 @@ def execute_model(self, output = all_outputs[0] return output - def add_lora(self, lora_request: LoRARequest) -> bool: - assert lora_request.lora_int_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "add_lora", - lora_request=lora_request, - ) - - def remove_lora(self, lora_id: int) -> bool: - assert lora_id > 0, "lora_id must be greater than 0." - return self._run_workers( - "remove_lora", - lora_id=lora_id, - ) - - def list_loras(self) -> Set[int]: - return self._run_workers("list_loras") - def _run_workers( self, method: str, @@ -378,7 +320,7 @@ def _check_if_any_actor_is_dead(self): f"Dead Workers: {dead_actors}. ") -class RayGPUExecutorAsync(RayGPUExecutor, ExecutorAsyncBase): +class RayGPUExecutorAsync(RayGPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -409,23 +351,3 @@ async def _run_workers_async( all_outputs = await asyncio.gather(*coros) return all_outputs - - async def execute_model_async( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - blocks_to_swap_in: Dict[int, int], - blocks_to_swap_out: Dict[int, int], - blocks_to_copy: Dict[int, List[int]], - ) -> SamplerOutput: - all_outputs = await self._run_workers_async( - "execute_model", - driver_kwargs={ - "seq_group_metadata_list": seq_group_metadata_list, - "blocks_to_swap_in": blocks_to_swap_in, - "blocks_to_swap_out": blocks_to_swap_out, - "blocks_to_copy": blocks_to_copy, - }) - - # Only the driver worker returns the sampling results. - output = all_outputs[0] - return output diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py new file mode 100644 index 000000000000..1720566840bb --- /dev/null +++ b/vllm/lora/fully_sharded_layers.py @@ -0,0 +1,262 @@ +# pylint: disable=unused-argument +from typing import TYPE_CHECKING, List, Optional + +import torch +import torch.nn as nn +from transformers import PretrainedConfig + +from vllm.config import LoRAConfig +from vllm.distributed.communication_op import ( + tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.lora.layers import (ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + RowParallelLinearWithLoRA) +from vllm.lora.punica import bgmv, dispatch_bgmv_low_level + +if TYPE_CHECKING: + pass + + +def _fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + return (can_replace(*args, **kwargs) + and kwargs['lora_config'].fully_sharded_loras) + + return dec + + +# these layers are based on the tensor parallelism strategy given in +# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023, +# https://arxiv.org/abs/2311.03285. + + +class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): + """ + Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.lora_a_stacked.shape[2] + start_idx = tp_rank * shard_size + lora_a = lora_a[:, start_idx:start_idx + shard_size] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_gather(buffer) + bgmv(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + # now have column partitioned output + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +def _mcp_apply_weights(x, bias, layer): + """ + MergedColumnParallelLinearWithShardedLoRA and + QKVParallelLinearWithShardedLora share the same + LoRa weight application method. + + The main difference is the step by shard_size for lora_b which can + vary for QKVParallelLinearWithShardedLora but is constant for + MergedColumnParallelLinearWithShardedLoRA. + """ + # expecting 2 for column parallel and 3 for qkv + n = len(layer.lora_a_stacked) + output = layer.base_layer.linear_method.apply_weights( + layer.base_layer, x, bias) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, output.shape[-1]), output.shape + buffers = torch.zeros((n, x.shape[0], layer.lora_a_stacked[0].shape[2]), + dtype=torch.float32, + device=x.device) + for idx in range(n): + bgmv(buffers[idx], x, layer.lora_a_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0) + + buffers = tensor_model_parallel_all_gather(buffers) + left_offset = 0 + for idx in range(n): + shard_size = layer.lora_b_stacked[idx].shape[2] + dispatch_bgmv_low_level(output, buffers[idx], + layer.lora_b_stacked[idx], + layer.indices[:layer.indices_len[0]], 0, 1.0, + left_offset, shard_size) + left_offset += shard_size + + output = output.view(*out_orig_shape) + # now have column partitioned and packed output + return output + + +class MergedColumnParallelLinearWithShardedLoRA( + MergedColumnParallelLinearWithLoRA): + """ + Differs from MergedColumnParallelLinearWithLoRA by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + output_shard_size = self.lora_a_stacked[0].shape[2] + output_start_idx = self.tp_rank * output_shard_size + lora_a = [ + lora_a[i][:, output_start_idx:output_start_idx + output_shard_size] + for i in range(2) + ] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply_weights(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class MergedQKVParallelLinearWithShardedLora(MergedQKVParallelLinearWithLora): + """ + Differs from QKVParallelLinearWithLora by slicing the + LoRA A's also. + + Based on S-LoRA, slicing happens along the rank dim. + """ + + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + shard_size = [self.lora_a_stacked[i].shape[2] for i in range(3)] + start_idx = [self.tp_rank * shard_size[i] for i in range(3)] + lora_a = [ + lora_a[i][:, start_idx[i]:start_idx[i] + + shard_size[i]] if lora_a[i] is not None else None + for i in range(3) + ] + return lora_a + + def apply_weights(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + return _mcp_apply_weights(x, bias, self) + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) + + +class RowParallelLinearWithShardedLoRA(RowParallelLinearWithLoRA): + """ + Differs from RowParallelLinearWithLoRA by slicing the + LoRA B's also. + + Based on S-LoRA, slicing happens along the output dim. + This yields a combined partial sum from the row parallel base + layer and column partitioned output from the LoRA. + """ + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + + def apply_weights(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.linear_method.apply_weights( + self.base_layer, x) + + x = x.view(-1, x.shape[-1]) + output, out_orig_shape = output.view(-1, + output.shape[-1]), output.shape + buffer = torch.zeros((x.shape[0], self.lora_a_stacked.shape[2]), + dtype=torch.float32, + device=x.device) + bgmv(buffer, x, self.lora_a_stacked, + self.indices[:self.indices_len[0]], 0, 1.0) + buffer = tensor_model_parallel_all_reduce(buffer) + + # following S-LoRA, allows the fusing of all_gather and all_reduce + # by adding the column partitioned lora output to a slice of output + # tensor, which is a partial sum due to row parallel. All that + # remains is a standard all_reduce. User should be aware though that + # the output is not the same as a normal row_parallel, it should be + # reduced before being used + shard_size = self.lora_b_stacked.shape[2] + start_idx = self.tp_rank * shard_size + dispatch_bgmv_low_level(output, buffer, self.lora_b_stacked, + self.indices[:self.indices_len[0]], 0, 1.0, + start_idx, shard_size) + + output = output.view(*out_orig_shape) + return output + + @classmethod + @_fully_sharded_can_replace + def can_replace_layer(cls, source_layer: nn.Module, + lora_config: LoRAConfig, packed_modules_list: List, + model_config: Optional[PretrainedConfig]) -> bool: + # specifying kwargs so they can be easily accessed in decorator + return super().can_replace_layer( + source_layer=source_layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config, + decorate=False, + ) diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 98e74168002c..b3609666b2ec 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -1,8 +1,7 @@ # pylint: disable=unused-argument -import inspect import math from dataclasses import dataclass -from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Type +from typing import TYPE_CHECKING, List, Optional, Tuple import torch import torch.nn as nn @@ -16,6 +15,7 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce, tensor_model_parallel_gather) +from vllm.distributed.utils import divide from vllm.lora.punica import add_lora, add_lora_slice, bgmv from vllm.model_executor.layers.linear import (ColumnParallelLinear, MergedColumnParallelLinear, @@ -23,7 +23,7 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.vocab_parallel_embedding import ( - ParallelLMHead, VocabParallelEmbedding) + VocabParallelEmbedding) if TYPE_CHECKING: pass @@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device: raise ValueError(f"Unsupported base layer: {base_layer}") +def _not_fully_sharded_can_replace(can_replace): + """ + decorator which adds the condition of not using fully sharded loras + intended to wrap can_replace_layer() + """ + + def dec(*args, **kwargs): + decorate = kwargs.pop('decorate') if 'decorate' in kwargs else True + condition = (not kwargs['lora_config'].fully_sharded_loras + if decorate else True) + return can_replace(*args, **kwargs) and condition + + return dec + + def _apply_lora( x: torch.Tensor, lora_a_stacked: torch.Tensor, @@ -130,6 +145,14 @@ def __post_init__(self): class BaseLayerWithLoRA(nn.Module): + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + """Slice lora a if splitting for tensor parallelism.""" + ... + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + """Slice lora b if splitting with tensor parallelism.""" + ... + def create_lora_weights( self, max_loras: int, @@ -317,6 +340,11 @@ def can_replace_layer(cls, source_layer: nn.Module, class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): + """ + LoRA on top of ColumnParallelLinear layer. + + LoRA B is sliced for tensor parallelism. + """ def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() @@ -331,10 +359,15 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_size = get_tensor_model_parallel_world_size() + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -357,6 +390,17 @@ def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] + return lora_b + def set_lora( self, index: int, @@ -365,12 +409,11 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( lora_a.T, non_blocking=True) @@ -389,10 +432,9 @@ def set_mapping( self.indices = base_indices self.indices_len = indices_len - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora( x, self.lora_a_stacked, @@ -416,7 +458,7 @@ def forward(self, input_): if not self.base_layer.skip_bias_add else None) # Matrix multiply. - output_parallel = self.apply_weights(input_, bias) + output_parallel = self.apply(input_, bias) if self.base_layer.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -427,6 +469,7 @@ def forward(self, input_): return output, output_bias @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -452,6 +495,7 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config n_slices = 2 if not (len(self.base_layer.output_sizes) == n_slices and self.base_layer.output_sizes[0] @@ -460,12 +504,17 @@ def create_lora_weights( "LoRAColumnParallelLinear2Slice requires 2 slices with " "the same size.") self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() + + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) self.lora_a_stacked = tuple( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -490,6 +539,18 @@ def reset_lora(self, index: int): self.lora_b_stacked[0][index] = 0 self.lora_b_stacked[1][index] = 0 + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + return lora_a + + def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + shard_size = self.output_dim + start_idx = self.tp_rank * shard_size + end_idx = (self.tp_rank + 1) * shard_size + lora_b = [ + lora_b[0][:, start_idx:end_idx], lora_b[1][:, start_idx:end_idx] + ] + return lora_b + def set_lora( self, index: int, @@ -500,13 +561,8 @@ def set_lora( self.reset_lora(index) if self.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[0][:, - start_idx:end_idx], lora_b[1][:, - start_idx:end_idx] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -523,10 +579,9 @@ def set_lora( index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( lora_b[1].T, non_blocking=True) - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -538,6 +593,7 @@ def apply_weights(self, x: torch.Tensor, return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -629,21 +685,25 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config self.tp_size = get_tensor_model_parallel_world_size() - tp_rank = get_tensor_model_parallel_rank() + self.tp_rank = get_tensor_model_parallel_rank() self.q_proj_shard_size = (self.base_layer.num_heads * self.base_layer.head_size) self.kv_proj_shard_size = (self.base_layer.num_kv_heads * self.base_layer.head_size) - self.q_shard_id = tp_rank - self.kv_shard_id = tp_rank // self.base_layer.num_kv_head_replicas + self.q_shard_id = self.tp_rank + self.kv_shard_id = self.tp_rank // self.base_layer.num_kv_head_replicas + lora_a_output_size_per_partition = ( + lora_config.max_lora_rank if not lora_config.fully_sharded_loras + else divide(lora_config.max_lora_rank, self.tp_size)) # q, k, v self.lora_a_stacked = ( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -651,7 +711,7 @@ def create_lora_weights( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -659,7 +719,7 @@ def create_lora_weights( torch.zeros( max_loras, 1, - lora_config.max_lora_rank, + lora_a_output_size_per_partition, self.input_size, dtype=lora_config.lora_dtype, device=self.device, @@ -707,6 +767,25 @@ def reset_lora(self, index: int): self.lora_a_stacked[2][index] = 0 self.lora_b_stacked[2][index] = 0 + def slice_lora_a(self, lora_a: List[torch.Tensor]) -> List[torch.Tensor]: + return lora_a + + def slice_lora_b(self, lora_b: List[torch.Tensor]) -> List[torch.Tensor]: + if lora_b[0] is not None: + lora_b_q = lora_b[0][:, self.q_proj_shard_size * + self.q_shard_id:self.q_proj_shard_size * + (self.q_shard_id + 1)] + if lora_b[1] is not None: + lora_b_k = lora_b[1][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + if lora_b[2] is not None: + lora_b_v = lora_b[2][:, self.kv_proj_shard_size * + self.kv_shard_id:self.kv_proj_shard_size * + (self.kv_shard_id + 1)] + lora_b = [lora_b_q, lora_b_k, lora_b_v] + return lora_b + def set_lora( self, index: int, @@ -717,40 +796,24 @@ def set_lora( self.reset_lora(index) if self.tp_size > 1: - if lora_b[0] is not None: - lora_b_q = lora_b[0][:, self.q_proj_shard_size * - self.q_shard_id:self.q_proj_shard_size * - (self.q_shard_id + 1)] - self.lora_b_stacked[0][ - index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( - lora_b_q.T, non_blocking=True) - if lora_b[1] is not None: - lora_b_k = lora_b[1][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[1][ - index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( - lora_b_k.T, non_blocking=True) - if lora_b[2] is not None: - lora_b_v = lora_b[2][:, self.kv_proj_shard_size * - self.kv_shard_id:self.kv_proj_shard_size * - (self.kv_shard_id + 1)] - self.lora_b_stacked[2][ - index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( - lora_b_v.T, non_blocking=True) - else: - if lora_b[0] is not None: - self.lora_b_stacked[0][ - index, 0, :lora_b[0].shape[1], :lora_b[0].shape[0]].copy_( - lora_b[0].T, non_blocking=True) - if lora_b[1] is not None: - self.lora_b_stacked[1][ - index, 0, :lora_b[1].shape[1], :lora_b[1].shape[0]].copy_( - lora_b[1].T, non_blocking=True) - if lora_b[2] is not None: - self.lora_b_stacked[2][ - index, 0, :lora_b[2].shape[1], :lora_b[2].shape[0]].copy_( - lora_b[2].T, non_blocking=True) + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) + + if lora_b[0] is not None: + lora_b_q = lora_b[0] + self.lora_b_stacked[0][ + index, 0, :lora_b_q.shape[1], :lora_b_q.shape[0]].copy_( + lora_b_q.T, non_blocking=True) + if lora_b[1] is not None: + lora_b_k = lora_b[1] + self.lora_b_stacked[1][ + index, 0, :lora_b_k.shape[1], :lora_b_k.shape[0]].copy_( + lora_b_k.T, non_blocking=True) + if lora_b[2] is not None: + lora_b_v = lora_b[2] + self.lora_b_stacked[2][ + index, 0, :lora_b_v.shape[1], :lora_b_v.shape[0]].copy_( + lora_b_v.T, non_blocking=True) if lora_a[0] is not None: self.lora_a_stacked[0][ @@ -765,10 +828,9 @@ def set_lora( index, 0, :lora_a[2].shape[1], :lora_a[2].shape[0]].copy_( lora_a[2].T, non_blocking=True) - def apply_weights(self, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x, bias) + def apply(self, x: torch.Tensor, + bias: Optional[torch.Tensor]) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x, bias) _apply_lora_packed_nslice( x, self.lora_a_stacked, @@ -780,6 +842,7 @@ def apply_weights(self, x: torch.Tensor, return output @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -801,6 +864,8 @@ def create_lora_weights( max_loras: int, lora_config: LoRAConfig, model_config: Optional[PretrainedConfig] = None) -> None: + self.lora_config = lora_config + self.tp_rank = get_tensor_model_parallel_rank() self.lora_a_stacked = torch.zeros( ( max_loras, @@ -811,11 +876,16 @@ def create_lora_weights( dtype=lora_config.lora_dtype, device=self.device, ) + tp_size = get_tensor_model_parallel_world_size() + lora_b_output_size_per_partition = ( + self.output_size if not lora_config.fully_sharded_loras else + divide(self.output_size, tp_size)) + self.lora_b_stacked = torch.zeros( ( max_loras, 1, - self.output_size, + lora_b_output_size_per_partition, lora_config.max_lora_rank, ), dtype=lora_config.lora_dtype, @@ -829,6 +899,17 @@ def reset_lora(self, index: int): self.lora_a_stacked[index] = 0 self.lora_b_stacked[index] = 0 + def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.input_size + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_a = lora_a[start_idx:end_idx, :] + return lora_a + + def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: + return lora_b + def set_lora( self, index: int, @@ -837,12 +918,10 @@ def set_lora( embeddings_tensor: Optional[torch.Tensor], ): self.reset_lora(index) + if self.base_layer.tp_size > 1: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.input_size - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_a = lora_a[start_idx:end_idx, :] + lora_a = self.slice_lora_a(lora_a) + lora_b = self.slice_lora_b(lora_b) self.lora_a_stacked[index, 0, :lora_a.shape[1], :lora_a.shape[0]].copy_( @@ -862,9 +941,8 @@ def set_mapping( self.indices = base_indices self.indices_len = indices_len - def apply_weights(self, x: torch.Tensor) -> torch.Tensor: - output = self.base_layer.linear_method.apply_weights( - self.base_layer, x) + def apply(self, x: torch.Tensor) -> torch.Tensor: + output = self.base_layer.quant_method.apply(self.base_layer, x) _apply_lora( x, self.lora_a_stacked, @@ -897,7 +975,7 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - output_parallel = self.apply_weights(input_parallel) + output_parallel = self.apply(input_parallel) if self.base_layer.reduce_results and self.base_layer.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: @@ -919,6 +997,7 @@ def weight(self): self.base_layer, "weight") else self.base_layer.qweight @classmethod + @_not_fully_sharded_can_replace def can_replace_layer(cls, source_layer: nn.Module, lora_config: LoRAConfig, packed_modules_list: List, model_config: Optional[PretrainedConfig]) -> bool: @@ -1100,37 +1179,3 @@ def can_replace_layer(cls, source_layer: nn.Module, model_config: Optional[PretrainedConfig]) -> bool: # Special handling for the LogitsProcessor. return False - - -_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { - cls - for cls in globals().values() if inspect.isclass(cls) - and issubclass(cls, BaseLayerWithLoRA) and cls is not BaseLayerWithLoRA -} - - -def from_layer(layer: nn.Module, - max_loras: int, - lora_config: LoRAConfig, - packed_modules_list: List, - model_config: Optional[PretrainedConfig] = None) -> nn.Module: - for lora_cls in _all_lora_classes: - if lora_cls.can_replace_layer(layer, lora_config, packed_modules_list, - model_config): - ret = lora_cls(layer) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret - return layer - - -def from_layer_logits_processor( - layer: LogitsProcessor, - lm_head: ParallelLMHead, - max_loras: int, - lora_config: LoRAConfig, - model_config: Optional[PretrainedConfig] = None, -) -> LogitsProcessorWithLoRA: - ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, - lm_head.weight.dtype, lm_head.weight.device) - ret.create_lora_weights(max_loras, lora_config, model_config) - return ret diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6a077e9b0c75..50d7e9133e0e 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -11,10 +11,10 @@ from vllm.config import LoRAConfig from vllm.logger import init_logger -from vllm.lora.layers import (BaseLayerWithLoRA, LoRAMapping, from_layer, - from_layer_logits_processor) +from vllm.lora.layers import BaseLayerWithLoRA, LoRAMapping from vllm.lora.lora import LoRALayerWeights, PackedLoRALayerWeights -from vllm.lora.utils import parse_fine_tuned_lora_name, replace_submodule +from vllm.lora.utils import (from_layer, from_layer_logits_processor, + parse_fine_tuned_lora_name, replace_submodule) from vllm.utils import LRUCache, is_pin_memory_available logger = init_logger(__name__) diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index fc74269e5587..c87bed54726f 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -49,6 +49,49 @@ def bgmv( punica_kernels.dispatch_bgmv(y, x, w_t_all, indicies, layer_idx, scale) +def dispatch_bgmv_low_level(y: torch.Tensor, x: torch.Tensor, + w_t_all: torch.Tensor, indicies: torch.LongTensor, + layer_idx: int, scale: float, y_offset: int, + y_slice_size: int): + """ + Same as `bgmv` but you can operate on slices of y. + Pass whole y, define y_offset and y_slice_size. + + Semantics: + y[i] += ( + x[i].unsqueeze(0) + @ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) + * scale + ).squeeze(0) + + Args: + y: Shape: `[B, H2]`. Output vectors. Will be changed in-place. + x: Shape: `[B, H1]`. Input vectors. + w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of + all of the transposed LoRA matrices. + indicies: Shape: `[B]`. Indices of the LoRA weights. + layer_idx: Layer index of LoRA weights. + scale: Scaling factor. + y_offset: Offset to apply to the starting column of y. + y_slice_size: Size of the y column slice. + """ + try: + import vllm._punica_C as punica_kernels + except ImportError as e: + _raise_import_error(e) + punica_kernels.dispatch_bgmv_low_level( + y, + x, + w_t_all, + indicies, + layer_idx, + scale, + x.size(1), + y_slice_size, + y_offset, + ) + + def add_lora(y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, diff --git a/vllm/lora/utils.py b/vllm/lora/utils.py index 39e08f0412e4..9942a5fd40de 100644 --- a/vllm/lora/utils.py +++ b/vllm/lora/utils.py @@ -1,11 +1,69 @@ -from typing import Tuple +from typing import List, Optional, Set, Tuple, Type from torch import nn +from transformers import PretrainedConfig +from vllm.config import LoRAConfig from vllm.logger import init_logger +from vllm.lora.fully_sharded_layers import ( + ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA) +# being imported for _all_lora_classes below +# yapf conflicts with isort for this block +# yapf: disable +from vllm.lora.layers import (BaseLayerWithLoRA, ColumnParallelLinearWithLoRA, + LogitsProcessorWithLoRA, + MergedColumnParallelLinearWithLoRA, + MergedQKVParallelLinearWithLora, + QKVParallelLinearWithLora, + RowParallelLinearWithLoRA, + VocabParallelEmbeddingWithLoRA) +# yapf: enable +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead logger = init_logger(__name__) +_all_lora_classes: Set[Type[BaseLayerWithLoRA]] = { + VocabParallelEmbeddingWithLoRA, ColumnParallelLinearWithLoRA, + MergedColumnParallelLinearWithLoRA, QKVParallelLinearWithLora, + MergedQKVParallelLinearWithLora, RowParallelLinearWithLoRA, + LogitsProcessorWithLoRA, ColumnParallelLinearWithShardedLoRA, + MergedColumnParallelLinearWithShardedLoRA, + MergedQKVParallelLinearWithShardedLora, RowParallelLinearWithShardedLoRA +} + + +def from_layer(layer: nn.Module, + max_loras: int, + lora_config: LoRAConfig, + packed_modules_list: List, + model_config: Optional[PretrainedConfig] = None) -> nn.Module: + for lora_cls in _all_lora_classes: + # specifying kwargs so they can be easily accessed in decorator + if lora_cls.can_replace_layer(source_layer=layer, + lora_config=lora_config, + packed_modules_list=packed_modules_list, + model_config=model_config): + ret = lora_cls(layer) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + return layer + + +def from_layer_logits_processor( + layer: LogitsProcessor, + lm_head: ParallelLMHead, + max_loras: int, + lora_config: LoRAConfig, + model_config: Optional[PretrainedConfig] = None, +) -> LogitsProcessorWithLoRA: + ret = LogitsProcessorWithLoRA(layer, lm_head.embedding_dim, + lm_head.weight.dtype, lm_head.weight.device) + ret.create_lora_weights(max_loras, lora_config, model_config) + return ret + def replace_submodule(model: nn.Module, module_name: str, new_module: nn.Module) -> nn.Module: diff --git a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py index 0d74a5f8e81f..d0a5ca5592f9 100644 --- a/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +++ b/vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py @@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict: return schema if isinstance(schema, BaseModel): return schema.model_json_schema() + raise AssertionError(f"Unsupported schema type {schema}") @lru_cache diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index aed2c350bdd1..b4f81527141a 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -220,8 +220,9 @@ def moe_align_block_size( def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, - B_scale: torch.Tensor, topk_weights: torch.Tensor, - topk_ids: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, sorted_token_ids: torch.Tensor, expert_ids: torch.Tensor, num_tokens_post_padded: torch.Tensor, @@ -232,10 +233,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, assert sorted_token_ids.stride(0) == 1 if not use_fp8: - A_scale = None + assert A_scale is None assert B_scale is None else: - A, A_scale = ops.scaled_fp8_quant(A) + A, A_scale = ops.scaled_fp8_quant(A, A_scale) assert B_scale is not None grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ @@ -318,6 +319,8 @@ def fused_moe( use_fp8: bool = False, w1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ This function computes a Mixture of Experts (MoE) layer using two sets of @@ -430,10 +433,13 @@ def fused_moe( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( topk_ids, config['BLOCK_SIZE_M'], E) + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) invoke_fused_moe_kernel(hidden_states, w1, intermediate_cache1, + a1_scale, w1_scale, topk_weights, topk_ids, @@ -443,7 +449,7 @@ def fused_moe( False, topk_ids.shape[1], config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) @@ -451,6 +457,7 @@ def fused_moe( invoke_fused_moe_kernel(intermediate_cache2, w2, intermediate_cache3, + a2_scale, w2_scale, topk_weights, topk_ids, @@ -460,7 +467,7 @@ def fused_moe( True, 1, config, - compute_type=tl.float16, + compute_type=compute_type, use_fp8=use_fp8) if inplace: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 6ad7ae0f6319..4d43ed4c5f14 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,9 +1,8 @@ -from abc import ABC, abstractmethod +from abc import abstractmethod from typing import List, Optional import torch import torch.nn.functional as F -from torch import nn from torch.nn.parameter import Parameter from vllm.distributed import (divide, get_tensor_model_parallel_rank, @@ -12,6 +11,8 @@ tensor_model_parallel_all_gather, tensor_model_parallel_all_reduce) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig, QuantizeMethodBase) from vllm.model_executor.utils import set_weight_attrs logger = init_logger(__name__) @@ -25,7 +26,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset): return shard_size * marlin_tile_size, shard_offset * marlin_tile_size -class LinearMethodBase(ABC): +class LinearMethodBase(QuantizeMethodBase): """Base class for different (maybe quantized) linear methods.""" @abstractmethod @@ -50,22 +51,15 @@ def create_weights(self, layer: torch.nn.Module, raise NotImplementedError @abstractmethod - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: """Apply the weights in layer to the input tensor. Expects create_weights to have been called before on the layer.""" raise NotImplementedError - def process_weights_after_loading(self, layer: nn.Module) -> None: - """Process the weight after loading. - - This can be used for example, to transpose weights for computation. - """ - return - class UnquantizedLinearMethod(LinearMethodBase): """Linear method without quantization. @@ -92,10 +86,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("weight", weight) set_weight_attrs(weight, extra_weight_attrs) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: weight = layer.weight if self.separate_bias_add: if bias is not None: @@ -104,8 +98,8 @@ def apply_weights(self, return F.linear(x, weight, bias) -class ReplicatedLinear(torch.nn.Module): - """Replicated linear layer. +class LinearBase(torch.nn.Module): + """Base linear layer. Args: input_size: input dimension of the linear layer. @@ -113,17 +107,16 @@ class ReplicatedLinear(torch.nn.Module): bias: If true, add bias. skip_bias_add: If true, skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( self, input_size: int, output_size: int, - bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -134,12 +127,46 @@ def __init__( if params_dtype is None: params_dtype = torch.get_default_dtype() self.params_dtype = params_dtype - if linear_method is None: - linear_method = UnquantizedLinearMethod() - self.linear_method = linear_method - self.linear_method.create_weights(self, self.input_size, - [self.output_size], self.input_size, - self.output_size, self.params_dtype) + if quant_config is None: + self.quant_method: Optional[ + QuantizeMethodBase] = UnquantizedLinearMethod() + else: + self.quant_method = quant_config.get_quant_method(self) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + +class ReplicatedLinear(LinearBase): + """Replicated linear layer. + + Args: + input_size: input dimension of the linear layer. + output_size: output dimension of the linear layer. + bias: If true, add bias. + skip_bias_add: If true, skip adding bias but instead return it. + params_dtype: Data type for the parameters. + quant_config: Quantization configure. + """ + + def __init__( + self, + input_size: int, + output_size: int, + bias: bool = True, + skip_bias_add: bool = False, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + ): + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) + + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, self.input_size, + [self.output_size], self.input_size, + self.output_size, self.params_dtype) + if bias: self.bias = Parameter( torch.empty(self.output_size, dtype=self.params_dtype)) @@ -149,12 +176,13 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: bias = self.bias if not self.skip_bias_add else None - output = self.linear_method.apply_weights(self, x, bias) + assert self.quant_method is not None + output = self.quant_method.apply(self, x, bias) output_bias = self.bias if self.skip_bias_add else None return output, output_bias -class ColumnParallelLinear(torch.nn.Module): +class ColumnParallelLinear(LinearBase): """Linear layer with column parallelism. The linear layer is defined as Y = XA + b. A is parallelized along @@ -171,7 +199,7 @@ class ColumnParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. output_sizes: list of output sizes packed into one output, like for QKV the list would be size 3. """ @@ -184,34 +212,28 @@ def __init__( gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, output_sizes: Optional[List[int]] = None, ): - super().__init__() + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) - # Keep input parameters - self.input_size = input_size - self.output_size = output_size self.gather_output = gather_output + # Divide the weight matrix along the last dimension. tp_size = get_tensor_model_parallel_world_size() self.output_size_per_partition = divide(output_size, tp_size) - self.skip_bias_add = skip_bias_add - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype - if linear_method is None: - linear_method = UnquantizedLinearMethod() if output_sizes is None: output_sizes = [output_size] - self.linear_method = linear_method - self.linear_method.create_weights(self, - self.input_size, - [x // tp_size for x in output_sizes], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size, + [x // tp_size for x in output_sizes], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -239,7 +261,8 @@ def forward(self, input_): bias = self.bias if not self.skip_bias_add else None # Matrix multiply. - output_parallel = self.linear_method.apply_weights(self, input_, bias) + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_, bias) if self.gather_output: # All-gather across the partitions. output = tensor_model_parallel_all_gather(output_parallel) @@ -267,7 +290,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( @@ -278,13 +301,13 @@ def __init__( gather_output: bool = False, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): self.output_sizes = output_sizes tp_size = get_tensor_model_parallel_world_size() assert all(output_size % tp_size == 0 for output_size in output_sizes) super().__init__(input_size, sum(output_sizes), bias, gather_output, - skip_bias_add, params_dtype, linear_method, + skip_bias_add, params_dtype, quant_config, self.output_sizes) def weight_loader(self, @@ -384,7 +407,7 @@ class QKVParallelLinear(ColumnParallelLinear): bias can be fused with other element-wise operations. we skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( @@ -396,7 +419,7 @@ def __init__( bias: bool = True, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): self.hidden_size = hidden_size self.head_size = head_size @@ -424,7 +447,7 @@ def __init__( ] super().__init__(input_size, output_size, bias, False, skip_bias_add, - params_dtype, linear_method, output_sizes) + params_dtype, quant_config, output_sizes) def weight_loader(self, param: Parameter, @@ -517,7 +540,7 @@ def weight_loader(self, param_data.copy_(loaded_weight) -class RowParallelLinear(torch.nn.Module): +class RowParallelLinear(LinearBase): """Linear layer with row parallelism. The linear layer is defined as Y = XA + b. A is parallelized along @@ -540,7 +563,7 @@ class RowParallelLinear(torch.nn.Module): bias can be fused with other element-wise operations. We skip adding bias but instead return it. params_dtype: Data type for the parameters. - linear_method: (Maybe quantized) linear method. + quant_config: Quantization configure. """ def __init__( @@ -552,32 +575,26 @@ def __init__( skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, reduce_results: bool = True, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): - super().__init__() - # Keep input parameters - self.input_size = input_size - self.output_size = output_size + super().__init__(input_size, output_size, skip_bias_add, params_dtype, + quant_config) + self.input_is_parallel = input_is_parallel self.reduce_results = reduce_results - if params_dtype is None: - params_dtype = torch.get_default_dtype() - self.params_dtype = params_dtype # Divide the weight matrix along the last dimension. self.tp_size = get_tensor_model_parallel_world_size() self.input_size_per_partition = divide(input_size, self.tp_size) - self.skip_bias_add = skip_bias_add - if linear_method is None: - linear_method = UnquantizedLinearMethod() - self.linear_method = linear_method - self.linear_method.create_weights(self, - self.input_size_per_partition, - [self.output_size], - self.input_size, - self.output_size, - self.params_dtype, - weight_loader=self.weight_loader) + # All the linear layer supports quant method. + assert self.quant_method is not None + self.quant_method.create_weights(self, + self.input_size_per_partition, + [self.output_size], + self.input_size, + self.output_size, + self.params_dtype, + weight_loader=self.weight_loader) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " @@ -616,8 +633,8 @@ def forward(self, input_): input_parallel = splitted_input[tp_rank].contiguous() # Matrix multiply. - output_parallel = self.linear_method.apply_weights( - self, input_parallel) + assert self.quant_method is not None + output_parallel = self.quant_method.apply(self, input_parallel) if self.reduce_results and self.tp_size > 1: output_ = tensor_model_parallel_all_reduce(output_parallel) else: diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index a525add45849..1c652e347d4a 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,20 +1,23 @@ -from typing import Type +from typing import Dict, Type from vllm.model_executor.layers.quantization.aqlm import AQLMConfig from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.fp8 import FP8Config +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig -QUANTIZATION_METHODS = { +QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { "aqlm": AQLMConfig, "awq": AWQConfig, - "fp8": FP8Config, + "fp8": Fp8Config, "gptq": GPTQConfig, "squeezellm": SqueezeLLMConfig, + "gptq_marlin": GPTQMarlinConfig, "marlin": MarlinConfig, } diff --git a/vllm/model_executor/layers/quantization/aqlm.py b/vllm/model_executor/layers/quantization/aqlm.py index b48c6e1702be..83e24fadc140 100644 --- a/vllm/model_executor/layers/quantization/aqlm.py +++ b/vllm/model_executor/layers/quantization/aqlm.py @@ -9,10 +9,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs def get_int_dtype(nbits: int) -> torch.dtype: @@ -207,8 +207,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AQLMConfig": return cls(in_group_size, nbits_per_codebook, num_code_books, out_group_size) - def get_linear_method(self) -> "AQLMLinearMethod": - return AQLMLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["AQLMLinearMethod"]: + if isinstance(layer, LinearBase): + return AQLMLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -321,7 +324,7 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) - def apply_weights( + def apply( self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 4f75134ee188..f4fc7ce020e9 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -4,10 +4,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class AWQConfig(QuantizationConfig): @@ -62,8 +62,11 @@ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig": zero_point = cls.get_from_keys(config, ["zero_point"]) return cls(weight_bits, group_size, zero_point) - def get_linear_method(self) -> "AWQLinearMethod": - return AWQLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["AWQLinearMethod"]: + if isinstance(layer, LinearBase): + return AWQLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return ["gelu", "gelu_fast", "gelu_new", "gelu_pytorch_tanh"] @@ -147,10 +150,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("scales", scales) set_weight_attrs(scales, extra_weight_attrs) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = layer.qweight scales = layer.scales qzeros = layer.qzeros diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 6115e7c3be95..ff5cf0b2bd61 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -1,9 +1,34 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional import torch +from torch import nn -from vllm.model_executor.layers.linear import LinearMethodBase + +class QuantizeMethodBase(ABC): + """Base class for different quantized methods.""" + + @abstractmethod + def create_weights(self, layer: torch.nn.Module, *weight_args, + **extra_weight_attrs): + """Create weights for a layer. + + The weights will be set as attributes of the layer.""" + raise NotImplementedError + + @abstractmethod + def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor: + """Apply the weights in layer to the input tensor. + + Expects create_weights to have been called before on the layer.""" + raise NotImplementedError + + def process_weights_after_loading(self, layer: nn.Module) -> None: + """Process the weight after loading. + + This can be used for example, to transpose weights for computation. + """ + return class QuantizationConfig(ABC): @@ -51,8 +76,16 @@ def get_from_keys(config: Dict[str, Any], keys: List[str]) -> Any: "quantization config.") @abstractmethod - def get_linear_method(self) -> LinearMethodBase: - """Get the linear method to use for the quantized linear layer.""" + def get_quant_method( + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + """Get the quantize method to use for the quantized layer. + + Args: + layer: The layer for the quant method. + Returns: + The quantize method. None if the given layer doesn't support quant + method. + """ raise NotImplementedError @abstractmethod diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index 01e494c870e7..ba9f3149649c 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -1,18 +1,25 @@ -from typing import Any, Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs -class FP8Config(QuantizationConfig): +class Fp8Config(QuantizationConfig): """Config class for FP8.""" + def __init__( + self, + activation_scheme: str = "dynamic", + ) -> None: + self.activation_scheme = activation_scheme + @classmethod def get_name(cls) -> str: return "fp8" @@ -33,11 +40,15 @@ def get_config_filenames(cls) -> List[str]: return [] @classmethod - def from_config(cls, config: Dict[str, Any]) -> "FP8Config": - return cls() + def from_config(cls, config: Dict[str, Any]) -> "Fp8Config": + activation_scheme = cls.get_from_keys(config, ["activation_scheme"]) + return cls(activation_scheme) - def get_linear_method(self) -> "Fp8LinearMethod": - return Fp8LinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["QuantizeMethodBase"]: + if isinstance(layer, LinearBase): + return Fp8LinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -57,7 +68,7 @@ class Fp8LinearMethod(LinearMethodBase): quant_config: The quantization config. """ - def __init__(self, quant_config: FP8Config): + def __init__(self, quant_config: Fp8Config): self.quant_config = quant_config def create_weights( @@ -86,24 +97,24 @@ def create_weights( layer.register_parameter("weight_scaling_factor", w_scale) def process_weights_after_loading(self, layer: Module) -> None: - # Although the linear_method is propagated to all layers, + # Although the quant_method is propagated to all layers, # only linear layers invoke "create_weights". So we check # whether "weight_scaling_facor" is registered to determine # whether the layer is a linear layer that requires quantization. if not hasattr(layer, "weight_scaling_factor"): return - qweight, weight_scale = per_tensor_quantize(layer.weight) + qweight, weight_scale = ops.scaled_fp8_quant(layer.weight) # torch._scaled_mm requires column-major in the second # input (weight), so we transpose the quantized weight. layer.weight = Parameter(qweight.t(), requires_grad=False) layer.weight_scaling_factor.data.copy_(weight_scale) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - qinput, x_scale = per_tensor_quantize(x) + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: + qinput, x_scale = ops.scaled_fp8_quant(x) output, _ = torch._scaled_mm( qinput, layer.weight, @@ -113,27 +124,3 @@ def apply_weights(self, bias=bias, ) return output - - -def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: - """Quantize a tensor using per-tensor static scaling factor. - - Args: - tensor: The input tensor. - """ - finfo = torch.finfo(torch.float8_e4m3fn) - # Calculate the scale as dtype max divided by absmax. - # Since .abs() creates a new tensor, we use aminmax to get - # the min and max first and then calculate the absmax. - min_val, max_val = tensor.aminmax() - amax = min_val.abs().max(max_val.abs()) - scale = finfo.max / amax.clamp(min=1e-12) - # scale and clamp the tensor to bring it to - # the representative range of float8 data type - # (as default cast is unsaturated) - qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) - # Return both float8 data and the inverse scale (as float), - # as both required as inputs to torch._scaled_mm - qweight = qweight.to(torch.float8_e4m3fn) - scale = scale.float().reciprocal() - return qweight, scale diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 92a5cdb9af92..ae9f7019f059 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -7,10 +7,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class GPTQConfig(QuantizationConfig): @@ -63,8 +63,11 @@ def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig": desc_act = cls.get_from_keys(config, ["desc_act"]) return cls(weight_bits, group_size, desc_act) - def get_linear_method(self) -> "GPTQLinearMethod": - return GPTQLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["GPTQLinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -194,10 +197,10 @@ def create_weights( layer.exllama_state = exllama_state - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = layer.qweight out_shape = x.shape[:-1] + (qweight.shape[-1], ) reshaped_x = x.reshape(-1, x.shape[-1]) diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py new file mode 100644 index 000000000000..efbffa0878c4 --- /dev/null +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -0,0 +1,444 @@ +import enum +from enum import Enum +from typing import Any, Dict, List, Optional + +import numpy +import torch +from torch.nn.parameter import Parameter + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + set_weight_attrs) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) + +GPTQ_MARLIN_TILE = 16 +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_MAX_PARALLEL = 16 + +GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4] +GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_SUPPORTED_SYM = [True] + + +# Precompute permutations for Marlin weight and scale shuffling +# +# Marlin works on [16,64] tiles. The goal of the permutations +# is to reorder the weight data so that it is compatible +# with the tensor-core format that is described here: +# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type # noqa: E501 +# +# As a result of this reordering, the vector loads inside the +# kernel will get the data as it is needed for tensor-core +# (without the need to use ldmatrix instructions) +def _get_perms(): + perm = [] + for i in range(32): + perm1 = [] + col = i // 4 + for block in [0, 1]: + for row in [ + 2 * (i % 4), + 2 * (i % 4) + 1, + 2 * (i % 4 + 4), + 2 * (i % 4 + 4) + 1, + ]: + perm1.append(16 * row + col + 8 * block) + for j in range(4): + perm.extend([p + 256 * j for p in perm1]) + + perm = numpy.array(perm) + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + perm = perm.reshape((-1, 8))[:, interleave].ravel() # type: ignore + perm = torch.from_numpy(perm) + scale_perm = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return perm, scale_perm, scale_perm_single + + +_perm, _scale_perm, _scale_perm_single = _get_perms() + + +def get_pack_factor(num_bits): + assert num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS, ( + f"Unsupported num_bits = {num_bits}") + return 32 // num_bits + + +def marlin_permute_scales(s, size_k, size_n, group_size): + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm] + else: + s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + + return s + + +class GPTQMarlinConfig(QuantizationConfig): + """Config class for GPTQ Marlin""" + + def __init__(self, weight_bits: int, group_size: int, desc_act: bool, + is_sym: bool) -> None: + if desc_act and group_size == -1: + # In this case, act_order == True is the same as act_order == False + # (since we have only one group per output channel) + desc_act = False + + self.weight_bits = weight_bits + self.group_size = group_size + self.desc_act = desc_act + self.is_sym = is_sym + + # Verify + if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS: + raise ValueError( + f"Marlin does not support weight_bits = {self.weight_bits}. " + f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} " + "are supported.") + if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin does not support group_size = {self.group_size}. " + f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} " + "are supported.") + if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM: + raise ValueError( + f"Marlin does not support is_sym = {self.is_sym}. " + f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.") + + # Init + self.pack_factor = get_pack_factor(weight_bits) + self.tile_size = GPTQ_MARLIN_TILE + self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N + self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K + self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL + + def __repr__(self) -> str: + return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"desc_act={self.desc_act})") + + @classmethod + def get_name(cls) -> str: + return "gptq_marlin" + + @classmethod + def get_supported_act_dtypes(cls) -> List[torch.dtype]: + return [torch.half] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @classmethod + def get_config_filenames(cls) -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig": + weight_bits = cls.get_from_keys(config, ["bits"]) + group_size = cls.get_from_keys(config, ["group_size"]) + desc_act = cls.get_from_keys(config, ["desc_act"]) + is_sym = cls.get_from_keys(config, ["sym"]) + return cls(weight_bits, group_size, desc_act, is_sym) + + def get_quant_method( + self, + layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]: + if isinstance(layer, LinearBase): + return GPTQMarlinLinearMethod(self) + return None + + def get_scaled_act_names(self) -> List[str]: + return [] + + @classmethod + def is_marlin_compatible(cls, quant_config: Dict[str, Any]): + # Extract data from quant config. + num_bits = quant_config.get("bits", None) + group_size = quant_config.get("group_size", None) + sym = quant_config.get("sym", None) + desc_act = quant_config.get("desc_act", None) + + # If we cannot find the info needed in the config, cannot convert. + if (num_bits is None or group_size is None or sym is None + or desc_act is None): + return False + + # If the capability of the device is too low, cannot convert. + major, minor = torch.cuda.get_device_capability() + device_capability = major * 10 + minor + if device_capability < cls.get_min_capability(): + return False + + # Otherwise, can convert if model satisfies marlin constraints. + return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS + and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES + and sym in GPTQ_MARLIN_SUPPORTED_SYM) + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +class GPTQMarlinLinearMethod(LinearMethodBase): + """Linear method for GPTQ Marlin. + + Args: + quant_config: The GPTQ Marlin quantization config. + """ + + def __init__(self, quant_config: GPTQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs, + ) -> None: + del output_size + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + # Validate dtype + if params_dtype != torch.float16: + raise ValueError( + f"The params dtype must be float16, but got {params_dtype}") + + # Validate output_size_per_partition + output_size_per_partition = sum(output_partition_sizes) + if output_size_per_partition % self.quant_config.min_thread_n != 0: + raise ValueError( + f"Weight output_size_per_partition = " + f"{output_size_per_partition} is not divisible by " + f" min_thread_n = {self.quant_config.min_thread_n}.") + + # Validate input_size_per_partition + if input_size_per_partition % self.quant_config.min_thread_k != 0: + raise ValueError( + f"Weight input_size_per_partition = " + f"{input_size_per_partition} is not divisible " + f"by min_thread_k = {self.quant_config.min_thread_k}.") + + if (group_size < input_size + and input_size_per_partition % group_size != 0): + raise ValueError( + f"Weight input_size_per_partition = {input_size_per_partition}" + f" is not divisible by group_size = {group_size}.") + + # Detect sharding of scales/zp + + # By default, no sharding over "input dim" + scales_and_zp_size = input_size // group_size + scales_and_zp_input_dim = None + + if self.quant_config.desc_act: + # Act-order case + assert self.quant_config.group_size != -1 + + is_k_full = input_size_per_partition == input_size + + else: + # No act-order case + + # K is always full due to full alignment with + # group-size and shard of scales/zp + is_k_full = True + + # If this is a row-parallel case, then shard scales/zp + if (input_size != input_size_per_partition + and self.quant_config.group_size != -1): + scales_and_zp_size = input_size_per_partition // group_size + scales_and_zp_input_dim = 0 + + # Init buffers + + # Quantized weights + qweight = Parameter( + torch.empty( + input_size_per_partition // self.quant_config.pack_factor, + output_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs( + qweight, { + **extra_weight_attrs, + "input_dim": 0, + "output_dim": 1, + "packed_dim": 0, + "pack_factor": self.quant_config.pack_factor, + }) + + # Activation order + g_idx = Parameter( + torch.empty( + input_size_per_partition, + dtype=torch.int32, + ), + requires_grad=False, + ) + # Ignore warning from fused linear layers such as QKVParallelLinear. + set_weight_attrs(g_idx, { + **extra_weight_attrs, "input_dim": 0, + "ignore_warning": True + }) + + g_idx_sort_indices = Parameter( + torch.empty( + g_idx.shape, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(g_idx_sort_indices, extra_weight_attrs) + + # Scales + scales = Parameter( + torch.empty( + scales_and_zp_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs( + scales, { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + }) + + # Quantized zero-points + qzeros = Parameter( + torch.empty(scales_and_zp_size, + output_size_per_partition // + self.quant_config.pack_factor, + dtype=torch.int32, + device="meta"), + requires_grad=False, + ) + set_weight_attrs( + qzeros, { + **extra_weight_attrs, + "input_dim": scales_and_zp_input_dim, + "output_dim": 1, + "packed_dim": 1, + "pack_factor": self.quant_config.pack_factor, + }) + + # Allocate marlin workspace + max_workspace_size = ( + output_size_per_partition // + self.quant_config.min_thread_n) * self.quant_config.max_parallel + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + requires_grad=False) + + layer.register_parameter("qweight", qweight) + layer.register_parameter("g_idx", g_idx) + layer.register_parameter("g_idx_sort_indices", g_idx_sort_indices) + layer.register_parameter("scales", scales) + layer.register_parameter("qzeros", qzeros) + layer.workspace = workspace + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.input_size = input_size + layer.is_k_full = is_k_full + layer.marlin_state = GPTQMarlinState.REPACK + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + reshaped_x = x.reshape(-1, x.shape[-1]) + + size_m = reshaped_x.shape[0] + part_size_n = layer.output_size_per_partition + part_size_k = layer.input_size_per_partition + full_size_k = layer.input_size + + out_shape = x.shape[:-1] + (part_size_n, ) + + if layer.marlin_state == GPTQMarlinState.REPACK: + layer.marlin_state = GPTQMarlinState.READY + + # Newly generated tensors need to replace existing tensors that are + # already registered as parameters by vLLM (and won't be freed) + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + cur_device = layer.qweight.device + + # Process act_order + if self.quant_config.desc_act: + # Get sorting based on g_idx + g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int) + + sorted_g_idx = layer.g_idx[g_idx_sort_indices] + + replace_tensor("g_idx", sorted_g_idx) + replace_tensor("g_idx_sort_indices", g_idx_sort_indices) + + else: + # Reset g_idx related tensors + layer.g_idx = Parameter(torch.empty(0, + dtype=torch.int, + device=cur_device), + requires_grad=False) + layer.g_idx_sort_indices = Parameter(torch.empty( + 0, dtype=torch.int, device=cur_device), + requires_grad=False) + + # Repack weights + marlin_qweight = ops.gptq_marlin_repack( + layer.qweight, + layer.g_idx_sort_indices, + part_size_k, + part_size_n, + ) + replace_tensor("qweight", marlin_qweight) + + # Permute scales + scales_size_k = part_size_k + scales_size_n = part_size_n + if self.quant_config.desc_act: + scales_size_k = full_size_k + + marlin_scales = marlin_permute_scales(layer.scales, scales_size_k, + scales_size_n, + self.quant_config.group_size) + replace_tensor("scales", marlin_scales) + + output = ops.gptq_marlin_gemm(reshaped_x, layer.qweight, layer.scales, + layer.g_idx, layer.g_idx_sort_indices, + layer.workspace, size_m, part_size_n, + part_size_k, layer.is_k_full) + + if bias is not None: + output.add_(bias) # In-place add + + return output.reshape(out_shape) diff --git a/vllm/model_executor/layers/quantization/marlin.py b/vllm/model_executor/layers/quantization/marlin.py index 00c3c404c2d7..94aba620ea08 100644 --- a/vllm/model_executor/layers/quantization/marlin.py +++ b/vllm/model_executor/layers/quantization/marlin.py @@ -4,10 +4,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) +from vllm.model_executor.utils import set_weight_attrs class MarlinConfig(QuantizationConfig): @@ -72,8 +72,11 @@ def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig": group_size = cls.get_from_keys(config, ["group_size"]) return cls(group_size) - def get_linear_method(self) -> "MarlinLinearMethod": - return MarlinLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional["MarlinLinearMethod"]: + if isinstance(layer, LinearBase): + return MarlinLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] @@ -197,7 +200,7 @@ def create_weights( layer.register_parameter("workspace", workspace) set_weight_attrs(workspace, extra_weight_attrs) - def apply_weights( + def apply( self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/squeezellm.py b/vllm/model_executor/layers/quantization/squeezellm.py index cc44447d347b..207dbcee8afc 100644 --- a/vllm/model_executor/layers/quantization/squeezellm.py +++ b/vllm/model_executor/layers/quantization/squeezellm.py @@ -4,10 +4,10 @@ from torch.nn.parameter import Parameter from vllm import _custom_ops as ops -from vllm.model_executor.layers.linear import (LinearMethodBase, - set_weight_attrs) +from vllm.model_executor.layers.linear import LinearBase from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) + QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.utils import set_weight_attrs from vllm.utils import is_hip @@ -51,14 +51,17 @@ def from_config(cls, config: Dict[str, Any]) -> "SqueezeLLMConfig": weight_bits = cls.get_from_keys(config, ["wbits"]) return cls(weight_bits) - def get_linear_method(self) -> "SqueezeLLMLinearMethod": - return SqueezeLLMLinearMethod(self) + def get_quant_method( + self, layer: torch.nn.Module) -> Optional[QuantizeMethodBase]: + if isinstance(layer, LinearBase): + return SqueezeLLMLinearMethod(self) + return None def get_scaled_act_names(self) -> List[str]: return [] -class SqueezeLLMLinearMethod(LinearMethodBase): +class SqueezeLLMLinearMethod(QuantizeMethodBase): """Linear method for SqueezeLLM. Args: @@ -112,10 +115,10 @@ def create_weights(self, layer: torch.nn.Module, layer.register_parameter("lookup_table", lookup_table) set_weight_attrs(lookup_table, extra_weight_attrs) - def apply_weights(self, - layer: torch.nn.Module, - x: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None) -> torch.Tensor: qweight = layer.qweight lookup_table = layer.lookup_table out_shape = x.shape[:-1] + (qweight.shape[-1], ) diff --git a/vllm/model_executor/layers/rotary_embedding.py b/vllm/model_executor/layers/rotary_embedding.py index b8361af61ae3..25365a9b50a1 100644 --- a/vllm/model_executor/layers/rotary_embedding.py +++ b/vllm/model_executor/layers/rotary_embedding.py @@ -431,8 +431,8 @@ def forward( torch.full_like(positions, k)).long() idx = (torch.add(positions, long_prompt_offset) if long_prompt_offset is not None else positions) - self.long_short_cos_sin_cache = self.long_short_cos_sin_cache.to( - idx.device) + self.long_short_cos_sin_cache: torch.Tensor = ( + self.long_short_cos_sin_cache.to(idx.device)) idx = torch.add(idx, offsets) if offsets is not None else idx cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2ffa8227cc4e..d79c99e5d0a4 100644 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -13,6 +13,9 @@ from vllm.sequence import (Logprob, PromptLogprobs, SampleLogprobs, SamplerOutput, SequenceGroupOutput, SequenceOutput) +# (num_token_ids, num_parent_ids) per sequence group. +SampleResultType = List[Tuple[List[int], List[int]]] + class Sampler(nn.Module): """Samples the next tokens from the model's outputs. @@ -155,7 +158,7 @@ def _apply_min_tokens_penalty( have not been generated yet """ # list of indices in logits that will be set to -inf - logits_to_penalize = [] + logits_to_penalize: List[Tuple[int, int]] = [] logits_applied = 0 for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids @@ -169,19 +172,17 @@ def _apply_min_tokens_penalty( start_idx = sample_indices[0] min_tokens = sampling_params.min_tokens - if min_tokens > 0: + token_ids_to_penalize = sampling_params.all_stop_token_ids + if min_tokens > 0 and token_ids_to_penalize: seqs_to_penalize = [] - for i, seq_id in enumerate(seq_ids): + for j, seq_id in enumerate(seq_ids): seq_data = seq_group.seq_data[seq_id] if len(seq_data.output_token_ids) < min_tokens: - seqs_to_penalize.append(i) + seqs_to_penalize.append(j) if seqs_to_penalize: # convert to the index into logits - seqs_to_penalize = [start_idx + i for i in seqs_to_penalize] - # use set() to remove any duplicates - token_ids_to_penalize = set(sampling_params.stop_token_ids + - [sampling_params.eos_token_id]) + seqs_to_penalize = [start_idx + j for j in seqs_to_penalize] # itertools.product pairs each seq index with every token id logits_to_penalize.extend( itertools.product(seqs_to_penalize, token_ids_to_penalize)) @@ -271,7 +272,7 @@ def _apply_min_p( def _greedy_sample( selected_seq_groups: List[SequenceGroupToSample], samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run greedy sampling on a given samples. Args: @@ -286,7 +287,7 @@ def _greedy_sample( """ samples = samples.tolist() sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -306,7 +307,7 @@ def _greedy_sample( def _random_sample( selected_seq_groups: List[SequenceGroupToSample], random_samples: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run random sampling on a given samples. Args: @@ -322,7 +323,7 @@ def _random_sample( # Find the maximum best_of value of the prompt phase requests. random_samples = random_samples.cpu() sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -350,7 +351,7 @@ def _random_sample( def _beam_search_sample( selected_seq_groups: List[SequenceGroupToSample], logprobs: torch.Tensor, -) -> List[Tuple[List[int], List[int]]]: +) -> SampleResultType: """Run beam sampling on a given samples. Args: @@ -372,7 +373,7 @@ def _beam_search_sample( # NOTE: Beam search is not vectorized, so its speed can be slower than # other sampling methods. sample_idx = 0 - results = [] + results: SampleResultType = [] for seq_group in selected_seq_groups: if not seq_group.do_sample: results.append(([], [])) @@ -393,16 +394,16 @@ def _beam_search_sample( next_token_ids = next_token_ids.tolist() else: # Generation phase. - cumulative_logprobs = [ + cumulative_logprobs: List[int] = [ seq_group.seq_data[seq_id].cumulative_logprob for seq_id in seq_ids ] - cumulative_logprobs = torch.tensor( + cumulative_logprobs_tensor = torch.tensor( cumulative_logprobs, dtype=torch.float, device=seq_group_logprobs.device) seq_group_logprobs = (seq_group_logprobs + - cumulative_logprobs.unsqueeze(dim=1)) + cumulative_logprobs_tensor.unsqueeze(dim=1)) _, topk_ids = torch.topk(seq_group_logprobs.flatten(), 2 * beam_width) topk_ids = topk_ids.tolist() @@ -454,8 +455,10 @@ def _sample_with_torch( sampling_metadata: SamplingMetadata, include_gpu_probs_tensor: bool, modify_greedy_probs: bool, -) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params @@ -557,8 +560,10 @@ def _sample_with_triton_kernel( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, -) -> List[Tuple[List[int], List[int]]]: - categorized_seq_group_ids = {t: [] for t in SamplingType} +) -> SampleResultType: + categorized_seq_group_ids: Dict[SamplingType, + List[int]] = {t: [] + for t in SamplingType} categorized_sample_indices = sampling_metadata.categorized_sample_indices for i, seq_group in enumerate(sampling_metadata.seq_groups): sampling_params = seq_group.sampling_params @@ -634,7 +639,7 @@ def _sample( probs: torch.Tensor, logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, sampling_tensors: SamplingTensors, include_gpu_probs_tensor: bool, modify_greedy_probs: bool -) -> Tuple[List[Tuple[List[int], List[int]]], Optional[torch.Tensor]]: +) -> Tuple[SampleResultType, Optional[torch.Tensor]]: """ Args: probs: (num_query_tokens_in_batch, num_vocab) @@ -645,7 +650,7 @@ def _sample( Returns: (next_token_ids, parent_seq_ids) for each seq group in a batch. If sampling is skipped, it returns ([], []) - sampled_token_ids_tensor: A tensor of sampled token ids. + sampled_token_ids_tensor: A tensor of sampled token ids. """ return _sample_with_torch( probs, @@ -682,7 +687,7 @@ def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor: def _get_logprobs( logprobs: torch.Tensor, sampling_metadata: SamplingMetadata, - sample_results: List[Tuple[List[int], List[int]]], + sample_results: SampleResultType, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: """Return sample lobprobs and prompt logprobs. @@ -753,8 +758,8 @@ def _get_logprobs( assert len(next_token_ids) == len(query_indices) if len(query_indices) == 0: - empty_sampled_logprob = [] - empty_prompt_logprob = None + empty_sampled_logprob: SampleLogprobs = [] + empty_prompt_logprob: Optional[PromptLogprobs] = None return [empty_prompt_logprob], [empty_sampled_logprob] query_indices_gpu = torch.tensor(query_indices, device=logprobs.device) @@ -967,7 +972,7 @@ def _modify_greedy_probs_inplace(logprobs: torch.Tensor, probs: torch.Tensor, def _build_sampler_output( - sample_results: List[Tuple[List[int], List[int]]], + sample_results: SampleResultType, sampling_metadata: SamplingMetadata, prompt_logprobs: List[Optional[PromptLogprobs]], sample_logprobs: List[SampleLogprobs], @@ -1011,7 +1016,7 @@ def _build_sampler_output( ) -def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[str]: +def _get_next_prompt_tokens(seq_group: SequenceGroupToSample) -> List[int]: """Get a list of next prompt tokens to compute logprob from a given sequence group. diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index f75c35a69d4a..70e64167f869 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -3,9 +3,9 @@ import glob import os from abc import ABC, abstractmethod -from typing import (TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple, - Type) +from typing import Any, Dict, Generator, List, Optional, Tuple, Type +import huggingface_hub import torch from torch import nn @@ -13,6 +13,8 @@ LoadFormat, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig) from vllm.logger import init_logger +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.tensorizer import ( TensorizerConfig, is_vllm_serialized_tensorizer, load_with_tensorizer, tensorizer_weights_iterator) @@ -24,9 +26,6 @@ pt_weights_iterator, safetensors_weights_iterator) from vllm.model_executor.models.llava import LlavaForConditionalGeneration -if TYPE_CHECKING: - from vllm.model_executor.layers.linear import LinearMethodBase - _VISION_MODEL_CLASSES = [ LlavaForConditionalGeneration, ] @@ -34,11 +33,10 @@ logger = init_logger(__name__) -def _get_linear_method( +def _get_quantization_config( model_config: ModelConfig, - load_config: LoadConfig) -> Optional["LinearMethodBase"]: - """Get the (maybe quantized) linear method.""" - linear_method = None + load_config: LoadConfig) -> Optional[QuantizationConfig]: + """Get the quantization config.""" if model_config.quantization is not None: quant_config = get_quant_config(model_config, load_config) capability = torch.cuda.get_device_capability() @@ -55,9 +53,8 @@ def _get_linear_method( f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}") - - linear_method = quant_config.get_linear_method() - return linear_method + return quant_config + return None def _get_model_initialization_kwargs( @@ -85,10 +82,10 @@ def _initialize_model( vision_language_config: Optional[VisionLanguageConfig]) -> nn.Module: """Initialize a model with the given configurations.""" model_class = get_model_architecture(model_config)[0] - linear_method = _get_linear_method(model_config, load_config) + quant_config = _get_quantization_config(model_config, load_config) return model_class(config=model_config.hf_config, - linear_method=linear_method, + quant_config=quant_config, **_get_model_initialization_kwargs( model_class, lora_config, vision_language_config)) @@ -135,7 +132,9 @@ def _maybe_download_from_modelscope( model_path = snapshot_download( model_id=model, cache_dir=self.load_config.download_dir, - revision=revision) + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + revision=revision, + ) else: model_path = model return model_path @@ -229,9 +228,11 @@ def load_model(self, *, model_config: ModelConfig, "fall_back_to_pt_during_load", True)), ) for _, module in model.named_modules(): - linear_method = getattr(module, "linear_method", None) - if linear_method is not None: - linear_method.process_weights_after_loading(module) + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) + # FIXME: Remove this after Mixtral is updated + # to use quant_method. if hasattr(module, "process_weights_after_loading"): module.process_weights_after_loading() return model.eval() @@ -314,11 +315,11 @@ def _load_model_serialized( with set_default_torch_dtype(model_config.dtype): with torch.device(device_config.device): model_class = get_model_architecture(model_config)[0] - linear_method = _get_linear_method(model_config, - self.load_config) + quant_config = _get_quantization_config( + model_config, self.load_config) extra_kwargs = _get_model_initialization_kwargs( model_class, lora_config, vision_language_config) - extra_kwargs["linear_method"] = linear_method + extra_kwargs["quant_config"] = quant_config tensorizer_config = copy.copy(self.tensorizer_config) tensorizer_config.model_class = model_class diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index 7e65d54bc522..2d654b2fefb8 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -13,7 +13,8 @@ from vllm.config import ModelConfig, ParallelConfig from vllm.logger import init_logger -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -63,7 +64,7 @@ def _construct_tensorizer_args(self) -> "TensorizerArgs": "s3_secret_access_key": self.s3_secret_access_key, "s3_endpoint": self.s3_endpoint, } - return TensorizerArgs(**tensorizer_args) + return TensorizerArgs(**tensorizer_args) # type: ignore def verify_with_parallel_config( self, @@ -251,7 +252,7 @@ class TensorizerAgent: """ def __init__(self, tensorizer_config: TensorizerConfig, - linear_method: LinearMethodBase, **extra_kwargs): + quant_config: QuantizationConfig, **extra_kwargs): if tensorizer_load_fail is not None: raise ImportError( "Tensorizer is not installed. Please install tensorizer " @@ -262,19 +263,21 @@ def __init__(self, tensorizer_config: TensorizerConfig, self.tensorizer_args = ( self.tensorizer_config._construct_tensorizer_args()) self.extra_kwargs = extra_kwargs - if extra_kwargs.get("linear_method", None) is not None: - self.linear_method = extra_kwargs["linear_method"] + if extra_kwargs.get("quant_config", None) is not None: + self.quant_config = extra_kwargs["quant_config"] else: - self.linear_method = linear_method + self.quant_config = quant_config self.model = self._init_model() def _init_model(self): + assert self.tensorizer_config.hf_config is not None model_args = self.tensorizer_config.hf_config model_args.torch_dtype = self.tensorizer_config.dtype + assert self.tensorizer_config.model_class is not None with no_init_or_tensor(): return self.tensorizer_config.model_class( config=model_args, - linear_method=self.linear_method, + quant_config=self.quant_config, **self.extra_kwargs) def _resize_lora_embeddings(self): diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index c0905b905131..c1abde9af770 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -127,11 +127,14 @@ def get_quant_config(model_config: ModelConfig, if not is_local: # Download the config files. with get_lock(model_name_or_path, load_config.download_dir): - hf_folder = snapshot_download(model_name_or_path, - revision=model_config.revision, - allow_patterns="*.json", - cache_dir=load_config.download_dir, - tqdm_class=DisabledTqdm) + hf_folder = snapshot_download( + model_name_or_path, + revision=model_config.revision, + allow_patterns="*.json", + cache_dir=load_config.download_dir, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + tqdm_class=DisabledTqdm, + ) else: hf_folder = model_name_or_path @@ -161,12 +164,14 @@ def get_quant_config(model_config: ModelConfig, return quant_cls.from_config(config) -def download_weights_from_hf(model_name_or_path: str, - cache_dir: Optional[str], - allow_patterns: List[str], - revision: Optional[str] = None) -> str: +def download_weights_from_hf( + model_name_or_path: str, + cache_dir: Optional[str], + allow_patterns: List[str], + revision: Optional[str] = None, +) -> str: """Download model weights from Hugging Face Hub. - + Args: model_name_or_path (str): The model name or path. cache_dir (Optional[str]): The cache directory to store the model @@ -179,26 +184,30 @@ def download_weights_from_hf(model_name_or_path: str, Returns: str: The path to the downloaded model weights. """ - # Before we download we look at that is available: - fs = HfFileSystem() - file_list = fs.ls(model_name_or_path, detail=False, revision=revision) - - # depending on what is available we download different things - for pattern in allow_patterns: - matching = fnmatch.filter(file_list, pattern) - if len(matching) > 0: - allow_patterns = [pattern] - break + if not huggingface_hub.constants.HF_HUB_OFFLINE: + # Before we download we look at that is available: + fs = HfFileSystem() + file_list = fs.ls(model_name_or_path, detail=False, revision=revision) + + # depending on what is available we download different things + for pattern in allow_patterns: + matching = fnmatch.filter(file_list, pattern) + if len(matching) > 0: + allow_patterns = [pattern] + break logger.info("Using model weights format %s", allow_patterns) # Use file lock to prevent multiple processes from # downloading the same model weights at the same time. with get_lock(model_name_or_path, cache_dir): - hf_folder = snapshot_download(model_name_or_path, - allow_patterns=allow_patterns, - cache_dir=cache_dir, - tqdm_class=DisabledTqdm, - revision=revision) + hf_folder = snapshot_download( + model_name_or_path, + allow_patterns=allow_patterns, + cache_dir=cache_dir, + tqdm_class=DisabledTqdm, + revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, + ) return hf_folder diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 69162b0a92d6..186cee258436 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -31,11 +31,12 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -77,17 +78,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -110,7 +111,7 @@ def __init__( position_embedding: str, rope_theta: float = 10000, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size @@ -132,13 +133,13 @@ def __init__( self.total_num_heads, self.total_num_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) # Create the alibi slopes and slice them. if self.postion_embedding == "ALIBI": @@ -184,7 +185,7 @@ class BaiChuanDecoderLayer(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -196,13 +197,13 @@ def __init__(self, position_embedding=position_embedding, rope_theta=rope_theta, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -243,7 +244,7 @@ class BaiChuanModel(nn.Module): def __init__(self, config: PretrainedConfig, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -254,7 +255,7 @@ def __init__(self, config.hidden_size, ) self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, linear_method) + BaiChuanDecoderLayer(config, position_embedding, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -303,13 +304,13 @@ def __init__( self, config, position_embedding: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.model = BaiChuanModel(config, position_embedding, linear_method) + self.quant_config = quant_config + self.model = BaiChuanModel(config, position_embedding, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() @@ -388,13 +389,13 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): if config.hidden_size == 4096: # baichuan2 7b - super().__init__(config, "ROPE", linear_method, lora_config) + super().__init__(config, "ROPE", quant_config, lora_config) else: # baichuan 13b, baichuan2 13b - super().__init__(config, "ALIBI", linear_method, lora_config) + super().__init__(config, "ALIBI", quant_config, lora_config) class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): @@ -403,7 +404,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): - super().__init__(config, "ROPE", linear_method, lora_config) + super().__init__(config, "ROPE", quant_config, lora_config) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 14f325e624f4..1d7e5d2517c7 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -28,10 +28,11 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -70,7 +71,7 @@ class BloomAttention(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -87,13 +88,13 @@ def __init__( self.head_dim, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) # Create the alibi slopes and slice them. @@ -129,21 +130,20 @@ class BloomMLP(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.dense_h_to_4h = ColumnParallelLinear( hidden_size, 4 * hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.gelu_impl = get_act_fn("gelu", quant_config, 4 * hidden_size) self.dense_4h_to_h = RowParallelLinear( 4 * hidden_size, hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -158,17 +158,17 @@ class BloomBlock(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, linear_method) + self.self_attention = BloomAttention(config, quant_config) self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) - self.mlp = BloomMLP(config, linear_method) + self.mlp = BloomMLP(config, quant_config) self.apply_residual_connection_post_layernorm = ( config.apply_residual_connection_post_layernorm) @@ -214,7 +214,7 @@ class BloomModel(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.embed_dim = config.hidden_size @@ -229,7 +229,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - BloomBlock(config, linear_method) + BloomBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -262,12 +262,12 @@ class BloomForCausalLM(nn.Module): def __init__( self, config: BloomConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = BloomModel(config, linear_method) + self.quant_config = quant_config + self.transformer = BloomModel(config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 3cdb7a7bca1c..e116af2ed080 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -13,11 +13,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -33,7 +34,7 @@ class GLMAttention(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -65,13 +66,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=config.add_bias_linear or config.add_qkv_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.total_num_heads * self.head_dim, config.hidden_size, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) # https://huggingface.co/THUDM/chatglm3-6b-32k/blob/e210410255278dd9d74463cf396ba559c0ef801c/modeling_chatglm.py#L141 @@ -123,7 +124,7 @@ class GLMMLP(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -134,7 +135,7 @@ def __init__( config.hidden_size, [config.ffn_hidden_size] * 2, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) self.activation_func = SiluAndMul() @@ -144,7 +145,7 @@ def __init__( config.ffn_hidden_size, config.hidden_size, bias=config.add_bias_linear, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, hidden_states): @@ -166,7 +167,7 @@ class GLMBlock(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -180,7 +181,7 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, linear_method) + self.self_attention = GLMAttention(config, quant_config) self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -188,7 +189,7 @@ def __init__( config.hidden_size, eps=config.layernorm_epsilon) # MLP - self.mlp = GLMMLP(config, linear_method) + self.mlp = GLMMLP(config, quant_config) def forward( self, @@ -236,7 +237,7 @@ class GLMTransformer(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -246,7 +247,7 @@ def __init__( # Transformer layers. self.layers = nn.ModuleList( - [GLMBlock(config, linear_method) for i in range(self.num_layers)]) + [GLMBlock(config, quant_config) for i in range(self.num_layers)]) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -281,7 +282,7 @@ class ChatGLMModel(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -291,7 +292,7 @@ def __init__( self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, linear_method) + self.encoder = GLMTransformer(config, quant_config) self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size) @@ -333,13 +334,13 @@ class ChatGLMForCausalLM(nn.Module): def __init__( self, config: ChatGLMConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ): super().__init__() self.config: ChatGLMConfig = config - self.linear_method = linear_method - self.transformer = ChatGLMModel(config, linear_method) + self.quant_config = quant_config + self.transformer = ChatGLMModel(config, quant_config) self.lm_head_weight = self.transformer.output_layer.weight self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index d80969773e16..17c2f1223d96 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -32,11 +32,12 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -91,7 +92,7 @@ class CohereMLP(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -101,13 +102,13 @@ def __init__( self.hidden_size, [self.intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.down_proj = RowParallelLinear( self.intermediate_size, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.act_fn = SiluAndMul() @@ -123,7 +124,7 @@ class CohereAttention(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -158,13 +159,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -218,13 +219,13 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = CohereAttention(config, linear_method=linear_method) + self.self_attn = CohereAttention(config, quant_config=quant_config) - self.mlp = CohereMLP(config, linear_method=linear_method) + self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) @@ -257,7 +258,7 @@ class CohereModel(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -265,7 +266,7 @@ def __init__( self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - CohereDecoderLayer(config, linear_method=linear_method) + CohereDecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = LayerNorm(param_shape=(config.hidden_size), @@ -298,14 +299,14 @@ class CohereForCausalLM(nn.Module): def __init__( self, config: CohereConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.logits_processor = LogitsProcessor(config.vocab_size, scale=config.logit_scale) - self.model = CohereModel(config, linear_method) + self.model = CohereModel(config, quant_config) self.sampler = Sampler() @torch.no_grad() diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 179094b8fd7a..a4a0ae50c645 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -9,11 +9,12 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,7 +45,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=params_dtype, - linear_method=None, + quant_config=None, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: @@ -63,7 +64,7 @@ class DbrxExperts(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, params_dtype: Optional[torch.dtype] = None, ): super().__init__() @@ -165,7 +166,7 @@ class DbrxAttention(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model @@ -183,13 +184,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( self.d_model, self.d_model, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -244,11 +245,11 @@ class DbrxFusedNormAttention(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, linear_method) + self.attn = DbrxAttention(config, quant_config) self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -278,11 +279,11 @@ class DbrxBlock(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, linear_method) - self.ffn = DbrxExperts(config, linear_method) + self.norm_attn_norm = DbrxFusedNormAttention(config, quant_config) + self.ffn = DbrxExperts(config, quant_config) def forward( self, @@ -307,7 +308,7 @@ class DbrxModel(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.wte = VocabParallelEmbedding( @@ -315,7 +316,7 @@ def __init__( config.d_model, ) self.blocks = nn.ModuleList( - [DbrxBlock(config, linear_method) for _ in range(config.n_layers)]) + [DbrxBlock(config, quant_config) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, @@ -348,13 +349,13 @@ class DbrxForCausalLM(nn.Module): def __init__( self, config: DbrxConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.unpadded_vocab_size = config.vocab_size - self.transformer = DbrxModel(config, linear_method) + self.transformer = DbrxModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.d_model, diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index d476630ee6f1..be9a6b6813f8 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -29,7 +29,8 @@ from transformers import PretrainedConfig from vllm.config import LoRAConfig -from vllm.model_executor.layers.linear import LinearMethodBase +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM @@ -55,13 +56,13 @@ class DeciLMForCausalLM(LlamaForCausalLM): def __init__( self, config: Optional[PretrainedConfig] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: config.num_key_value_heads = max(config.num_key_value_heads_per_layer) delattr(config, "num_key_value_heads_per_layer") super().__init__(config=config, - linear_method=linear_method, + quant_config=quant_config, lora_config=lora_config) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 46101a152ec0..e5f7ba086a35 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -34,12 +34,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -56,18 +57,18 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -86,7 +87,7 @@ class DeepseekMoE(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -103,7 +104,7 @@ def __init__( DeepseekMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False) for idx in range(self.n_routed_experts) ]) @@ -112,7 +113,7 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, - linear_method=None) + quant_config=None) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * @@ -121,7 +122,7 @@ def __init__( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False, ) @@ -177,7 +178,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -208,14 +209,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -251,7 +252,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -266,18 +267,18 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, linear_method=linear_method) + self.mlp = DeepseekMoE(config=config, quant_config=quant_config) else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -320,7 +321,7 @@ class DeepseekModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -331,9 +332,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, - layer_idx, - linear_method=linear_method) + DeepseekDecoderLayer(config, layer_idx, quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -361,12 +360,12 @@ class DeepseekForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = DeepseekModel(config, linear_method) + self.quant_config = quant_config + self.model = DeepseekModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 25ce239d1466..08dd69923dc6 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -32,10 +32,11 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -76,7 +77,7 @@ class FalconAttention(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -115,7 +116,7 @@ def __init__( self.total_num_kv_heads, bias=config.bias, skip_bias_add=True, - linear_method=linear_method, + quant_config=quant_config, ) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim @@ -129,7 +130,7 @@ def __init__( self.hidden_size, bias=config.bias, skip_bias_add=True, - linear_method=linear_method, + quant_config=quant_config, reduce_results=self.reduce_row_parallel_results) self.use_rotary = config.rotary @@ -192,7 +193,7 @@ class FalconMLP(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -201,8 +202,7 @@ def __init__( 4 * hidden_size, bias=config.bias, skip_bias_add=True, - linear_method=linear_method) - quant_config = getattr(linear_method, "quant_config", None) + quant_config=quant_config) self.act = get_act_fn("gelu", quant_config, 4 * hidden_size) self.reduce_row_parallel_results = not (config.new_decoder_architecture or config.parallel_attn) @@ -212,7 +212,7 @@ def __init__( bias=config.bias, skip_bias_add=True, reduce_results=self.reduce_row_parallel_results, - linear_method=linear_method) + quant_config=quant_config) def forward(self, x: torch.Tensor) -> torch.Tensor: # NOTE(zhuohan): Following huggingface, we do not fuse bias add here. @@ -229,13 +229,13 @@ class FalconDecoderLayer(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, linear_method) - self.mlp = FalconMLP(config, linear_method) + self.self_attention = FalconAttention(config, quant_config) + self.mlp = FalconMLP(config, quant_config) self.config = config if config.new_decoder_architecture: @@ -311,7 +311,7 @@ class FalconModel(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -327,7 +327,7 @@ def __init__( # Transformer blocks self.h = nn.ModuleList([ - FalconDecoderLayer(config, linear_method) + FalconDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -359,12 +359,12 @@ class FalconForCausalLM(nn.Module): def __init__( self, config: FalconConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = FalconModel(config, linear_method) + self.quant_config = quant_config + self.transformer = FalconModel(config, quant_config) self.lm_head_weight = self.transformer.word_embeddings.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index c3193258d641..bb73ff4d206d 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -27,11 +27,12 @@ from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -77,17 +78,17 @@ def __init__( intermediate_size: int, hidden_act: Optional[str] = None, hidden_activation: Optional[str] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.act_fn = _get_gemma_act_fn(hidden_act, hidden_activation) def forward(self, x): @@ -106,7 +107,7 @@ def __init__(self, head_dim: int, max_position_embeddings: int = 8192, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -135,13 +136,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -176,7 +177,7 @@ class GemmaDecoderLayer(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -187,14 +188,14 @@ def __init__( head_dim=config.head_dim, max_position_embeddings=config.max_position_embeddings, rope_theta=config.rope_theta, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = GemmaMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, hidden_activation=getattr(config, "hidden_activation", None), - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -235,7 +236,7 @@ class GemmaModel(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -245,7 +246,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, linear_method) + GemmaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -308,14 +309,14 @@ class GemmaForCausalLM(nn.Module): def __init__( self, config: GemmaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config # Unused. super().__init__() self.config = config - self.linear_method = linear_method - self.model = GemmaModel(config, linear_method) + self.quant_config = quant_config + self.model = GemmaModel(config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 850050c7232d..75eaebf0dbd1 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -27,10 +27,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -44,7 +45,7 @@ class GPT2Attention(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -61,13 +62,13 @@ def __init__( self.head_dim, total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, scale=self.scale) @@ -90,7 +91,7 @@ def __init__( self, intermediate_size: int, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -98,15 +99,14 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -122,7 +122,7 @@ class GPT2Block(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -130,9 +130,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPT2Attention(config, linear_method) + self.attn = GPT2Attention(config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPT2MLP(inner_dim, config, linear_method) + self.mlp = GPT2MLP(inner_dim, config, quant_config) def forward( self, @@ -163,7 +163,7 @@ class GPT2Model(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -174,7 +174,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPT2Block(config, linear_method) + GPT2Block(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -203,12 +203,12 @@ class GPT2LMHeadModel(nn.Module): def __init__( self, config: GPT2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = GPT2Model(config, linear_method) + self.quant_config = quant_config + self.transformer = GPT2Model(config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 8278ba02514d..d057fd928fdb 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -28,10 +28,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -45,7 +46,7 @@ class GPTBigCodeAttention(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -72,14 +73,14 @@ def __init__( total_num_heads, total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -111,7 +112,7 @@ def __init__( self, intermediate_size: int, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -119,15 +120,14 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -143,7 +143,7 @@ class GPTBigCodeBlock(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -151,9 +151,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, linear_method) + self.attn = GPTBigCodeAttention(config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = GPTBigMLP(inner_dim, config, linear_method) + self.mlp = GPTBigMLP(inner_dim, config, quant_config) def forward( self, @@ -184,7 +184,7 @@ class GPTBigCodeModel(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -195,7 +195,7 @@ def __init__( self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.h = nn.ModuleList([ - GPTBigCodeBlock(config, linear_method) + GPTBigCodeBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -224,12 +224,12 @@ class GPTBigCodeForCausalLM(nn.Module): def __init__( self, config: GPTBigCodeConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = GPTBigCodeModel(config, linear_method) + self.quant_config = quant_config + self.transformer = GPTBigCodeModel(config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 7a830d7f9c96..8d7fe8a5beef 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -26,10 +26,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,7 +45,7 @@ class GPTJAttention(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -56,13 +57,13 @@ def __init__( self.head_size, self.total_num_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( config.hidden_size, config.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -105,21 +106,20 @@ def __init__( self, intermediate_size: int, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.n_embd self.fc_in = ColumnParallelLinear( hidden_size, intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) self.fc_out = RowParallelLinear( intermediate_size, hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.activation_function, quant_config, intermediate_size) @@ -135,14 +135,14 @@ class GPTJBlock(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, linear_method) - self.mlp = GPTJMLP(inner_dim, config, linear_method) + self.attn = GPTJAttention(config, quant_config) + self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( self, @@ -169,7 +169,7 @@ class GPTJModel(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -179,7 +179,7 @@ def __init__( self.embed_dim, ) self.h = nn.ModuleList( - [GPTJBlock(config, linear_method) for _ in range(config.n_layer)]) + [GPTJBlock(config, quant_config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) def forward( @@ -207,13 +207,13 @@ class GPTJForCausalLM(nn.Module): def __init__( self, config: GPTJConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config assert not config.tie_word_embeddings - self.transformer = GPTJModel(config, linear_method) + self.transformer = GPTJModel(config, quant_config) self.lm_head = ParallelLMHead( config.vocab_size, config.n_embd, diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index b946aed92ed3..bab563b9c5a3 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -26,10 +26,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,7 +45,7 @@ class GPTNeoXAttention(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -63,13 +64,13 @@ def __init__( self.head_size, self.total_num_heads, bias=self.bias, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( config.hidden_size, config.hidden_size, bias=self.bias, - linear_method=linear_method, + quant_config=quant_config, ) scaling = self.head_size**-0.5 rotary_dim = int(self.head_size * config.rotary_pct) @@ -105,20 +106,19 @@ class GPTNeoXMLP(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.dense_h_to_4h = ColumnParallelLinear( config.hidden_size, config.intermediate_size, - linear_method=linear_method, + quant_config=quant_config, ) self.dense_4h_to_h = RowParallelLinear( config.intermediate_size, config.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) @@ -134,7 +134,7 @@ class GPTNeoXLayer(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -142,8 +142,8 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, linear_method) - self.mlp = GPTNeoXMLP(config, linear_method) + self.attention = GPTNeoXAttention(config, quant_config) + self.mlp = GPTNeoXMLP(config, quant_config) def forward( self, @@ -182,7 +182,7 @@ class GPTNeoXModel(nn.Module): def __init__( self, config: GPTNeoXConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -192,7 +192,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - GPTNeoXLayer(config, linear_method) + GPTNeoXLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layer_norm = nn.LayerNorm(config.hidden_size, @@ -223,12 +223,12 @@ class GPTNeoXForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.gpt_neox = GPTNeoXModel(config, linear_method) + self.quant_config = quant_config + self.gpt_neox = GPTNeoXModel(config, quant_config) self.embed_out = ParallelLMHead( config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index db1da8bdc4fb..5811cae83bf8 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -9,11 +9,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -30,17 +31,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w2 = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -63,7 +64,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -94,13 +95,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.wo = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -135,7 +136,7 @@ class InternLMDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -150,13 +151,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.feed_forward = InternLM2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.attention_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -195,7 +196,7 @@ class InternLM2Model(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -206,7 +207,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - InternLMDecoderLayer(config, linear_method) + InternLMDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -238,12 +239,12 @@ class InternLM2ForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = InternLM2Model(config, linear_method) + self.quant_config = quant_config + self.model = InternLM2Model(config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index e7ee749e824e..bd6a180ec8df 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -29,10 +29,11 @@ from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -68,7 +69,7 @@ class JAISAttention(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = config.hidden_size @@ -88,13 +89,13 @@ def __init__( self.head_dim, total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.hidden_size, self.hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) tp_rank = get_tensor_model_parallel_rank() @@ -128,7 +129,7 @@ def __init__( self, intermediate_size: int, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -137,19 +138,19 @@ def __init__( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_fc2 = (ColumnParallelLinear( hidden_size, intermediate_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) if self.swiglu else None) self.c_proj = RowParallelLinear( intermediate_size, hidden_size, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.act = SwiGLUActivation() @@ -169,7 +170,7 @@ class JAISBlock(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.hidden_size @@ -177,9 +178,9 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, linear_method) + self.attn = JAISAttention(config, quant_config) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.mlp = JAISMLP(inner_dim, config, linear_method) + self.mlp = JAISMLP(inner_dim, config, quant_config) def forward( self, @@ -210,7 +211,7 @@ class JAISModel(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -227,7 +228,7 @@ def __init__( else: self.embeddings_scale = config.mup_embeddings_scale self.h = nn.ModuleList([ - JAISBlock(config, linear_method) + JAISBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) @@ -261,12 +262,12 @@ class JAISLMHeadModel(nn.Module): def __init__( self, config: JAISConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = JAISModel(config, linear_method) + self.quant_config = quant_config + self.transformer = JAISModel(config, quant_config) self.lm_head_weight = self.transformer.wte.weight if hasattr(config, "width_scale"): self.output_logits_scale = config.width_scale diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index c102b40045c9..f6d7fc8733fc 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -33,11 +33,12 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -56,17 +57,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QKVParallelLinear] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -89,7 +90,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, ) -> None: @@ -131,13 +132,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -174,7 +175,7 @@ class LlamaDecoderLayer(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -199,7 +200,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, bias=attention_bias, sliding_window=sliding_window, ) @@ -207,7 +208,7 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -248,7 +249,7 @@ class LlamaModel(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -264,7 +265,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) + LlamaDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -329,13 +330,12 @@ class LlamaForCausalLM(nn.Module): def __init__( self, config: LlamaConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = LlamaModel(config, linear_method, lora_config=lora_config) + self.model = LlamaModel(config, quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 314a2792bf16..dcde4dfa0795 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -9,8 +9,9 @@ from vllm.attention import AttentionMetadata from vllm.config import VisionLanguageConfig from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.linear import LinearMethodBase from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader @@ -61,7 +62,7 @@ class LlavaForConditionalGeneration(nn.Module): def __init__(self, config: "LlavaConfig", vision_language_config: VisionLanguageConfig, - linear_method: Optional["LinearMethodBase"] = None) -> None: + quant_config: Optional["QuantizationConfig"] = None) -> None: super().__init__() self.config = config @@ -83,8 +84,8 @@ def __init__(self, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) - self.linear_method = linear_method - self.language_model = LlamaModel(config.text_config, linear_method) + self.quant_config = quant_config + self.language_model = LlamaModel(config.text_config, quant_config) self.unpadded_vocab_size = config.text_config.vocab_size self.lm_head = ParallelLMHead( self.unpadded_vocab_size, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index f0d72fafcaf7..c90bcfbfc470 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -35,12 +35,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -84,7 +85,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=self.params_dtype, - linear_method=None) + quant_config=None) self.ws = nn.Parameter( torch.empty(self.num_total_experts, @@ -147,17 +148,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -180,7 +181,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -211,13 +212,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -258,7 +259,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -274,7 +275,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.num_experts = getattr(self.config, "num_experts", 0) if self.num_experts == 0: @@ -282,7 +283,7 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) else: self.mlp = MiniCPMMoE(num_experts=config.num_experts, @@ -329,7 +330,7 @@ class MiniCPMModel(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -345,7 +346,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(config, linear_method) + MiniCPMDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -412,15 +413,15 @@ class MiniCPMForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config self.num_experts = getattr(self.config, "num_experts", 0) - self.linear_method = linear_method + self.quant_config = quant_config self.model = MiniCPMModel(config, - linear_method, + quant_config, lora_config=lora_config) unpadded_vocab_size = config.vocab_size if lora_config: diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index a33b795d7088..c5dd1a63e2f7 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -27,6 +27,7 @@ from torch import nn from transformers import MixtralConfig +from vllm import _custom_ops as ops from vllm.attention import Attention, AttentionMetadata from vllm.config import LoRAConfig from vllm.distributed import (get_tensor_model_parallel_rank, @@ -34,13 +35,13 @@ tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.fp8 import (Fp8LinearMethod, - per_tensor_quantize) +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -69,7 +70,7 @@ def __init__( intermediate_size: int, params_dtype: Optional[torch.dtype] = None, tp_size: Optional[int] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.tp_size = tp_size or get_tensor_model_parallel_world_size() @@ -79,7 +80,7 @@ def __init__( self.intermediate_size = intermediate_size // self.tp_size # FIXME(pcmoritz): Make this more general to support different # quantization schemes - self.use_fp8 = isinstance(linear_method, Fp8LinearMethod) + self.use_fp8 = isinstance(quant_config, Fp8Config) if params_dtype is None: params_dtype = torch.get_default_dtype() @@ -89,7 +90,7 @@ def __init__( self.num_total_experts, bias=False, params_dtype=self.params_dtype, - linear_method=None) + quant_config=None) self.ws = nn.Parameter( torch.empty(self.num_total_experts, @@ -104,6 +105,13 @@ def __init__( device="cuda", dtype=self.params_dtype)) + set_weight_attrs(self.ws, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.w2s, { + "weight_loader": self.weight_loader, + }) + # Scaling factors for FP8 weights self.ws_scale = nn.Parameter( torch.ones( @@ -114,12 +122,23 @@ def __init__( self.num_total_experts, device="cuda", dtype=torch.float32), requires_grad=False) if self.use_fp8 else None - set_weight_attrs(self.ws, { - "weight_loader": self.weight_loader, - }) - set_weight_attrs(self.w2s, { - "weight_loader": self.weight_loader, - }) + # Scaling factors for FP8 activations + need_act_scales = (self.use_fp8 + and quant_config.activation_scheme == "static") + self.as_scale = nn.Parameter( + torch.zeros(1, device="cuda", dtype=torch.float32), + requires_grad=False) if need_act_scales else None + self.a2s_scale = nn.Parameter( + torch.zeros(1, device="cuda", dtype=torch.float32), + requires_grad=False) if need_act_scales else None + + if need_act_scales: + set_weight_attrs(self.as_scale, { + "weight_loader": self.weight_loader, + }) + set_weight_attrs(self.a2s_scale, { + "weight_loader": self.weight_loader, + }) def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, expert_id: int): @@ -134,16 +153,18 @@ def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, shard_size:2 * shard_size, :] = loaded_weight[shard, :] if weight_name.endswith("w2.weight"): param_data[expert_id, :, :] = loaded_weight[:, shard] + if "act_scale" in weight_name: + param_data[:] = param_data[:].max(loaded_weight) def process_weights_after_loading(self): if self.use_fp8: ws = torch.empty_like(self.ws.data, dtype=torch.float8_e4m3fn) w2s = torch.empty_like(self.w2s.data, dtype=torch.float8_e4m3fn) for expert in range(self.num_total_experts): - ws[expert, :, :], self.ws_scale[expert] = per_tensor_quantize( + ws[expert, :, :], self.ws_scale[expert] = ops.scaled_fp8_quant( self.ws.data[expert, :, :]) w2s[expert, :, :], self.w2s_scale[ - expert] = per_tensor_quantize(self.w2s.data[expert, :, :]) + expert] = ops.scaled_fp8_quant(self.w2s.data[expert, :, :]) self.ws = nn.Parameter(ws, requires_grad=False) self.w2s = nn.Parameter(w2s, requires_grad=False) @@ -161,7 +182,9 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: inplace=True, use_fp8=self.use_fp8, w1_scale=self.ws_scale, - w2_scale=self.w2s_scale) + w2_scale=self.w2s_scale, + a1_scale=self.as_scale, + a2_scale=self.a2s_scale) if self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( @@ -178,7 +201,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -203,12 +226,12 @@ def __init__(self, self.rope_theta = rope_theta self.sliding_window = sliding_window - if isinstance(linear_method, Fp8LinearMethod): + if isinstance(quant_config, Fp8Config): print_warning_once( "For Mixtral FP8 quantization, we currently do not quantize " "the attention layers until their FP8 performance is improved." ) - linear_method = None + quant_config = None self.qkv_proj = QKVParallelLinear( hidden_size, @@ -216,13 +239,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -259,7 +282,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -272,13 +295,13 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, - linear_method=linear_method) + quant_config=quant_config) self.block_sparse_moe = MixtralMoE( num_experts=config.num_local_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, - linear_method=linear_method) + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -318,7 +341,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -334,7 +357,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) + MixtralDecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -384,14 +407,13 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method self.model = MixtralModel(config, - linear_method, + quant_config, lora_config=lora_config) self.unpadded_vocab_size = config.vocab_size if lora_config: @@ -443,11 +465,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] expert_params_mapping = [ + # These are the weights for the experts # (param_name, weight_name, expert_id) ("ws" if weight_name in ["w1", "w3"] else "w2s", f"experts.{expert_id}.{weight_name}.weight", expert_id) for expert_id in range(self.config.num_local_experts) for weight_name in ["w1", "w2", "w3"] + ] + [ + # These are the activation scales for the experts + # (param_name, weight_name, expert_id) + ("as_scale" if weight_name in ["w1", "w3"] else "a2s_scale", + f"experts.{expert_id}.{weight_name}.act_scale", expert_id) + for expert_id in range(self.config.num_local_experts) + for weight_name in ["w1", "w2", "w3"] ] params_dict = dict(self.named_parameters()) diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index acd13cc27f15..38c62afced28 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -34,11 +34,12 @@ get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - QKVParallelLinear, +from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -55,7 +56,7 @@ def __init__( num_experts: int, hidden_size: int, intermediate_size: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.num_experts = num_experts @@ -65,15 +66,15 @@ def __init__( self.w1 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w2 = ReplicatedLinear(self.ffn_dim, self.hidden_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.w3 = ReplicatedLinear(self.hidden_dim, self.ffn_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) # TODO: Use vllm's SiluAndMul self.act_fn = nn.SiLU() @@ -92,7 +93,7 @@ class MixtralMoE(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -115,14 +116,14 @@ def __init__( MixtralMLP(self.num_total_experts, config.hidden_size, config.intermediate_size, - linear_method=linear_method) + quant_config=quant_config) if idx in self.expert_indicies else None for idx in range(self.num_total_experts) ]) self.gate = ReplicatedLinear(config.hidden_size, self.num_total_experts, bias=False, - linear_method=None) + quant_config=None) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: num_tokens, hidden_dim = hidden_states.shape @@ -162,7 +163,7 @@ def __init__(self, num_kv_heads: int, max_position: int = 4096 * 32, rope_theta: float = 10000, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -193,13 +194,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -236,7 +237,7 @@ class MixtralDecoderLayer(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -249,9 +250,9 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, sliding_window=config.sliding_window, - linear_method=linear_method) + quant_config=quant_config) self.block_sparse_moe = MixtralMoE(config=config, - linear_method=linear_method) + quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, @@ -291,7 +292,7 @@ class MixtralModel(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -302,7 +303,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, linear_method=linear_method) + MixtralDecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -331,12 +332,12 @@ class MixtralForCausalLM(nn.Module): def __init__( self, config: MixtralConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = MixtralModel(config, linear_method) + self.quant_config = quant_config + self.model = MixtralModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 340f63286739..6fa5c5bd3014 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -11,10 +11,11 @@ get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -42,7 +43,7 @@ class MPTAttention(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.d_model = config.d_model @@ -65,7 +66,7 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) if self.qk_ln: self.q_ln = nn.LayerNorm(self.d_model) @@ -74,7 +75,7 @@ def __init__( self.d_model, self.d_model, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) tp_world_size = get_tensor_model_parallel_world_size() @@ -133,7 +134,7 @@ class MPTMLP(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model @@ -143,15 +144,14 @@ def __init__( hidden_size, intermediate_size, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn("gelu", quant_config, intermediate_size) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=not config.no_bias, - linear_method=linear_method, + quant_config=quant_config, ) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -166,14 +166,14 @@ class MPTBlock(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, linear_method) + self.attn = MPTAttention(config, quant_config) self.norm_2 = nn.LayerNorm(hidden_size) - self.ffn = MPTMLP(config, linear_method) + self.ffn = MPTMLP(config, quant_config) def forward( self, @@ -201,7 +201,7 @@ class MPTModel(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() assert config.embedding_fraction == 1.0 @@ -212,7 +212,7 @@ def __init__( config.d_model, ) self.blocks = nn.ModuleList( - [MPTBlock(config, linear_method) for _ in range(config.n_layers)]) + [MPTBlock(config, quant_config) for _ in range(config.n_layers)]) self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -246,14 +246,14 @@ class MPTForCausalLM(nn.Module): def __init__( self, config: MPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config assert config.tie_word_embeddings - self.linear_method = linear_method + self.quant_config = quant_config - self.transformer = MPTModel(config, linear_method) + self.transformer = MPTModel(config, quant_config) self.lm_head_weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 15527569b9e2..f212ea2166e1 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -30,11 +30,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -54,7 +55,7 @@ class OlmoAttention(nn.Module): def __init__( self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -79,7 +80,7 @@ def __init__( self.head_dim, self.total_num_heads, bias=config.attention_bias, - linear_method=linear_method, + quant_config=quant_config, ) # Rotary embeddings. @@ -99,7 +100,7 @@ def __init__( self.hidden_size, self.hidden_size, bias=config.attention_bias, - linear_method=linear_method, + quant_config=quant_config, ) def forward( @@ -129,7 +130,7 @@ class OlmoMLP(nn.Module): def __init__( self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -141,7 +142,7 @@ def __init__( self.hidden_size, [self.intermediate_size] * 2, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) # Activation function. @@ -152,7 +153,7 @@ def __init__( self.intermediate_size, self.hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) def forward( @@ -174,13 +175,13 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, linear_method) + self.self_attn = OlmoAttention(config, quant_config) # MLP block. - self.mlp = OlmoMLP(config, linear_method) + self.mlp = OlmoMLP(config, quant_config) # LayerNorm self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -216,14 +217,14 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, linear_method) + OlmoDecoderLayer(config, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, @@ -270,11 +271,10 @@ class OlmoForCausalLM(nn.Module): def __init__(self, config: OlmoConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method - self.model = OlmoModel(config, linear_method) + self.model = OlmoModel(config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight else: diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 89263166bca8..336f765ababa 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -27,11 +27,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -60,7 +61,7 @@ def __init__( embed_dim: int, num_heads: int, bias: bool = True, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.embed_dim = embed_dim @@ -77,13 +78,13 @@ def __init__( self.head_dim, total_num_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.out_proj = RowParallelLinear( embed_dim, embed_dim, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.attn = Attention(self.num_heads, self.head_dim, @@ -107,7 +108,7 @@ class OPTDecoderLayer(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -116,7 +117,7 @@ def __init__( embed_dim=self.embed_dim, num_heads=config.num_attention_heads, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.do_layer_norm_before = config.do_layer_norm_before @@ -127,16 +128,15 @@ def __init__( self.embed_dim, config.ffn_dim, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.activation_fn = get_act_fn(config.activation_function, quant_config, config.ffn_dim) self.fc2 = RowParallelLinear( config.ffn_dim, self.embed_dim, bias=config.enable_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.final_layer_norm = nn.LayerNorm( self.embed_dim, @@ -181,7 +181,7 @@ class OPTDecoder(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -202,7 +202,7 @@ def __init__( self.project_out = ReplicatedLinear(config.hidden_size, config.word_embed_proj_dim, bias=False, - linear_method=linear_method) + quant_config=quant_config) else: self.project_out = None @@ -210,7 +210,7 @@ def __init__( self.project_in = ReplicatedLinear(config.word_embed_proj_dim, config.hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) else: self.project_in = None @@ -226,7 +226,7 @@ def __init__( self.final_layer_norm = None self.layers = nn.ModuleList([ - OPTDecoderLayer(config, linear_method) + OPTDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) @@ -259,10 +259,10 @@ class OPTModel(nn.Module): def __init__( self, config: OPTConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() - self.decoder = OPTDecoder(config, linear_method) + self.decoder = OPTDecoder(config, quant_config) def forward( self, @@ -279,12 +279,12 @@ class OPTForCausalLM(nn.Module): def __init__( self, config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.model = OPTModel(config, linear_method) + self.quant_config = quant_config + self.model = OPTModel(config, quant_config) self.lm_head_weight = self.model.decoder.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index bbb9fa5347cc..9ab5dfb97c19 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -13,11 +13,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -34,17 +35,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -67,7 +68,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -98,13 +99,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -139,7 +140,7 @@ class OrionDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -154,13 +155,13 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) self.mlp = OrionMLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -201,7 +202,7 @@ class OrionModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -212,7 +213,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - OrionDecoderLayer(config, linear_method) + OrionDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -244,12 +245,12 @@ class OrionForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = OrionModel(config, linear_method) + self.quant_config = quant_config + self.model = OrionModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index f974b78a0fbd..4a45879201af 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -45,10 +45,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -62,7 +63,7 @@ class PhiAttention(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -80,12 +81,12 @@ def __init__(self, self.head_size, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.dense = RowParallelLinear( self.hidden_size, self.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) scaling = self.head_size**-0.5 @@ -125,7 +126,7 @@ class PhiMLP(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() n_inner = getattr(config, "n_inner", None) @@ -134,14 +135,13 @@ def __init__(self, self.fc1 = ColumnParallelLinear( config.hidden_size, n_inner, - linear_method=linear_method, + quant_config=quant_config, ) self.fc2 = RowParallelLinear( n_inner, config.hidden_size, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, n_inner) def forward(self, hidden_states): @@ -155,12 +155,12 @@ class PhiLayer(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, linear_method) - self.mlp = PhiMLP(config, linear_method) + self.self_attn = PhiAttention(config, quant_config) + self.mlp = PhiMLP(config, quant_config) def forward( self, @@ -186,14 +186,14 @@ class PhiModel(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - PhiLayer(config, linear_method) + PhiLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.final_layernorm = nn.LayerNorm(config.hidden_size, @@ -225,12 +225,12 @@ class PhiForCausalLM(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.linear_method = linear_method + self.quant_config = quant_config - self.model = PhiModel(config, linear_method) + self.model = PhiModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index a77da7cb1598..e5e0028888c8 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -14,11 +14,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -35,17 +36,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str = "silu", - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.c_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -67,7 +68,7 @@ def __init__( max_position_embeddings: int, rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.hidden_size = hidden_size @@ -83,13 +84,13 @@ def __init__( self.head_dim, self.total_num_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.scaling = self.head_dim**-0.5 @@ -122,7 +123,7 @@ class QWenBlock(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -134,13 +135,13 @@ def __init__( config.max_position_embeddings, rope_theta=rope_theta, rope_scaling=rope_scaling, - linear_method=linear_method) + quant_config=quant_config) self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.mlp = QWenMLP(config.hidden_size, config.intermediate_size // 2, - linear_method=linear_method) + quant_config=quant_config) def forward( self, @@ -174,7 +175,7 @@ class QWenModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -185,7 +186,7 @@ def __init__( config.hidden_size, ) self.h = nn.ModuleList([ - QWenBlock(config, linear_method) + QWenBlock(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -217,12 +218,12 @@ class QWenLMHeadModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config - self.linear_method = linear_method - self.transformer = QWenModel(config, linear_method) + self.quant_config = quant_config + self.transformer = QWenModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 71b906e20ac1..62bc7fe22c36 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -33,11 +33,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -54,17 +55,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -86,7 +87,7 @@ def __init__(self, max_position: int = 4096 * 32, rope_theta: float = 10000, use_sliding_window: bool = False, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, sliding_window: Optional[int] = None) -> None: super().__init__() self.hidden_size = hidden_size @@ -117,13 +118,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -159,7 +160,7 @@ def __init__( self, config: Qwen2Config, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -174,13 +175,13 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, use_sliding_window=use_sliding_window, - linear_method=linear_method, + quant_config=quant_config, sliding_window=config.sliding_window) self.mlp = Qwen2MLP( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -221,7 +222,7 @@ class Qwen2Model(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config @@ -233,7 +234,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2DecoderLayer(config, layer_idx, linear_method) + Qwen2DecoderLayer(config, layer_idx, quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -286,14 +287,14 @@ class Qwen2ForCausalLM(nn.Module): def __init__( self, config: Qwen2Config, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: del lora_config super().__init__() self.config = config - self.linear_method = linear_method - self.model = Qwen2Model(config, linear_method) + self.quant_config = quant_config + self.model = Qwen2Model(config, quant_config) if config.tie_word_embeddings: self.lm_head_weight = self.model.embed_tokens.weight diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 59908bc9ef26..8da89a2b7ba6 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -36,12 +36,13 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -58,18 +59,18 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, reduce_results=reduce_results) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " @@ -88,7 +89,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ): super().__init__() self.config = config @@ -105,7 +106,7 @@ def __init__( Qwen2MoeMLP(hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False) for idx in range(self.n_routed_experts) ]) @@ -114,13 +115,13 @@ def __init__( self.gate = ReplicatedLinear(config.hidden_size, self.n_routed_experts, bias=False, - linear_method=None) + quant_config=None) if config.shared_expert_intermediate_size > 0: self.shared_expert = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.shared_expert_intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, reduce_results=False, ) else: @@ -186,7 +187,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = hidden_size @@ -217,14 +218,14 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=True, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -260,7 +261,7 @@ def __init__( self, config: PretrainedConfig, layer_idx: int, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -275,18 +276,18 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, ) if (config.num_experts is not None and (layer_idx + 1) % config.decoder_sparse_step == 0): self.mlp = Qwen2MoeSparseMoeBlock(config=config, - linear_method=linear_method) + quant_config=quant_config) else: self.mlp = Qwen2MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -327,7 +328,7 @@ class Qwen2MoeModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -338,9 +339,7 @@ def __init__( config.hidden_size, ) self.layers = nn.ModuleList([ - Qwen2MoeDecoderLayer(config, - layer_idx, - linear_method=linear_method) + Qwen2MoeDecoderLayer(config, layer_idx, quant_config=quant_config) for layer_idx in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -370,12 +369,12 @@ class Qwen2MoeForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = Qwen2MoeModel(config, linear_method) + self.quant_config = quant_config + self.model = Qwen2MoeModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 3e6c2db6f3c6..3d4f4f700f86 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -28,11 +28,12 @@ from vllm.attention import Attention, AttentionMetadata from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -46,7 +47,7 @@ class StablelmMLP(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -54,7 +55,7 @@ def __init__(self, self.gate_up_proj = MergedColumnParallelLinear( config.hidden_size, [config.intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(config.intermediate_size, config.hidden_size, bias=False) @@ -71,7 +72,7 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -109,11 +110,11 @@ def __init__(self, self.total_num_heads, self.total_num_key_value_heads, self.qkv_bias, - linear_method=linear_method) + quant_config=quant_config) self.o_proj = RowParallelLinear(self.total_num_heads * self.head_dim, self.hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.rotary_emb = get_rope( self.head_dim, rotary_dim=self.rotary_ndims, @@ -145,11 +146,11 @@ class StablelmDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.self_attn = StablelmAttention(config) - self.mlp = StablelmMLP(config, linear_method) + self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=norm_eps) @@ -187,14 +188,14 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None) -> None: + quant_config: Optional[QuantizationConfig] = None) -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, linear_method) + StablelmDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) norm_eps = getattr(config, "norm_eps", @@ -226,12 +227,12 @@ class StablelmForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = StableLMEpochModel(config, linear_method) + self.quant_config = quant_config + self.model = StableLMEpochModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index b90f3da141c2..33998e2aad5c 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -28,10 +28,11 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, - LinearMethodBase, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,7 +46,7 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config @@ -79,13 +80,13 @@ def __init__(self, self.total_num_heads, self.total_num_kv_heads, bias=self.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, self.hidden_size, bias=self.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( self.head_dim, @@ -121,21 +122,20 @@ class Starcoder2MLP(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.c_fc = ColumnParallelLinear( config.hidden_size, config.intermediate_size, bias=config.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) self.c_proj = RowParallelLinear( config.intermediate_size, config.hidden_size, bias=config.use_bias, - linear_method=linear_method, + quant_config=quant_config, ) - quant_config = getattr(linear_method, "quant_config", None) self.act = get_act_fn(config.hidden_act, quant_config, config.intermediate_size) @@ -150,12 +150,11 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.hidden_size = config.hidden_size - self.self_attn = Starcoder2Attention(config, - linear_method=linear_method) - self.mlp = Starcoder2MLP(config, linear_method=linear_method) + self.self_attn = Starcoder2Attention(config, quant_config=quant_config) + self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, @@ -192,7 +191,7 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -202,7 +201,7 @@ def __init__(self, self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, linear_method=linear_method) + Starcoder2DecoderLayer(config, quant_config=quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -227,10 +226,10 @@ class Starcoder2ForCausalLM(nn.Module): def __init__(self, config: Starcoder2Config, - linear_method: Optional[LinearMethodBase] = None): + quant_config: Optional[QuantizationConfig] = None): super().__init__() self.config = config - self.model = Starcoder2Model(config, linear_method=linear_method) + self.model = Starcoder2Model(config, quant_config=quant_config) self.vocab_size = config.vocab_size self.unpadded_vocab_size = config.vocab_size if config.tie_word_embeddings: diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 4e905390c234..0fb2662b2f71 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -31,11 +31,12 @@ from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -52,17 +53,17 @@ def __init__( hidden_size: int, intermediate_size: int, hidden_act: str, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, - linear_method=linear_method) + quant_config=quant_config) self.down_proj = RowParallelLinear(intermediate_size, hidden_size, bias=False, - linear_method=linear_method) + quant_config=quant_config) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -85,7 +86,7 @@ def __init__( rope_theta: float = 10000, rope_scaling: Optional[Dict[str, Any]] = None, max_position_embeddings: int = 8192, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, bias: bool = False, sliding_window: Optional[int] = None, ) -> None: @@ -112,13 +113,13 @@ def __init__( self.total_num_heads, self.total_num_kv_heads, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=bias, - linear_method=linear_method, + quant_config=quant_config, ) self.rotary_emb = get_rope( @@ -154,7 +155,7 @@ class XverseDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -171,7 +172,7 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, max_position_embeddings=max_position_embeddings, - linear_method=linear_method, + quant_config=quant_config, bias=getattr(config, "bias", False), sliding_window=sliding_window, ) @@ -179,7 +180,7 @@ def __init__( hidden_size=self.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, - linear_method=linear_method, + quant_config=quant_config, ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -220,7 +221,7 @@ class XverseModel(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, ) -> None: super().__init__() @@ -236,7 +237,7 @@ def __init__( org_num_embeddings=config.vocab_size, ) self.layers = nn.ModuleList([ - XverseDecoderLayer(config, linear_method) + XverseDecoderLayer(config, quant_config) for _ in range(config.num_hidden_layers) ]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -294,13 +295,13 @@ class XverseForCausalLM(nn.Module): def __init__( self, config: PretrainedConfig, - linear_method: Optional[LinearMethodBase] = None, + quant_config: Optional[QuantizationConfig] = None, lora_config=None, ) -> None: super().__init__() self.config = config - self.linear_method = linear_method - self.model = XverseModel(config, linear_method) + self.quant_config = quant_config + self.model = XverseModel(config, quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() diff --git a/vllm/sampling_params.py b/vllm/sampling_params.py index dc0e60344d85..0ed6a01a6221 100644 --- a/vllm/sampling_params.py +++ b/vllm/sampling_params.py @@ -185,8 +185,8 @@ def __init__( self.top_k = -1 self.min_p = 0.0 self._verify_greedy_sampling() - # injected by the engine - self.eos_token_id = None + # eos_token_id is added to this by the engine + self.all_stop_token_ids = set(self.stop_token_ids) def _verify_args(self) -> None: if self.n < 1: diff --git a/vllm/sequence.py b/vllm/sequence.py index 567fca570951..0e931ebbb657 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -442,15 +442,27 @@ def prompt_token_ids(self) -> List[int]: def lora_int_id(self) -> int: return self.lora_request.lora_int_id if self.lora_request else 0 - def get_last_latency(self, now: float) -> float: - """Gets last token latency for Request level timings.""" + def get_last_latency(self, now: float) -> Optional[float]: + """Sets the last token time for Request level timings.""" + # If still in prefill phase, raise Error. + if self.is_prefill(): + raise ValueError( + "seq_group.get_last_latency() should not be called " + "if the seq_group is in prefill phase.") + + # Otherwise return token latency. latency = now - self.metrics.last_token_time self.metrics.last_token_time = now return latency def maybe_set_first_token_time(self, time: float) -> None: """Sets the first token time for Request level timings.""" - if self.metrics.first_token_time is None: + # Note: in a case where a sequence_group is swapped and + # recomputed, the time between iterations is counted + # in TPOT, rather than recalculating TTFT (since from the ) + # POV of the user, there is simply a long generation delay. + if (self.metrics.first_token_time is None + and self.get_seqs()[0].get_output_len() == 1): self.metrics.first_token_time = time def maybe_set_first_scheduled_time(self, time: float) -> None: diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 2fcddc3bea5a..fa4693cb7dac 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,6 +1,7 @@ import os from typing import Optional, Union +import huggingface_hub from transformers import (AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast) @@ -76,6 +77,7 @@ def get_tokenizer( model_id=tokenizer_name, cache_dir=download_dir, revision=revision, + local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE, # Ignore weights - we only need the tokenizer. ignore_file_pattern=["*.pt", "*.safetensors", "*.bin"]) tokenizer_name = tokenizer_path diff --git a/vllm/utils.py b/vllm/utils.py index 76c2fc66e47c..88447878f170 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -225,11 +225,18 @@ async def producer(i: int, iterator: AsyncIterator[T]): ] async def consumer(): - while not all(finished) or not queue.empty(): - item = await queue.get() - if isinstance(item, Exception): - raise item - yield item + try: + while not all(finished) or not queue.empty(): + item = await queue.get() + if isinstance(item, Exception): + raise item + yield item + except (Exception, asyncio.CancelledError) as e: + for task in _tasks: + # NOTE: Pass the error msg in cancel() + # when only Python 3.9+ is supported. + task.cancel() + raise e await asyncio.gather(*_tasks) return consumer()