Skip to content

Commit

Permalink
[Core/Bugfix] Add query dtype as per FlashInfer API requirements. (vl…
Browse files Browse the repository at this point in the history
  • Loading branch information
elfiegg authored and Jeffwan committed Sep 19, 2024
1 parent bc01c61 commit a026bca
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 2 deletions.
3 changes: 2 additions & 1 deletion tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
head_size,
block_size,
"NONE",
data_type=dtype)
data_type=dtype,
q_data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
Expand Down
9 changes: 8 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
Expand Down Expand Up @@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
is_profile_run: bool = False

Expand Down Expand Up @@ -353,7 +356,10 @@ def begin_forward(self):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
data_type=self.data_type)
# kv-cache data type.
data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)

def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
Expand Down Expand Up @@ -617,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)

Expand Down

0 comments on commit a026bca

Please sign in to comment.