diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index ccc2f090565e4..5548071390aff 100644 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -75,6 +75,7 @@ docker run \ --network host \ --shm-size=16gb \ --rm \ + -e HIP_VISIBLE_DEVICES=0 \ -e HF_TOKEN \ -v ${HF_CACHE}:${HF_MOUNT} \ -e HF_HOME=${HF_MOUNT} \ diff --git a/.buildkite/run-tpu-test.sh b/.buildkite/run-tpu-test.sh index 4aabd123ae234..335ffd83fcd7a 100644 --- a/.buildkite/run-tpu-test.sh +++ b/.buildkite/run-tpu-test.sh @@ -12,5 +12,4 @@ remove_docker_container # For HF_TOKEN. source /etc/environment # Run a simple end-to-end example. -docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu \ - python3 /workspace/vllm/examples/offline_inference_tpu.py +docker run --privileged --net host --shm-size=16G -it -e HF_TOKEN=$HF_TOKEN --name tpu-test vllm-tpu /bin/bash -c "python3 -m pip install git+https://github.com/thuml/depyf.git && python3 /workspace/vllm/tests/tpu/test_compilation.py && python3 /workspace/vllm/examples/offline_inference_tpu.py" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index e406938647479..9f449ff650b90 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -233,12 +233,13 @@ steps: parallelism: 4 - label: Tensorizer Test # 11min + mirror_hardwares: [amd] soft_fail: true source_file_dependencies: - vllm/model_executor/model_loader - tests/tensorizer_loader commands: - - apt-get install -y curl libsodium23 + - apt-get update && apt-get install -y curl libsodium23 - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s tensorizer_loader diff --git a/.github/workflows/mypy.yaml b/.github/workflows/mypy.yaml index 3474bd3861598..ea767f4c3e264 100644 --- a/.github/workflows/mypy.yaml +++ b/.github/workflows/mypy.yaml @@ -35,7 +35,6 @@ jobs: mypy mypy tests --follow-imports skip mypy vllm/attention --follow-imports skip - mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip mypy vllm/executor --follow-imports skip diff --git a/CMakeLists.txt b/CMakeLists.txt index ab91b86426cd4..5b0d0ba904c32 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -296,6 +296,11 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") +if(VLLM_GPU_LANG STREQUAL "CUDA") + list(APPEND VLLM_MOE_EXT_SRC + "csrc/moe/marlin_moe_ops.cu") +endif() + define_gpu_extension_target( _moe_C DESTINATION vllm diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 1ccab2c65e697..eaf256f7cb8c2 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -86,6 +86,7 @@ def run_vllm( use_v2_block_manager: bool = False, download_dir: Optional[str] = None, load_format: str = EngineArgs.load_format, + disable_async_output_proc: bool = False, ) -> float: from vllm import LLM, SamplingParams llm = LLM( @@ -110,6 +111,7 @@ def run_vllm( load_format=load_format, num_scheduler_steps=num_scheduler_steps, use_v2_block_manager=use_v2_block_manager, + disable_async_output_proc=disable_async_output_proc, ) # Add the requests to the engine. @@ -237,7 +239,8 @@ def main(args: argparse.Namespace): args.enable_prefix_caching, args.enable_chunked_prefill, args.max_num_batched_tokens, args.distributed_executor_backend, args.gpu_memory_utilization, args.num_scheduler_steps, - args.use_v2_block_manager, args.download_dir, args.load_format) + args.use_v2_block_manager, args.download_dir, args.load_format, + args.disable_async_output_proc) elif args.backend == "hf": assert args.tensor_parallel_size == 1 elapsed_time = run_hf(requests, args.model, tokenizer, args.n, @@ -418,6 +421,11 @@ def main(args: argparse.Namespace): 'section for more information.\n' '* "bitsandbytes" will load the weights using bitsandbytes ' 'quantization.\n') + parser.add_argument( + "--disable-async-output-proc", + action='store_true', + default=False, + help="Disable async output processor for vLLM backend.") args = parser.parse_args() if args.tokenizer is None: args.tokenizer = args.model diff --git a/benchmarks/launch_tgi_server.sh b/benchmarks/launch_tgi_server.sh index f491c90d0683e..8c5cd454fbbee 100755 --- a/benchmarks/launch_tgi_server.sh +++ b/benchmarks/launch_tgi_server.sh @@ -6,7 +6,7 @@ TOKENS=$2 docker run -e HF_TOKEN=$HF_TOKEN --gpus all --shm-size 1g -p $PORT:80 \ -v $PWD/data:/data \ - ghcr.io/huggingface/text-generation-inference:1.4.0 \ + ghcr.io/huggingface/text-generation-inference:2.2.0 \ --model-id $MODEL \ --sharded false \ --max-input-length 1024 \ diff --git a/csrc/core/scalar_type.hpp b/csrc/core/scalar_type.hpp index b1e10fecb6b54..0e1f360d74bd5 100644 --- a/csrc/core/scalar_type.hpp +++ b/csrc/core/scalar_type.hpp @@ -387,7 +387,8 @@ class ScalarTypeTorch : public torch::CustomClassHolder, public ScalarType { // This needs to be implemented and throw a TypeError in order for // PyTorch's opcheck to work on ops that use ScalarTypes. int64_t len() const { - throw c10::TypeError("__len__ not implemented"); + throw c10::TypeError({__func__, __FILE__, static_cast(__LINE__)}, + "__len__ not implemented"); return 0; } diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu new file mode 100644 index 0000000000000..1e170e80d2f70 --- /dev/null +++ b/csrc/moe/marlin_moe_ops.cu @@ -0,0 +1,1740 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include +#include +#include +#include + +#include + +template +inline std::string str(T x) { + return std::to_string(x); +} + +namespace marlin_moe { + +constexpr int ceildiv(int a, int b) { return (a + b - 1) / b; } + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + +// Instances of `Vec` are used to organize groups of >>registers<<, as needed +// for instance as inputs to tensor core operations. Consequently, all +// corresponding index accesses must be compile-time constants, which is why we +// extensively use `#pragma unroll` throughout the kernel code to guarantee +// this. +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { return elems[i]; } +}; + +using I4 = Vec; + +// Matrix fragments for tensor core instructions; their precise layout is +// documented here: +// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type +using FragA = Vec; +using FragB = Vec; +using FragC = Vec; +using FragS = Vec; // quantization scales + +// Predicated asynchronous global->shared copy; used for inputs A where we apply +// predication to handle batchsizes that are not multiples of 16. +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, + bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), "l"(glob_ptr), "n"(BYTES)); +} + +// Asynchronous global->shared copy +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), "n"(BYTES)); +} + +// Async copy fence. +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +// Wait until at most `n` async copy stages are still pending. +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +__device__ inline void mma(const FragA& a_frag, const FragB& frag_b, + FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), + "f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3])); +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +__device__ inline void ldsm4(FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); +} + +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16 +// values. We mostly follow the strategy in the link below, with some small +// changes: +// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h +__device__ inline FragB dequant(int q) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + FragB frag_b; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), + *reinterpret_cast(&SUB)); + frag_b[1] = __hfma2(*reinterpret_cast(&hi), + *reinterpret_cast(&MUL), + *reinterpret_cast(&ADD)); + return frag_b; +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +__device__ inline void scale(FragB& frag_b, FragS& frag_s, int i) { + half2 s = __half2half2(reinterpret_cast<__half*>(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +// Given 2 floats multiply by 2 scales (halves) +__device__ inline void scale_float(float* c, FragS& s) { + __half* s_ptr = reinterpret_cast<__half*>(&s); + c[0] = __fmul_rn(c[0], __half2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], __half2float(s_ptr[1])); +} + +// Same as above, but for act_order (each K is multiplied individually) +__device__ inline void scale4(FragB& frag_b, FragS& frag_s_1, FragS& frag_s_2, + FragS& frag_s_3, FragS& frag_s_4, int i) { + __half2 s_val_1_2; + s_val_1_2.x = reinterpret_cast<__half*>(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast<__half*>(&frag_s_2)[i]; + + __half2 s_val_3_4; + s_val_3_4.x = reinterpret_cast<__half*>(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast<__half*>(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" + : "=r"(state) + : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" + : + : "l"(lock), "r"(val)); + } +} + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + int start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / blockDim.x; + int rest = size_k % blockDim.x; + + int offset = row * row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + offset); + half* out_half = reinterpret_cast(out_int4_ptr + offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += blockDim.x; + } + + if (rest) { + if (threadIdx.x < rest) { + int cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + int expert_id = threadIdx.x; + int num_experts = blockDim.x; + + int occurrences = 0; + for (int i = 0; i < topk_length; ++i) { + occurrences += (topk_ids[i] == expert_id); + } + expert_offsets[expert_id + 1] = occurrences; + __syncthreads(); + + if (threadIdx.x == 0) { + int tot_offset = 0; + expert_offsets[0] = 0; + for (int i = 0; i < num_experts; ++i) { + tot_offset += ceildiv(expert_offsets[i + 1], block_size) * block_size; + expert_offsets[i + 1] = tot_offset; + } + } + __syncthreads(); +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__device__ inline void MarlinMoESingle( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block // current m block to start kernel computation from +) { + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > 16 * thread_m_blocks) { + parallel = prob_m / (16 * thread_m_blocks); + prob_m = 16 * thread_m_blocks; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = ceildiv(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * + ceildiv(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = + 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + locks += (slice_col_par / n_tiles) * n_tiles; + slice_col = slice_col_par % n_tiles; + sorted_ids += (slice_col_par / n_tiles) * 16 * thread_m_blocks; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&]() { + slice_iters = + iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * ceildiv(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = ceildiv(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (slice_col == n_tiles) { + sorted_ids += 16 * thread_m_blocks; + locks += n_tiles; + slice_col = 0; + } + }; + init_slice(); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = prob_k / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks); + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = ceildiv(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / 32; + constexpr int b_sh_stride = 32 * thread_n_blocks / 4; + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride); + constexpr int b_sh_wr_delta = threads; + constexpr int b_sh_rd_delta = threads; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + constexpr int sorted_sh_stride = threads; + constexpr int sorted_gl_stride = threads; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = + a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16; + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = + b_gl_stride * (threadIdx.x / b_sh_stride) + (threadIdx.x % b_sh_stride); + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + int b_sh_wr = threadIdx.x; + int b_sh_rd = threadIdx.x; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (group_blocks == -1 || group_blocks == 0) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + + s_sh_stride * slice_col + threadIdx.x; + } + int s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) / 4; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + (threadIdx.x % 32) % 4; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + constexpr int sh_max_num_groups = 32; + + int shs_size; + if constexpr (has_act_order) + shs_size = sh_max_num_groups * s_sh_stride + threads; + else + shs_size = group_blocks > 0 ? stages * s_sh_stage : threads; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + int4* sh_a = sh; + int4* sh_b = sh_a + (stages * a_sh_stage); + int4* sh_g_idx = sh_b + (stages * b_sh_stage); + int4* sh_s = sh_g_idx + (stages * g_idx_stage); + int* sh_sorted = (int*)(sh_s + shs_size); + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_sh_wr_delta * i + a_sh_wr; + int row = a_idx / a_gl_rd_delta_o; + if (row >= prob_m) { + a_sh_wr_pred[i] = false; + } else { + a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + } + } + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row; + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) + a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + #pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = + transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + + // Zero accumulators. + auto zero_accums = [&]() { + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) + reinterpret_cast(frag_c)[i] = 0; + }; + + auto fetch_scales_to_shared = [&](bool is_async, int first_group_id, + int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups < sh_max_num_groups) { + sh_num_groups = sh_max_num_groups; + } + + if (sh_first_group_id + sh_num_groups > num_groups) { + sh_num_groups = num_groups - sh_first_group_id; + } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = + scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + int a_idx = a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off; + int row = a_idx / a_gl_stride; + int sorted_row = + replicate_input ? sorted_ids[row] / topk : sorted_ids[row]; + int new_idx = sorted_row * a_gl_stride + a_idx % a_gl_stride; + if (sorted_row < tot_m * (replicate_input ? 1 : topk) && + new_idx < a_gl_stride * tot_m * (replicate_input ? 1 : topk)) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[new_idx], + a_sh_wr_pred[i]); + } + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { + cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr], B_ptr[i]); + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = + reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { + cp_async4_pred(&sh_g_idx_stage[threadIdx.x], + &cur_g_idx_stage_ptr[threadIdx.x]); + } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { + cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], + &scales_ptr[s_gl_rd]); + } + s_gl_rd += s_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + // TODO we are currently hitting illegal memory accesses when fetching + // sorted_ids to shared data: fix this + auto fetch_sorted_ids_to_shared = [&]() { + const int mpt = ceildiv(prob_m, threads); + for (int i = 0; i < mpt; i++) { + if ((i * sorted_gl_stride) + threadIdx.x < prob_m) { + sh_sorted[(i * sorted_sh_stride) + threadIdx.x] = + sorted_ids[(i * sorted_gl_stride) + threadIdx.x]; + } + } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm4(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + frag_b_quant[k % 2] = *reinterpret_cast( + &sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd]); + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * + (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + int warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / group_blocks; + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + reinterpret_cast(&frag_s[k % 2])[0] = + sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { + return; + } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + int warp_id = threadIdx.x / 32; + int n_warps = + thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + int th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, + 9}; // Tensor core offsets per thread + + #pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = + sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + // Execute the actual tensor core matmul of a sub-tile. + auto matmul = [&](int k) { + // We have the m dimension as the inner loop in order to encourage overlapping + // dequantization and matmul operations. + #pragma unroll + for (int j = 0; j < 4; j++) { + int b_quant = frag_b_quant[k % 2][j]; + int b_quant_shift = b_quant >> 8; + + FragB frag_b0 = dequant(b_quant); + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + scale4(frag_b0, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 0); + } else { + if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k % 2][j], 0); + } + } + + FragB frag_b1 = dequant(b_quant_shift); + + // Apply scale to frag_b1 + if constexpr (has_act_order) { + scale4(frag_b1, act_frag_s[k % 2][0][j], act_frag_s[k % 2][1][j], + act_frag_s[k % 2][2][j], act_frag_s[k % 2][3][j], 1); + + } else { + if constexpr (group_blocks != -1) { + scale(frag_b1, frag_s[k % 2][j], 1); + } + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + mma(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]); + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride / 2; + if (red_off >= 1) { + int red_idx = threadIdx.x / b_sh_stride; + constexpr int red_sh_stride = b_sh_stride * 4 * 2; + constexpr int red_sh_delta = b_sh_stride; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride) + + (threadIdx.x % b_sh_stride); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + + #pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { + #pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { + #pragma unroll + for (int j = 0; j < 4 * 2; j++) { + int red_sh_wr = + red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh[red_sh_wr]); + #pragma unroll + for (int k = 0; k < 4; k++) + reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += + c_rd[k] + c_wr[k]; + } + sh[red_sh_wr] = + reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { + #pragma unroll + for (int i = 0; i < 4 * 2; i++) { + float* c_rd = + reinterpret_cast(&sh[red_sh_delta * i + red_sh_rd]); + #pragma unroll + for (int j = 0; j < 4; j++) + reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += + c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + constexpr int c_sh_wr_delta = active_threads; + int c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { + // Interestingly, doing direct global accesses here really seems to mess up + // the compiler and lead to slowdowns, hence we also use async-copies even + // though these fetches are not actually asynchronous. + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int sorted_row = sorted_ids[c_idx / c_gl_stride]; + int new_idx = sorted_row * c_gl_stride + c_idx % c_gl_stride; + cp_async4_pred(&sh[c_sh_wr + c_sh_wr_delta * i], &C[new_idx], + sorted_row < tot_m * topk && + (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk))); + } + cp_async_fence(); + cp_async_wait<0>(); + } + + #pragma unroll + for (int i = 0; i < thread_m_blocks * 4; i++) { + if (8 * (i / 2) + row < prob_m && + (i < (thread_m_blocks - 1) * 4 || + sorted_ids[8 * (i / 2) + row] < tot_m * topk)) { + if (!first) { + int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta]; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] += + __half2float(reinterpret_cast<__half*>(&c_red)[j]); + } + } + if (!last) { + int4 c; + #pragma unroll + for (int j = 0; j < 2 * 4; j++) { + reinterpret_cast<__half*>(&c)[j] = + __float2half(reinterpret_cast( + &frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]); + } + int c_idx = + c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2); + int row = sorted_ids[c_idx / c_gl_stride]; + if (row < tot_m * topk) { + int new_idx = row * c_gl_stride + c_idx % c_gl_stride; + C[new_idx] = c; + } + } + } + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = + c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr = + (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + half2 res = __halves2half2(__float2half(c0), __float2half(c1)); + + // For per-column quantization we finally apply the scale here + if constexpr (!has_act_order && group_blocks == -1) { + res = __hmul2(res, s[0]); + } + + ((half2*)sh)[idx] = res; + }; + if (threadIdx.x / 32 < thread_n_blocks / 4) { + #pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + #pragma unroll + for (int j = 0; j < 4; j++) { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], + frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], + frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], + frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], + frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + + #pragma unroll + for (int i = 0; + i < ceildiv(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); + i++) { + if (c_gl_wr < c_gl_wr_end) { + int row = sorted_ids[c_gl_wr / c_gl_stride]; + if (row < tot_m * topk) { + int off = row * c_gl_stride + c_gl_wr % c_gl_stride; + if (!apply_weights) { + C[off] = sh[c_sh_rd]; + } else { + __half* ctrg = reinterpret_cast<__half*>(&C[off]); + __half* csrc = reinterpret_cast<__half*>(&sh[c_sh_rd]); + for (int j = 0; j < 8; ++j) { + ctrg[j] = __float2half(topk_weights[row] * __half2float(csrc[j])); + } + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + } + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + // TODO re-enable after fixing this function + // fetch_sorted_ids_to_shared(); + __syncthreads(); + + #pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + slice_k_start_shared_fetch += tb_k * (stages - 1); + }; + if (slice_iters) { + start_pipes(); + } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + #pragma unroll + for (int pipe = 0; pipe < stages;) { + #pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, + slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { + break; + } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + slice_k_start += tb_k * stages; + slice_k_start_shared_fetch += tb_k * stages; + + if constexpr (has_act_order) { + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { + last_g_idx = prob_k - 1; + } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + if (s_sh_wr_pred) { + cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); + } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1) { + if (last) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } + } + if (slice_count > 1) { // only globally reduce if there is more than one + // block in a slice + barrier_acquire(&locks[slice_col], slice_idx); + global_reduce(slice_idx == 0, last); + barrier_release(&locks[slice_col], last); + } + if (last) // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + init_slice(); + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + + (threadIdx.x % a_gl_rd_delta_o); + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) + B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { + #pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } + start_pipes(); + } + } + } +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids_base, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par // maximum parallelism +) { + int m_block_ctr = current_m_block; + + const int* sorted_ids_expert = + sorted_ids_base + expert_offsets[expert_idx] + m_block_ctr * 4 * max_par; + int tot_its = expert_offsets[expert_idx + 1] - expert_offsets[expert_idx]; + if (tot_its == 0) { + return; + } + int tot_m_blocks = ceildiv(tot_its, 16); + int pad = 16 * tot_m_blocks - tot_its; + + if (m_block_ctr >= tot_m_blocks) { + return; + } + + int max_block = tot_m_blocks - m_block_ctr; + prob_m = tot_its - 16 * m_block_ctr; + + int par = 1; + if (max_block > 4) { + // Note that parallel > 1 currently only works for inputs without any + // padding + par = (16 * max_block - pad) / 64; + par = min((16 * max_block - pad) / 64, max_par); + prob_m = 64 * par; + m_block_ctr += 4 * (par - 1); + max_block = 4; + } + + if (max_block == 1) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 2) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else if (max_block == 3) { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } else { + MarlinMoESingle( + A, B, C, sorted_ids_expert, topk_weights, scales_ptr, g_idx, + expert_offsets, num_groups, expert_idx, num_experts, topk, prob_m, + prob_n, prob_k, tot_m, locks, replicate_input, apply_weights, + current_m_block); + } +} + +#else + +__global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, int size_m, + int size_k, int block_rows) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +__global__ void compute_expert_offsets(int const* __restrict__ topk_ids, + int* __restrict__ expert_offsets, + int topk_length, int block_size) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +template shared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks = -1 // number of consecutive 16x16 blocks + // with a separate quantization scale + > +__global__ void MarlinMoE( + const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + const int* __restrict__ sorted_ids, // int32 sorted ids of experts + const float* __restrict__ topk_weights, // float topk weights + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + const int* __restrict__ expert_offsets, + int num_groups, // number of scale groups per output channel + int expert_idx, // idx of current expert + int num_experts, // number of experts + int topk, // topk parameter of moe + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int tot_m, // total number of rows in A and C + int* locks, // extra global storage for barrier synchronization + bool replicate_input, // do we use the same input for each expert? + bool apply_weights, // apply weights to output + int current_m_block, // current m block to start kernel computation from + int max_par // maximum parallelism +) { + // Marlin is not implemented yet for SM < 8.0 + assert(false); + return; +} + +#endif + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +const int USER_THREADS = + 256; // Note: This is only used with user-provided thread_k/n +const int STAGES = 4; // 4 pipeline stages fit into shared memory +// const int SHARED_MEM = +// 96 * 1024; // max shared memory on compute capability 8.6 (< 8.0) + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; + +#define __CALL_IF_MOE(THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \ + HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \ + else if (thread_m_blocks == THREAD_M_BLOCKS && \ + thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && \ + has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS) { \ + cudaFuncSetAttribute( \ + MarlinMoE, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + MarlinMoE \ + <<>>( \ + A_ptr, B_ptr, C_ptr, sorted_ids_ptr, topk_weights_ptr, s_ptr, \ + g_idx_ptr, expert_offsets_ptr, num_groups, expert_idx, \ + num_experts, topk, prob_m, prob_n, prob_k, tot_m, locks, \ + replicate_input, apply_weights, m_block, max_par); \ + } + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, // Default + {128, 64, 128}, // Reduce N 2X, same K + {64, 256, 256}, // Reduce K 2X, increase N 2X + {64, 128, 128}, // Reduce K 2X, same N +}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, // Default + {128, 128, 256}, // Reduce N 2X, increase K 2X + {64, 128, 128}, // Reduce N 2X, same K + {128, 64, 128}, // Reduce N 4X, increase K 2X +}; + +bool is_valid_config(thread_config_t const& th_config, int prob_m, int prob_n, + int prob_k) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || + th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // thread_k can be only 128 or 64 (because it must be less than groupsize + // which is 128) + if (th_config.thread_k != 128 && th_config.thread_k != 64) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + return true; +} + +thread_config_t determine_thread_config(int prob_m, int prob_n, int prob_k) { + if (prob_m <= 16) { + for (auto th_config : small_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + + } else { + for (auto th_config : large_batch_thread_configs) { + if (is_valid_config(th_config, prob_m, prob_n, prob_k)) { + return th_config; + } + } + } + + return thread_config_t{-1, -1, -1}; +} + +#define CALL_IF_MOE(N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \ + \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \ + \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \ + __CALL_IF_MOE(4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) + +void marlin_mm_moe_f16i4(const void* A, const void* B, void* C, + const void* sorted_ids, const void* topk_weights, + const void* topk_ids, const void* s, const void* g_idx, + const void* perm, void* a_tmp, void* expert_offsets, + int prob_m, int prob_n, int prob_k, void* workspace, + bool has_act_order, bool is_k_full, int num_groups, + int group_size, int num_experts, int topk, + int moe_block_size, int dev, cudaStream_t stream, + int thread_k, int thread_n, int sms, int max_par, + bool replicate_input, bool apply_weights) { + TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, + ", ", prob_n, ", ", prob_k, "]"); + + if (sms == -1) { + cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev); + } + + // Set thread config + thread_config_t th_config; + if (thread_k != -1 && thread_n != -1) { + // User-defined config + th_config = thread_config_t{thread_k, thread_n, USER_THREADS}; + } else { + // Auto config + th_config = determine_thread_config(prob_m, prob_n, prob_k); + } + + TORCH_CHECK(is_valid_config(th_config, prob_m, prob_n, prob_k), + "Invalid thread config: thread_k = " + str(th_config.thread_k) + + ", thread_n = " + str(th_config.thread_n) + + ", num_threads = " + str(th_config.num_threads) + + " for MKN = [" + str(prob_m) + ", " + str(prob_k) + ", " + + str(prob_n) + "]"); + + int num_threads = th_config.num_threads; + thread_k = th_config.thread_k; + thread_n = th_config.thread_n; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + int blocks = sms; + + TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n, + " is not divisible by thread_n = ", thread_n); + TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k, + " is not divisible by thread_k = ", thread_k); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(group_size != -1); + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } else { + TORCH_CHECK(group_size == 0); + group_blocks = 0; + } + + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + TORCH_CHECK(prob_k % group_blocks == 0, "prob_k = ", prob_k, + " is not divisible by group_blocks = ", group_blocks); + } + } + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, + cudaDevAttrMaxSharedMemoryPerBlockOptin, dev); + TORCH_CHECK(max_shared_mem > 0); + + int tot_m = prob_m; + + const int* topk_ids_ptr = (const int*)topk_ids; + int* expert_offsets_ptr = (int*)expert_offsets; + compute_expert_offsets<<<1, num_experts, 0, stream>>>( + topk_ids_ptr, expert_offsets_ptr, tot_m * topk, moe_block_size); + + bool do_permute_a = has_act_order; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) { + has_act_order = false; + } + + for (int expert_idx = 0; expert_idx < num_experts; ++expert_idx) { + const int4* A_ptr = (const int4*)A; + int4* a_tmp_ptr = (int4*)a_tmp; + const int4* B_ptr = (const int4*)B + (prob_n * prob_k / 32) * expert_idx; + int4* C_ptr = (int4*)C; + const float* topk_weights_ptr = (const float*)topk_weights; + const int* sorted_ids_ptr = (const int*)sorted_ids; + const int4* s_ptr = + (const int4*)s + + (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * + prob_n / 8) * + expert_idx; + const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; + const int* perm_ptr = (const int*)perm + prob_k * expert_idx; + int* locks = (int*)workspace; + + if (do_permute_a) { + // Permute A columns + int topk_rows = replicate_input ? tot_m : tot_m * topk; + int block_rows = ceildiv(topk_rows, blocks); + permute_cols_kernel<<>>( + A_ptr, perm_ptr, a_tmp_ptr, topk_rows, prob_k, block_rows); + A_ptr = a_tmp_ptr; + } + + int max_m_blocks = ceildiv(tot_m, 16); + for (int m_block = 0; m_block < max_m_blocks; m_block += 16) { + // Define kernel configurations + + // make it max possible value + int thread_m_blocks = 4; + + if (false) { + } + CALL_IF_MOE(16, 4, 256) + CALL_IF_MOE(8, 8, 256) + CALL_IF_MOE(8, 4, 128) + CALL_IF_MOE(4, 8, 128) + else { + TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + + str(prob_n) + ", " + str(prob_k) + "]" + + ", has_act_order = " + str(has_act_order) + + ", num_groups = " + str(num_groups) + + ", group_size = " + str(group_size) + + ", thread_m_blocks = " + str(thread_m_blocks) + + ", thread_n_blocks = " + str(thread_n_blocks) + + ", thread_k_blocks = " + str(thread_k_blocks)); + } + } + } +} + +} // namespace marlin_moe + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + const torch::Tensor& g_idx, const torch::Tensor& perm, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights) { + int max_par = 4; + + int dev = a.get_device(); + + auto options_dtype = + torch::TensorOptions().dtype(a.dtype()).device(a.device()); + auto options_int = + torch::TensorOptions().dtype(torch::kInt).device(a.device()); + torch::Tensor c = torch::zeros({size_m, topk, size_n}, options_dtype); + torch::Tensor a_tmp = + replicate_input ? torch::zeros({size_m, size_k}, options_dtype) + : torch::zeros({size_m, topk, size_k}, options_dtype); + torch::Tensor expert_offsets = torch::empty({num_experts + 1}, options_int); + + // thread_k: `k` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_k = -1; + // thread_n: `n` size of a thread_tile in `weights` (can usually be left as + // auto -1) + int thread_n = -1; + // sms: number of SMs to use for the kernel (can usually be left as auto -1) + int sms = -1; + + // Detect groupsize and act_order + int num_groups = -1; + int group_size = -1; + bool has_act_order = g_idx.size(1) != 0; + + int b_rank = b_scales.sizes().size(); + TORCH_CHECK(b_rank == 3, "b_scales rank = ", b_rank, " is not 3"); + TORCH_CHECK(b_scales.size(2) == size_n, "b_scales dim 2 = ", b_scales.size(2), + " is not size_n = ", size_n); + num_groups = b_scales.size(1); + + if (has_act_order) { + if (is_k_full) { + TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); + TORCH_CHECK(size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by num_groups = ", num_groups); + group_size = size_k / num_groups; + } else { + group_size = 0; + } + + } else { + if (num_groups > 1) { + TORCH_CHECK( + size_k % num_groups == 0, "size_k = ", size_k, + ", is not divisible by b_scales.size(0) = ", b_scales.size(0)); + group_size = size_k / num_groups; + } else { + group_size = -1; + } + } + + marlin_moe::marlin_mm_moe_f16i4( + a.data_ptr(), b_q_weights.data_ptr(), c.data_ptr(), sorted_ids.data_ptr(), + topk_weights.data_ptr(), topk_ids.data_ptr(), b_scales.data_ptr(), + g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr(), + expert_offsets.data_ptr(), size_m, size_n, size_k, workspace.data_ptr(), + has_act_order, is_k_full, num_groups, group_size, num_experts, topk, + moe_block_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k, + thread_n, sms, max_par, replicate_input, apply_weights); + return c; +} \ No newline at end of file diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h new file mode 100644 index 0000000000000..01ba8ff69850d --- /dev/null +++ b/csrc/moe/marlin_moe_ops.h @@ -0,0 +1,12 @@ +#pragma once + +#include + +torch::Tensor marlin_gemm_moe( + const torch::Tensor& a, const torch::Tensor& b_q_weights, + const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, + const torch::Tensor& topk_ids, const torch::Tensor& b_scales, + const torch::Tensor& g_idx, const torch::Tensor& perm, + torch::Tensor& workspace, int64_t size_m, int64_t size_n, int64_t size_k, + bool is_k_full, int64_t num_experts, int64_t topk, int64_t moe_block_size, + bool replicate_input, bool apply_weights); \ No newline at end of file diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 86e42af44df15..d4d43e2c601b5 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,5 +1,6 @@ #include "core/registration.h" #include "moe_ops.h" +#include "marlin_moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. @@ -7,6 +8,17 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + +#ifndef USE_ROCM + m.def( + "marlin_gemm_moe(Tensor! a, Tensor! b_q_weights, Tensor! sorted_ids, " + "Tensor! topk_weights, Tensor! topk_ids, Tensor! b_scales, Tensor! " + "g_idx, Tensor! perm, Tensor! workspace, int size_m, int size_n, int " + "size_k, bool is_k_full, int num_experts, int topk, int moe_block_size, " + "bool replicate_input, bool apply_weights) -> Tensor"); + + m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); +#endif } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index e292c32999d63..95a9be7806633 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -11,4 +11,5 @@ pydantic >= 2.8 torch py-cpuinfo transformers +mistral_common >= 1.3.4 openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index a45bc885dc122..241b2ccd0991e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -45,8 +45,6 @@ Base Classes .. autodata:: vllm.multimodal.NestedTensors -.. autodata:: vllm.multimodal.BatchedTensors - .. autodata:: vllm.multimodal.BatchedTensorInputs .. autoclass:: vllm.multimodal.MultiModalDataBuiltins diff --git a/format.sh b/format.sh index 9e0780870303d..2204b3ba59498 100755 --- a/format.sh +++ b/format.sh @@ -99,7 +99,6 @@ echo 'vLLM mypy:' mypy --follow-imports skip # Note that this is less strict than CI mypy tests --follow-imports skip mypy vllm/attention --follow-imports skip -mypy vllm/core --follow-imports skip mypy vllm/distributed --follow-imports skip mypy vllm/engine --follow-imports skip mypy vllm/executor --follow-imports skip diff --git a/pyproject.toml b/pyproject.toml index bcedbb53ab887..22a25d9cf32e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,6 +58,7 @@ files = [ "vllm/adapter_commons", "vllm/assets", "vllm/entrypoints", + "vllm/core", "vllm/inputs", "vllm/logging", "vllm/multimodal", diff --git a/requirements-common.txt b/requirements-common.txt index 534d63feec2b8..61daf99819756 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -26,3 +26,4 @@ librosa # Required for audio processing soundfile # Required for audio processing gguf == 0.9.1 importlib_metadata +mistral_common >= 1.3.4 diff --git a/requirements-rocm.txt b/requirements-rocm.txt index cc955e279a845..121123611d2da 100644 --- a/requirements-rocm.txt +++ b/requirements-rocm.txt @@ -8,3 +8,4 @@ botocore ray >= 2.10.0 peft pytest-asyncio +tensorizer>=2.9.0 \ No newline at end of file diff --git a/tests/async_engine/test_openapi_server_ray.py b/tests/async_engine/test_openapi_server_ray.py index d5c88708d047b..f70118546c7b6 100644 --- a/tests/async_engine/test_openapi_server_ray.py +++ b/tests/async_engine/test_openapi_server_ray.py @@ -1,5 +1,6 @@ import openai # use the official client for correctness check import pytest +import pytest_asyncio from ..utils import VLLM_PATH, RemoteOpenAIServer @@ -31,9 +32,10 @@ def server(): yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index 9c6364ecc6792..1211e6ba5aafc 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -88,6 +88,9 @@ def test_models( # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) +# Due to low-precision numerical divergence, this test is too sensitive to +# the async postprocessor +@pytest.mark.parametrize("disable_async_output_proc", [True]) def test_models_with_fp8_kv_cache( vllm_runner, example_prompts, @@ -97,6 +100,7 @@ def test_models_with_fp8_kv_cache( chunked_prefill_token_size: int, enforce_eager: bool, tensor_parallel_size: int, + disable_async_output_proc: bool, ) -> None: """ Only checks log probs match between chunked-prefill and @@ -126,6 +130,7 @@ def test_models_with_fp8_kv_cache( enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, **extra_kwargs, ) as vllm_model: no_chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( @@ -139,6 +144,7 @@ def test_models_with_fp8_kv_cache( enforce_eager=enforce_eager, max_num_seqs=max_num_seqs, kv_cache_dtype=kv_cache_dtype, + disable_async_output_proc=disable_async_output_proc, **extra_kwargs, ) as vllm_model: chunked_prefill_outputs = vllm_model.generate_greedy_logprobs( diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 7c62de9fa9e37..7e77037da07d3 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -209,7 +209,6 @@ def test_swap_infeasible( prefill_blocks = 2 decode_blocks = max_tokens // BLOCK_SIZE example_prompts = example_prompts[:1] - with vllm_runner( model, dtype=dtype, diff --git a/tests/conftest.py b/tests/conftest.py index ae362b228d9d8..d8264f65b6149 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -41,6 +41,10 @@ _TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")] _LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "summary.txt")] +PromptImageInput = Union[List[Image.Image], List[List[Image.Image]]] +PromptAudioInput = Union[List[Tuple[np.ndarray, int]], + List[List[Tuple[np.ndarray, int]]]] + def _read_prompts(filename: str) -> List[str]: with open(filename, "r") as f: @@ -161,7 +165,7 @@ def example_encoder_decoder_prompts( decoder prompt) tuple. Returns: - + * Encoder prompt list * Decoder prompt list (reverse of encoder prompt list) ''' @@ -578,8 +582,7 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, + images: Optional[PromptImageInput] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) @@ -623,10 +626,8 @@ def generate_w_logprobs( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, - audios: Optional[Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]]] = None + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None @@ -676,10 +677,8 @@ def generate_greedy_logprobs( prompts: List[str], max_tokens: int, num_logprobs: int, - images: Optional[Union[List[Image.Image], - List[List[Image.Image]]]] = None, - audios: Optional[Union[List[Tuple[np.ndarray, int]], - List[List[Tuple[np.ndarray, int]]]]] = None, + images: Optional[PromptImageInput] = None, + audios: Optional[PromptAudioInput] = None, stop_token_ids: Optional[List[int]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: greedy_logprobs_params = SamplingParams(temperature=0.0, diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index c2226870c2e83..25be2dd13f8bd 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -708,6 +708,37 @@ def test_metric(): token_ids=token_ids) assert allocator.get_prefix_cache_hit_rate() > 0.99 + # Test case for marking cache hit blocks as computed right after + # a batch of prefill sequences are scheduled. + @staticmethod + def test_touch_block(): + block_size = 16 + common_blocks = 4 + allocator = PrefixCachingBlockAllocator(num_blocks=8, + block_size=block_size) + + common_token_ids = list(range(block_size * common_blocks)) + + # Mimic the behavior of allocating the same block chain + # (i.e., common prefix) for a batch of 3 different prefill sequences. + for _ in range(3): + blocks = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=common_token_ids, + allocator=allocator, + ) + block_ids = [block.block_id for block in blocks] + # The allocated blocks should be marked as touched + # but not computed. + computed_block_ids = allocator.get_computed_block_ids( + [], block_ids, skip_last_block_id=False) + assert len(computed_block_ids) == 0 + + allocator.mark_blocks_as_computed([]) + computed_block_ids = allocator.get_computed_block_ids( + [], block_ids, skip_last_block_id=False) + assert len(computed_block_ids) == common_blocks + @staticmethod def create_immutable_chain( block_size: int, diff --git a/tests/core/test_chunked_prefill_scheduler.py b/tests/core/test_chunked_prefill_scheduler.py index a3b76327e0a53..6d9c2f3ebba4a 100644 --- a/tests/core/test_chunked_prefill_scheduler.py +++ b/tests/core/test_chunked_prefill_scheduler.py @@ -21,7 +21,7 @@ def append_new_token(seq_group, token_id: int): def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() + metas, out, _ = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) return metas, out @@ -180,7 +180,7 @@ def test_maximal_decoding(): """Verify decoding requests are prioritized.""" block_size = 4 max_seqs = 2 - max_model_len = 2 + max_model_len = 8 max_num_batched_tokens = 2 scheduler_config = SchedulerConfig(max_num_batched_tokens, max_seqs, diff --git a/tests/core/utils.py b/tests/core/utils.py index 12b66d50749db..40d8f51fc186e 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -199,7 +199,7 @@ def append_new_token(out, token_id: int): def schedule_and_update_computed_tokens(scheduler): - metas, out = scheduler.schedule() + metas, out, _ = scheduler.schedule() for s, meta in zip(out.scheduled_seq_groups, metas): s.seq_group.update_num_computed_tokens(meta.token_chunk_size) return metas, out diff --git a/tests/engine/test_stop_strings.py b/tests/engine/test_stop_strings.py index 1584b85aeb064..499935620c16a 100644 --- a/tests/engine/test_stop_strings.py +++ b/tests/engine/test_stop_strings.py @@ -7,6 +7,8 @@ MODEL = "meta-llama/llama-2-7b-hf" MAX_TOKENS = 200 +IS_ASYNC = False + @pytest.fixture(scope="session") def vllm_model(vllm_runner): @@ -14,99 +16,148 @@ def vllm_model(vllm_runner): yield vllm_model -@pytest.mark.skip_global_cleanup -def test_stop_basic(vllm_model): - _test_stopping(vllm_model.model.llm_engine, +def _test_stopping(llm_engine: LLMEngine, + expected_output: str, + expected_reason: Any, + stop: Optional[List[str]] = None, + stop_token_ids: Optional[List[int]] = None, + include_in_output: bool = False, + use_async_output_proc: bool = False) -> None: + llm_engine.add_request( + "id", "A story about vLLM:\n", + SamplingParams( + temperature=0.0, + max_tokens=MAX_TOKENS, + stop=stop, + stop_token_ids=stop_token_ids, + include_stop_str_in_output=include_in_output, + ), None) + + output: Optional[CompletionOutput] = None + output_text = "" + stop_reason = None + + if use_async_output_proc: + llm_engine.step() + + while llm_engine.has_unfinished_requests(): + (request_output, ) = llm_engine.step() + (output, ) = request_output.outputs + + # Ensure we don't backtrack + assert output.text.startswith(output_text) + output_text = output.text + stop_reason = output.stop_reason + + assert output is not None + assert output_text == expected_output + assert stop_reason == expected_reason + + +def _set_async_mode(llm_engine, is_async): + llm_engine.scheduler[0].use_async_output_proc = is_async + + +def _stop_basic(llm_engine, is_async): + _test_stopping(llm_engine, stop=["."], include_in_output=False, expected_output="VLLM is a 100% volunteer organization", - expected_reason=".") + expected_reason=".", + use_async_output_proc=is_async) - _test_stopping(vllm_model.model.llm_engine, + _test_stopping(llm_engine, stop=["."], include_in_output=True, expected_output="VLLM is a 100% volunteer organization.", - expected_reason=".") + expected_reason=".", + use_async_output_proc=is_async) -@pytest.mark.skip_global_cleanup -def test_stop_multi_tokens(vllm_model): +def _stop_multi_tokens(llm_engine, is_async): _test_stopping( - vllm_model.model.llm_engine, + llm_engine, stop=["group of peo", "short"], include_in_output=False, expected_output="VLLM is a 100% volunteer organization. We are a ", - expected_reason="group of peo") + expected_reason="group of peo", + use_async_output_proc=is_async) _test_stopping( - vllm_model.model.llm_engine, + llm_engine, stop=["group of peo", "short"], include_in_output=True, expected_output= "VLLM is a 100% volunteer organization. We are a group of peo", - expected_reason="group of peo") + expected_reason="group of peo", + use_async_output_proc=is_async) -@pytest.mark.skip_global_cleanup -def test_stop_partial_token(vllm_model): - _test_stopping(vllm_model.model.llm_engine, +def _stop_partial_token(llm_engine, is_async): + _test_stopping(llm_engine, stop=["gani"], include_in_output=False, expected_output="VLLM is a 100% volunteer or", - expected_reason="gani") + expected_reason="gani", + use_async_output_proc=is_async) - _test_stopping(vllm_model.model.llm_engine, + _test_stopping(llm_engine, stop=["gani"], include_in_output=True, expected_output="VLLM is a 100% volunteer organi", - expected_reason="gani") + expected_reason="gani", + use_async_output_proc=is_async) -@pytest.mark.skip_global_cleanup -def test_stop_token_id(vllm_model): +def _stop_token_id(llm_engine, is_async): # token id 13013 => " organization" - _test_stopping(vllm_model.model.llm_engine, + _test_stopping(llm_engine, stop_token_ids=[13013], include_in_output=False, expected_output="VLLM is a 100% volunteer", - expected_reason=13013) + expected_reason=13013, + use_async_output_proc=is_async) - _test_stopping(vllm_model.model.llm_engine, + _test_stopping(llm_engine, stop_token_ids=[13013], include_in_output=True, expected_output="VLLM is a 100% volunteer organization", - expected_reason=13013) + expected_reason=13013, + use_async_output_proc=is_async) -def _test_stopping(llm_engine: LLMEngine, - expected_output: str, - expected_reason: Any, - stop: Optional[List[str]] = None, - stop_token_ids: Optional[List[int]] = None, - include_in_output: bool = False) -> None: - llm_engine.add_request( - "id", "A story about vLLM:\n", - SamplingParams( - temperature=0.0, - max_tokens=MAX_TOKENS, - stop=stop, - stop_token_ids=stop_token_ids, - include_stop_str_in_output=include_in_output, - ), None) +@pytest.mark.skip_global_cleanup +def test_stop_basic(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_basic(vllm_model.model.llm_engine, is_async=True) - output: Optional[CompletionOutput] = None - output_text = "" - stop_reason = None - while llm_engine.has_unfinished_requests(): - (request_output, ) = llm_engine.step() - (output, ) = request_output.outputs + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_basic(vllm_model.model.llm_engine, is_async=False) - # Ensure we don't backtrack - assert output.text.startswith(output_text) - output_text = output.text - stop_reason = output.stop_reason - assert output is not None - assert output_text == expected_output - assert stop_reason == expected_reason +@pytest.mark.skip_global_cleanup +def test_stop_multi_tokens(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_multi_tokens(vllm_model.model.llm_engine, is_async=True) + + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_multi_tokens(vllm_model.model.llm_engine, is_async=False) + + +@pytest.mark.skip_global_cleanup +def test_stop_partial_token(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_partial_token(vllm_model.model.llm_engine, is_async=True) + + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_partial_token(vllm_model.model.llm_engine, is_async=False) + + +@pytest.mark.skip_global_cleanup +def test_stop_token_id(vllm_model): + _set_async_mode(vllm_model.model.llm_engine, True) + _stop_token_id(vllm_model.model.llm_engine, is_async=True) + + _set_async_mode(vllm_model.model.llm_engine, False) + _stop_token_id(vllm_model.model.llm_engine, is_async=False) diff --git a/tests/entrypoints/openai/test_audio.py b/tests/entrypoints/openai/test_audio.py index 6dc8dde667389..a9a0ac012c8ff 100644 --- a/tests/entrypoints/openai/test_audio.py +++ b/tests/entrypoints/openai/test_audio.py @@ -2,6 +2,7 @@ import openai import pytest +import pytest_asyncio from vllm.assets.audio import AudioAsset from vllm.multimodal.utils import encode_audio_base64, fetch_audio @@ -28,9 +29,10 @@ def server(): yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.fixture(scope="session") diff --git a/tests/entrypoints/openai/test_basic.py b/tests/entrypoints/openai/test_basic.py index faada2ce64bcd..a7e418db30a29 100644 --- a/tests/entrypoints/openai/test_basic.py +++ b/tests/entrypoints/openai/test_basic.py @@ -2,6 +2,7 @@ import openai import pytest +import pytest_asyncio import requests from vllm.version import __version__ as VLLM_VERSION @@ -28,9 +29,10 @@ def server(): yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index ce5bf3d5d7ba0..0fbc4cca83bd2 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -6,6 +6,7 @@ import jsonschema import openai # use the official client for correctness check import pytest +import pytest_asyncio import torch from openai import BadRequestError @@ -46,9 +47,10 @@ def server(zephyr_lora_files, zephyr_lora_added_tokens_files): # noqa: F811 yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_completion.py b/tests/entrypoints/openai/test_completion.py index 18f41f5fc671b..d77cd57f12471 100644 --- a/tests/entrypoints/openai/test_completion.py +++ b/tests/entrypoints/openai/test_completion.py @@ -8,6 +8,7 @@ import jsonschema import openai # use the official client for correctness check import pytest +import pytest_asyncio # downloading lora to test lora requests from huggingface_hub import snapshot_download from openai import BadRequestError @@ -89,11 +90,17 @@ def default_server_args(zephyr_lora_files, zephyr_lora_added_tokens_files, @pytest.fixture(scope="module", params=["", "--disable-frontend-multiprocessing"]) -def client(default_server_args, request): +def server(default_server_args, request): if request.param: default_server_args.append(request.param) with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: - yield remote_server.get_async_client() + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 6bf170b94c0d7..3baaeab2feeaf 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -3,6 +3,7 @@ import numpy as np import openai import pytest +import pytest_asyncio from ...utils import RemoteOpenAIServer @@ -24,10 +25,10 @@ def embedding_server(): yield remote_server -@pytest.mark.asyncio -@pytest.fixture(scope="module") -def embedding_client(embedding_server): - return embedding_server.get_async_client() +@pytest_asyncio.fixture +async def embedding_client(embedding_server): + async with embedding_server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_encoder_decoder.py b/tests/entrypoints/openai/test_encoder_decoder.py index 85f1c6f18bf36..51eba694e62ad 100644 --- a/tests/entrypoints/openai/test_encoder_decoder.py +++ b/tests/entrypoints/openai/test_encoder_decoder.py @@ -1,5 +1,6 @@ import openai import pytest +import pytest_asyncio from ...utils import RemoteOpenAIServer @@ -18,9 +19,10 @@ def server(): yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_metrics.py b/tests/entrypoints/openai/test_metrics.py index 042c3730e09f5..5e9a9f8ab7d4d 100644 --- a/tests/entrypoints/openai/test_metrics.py +++ b/tests/entrypoints/openai/test_metrics.py @@ -6,6 +6,7 @@ import openai import pytest +import pytest_asyncio import requests from prometheus_client.parser import text_string_to_metric_families from transformers import AutoTokenizer @@ -35,11 +36,17 @@ def default_server_args(): "--enable-chunked-prefill", "--disable-frontend-multiprocessing", ]) -def client(default_server_args, request): +def server(default_server_args, request): if request.param: default_server_args.append(request.param) with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server: - yield remote_server.get_async_client() + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as cl: + yield cl _PROMPT = "Hello my name is Robert and I love magic" diff --git a/tests/entrypoints/openai/test_models.py b/tests/entrypoints/openai/test_models.py index c2cfff228c546..5cd570f43e1a7 100644 --- a/tests/entrypoints/openai/test_models.py +++ b/tests/entrypoints/openai/test_models.py @@ -1,5 +1,6 @@ import openai # use the official client for correctness check import pytest +import pytest_asyncio # downloading lora to test lora requests from huggingface_hub import snapshot_download @@ -43,9 +44,10 @@ def server(zephyr_lora_files): yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_return_tokens_as_ids.py b/tests/entrypoints/openai/test_return_tokens_as_ids.py index abe413978e0e5..99f6da160d6f9 100644 --- a/tests/entrypoints/openai/test_return_tokens_as_ids.py +++ b/tests/entrypoints/openai/test_return_tokens_as_ids.py @@ -25,59 +25,63 @@ def server_with_return_tokens_as_token_ids_flag( @pytest.mark.asyncio async def test_completion_return_tokens_as_token_ids_completion( server_with_return_tokens_as_token_ids_flag): - client = server_with_return_tokens_as_token_ids_flag.get_async_client() + async with server_with_return_tokens_as_token_ids_flag.get_async_client( + ) as client: - completion = await client.completions.create( - model=MODEL_NAME, - # Include Unicode characters to test for dividing a single - # character across multiple tokens: 🎉 is [28705, 31862] for the - # Zephyr tokenizer - prompt="Say 'Hello, world! 🎉'", - echo=True, - temperature=0, - max_tokens=10, - logprobs=1) + completion = await client.completions.create( + model=MODEL_NAME, + # Include Unicode characters to test for dividing a single + # character across multiple tokens: 🎉 is [28705, 31862] for the + # Zephyr tokenizer + prompt="Say 'Hello, world! 🎉'", + echo=True, + temperature=0, + max_tokens=10, + logprobs=1) - text = completion.choices[0].text - token_strs = completion.choices[0].logprobs.tokens - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - # Check that the token representations are consistent between raw tokens - # and top_logprobs - # Slice off the first one, because there's no scoring associated with BOS - top_logprobs = completion.choices[0].logprobs.top_logprobs[1:] - top_logprob_keys = [ - next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs - ] - assert token_strs[1:] == top_logprob_keys + text = completion.choices[0].text + token_strs = completion.choices[0].logprobs.tokens + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + # Check that the token representations are consistent between raw + # tokens and top_logprobs + # Slice off the first one, because there's no scoring associated + # with BOS + top_logprobs = completion.choices[0].logprobs.top_logprobs[1:] + top_logprob_keys = [ + next(iter(logprob_by_tokens)) for logprob_by_tokens in top_logprobs + ] + assert token_strs[1:] == top_logprob_keys - # Check that decoding the tokens gives the expected text - tokens = [int(token.removeprefix("token_id:")) for token in token_strs] - assert text == tokenizer.decode(tokens, skip_special_tokens=True) + # Check that decoding the tokens gives the expected text + tokens = [int(token.removeprefix("token_id:")) for token in token_strs] + assert text == tokenizer.decode(tokens, skip_special_tokens=True) @pytest.mark.asyncio async def test_chat_return_tokens_as_token_ids_completion( server_with_return_tokens_as_token_ids_flag): - client = server_with_return_tokens_as_token_ids_flag.get_async_client() - response = await client.chat.completions.create( - model=MODEL_NAME, - # Include Unicode characters to test for dividing a single - # character across multiple tokens: 🎉 is [28705, 31862] for the - # Zephyr tokenizer - messages=[{ - "role": "system", - "content": "You like to respond in only emojis, like 🎉" - }, { - "role": "user", - "content": "Please write some emojis: 🐱🐶🎉" - }], - temperature=0, - max_tokens=8, - logprobs=True) + async with server_with_return_tokens_as_token_ids_flag.get_async_client( + ) as client: + response = await client.chat.completions.create( + model=MODEL_NAME, + # Include Unicode characters to test for dividing a single + # character across multiple tokens: 🎉 is [28705, 31862] for the + # Zephyr tokenizer + messages=[{ + "role": "system", + "content": "You like to respond in only emojis, like 🎉" + }, { + "role": "user", + "content": "Please write some emojis: 🐱🐶🎉" + }], + temperature=0, + max_tokens=8, + logprobs=True) - text = response.choices[0].message.content - tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) - token_ids = [] - for logprob_content in response.choices[0].logprobs.content: - token_ids.append(int(logprob_content.token.removeprefix("token_id:"))) - assert tokenizer.decode(token_ids, skip_special_tokens=True) == text + text = response.choices[0].message.content + tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME) + token_ids = [] + for logprob_content in response.choices[0].logprobs.content: + token_ids.append( + int(logprob_content.token.removeprefix("token_id:"))) + assert tokenizer.decode(token_ids, skip_special_tokens=True) == text diff --git a/tests/entrypoints/openai/test_shutdown.py b/tests/entrypoints/openai/test_shutdown.py index 6dff1cfbe7f75..73ecb74007272 100644 --- a/tests/entrypoints/openai/test_shutdown.py +++ b/tests/entrypoints/openai/test_shutdown.py @@ -35,13 +35,14 @@ async def test_shutdown_on_engine_failure(tmp_path): ] with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: - client = remote_server.get_async_client() + async with remote_server.get_async_client() as client: - with pytest.raises(openai.APIConnectionError): - # This crashes the engine - await client.completions.create(model="bad-adapter", - prompt="Hello, my name is") + with pytest.raises( + (openai.APIConnectionError, openai.InternalServerError)): + # This crashes the engine + await client.completions.create(model="bad-adapter", + prompt="Hello, my name is") - # Now the server should shut down - return_code = remote_server.proc.wait(timeout=1) - assert return_code is not None + # Now the server should shut down + return_code = remote_server.proc.wait(timeout=3) + assert return_code is not None diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 18c51c560b511..316ca11b8e95a 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -1,5 +1,6 @@ import openai # use the official client for correctness check import pytest +import pytest_asyncio import requests from vllm.transformers_utils.tokenizer import get_tokenizer @@ -42,9 +43,10 @@ def tokenizer_name(model_name: str, model_name == "zephyr-lora2") else model_name -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_vision.py b/tests/entrypoints/openai/test_vision.py index 843ba91f7a076..d2ef3c2071efb 100644 --- a/tests/entrypoints/openai/test_vision.py +++ b/tests/entrypoints/openai/test_vision.py @@ -2,6 +2,7 @@ import openai import pytest +import pytest_asyncio from vllm.multimodal.utils import encode_image_base64, fetch_image @@ -36,9 +37,10 @@ def server(): yield remote_server -@pytest.fixture(scope="module") -def client(server): - return server.get_async_client() +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client @pytest.fixture(scope="session") diff --git a/tests/models/test_llava_next.py b/tests/models/test_llava_next.py index 9cf55c0858df0..d5fe0cbe32880 100644 --- a/tests/models/test_llava_next.py +++ b/tests/models/test_llava_next.py @@ -6,24 +6,22 @@ from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import (IMAGE_ASSETS, HfRunner, PromptImageInput, VllmRunner, + _ImageAssets) from .utils import check_logprobs_close pytestmark = pytest.mark.vlm -_PREFACE = ( - "A chat between a curious human and an artificial intelligence assistant. " - "The assistant gives helpful, detailed, and polite answers to the human's " - "questions.") +_LIMIT_IMAGE_PER_PROMPT = 4 HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": - f"{_PREFACE} USER: \nWhat's the content of the image? ASSISTANT:", + "[INST] \nWhat's the content of the image? [/INST]", "cherry_blossom": - f"{_PREFACE} USER: \nWhat is the season? ASSISTANT:", + "[INST] \nWhat is the season? [/INST]", }) -models = ["llava-hf/llava-v1.6-vicuna-7b-hf"] +models = ["llava-hf/llava-v1.6-mistral-7b-hf"] def vllm_to_hf_output(vllm_output: Tuple[List[int], str, @@ -114,19 +112,43 @@ def run_test( else: raise ValueError("You must provide either `size_factors` or `sizes`") + _run_test(hf_runner, + vllm_runner, + inputs_per_image, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=tensor_parallel_size, + distributed_executor_backend=distributed_executor_backend) + + +def _run_test( + hf_runner: Type[HfRunner], + vllm_runner: Type[VllmRunner], + inputs: List[Tuple[List[str], PromptImageInput]], + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, + tensor_parallel_size: int, + distributed_executor_backend: Optional[str] = None, +): # max_model_len should be greater than image_feature_size with vllm_runner(model, dtype=dtype, - max_model_len=4096, + max_model_len=10240, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: + enforce_eager=True, + limit_mm_per_prompt={"image": _LIMIT_IMAGE_PER_PROMPT + }) as vllm_model: vllm_outputs_per_image = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] with hf_runner(model, dtype=dtype, @@ -136,7 +158,7 @@ def run_test( max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -177,7 +199,7 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, All the image fixtures for the test is under tests/images. For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects + For vllm runner, we provide MultiModalDataDict objects and corresponding MultiModalConfig as input. Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. @@ -216,3 +238,48 @@ def test_models_fixed_sizes(hf_runner, vllm_runner, image_assets, model, sizes, num_logprobs=num_logprobs, tensor_parallel_size=1, ) + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("max_tokens", [128]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models_multiple_image_inputs(hf_runner, vllm_runner, image_assets, + model, dtype, max_tokens, + num_logprobs) -> None: + stop_sign = image_assets[0].pil_image + cherry_blossom = image_assets[1].pil_image + + inputs = [( + [ + "[INST] \nDescribe 2 images. [/INST]", + "[INST] \nDescribe 2 images. [/INST]", + "[INST] \nDescribe 4 images. [/INST]", + "[INST] \nWhat is the season? [/INST]" + ], + [ + [stop_sign, cherry_blossom], + # Images with different sizes and aspect-ratios + [ + rescale_image_size(stop_sign, 0.1), + stop_sign, + ], + [ + stop_sign, + rescale_image_size(stop_sign, 0.25), + cherry_blossom.resize((183, 488)), + cherry_blossom.resize((488, 183)) + ], + cherry_blossom, + ])] + + _run_test( + hf_runner, + vllm_runner, + inputs, + model, + dtype=dtype, + max_tokens=max_tokens, + num_logprobs=num_logprobs, + tensor_parallel_size=1, + ) diff --git a/tests/models/test_minicpmv.py b/tests/models/test_minicpmv.py index bf72dad0d1f5b..99e49c14f1f26 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/test_minicpmv.py @@ -1,14 +1,15 @@ -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, Union import pytest import torch import torch.types +from PIL import Image from transformers import BatchEncoding from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs -from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner from .utils import check_logprobs_close pytestmark = pytest.mark.vlm @@ -24,6 +25,11 @@ "(./)\nWhat is the season?<|eot_id|>" \ "<|start_header_id|>assistant<|end_header_id|>\n\n", }) +HF_MULTIIMAGE_IMAGE_PROMPT = \ + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ + "(./)\n(./)\n" \ + "Describe these images.<|eot_id|>" \ + "<|start_header_id|>assistant<|end_header_id|>\n\n" models = ["openbmb/MiniCPM-Llama3-V-2_5"] @@ -46,13 +52,14 @@ def trunc_hf_output(hf_output: Tuple[List[int], str, def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, + inputs: List[Tuple[List[str], Union[List[Image.Image], + List[List[Image.Image]]]]], model: str, *, - size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, + mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -65,12 +72,6 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - images = [asset.pil_image for asset in image_assets] - - inputs_per_image = [( - [prompt for _ in size_factors], - [rescale_image_size(image, factor) for factor in size_factors], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -82,6 +83,7 @@ def run_test( max_model_len=4096, max_num_seqs=1, dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: @@ -93,7 +95,7 @@ def run_test( num_logprobs=num_logprobs, images=images, stop_token_ids=stop_token_ids) - for prompts, images in inputs_per_image + for prompts, images in inputs ] hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs) @@ -104,7 +106,7 @@ def run_test( num_logprobs=num_logprobs, images=images, tokenizer=tokenizer) - for prompts, images in inputs_per_image + for prompts, images in inputs ] for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, @@ -138,104 +140,26 @@ def run_test( @pytest.mark.parametrize("num_logprobs", [5]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + run_test( hf_runner, vllm_runner, - image_assets, + inputs_per_image, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=1, tensor_parallel_size=1, ) -HF_MULTIIMAGE_IMAGE_PROMPT = \ - "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" \ - "(./)\n(./)\n" \ - "Describe these images.<|eot_id|>" \ - "<|start_header_id|>assistant<|end_header_id|>\n\n" - - -def run_multi_image_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - image_assets: _ImageAssets, - model: str, - *, - size_factors: List[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test is under tests/images. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - images = [asset.pil_image for asset in image_assets] - - inputs_per_case = [ - ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors]) - ] - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - max_model_len=4096, - max_num_seqs=1, - limit_mm_per_prompt={"image": len(images)}, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - tokenizer = vllm_model.model.get_tokenizer() - stop_token_ids = [tokenizer.eos_id, tokenizer.eot_id] - vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - stop_token_ids=stop_token_ids) - for prompts, images in inputs_per_case - ] - - hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs) - with hf_model, torch.no_grad(): - hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - tokenizer=tokenizer) - for prompts, images in inputs_per_case - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=[ - trunc_hf_output(hf_output) for hf_output in hf_outputs - ], - outputs_1_lst=vllm_outputs, - name_0="hf", - name_1="vllm", - ) - - @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", @@ -256,14 +180,22 @@ def run_multi_image_test( def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: - run_multi_image_test( + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + run_test( hf_runner, vllm_runner, - image_assets, + inputs_per_case, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=2, tensor_parallel_size=1, ) diff --git a/tests/models/test_mistral.py b/tests/models/test_mistral.py index 6acc057fe588c..4965354c0016b 100644 --- a/tests/models/test_mistral.py +++ b/tests/models/test_mistral.py @@ -30,9 +30,11 @@ def test_models( hf_outputs = hf_model.generate_greedy_logprobs_limit( example_prompts, max_tokens, num_logprobs) - with vllm_runner(model, dtype=dtype) as vllm_model: + with vllm_runner(model, dtype=dtype, + tokenizer_mode="mistral") as vllm_model: vllm_outputs = vllm_model.generate_greedy_logprobs( example_prompts, max_tokens, num_logprobs) + check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=vllm_outputs, diff --git a/tests/models/test_phi3v.py b/tests/models/test_phi3v.py index 259cbe515066d..e416a85b8962a 100644 --- a/tests/models/test_phi3v.py +++ b/tests/models/test_phi3v.py @@ -1,6 +1,6 @@ import os import re -from typing import List, Optional, Tuple, Type +from typing import List, Optional, Tuple, Type, Union import pytest from PIL import Image @@ -60,13 +60,14 @@ def vllm_to_hf_output(vllm_output: Tuple[List[int], str, def run_test( hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], - images: List[Image.Image], + inputs: List[Tuple[List[str], Union[List[Image.Image], + List[List[Image.Image]]]]], model: str, *, - size_factors: List[float], dtype: str, max_tokens: int, num_logprobs: int, + mm_limit: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): @@ -79,13 +80,6 @@ def run_test( Note, the text input is also adjusted to abide by vllm contract. The text output is sanitized to be able to compare with hf. """ - inputs_per_image = [( - [prompt for _ in size_factors], - [ - rescale_image_size(image, factor, transpose=idx) - for idx, factor in enumerate(size_factors) - ], - ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] # NOTE: take care of the order. run vLLM first, and then run HF. # vLLM needs a fresh new process without cuda initialization. @@ -97,15 +91,16 @@ def run_test( max_model_len=4096, max_num_seqs=1, dtype=dtype, + limit_mm_per_prompt={"image": mm_limit}, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: - vllm_outputs_per_image = [ + vllm_outputs_per_case = [ vllm_model.generate_greedy_logprobs(prompts, max_tokens, num_logprobs=num_logprobs, images=images) - for prompts, images in inputs_per_image + for prompts, images in inputs ] # use eager mode for hf runner, since phi3_v didn't work with flash_attn @@ -113,17 +108,17 @@ def run_test( with hf_runner(model, dtype=dtype, model_kwargs=hf_model_kwargs) as hf_model: eos_token_id = hf_model.processor.tokenizer.eos_token_id - hf_outputs_per_image = [ + hf_outputs_per_case = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, num_logprobs=num_logprobs, images=images, eos_token_id=eos_token_id) - for prompts, images in inputs_per_image + for prompts, images in inputs ] - for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, - vllm_outputs_per_image): + for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, + vllm_outputs_per_case): check_logprobs_close( outputs_0_lst=hf_outputs, outputs_1_lst=[ @@ -156,15 +151,22 @@ def run_test( @pytest.mark.parametrize("num_logprobs", [10]) def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_per_image = [( + [prompt for _ in size_factors], + [rescale_image_size(image, factor) for factor in size_factors], + ) for image, prompt in zip(images, HF_IMAGE_PROMPTS)] + run_test( hf_runner, vllm_runner, - [asset.pil_image for asset in image_assets], + inputs_per_image, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=1, tensor_parallel_size=1, ) @@ -173,97 +175,26 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, @pytest.mark.parametrize("dtype", [target_dtype]) def test_regression_7840(hf_runner, vllm_runner, image_assets, model, dtype) -> None: + images = [asset.pil_image for asset in image_assets] + + inputs_regresion_7840 = [ + ([prompt], [image]) for image, prompt in zip(images, HF_IMAGE_PROMPTS) + ] + # Regression test for #7840. run_test( hf_runner, vllm_runner, - [image_assets[0].pil_image.resize((465, 226))], + inputs_regresion_7840, model, - size_factors=[1.0], dtype=dtype, max_tokens=128, num_logprobs=10, + mm_limit=1, tensor_parallel_size=1, ) -def run_multi_image_test( - hf_runner: Type[HfRunner], - vllm_runner: Type[VllmRunner], - images: List[Image.Image], - model: str, - *, - size_factors: List[float], - dtype: str, - max_tokens: int, - num_logprobs: int, - tensor_parallel_size: int, - distributed_executor_backend: Optional[str] = None, -): - """Inference result should be the same between hf and vllm. - - All the image fixtures for the test is under tests/images. - For huggingface runner, we provide the PIL images as input. - For vllm runner, we provide MultiModalDataDict objects - and corresponding MultiModalConfig as input. - Note, the text input is also adjusted to abide by vllm contract. - The text output is sanitized to be able to compare with hf. - """ - - inputs_per_case = [ - ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], - [[rescale_image_size(image, factor) for image in images] - for factor in size_factors]) - ] - - # NOTE: take care of the order. run vLLM first, and then run HF. - # vLLM needs a fresh new process without cuda initialization. - # if we run HF first, the cuda initialization will be done and it - # will hurt multiprocessing backend with fork method (the default method). - - # max_model_len should be greater than image_feature_size - with vllm_runner(model, - max_model_len=4096, - max_num_seqs=1, - limit_mm_per_prompt={"image": len(images)}, - dtype=dtype, - tensor_parallel_size=tensor_parallel_size, - distributed_executor_backend=distributed_executor_backend, - enforce_eager=True) as vllm_model: - vllm_outputs_per_case = [ - vllm_model.generate_greedy_logprobs(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images) - for prompts, images in inputs_per_case - ] - - hf_model_kwargs = {"_attn_implementation": "eager"} - with hf_runner(model, dtype=dtype, - model_kwargs=hf_model_kwargs) as hf_model: - eos_token_id = hf_model.processor.tokenizer.eos_token_id - hf_outputs_per_case = [ - hf_model.generate_greedy_logprobs_limit(prompts, - max_tokens, - num_logprobs=num_logprobs, - images=images, - eos_token_id=eos_token_id) - for prompts, images in inputs_per_case - ] - - for hf_outputs, vllm_outputs in zip(hf_outputs_per_case, - vllm_outputs_per_case): - check_logprobs_close( - outputs_0_lst=hf_outputs, - outputs_1_lst=[ - vllm_to_hf_output(vllm_output, model) - for vllm_output in vllm_outputs - ], - name_0="hf", - name_1="vllm", - ) - - @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", @@ -280,18 +211,26 @@ def run_multi_image_test( ) @pytest.mark.parametrize("dtype", [target_dtype]) @pytest.mark.parametrize("max_tokens", [128]) -@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("num_logprobs", [10]) def test_multi_images_models(hf_runner, vllm_runner, image_assets, model, size_factors, dtype: str, max_tokens: int, num_logprobs: int) -> None: - run_multi_image_test( + images = [asset.pil_image for asset in image_assets] + + inputs_per_case = [ + ([HF_MULTIIMAGE_IMAGE_PROMPT for _ in size_factors], + [[rescale_image_size(image, factor) for image in images] + for factor in size_factors]) + ] + + run_test( hf_runner, vllm_runner, - [asset.pil_image for asset in image_assets], + inputs_per_case, model, - size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, num_logprobs=num_logprobs, + mm_limit=2, tensor_parallel_size=1, ) diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index bc14311c66424..ad99d70d7417c 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -28,12 +28,12 @@ async def completions_with_server_args(prompts: List[str], model_name: str, outputs = None with RemoteOpenAIServer(model_name, server_cli_args) as server: - client = server.get_async_client() - outputs = await client.completions.create(model=model_name, - prompt=prompts, - temperature=0, - stream=False, - max_tokens=5) + async with server.get_async_client() as client: + outputs = await client.completions.create(model=model_name, + prompt=prompts, + temperature=0, + stream=False, + max_tokens=5) assert outputs is not None return outputs @@ -62,6 +62,9 @@ async def test_multi_step(example_prompts, model: str, tp_size: int, ms_server_args = DEFAULT_SERVER_ARGS + \ ["--num-scheduler-steps", f"{num_scheduler_steps}"] + # Disable output proc callback as its not supported + # with multi-step right now + ms_server_args += ["--disable-async-output-proc"] if eager_mode: ms_server_args.append("--enforce-eager") diff --git a/tests/multimodal/test_base.py b/tests/multimodal/test_base.py new file mode 100644 index 0000000000000..f19a0f33fe067 --- /dev/null +++ b/tests/multimodal/test_base.py @@ -0,0 +1,83 @@ +import torch + +from vllm.multimodal.base import MultiModalInputs, NestedTensors + + +def assert_nested_tensors_equal(expected: NestedTensors, + actual: NestedTensors): + assert type(expected) == type(actual) + if isinstance(expected, torch.Tensor): + assert torch.equal(expected, actual) + else: + for expected_item, actual_item in zip(expected, actual): + assert_nested_tensors_equal(expected_item, actual_item) + + +def assert_multimodal_inputs_equal(expected: MultiModalInputs, + actual: MultiModalInputs): + assert set(expected.keys()) == set(actual.keys()) + for key in expected: + assert_nested_tensors_equal(expected[key], actual[key]) + + +def test_multimodal_input_batch_single_tensor(): + t = torch.rand([1, 2]) + result = MultiModalInputs.batch([{"image": t}]) + assert_multimodal_inputs_equal(result, {"image": t.unsqueeze(0)}) + + +def test_multimodal_input_batch_multiple_tensors(): + a = torch.rand([1, 1, 2]) + b = torch.rand([1, 1, 2]) + c = torch.rand([1, 1, 2]) + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + assert_multimodal_inputs_equal(result, {"image": torch.stack([a, b, c])}) + + +def test_multimodal_input_batch_multiple_heterogeneous_tensors(): + a = torch.rand([1, 2, 2]) + b = torch.rand([1, 3, 2]) + c = torch.rand([1, 4, 2]) + result = MultiModalInputs.batch([{"image": a}, {"image": b}, {"image": c}]) + assert_multimodal_inputs_equal(result, {"image": [a, b, c]}) + + +def test_multimodal_input_batch_nested_tensors(): + a = torch.rand([2, 3]) + b = torch.rand([2, 3]) + c = torch.rand([2, 3]) + result = MultiModalInputs.batch([{ + "image": [a] + }, { + "image": [b] + }, { + "image": [c] + }]) + assert_multimodal_inputs_equal(result, { + "image": + torch.stack([a.unsqueeze(0), + b.unsqueeze(0), + c.unsqueeze(0)]) + }) + + +def test_multimodal_input_batch_heterogeneous_lists(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 2, 3]) + c = torch.rand([1, 2, 3]) + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c]}]) + assert_multimodal_inputs_equal( + result, + {"image": [torch.stack([a, b]), c.unsqueeze(0)]}) + + +def test_multimodal_input_batch_multiple_batchable_lists(): + a = torch.rand([1, 2, 3]) + b = torch.rand([1, 2, 3]) + c = torch.rand([1, 2, 3]) + d = torch.rand([1, 2, 3]) + result = MultiModalInputs.batch([{"image": [a, b]}, {"image": [c, d]}]) + assert_multimodal_inputs_equal( + result, + {"image": torch.stack([torch.stack([a, b]), + torch.stack([c, d])])}) diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index cd1fc91c29374..38cd48629f903 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -6,8 +6,10 @@ import numpy as np import pytest from PIL import Image +from transformers import AutoConfig, AutoTokenizer -from vllm.multimodal.utils import async_fetch_image, fetch_image +from vllm.multimodal.utils import (async_fetch_image, fetch_image, + repeat_and_pad_placeholder_tokens) # Test different image extensions (JPG/PNG) and formats (gray/RGB/RGBA) TEST_IMAGE_URLS = [ @@ -80,3 +82,34 @@ async def test_fetch_image_base64(url_images: Dict[str, Image.Image], data_image_async = await async_fetch_image(data_url) assert _image_equals(data_image_sync, data_image_async) + + +@pytest.mark.parametrize("model", ["llava-hf/llava-v1.6-mistral-7b-hf"]) +def test_repeat_and_pad_placeholder_tokens(model): + config = AutoConfig.from_pretrained(model) + image_token_id = config.image_token_index + + tokenizer = AutoTokenizer.from_pretrained(model) + + test_cases = [ + ("", 2, "", [32000, 32000]), + ("", 2, "", [32000, 32000, 32000]), + ("", [3, 2], "", + [32000, 32000, 32000, 32000, 32000]), + ("Image:Image:!", [3, 2], + "Image:Image:!", + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918]), + ("", [3, 2], "", [32000, 32000, 32000]), + ] + + for prompt, repeat_count, expected_prompt, expected_token_ids in test_cases: + new_prompt, new_token_ids = repeat_and_pad_placeholder_tokens( + tokenizer=tokenizer, + prompt=prompt, + prompt_token_ids=tokenizer.encode(prompt, + add_special_tokens=False), + placeholder_token_id=image_token_id, + repeat_count=repeat_count, + ) + assert new_prompt == expected_prompt + assert new_token_ids == expected_token_ids diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 2ea340779b819..7dd20636c892f 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -160,4 +160,4 @@ def test_compressed_tensors_kv_cache(vllm_runner): model_path = "nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme" with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: output = llm.generate_greedy("Hello world!", max_tokens=20) - assert output + assert output \ No newline at end of file diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py new file mode 100644 index 0000000000000..5a432fb78b3da --- /dev/null +++ b/tests/tpu/test_compilation.py @@ -0,0 +1,34 @@ +import glob +import os +import runpy +import tempfile + +import depyf + +temp_dir = tempfile.mkdtemp() +with depyf.prepare_debug(temp_dir): + cur_dir = os.path.dirname(__file__) + parent_dir = os.path.dirname(cur_dir) + root_dir = os.path.dirname(parent_dir) + example_file = os.path.join(root_dir, "examples", + "offline_inference_tpu.py") + runpy.run_path(example_file) + +compiled_code = sorted( + glob.glob(os.path.join(temp_dir, "__transformed_code*.py"))) +full_code = glob.glob(os.path.join(temp_dir, "full_code*.py"))[0] +# we should only trigger Dynamo compilation three times: +# one for the profiling phase (and the compiled artifact will be discarded) +# one for the prefill phase with symbolic shapes +# one for the decode phase with symbolic shapes +# and later calls should not trigger Dynamo compilation again. +# NOTE: it might still trigger XLA compilation. + +# check we have three compiled code +assert len(compiled_code) == 3 + +# check the first compilation is discarded +with open(full_code) as f: + full_code_content = f.read() + profile_function = compiled_code[0].split(".")[0] + assert profile_function not in full_code_content diff --git a/tests/utils.py b/tests/utils.py index b73a05b5fe67f..de887bc8cf6fb 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -154,6 +154,7 @@ def get_async_client(self): return openai.AsyncOpenAI( base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY, + max_retries=0, ) diff --git a/tests/weight_loading/models.txt b/tests/weight_loading/models.txt index 98a66b6701ea9..cbe30305c14f6 100644 --- a/tests/weight_loading/models.txt +++ b/tests/weight_loading/models.txt @@ -13,8 +13,12 @@ compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-quantized, main +compressed-tensors, nm-testing/Mixtral-8x7B-Instruct-v0.1-W4A16-channel-quantized, main awq, casperhansen/mixtral-instruct-awq, main awq_marlin, casperhansen/mixtral-instruct-awq, main fp8, neuralmagic/Meta-Llama-3-8B-Instruct-FP8-KV, main marlin, nm-testing/zephyr-beta-7b-marlin-g128, main -marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main \ No newline at end of file +marlin, robertgshaw2/zephyr-7b-beta-channelwise-marlin, main +qqq, HandH1998/QQQ-Llama-3-8b-g128, main +qqq, HandH1998/QQQ-Llama-3-8b, main \ No newline at end of file diff --git a/vllm/_core_ext.py b/vllm/_core_ext.py index aa520e1eafbaf..a27b8648bee47 100644 --- a/vllm/_core_ext.py +++ b/vllm/_core_ext.py @@ -181,92 +181,98 @@ def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, ScalarType = torch.classes._core_C.ScalarType - # Needed for dynamo support of ScalarType. - @torch._library.register_fake_class("_core_C::ScalarType") - class FakeScalarType: + if (hasattr(torch, "_library") + and hasattr(torch._library, "register_fake_class")): + # Needed for dynamo support of ScalarType. + @torch._library.register_fake_class("_core_C::ScalarType") + class FakeScalarType: - def __init__(self, scalar_type): - self.ScalarType = scalar_type + def __init__(self, scalar_type): + self.ScalarType = scalar_type - def bias_getter(self) -> int: - return self.ScalarType.bias + def bias_getter(self) -> int: + return self.ScalarType.bias - def exponent_getter(self) -> int: - return self.ScalarType.exponent + def exponent_getter(self) -> int: + return self.ScalarType.exponent - def mantissa_getter(self) -> int: - return self.ScalarType.mantissa + def mantissa_getter(self) -> int: + return self.ScalarType.mantissa - def signed_getter(self) -> bool: - return self.ScalarType.signed + def signed_getter(self) -> bool: + return self.ScalarType.signed - def size_bits_getter(self) -> int: - return self.ScalarType.size_bits + def size_bits_getter(self) -> int: + return self.ScalarType.size_bits - @property - def size_bits(self) -> int: - return self.ScalarType.size_bits + @property + def size_bits(self) -> int: + return self.ScalarType.size_bits - def min(self) -> Union[int, float]: - return self.ScalarType.min() + def min(self) -> Union[int, float]: + return self.ScalarType.min() - def max(self) -> Union[int, float]: - return self.ScalarType.max() + def max(self) -> Union[int, float]: + return self.ScalarType.max() - def is_signed(self) -> bool: - return self.ScalarType.is_signed() + def is_signed(self) -> bool: + return self.ScalarType.is_signed() - def is_floating_point(self) -> bool: - return self.ScalarType.is_floating_point() + def is_floating_point(self) -> bool: + return self.ScalarType.is_floating_point() - def is_integer(self) -> bool: - return self.ScalarType.is_integer() + def is_integer(self) -> bool: + return self.ScalarType.is_integer() - def has_bias(self) -> bool: - return self.ScalarType.has_bias() + def has_bias(self) -> bool: + return self.ScalarType.has_bias() - def has_infs(self) -> bool: - return self.ScalarType.has_infs() + def has_infs(self) -> bool: + return self.ScalarType.has_infs() - def has_nans(self) -> bool: - return self.ScalarType.has_nans() + def has_nans(self) -> bool: + return self.ScalarType.has_nans() - def is_ieee_754(self) -> bool: - return self.ScalarType.is_ieee_754() + def is_ieee_754(self) -> bool: + return self.ScalarType.is_ieee_754() - def __str__(self) -> str: - return self.ScalarType.__str__() + def __str__(self) -> str: + return self.ScalarType.__str__() - def __repr__(self) -> str: - return self.ScalarType.__repr__() + def __repr__(self) -> str: + return self.ScalarType.__repr__() - def __len__(self) -> int: - return self.ScalarType.__len__() + def __len__(self) -> int: + return self.ScalarType.__len__() - def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: - return torch.classes._core_C.ScalarType.__obj_flatten__( - self.ScalarType) + def __obj_flatten__(self) -> Tuple[Tuple[str, Any], ...]: + return torch.classes._core_C.ScalarType.__obj_flatten__( + self.ScalarType) - @classmethod - def __obj_unflatten__( - cls, flat_type: Tuple[Tuple[str, Any], ...]) -> 'ScalarType': - return cls( - torch.classes._core_C.ScalarType.__obj_unflatten__(flat_type)) + @classmethod + def __obj_unflatten__( + cls, flat_type: Tuple[Tuple[str, Any], + ...]) -> 'ScalarType': + return cls( + torch.classes._core_C.ScalarType.__obj_unflatten__( + flat_type)) - @classmethod - def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.int_(size_bits, bias) + @classmethod + def int_(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.int_(size_bits, bias) - @classmethod - def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': - return ScalarType.uint(size_bits, bias) + @classmethod + def uint(cls, size_bits: int, bias: Optional[int]) -> 'ScalarType': + return ScalarType.uint(size_bits, bias) - @classmethod - def float_IEEE754(cls, exponent: int, mantissa: int) -> 'ScalarType': - return ScalarType.float_IEEE754(exponent, mantissa) + @classmethod + def float_IEEE754(cls, exponent: int, + mantissa: int) -> 'ScalarType': + return ScalarType.float_IEEE754(exponent, mantissa) - @classmethod - def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, - nan_repr: int) -> 'ScalarType': - return ScalarType.float_(exponent, mantissa, finite_values_only, - nan_repr) + @classmethod + def float_(cls, exponent: int, mantissa: int, + finite_values_only: bool, + nan_repr: int) -> 'ScalarType': + return ScalarType.float_(exponent, mantissa, + finite_values_only, nan_repr) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index b89a90ef0f70c..ae90af563c0cf 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -300,6 +300,20 @@ def awq_marlin_repack(b_q_weight: torch.Tensor, size_k: int, size_n: int, return torch.ops._C.awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) +def gptq_marlin_moe_repack(b_q_weight: torch.Tensor, perm: torch.Tensor, + size_k: int, size_n: int, + num_bits: int) -> torch.Tensor: + num_experts = b_q_weight.shape[0] + assert size_k % 16 == 0 + output = torch.empty((num_experts, size_k // 16, size_n * 2), + device=b_q_weight.device, + dtype=b_q_weight.dtype) + for e in range(num_experts): + output[e] = torch.ops._C.gptq_marlin_repack(b_q_weight[e], perm[e], + size_k, size_n, num_bits) + return output + + def gptq_marlin_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, b_scales: torch.Tensor, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index ce7a7198dc400..a8d76b79ff204 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -113,8 +113,7 @@ def _get_decode_wrapper(self): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ - (1, 2, 4, 8) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 self._decode_wrapper = BatchDecodeWithPagedKVCacheWrapper( self._get_workspace_buffer(), "NHD", @@ -172,8 +171,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): self.runner.parallel_config)) num_kv_heads = self.runner.model_config.get_num_kv_heads( self.runner.parallel_config) - use_tensor_cores = (num_qo_heads // num_kv_heads) not in \ - (1, 2, 4, 8) + use_tensor_cores = num_qo_heads // num_kv_heads > 4 self._graph_decode_wrapper = \ CUDAGraphBatchDecodeWithPagedKVCacheWrapper( self._graph_decode_workspace_buffer, _indptr_buffer, diff --git a/vllm/block.py b/vllm/block.py index 95286048d9115..47c381c19383b 100644 --- a/vllm/block.py +++ b/vllm/block.py @@ -1,9 +1,9 @@ """Token blocks.""" -from typing import List, Optional +from typing import TYPE_CHECKING, Iterator, List, Optional from vllm.utils import Device -DEFAULT_LAST_ACCESSED_TIME = -1 +DEFAULT_LAST_ACCESSED_TIME: float = -1 class PhysicalTokenBlock: @@ -59,6 +59,11 @@ def __len__(self) -> int: def __getitem__(self, key): return self._blocks[key] + if TYPE_CHECKING: + + def __iter__(self) -> Iterator[PhysicalTokenBlock]: + raise RuntimeError("Method should be automatically generated") + def __setitem__(self, key, value): if isinstance(key, slice): blocks = value diff --git a/vllm/config.py b/vllm/config.py index 4cbdde5e113a2..4e014e43d849a 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -61,7 +61,8 @@ class ModelConfig: output when `served_model_name` is not specified. tokenizer: Name or path of the huggingface tokenizer to use. tokenizer_mode: Tokenizer mode. "auto" will use the fast tokenizer if - available, and "slow" will always use the slow tokenizer. + available, "slow" will always use the slow tokenizer, and + "mistral" will always use the tokenizer from `mistral_common`. trust_remote_code: Trust remote code (e.g., from HuggingFace) when downloading the model and tokenizer. dtype: Data type for model weights and activations. The "auto" option @@ -140,6 +141,7 @@ def __init__( skip_tokenizer_init: bool = False, served_model_name: Optional[Union[str, List[str]]] = None, limit_mm_per_prompt: Optional[Mapping[str, int]] = None, + use_async_output_proc: bool = True, ) -> None: self.model = model self.tokenizer = tokenizer @@ -172,6 +174,7 @@ def __init__( self.hf_image_processor_config = get_hf_image_processor_config( self.model, revision) self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.use_async_output_proc = use_async_output_proc # Choose a default enforce_eager value if the user did not specify # a value (enforce_eager is None) @@ -244,10 +247,10 @@ def _init_multimodal_config( def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() - if tokenizer_mode not in ["auto", "slow"]: + if tokenizer_mode not in ["auto", "slow", "mistral"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto' or 'slow'.") + "either 'auto', 'slow' or 'mistral'.") self.tokenizer_mode = tokenizer_mode def _verify_embedding_mode(self) -> None: @@ -326,6 +329,49 @@ def _verify_cuda_graph(self) -> None: self.max_seq_len_to_capture = min(self.max_seq_len_to_capture, self.max_model_len) + def verify_async_output_proc(self, parallel_config, speculative_config, + device_config) -> None: + if not self.use_async_output_proc: + # Nothing to check + return + + if parallel_config.pipeline_parallel_size > 1: + logger.warning("Async output processing can not be enabled " + "with pipeline parallel") + self.use_async_output_proc = False + return + + if device_config.device_type != "cuda": + logger.warning( + "Async output processing is only supported for CUDA." + " Disabling it for other platforms.") + self.use_async_output_proc = False + return + + if envs.VLLM_USE_RAY_SPMD_WORKER: + logger.warning( + "Async output processing can not be enabled with ray spmd") + self.use_async_output_proc = False + return + + if self.enforce_eager: + logger.warning( + "To see benefits of async output processing, enable CUDA " + "graph. Since, enforce-eager is enabled, async output " + "processor cannot be used") + self.use_async_output_proc = not self.enforce_eager + return + + # Async postprocessor is not necessary with embedding mode + # since there is no token generation + if self.embedding_mode: + self.use_async_output_proc = False + + if speculative_config: + logger.warning("Async output processing is not supported with" + " speculative decoding currently.") + self.use_async_output_proc = False + def verify_with_parallel_config( self, parallel_config: "ParallelConfig", @@ -358,6 +404,11 @@ def verify_with_parallel_config( "fallback to the eager mode.") self.enforce_eager = True + if pipeline_parallel_size > 1 and self.use_async_output_proc: + logger.warning("Async output processor is not supported with " + "pipeline parallelism currently. Disabling it.") + self.use_async_output_proc = False + def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -1769,6 +1820,9 @@ class EngineConfig: def __post_init__(self): """Verify configs are valid & consistent with each other. """ + self.model_config.verify_async_output_proc(self.parallel_config, + self.speculative_config, + self.device_config) self.model_config.verify_with_parallel_config(self.parallel_config) self.cache_config.verify_with_parallel_config(self.parallel_config) diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index c6330df2a485a..c87246c1c6d6a 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -132,7 +132,7 @@ def allocate_mutable_block(self, prev_block: Optional[Block], def allocate_immutable_blocks(self, prev_block: Optional[Block], block_token_ids: List[List[int]], - device: Optional[Device]) -> List[Block]: + device: Device) -> List[Block]: """Allocates a new group of immutable blocks with the provided block token IDs on the specified device. diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 432a6651ab07a..a87e814cfb041 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,6 +1,6 @@ """Token blocks.""" from os.path import commonprefix -from typing import Dict, FrozenSet, Iterable, List, Optional, Tuple +from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) @@ -73,6 +73,11 @@ def __init__( # prefix hash will be in this dict, even if they have refcount 0. self._cached_blocks: Dict[PrefixHash, BlockId] = {} + # A list of immutable block IDs that have been touched by scheduler + # and should be marked as computed after an entire batch of sequences + # are scheduled. + self._touched_blocks: Set[BlockId] = set() + # Used to track status of each physical block id self._block_tracker: Dict[BlockId, BlockTracker] = {} for block_id in block_ids: @@ -438,10 +443,14 @@ def promote_to_immutable_block(self, block: Block) -> BlockId: assert self._refcounter.get(block.block_id) > 0 if block.content_hash not in self._cached_blocks: - # No cached content hash => Set this block as cached - # (Note that this block is not computed yet => - # Will be computed after free()) + # No cached content hash => Set this block as cached. + # Note that this block cannot be marked as computed yet + # because other sequences in the same batch cannot reuse + # this block. self._cached_blocks[block.content_hash] = block.block_id + # Mark this block as touched so that it can be marked as + # computed after the entire batch of sequences are scheduled. + self._touched_blocks.add(block.block_id) return block.block_id # Reuse the cached content hash @@ -507,7 +516,10 @@ def mark_blocks_as_accessed(self, block_ids: List[int], "Mark block as accessed which is not belonged to GPU") def mark_blocks_as_computed(self, block_ids: List[int]) -> None: - raise NotImplementedError("Marking as computed is incremental") + # Mark all touched blocks as computed. + for block_id in self._touched_blocks: + self._block_tracker[block_id].computed = True + self._touched_blocks.clear() def _track_block_id(self, block_id: Optional[BlockId], computed: bool) -> None: diff --git a/vllm/core/block_manager_v1.py b/vllm/core/block_manager_v1.py index 0af04399a4b31..666723313c829 100644 --- a/vllm/core/block_manager_v1.py +++ b/vllm/core/block_manager_v1.py @@ -278,7 +278,7 @@ def __init__( # request ID self.cross_block_tables: Dict[str, BlockTable] = {} - def _get_seq_num_required_blocks(self, seq: Sequence) -> int: + def _get_seq_num_required_blocks(self, seq: Optional[Sequence]) -> int: return 0 if seq is None else seq.n_blocks def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: @@ -310,13 +310,14 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: return AllocStatus.LATER def _allocate_sequence(self, \ - seq: Sequence, \ + seq: Optional[Sequence], \ ref_count: int, \ is_encoder_decoder: bool = True) -> BlockTable: # Allocate new physical token blocks that will store the prompt tokens. - num_prompt_blocks = seq.n_blocks + num_prompt_blocks = self._get_seq_num_required_blocks(seq) block_table: BlockTable = BlockTable() + assert seq is not None for logical_idx in range(num_prompt_blocks): if (self.block_sliding_window is not None and logical_idx >= self.block_sliding_window): diff --git a/vllm/core/block_manager_v2.py b/vllm/core/block_manager_v2.py index b7d9451f18067..7d2db43cb4602 100644 --- a/vllm/core/block_manager_v2.py +++ b/vllm/core/block_manager_v2.py @@ -120,8 +120,10 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: ) if seq_group.is_encoder_decoder(): + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None num_required_blocks += BlockTable.get_num_required_blocks( - seq_group.get_encoder_seq().get_token_ids(), + encoder_seq.get_token_ids(), block_size=self.block_size, ) @@ -189,7 +191,9 @@ def allocate(self, seq_group: SequenceGroup) -> None: check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) if seq_group.is_encoder_decoder(): - block_table = self._allocate_sequence(seq_group.get_encoder_seq()) + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + block_table = self._allocate_sequence(encoder_seq) self.cross_block_tables[request_id] = block_table def can_append_slots(self, seq_group: SequenceGroup, @@ -287,11 +291,11 @@ def access_all_blocks_in_seq(self, seq: Sequence, now: float): seq.seq_id, now) def mark_blocks_as_computed(self, seq_group: SequenceGroup): - # The only need for mark block as computed is for prefix caching, - # while currently we could determine whether one block is computed - # or not by check whether it has content hash. - # So this function is useless for block_v2. - pass + # If prefix caching is enabled, mark immutable blocks as computed + # right after they have been scheduled (for prefill). This assumes + # the scheduler is synchronous so blocks are actually computed when + # scheduling the next batch. + self.block_allocator.mark_blocks_as_computed([]) def get_common_computed_block_ids( self, seqs: List[Sequence]) -> GenericSequence[int]: diff --git a/vllm/core/embedding_model_block_manager.py b/vllm/core/embedding_model_block_manager.py index 3d864a73f91d0..f16f66e99e7f8 100644 --- a/vllm/core/embedding_model_block_manager.py +++ b/vllm/core/embedding_model_block_manager.py @@ -77,8 +77,8 @@ def access_all_blocks_in_seq( pass def get_common_computed_block_ids(self, - seq_group: SequenceGroup) -> List[int]: - return None # type: ignore + seq_group: List[Sequence]) -> List[int]: + return [] def mark_blocks_as_computed(self, seq_group: SequenceGroup): pass diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 4e6298eba6c30..b7e3f48522672 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -4,7 +4,8 @@ import time from collections import deque from dataclasses import dataclass, field -from typing import Deque, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import (Callable, Deque, Dict, Iterable, List, Optional, Set, + Tuple, Union) from vllm.config import CacheConfig, LoRAConfig, SchedulerConfig from vllm.core.interfaces import AllocStatus, BlockSpaceManager @@ -220,10 +221,10 @@ class SchedulerSwappedInOutputs: """ # Selected sequences that are going to be swapped in and is in a # decoding phase. - decode_seq_groups: List[SequenceGroup] + decode_seq_groups: List[ScheduledSequenceGroup] # Selected sequences that are going to be swapped in and in a prefill # phase. I.e., it means the prefill has been chunked. - prefill_seq_groups: List[SequenceGroup] + prefill_seq_groups: List[ScheduledSequenceGroup] # The blocks to swap in. blocks_to_swap_in: List[Tuple[int, int]] # The blocks to copy. @@ -253,7 +254,7 @@ class SchedulerPrefillOutputs: to be recomputed from scratch. """ # Selected sequences for prefill. - seq_groups: List[SequenceGroup] + seq_groups: List[ScheduledSequenceGroup] # Ignored sequence groups. ignored_seq_groups: List[SequenceGroup] num_lookahead_slots: int @@ -288,7 +289,9 @@ def scheduler_running_outputs_builder(): def scheduled_seq_group_builder(): - return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) + return ScheduledSequenceGroup(SequenceGroup("", [], -1), + token_chunk_size=0) + # return ScheduledSequenceGroup(seq_group=None, token_chunk_size=0) class Scheduler: @@ -299,6 +302,7 @@ def __init__( cache_config: CacheConfig, lora_config: Optional[LoRAConfig], pipeline_parallel_size: int = 1, + output_proc_callback_fn: Optional[Callable] = None, ) -> None: self.scheduler_config = scheduler_config self.cache_config = cache_config @@ -364,10 +368,36 @@ def __init__( self.num_cumulative_preemption: int = 0 # Used to cache python objects - self._scheduler_running_outputs_cache: PyObjectCache = PyObjectCache( - scheduler_running_outputs_builder) - self._scheduled_seq_group_cache: PyObjectCache = PyObjectCache( - scheduled_seq_group_builder) + self._seq_group_metadata_cache: List[PyObjectCache] = [] + self._scheduler_running_outputs_cache: List[PyObjectCache] = [] + self._scheduled_seq_group_cache: List[PyObjectCache] = [] + + # For async output processing, we need to swap cache buffers between + # iterations. I.e. since the output processing is lagged one step, + # we cannot reuse the cached objects immediately when the schedule() + # is called again, but only when schedule() is called the second time. + self.output_proc_callback_fn = output_proc_callback_fn + self.use_async_output_proc = self.output_proc_callback_fn is not None + self.num_cache_iters = 2 if self.use_async_output_proc else 1 + + self.cache_id = 0 + for i in range(self.num_cache_iters): + self._seq_group_metadata_cache.append( + PyObjectCache(seq_group_metadata_builder)) + self._scheduler_running_outputs_cache.append( + PyObjectCache(scheduler_running_outputs_builder)) + self._scheduled_seq_group_cache.append( + PyObjectCache(scheduled_seq_group_builder)) + + # For async postprocessor, the extra decode run cannot be done + # when the request reaches max_model_len. In this case, the request + # will be stopped during schedule() call and added to this stop list + # for processing and deallocation by the free_finished_seq_groups() + self._async_stopped: List[SequenceGroup] = [] + + @property + def next_cache_id(self): + return (self.cache_id + 1) % self.num_cache_iters @property def lora_enabled(self) -> bool: @@ -483,7 +513,7 @@ def _schedule_running( SchedulerRunningOutputs. """ ret: SchedulerRunningOutputs = \ - self._scheduler_running_outputs_cache.get_object() + self._scheduler_running_outputs_cache[self.cache_id].get_object() ret.blocks_to_swap_out.clear() ret.blocks_to_copy.clear() ret.decode_seq_groups.clear() @@ -510,8 +540,12 @@ def _schedule_running( # NOTE(woosuk): Preemption happens only when there is no available slot # to keep all the sequence groups in the RUNNING state. - running_queue = self.running + # Store original running requests for the case of async + preemption + if self.use_async_output_proc: + orig_running = self.running.copy() + running_queue = self.running + assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] num_running_tokens = self._get_num_new_tokens( @@ -521,6 +555,28 @@ def _schedule_running( break running_queue.popleft() + + # With async postprocessor, an extra decode run is done + # to process the final tokens. The check below avoids this extra + # decode run when the model max len is reached, in order to avoid + # a memory overflow. + if self.use_async_output_proc and seq_group.seqs[0].get_len( + ) > self.scheduler_config.max_model_len: + self._async_stopped.append(seq_group) + continue + + # With async postprocessor, when preemption kicks in, we need + # first to drain the async postprocessor, so that all async + # block_table freeing is applied before the preemption freeing + # is applied. + if self.use_async_output_proc and not self._can_append_slots( + seq_group): + tmp = self.running + self.running = orig_running + assert self.output_proc_callback_fn is not None + self.output_proc_callback_fn(is_async=True) + self.running = tmp + while not self._can_append_slots(seq_group): budget.subtract_num_batched_tokens(seq_group.request_id, num_running_tokens) @@ -556,7 +612,7 @@ def _schedule_running( is_prefill = seq_group.is_prefill() scheduled_seq_group: ScheduledSequenceGroup = \ - self._scheduled_seq_group_cache.get_object() + self._scheduled_seq_group_cache[self.cache_id].get_object() scheduled_seq_group.seq_group = seq_group if is_prefill: scheduled_seq_group.token_chunk_size = num_running_tokens @@ -579,8 +635,8 @@ def _schedule_running( if curr_loras is not None and seq_group.lora_int_id > 0: curr_loras.add(seq_group.lora_int_id) - self._scheduler_running_outputs_cache.reset() - self._scheduled_seq_group_cache.reset() + self._scheduler_running_outputs_cache[self.next_cache_id].reset() + self._scheduled_seq_group_cache[self.next_cache_id].reset() return ret @@ -737,7 +793,7 @@ def _schedule_prefills( SchedulerPrefillOutputs. """ ignored_seq_groups: List[SequenceGroup] = [] - seq_groups: List[SequenceGroup] = [] + seq_groups: List[ScheduledSequenceGroup] = [] waiting_queue = self.waiting @@ -1036,17 +1092,32 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), ) - def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: + def _allow_async_output_proc(self, seq_group: SequenceGroup) -> bool: + no_beam_search = seq_group.sampling_params is None or ( + seq_group.sampling_params.best_of == 1 + and not seq_group.sampling_params.use_beam_search) + + return no_beam_search + + def schedule( + self + ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, bool]: # Schedule sequence groups. # This function call changes the internal states of the scheduler # such as self.running, self.swapped, and self.waiting. scheduler_start_time = time.perf_counter() + scheduler_outputs = self._schedule() now = time.time() if not self.cache_config.enable_prefix_caching: common_computed_block_nums = [] + # TODO: Combine multi-step and async postprocessor + allow_async_output_proc: bool = ( + self.use_async_output_proc + and not self.scheduler_config.is_multi_step) + # Create input data structures. seq_group_metadata_list: List[SequenceGroupMetadata] = [] for i, scheduled_seq_group in enumerate( @@ -1055,6 +1126,11 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: token_chunk_size = scheduled_seq_group.token_chunk_size seq_group.maybe_set_first_scheduled_time(now) + seq_group_metadata = self._seq_group_metadata_cache[ + self.cache_id].get_object() + seq_group_metadata.seq_data.clear() + seq_group_metadata.block_tables.clear() + # seq_id -> SequenceData seq_data: Dict[int, SequenceData] = {} # seq_id -> physical block numbers @@ -1062,7 +1138,9 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: if seq_group.is_encoder_decoder(): # Encoder associated with SequenceGroup - encoder_seq_data = seq_group.get_encoder_seq().data + encoder_seq = seq_group.get_encoder_seq() + assert encoder_seq is not None + encoder_seq_data = encoder_seq.data # Block table for cross-attention # Also managed at SequenceGroup level cross_block_table = self.block_manager.get_cross_block_table( @@ -1144,6 +1222,10 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: ) seq_group_metadata_list.append(seq_group_metadata) + if allow_async_output_proc: + allow_async_output_proc = self._allow_async_output_proc( + seq_group) + # Now that the batch has been created, we can assume all blocks in the # batch will have been computed before the next scheduling invocation. # This is because the engine assumes that a failure in model execution @@ -1152,6 +1234,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: self.block_manager.mark_blocks_as_computed( scheduled_seq_group.seq_group) + self._seq_group_metadata_cache[self.next_cache_id].reset() + scheduler_time = time.perf_counter() - scheduler_start_time # Add this to scheduler time to all the sequences that are currently # running. This will help estimate if the scheduler is a significant @@ -1163,7 +1247,12 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: else: seq_group.metrics.scheduler_time = scheduler_time - return seq_group_metadata_list, scheduler_outputs + # Move to next cache (if exists) + self.cache_id = self.next_cache_id + + # Return results + return (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) def fork_seq(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_manager.fork(parent_seq, child_seq) @@ -1172,6 +1261,12 @@ def free_seq(self, seq: Sequence) -> None: """Free a sequence from a block table.""" self.block_manager.free(seq) + def _free_finished_seqs(self, seq_group: SequenceGroup) -> None: + """Free finished seqs in a sequence group.""" + for seq in seq_group.get_seqs(): + if seq.is_finished(): + self.free_seq(seq) + def free_finished_seq_groups(self) -> None: remaining: Deque[SequenceGroup] = deque() for seq_group in self.running: @@ -1184,8 +1279,24 @@ def free_finished_seq_groups(self) -> None: self._finished_requests_ids.append(seq_group.request_id) else: remaining.append(seq_group) + + # Free finished seqs + self._free_finished_seqs(seq_group) + self.running = remaining + # Handle async stopped sequence groups + # (ones that reached max model len) + if self._async_stopped: + for seq_group in self._async_stopped: + self._free_seq_group_cross_attn_blocks(seq_group) + self._finished_requests_ids.append(seq_group.request_id) + + # Free finished seqs + self._free_finished_seqs(seq_group) + + self._async_stopped.clear() + def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None: self.block_manager.allocate(seq_group) for seq in seq_group.get_seqs(status=SequenceStatus.WAITING): diff --git a/vllm/distributed/device_communicators/custom_all_reduce_utils.py b/vllm/distributed/device_communicators/custom_all_reduce_utils.py index 37ae94c671e33..983e772a3f79b 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce_utils.py +++ b/vllm/distributed/device_communicators/custom_all_reduce_utils.py @@ -4,6 +4,7 @@ import pickle import subprocess import sys +import tempfile from itertools import product from typing import Dict, List, Optional, Sequence @@ -211,20 +212,27 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: # However, `can_actually_p2p` requires spawn method. # The fix is, we use `subprocess` to call the function, # where we have `if __name__ == "__main__":` in this file. - input_bytes = pickle.dumps((batch_src, batch_tgt)) - returned = subprocess.run([sys.executable, __file__], - input=input_bytes, - capture_output=True) - # check if the subprocess is successful - try: - returned.check_returncode() - except Exception as e: - # wrap raised exception to provide more information - raise RuntimeError( - f"Error happened when batch testing " - f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" - f"{returned.stderr.decode()}") from e - result = pickle.loads(returned.stdout) + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps( + (batch_src, batch_tgt, output_file.name)) + returned = subprocess.run([sys.executable, __file__], + input=input_bytes, + capture_output=True) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) for _i, _j, r in zip(batch_src, batch_tgt, result): cache[f"{_i}->{_j}"] = r with open(path, "w") as f: @@ -241,6 +249,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool: __all__ = ["gpu_p2p_access_check"] if __name__ == "__main__": - batch_src, batch_tgt = pickle.loads(sys.stdin.buffer.read()) + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) result = can_actually_p2p(batch_src, batch_tgt) - sys.stdout.buffer.write(pickle.dumps(result)) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 987c1be3d5ad9..6e66198e203fc 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -147,6 +147,7 @@ class EngineArgs: otlp_traces_endpoint: Optional[str] = None collect_detailed_traces: Optional[str] = None + disable_async_output_proc: bool = False def __post_init__(self): if self.tokenizer is None: @@ -197,10 +198,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: '--tokenizer-mode', type=str, default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], + choices=['auto', 'slow', 'mistral'], help='The tokenizer mode.\n\n* "auto" will use the ' 'fast tokenizer if available.\n* "slow" will ' - 'always use the slow tokenizer.') + 'always use the slow tokenizer. \n* ' + '"mistral" will always use the `mistral_common` tokenizer.') parser.add_argument('--trust-remote-code', action='store_true', help='Trust remote code from huggingface.') @@ -317,9 +319,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument('--block-size', type=int, default=EngineArgs.block_size, - choices=[8, 16, 32, 128, 256, 512, 1024, 2048], + choices=[8, 16, 32], help='Token block size for contiguous chunks of ' - 'tokens.') + 'tokens. This is ignored on neuron devices and ' + 'set to max-model-len') parser.add_argument('--enable-prefix-caching', action='store_true', @@ -732,6 +735,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "modules. This involves use of possibly costly and or blocking " "operations and hence might have a performance impact.") + parser.add_argument( + '--disable-async-output-proc', + action='store_true', + default=EngineArgs.disable_async_output_proc, + help="Disable async output processing. This may result in " + "lower performance.") return parser @classmethod @@ -791,9 +800,11 @@ def create_engine_config(self) -> EngineConfig: skip_tokenizer_init=self.skip_tokenizer_init, served_model_name=self.served_model_name, limit_mm_per_prompt=self.limit_mm_per_prompt, + use_async_output_proc=not self.disable_async_output_proc, ) cache_config = CacheConfig( - block_size=self.block_size, + block_size=self.block_size if self.device != "neuron" else + self.max_model_len, # neuron needs block_size = max_model_len gpu_memory_utilization=self.gpu_memory_utilization, swap_space=self.swap_space, cache_dtype=self.kv_cache_dtype, diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index a2a80b1412132..10e14ff996f36 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -277,23 +277,36 @@ async def step_async( cached_outputs = self.cached_scheduler_outputs[virtual_engine] seq_group_metadata_list = cached_outputs.seq_group_metadata_list scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc + # skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - virtual_engine].schedule() + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc + ) = self.scheduler[virtual_engine].schedule() + + # If current scheduler iteration has no async postprocessor, + # then we need first to drain the pending async postprocessor + # before moving forward + if not allow_async_output_proc and len(self.output_queue) > 0: + self._process_model_outputs(is_async=True) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots self._cache_scheduler_outputs_for_multi_step( - virtual_engine, seq_group_metadata_list, scheduler_outputs) + virtual_engine, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) assert seq_group_metadata_list is not None assert scheduler_outputs is not None + assert not (self.scheduler_config.is_multi_step and \ + allow_async_output_proc) + if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ virtual_engine].get_and_reset_finished_requests_ids() @@ -317,6 +330,11 @@ async def step_async( # We use ExecuteModelRequest to pass the last sampled_token_ids # to each of the non-last PP stages for in-place prepare_input. last_sampled_token_ids=last_sampled_token_ids) + + if allow_async_output_proc: + execute_model_req.output_proc_callback_fn = \ + self._process_model_outputs + # Execute the model. output = await self.model_executor.execute_model_async( execute_model_req) @@ -325,6 +343,9 @@ async def step_async( if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(virtual_engine, output) else: + if len(self.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(is_async=True) output = [] # Finish the current step for all the sequence groups. @@ -337,19 +358,32 @@ async def step_async( if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[ virtual_engine] = SchedulerOutputState() - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) - else: - request_outputs = [] - # Log stats. - self.do_log_stats(scheduler_outputs, output) + # Cache results in engine + self.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - # Tracing - self.do_tracing(scheduler_outputs) + if output and allow_async_output_proc: + assert len( + output + ) == 1, "Multi step decoding does not work with async output processing." # noqa: E501 + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) - return request_outputs + if not allow_async_output_proc: + self._process_model_outputs(is_async=False) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + # Tracing + self.do_tracing(scheduler_outputs) + + else: + self.request_outputs = [] + + return self.request_outputs async def stop_remote_worker_execution_loop_async(self) -> None: """Stop the remote worker execution loop.""" @@ -632,6 +666,11 @@ def _get_executor_cls( initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutorAsync executor_class = RayXPUExecutorAsync + elif distributed_executor_backend == "mp": + initialize_ray_cluster(engine_config.parallel_config) + from vllm.executor.multiproc_xpu_executor import ( + MultiprocessingXPUExecutorAsync) + executor_class = MultiprocessingXPUExecutorAsync else: raise RuntimeError( "Not supported distributed execution model on XPU device.") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 79072e403dc1b..addde032f2639 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,7 +1,8 @@ import time +from collections import deque from contextlib import contextmanager from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, ClassVar, Dict, Iterable, List, +from typing import (TYPE_CHECKING, Any, ClassVar, Deque, Dict, Iterable, List, Mapping, Optional) from typing import Sequence as GenericSequence from typing import Set, Tuple, Type, Union @@ -38,9 +39,8 @@ from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.sampling_params import SamplingParams from vllm.sequence import (EmbeddingSequenceGroupOutput, ExecuteModelRequest, - PoolerOutput, SamplerOutput, Sequence, - SequenceGroup, SequenceGroupMetadata, - SequenceStatus) + SamplerOutput, Sequence, SequenceGroup, + SequenceGroupMetadata, SequenceStatus) from vllm.tracing import (SpanAttributes, SpanKind, extract_trace_context, init_tracer) from vllm.transformers_utils.config import try_get_generation_config @@ -82,9 +82,10 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: @dataclass class SchedulerOutputState: """Caches the scheduler outputs for a virtual engine. Used for Multi-Step""" - last_output: Optional[SamplerOutput] = None seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] = None scheduler_outputs: Optional[SchedulerOutputs] = None + allow_async_output_proc: bool = False + last_output: Optional[SamplerOutput] = None class LLMEngine: @@ -190,6 +191,9 @@ def __init__( usage_context: UsageContext = UsageContext.ENGINE_CONTEXT, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, input_registry: InputRegistry = INPUT_REGISTRY, + # To improve performance, only final requests outputs may be required. + # If this set to true, then no intermediate outputs will be returned. + step_return_finished_only: bool = False, ) -> None: logger.info( "Initializing an LLM engine (v%s) with config: " @@ -204,7 +208,8 @@ def __init__( "quantization_param_path=%s, device_config=%s, " "decoding_config=%r, observability_config=%r, " "seed=%d, served_model_name=%s, use_v2_block_manager=%s, " - "num_scheduler_steps=%d, enable_prefix_caching=%s)", + "num_scheduler_steps=%d, enable_prefix_caching=%s, " + "use_async_output_proc=%s)", VLLM_VERSION, model_config.model, speculative_config, @@ -235,6 +240,7 @@ def __init__( scheduler_config.use_v2_block_manager, scheduler_config.num_scheduler_steps, cache_config.enable_prefix_caching, + model_config.use_async_output_proc, ) # TODO(woosuk): Print more configs in debug mode. from vllm.plugins import load_general_plugins @@ -253,6 +259,7 @@ def __init__( self.observability_config = observability_config or ObservabilityConfig( ) self.log_stats = log_stats + self.step_return_finished_only = step_return_finished_only if not self.model_config.skip_tokenizer_init: self.tokenizer = self._init_tokenizer() @@ -340,8 +347,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # NOTE: the cache_config here have been updated with the numbers of # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ - Scheduler(scheduler_config, cache_config, lora_config, - parallel_config.pipeline_parallel_size) + Scheduler( + scheduler_config, cache_config, lora_config, + parallel_config.pipeline_parallel_size, + self._process_model_outputs + if model_config.use_async_output_proc else None) for _ in range(parallel_config.pipeline_parallel_size) ] @@ -396,6 +406,13 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] + # Async output processing pointers + self.output_queue: Deque[Tuple[List[SamplerOutput], + List[SequenceGroupMetadata], + SchedulerOutputs]] = deque() + self.request_outputs: List[Union[RequestOutput, + EmbeddingRequestOutput]] = [] + def _initialize_kv_caches(self) -> None: """Initialize the KV cache in the worker(s). @@ -455,6 +472,13 @@ def _get_executor_cls(cls, initialize_ray_cluster(engine_config.parallel_config) from vllm.executor.ray_xpu_executor import RayXPUExecutor executor_class = RayXPUExecutor + elif distributed_executor_backend == "mp": + # FIXME(kunshang): + # spawn needs calling `if __name__ == '__main__':`` + # fork is not supported for xpu start new process. + logger.error( + "Both start methods (spawn and fork) have issue " + "on XPU if you use mp backend, Please try ray instead.") else: from vllm.executor.xpu_executor import XPUExecutor executor_class = XPUExecutor @@ -1197,34 +1221,66 @@ def _process_sequence_group_outputs( return - def _process_model_outputs( - self, - output: GenericSequence[Union[SamplerOutput, PoolerOutput]], - scheduled_seq_groups: List[ScheduledSequenceGroup], - ignored_seq_groups: List[SequenceGroup], - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: + def _process_model_outputs(self, + is_async: bool, + clear_outputs: bool = True) -> None: """Apply the model output to the sequences in the scheduled seq groups. + is_async: Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) + clear_outputs: Sometimes existing outputs need to be combined + with outputs of this call. This happens for postprocessor + draining at the final stage (like when sequences are finished) + Returns RequestOutputs that can be returned to the client. """ - now = time.time() - # Organize outputs by [sequence group][step] instead of - # [step][sequence group]. - output_by_sequence_group = create_output_by_sequence_group( - output, num_seq_groups=len(scheduled_seq_groups)) + if clear_outputs: + self.request_outputs.clear() + + if len(self.output_queue) == 0: + return None + + (outputs, seq_group_metadata_list, + scheduler_outputs) = self.output_queue.popleft() + + # Sanity check + assert len(seq_group_metadata_list) == len( + scheduler_outputs.scheduled_seq_groups) + + # Organize outputs by [step][sequence group] instead of + # [sequence group][step]. + if len(outputs) > 1: + outputs_by_sequence_group = create_output_by_sequence_group( + outputs, num_seq_groups=len(seq_group_metadata_list)) + else: + outputs_by_sequence_group = outputs + + finished_before: List[int] = [] + for i, seq_group_meta in enumerate(seq_group_metadata_list): + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] - # Update the scheduled sequence groups with the model outputs. - for scheduled_seq_group, outputs, seq_group_meta in zip( - scheduled_seq_groups, output_by_sequence_group, - seq_group_metadata_list): seq_group = scheduled_seq_group.seq_group - seq_group.update_num_computed_tokens( - scheduled_seq_group.token_chunk_size) - if output is not None and len(output) > 0: - for o in output: + + if seq_group.is_finished(): + finished_before.append(i) + continue + + if len(outputs) > 1: + output = outputs_by_sequence_group[i] + else: + output = [outputs_by_sequence_group[0][i]] + + if not is_async: + seq_group.update_num_computed_tokens( + scheduled_seq_group.token_chunk_size) + + if outputs: + for o in outputs: if (isinstance(o, SamplerOutput) and seq_group.metrics is not None): if seq_group.metrics.model_forward_time is not None: @@ -1239,30 +1295,75 @@ def _process_model_outputs( else: seq_group.metrics.model_execute_time = ( o.model_execute_time) + if self.model_config.embedding_mode: - self._process_sequence_group_outputs(seq_group, outputs) + self._process_sequence_group_outputs(seq_group, output) continue - self.output_processor.process_prompt_logprob(seq_group, outputs) + self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - self.output_processor.process_outputs(seq_group, outputs) + self.output_processor.process_outputs(seq_group, output, + is_async) # Free the finished sequence groups. for scheduler in self.scheduler: scheduler.free_finished_seq_groups() # Create the outputs. - request_outputs: List[Union[RequestOutput, - EmbeddingRequestOutput]] = [] - for scheduled_seq_group in scheduled_seq_groups: + for i, _ in enumerate(seq_group_metadata_list): + scheduled_seq_group = scheduler_outputs.scheduled_seq_groups[i] + + if i in finished_before: + continue # Avoids double processing + seq_group = scheduled_seq_group.seq_group seq_group.maybe_set_first_token_time(now) + if (seq_group.is_finished() + if self.step_return_finished_only else True): + request_output = RequestOutputFactory.create(seq_group) + self.request_outputs.append(request_output) + + for seq_group in scheduler_outputs.ignored_seq_groups: request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - for seq_group in ignored_seq_groups: - request_output = RequestOutputFactory.create(seq_group) - request_outputs.append(request_output) - return request_outputs + self.request_outputs.append(request_output) + + if is_async: + # Log stats. + self.do_log_stats(scheduler_outputs, outputs, finished_before) + + # Tracing + self.do_tracing(scheduler_outputs) + + return None + + def _advance_to_next_step( + self, output: List[SamplerOutput], + seq_group_metadata_list: List[SequenceGroupMetadata], + scheduled_seq_groups: List[ScheduledSequenceGroup]) -> None: + """Given model output from a single run, append the tokens to the + sequences. This is normally done inside output processor, but it is + required if the worker is to perform async forward pass to next step. + """ + for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in \ + zip(seq_group_metadata_list, output, scheduled_seq_groups): + seq_group = scheduled_seq_group.seq_group + + if seq_group.is_finished(): + continue + + seq_group.update_num_computed_tokens( + seq_group_metadata.token_chunk_size) + + if seq_group_metadata.do_sample: + assert len(sequence_group_outputs.samples) == 1, ( + "Async output processor expects a single sample" + " (i.e sampling_params.n == 1 and no " + "sampling_params.best_of > 1)") + sample = sequence_group_outputs.samples[0] + + assert len(seq_group.seqs) == 1 + seq = seq_group.seqs[0] + seq.append_token_id(sample.output_token, sample.logprobs) def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: """Performs one decoding iteration and returns newly generated results. @@ -1325,24 +1426,32 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: cached_outputs = self.cached_scheduler_outputs[0] seq_group_metadata_list = cached_outputs.seq_group_metadata_list scheduler_outputs = cached_outputs.scheduler_outputs + allow_async_output_proc = cached_outputs.allow_async_output_proc # Skip the scheduler if there are any remaining steps in the seq groups. # This ensures that the scheduler is only called again when the current # batch has completed. if not self._has_remaining_steps(seq_group_metadata_list): - seq_group_metadata_list, scheduler_outputs = self.scheduler[ - 0].schedule() + (seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) = self.scheduler[0].schedule() + + if not allow_async_output_proc and len(self.output_queue) > 0: + self._process_model_outputs(is_async=True) if (self.scheduler_config.is_multi_step and scheduler_outputs.num_lookahead_slots > 0): # cache the scheduler outputs for the next iteration if we have # lookahead slots self._cache_scheduler_outputs_for_multi_step( - 0, seq_group_metadata_list, scheduler_outputs) + 0, seq_group_metadata_list, scheduler_outputs, + allow_async_output_proc) assert seq_group_metadata_list is not None assert scheduler_outputs is not None + assert not (self.scheduler_config.is_multi_step and \ + allow_async_output_proc) + if not scheduler_outputs.is_empty(): finished_requests_ids = self.scheduler[ 0].get_and_reset_finished_requests_ids() @@ -1366,6 +1475,10 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # to each of the non-last PP stages for in-place prepare_input. last_sampled_token_ids=last_sampled_token_ids) + if allow_async_output_proc: + execute_model_req.output_proc_callback_fn = \ + self._process_model_outputs + output = self.model_executor.execute_model( execute_model_req=execute_model_req) @@ -1374,6 +1487,9 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: if self.scheduler_config.is_multi_step: self._update_cached_scheduler_output(0, output) else: + if len(self.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(is_async=True) output = [] # Finish the current step for all the sequence groups. @@ -1382,23 +1498,41 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: seq_group.finish_step() if not self._has_remaining_steps(seq_group_metadata_list): - # clear the cache if we have finished all the steps + # clear the cache if we have finished all the steps. if self.scheduler_config.is_multi_step: self.cached_scheduler_outputs[0] = SchedulerOutputState() - request_outputs = self._process_model_outputs( - output, scheduler_outputs.scheduled_seq_groups, - scheduler_outputs.ignored_seq_groups, seq_group_metadata_list) - else: - request_outputs = [] + # Add results to the output_queue + # (for async or non-async postprocessing) + self.output_queue.append( + (output, seq_group_metadata_list, scheduler_outputs)) - # Log stats. - self.do_log_stats(scheduler_outputs, output) + if output and allow_async_output_proc: + assert len(output) == 1, ("Multi step decoding does not work " + "with async output processing.") - # Tracing - self.do_tracing(scheduler_outputs) + self._advance_to_next_step( + output[0], seq_group_metadata_list, + scheduler_outputs.scheduled_seq_groups) + + if not allow_async_output_proc: + self._process_model_outputs(is_async=False) + + # Log stats. + self.do_log_stats(scheduler_outputs, output) + + # Tracing + self.do_tracing(scheduler_outputs) + else: + self.request_outputs = [] if not self.has_unfinished_requests(): + # Drain async postprocessor + if len(self.output_queue) > 0: + assert not self.scheduler_config.is_multi_step + self._process_model_outputs(is_async=True, clear_outputs=False) + assert len(self.output_queue) == 0 + # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in # torch.distributed ops which may otherwise timeout, and unblocks @@ -1406,7 +1540,7 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # queued control plane messages, such as add/remove lora adapters. self.model_executor.stop_remote_worker_execution_loop() - return request_outputs + return self.request_outputs def _has_remaining_steps( self, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]] @@ -1431,12 +1565,14 @@ def _has_remaining_steps( def _cache_scheduler_outputs_for_multi_step( self, virtual_engine: int, seq_group_metadata_list: Optional[List[SequenceGroupMetadata]], - scheduler_outputs: SchedulerOutputs) -> None: - self.cached_scheduler_outputs[ - virtual_engine].seq_group_metadata_list = seq_group_metadata_list - self.cached_scheduler_outputs[virtual_engine].scheduler_outputs = \ - scheduler_outputs - self.cached_scheduler_outputs[virtual_engine].last_output = None + scheduler_outputs: SchedulerOutputs, + allow_async_output_proc: bool) -> None: + co = self.cached_scheduler_outputs[virtual_engine] + + co.seq_group_metadata_list = seq_group_metadata_list + co.scheduler_outputs = scheduler_outputs + co.allow_async_output_proc = allow_async_output_proc + co.last_output = None def _update_cached_scheduler_output( self, virtual_engine: int, @@ -1472,20 +1608,21 @@ def remove_logger(self, logger_name: str) -> None: raise KeyError(f"Logger with name {logger_name} does not exist.") del self.stat_loggers[logger_name] - def do_log_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs] = None, - model_output: Optional[List[SamplerOutput]] = None) -> None: + def do_log_stats(self, + scheduler_outputs: Optional[SchedulerOutputs] = None, + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None) -> None: """Forced log when no requests active.""" if self.log_stats: - stats = self._get_stats(scheduler_outputs, model_output) + stats = self._get_stats(scheduler_outputs, model_output, + finished_before) for logger in self.stat_loggers.values(): logger.log(stats) - def _get_stats( - self, - scheduler_outputs: Optional[SchedulerOutputs], - model_output: Optional[List[SamplerOutput]] = None) -> Stats: + def _get_stats(self, + scheduler_outputs: Optional[SchedulerOutputs], + model_output: Optional[List[SamplerOutput]] = None, + finished_before: Optional[List[int]] = None) -> Stats: """Get Stats to be Logged to Prometheus. Args: @@ -1550,6 +1687,10 @@ def _get_stats( # NOTE: This loop assumes prefill seq_groups are before # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: + # For async postprocessor, already finished sequences need to be + # not counted (to avoid double counting) + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + num_generation_tokens_from_prefill_groups = 0. # NOTE: if scheduler_outputs.num_prefill_groups > 0 and # the len of scheduler_outputs.scheduled_seq_groups is != @@ -1558,6 +1699,11 @@ def _get_stats( for idx, scheduled_seq_group in enumerate( scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group @@ -1592,7 +1738,6 @@ def _get_stats( # Latency timings time_e2e_requests.append(now - seq_group.metrics.arrival_time) - # Metadata num_prompt_tokens_requests.append( len(seq_group.prompt_token_ids)) @@ -1616,7 +1761,7 @@ def _get_stats( # + num_generation_tokens_from_prefill_groups (since we generate # one token on prefills on iters where the prefill finishes). num_generation_tokens_iter = ( - scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + + actual_num_batched_tokens - num_prompt_tokens_iter + num_generation_tokens_from_prefill_groups) # Spec decode, if enabled, emits specialized metrics from the worker in diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index a385f37d807ad..50adaf4e59188 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -40,13 +40,9 @@ def create_output_processor( # Importing here to avoid cycle. from vllm.engine.output_processor.single_step import ( SingleStepOutputProcessor) - return SingleStepOutputProcessor( - scheduler_config, - detokenizer, - scheduler, - seq_counter, - stop_checker, - ) + return SingleStepOutputProcessor(scheduler_config, detokenizer, + scheduler, seq_counter, + stop_checker) else: # Importing here to avoid cycle. from vllm.engine.output_processor.multi_step import ( @@ -61,7 +57,8 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 6c472528a7a9c..49a33ded5fcaa 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -57,17 +57,28 @@ def _log_prompt_logprob_unsupported_warning_once(): "Prompt logprob is not supported by multi step workers. " "(e.g., speculative decode uses multi step workers).") - def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + def process_outputs(self, + sequence_group: SequenceGroup, + outputs: List[SequenceGroupOutput], + is_async: bool = False) -> None: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than one new token per sequence. - This applies logic like stop condition checking and detokenization, - including freeing finished sequences. It also handles cases where there - are tokens emitted after the EOS token. + This applies logic like stop condition checking and detokenization. + It also handles cases where there are tokens emitted after + the EOS token. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) """ + # TODO: Add support for async if necessary + assert not is_async + seqs = sequence_group.get_seqs(status=SequenceStatus.RUNNING) assert seqs, "expected running sequences" @@ -138,7 +149,3 @@ def _process_seq_outputs(self, seq: Sequence, ) if seq.is_finished(): break - - if seq.is_finished(): - for scheduler in self.scheduler: - scheduler.free_seq(seq) diff --git a/vllm/engine/output_processor/single_step.py b/vllm/engine/output_processor/single_step.py index 4a46c93f84256..4b0c3f37a5e21 100644 --- a/vllm/engine/output_processor/single_step.py +++ b/vllm/engine/output_processor/single_step.py @@ -29,14 +29,9 @@ class SingleStepOutputProcessor(SequenceGroupOutputProcessor): that is currently difficult to schedule multiple steps ahead of time. """ - def __init__( - self, - scheduler_config: SchedulerConfig, - detokenizer: Detokenizer, - scheduler: List[Scheduler], - seq_counter: Counter, - stop_checker: StopChecker, - ): + def __init__(self, scheduler_config: SchedulerConfig, + detokenizer: Detokenizer, scheduler: List[Scheduler], + seq_counter: Counter, stop_checker: StopChecker): self.scheduler_config = scheduler_config self.detokenizer = detokenizer self.scheduler = scheduler @@ -44,16 +39,24 @@ def __init__( self.stop_checker = stop_checker def process_outputs(self, sequence_group: SequenceGroup, - outputs: List[SequenceGroupOutput]) -> None: + outputs: List[SequenceGroupOutput], + is_async: bool) -> None: """Append all new tokens to sequences in the sequence group. Fork any surviving beam candidates; free any unsurviving ones. Invokes detokenizer to detokenize new tokens, and also marks sequences as finished if they meet stop conditions. + + is_async - Indicates whether this postprocessor runs in + parallel with the GPU forward pass and is processing + tokens from the previous step. If this is true, then + no tokens need to be appended since it is already done + externally (before the next schedule() call) """ assert (len(outputs) == 1 ), f"{type(self)} does not support multiple outputs per step" - return self._process_sequence_group_outputs(sequence_group, outputs[0]) + return self._process_sequence_group_outputs(sequence_group, outputs[0], + is_async) def process_prompt_logprob(self, seq_group: SequenceGroup, outputs: List[SequenceGroupOutput]) -> None: @@ -80,14 +83,16 @@ def process_prompt_logprob(self, seq_group: SequenceGroup, seq_group.prompt_logprobs.extend(prompt_logprobs) def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: + outputs: SequenceGroupOutput, + is_async: bool) -> None: sampling_params = seq_group.sampling_params if sampling_params.n == 1 and not sampling_params.use_beam_search: # only have one output sample sample = outputs.samples[0] # only have one sequence seq = seq_group.seqs[0] - seq.append_token_id(sample.output_token, sample.logprobs) + if not is_async: + seq.append_token_id(sample.output_token, sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) @@ -104,6 +109,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, scheduler.free_seq(seq) return + # TODO: Add support for async for beam search + assert not is_async + # Process samples samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index 19d1095084293..c5368ac3bf026 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -267,7 +267,7 @@ def apply_chat_template( *, tokenize: bool = False, # Different from HF's default **kwargs: Any, -) -> str: +) -> Union[str, List[int]]: if chat_template is None and tokenizer.chat_template is None: raise ValueError( "As of transformers v4.44, default chat template is no longer " @@ -280,6 +280,4 @@ def apply_chat_template( tokenize=tokenize, **kwargs, ) - assert isinstance(prompt, str) - return prompt diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 31175724c6c79..0edd4bfaecd6a 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -129,6 +129,7 @@ def __init__( max_context_len_to_capture: Optional[int] = None, max_seq_len_to_capture: int = 8192, disable_custom_all_reduce: bool = False, + disable_async_output_proc: bool = False, **kwargs, ) -> None: ''' @@ -170,6 +171,7 @@ def __init__( max_context_len_to_capture=max_context_len_to_capture, max_seq_len_to_capture=max_seq_len_to_capture, disable_custom_all_reduce=disable_custom_all_reduce, + disable_async_output_proc=disable_async_output_proc, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args( @@ -388,15 +390,21 @@ def chat( conversations, _ = parse_chat_messages(messages, model_config, tokenizer) - prompts = apply_chat_template( + prompt = apply_chat_template( tokenizer, conversations, chat_template=chat_template, add_generation_prompt=add_generation_prompt) + inputs: PromptInputs + if isinstance(prompt, list) and isinstance(prompt[0], int): + inputs = TokensPrompt(prompt_token_ids=prompt) + else: + inputs = TextPrompt(prompt=prompt) + return self.generate( - prompts, - sampling_params, + inputs, + sampling_params=sampling_params, use_tqdm=use_tqdm, lora_request=lora_request, ) @@ -603,7 +611,6 @@ def _validate_and_add_requests( inputs = [inputs] num_requests = len(inputs) - if isinstance(params, list) and len(params) != num_requests: raise ValueError("The lengths of prompts and params " "must be the same.") @@ -678,6 +685,10 @@ def _run_engine( postfix=(f"est. speed input: {0:.2f} toks/s, " f"output: {0:.2f} toks/s"), ) + + # In the loop below, only finished outputs are used + self.llm_engine.step_return_finished_only = True + # Run the engine. outputs: List[Union[RequestOutput, EmbeddingRequestOutput]] = [] total_in_toks = 0 @@ -700,6 +711,10 @@ def _run_engine( f"est. speed input: {in_spd:.2f} toks/s, " f"output: {out_spd:.2f} toks/s") pbar.update(1) + + # Restore original behavior + self.llm_engine.step_return_finished_only = False + if use_tqdm: pbar.close() # Sort the outputs by request ID. diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 4d8e240a88ee6..d31ac4995fe2f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -22,7 +22,8 @@ FunctionCall, ToolCall, UsageInfo) from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, OpenAIServing, - PromptAdapterPath) + PromptAdapterPath, + TextTokensPrompt) from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.multimodal import MultiModalDataDict @@ -130,13 +131,22 @@ async def create_chat_completion( guided_decode_logits_processor = ( await self._guided_decode_logits_processor(request, tokenizer)) - prompt_inputs = self._tokenize_prompt_input( - request, - tokenizer, - prompt, - truncate_prompt_tokens=request.truncate_prompt_tokens, - add_special_tokens=request.add_special_tokens, - ) + if isinstance(prompt, str): + prompt_inputs = self._tokenize_prompt_input( + request, + tokenizer, + prompt, + truncate_prompt_tokens=request.truncate_prompt_tokens, + add_special_tokens=request.add_special_tokens, + ) + else: + assert isinstance(prompt, list) and isinstance( + prompt[0], int + ), "Prompt has to be either a string or a list of token ids" + prompt_inputs = TextTokensPrompt( + prompt=tokenizer.decode(prompt), prompt_token_ids=prompt) + + assert prompt_inputs is not None sampling_params = request.to_sampling_params( tokenizer, diff --git a/vllm/executor/distributed_gpu_executor.py b/vllm/executor/distributed_gpu_executor.py index 4df54a09e5e8c..1a35a7c3b8f75 100644 --- a/vllm/executor/distributed_gpu_executor.py +++ b/vllm/executor/distributed_gpu_executor.py @@ -64,8 +64,9 @@ def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks=num_cpu_blocks) def execute_model( - self, - execute_model_req: ExecuteModelRequest) -> List[SamplerOutput]: + self, + execute_model_req: ExecuteModelRequest, + ) -> List[SamplerOutput]: if self.parallel_worker_tasks is None: self.parallel_worker_tasks = self._run_workers( "start_worker_execution_loop", @@ -188,7 +189,7 @@ async def stop_remote_worker_execution_loop_async(self) -> None: @abstractmethod async def _driver_execute_model_async( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> List[SamplerOutput]: """Execute the model asynchronously in the driver worker. diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index 8346c3cc1d3ea..795692195f84d 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -176,5 +176,5 @@ async def execute_model_async( execute_model_req: ExecuteModelRequest, ) -> List[Union[SamplerOutput, PoolerOutput]]: output = await make_async(self.driver_worker.execute_model - )(execute_model_req=execute_model_req, ) + )(execute_model_req=execute_model_req) return output diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 08a35a074b37b..7b98fbea5cd0a 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -30,16 +30,12 @@ class MultiprocessingGPUExecutor(DistributedGPUExecutor): uses_ray: bool = False def _init_executor(self) -> None: + self._check_executor_parameters() + # Create the parallel GPU workers. world_size = self.parallel_config.world_size tensor_parallel_size = self.parallel_config.tensor_parallel_size - # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers - if "CUDA_VISIBLE_DEVICES" not in os.environ: - update_environment_variables({ - "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) - }) - # Ensure that VLLM_INSTANCE_ID is set, to be inherited by workers os.environ["VLLM_INSTANCE_ID"] = get_vllm_instance_id() @@ -68,16 +64,6 @@ def _init_executor(self) -> None: if world_size > 1: maybe_set_triton_cache_manager() - cuda_device_count = cuda_device_count_stateless() - # Use confusing message for more common TP-only case. - assert tensor_parallel_size <= cuda_device_count, ( - f"please set tensor_parallel_size ({tensor_parallel_size}) " - f"to less than max local gpu count ({cuda_device_count})") - - assert world_size <= cuda_device_count, ( - f"please ensure that world_size ({world_size}) " - f"is less than than max local gpu count ({cuda_device_count})") - # Multiprocessing-based executor does not support multi-node setting. # Since it only works for single node, we can use the loopback address # 127.0.0.1 for communication. @@ -139,6 +125,26 @@ def shutdown(signum, frame): max_concurrent_workers=self.parallel_config. max_parallel_loading_workers) + def _check_executor_parameters(self): + world_size = self.parallel_config.tensor_parallel_size + tensor_parallel_size = self.parallel_config.tensor_parallel_size + + # Set CUDA_VISIBLE_DEVICES for the driver, inherited by workers + if "CUDA_VISIBLE_DEVICES" not in os.environ: + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + + cuda_device_count = cuda_device_count_stateless() + # Use confusing message for more common TP-only case. + assert tensor_parallel_size <= cuda_device_count, ( + f"please set tensor_parallel_size ({tensor_parallel_size}) " + f"to less than max local gpu count ({cuda_device_count})") + + assert world_size <= cuda_device_count, ( + f"please ensure that world_size ({world_size}) " + f"is less than than max local gpu count ({cuda_device_count})") + def shutdown(self): if (worker_monitor := getattr(self, "worker_monitor", None)) is not None: diff --git a/vllm/executor/multiproc_xpu_executor.py b/vllm/executor/multiproc_xpu_executor.py new file mode 100644 index 0000000000000..a66afbf939ef0 --- /dev/null +++ b/vllm/executor/multiproc_xpu_executor.py @@ -0,0 +1,26 @@ +import vllm.envs as envs +from vllm.executor.multiproc_gpu_executor import ( + MultiprocessingGPUExecutor, MultiprocessingGPUExecutorAsync) +from vllm.executor.xpu_executor import XPUExecutor +from vllm.logger import init_logger +from vllm.utils import make_async + +logger = init_logger(__name__) + + +class MultiprocessingXPUExecutor(MultiprocessingGPUExecutor, XPUExecutor): + """Python multiprocessing-based multi-XPU executor""" + + def _check_executor_parameters(self): + mp_method = envs.VLLM_WORKER_MULTIPROC_METHOD + if mp_method != "spawn": + raise RuntimeError( + "XPU multiprocess executor only support spawn as mp method") + + +class MultiprocessingXPUExecutorAsync(MultiprocessingXPUExecutor, + MultiprocessingGPUExecutorAsync): + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.driver_exec_model = make_async(self.driver_worker.execute_model) diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index b45d5d86b54fa..02627de3e0be7 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -4,7 +4,8 @@ from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.sequence import ExecuteModelRequest, SamplerOutput -from vllm.utils import make_async +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) logger = init_logger(__name__) @@ -24,14 +25,17 @@ def _init_executor(self) -> None: def _init_worker(self): from vllm.worker.neuron_worker import NeuronWorker - + distributed_init_method = get_distributed_init_method( + get_ip(), get_open_port()) self.driver_worker = NeuronWorker( - self.model_config, - self.parallel_config, - self.scheduler_config, - self.device_config, - self.cache_config, - ) + model_config=self.model_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config, + device_config=self.device_config, + cache_config=self.cache_config, + local_rank=0, + rank=0, + distributed_init_method=distributed_init_method) self.driver_worker.init_device() self.driver_worker.load_model() diff --git a/vllm/model_executor/layers/fused_moe/__init__.py b/vllm/model_executor/layers/fused_moe/__init__.py index 3e0767c7d2665..fd6f41b90042e 100644 --- a/vllm/model_executor/layers/fused_moe/__init__.py +++ b/vllm/model_executor/layers/fused_moe/__init__.py @@ -1,19 +1,17 @@ -from vllm.model_executor.layers.fused_moe.layer import (FusedMoE, - FusedMoEMethodBase) +from vllm.model_executor.layers.fused_moe.layer import ( + FusedMoE, FusedMoEMethodBase, FusedMoeWeightScaleSupported) from vllm.triton_utils import HAS_TRITON -__all__ = [ - "FusedMoE", - "FusedMoEMethodBase", -] +__all__ = ["FusedMoE", "FusedMoEMethodBase", "FusedMoeWeightScaleSupported"] if HAS_TRITON: from vllm.model_executor.layers.fused_moe.fused_moe import ( - fused_experts, fused_moe, fused_topk, get_config_file_name, - grouped_topk) + fused_experts, fused_marlin_moe, fused_moe, fused_topk, + get_config_file_name, grouped_topk) __all__ += [ + "fused_marlin_moe", "fused_moe", "fused_topk", "fused_experts", diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index bcf25d2631042..d2b152320e11e 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -323,21 +323,16 @@ def get_moe_configs(E: int, N: int, return None -def get_default_config( - M: int, - E: int, - N: int, - K: int, - topk: int, - dtype: Optional[str], -) -> Dict[str, int]: +def get_default_config(M: int, E: int, N: int, K: int, topk: int, + dtype: Optional[str], + is_marlin: bool) -> Dict[str, int]: config = { 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8 } - if M <= E: + if M <= E or (is_marlin and M <= 32): config = { 'BLOCK_SIZE_M': 16, 'BLOCK_SIZE_N': 32, @@ -347,14 +342,14 @@ def get_default_config( return config -def try_get_optimal_moe_config( - w1_shape: Tuple[int, ...], - w2_shape: Tuple[int, ...], - top_k: int, - dtype: Optional[str], - M: int, - override_config: Optional[Dict[str, Any]] = None, -): +def try_get_optimal_moe_config(w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, + Any]] = None, + is_marlin: bool = False): if override_config: config = override_config else: @@ -368,7 +363,8 @@ def try_get_optimal_moe_config( config = configs[min(configs.keys(), key=lambda x: abs(x - M))] else: # Else use the default config - config = get_default_config(M, E, N, w1_shape[2], top_k, dtype) + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin) return config @@ -441,6 +437,108 @@ def grouped_topk(hidden_states: torch.Tensor, return topk_weights, topk_ids +def fused_marlin_moe(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + g_idx1: torch.Tensor, + g_idx2: torch.Tensor, + rand_perm1: torch.Tensor, + rand_perm2: torch.Tensor, + topk: int, + renormalize: bool, + override_config: Optional[Dict[str, Any]] = None, + use_fp8: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + assert hidden_states.shape[ + 1] == w1.shape[1] * 16, "Hidden size mismatch w1" + assert hidden_states.shape[ + 1] == w2.shape[2] // 2, "Hidden size mismatch w2" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + #TODO fp8 is not implemented yet + assert not use_fp8 + + M, K = hidden_states.shape + E = w1.shape[0] + N = w2.shape[1] * 16 + + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + + get_config_func = functools.partial(try_get_optimal_moe_config, + w1.shape, + w2.shape, + topk_ids.shape[1], + "float8" if use_fp8 else None, + override_config=override_config, + is_marlin=True) + config = get_config_func(M) + + block_size_m = config['BLOCK_SIZE_M'] + + sorted_token_ids, _, _ = moe_align_block_size(topk_ids, block_size_m, E) + + max_workspace_size = ((M + 255) // 256) * (max(2 * N, K) // 64) * 16 + workspace = torch.zeros(max_workspace_size, + dtype=torch.int, + device="cuda", + requires_grad=False) + + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + + intermediate_cache1 = torch.ops._moe_C.marlin_gemm_moe( + hidden_states, w1, sorted_token_ids, topk_weights, topk_ids, w1_scale, + g_idx1, rand_perm1, workspace, M, 2 * N, K, True, E, topk, + block_size_m, True, False) + + ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, 2 * N)) + + intermediate_cache3 = torch.ops._moe_C.marlin_gemm_moe( + intermediate_cache2, w2, sorted_token_ids, topk_weights, topk_ids, + w2_scale, g_idx2, rand_perm2, workspace, M, K, N, True, E, topk, + block_size_m, False, True) + + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), + dim=1) + + def get_config_dtype_str(dtype: torch.dtype, use_int8_w8a16: Optional[bool] = False, use_fp8_w8a8: Optional[bool] = False): diff --git a/vllm/model_executor/layers/fused_moe/layer.py b/vllm/model_executor/layers/fused_moe/layer.py index 4e29ab701b937..61ebef5e11f43 100644 --- a/vllm/model_executor/layers/fused_moe/layer.py +++ b/vllm/model_executor/layers/fused_moe/layer.py @@ -1,4 +1,5 @@ from abc import abstractmethod +from enum import Enum from typing import List, Optional, Tuple import torch @@ -15,6 +16,12 @@ logger = init_logger(__name__) +class FusedMoeWeightScaleSupported(Enum): + TENSOR = "tensor" + CHANNEL = "channel" + GROUP = "group" + + class FusedMoEMethodBase(QuantizeMethodBase): @abstractmethod @@ -199,55 +206,182 @@ def __init__( params_dtype=params_dtype, weight_loader=self.weight_loader) + def _load_per_tensor_weight_scale(self, shard_id: str, + param: torch.nn.Parameter, + loaded_weight: torch.Tensor, + expert_id: int): + param_data = param.data + # for per tensor weight quantization + if shard_id in ("w1", "w3"): + # We have to keep the weight scales of w1 and w3 because + # we need to re-quantize w1/w3 weights after weight loading. + idx = 0 if shard_id == "w1" else 1 + param_data[expert_id][idx] = loaded_weight + # If we are in the row parallel case (down_proj) + elif shard_id == "w2": + param_data[expert_id] = loaded_weight + + def _load_model_weight_or_group_weight_scale(self, shard_dim: int, + expert_data: torch.Tensor, + shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # Load grouped weight scales for group quantization + # or model weights + if shard_id == "w2": + self._load_w2(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_per_channel_weight_scale(self, expert_data: torch.Tensor, + shard_dim: int, shard_id: str, + loaded_weight: torch.tensor, + tp_rank: int): + # for per channel weight quantization + if shard_id == "w2": + expert_data.copy_(loaded_weight) + elif shard_id in ("w1", "w3"): + self._load_w13(shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + + def _load_w13(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim + shard_size = expert_data.shape[shard_dim] // 2 + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # Narrow parameter and load. + # w1, gate_proj: Load into first logical weight of w13. + if shard_id == "w1": + expert_data = expert_data.narrow(shard_dim, 0, shard_size) + # w3, up_proj: Load into second logical weight of w13. + else: + assert shard_id == "w3" + expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) + expert_data.copy_(loaded_weight) + + def _load_w2(self, expert_data: torch.Tensor, shard_dim: int, + shard_id: str, loaded_weight: torch.tensor, tp_rank: int): + + # Index the loaded weight for tp sharding. + # down_proj: "RowParallel" so tp sharding on input_dim + # Narrow parameter and load. + shard_size = expert_data.shape[shard_dim] + loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, + shard_size) + # w2, down_proj: Load into only logical weight of w2. + expert_data.copy_(loaded_weight) + + def _load_single_value(self, param: torch.nn.Parameter, + loaded_weight: torch.Tensor, expert_id: int): + param_data = param.data + + # Input scales can be loaded directly and should be equal. + param_data[expert_id] = loaded_weight + def weight_loader(self, param: torch.nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int) -> None: + if shard_id not in ("w1", "w2", "w3"): raise ValueError(f"shard_id must be ['w1','w2','w3'] but " f"got {shard_id}.") - # Special case for fp8 scales. - if getattr(param, "is_fp8_scale", False): - self._load_fp8_scale(param.data, loaded_weight, weight_name, - shard_id, expert_id) - return + WEIGHT_SCALE_SUPPORTED = [ + e.value for e in FusedMoeWeightScaleSupported + ] + # Fetch the dim to shard the parameter/loaded weight + # based on the shard id. This will be whatever + # dimension intermediate_size is used. + SHARD_ID_TO_SHARDED_DIM = {"w1": 0, "w2": 1, "w3": 0} expert_data = param.data[expert_id] tp_rank = get_tensor_model_parallel_rank() - # If transposed, weight is saved as [input_dim, output_dim] - # Otherwise, weight is saved as [output_dim, input_dim] - # Default is not transposed/input dim is dim 1 - input_dim = getattr(param, "input_dim", 1) - output_dim = getattr(param, "output_dim", 0) + # is_transposed: whether or not the parameter is transposed on disk + # If transposed, the loaded weight will be transposed and the dim + # to shard the loaded weight will be flipped. + is_transposed = getattr(param, "is_transposed", False) + shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] + if is_transposed: + loaded_weight = loaded_weight.t().contiguous() + shard_dim = ~shard_dim + + # Case weight_scales + if "weight_scale" in weight_name: + # load the weight scaling based on the quantization scheme + # supported weight scales can be found in + # FusedMoeWeightScaleSupported + # TODO @dsikka: once hardened, refactor to use vLLM Parameters + # specific to each case + quant_method = getattr(param, "quant_method", None) + if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value: + self._load_per_channel_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.GROUP.value: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value: + self._load_per_tensor_weight_scale(shard_id=shard_id, + param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + else: + raise ValueError( + f"quant method must be one of {WEIGHT_SCALE_SUPPORTED}") + return - # Index the loaded weight for tp sharding. - # down_proj: "RowParallel" so tp sharding on input_dim - if shard_id == "w2": - shard_dim = input_dim - shard_size = expert_data.shape[shard_dim] - # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim - elif shard_id in ("w1", "w3"): - shard_dim = output_dim - shard_size = expert_data.shape[output_dim] // 2 - offset = shard_size * tp_rank - loaded_weight = loaded_weight.narrow(shard_dim, offset, shard_size) + if "weight_shape" in weight_name: + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return - # Narrow parameter and load. - # w1, gate_proj: Load into first logical weight of w13. - if shard_id == "w1": - expert_data = expert_data.narrow(shard_dim, 0, shard_size) - expert_data.copy_(loaded_weight) - # w3, up_proj: Load into second logical weight of w13. - elif shard_id == "w3": - expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) - expert_data.copy_(loaded_weight) - # w2, down_proj: Load into only logical weight of w2. - elif shard_id == "w2": - expert_data.copy_(loaded_weight) - else: - raise ValueError( - f"Expected shard_id w1,w2 or w3 but got {shard_id}") + # Case input scale + if "input_scale" in weight_name: + # Note: input_scale loading is only supported for fp8 + if param.data[expert_id] != 1 and (param.data[expert_id] - + loaded_weight).abs() > 1e-5: + raise ValueError( + "input_scales of w1 and w3 of a layer " + f"must be equal. But got {param.data[expert_id]} " + f"vs. {loaded_weight}") + + self._load_single_value(param=param, + loaded_weight=loaded_weight, + expert_id=expert_id) + return + + # Case model weights + if "weight" in weight_name: + self._load_model_weight_or_group_weight_scale( + shard_id=shard_id, + shard_dim=shard_dim, + loaded_weight=loaded_weight, + expert_data=expert_data, + tp_rank=tp_rank) + return @staticmethod def select_experts(hidden_states: torch.Tensor, @@ -342,4 +476,4 @@ def _load_fp8_scale(self, param: torch.nn.Parameter, param_data[expert_id][idx] = loaded_weight # If we are in the row parallel case (down_proj) else: - param_data[expert_id] = loaded_weight + param_data[expert_id] = loaded_weight \ No newline at end of file diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index e5b40a64abc41..1cad4e55f51ee 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -23,7 +23,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [ "CompressedTensorsLinearMethod", "AWQMarlinLinearMethod", "AWQLinearMethod", "GPTQMarlinLinearMethod", "Fp8LinearMethod", - "MarlinLinearMethod" + "MarlinLinearMethod", "QQQLinearMethod", "GPTQMarlin24LinearMethod" ] @@ -208,8 +208,7 @@ def __init__(self, self.input_size, self.output_size, self.params_dtype, - weight_loader=self.weight_loader, - prefix=prefix) + weight_loader=self.weight_loader) if bias: self.bias = Parameter( @@ -307,8 +306,7 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), - prefix=prefix) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if bias: self.bias = Parameter( torch.empty(self.output_size_per_partition, @@ -976,8 +974,7 @@ def __init__(self, params_dtype=self.params_dtype, weight_loader=( self.weight_loader_v2 if self.quant_method.__class__.__name__ - in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader), - prefix=prefix) + in WEIGHT_LOADER_V2_SUPPORTED else self.weight_loader)) if not reduce_results and (bias and not skip_bias_add): raise ValueError("When not reduce the results, adding bias to the " "results can lead to incorrect results") diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index ae75781927381..0768b37044aac 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -3,15 +3,18 @@ import torch from pydantic import BaseModel -from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, + UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( # noqa: E501 QuantizationConfig, QuantizeMethodBase) +from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501 + CompressedTensorsMoEMethod) from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( W4A16SPARSE24_SUPPORTED_BITS, WNA16_SUPPORTED_BITS, - CompressedTensorsScheme, CompressedTensorsUnquantized, - CompressedTensorsW4A16Sparse24, CompressedTensorsW8A8Fp8, - CompressedTensorsW8A8Int8, CompressedTensorsW8A16Fp8, - CompressedTensorsWNA16) + CompressedTensorsScheme, CompressedTensorsW4A16Sparse24, + CompressedTensorsW8A8Fp8, CompressedTensorsW8A8Int8, + CompressedTensorsW8A16Fp8, CompressedTensorsWNA16) from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( CompressionFormat, QuantizationArgs, QuantizationStrategy, QuantizationType, find_matched_target, is_activation_quantization_format, @@ -52,18 +55,25 @@ def get_min_capability(cls) -> int: def get_name(self) -> str: return "compressed_tensors" - # TODO (@robertgshaw2-neuralmagic): do layer skipping though here - # rather than though create_weights to match other methods def get_quant_method( self, layer: torch.nn.Module, prefix: str, ) -> Optional["QuantizeMethodBase"]: from vllm.attention.layer import Attention # Avoid circular import + + # Check if the layer is skipped for quantization. + # TODO (@robertgshaw2): support module names + if should_ignore_layer(prefix, ignore=self.ignore): + return UnquantizedLinearMethod() if isinstance(layer, LinearBase): + scheme = self.get_scheme(layer=layer, layer_name=prefix) + layer.scheme = scheme return CompressedTensorsLinearMethod(self) if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) + if isinstance(layer, FusedMoE): + return CompressedTensorsMoEMethod(self) return None @classmethod @@ -281,15 +291,11 @@ def get_scheme( to select the CompressedTensorsScheme used for infernece. """ - # Check if the layer is skipped for quantization. - # TODO (@robertgshaw2): support module names - if should_ignore_layer(layer_name, ignore=self.ignore): - return CompressedTensorsUnquantized() - # Find the "target" in the compressed-tensors config # that our layer conforms to. # TODO (@robertgshaw): add compressed-tensors as dep # so we do not have to re-write these functions + # need to make accelerate optional in ct to do this matched_target = find_matched_target( layer_name=layer_name, module=layer, @@ -327,10 +333,7 @@ def create_weights(self, layer: torch.nn.Module, details """ weight_loader = extra_weight_attrs.get("weight_loader") - layer_name = extra_weight_attrs.get("prefix") - - scheme = self.quantization_config.get_scheme(layer, layer_name) - scheme.create_weights( + layer.scheme.create_weights( layer=layer, input_size=input_size, input_size_per_partition=input_size_per_partition, @@ -339,8 +342,6 @@ def create_weights(self, layer: torch.nn.Module, params_dtype=params_dtype, weight_loader=weight_loader) - layer.scheme = scheme - def apply(self, layer: torch.nn.Module, x: torch.Tensor, diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py new file mode 100644 index 0000000000000..0e0ab9ce9169f --- /dev/null +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py @@ -0,0 +1,283 @@ +import enum +from enum import Enum +from typing import List, Optional + +import torch + +from vllm import _custom_ops as ops +from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase +from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( + WNA16_SUPPORTED_BITS) +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + CompressionFormat) +from vllm.model_executor.utils import set_weight_attrs + + +class GPTQMarlinState(Enum): + REPACK = enum.auto() + READY = enum.auto() + + +__all__ = ["CompressedTensorsMoEMethod"] + + +class CompressedTensorsMoEMethod(FusedMoEMethodBase): + + def __init__( + self, + quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501 + ): + self.quant_config = quant_config + # TODO: @dsikka: refactor this to use schemes as other kernels + # are supported + check if the layer is being ignored. + config = self.quant_config.target_scheme_map["Linear"].get("weights") + self.num_bits = config.num_bits + self.packed_factor = 32 // config.num_bits + self.strategy = config.strategy.value + self.group_size = config.group_size + assert config.symmetric, ( + "Only symmetric quantization is supported for MoE") + + if not (self.quant_config.quant_format + == CompressionFormat.pack_quantized.value + and self.num_bits in WNA16_SUPPORTED_BITS): + raise ValueError("For Fused MoE layers, only ", + f"{CompressionFormat.pack_quantized.value} ", + "is supported for the following bits: ", + f"{WNA16_SUPPORTED_BITS}") + + def create_weights(self, layer: torch.nn.Module, num_experts: int, + hidden_size: int, intermediate_size: int, + params_dtype: torch.dtype, **extra_weight_attrs): + + # Will transpose the loaded weight along the + # intermediate and hidden dim sizes. Will + # shard for TP along the transposed dims + extra_weight_attrs.update({ + "is_transposed": True, + "quant_method": self.strategy + }) + w13_weight = torch.nn.Parameter(torch.empty(num_experts, + hidden_size // + self.packed_factor, + 2 * intermediate_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w13_weight_packed", w13_weight) + set_weight_attrs(w13_weight, extra_weight_attrs) + + w2_weight = torch.nn.Parameter(torch.empty(num_experts, + intermediate_size // + self.packed_factor, + hidden_size, + dtype=torch.int32), + requires_grad=False) + layer.register_parameter("w2_weight_packed", w2_weight) + set_weight_attrs(w2_weight, extra_weight_attrs) + + if self.strategy == "channel": + num_groups_w2 = num_groups_w13 = 1 + self.group_size = -1 + else: + num_groups_w2 = intermediate_size // self.group_size + num_groups_w13 = hidden_size // self.group_size + + w13_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w13, + 2 * intermediate_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w13_weight_scale", w13_scale) + set_weight_attrs(w13_scale, extra_weight_attrs) + + w2_scale = torch.nn.Parameter(torch.ones(num_experts, + num_groups_w2, + hidden_size, + dtype=params_dtype), + requires_grad=False) + layer.register_parameter("w2_weight_scale", w2_scale) + set_weight_attrs(w2_scale, extra_weight_attrs) + + w2_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + layer.register_parameter("w2_weight_shape", w2_weight_shape) + set_weight_attrs(w2_weight_shape, extra_weight_attrs) + w13_weight_shape = torch.nn.Parameter(torch.empty(num_experts, 2), + requires_grad=False) + + layer.register_parameter("w13_weight_shape", w13_weight_shape) + set_weight_attrs(w13_weight_shape, extra_weight_attrs) + + w13_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx", w13_g_idx) + set_weight_attrs(w13_g_idx, extra_weight_attrs) + + w2_g_idx = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx", w2_g_idx) + set_weight_attrs(w2_g_idx, extra_weight_attrs) + + w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + hidden_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w13_g_idx_sort_indices", + w13_g_idx_sort_indices) + set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs) + + w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty( + num_experts, + intermediate_size, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("w2_g_idx_sort_indices", + w2_g_idx_sort_indices) + set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs) + + layer.a13_scale = None + layer.a2_scale = None + layer.marlin_state = GPTQMarlinState.REPACK + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + + def replace_tensor(name, new_t): + # It is important to use resize_() here since it ensures + # the same buffer is reused + getattr(layer, name).resize_(new_t.shape) + getattr(layer, name).copy_(new_t) + del new_t + + def get_scale_perms(num_bits: int): + scale_perm: List[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: List[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]) + return scale_perm, scale_perm_single + + def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int, + group_size: int, num_bits: int): + scale_perm, scale_perm_single = get_scale_perms(num_bits) + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, + scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + def marlin_moe_permute_scales(s: torch.Tensor, size_k: int, + size_n: int, group_size: int, + num_bits: int): + num_experts = s.shape[0] + output = torch.empty((num_experts, s.shape[1], s.shape[2]), + device=s.device, + dtype=s.dtype) + for e in range(num_experts): + output[e] = marlin_permute_scales(s[e], size_k, size_n, + group_size, num_bits) + return output + + size_k2 = layer.w2_weight_packed.shape[2] + size_k13 = layer.w13_weight_packed.shape[2] + + num_experts = layer.w13_g_idx.shape[0] + device = layer.w13_g_idx.device + layer.w13_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w13_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + layer.w2_g_idx_sort_indices = torch.nn.Parameter( + torch.empty((num_experts, 0), dtype=torch.int32, device=device), + requires_grad=False, + ) + + marlin_w13_qweight = ops.gptq_marlin_moe_repack( + layer.w13_weight_packed, + layer.w13_g_idx_sort_indices, + layer.w13_weight_packed.shape[1] * self.packed_factor, + layer.w13_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w13_weight_packed", marlin_w13_qweight) + marlin_w2_qweight = ops.gptq_marlin_moe_repack( + layer.w2_weight_packed, + layer.w2_g_idx_sort_indices, + layer.w2_weight_packed.shape[1] * self.packed_factor, + layer.w2_weight_packed.shape[2], + self.num_bits, + ) + replace_tensor("w2_weight_packed", marlin_w2_qweight) + # Repack scales + marlin_w13_scales = marlin_moe_permute_scales( + layer.w13_weight_scale, + size_k13, + layer.w13_weight_scale.shape[2], + self.group_size, + self.num_bits, + ) + replace_tensor("w13_weight_scale", marlin_w13_scales) + marlin_w2_scales = marlin_moe_permute_scales( + layer.w2_weight_scale, + layer.w2_weight_scale.shape[1] * self.packed_factor, + size_k2, + self.group_size, + self.num_bits, + ) + replace_tensor("w2_weight_scale", marlin_w2_scales) + + def apply(self, + layer: torch.nn.Module, + x: torch.Tensor, + router_logits: torch.Tensor, + top_k: int, + renormalize: bool = True, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None) -> torch.Tensor: + + from vllm.model_executor.layers.fused_moe.fused_moe import ( + fused_marlin_moe) + + return fused_marlin_moe(x, + layer.w13_weight_packed, + layer.w2_weight_packed, + router_logits, + layer.w13_g_idx, + layer.w2_g_idx, + layer.w13_g_idx_sort_indices, + layer.w2_g_idx_sort_indices, + top_k, + renormalize=renormalize, + w1_scale=layer.w13_weight_scale, + w2_scale=layer.w2_weight_scale) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py index ca9e286ce5b2d..5d259ec72051c 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py @@ -1,5 +1,4 @@ from .compressed_tensors_scheme import CompressedTensorsScheme -from .compressed_tensors_unquantized import CompressedTensorsUnquantized from .compressed_tensors_w4a16_24 import (W4A16SPARSE24_SUPPORTED_BITS, CompressedTensorsW4A16Sparse24) from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8 @@ -10,7 +9,6 @@ __all__ = [ "CompressedTensorsScheme", - "CompressedTensorsUnquantized", "CompressedTensorsWNA16", "CompressedTensorsW8A16Fp8", "CompressedTensorsW4A16Sparse24", diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py b/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py deleted file mode 100644 index 2e8d520eacc81..0000000000000 --- a/vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_unquantized.py +++ /dev/null @@ -1,49 +0,0 @@ -from typing import Callable, List, Optional - -import torch -import torch.nn.functional as F - -from vllm.model_executor.layers.quantization.compressed_tensors.schemes import ( - CompressedTensorsScheme) -from vllm.model_executor.parameter import ModelWeightParameter - -__all__ = ["CompressedTensorsUnquantized"] - - -class CompressedTensorsUnquantized(CompressedTensorsScheme): - """ - Implements the scheme for all layers which are ignored - in the CompressedTensors config. The input and loaded weight are used - in a linear transformation. - """ - - @classmethod - def get_min_capability(cls) -> int: - # volta and up - return 70 - - def process_weights_after_loading(self, layer: torch.nn.Module) -> None: - # required by torch.compile to be torch.nn.Parameter - layer.weight = torch.nn.Parameter(layer.weight.data, - requires_grad=False) - - def create_weights(self, layer: torch.nn.Module, - output_partition_sizes: List[int], - input_size_per_partition: int, - params_dtype: torch.dtype, weight_loader: Callable, - **kwargs): - - weight = ModelWeightParameter(data=torch.empty( - sum(output_partition_sizes), - input_size_per_partition, - dtype=params_dtype), - input_dim=1, - output_dim=0, - weight_loader=weight_loader) - - layer.register_parameter("weight", weight) - - def apply_weights(self, layer: torch.nn.Module, x: torch.Tensor, - bias: Optional[torch.Tensor]) -> torch.Tensor: - - return F.linear(x, layer.weight, bias) diff --git a/vllm/model_executor/layers/quantization/fp8.py b/vllm/model_executor/layers/quantization/fp8.py index b10988b992ae1..1817dbcb023a7 100644 --- a/vllm/model_executor/layers/quantization/fp8.py +++ b/vllm/model_executor/layers/quantization/fp8.py @@ -7,7 +7,8 @@ import vllm.envs as envs from vllm import _custom_ops as ops from vllm.logger import init_logger -from vllm.model_executor.layers.fused_moe import FusedMoE, FusedMoEMethodBase +from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, + FusedMoeWeightScaleSupported) from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, UnquantizedLinearMethod) from vllm.model_executor.layers.quantization.base_config import ( @@ -332,19 +333,16 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_weight_scale", w2_weight_scale) - + # Add the quantization method used (per tensor/grouped/channel) + # to ensure the weight scales are loaded in properly + extra_weight_attrs.update( + {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}) # If loading fp8 checkpoint, pass the weight loaders. # If loading an fp16 checkpoint, do not (we will quantize in # process_weights_after_loading() if self.quant_config.is_checkpoint_fp8_serialized: - set_weight_attrs(w13_weight_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) - set_weight_attrs(w2_weight_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w13_weight_scale, extra_weight_attrs) + set_weight_attrs(w2_weight_scale, extra_weight_attrs) # INPUT_SCALES if self.quant_config.activation_scheme == "static": @@ -357,19 +355,14 @@ def create_weights(self, layer: Module, num_experts: int, hidden_size: int, num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w13_input_scale", w13_input_scale) - set_weight_attrs(w13_input_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter(torch.ones( num_experts, dtype=torch.float32), requires_grad=False) layer.register_parameter("w2_input_scale", w2_input_scale) - set_weight_attrs(w2_input_scale, { - "is_fp8_scale": True, - **extra_weight_attrs - }) + set_weight_attrs(w2_input_scale, extra_weight_attrs) + else: layer.w13_input_scale = None layer.w2_input_scale = None diff --git a/vllm/model_executor/layers/quantization/gptq_marlin_24.py b/vllm/model_executor/layers/quantization/gptq_marlin_24.py index cafd100a2f40c..0971aedba4c3c 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin_24.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin_24.py @@ -8,7 +8,10 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) from vllm.scalar_type import scalar_types logger = init_logger(__name__) @@ -149,7 +152,7 @@ def create_weights( **extra_weight_attrs, ): del output_size # Unused. - + weight_loader = extra_weight_attrs["weight_loader"] if params_dtype != torch.float16: raise ValueError( f"The params dtype must be float16, but got {params_dtype}") @@ -187,87 +190,80 @@ def create_weights( "Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.tile_size // 2, output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, - { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - "marlin_tile_size": self.quant_config.tile_size, - }, - ) + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) # Meta - meta = Parameter( - torch.empty( - input_size_per_partition // 8 // 2 // 2, - output_size_per_partition * 2, - device="cuda", - dtype=torch.int16, - ), - requires_grad=False, - ) - set_weight_attrs( - meta, - { - "input_dim": 0, - "packed_dim": 1, - "pack_factor": 1, - "output_dim": 1, - "marlin_tile_size": 2, - }, - ) + meta = PackedvLLMParameter(data=torch.empty( + input_size_per_partition // 8 // 2 // 2, + output_size_per_partition * 2, + device="cuda", + dtype=torch.int16, + ), + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=1, + marlin_tile_size=2, + weight_loader=weight_loader) # Determine if channelwise or not input_groups = (1 if self.quant_config.group_size == -1 else input_size_per_partition // self.quant_config.group_size) - scales = Parameter( + weight_scale_args = { + "data": torch.empty( input_groups, output_size_per_partition, device="cuda", dtype=params_dtype, ), - requires_grad=False, - ) - set_weight_attrs( - scales, - { - "input_dim": None if input_groups == 1 else 0, - "output_dim": 1, - }, - ) + "weight_loader": + weight_loader + } + if input_groups == 1: + scales = ChannelQuantScaleParameter(output_dim=1, + **weight_scale_args) + else: + scales = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **weight_scale_args) # Allocate workspace (Used for internal locking mechanism) max_workspace_size = ( output_size_per_partition // self.quant_config.min_n_threads) * self.quant_config.max_parallel - workspace = Parameter(torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - requires_grad=False) + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) layer.register_parameter("B_24", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("B_meta", meta) - set_weight_attrs(meta, extra_weight_attrs) layer.register_parameter("s", scales) - set_weight_attrs(scales, extra_weight_attrs) layer.register_parameter("workspace", workspace) - set_weight_attrs(workspace, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B_24 = Parameter(layer.B_24.data, requires_grad=False) + layer.s = Parameter(layer.s.data, requires_grad=False) + layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) def apply( self, diff --git a/vllm/model_executor/layers/quantization/qqq.py b/vllm/model_executor/layers/quantization/qqq.py index be10cee2cf68f..c3434214a1cde 100644 --- a/vllm/model_executor/layers/quantization/qqq.py +++ b/vllm/model_executor/layers/quantization/qqq.py @@ -8,7 +8,10 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.utils import set_weight_attrs +from vllm.model_executor.parameter import (BasevLLMParameter, + ChannelQuantScaleParameter, + GroupQuantScaleParameter, + PackedvLLMParameter) logger = init_logger(__name__) @@ -133,6 +136,7 @@ def create_weights( params_dtype: torch.dtype, **extra_weight_attrs, ): + weight_loader = extra_weight_attrs["weight_loader"] if params_dtype != torch.float16: raise ValueError( f"The params dtype must be float16, but got {params_dtype}") @@ -170,90 +174,74 @@ def create_weights( "Each permutation group must reside on the same gpu") # Quantized 4Bit weights packed into Int32. - qweight = Parameter( - torch.empty( + qweight = PackedvLLMParameter( + data=torch.empty( input_size_per_partition // self.quant_config.tile_size, output_size_per_partition * self.quant_config.tile_size // self.quant_config.pack_factor, device="cuda", dtype=torch.int32, ), - requires_grad=False, - ) - set_weight_attrs( - qweight, - { - "input_dim": 0, - "output_dim": 1, - "packed_dim": 1, - "pack_factor": self.quant_config.pack_factor, - "marlin_tile_size": self.quant_config.tile_size, - }, - ) - - s_channel = Parameter( - torch.empty( - 1, - output_size_per_partition, - device="cuda", - dtype=torch.float, - ), - requires_grad=False, - ) - set_weight_attrs( - s_channel, - { - "input_dim": None, - "output_dim": 1, - }, - ) + input_dim=0, + output_dim=1, + packed_dim=1, + packed_factor=self.quant_config.pack_factor, + marlin_tile_size=self.quant_config.tile_size, + weight_loader=weight_loader) + + s_channel = ChannelQuantScaleParameter(data=torch.empty( + 1, + output_size_per_partition, + device="cuda", + dtype=torch.float, + ), + weight_loader=weight_loader, + output_dim=1) if self.quant_config.group_size == -1: - s_group = Parameter( - torch.tensor( - [], - device="cuda", - dtype=torch.half, - ), - requires_grad=False, + s_group_data = torch.tensor( + [], + device="cuda", + dtype=torch.half, ) else: - s_group = Parameter( - torch.empty( - input_size_per_partition // self.quant_config.group_size, - output_size_per_partition, - device="cuda", - dtype=torch.half, - ), - requires_grad=False, + s_group_data = torch.empty( + input_size_per_partition // self.quant_config.group_size, + output_size_per_partition, + device="cuda", + dtype=torch.half, ) - set_weight_attrs( - s_group, - { - "input_dim": None if self.quant_config.group_size == -1 else 0, - "output_dim": - None if self.quant_config.group_size == -1 else 1, - }, - ) + s_group_attr = {"data": s_group_data, "weight_loader": weight_loader} + + if self.quant_config.group_size == -1: + s_group = BasevLLMParameter(**s_group_attr) + else: + s_group = GroupQuantScaleParameter(output_dim=1, + input_dim=0, + **s_group_attr) # Allocate workspace (Used for internal locking mechanism) max_workspace_size = ( output_size_per_partition // self.quant_config.min_n_threads) * self.quant_config.max_parallel - workspace = Parameter(torch.zeros(max_workspace_size, - device="cuda", - dtype=torch.int), - requires_grad=False) + + workspace = BasevLLMParameter(data=torch.zeros(max_workspace_size, + device="cuda", + dtype=torch.int), + weight_loader=weight_loader) layer.register_parameter("B", qweight) - set_weight_attrs(qweight, extra_weight_attrs) layer.register_parameter("s_channel", s_channel) - set_weight_attrs(s_channel, extra_weight_attrs) layer.register_parameter("s_group", s_group) - set_weight_attrs(s_group, extra_weight_attrs) layer.register_parameter("workspace", workspace) - set_weight_attrs(workspace, extra_weight_attrs) + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + # required by torch.compile + layer.B = Parameter(layer.B.data, requires_grad=False) + layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False) + layer.s_group = Parameter(layer.s_group.data, requires_grad=False) + layer.workspace = Parameter(layer.workspace.data, requires_grad=False) def apply( self, diff --git a/vllm/model_executor/model_loader/utils.py b/vllm/model_executor/model_loader/utils.py index 331b859d2adec..4bb943ab3afe4 100644 --- a/vllm/model_executor/model_loader/utils.py +++ b/vllm/model_executor/model_loader/utils.py @@ -23,11 +23,11 @@ def get_model_architecture( architectures = getattr(model_config.hf_config, "architectures", []) # Special handling for quantized Mixtral. # FIXME(woosuk): This is a temporary hack. + mixtral_supported = ["fp8", "compressed-tensors"] if (model_config.quantization is not None - and model_config.quantization != "fp8" + and model_config.quantization not in mixtral_supported and "MixtralForCausalLM" in architectures): architectures = ["QuantMixtralForCausalLM"] - return ModelRegistry.resolve_model_cls(architectures) diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 20dda2a67820d..7c9123079c44f 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -555,6 +555,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return Blip2ImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -564,6 +567,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return Blip2ImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index a335e1766b2a9..2d4f172ce0be6 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -946,6 +946,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return ChameleonImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 0933966055330..69bb9f6f3afee 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -1,7 +1,7 @@ -"""Minimal implementation of CLIPVisionModel intended to be only used +"""Minimal implementation of CLIPVisionModel intended to be only used within a vision language model.""" from array import array -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn @@ -84,7 +84,7 @@ def input_processor_for_clip( llm_inputs: LLMInputs, *, image_token_id: int, - image_feature_size_override: Optional[int] = None, + image_feature_size_override: Optional[Union[int, List[int]]] = None, ): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: @@ -217,7 +217,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: class CLIPEncoder(nn.Module): """ - Transformer encoder consisting of `config.num_hidden_layers` self + Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a [`CLIPEncoderLayer`]. Args: diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index cfc2a5288a37b..6cdf331fed8b7 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -249,6 +249,9 @@ def _parse_and_validate_image_input( image_patches = kwargs.pop("image_patches", None) if isinstance(image_patches, torch.Tensor): + # Remove the N dimension until multiple images are supported. + image_patches = image_patches.squeeze(1) + expected_feature_size = self.image_feature_size if image_patches.size(-1) != expected_feature_size: raise ValueError( diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index c996f0b73f293..7f213287f33b4 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -244,6 +244,8 @@ def input_mapper_for_internvl(ctx: InputContext, data: object): min_num, max_num, use_thumbnail=use_thumbnail) + # Add an N dimension for number of images per prompt (currently 1). + data = data.unsqueeze(0) model_config = ctx.model_config tokenizer = cached_get_tokenizer(model_config.tokenizer, trust_remote_code=True) @@ -410,6 +412,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Flatten the B and N dimensions + image_embeds = image_embeds.flatten(0, 2) + return InternVLImageEmbeddingInputs( type="image_embeds", data=image_embeds, @@ -422,6 +428,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + # Flatten the B and N dimensions + pixel_values = pixel_values.flatten(0, 2) + return InternVLImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index b82eb14fb5f23..caeda4e42d8a0 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -920,7 +920,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = param.weight_loader weight_loader(param, loaded_weight, - weight_name, + name, shard_id=shard_id, expert_id=expert_id) break diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 6433ea380cbfe..03a0abf1db481 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -232,6 +232,10 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, torch.Tensor): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return LlavaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -241,6 +245,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return LlavaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index c7cb243fa84da..3a87242954114 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -19,6 +19,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.sequence import IntermediateTensors, SamplerOutput +from vllm.utils import is_list_of from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_image_feature_size, @@ -223,6 +224,13 @@ def input_processor_for_llava_next(ctx: InputContext, llm_inputs: LLMInputs): input_height=height, input_width=width, ) + elif is_list_of(image_data, Image.Image): + image_feature_size = [ + get_llava_next_image_feature_size(hf_config, + input_height=img.height, + input_width=img.width) + for img in image_data + ] elif isinstance(image_data, torch.Tensor): image_feature_size = image_data.shape[0] else: @@ -353,6 +361,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + # Remove the N dimension until multiple images are supported. + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.squeeze(1) + else: + pixel_values = [t.squeeze(0) for t in pixel_values] + + image_sizes = image_sizes.squeeze(1) + return LlavaNextImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -364,6 +380,9 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image embeds. " f"Got type: {type(image_embeds)}") + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return LlavaNextImageEmbeddingInputs( type="image_embeds", data=image_embeds, @@ -425,7 +444,10 @@ def _merge_image_patch_embeddings(self, image_size: torch.Tensor, self.config.image_grid_pinpoints, self.config.vision_config.image_size, ) - other_patch_embeds = other_patch_embeds \ + num_patches = num_patch_height * num_patch_width + + # Image patches might be padded for batch processing + other_patch_embeds = other_patch_embeds[:num_patches] \ .view(num_patch_height, num_patch_width, height, width, -1) if "unpad" in strategy: @@ -496,7 +518,6 @@ def _process_image_input( self, image_input: LlavaNextImageInputs, ) -> Union[torch.Tensor, List[torch.Tensor]]: - if image_input["type"] == "image_embeds": return [image_input["data"]] diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 29f3640e2458b..6a3d5422e0ce4 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -594,9 +594,14 @@ def _parse_and_validate_inputs( pixel_values_flat: List[torch.Tensor] = [] tgt_sizes_flat: List[torch.Tensor] = [] - for b in range(len(pixel_values)): - pixel_values_flat += pixel_values[b] - tgt_sizes_flat += tgt_sizes[b] + for pixel_b, tgt_b in zip(pixel_values, tgt_sizes): + if len(pixel_b) != len(tgt_b): + raise ValueError("Inconsistent N lengths, found: " + f"{len(pixel_b)} vs {len(tgt_b)}") + + for pixel_n, tgt_n in zip(pixel_b, tgt_b): + pixel_values_flat += pixel_n + tgt_sizes_flat += tgt_n # NOTE: Input IDs does not contain image tokens during memory profiling, # so we allow it to be empty diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 34f581ac78582..413783ba4b259 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -73,6 +73,7 @@ def __init__(self, self.hidden_size = hidden_size # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, num_experts, bias=False, diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8cb5065ed79ec..0700f0c29d708 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -185,6 +185,10 @@ def _parse_and_validate_image_input( if not isinstance(pixel_values, torch.Tensor): raise ValueError("Incorrect type of pixel values. " f"Got type: {type(pixel_values)}") + + # Remove the N dimension until multiple images are supported. + pixel_values = pixel_values.squeeze(1) + return PaliGemmaImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), @@ -194,6 +198,10 @@ def _parse_and_validate_image_input( if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") + + # Remove the N dimension until multiple images are supported. + image_embeds = image_embeds.squeeze(1) + return PaliGemmaImageEmbeddingInputs( type="image_embeds", data=image_embeds, diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 4872929ec36cc..61f1d73976379 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -422,7 +422,9 @@ def input_processor_for_phi3v(ctx: InputContext, llm_inputs: LLMInputs): prompt = llm_inputs.get("prompt") if prompt is None: - image_idx = [] + # for async server request, we assume prompt and its token_ids is always + # in correct format. And num_image_tags == len(image_data) always True. + image_idx = range(1, len(image_data) + 1) new_prompt = None else: image_idx = sorted(map(int, re.findall(r"<\|image_(\d+)\|>+", prompt))) @@ -558,6 +560,14 @@ def _parse_and_validate_image_input( raise ValueError("Incorrect type of image sizes. " f"Got type: {type(image_sizes)}") + # Merge the B and N dimensions. + if isinstance(pixel_values, torch.Tensor): + pixel_values = pixel_values.flatten(0, 1) + else: + pixel_values = torch.cat(pixel_values) + + image_sizes = image_sizes.flatten(0, 1) + return Phi3VImagePixelInputs( type="pixel_values", data=self._validate_pixel_values(pixel_values), diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index 7f6186fa010a4..073f60bb3a056 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -3,7 +3,7 @@ import math from array import array -from typing import Iterable, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from PIL import Image @@ -93,7 +93,7 @@ def input_processor_for_siglip( llm_inputs: LLMInputs, *, image_token_id: int, - image_feature_size_override: Optional[int] = None, + image_feature_size_override: Optional[Union[int, List[int]]] = None, ): multi_modal_data = llm_inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 842264f765866..c81c2fd114eb8 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -333,6 +333,12 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio features. " f"Got type: {type(audio_features)}") + # Remove the N dimension until multiple audios are supported. + if isinstance(audio_features, torch.Tensor): + audio_features = audio_features.squeeze(1) + else: + audio_features = [t.squeeze(0) for t in audio_features] + return UltravoxAudioFeatureInputs(type="audio_features", data=audio_features) @@ -341,6 +347,9 @@ def _parse_and_validate_audio_input( raise ValueError("Incorrect type of audio embeds. " f"Got type: {type(audio_embeds)}") + # Remove the N dimension until multiple audios are supported. + audio_embeds = audio_embeds.squeeze(1) + return UltravoxAudioEmbeddingInputs(type="audio_embeds", data=audio_embeds) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 91b414b1fd91a..00026b7ebe2e1 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -1,5 +1,6 @@ from typing import Dict, Iterable, List, Optional, Protocol, Tuple +import numpy as np import torch import torch.nn as nn from torch.func import functional_call @@ -10,7 +11,7 @@ from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.loader import build_model from vllm.model_executor.models import ModelRegistry -from vllm.multimodal import BatchedTensors +from vllm.multimodal.base import NestedTensors from vllm.utils import is_pin_memory_available @@ -54,9 +55,34 @@ def init_vllm_registered_model( ) +def _flatten_embeddings(embeddings: NestedTensors) -> torch.Tensor: + """ + Recursively concatenates NestedTensors along any heterogeneously sized + dimensions. + """ + + if isinstance(embeddings, torch.Tensor): + return embeddings + + return torch.cat(tuple(_flatten_embeddings(t) for t in embeddings)) + + +def _embedding_count_expression(embeddings: NestedTensors) -> str: + """ + Constructs a debugging representation of the number of embeddings in the + NestedTensors. + """ + + if isinstance(embeddings, torch.Tensor): + return " x ".join([str(dim) for dim in embeddings.shape[:-1]]) + + return " + ".join( + _embedding_count_expression(inner) for inner in embeddings) + + def merge_multimodal_embeddings(input_ids: torch.Tensor, inputs_embeds: torch.Tensor, - multimodal_embeddings: BatchedTensors, + multimodal_embeddings: NestedTensors, placeholder_token_id: int) -> torch.Tensor: """ Merge ``multimodal_embeddings`` into ``inputs_embeds`` by overwriting the @@ -69,28 +95,16 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, mask = (input_ids == placeholder_token_id) num_expected_tokens = mask.sum() - if isinstance(multimodal_embeddings, torch.Tensor): - batch_size, batch_tokens, *_, embed_dim = multimodal_embeddings.shape - total_tokens = batch_size * batch_tokens - if num_expected_tokens != total_tokens: - expr = f"{batch_size} x {batch_tokens}" - raise ValueError( - f"Attempted to assign {expr} = {total_tokens} " - f"multimodal tokens to {num_expected_tokens} placeholders") - - inputs_embeds[mask] = multimodal_embeddings.view( - total_tokens, embed_dim) - else: - size_per_batch = [t.shape[0] for t in multimodal_embeddings] - total_tokens = sum(size_per_batch) - if num_expected_tokens != total_tokens: - expr = ' + '.join(map(str, size_per_batch)) - raise ValueError( - f"Attempted to assign {expr} = {total_tokens} " - f"multimodal tokens to {num_expected_tokens} placeholders") - - inputs_embeds[mask] = torch.cat(multimodal_embeddings) + flattened = _flatten_embeddings(multimodal_embeddings) + *dims, embed_dim = flattened.shape + num_multimodal_embeddings = np.prod(dims) + if num_multimodal_embeddings != num_expected_tokens: + expr = _embedding_count_expression(multimodal_embeddings) + raise ValueError( + f"Attempted to assign {expr} = {num_multimodal_embeddings} " + f"multimodal tokens to {num_expected_tokens} placeholders") + inputs_embeds[mask] = flattened.view(num_expected_tokens, embed_dim) return inputs_embeds diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 456e41ebfad03..489e1e51f05cb 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -1,4 +1,4 @@ -from .base import (BatchedTensorInputs, BatchedTensors, MultiModalDataBuiltins, +from .base import (BatchedTensorInputs, MultiModalDataBuiltins, MultiModalDataDict, MultiModalInputs, MultiModalPlugin, NestedTensors) from .registry import MultiModalRegistry @@ -14,7 +14,6 @@ __all__ = [ "BatchedTensorInputs", - "BatchedTensors", "MultiModalDataBuiltins", "MultiModalDataDict", "MultiModalInputs", diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 8ada60c8fd6ae..5b00117c64e53 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -1,9 +1,8 @@ import sys from abc import ABC, abstractmethod from collections import UserDict, defaultdict -from typing import Callable, Dict, List, Mapping, Optional -from typing import Sequence as GenericSequence -from typing import Tuple, Type, TypedDict, TypeVar, Union, cast, final +from typing import (Callable, Dict, List, Mapping, Optional, Tuple, Type, + TypedDict, TypeVar, Union, cast, final) import numpy as np import torch @@ -15,23 +14,16 @@ from vllm.config import ModelConfig from vllm.inputs import InputContext from vllm.logger import init_logger -from vllm.utils import JSONTree, json_map_leaves +from vllm.utils import json_map_leaves logger = init_logger(__name__) -NestedTensors = Union[GenericSequence[torch.Tensor], torch.Tensor] +NestedTensors = Union[List["NestedTensors"], torch.Tensor] """ -Use a list instead of a tensor if the dimensions of each element do not match. -Currently only supports up to singly nested list of tensors. +Uses a list instead of a tensor if the dimensions of each element do not match. """ -BatchedTensors: TypeAlias = JSONTree[torch.Tensor] -""" -A nested JSON structure of tensors which have been batched via -:meth:`MultiModalInputs.batch`. -""" - -BatchedTensorInputs: TypeAlias = Dict[str, JSONTree[torch.Tensor]] +BatchedTensorInputs: TypeAlias = Dict[str, NestedTensors] """ A dictionary containing nested tensors which have been batched via :meth:`MultiModalInputs.batch`. @@ -54,26 +46,23 @@ class MultiModalInputs(_MultiModalInputsBase): """ @staticmethod - def _try_concat(tensors: List[NestedTensors]) -> BatchedTensors: + def _try_stack(nested_tensors: NestedTensors) -> NestedTensors: """ - If each input tensor in the batch has the same shape, return a single - batched tensor; otherwise, return a list of :class:`NestedTensors` with - one element per item in the batch. + Recursively stacks lists of tensors when they all have the same shape. """ - # may be list rather than tensors - if isinstance(tensors[0], list): - return [[t for t in tensor[0]] - for tensor in cast(List[List[torch.Tensor]], tensors)] - - tensors_ = cast(List[torch.Tensor], tensors) + if isinstance(nested_tensors, torch.Tensor): + return nested_tensors - unbatched_shape = tensors_[0].shape[1:] + stacked = [MultiModalInputs._try_stack(t) for t in nested_tensors] + if any(isinstance(t, list) for t in stacked): + return stacked - for tensor in tensors_: - if tensor.shape[1:] != unbatched_shape: - return [tensor.squeeze(0) for tensor in tensors_] + tensors_ = cast(List[torch.Tensor], stacked) + if any(t.shape != tensors_[0].shape for t in tensors_): + # The tensors have incompatible shapes and can't be stacked. + return tensors_ - return torch.cat(tensors_, dim=0) + return torch.stack(tensors_) @staticmethod def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: @@ -102,7 +91,7 @@ def batch(inputs_list: List["MultiModalInputs"]) -> BatchedTensorInputs: item_lists[k].append(v) return { - k: MultiModalInputs._try_concat(item_list) + k: MultiModalInputs._try_stack(item_list) for k, item_list in item_lists.items() } diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 3bf430235462b..989b2e1a814c9 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -189,10 +189,13 @@ def repeat_and_pad_placeholder_tokens( prompt_token_ids: List[int], *, placeholder_token_id: int, - repeat_count: int = 1, + repeat_count: Union[int, List[int]], pad_token_left: Optional[int] = None, pad_token_right: Optional[int] = None, ) -> Tuple[Optional[str], List[int]]: + if isinstance(repeat_count, int): + repeat_count = [repeat_count] + if prompt is None: new_prompt = None else: @@ -201,13 +204,6 @@ def repeat_and_pad_placeholder_tokens( tokenizer.decode(pad_token_left)) pad_token_str_right = (None if pad_token_right is None else tokenizer.decode(pad_token_right)) - replacement_str = "".join( - repeat_and_pad_token( - placeholder_token_str, - repeat_count=repeat_count, - pad_token_left=pad_token_str_left, - pad_token_right=pad_token_str_right, - )) placeholder_token_count = prompt.count(placeholder_token_str) # This is an arbitrary number to distinguish between the two cases @@ -216,28 +212,45 @@ def repeat_and_pad_placeholder_tokens( "Please follow the prompt format that is " "documented on HuggingFace which does not involve " "repeating %s tokens.", placeholder_token_str) - elif placeholder_token_count > 1: - logger.warning("Multiple multi-modal input is not supported yet, " - "so any extra placeholder tokens will be treated " - "as plain text.") - - # The image tokens are removed to be consistent with HuggingFace - new_prompt = prompt.replace(placeholder_token_str, replacement_str, 1) + if placeholder_token_count < len(repeat_count): + logger.warning( + "The number of multi-modal placeholder tokens in the prompt " + "is less than the number of multi-modal inputs. Extra " + "placeholder tokens will be treated as plain text") + repeat_count = repeat_count[:placeholder_token_count] + + prompt_parts = prompt.split(placeholder_token_str, + maxsplit=len(repeat_count)) + new_prompt = "" + for i, repeat_count_item in enumerate(repeat_count): + replacement_str = "".join( + repeat_and_pad_token( + placeholder_token_str, + repeat_count=repeat_count_item, + pad_token_left=pad_token_str_left, + pad_token_right=pad_token_str_right, + )) + # The image tokens are removed to be consistent with HuggingFace + new_prompt += prompt_parts[i] + replacement_str + new_prompt += prompt_parts[-1] new_token_ids: List[int] = [] + placeholder_token_idx = 0 for i, token in enumerate(prompt_token_ids): if token == placeholder_token_id: replacement_ids = repeat_and_pad_token( placeholder_token_id, - repeat_count=repeat_count, + repeat_count=repeat_count[placeholder_token_idx], pad_token_left=pad_token_left, pad_token_right=pad_token_right, ) new_token_ids.extend(replacement_ids) + placeholder_token_idx += 1 - # No need to further scan the list since we only replace once - new_token_ids.extend(prompt_token_ids[i + 1:]) - break + # No need to further scan the list since we replaced all tokens + if placeholder_token_idx >= len(repeat_count): + new_token_ids.extend(prompt_token_ids[i + 1:]) + break else: new_token_ids.append(token) diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index bda82d3712f09..8d18527e7c973 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -84,6 +84,9 @@ def warn_if_different_devices(): def device_id_to_physical_device_id(device_id: int) -> int: if "CUDA_VISIBLE_DEVICES" in os.environ: device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") + if device_ids == [""]: + raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string," + " which means GPU support is disabled.") physical_device_id = device_ids[device_id] return int(physical_device_id) else: diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 3f6f5adee5a56..28525e8ff8811 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,10 +1,21 @@ +import os from functools import lru_cache from typing import Tuple import torch +from vllm.logger import init_logger + from .interface import Platform, PlatformEnum +logger = init_logger(__name__) + +if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: + logger.warning("`fork` method is not supported by ROCm. " + "VLLM_WORKER_MULTIPROC_METHOD is overridden to" + " `spawn` instead.") + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + class RocmPlatform(Platform): _enum = PlatformEnum.ROCM diff --git a/vllm/sequence.py b/vllm/sequence.py index 2fe8ae9d7b270..964072dd7c8f1 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -5,8 +5,8 @@ from array import array from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Set, - Tuple, Union, cast) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Mapping, + Optional, Set, Tuple, Union, cast) import msgspec import torch @@ -474,11 +474,8 @@ def reset_state_for_recompute(self): """Reset the sequence states for recomputation.""" self.data.reset_state_for_recompute() - def append_token_id( - self, - token_id: int, - logprobs: Dict[int, Logprob], - ) -> None: + def append_token_id(self, token_id: int, logprobs: Dict[int, + Logprob]) -> None: assert token_id in logprobs self.output_logprobs.append(logprobs) self.data.append_token_id(token_id, logprobs[token_id].logprob) @@ -1293,6 +1290,8 @@ class ExecuteModelRequest( finished_requests_ids: List[str] = msgspec.field(default_factory=list) # The last sampled token ids for multi step decoding. last_sampled_token_ids: Optional[torch.Tensor] = None + # Async postprocessor + output_proc_callback_fn: Optional[Callable] = None @property def is_first_multi_step(self) -> bool: @@ -1338,4 +1337,5 @@ def clone( num_steps=self.num_steps, finished_requests_ids=self.finished_requests_ids, last_sampled_token_ids=self.last_sampled_token_ids.clone() - if self.last_sampled_token_ids is not None else None) + if self.last_sampled_token_ids is not None else None, + output_proc_callback_fn=self.output_proc_callback_fn) diff --git a/vllm/transformers_utils/detokenizer.py b/vllm/transformers_utils/detokenizer.py index b7624c471cdb2..d27d7ba9e67bb 100644 --- a/vllm/transformers_utils/detokenizer.py +++ b/vllm/transformers_utils/detokenizer.py @@ -230,7 +230,7 @@ def convert_prompt_ids_to_tokens( prefix_offset = max( read_offset - INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET, 0) # This is required to guard against out-of-vocab prompt token ids - _replace_none_with_empty(new_tokens) + _replace_none_with_empty(new_tokens) # type: ignore[arg-type] return new_tokens, prefix_offset, read_offset diff --git a/vllm/transformers_utils/tokenizer.py b/vllm/transformers_utils/tokenizer.py index 0271aa809320e..2866975850db3 100644 --- a/vllm/transformers_utils/tokenizer.py +++ b/vllm/transformers_utils/tokenizer.py @@ -1,4 +1,5 @@ import os +import warnings from pathlib import Path from typing import Optional, Union @@ -9,12 +10,14 @@ from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger from vllm.lora.request import LoRARequest -from vllm.transformers_utils.tokenizers import BaichuanTokenizer +from vllm.transformers_utils.tokenizers import (BaichuanTokenizer, + MistralTokenizer) from vllm.utils import make_async logger = init_logger(__name__) -AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast] +AnyTokenizer = Union[PreTrainedTokenizer, PreTrainedTokenizerFast, + MistralTokenizer] def get_cached_tokenizer(tokenizer: AnyTokenizer) -> AnyTokenizer: @@ -99,45 +102,64 @@ def get_tokenizer( kwargs["gguf_file"] = Path(tokenizer_name).name tokenizer_name = Path(tokenizer_name).parent - try: - tokenizer = AutoTokenizer.from_pretrained( - tokenizer_name, - *args, - trust_remote_code=trust_remote_code, - revision=revision, - **kwargs) - except ValueError as e: - # If the error pertains to the tokenizer class not existing or not - # currently being imported, suggest using the --trust-remote-code flag. - if (not trust_remote_code and - ("does not exist or is not currently imported." in str(e) - or "requires you to execute the tokenizer file" in str(e))): - err_msg = ( - "Failed to load the tokenizer. If the tokenizer is a custom " - "tokenizer not yet available in the HuggingFace transformers " - "library, consider setting `trust_remote_code=True` in LLM " - "or using the `--trust-remote-code` flag in the CLI.") - raise RuntimeError(err_msg) from e - else: - raise e - except AttributeError as e: - if "BaichuanTokenizer" in str(e): - # This is for the error "'BaichuanTokenizer' object has no - # attribute 'sp_model'". - tokenizer = BaichuanTokenizer.from_pretrained( + # if tokenizer is from official mistral org + is_from_mistral_org = str(tokenizer_name).split("/")[0] == "mistralai" + if is_from_mistral_org and tokenizer_mode != "mistral": + warnings.warn( + 'It is strongly recommended to run mistral models with ' + '`--tokenizer_mode "mistral"` to ensure correct ' + 'encoding and decoding.', + FutureWarning, + stacklevel=2) + + if tokenizer_mode == "mistral": + tokenizer = MistralTokenizer.from_pretrained(str(tokenizer_name), + revision=revision) + else: + try: + tokenizer = AutoTokenizer.from_pretrained( tokenizer_name, *args, trust_remote_code=trust_remote_code, revision=revision, - **kwargs) - else: - raise e + **kwargs, + ) + except ValueError as e: + # If the error pertains to the tokenizer class not existing or not + # currently being imported, + # suggest using the --trust-remote-code flag. + if not trust_remote_code and ( + "does not exist or is not currently imported." in str(e) + or "requires you to execute the tokenizer file" in str(e)): + err_msg = ("Failed to load the tokenizer. If the tokenizer " + "is a custom tokenizer not yet available in the " + "HuggingFace transformers library, consider " + "setting `trust_remote_code=True` in LLM or using " + "the `--trust-remote-code` flag in the CLI.") + raise RuntimeError(err_msg) from e + else: + raise e + except AttributeError as e: + if "BaichuanTokenizer" in str(e): + # This is for the error "'BaichuanTokenizer' object has no + # attribute 'sp_model'". + tokenizer = BaichuanTokenizer.from_pretrained( + tokenizer_name, + *args, + trust_remote_code=trust_remote_code, + revision=revision, + **kwargs, + ) + else: + raise e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + logger.warning( + "Using a slow tokenizer. This might cause a significant " + "slowdown. Consider using a fast tokenizer instead.") + tokenizer = get_cached_tokenizer(tokenizer) - if not isinstance(tokenizer, PreTrainedTokenizerFast): - logger.warning( - "Using a slow tokenizer. This might cause a significant " - "slowdown. Consider using a fast tokenizer instead.") - return get_cached_tokenizer(tokenizer) + return tokenizer def get_lora_tokenizer(lora_request: LoRARequest, *args, diff --git a/vllm/transformers_utils/tokenizers/__init__.py b/vllm/transformers_utils/tokenizers/__init__.py index e6b59722c2591..9433f2d48f6f3 100644 --- a/vllm/transformers_utils/tokenizers/__init__.py +++ b/vllm/transformers_utils/tokenizers/__init__.py @@ -1,5 +1,4 @@ from vllm.transformers_utils.tokenizers.baichuan import BaichuanTokenizer +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer -__all__ = [ - "BaichuanTokenizer", -] +__all__ = ["BaichuanTokenizer", "MistralTokenizer"] diff --git a/vllm/transformers_utils/tokenizers/mistral.py b/vllm/transformers_utils/tokenizers/mistral.py new file mode 100644 index 0000000000000..23ecfc0af6be4 --- /dev/null +++ b/vllm/transformers_utils/tokenizers/mistral.py @@ -0,0 +1,174 @@ +import os +import re +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from huggingface_hub import HfApi, hf_hub_download +# yapf: disable +from mistral_common.tokens.tokenizers.mistral import ChatCompletionRequest +from mistral_common.tokens.tokenizers.mistral import ( + MistralTokenizer as PublicMistralTokenizer) +# yapf: enable +from mistral_common.tokens.tokenizers.sentencepiece import ( + SentencePieceTokenizer) +from mistral_common.tokens.tokenizers.tekken import (SpecialTokenPolicy, + Tekkenizer) + +if TYPE_CHECKING: + from vllm.entrypoints.chat_utils import ConversationMessage + + +@dataclass +class Encoding: + input_ids: List[int] + + +def find_tokenizer_file(files: List[str]): + file_pattern = re.compile(r"^tokenizer\.model\.v.*$|^tekken\.json$") + + matched_files = [file for file in files if file_pattern.match(file)] + if len(matched_files) > 1: + raise OSError(f"Found {len(matched_files)} files matching the " + "pattern: {matched_files}. Make sure only one Mistral " + "tokenizer is present in {tokenizer_name}.") + elif len(matched_files) == 0: + raise OSError(f"Found {len(matched_files)} files matching the " + "pattern: {matched_files}. Make sure that a Mistral " + "tokenizer is present in {tokenizer_name}.") + + return matched_files[0] + + +class MistralTokenizer: + + def __init__(self, tokenizer: PublicMistralTokenizer) -> None: + self.mistral = tokenizer + self.instruct = tokenizer.instruct_tokenizer + self.tokenizer = tokenizer.instruct_tokenizer.tokenizer + + self.vocab_size = len(self.tokenizer.vocab()) + + assert isinstance(self.tokenizer, + (Tekkenizer, SentencePieceTokenizer)), type( + self.tokenizer) + self._is_tekken = isinstance(self.tokenizer, Tekkenizer) + + if self._is_tekken: + # Make sure special tokens will not raise + self.tokenizer.special_token_policy = SpecialTokenPolicy.IGNORE + + # the following attributes are set to fit VLLM's design + self.is_fast = True + self.chat_template = True + self.all_special_ids: List[Any] = [] + self.all_special_tokens: List[Any] = [] + self.all_special_tokens_extended: List[Any] = [] + + @classmethod + def from_pretrained(cls, + path_or_repo_id: str, + *, + revision: Optional[str] = None) -> "MistralTokenizer": + if not Path(path_or_repo_id).exists(): + assert len(path_or_repo_id.split("/")) == 2, ( + "You have either provided a non-existent path: " + "{path_or_repo_id} or an invalid HF Hub repo id.") + tokenizer_file = cls._download_mistral_tokenizer_from_hf( + path_or_repo_id, revision) + elif Path(path_or_repo_id).is_dir(): + tokenizer_file_name = find_tokenizer_file( + os.listdir(path_or_repo_id)) + tokenizer_file = str(Path(path_or_repo_id) / tokenizer_file_name) + else: + assert Path( + path_or_repo_id).is_file(), f"Invalid path: {path_or_repo_id}" + + mistral_tokenizer = PublicMistralTokenizer.from_file(tokenizer_file) + return cls(mistral_tokenizer) + + @staticmethod + def _download_mistral_tokenizer_from_hf(tokenizer_name: str, + revision: Optional[str]) -> str: + api = HfApi() + repo_info = api.model_info(tokenizer_name) + files = [s.rfilename for s in repo_info.siblings] + + filename = find_tokenizer_file(files) + + tokenizer_file = hf_hub_download(tokenizer_name, + filename=filename, + revision=revision) + return tokenizer_file + + def __call__( + self, + prompt: str, + add_special_tokens: bool = False, + truncation: bool = False, + max_length: Optional[int] = None, + ): + # Mistral Tokenizers should not add special tokens + input_ids = self.encode(prompt) + + if truncation: + input_ids = input_ids[:max_length] + + return Encoding(input_ids=input_ids) + + def get_added_vocab(self) -> List[str]: + # Mistral tokenizers have no added vocabulary + return [] + + def encode(self, prompt: str) -> List[int]: + # `encode ` should only be used for prompt completion + # it should never be used for chat_completion. + # For chat completion use `apply_chat_template` + return self.tokenizer.encode(prompt, bos=True, eos=False) + + def apply_chat_template(self, + conversation: List["ConversationMessage"], + tools: Optional[Dict[str, Any]] = None, + **kwargs) -> List[int]: + assert tools is None, "`tools` are not yet supported." + + request = ChatCompletionRequest( + messages=conversation) # type: ignore[type-var] + encoded = self.mistral.encode_chat_completion(request) + + # encode-decode to get clean prompt + return encoded.tokens + + def convert_tokens_to_string(self, tokens: List[str]) -> str: + if self._is_tekken: + return "".join(tokens) + else: + return self.tokenizer.decode(tokens) # type: ignore[arg-type] + + def decode(self, ids: Union[List[int], int]) -> str: + if isinstance(ids, int): + ids = [ids] + return self.tokenizer.decode(ids) + + @property + def eos_token_id(self): + return self.tokenizer.eos_id + + def convert_ids_to_tokens( + self, + ids: List[int], + skip_special_tokens: Optional[bool] = True) -> List[str]: + # TODO(Patrick) - potentially allow special tokens to not be skipped + assert ( + skip_special_tokens + ), "Skipping special tokens is not supported for Mistral tokenizers." + + assert isinstance(self.tokenizer, + (Tekkenizer, SentencePieceTokenizer)), type( + self.tokenizer) + + tokens = [self.tokenizer.id_to_piece(id) for id in ids] + return tokens + + def __len__(self): + return self.vocab_size diff --git a/vllm/utils.py b/vllm/utils.py index 0b7457a70b362..dab8e5fe04359 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -25,6 +25,7 @@ import psutil import torch import torch.types +from packaging.version import Version from typing_extensions import ParamSpec, TypeIs, assert_never import vllm.envs as envs @@ -1114,3 +1115,11 @@ async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, """Utility function to run async task in a lock""" async with lock: return await task(*args, **kwargs) + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + base_torch_version = Version(Version(torch.__version__).base_version) + return base_torch_version >= Version("2.4.0") diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 5d930919b8ae5..6073810962769 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -6,8 +6,8 @@ import warnings import weakref from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Set, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, + Tuple, Type, TypeVar, Union) import numpy as np import torch @@ -44,7 +44,8 @@ from vllm.sequence import (IntermediateTensors, SamplerOutput, SequenceGroupMetadata) from vllm.utils import (CudaMemoryProfiler, PyObjectCache, async_tensor_h2d, - flatten_2d_lists, is_hip, is_pin_memory_available) + flatten_2d_lists, is_hip, is_pin_memory_available, + supports_dynamo) from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -90,6 +91,7 @@ class ModelInputForGPU(ModelRunnerInputBase): request_ids_to_seq_ids: Optional[Dict[str, List[int]]] = None finished_requests_ids: Optional[List[str]] = None virtual_engine: int = 0 + output_proc_callback_fn: Optional[Callable] = None def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { @@ -945,7 +947,7 @@ def load_model(self) -> None: "provided. Defaulting to scaling factors of 1.0. " "This may lead to less accurate results!") - if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE: + if envs.VLLM_TEST_DYNAMO_GRAPH_CAPTURE and supports_dynamo(): self.model = torch.compile(self.model, fullgraph=True, backend="eager") @@ -1096,6 +1098,10 @@ def profile_run(self) -> None: device=self.device) self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() + + # reset and discard the guard and compiled bytecode for profiling runs + torch._dynamo.reset() + return def remove_all_loras(self): @@ -1327,7 +1333,7 @@ def prepare_model_input( self, seq_group_metadata_list: List[SequenceGroupMetadata], virtual_engine: int = 0, - finished_requests_ids: Optional[List[str]] = None + finished_requests_ids: Optional[List[str]] = None, ) -> ModelInputForGPUWithSamplingMetadata: """Prepare the model input based on a given sequence group, including metadata for the sampling step. @@ -1451,6 +1457,9 @@ def execute_model( if not self.is_driver_worker: return [] + if model_input.output_proc_callback_fn is not None: + model_input.output_proc_callback_fn(is_async=True) + # Sample the next token. output: SamplerOutput = self.model.sample( logits=logits, diff --git a/vllm/worker/neuron_worker.py b/vllm/worker/neuron_worker.py index 3b0ded36ca1b6..fff14d6402b44 100644 --- a/vllm/worker/neuron_worker.py +++ b/vllm/worker/neuron_worker.py @@ -6,6 +6,8 @@ from vllm.config import (CacheConfig, DeviceConfig, ModelConfig, ParallelConfig, SchedulerConfig) +from vllm.distributed import (ensure_model_parallel_initialized, + init_distributed_environment) from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.worker.neuron_model_runner import NeuronModelRunner @@ -24,12 +26,18 @@ def __init__( scheduler_config: SchedulerConfig, device_config: DeviceConfig, cache_config: CacheConfig, + local_rank: int, + rank: int, + distributed_init_method: str, ) -> None: self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.device_config = device_config self.cache_config = cache_config + self.local_rank = local_rank + self.rank = rank + self.distributed_init_method = distributed_init_method if self.model_config.trust_remote_code: # note: lazy import to avoid importing torch before initializing from vllm.utils import init_cached_hf_modules @@ -40,6 +48,8 @@ def __init__( self.is_driver_worker = True def init_device(self) -> None: + self.init_distributed_environment() + # Set random seed. set_random_seed(self.model_config.seed) @@ -98,3 +108,20 @@ def get_cache_block_size_bytes(self) -> int: This is required for speculative decoding; it is not yet implemented. """ raise NotImplementedError + + def init_distributed_environment(self): + """Neuron uses transformers-neuronx for tensor parallelism. + + vLLM still needs the environment inited when TP/PP > 1 + """ + init_distributed_environment( + world_size=1, + rank=self.rank, + local_rank=self.local_rank, + distributed_init_method=self.distributed_init_method, + backend="gloo", + ) + ensure_model_parallel_initialized( + 1, + 1, + ) diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 44fa3aed5816d..320b15d3604bc 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -143,6 +143,10 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_cpu_blocks = int(self.cache_config.swap_space_bytes // block_size_bytes) num_cpu_blocks = (num_cpu_blocks // 8) * 8 # Round down to 8. + + # reset and discard the guard and compiled bytecode for profiling runs + torch._dynamo.reset() + return num_tpu_blocks, num_cpu_blocks def initialize_cache( diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index 516e386595195..e35d5c962a489 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -263,6 +263,12 @@ def _get_driver_input_and_broadcast( broadcast_data.update(kwargs) broadcast_tensor_dict(broadcast_data, src=0) + if execute_model_req.output_proc_callback_fn: + model_input = dataclasses.replace( # type: ignore + model_input, + output_proc_callback_fn=execute_model_req. + output_proc_callback_fn) + return model_input, worker_input, kwargs def prepare_input( @@ -289,7 +295,7 @@ def prepare_input( def execute_model( self, - execute_model_req: Optional[ExecuteModelRequest] = None + execute_model_req: Optional[ExecuteModelRequest] = None, ) -> Optional[List[SamplerOutput]]: """Executes at least one model step on the given sequences, unless no sequences are provided.""" diff --git a/vllm/worker/xpu_model_runner.py b/vllm/worker/xpu_model_runner.py index 0335bbcd091e8..3894658a095f3 100644 --- a/vllm/worker/xpu_model_runner.py +++ b/vllm/worker/xpu_model_runner.py @@ -12,6 +12,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor.model_loader import get_model @@ -439,9 +440,11 @@ def profile_run(self) -> None: "Setting it to the minimum value of 1.", expr) max_num_seqs = 1 + batch_size = 0 for group_id in range(max_num_seqs): seq_len = (max_num_batched_tokens // max_num_seqs + (group_id < max_num_batched_tokens % max_num_seqs)) + batch_size += seq_len seq_data, dummy_multi_modal_data = self.input_registry \ .dummy_data_for_profiling(self.model_config, @@ -465,7 +468,13 @@ def profile_run(self) -> None: finished_requests_ids = [seq.request_id for seq in seqs] model_input = self.prepare_model_input( seqs, finished_requests_ids=finished_requests_ids) - self.execute_model(model_input, kv_caches) + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = self.model.make_empty_intermediate_tensors( + batch_size=batch_size, + dtype=self.model_config.dtype, + device=self.device) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.xpu.synchronize() return @@ -537,7 +546,7 @@ def execute_model( and self.observability_config.collect_model_forward_time): model_forward_start_time = time.time() - hidden_states = model_executable( + hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, kv_caches=kv_caches, @@ -545,12 +554,16 @@ def execute_model( intermediate_tensors=intermediate_tensors, **MultiModalInputs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device)) + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + if (self.observability_config is not None and self.observability_config.collect_model_forward_time): model_forward_end_time = time.time() # Compute the logits. - logits = self.model.compute_logits(hidden_states, + logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) # Only perform sampling in the driver worker. diff --git a/vllm/worker/xpu_worker.py b/vllm/worker/xpu_worker.py index b00d1889f8d4b..9ad070d042a3d 100644 --- a/vllm/worker/xpu_worker.py +++ b/vllm/worker/xpu_worker.py @@ -14,6 +14,7 @@ SpeculativeConfig) from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment) +from vllm.distributed.parallel_state import get_pp_group from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.utils import is_xpu @@ -198,3 +199,8 @@ def init_worker_distributed_environment(self) -> None: ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + + if parallel_config.pipeline_parallel_size > 1: + # torch-ccl xpu need a collective API warm up + # before calling send/recv API + get_pp_group().all_reduce(torch.zeros(1).xpu())