diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 34b4858bc5ed..cae942db0084 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -240,7 +240,7 @@ void advance_step_flashinfer( if (logging) { printf("launching kernel with %d blocks\n", blocks); } - + // TODO(will): support arbitrary block_tables stride if ((blocks * threads) / block_tables.stride(0) < num_queries) { TORCH_CHECK(false, diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 236a07ad4f9e..90a528c2ce9d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -171,18 +171,23 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int, input_positions, seq_lens, slot_mapping, block_tables) -def advance_step_flashinfer( num_seqs: int, num_queries: int, block_size: int, - input_tokens: torch.Tensor, sampled_token_ids: torch.Tensor, - input_positions: torch.Tensor, seq_lens: torch.Tensor, - slot_mapping: torch.Tensor, block_tables: torch.Tensor, - paged_kv_indices: torch.Tensor, paged_kv_indptr: torch.Tensor, - paged_kv_last_page_len: torch.Tensor, block_table_bound: torch.Tensor -) -> None: - return torch.ops._C.advance_step_flashinfer(num_seqs, num_queries, - block_size, input_tokens, sampled_token_ids, input_positions, - seq_lens, slot_mapping, block_tables, paged_kv_indices, paged_kv_indptr, - paged_kv_last_page_len, block_table_bound) +def advance_step_flashinfer(num_seqs: int, num_queries: int, block_size: int, + input_tokens: torch.Tensor, + sampled_token_ids: torch.Tensor, + input_positions: torch.Tensor, + seq_lens: torch.Tensor, slot_mapping: torch.Tensor, + block_tables: torch.Tensor, + paged_kv_indices: torch.Tensor, + paged_kv_indptr: torch.Tensor, + paged_kv_last_page_len: torch.Tensor, + block_table_bound: torch.Tensor) -> None: + + return torch.ops._C.advance_step_flashinfer( + num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, + input_positions, seq_lens, slot_mapping, block_tables, + paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, + block_table_bound) # quantization ops diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 074e5880fcbe..c733f8858604 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -306,6 +306,8 @@ def begin_forward(self): assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None + assert self.block_table_bound is not None + assert self.seq_lens_tensor is not None batch_size = self.query_start_loc.shape[0] - 1 assert batch_size >= 0 # We will use flash attention for profiling to @@ -325,7 +327,6 @@ def begin_forward(self): self.num_qo_heads, self.num_kv_heads, self.head_dim, self.page_size) else: - #if not self.use_cuda_graph: assert self.paged_kv_indices is not None assert self.paged_kv_indptr is not None assert self.paged_kv_last_page_len is not None @@ -333,8 +334,11 @@ def begin_forward(self): self.paged_kv_indptr = self.paged_kv_indptr.to(self.device) self.paged_kv_last_page_len = self.paged_kv_last_page_len.to( self.device) - self.block_table_bound = self.block_table_bound.to(self.device) - self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) + # handle model warmup path + if self.block_table_bound is not None: + self.block_table_bound = self.block_table_bound.to(self.device) + if self.seq_lens_tensor is not None: + self.seq_lens_tensor = self.seq_lens_tensor.to(self.device) assert self.decode_wrapper is not None self.decode_wrapper.end_forward() @@ -524,7 +528,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) - self.block_tables.extend([] for _ in range(cuda_graph_pad_size)) + self.block_tables.extend([] * cuda_graph_pad_size) num_decode_tokens = batch_size # The shape of graph_block_tables is @@ -574,8 +578,8 @@ def build(self, seq_lens: List[int], query_lens: List[int], if len(self.paged_kv_indptr) > 0: # extend to the maximum number of blocks as returned by the # scheduler - self.paged_kv_indices.extend([0] * - (self.total_blocks - len(self.paged_kv_indices))) + self.paged_kv_indices.extend( + [0] * (self.total_blocks - len(self.paged_kv_indices))) paged_kv_indices_tensor = torch.tensor(self.paged_kv_indices, device="cpu", dtype=torch.int) @@ -584,14 +588,16 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int) paged_kv_last_page_len_tensor = torch.tensor( self.paged_kv_last_page_len, device="cpu", dtype=torch.int) - block_table_bound_tensor = torch.zeros( - len(self.paged_kv_indptr) - 1, device="cpu", dtype=torch.int) + block_table_bound_tensor = torch.zeros(len(self.paged_kv_indptr) - + 1, + device="cpu", + dtype=torch.int) else: paged_kv_indices_tensor = None paged_kv_indptr_tensor = None paged_kv_last_page_len_tensor = None block_table_bound_tensor = None - + kv_cache_dtype = get_kv_cache_torch_dtype( self.runner.kv_cache_dtype, self.runner.model_config.dtype) return FlashInferMetadata( diff --git a/vllm/worker/multi_step_model_runner.py b/vllm/worker/multi_step_model_runner.py index 6b1b19e4fcd5..579d9affba5f 100644 --- a/vllm/worker/multi_step_model_runner.py +++ b/vllm/worker/multi_step_model_runner.py @@ -346,7 +346,7 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.query_len is None # Decode def _advance_step_flashattn(self, model_input: StatefulModelInput, - out: SamplerOutput) -> StatefulModelInput: + out: SamplerOutput) -> StatefulModelInput: frozen_model_input = model_input.frozen_model_input assert frozen_model_input is not None assert frozen_model_input.attn_metadata is not None @@ -378,7 +378,7 @@ def _advance_step_flashattn(self, model_input: StatefulModelInput, frozen_model_input.seq_lens[i] = attn_metadata.seq_lens[i] return model_input - + def _advance_step_flashinfer( self, model_input: StatefulModelInput, @@ -394,7 +394,10 @@ def _advance_step_flashinfer( num_queries = model_input.num_queries sampled_tokens = model_input.cached_outputs[-1].sampled_token_ids - frozen_model_input.input_tokens[:num_queries] = sampled_tokens.flatten() + assert sampled_tokens is not None + assert frozen_model_input.input_tokens is not None + frozen_model_input.input_tokens[:num_queries] = sampled_tokens.flatten( + ) # Update GPU tensors ops.advance_step_flashinfer( @@ -411,10 +414,9 @@ def _advance_step_flashinfer( paged_kv_indptr=attn_metadata.paged_kv_indptr, paged_kv_last_page_len=attn_metadata.paged_kv_last_page_len, block_table_bound=attn_metadata.block_table_bound) - #frozen_model_input.seq_lens[:num_queries] = [x + 1 for x in frozen_model_input.seq_lens[:num_queries]] return model_input - + def _advance_step(self, model_input: StatefulModelInput, out: SamplerOutput) -> StatefulModelInput: if self.attn_backend.get_name() == "flash-attn": @@ -422,7 +424,8 @@ def _advance_step(self, model_input: StatefulModelInput, elif self.attn_backend.get_name() == "flashinfer": return self._advance_step_flashinfer(model_input, out) else: - raise ValueError(f"Unsupported attention backend: {self.attn_backend}") + raise ValueError( + f"Unsupported attention backend: {self.attn_backend}") def load_model(self) -> None: return self._base_model_runner.load_model()