diff --git a/tests/kernels/test_flashinfer.py b/tests/kernels/test_flashinfer.py index 67f12cf1ee08e..696cc0c6cdf10 100644 --- a/tests/kernels/test_flashinfer.py +++ b/tests/kernels/test_flashinfer.py @@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv( head_size, block_size, "NONE", - data_type=dtype) + data_type=dtype, + q_data_type=dtype) output = wrapper.forward(query, kv_cache_fp8, logits_soft_cap=soft_cap, diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index aa9d4a71dbf87..7aec8203eb1e5 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -224,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int): query_start_loc=query_start_loc_host, device=self.runner.device, data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, use_cuda_graph=True, decode_wrapper=self._graph_decode_wrapper, prefill_wrapper=None) @@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata): page_size: Optional[int] = None # The data type of the paged kv cache data_type: torch.dtype = None + # The data type of the query + q_data_type: torch.dtype = None device: torch.device = torch.device("cuda") is_profile_run: bool = False @@ -353,7 +356,10 @@ def begin_forward(self): self.page_size, # Disable flashinfer's pos encoding and use vllm's rope. pos_encoding_mode="NONE", - data_type=self.data_type) + # kv-cache data type. + data_type=self.data_type, + # query data type. + q_data_type=self.q_data_type) def asdict_zerocopy(self, skip_fields: Optional[Set[str]] = None @@ -617,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], query_start_loc=query_start_loc, device=device, data_type=kv_cache_dtype, + q_data_type=self.runner.model_config.dtype, use_cuda_graph=use_captured_graph, is_profile_run=self.is_profile_run)