From a6c0f3658da4f2f23460e3e15bfa7d70ac7e60c1 Mon Sep 17 00:00:00 2001 From: William Lin Date: Thu, 12 Sep 2024 11:16:22 -0700 Subject: [PATCH] [multi-step] add flashinfer backend (#7928) --- csrc/ops.h | 19 +- csrc/prepare_inputs/advance_step.cu | 225 ++++++++++++++++-- csrc/torch_bindings.cpp | 15 +- .../multi_step/test_correctness_async_llm.py | 12 +- vllm/_custom_ops.py | 38 ++- vllm/attention/backends/abstract.py | 4 +- vllm/attention/backends/flash_attn.py | 18 +- vllm/attention/backends/flashinfer.py | 87 ++++++- vllm/worker/multi_step_model_runner.py | 37 ++- 9 files changed, 371 insertions(+), 84 deletions(-) diff --git a/csrc/ops.h b/csrc/ops.h index 05b89e183ca2..5333b22c536d 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -54,10 +54,21 @@ void gelu_fast(torch::Tensor& out, torch::Tensor& input); void gelu_quick(torch::Tensor& out, torch::Tensor& input); -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables); + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bounds); #ifndef USE_ROCM torch::Tensor aqlm_gemm(const torch::Tensor& input, const torch::Tensor& codes, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 0e537ddd6c4c..a9d08ca0dc14 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -12,13 +12,11 @@ namespace prepare_inputs { // template -__global__ void advance_step_kernel(int num_seqs, int num_queries, - int block_size, long* input_tokens_ptr, - long const* sampled_token_ids_ptr, - long* input_positions_ptr, - int* seq_lens_ptr, long* slot_mapping_ptr, - int const* block_tables_ptr, - int64_t const block_tables_stride) { +__global__ void advance_step_flashattn_kernel( + int num_seqs, int num_queries, int block_size, long* input_tokens_ptr, + long const* sampled_token_ids_ptr, long* input_positions_ptr, + int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, + int64_t const block_tables_stride) { int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x >= num_query_blocks) { @@ -79,16 +77,91 @@ inline void verify_tensor(std::string const& name, torch::Tensor& t, } } -void advance_step(int num_seqs, int num_queries, int block_size, - torch::Tensor& input_tokens, // type: long - torch::Tensor& sampled_token_ids, // type: long - torch::Tensor& input_positions, // type: long - torch::Tensor& seq_lens, // type: int - torch::Tensor& slot_mapping, // type: long - torch::Tensor& block_tables) { // type: int +__global__ void advance_step_flashinfer_kernel( + int num_threads, int num_seqs, int num_queries, int block_size, + long* input_tokens_ptr, long const* sampled_token_ids_ptr, + long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, + int const* block_tables_ptr, int64_t const block_tables_stride, + int* paged_kv_last_page_len_ptr, int* block_table_bound_ptr) { + int num_query_blocks = div_ceil(num_queries, num_threads); + + if (blockIdx.x < num_query_blocks) { + int cur_query_id = blockIdx.x * num_threads + threadIdx.x; + + if (cur_query_id < num_queries) { + // Update input_tokens + input_tokens_ptr[cur_query_id] = sampled_token_ids_ptr[cur_query_id]; + + int seq_len = seq_lens_ptr[cur_query_id]; + int next_seq_len = seq_len + 1; + int next_input_pos = next_seq_len - 1; + + // Update seq_lens + seq_lens_ptr[cur_query_id] = next_seq_len; + // Update input_positions + input_positions_ptr[cur_query_id] = next_input_pos; + + int const* seq_block_tables_ptr = + block_tables_ptr + block_tables_stride * cur_query_id; + + int block_index = next_input_pos / block_size; + int block_offset = next_input_pos % block_size; + + // Update paged_kv_last_page_len + paged_kv_last_page_len_ptr[cur_query_id] = block_offset + 1; + + int slot_num = + seq_block_tables_ptr[block_index] * block_size + block_offset; + // Update slot_mapping + slot_mapping_ptr[cur_query_id] = slot_num; + block_table_bound_ptr[cur_query_id] = div_ceil(next_seq_len, block_size); + } + } +} + +__global__ void advance_step_flashinfer_indptr_kernel( + int num_threads, int num_seqs, int num_queries, int* paged_kv_indptr_ptr, + int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + + // Update paged_kv_indptr + if (idx < num_queries) { + int sum = 0; + for (int i = 0; i <= idx; ++i) { + sum += block_table_bound_ptr[i]; + } + paged_kv_indptr_ptr[idx + 1] = sum; + } +} + +__global__ void advance_step_flashinfer_indices_kernel( + int num_threads, int num_seqs, int num_queries, int const* block_tables_ptr, + int64_t const block_tables_stride, int* paged_kv_indices_ptr, + int* paged_kv_indptr_ptr, int* block_table_bound_ptr) { + int idx = blockIdx.x * num_threads + threadIdx.x; + int row = idx / block_tables_stride; + int col = idx % block_tables_stride; + + if (row < num_queries && col < block_table_bound_ptr[row]) { + paged_kv_indices_ptr[paged_kv_indptr_ptr[row] + col] = + block_tables_ptr[row * block_tables_stride + col]; + } + // if cudagraph, fill padded seqs with the last valid seq's indptr + if (num_queries < row && row <= num_seqs) { + paged_kv_indptr_ptr[row] = paged_kv_indptr_ptr[num_queries]; + } +} + +void advance_step_flashattn(int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables) { // type: int if (logging) { - printf("advance_step:\n"); + printf("advance_step_flashattn:\n"); printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); printf(" block_size = %d\n", block_size); @@ -108,24 +181,126 @@ void advance_step(int num_seqs, int num_queries, int block_size, int blocks; cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); - advance_step_kernel<<>>( - num_seqs, num_queries, block_size, + advance_step_flashattn_kernel + <<>>( + num_seqs, num_queries, block_size, + reinterpret_cast(input_tokens.data_ptr()), + reinterpret_cast(sampled_token_ids.data_ptr()), + reinterpret_cast(input_positions.data_ptr()), + reinterpret_cast(seq_lens.data_ptr()), + reinterpret_cast(slot_mapping.data_ptr()), + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0)); +} + +void advance_step_flashinfer( + int num_seqs, int num_queries, int block_size, + torch::Tensor& input_tokens, // type: long + torch::Tensor& sampled_token_ids, // type: long + torch::Tensor& input_positions, // type: long + torch::Tensor& seq_lens, // type: int + torch::Tensor& slot_mapping, // type: long + torch::Tensor& block_tables, // type: int + torch::Tensor& paged_kv_indices, // type: int + torch::Tensor& paged_kv_indptr, // type: int + torch::Tensor& paged_kv_last_page_len, // type: int + torch::Tensor& block_table_bound) { // type: int + + if (logging) { + printf("advance_step_flashinfer:\n"); + printf(" num_seqs = %d\n", num_seqs); + printf(" num_queries = %d\n", num_queries); + printf(" block_size = %d\n", block_size); + printf(" block_tables.stride(0) = %d\n", block_tables.stride(0)); + } + // Verify all tensors + verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); + // verify_tensor("sampled_token_ids", sampled_token_ids, num_queries, 1, + // at::kLong); + verify_tensor("input_positions", input_positions, num_seqs, -1, at::kLong); + verify_tensor("seq_lens", seq_lens, num_seqs, -1, at::kInt); + verify_tensor("slot_mapping", slot_mapping, num_seqs, -1, at::kLong); + verify_tensor("block_tables", block_tables, num_seqs, -1, at::kInt); + + verify_tensor("paged_kv_indices", paged_kv_indices, -1, -1, at::kInt); + verify_tensor("paged_kv_indptr", paged_kv_indptr, num_seqs + 1, -1, at::kInt); + verify_tensor("paged_kv_last_page_len", paged_kv_last_page_len, num_seqs, -1, + at::kInt); + + verify_tensor("block_table_bound", block_table_bound, num_seqs, -1, at::kInt); + + int dev = sampled_token_ids.get_device(); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev); + + int blocks; + int threads; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev); + cudaDeviceGetAttribute(&threads, cudaDevAttrMaxThreadsPerBlock, dev); + if (logging) { + printf("launching kernel with %d blocks\n", blocks); + } + + // TODO(will): support arbitrary block_tables stride + if ((blocks * threads) / block_tables.stride(0) < num_queries) { + TORCH_CHECK(false, + "multi-step: not enough threads to map block_table to" + "FlashInfer's paged_kv_indices on GPU. Try reducing the number " + "of seqs,", + " increasing the block size or take smaller steps.", + " num_queries = ", num_queries, + " block_tables.stride(0) = ", block_tables.stride(0), + " blocks = ", blocks, " max_threads = ", threads); + } + + advance_step_flashinfer_kernel<<>>( + threads, num_seqs, num_queries, block_size, reinterpret_cast(input_tokens.data_ptr()), reinterpret_cast(sampled_token_ids.data_ptr()), reinterpret_cast(input_positions.data_ptr()), reinterpret_cast(seq_lens.data_ptr()), reinterpret_cast(slot_mapping.data_ptr()), reinterpret_cast(block_tables.data_ptr()), - block_tables.stride(0)); + block_tables.stride(0), + reinterpret_cast(paged_kv_last_page_len.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indptr_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); + + advance_step_flashinfer_indices_kernel<<>>( + threads, num_seqs, num_queries, + reinterpret_cast(block_tables.data_ptr()), + block_tables.stride(0), + reinterpret_cast(paged_kv_indices.data_ptr()), + reinterpret_cast(paged_kv_indptr.data_ptr()), + reinterpret_cast(block_table_bound.data_ptr())); } } // namespace prepare_inputs -void advance_step(int64_t num_seqs, int64_t num_queries, int64_t block_size, - torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, - torch::Tensor& input_positions, torch::Tensor& seq_lens, - torch::Tensor& slot_mapping, torch::Tensor& block_tables) { - prepare_inputs::advance_step(num_seqs, num_queries, block_size, input_tokens, - sampled_token_ids, input_positions, seq_lens, - slot_mapping, block_tables); +void advance_step_flashattn(int64_t num_seqs, int64_t num_queries, + int64_t block_size, torch::Tensor& input_tokens, + torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, + torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, + torch::Tensor& block_tables) { + prepare_inputs::advance_step_flashattn( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables); +} + +void advance_step_flashinfer( + int64_t num_seqs, int64_t num_queries, int64_t block_size, + torch::Tensor& input_tokens, torch::Tensor& sampled_token_ids, + torch::Tensor& input_positions, torch::Tensor& seq_lens, + torch::Tensor& slot_mapping, torch::Tensor& block_tables, + torch::Tensor& paged_kv_indices, torch::Tensor& paged_kv_indptr, + torch::Tensor& paged_kv_last_page_len, torch::Tensor& block_table_bound) { + prepare_inputs::advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, + paged_kv_indptr, paged_kv_last_page_len, block_table_bound); } \ No newline at end of file diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 57103c0936f5..51afeacfdc0a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -74,11 +74,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { // prepare_inputs advance_step ops.def( - "advance_step(int num_seqs, int num_queries, int block_size, " + "advance_step_flashattn(int num_seqs, int num_queries, int block_size, " "Tensor! input_tokens, Tensor sampled_token_ids, " "Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping, " "Tensor block_tables) -> ()"); - ops.impl("advance_step", torch::kCUDA, &advance_step); + ops.impl("advance_step_flashattn", torch::kCUDA, &advance_step_flashattn); + + ops.def( + "advance_step_flashinfer(" + " int num_seqs, int num_queries, int block_size," + " Tensor! input_tokens, Tensor sampled_token_ids," + " Tensor! input_positions, Tensor! seq_lens, Tensor! slot_mapping," + " Tensor block_tables, Tensor! paged_kv_indices," + " Tensor! paged_kv_indptr, Tensor! paged_kv_last_page_len," + " Tensor! block_table_bounds" + ") -> ()"); + ops.impl("advance_step_flashinfer", torch::kCUDA, &advance_step_flashinfer); // Layernorm // Apply Root Mean Square (RMS) Normalization to the input tensor. diff --git a/tests/multi_step/test_correctness_async_llm.py b/tests/multi_step/test_correctness_async_llm.py index 0cbe8371e235..a75a671e57f7 100644 --- a/tests/multi_step/test_correctness_async_llm.py +++ b/tests/multi_step/test_correctness_async_llm.py @@ -1,9 +1,10 @@ # Test the AsyncLLMEngine with multi-step-decoding - from typing import List, Optional import pytest +from tests.kernels.utils import override_backend_env_variable + from ..models.utils import check_logprobs_close from ..utils import (completions_with_server_args, get_client_text_generations, get_client_text_logprob_generations) @@ -33,8 +34,9 @@ @pytest.mark.parametrize("eager_mode", [False, True]) @pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) @pytest.mark.parametrize("num_prompts", NUM_PROMPTS) -@pytest.mark.parametrize("num_logprobs", [None, 5]) -@pytest.mark.parametrize("is_async", [False, True]) +@pytest.mark.parametrize("num_logprobs", [5]) +@pytest.mark.parametrize("is_async", [True]) +@pytest.mark.parametrize("attention_backend", ["FLASHINFER", "FLASH_ATTN"]) @pytest.mark.asyncio async def test_multi_step( example_prompts, @@ -46,6 +48,8 @@ async def test_multi_step( num_prompts: int, is_async: bool, num_logprobs: Optional[int], + attention_backend: str, + monkeypatch, ) -> None: """Test vLLM engine with multi-step scheduling in an OpenAI-protocol client/server environment. @@ -71,6 +75,8 @@ async def test_multi_step( completions endpoint; `None` -> no logprobs """ + override_backend_env_variable(monkeypatch, attention_backend) + prompts = example_prompts if len(prompts) < num_prompts: prompts = prompts * ((num_prompts // len(prompts)) + 1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 7a9061526ef2..efa02d36c4ac 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -161,16 +161,36 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, torch.ops._C.fused_add_rms_norm(input, residual, weight, epsilon) -def advance_step(num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, seq_lens: torch.Tensor, - slot_mapping: torch.Tensor, - block_tables: torch.Tensor) -> None: +def advance_step_flashattn(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor) -> None: """Advance a step on GPU for existing inputs for a multi-step runner""" - return torch.ops._C.advance_step(num_seqs, num_queries, block_size, - input_tokens, sampled_token_ids, - input_positions, seq_lens, slot_mapping, - block_tables) + return torch.ops._C.advance_step_flashattn(num_seqs, num_queries, + block_size, input_tokens, + sampled_token_ids, + input_positions, seq_lens, + slot_mapping, block_tables) + + +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + + return torch.ops._C.advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + block_table_bound) # quantization ops diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index ccfc6b254c1e..adc8390e6f9e 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -83,7 +83,9 @@ def copy_blocks( ) -> None: raise NotImplementedError - def advance_step(self, num_seqs: int, num_queries: int): + def advance_step(self, model_input: "ModelRunnerInputBase", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, num_seqs: int, num_queries: int) -> None: raise NotImplementedError diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index ec9cbde7467d..bf883987bd80 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -380,15 +380,15 @@ def advance_step(self, model_input: "ModelInputForGPUWithSamplingMetadata", self.seq_lens[i] += 1 self.max_decode_seq_len = max(self.seq_lens) - ops.advance_step(num_seqs=num_seqs, - num_queries=num_queries, - block_size=block_size, - input_tokens=model_input.input_tokens, - sampled_token_ids=sampled_token_ids, - input_positions=model_input.input_positions, - seq_lens=self.seq_lens_tensor, - slot_mapping=self.slot_mapping, - block_tables=self.block_tables) + ops.advance_step_flashattn(num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=sampled_token_ids, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables) class FlashAttentionMetadataBuilder( diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 7aec8203eb1e..58d62e02e873 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -30,7 +30,8 @@ make_tensor_with_pad) if TYPE_CHECKING: - from vllm.worker.model_runner import ModelInputForGPUBuilder + from vllm.worker.model_runner import (ModelInputForGPUBuilder, + ModelInputForGPUWithSamplingMetadata) class FlashInferBackend(AttentionBackend): @@ -268,6 +269,10 @@ class FlashInferMetadata(AttentionMetadata): query_start_loc: Optional[torch.Tensor] = None block_tables: Optional[torch.Tensor] = None + # used for GPU in-place advance_step + seq_lens_tensor: Optional[torch.Tensor] = None + block_table_bound: Optional[torch.Tensor] = None + # An example for paged_kv_indices, paged_kv_indptr: # request 1, page indices [0, 5, 8] # request 2, page indices [1, 6, 7] @@ -318,6 +323,8 @@ def begin_forward(self): assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None batch_size = self.query_start_loc.shape[0] - 1 assert batch_size >= 0 # We will use flash attention for profiling to @@ -327,6 +334,8 @@ def begin_forward(self): self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) + self.block_table_bound = self.block_table_bound.to(self.device) + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) self.paged_kv_indices = self.paged_kv_indices.to(self.device) self.prefill_wrapper.end_forward() self.prefill_wrapper.begin_forward( @@ -335,14 +344,18 @@ def begin_forward(self): self.num_qo_heads, self.num_kv_heads, self.head_dim, self.page_size) else: - if not self.use_cuda_graph: - assert self.paged_kv_indices is not None - assert self.paged_kv_indptr is not None - assert self.paged_kv_last_page_len is not None - self.paged_kv_indices = self.paged_kv_indices.to(self.device) - self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) - self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( - self.device) + assert self.paged_kv_indices is not None + assert self.paged_kv_indptr is not None + assert self.paged_kv_last_page_len is not None + self.paged_kv_indices = self.paged_kv_indices.to(self.device) + self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) + self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( + self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None self.decode_wrapper.end_forward() @@ -391,6 +404,48 @@ def decode_metadata(self) -> Optional["FlashInferMetadata"]: return self + def advance_step( + self, + model_input: "ModelInputForGPUWithSamplingMetadata", + sampled_token_ids: Optional[torch.Tensor], + block_size: int, + num_seqs: int, + num_queries: int, + ): + """ + Update metadata in-place to advance one decode step. + """ + + assert num_seqs > 0 + assert num_queries > 0 + assert model_input.attn_metadata is not None + assert sampled_token_ids is not None + + # When using cudagraph, the num_seqs is padded to the next captured + # batch sized, but num_queries tracks the actual number of requests in + # the batch. For --enforce-eager mode, num_seqs == num_queries + if num_seqs != num_queries: + assert num_seqs > num_queries + assert self.use_cuda_graph + + model_input.input_tokens[:num_queries] = sampled_token_ids.flatten() + + # Update GPU tensors + ops.advance_step_flashinfer( + num_seqs=num_seqs, + num_queries=num_queries, + block_size=block_size, + input_tokens=model_input.input_tokens, + sampled_token_ids=model_input.input_tokens, + input_positions=model_input.input_positions, + seq_lens=self.seq_lens_tensor, + slot_mapping=self.slot_mapping, + block_tables=self.block_tables, + paged_kv_indices=self.paged_kv_indices, + paged_kv_indptr=self.paged_kv_indptr, + paged_kv_last_page_len=self.paged_kv_last_page_len, + block_table_bound=self.block_table_bound) + class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): @@ -428,7 +483,7 @@ def __init__(self, input_builder: "ModelInputForGPUBuilder"): self.paged_kv_indptr: List[int] = [0] # paged_kv_last_page_len is the length of the last page of each request self.paged_kv_last_page_len: List[int] = [] - + self.total_blocks = 0 self.is_profile_run: bool = False def _add_seq_group( @@ -499,6 +554,7 @@ def _update_paged_kv_tensors(self, block_table: List[int], seq_len: int): # block_table_bound is 1 with 1 valid block. # If seq_len = 15, block_size = 16, # block_table_bound is 0 + 1 with 1 valid block. + self.total_blocks += len(block_table) block_table_bound = seq_len // self.block_size + 1 \ if seq_len % self.block_size != 0 \ else seq_len // self.block_size @@ -583,6 +639,10 @@ def build(self, seq_lens: List[int], query_lens: List[int], out=query_start_loc[1:]) if len(self.paged_kv_indptr) > 0: + # extend to the maximum number of blocks as returned by the + # scheduler + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", dtype=torch.int) @@ -591,10 +651,15 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int) paged_kv_last_page_len_tensor = torch.tensor( self.paged_kv_last_page_len, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None + block_table_bound_tensor = None if self.runner.kv_cache_dtype.startswith("fp8"): kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( @@ -613,6 +678,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], paged_kv_indptr=paged_kv_indptr_tensor, paged_kv_indices=paged_kv_indices_tensor, paged_kv_last_page_len=paged_kv_last_page_len_tensor, + block_table_bound=block_table_bound_tensor, + seq_lens_tensor=seq_lens_tensor, num_qo_heads=self.runner.model_config.get_num_attention_heads( self.runner.parallel_config), num_kv_heads=self.runner.model_config.get_num_kv_heads( diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index cd9b20083c1a..b900eb5a610f 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -4,13 +4,6 @@ from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union) -try: - from vllm.attention.backends.flash_attn import FlashAttentionMetadata -except ModuleNotFoundError: - # vllm_flash_attn is not installed, use the identical ROCm FA metadata - from vllm.attention.backends.rocm_flash_attn import ( - ROCmFlashAttentionMetadata as FlashAttentionMetadata) - import torch from vllm.distributed import get_pp_group @@ -36,6 +29,8 @@ logger = init_logger(__name__) +MULTI_STEP_ATTENTION_BACKENDS = ["flash-attn", "flashinfer"] + def seq_output_builder(): return SequenceOutput( @@ -489,27 +484,27 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: - frozen_model_input = model_input.frozen_model_input - assert frozen_model_input is not None - assert frozen_model_input.attn_metadata is not None + if self.attn_backend.get_name() not in MULTI_STEP_ATTENTION_BACKENDS: + raise ValueError( + f"Multi-step not supported for attention backend: " + f"{self.attn_backend.get_name()}. Set VLLM_ATTENTION_BACKEND " + f"to a value from {MULTI_STEP_ATTENTION_BACKENDS}.") + sampled_token_ids = model_input.cached_outputs[-1].sampled_token_ids num_seqs = model_input.num_seqs num_queries = model_input.num_queries - assert num_seqs > 0 - assert num_queries > 0 - assert num_seqs >= num_queries - + frozen_model_input = model_input.frozen_model_input + assert frozen_model_input is not None attn_metadata = frozen_model_input.attn_metadata - assert isinstance(attn_metadata, FlashAttentionMetadata) + assert attn_metadata is not None attn_metadata.advance_step( frozen_model_input, - model_input.cached_outputs[-1].sampled_token_ids, self.block_size, - num_seqs, num_queries) - - if frozen_model_input.seq_lens is not None: - for i in range(num_queries): - frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] + sampled_token_ids, + self.block_size, + num_seqs, + num_queries, + ) return model_input