Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core/Bugfix] Add query dtype as per FlashInfer API requirements. #8173

Merged
merged 1 commit into from
Sep 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,8 @@ def test_flashinfer_decode_with_paged_fp8_kv(
head_size,
block_size,
"NONE",
data_type=dtype)
data_type=dtype,
q_data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
Expand Down
9 changes: 8 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
query_start_loc=query_start_loc_host,
device=self.runner.device,
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=True,
decode_wrapper=self._graph_decode_wrapper,
prefill_wrapper=None)
Expand Down Expand Up @@ -292,6 +293,8 @@ class FlashInferMetadata(AttentionMetadata):
page_size: Optional[int] = None
# The data type of the paged kv cache
data_type: torch.dtype = None
# The data type of the query
q_data_type: torch.dtype = None
device: torch.device = torch.device("cuda")
is_profile_run: bool = False

Expand Down Expand Up @@ -353,7 +356,10 @@ def begin_forward(self):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
data_type=self.data_type)
# kv-cache data type.
data_type=self.data_type,
# query data type.
q_data_type=self.q_data_type)

def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
Expand Down Expand Up @@ -617,6 +623,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
query_start_loc=query_start_loc,
device=device,
data_type=kv_cache_dtype,
q_data_type=self.runner.model_config.dtype,
use_cuda_graph=use_captured_graph,
is_profile_run=self.is_profile_run)

Expand Down
Loading