Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Kernels] Enable FP8 KV Cache with Flashinfer backend. + BugFix for kv_cache_dtype=auto #7985

Merged
merged 4 commits into from
Aug 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 222 additions & 6 deletions tests/kernels/test_flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,14 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
num_heads: Tuple[int,
int], head_size: int,
dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
def test_flashinfer_decode_with_paged_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
Expand All @@ -88,6 +91,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
scale = head_size**-0.5

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)

key_value_cache = torch.randn(NUM_BLOCKS,
2,
block_size,
Expand Down Expand Up @@ -125,7 +129,7 @@ def test_flashinfer_decode_with_paged_kv(kv_lens: List[int],
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=(
(num_query_heads//num_kv_heads) not in (1, 2, 4, 8))
(num_query_heads//num_kv_heads) > 4)
)
wrapper.begin_forward(kv_indptr,
kv_indices,
Expand Down Expand Up @@ -249,3 +253,215 @@ def test_flashinfer_prefill_with_paged_kv(seq_lens: List[Tuple[int, int]],
soft_cap=soft_cap)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("seq_lens", [[(1, 132), (5, 18)]])
@pytest.mark.parametrize("num_heads", [(32, 8), (6, 1)])
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
def test_flashinfer_prefill_with_paged_fp8_kv(
mgoin marked this conversation as resolved.
Show resolved Hide resolved
seq_lens: List[Tuple[int, int]], num_heads: Tuple[int, int],
head_size: int, dtype: torch.dtype, block_size: int,
soft_cap: Optional[float]) -> None:
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(seq_lens)
query_lens = [x[0] for x in seq_lens]
kv_lens = [x[1] for x in seq_lens]
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5

kv_cache_dtype = torch.float8_e4m3fn

query = torch.randn(sum(query_lens),
num_query_heads,
head_size,
dtype=dtype)
NUM_BLOCKS_FP8 = 2048
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
key_cache /= head_size**0.5
value_cache /= head_size**0.5

k_scale = key_cache.amax().item() / 448.0
v_scale = value_cache.amax().item() / 448.0

kv_cache_fp8 = torch.cat([key_cache / k_scale, value_cache / v_scale],
dim=1).to(kv_cache_dtype)

assert (kv_cache_fp8.shape == key_value_cache.shape)
max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS_FP8,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

qo_indptr = [0]
kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)
qo_indptr.append(qo_indptr[-1] + query_lens[i])

qo_indptr = torch.tensor(qo_indptr, dtype=torch.int32)
kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
workspace_buffer, "NHD")
wrapper.begin_forward(
qo_indptr,
kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
)

output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache.squeeze(1),
value_cache=value_cache.squeeze(1),
query_lens=query_lens,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
del query
del block_tables
# verify prefill fp8
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"


@pytest.mark.parametrize("kv_lens", [[1328, 18, 463], [1, 54, 293, 70]])
@pytest.mark.parametrize("num_heads", [(32, 8), (64, 8), (6, 1)])
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 30.0, 50.0])
@torch.inference_mode
def test_flashinfer_decode_with_paged_fp8_kv(
kv_lens: List[int],
num_heads: Tuple[int, int],
head_size: int,
dtype: torch.dtype,
block_size: int,
soft_cap: Optional[float],
) -> None:
# test doesn't work for num_heads = (16,16)
torch.set_default_device("cuda")
torch.cuda.manual_seed_all(0)
num_seqs = len(kv_lens)
num_query_heads = num_heads[0]
num_kv_heads = num_heads[1]
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
use_tensor_cores = (num_query_heads // num_kv_heads) > 4
kv_cache_dtype = torch.float8_e4m3fn

query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
NUM_BLOCKS_FP8 = 2048
key_value_cache = torch.randn(NUM_BLOCKS_FP8,
2,
block_size,
num_kv_heads,
head_size,
dtype=dtype)
key_cache, value_cache = torch.chunk(key_value_cache, 2, dim=1)
key_cache /= head_size**0.5
value_cache /= head_size**0.5

k_scale = key_cache.amax().item() / 448.0
v_scale = value_cache.amax().item() / 448.0

key_cache_fp8 = (key_cache / k_scale).to(kv_cache_dtype)
value_cache_fp8 = (value_cache / v_scale).to(kv_cache_dtype)
assert (key_cache_fp8.shape[1] == 1 and value_cache_fp8.shape[1] == 1)
kv_cache_fp8 = torch.cat([key_cache_fp8, value_cache_fp8], dim=1)

max_num_blocks_per_seq = (max_kv_len + block_size - 1) // block_size
block_tables = torch.randint(0,
NUM_BLOCKS_FP8,
(num_seqs, max_num_blocks_per_seq),
dtype=torch.int32)

kv_indptr = [0]
kv_indices = []
kv_last_page_lens = []
for i in range(num_seqs):
seq_len = kv_lens[i]
assert seq_len > 0
num_blocks = (seq_len + block_size - 1) // block_size
kv_indices.extend(block_tables[i, :num_blocks])
kv_indptr.append(kv_indptr[-1] + num_blocks)
kv_last_page_len = seq_len % block_size
if kv_last_page_len == 0:
kv_last_page_len = block_size
kv_last_page_lens.append(kv_last_page_len)

kv_indptr = torch.tensor(kv_indptr, dtype=torch.int32)
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)

workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
wrapper = flashinfer.\
BatchDecodeWithPagedKVCacheWrapper(workspace_buffer, "NHD",
use_tensor_cores=use_tensor_cores)
wrapper.begin_forward(kv_indptr,
kv_indices,
kv_last_page_lens,
num_query_heads,
num_kv_heads,
head_size,
block_size,
"NONE",
data_type=dtype)
output = wrapper.forward(query,
kv_cache_fp8,
logits_soft_cap=soft_cap,
k_scale=k_scale,
v_scale=v_scale)
key_cache = key_value_cache[:, 0, :, :, :].squeeze(1)
value_cache = key_value_cache[:, 1, :, :, :].squeeze(1)

ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap)
# Temporary fix: Increasing the tolerance. Seems like a flashinfer issue
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
30 changes: 24 additions & 6 deletions vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def copy_blocks(
def get_supported_head_sizes() -> List[int]:
return [64, 128, 256]

@staticmethod
def get_fp8_dtype_for_flashinfer(kv_cache_dtype: str) -> torch.dtype:
if kv_cache_dtype in ("fp8", "fp8_e4m3"):
return torch.float8_e4m3fn
elif kv_cache_dtype == "fp8_e5m2":
return torch.float8_e5m2
else:
raise ValueError(f"Unrecognized FP8 dtype: {kv_cache_dtype}")


class FlashInferState(AttentionState):

Expand Down Expand Up @@ -177,9 +186,9 @@ def graph_capture_get_metadata_for_batch(self, batch_size: int):
self._graph_decode_workspace_buffer, _indptr_buffer,
self._graph_indices_buffer, _last_page_len_buffer, "NHD",
use_tensor_cores)
kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)

kv_cache_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.runner.kv_cache_dtype)
paged_kv_indptr_tensor_host = torch.arange(0,
batch_size + 1,
dtype=torch.int32)
Expand Down Expand Up @@ -340,7 +349,7 @@ def begin_forward(self):
self.page_size,
# Disable flashinfer's pos encoding and use vllm's rope.
pos_encoding_mode="NONE",
data_type=self.data_type)
)
Comment on lines -343 to +352
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this not needed anymore?

Copy link
Contributor Author

@pavanimajety pavanimajety Aug 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The wrappers are throwing dispatch errors with Float8 datatype with the current flashinfer versions(0.1.4 and 0.1.5). I have a couple of future MRs planned. I will enable it if future releases of Flashinfer fixes the dispatch api. Seems like there would be api name changes with flashinfer too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, this is sufficient for now and it seems we will likely need to address this in flashinfer first


def asdict_zerocopy(self,
skip_fields: Optional[Set[str]] = None
Expand All @@ -366,7 +375,8 @@ def prefill_metadata(self) -> Optional["FlashInferMetadata"]:
def decode_metadata(self) -> Optional["FlashInferMetadata"]:
# Currently chunked prefill is not supported
if self.num_prefills > 0:
assert self.num_decode_tokens == 0
assert self.num_decode_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")
return None

return self
Expand Down Expand Up @@ -578,6 +588,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],

kv_cache_dtype = get_kv_cache_torch_dtype(
self.runner.kv_cache_dtype, self.runner.model_config.dtype)

return FlashInferMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down Expand Up @@ -661,7 +672,6 @@ def forward(
if attn_metadata.num_decode_tokens > 0:
assert attn_metadata.num_prefill_tokens == 0, (
"Chunked prefill is not supported with flashinfer yet.")

if kv_cache is not None:
# Use the same reshape and cache kernel as flash attention.
ops.reshape_and_cache_flash(
Expand All @@ -674,6 +684,12 @@ def forward(
k_scale,
v_scale,
)
# The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2
# to process the cache when the kv_cache_dtype is fp8
if self.kv_cache_dtype.startswith("fp8"):
torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer(
self.kv_cache_dtype)
kv_cache = kv_cache.view(torch_dtype)

query = query.contiguous(
) # Flashinfer requires query to be contiguous
Expand Down Expand Up @@ -711,5 +727,7 @@ def forward(
query,
kv_cache,
sm_scale=self.scale,
logits_soft_cap=self.logits_soft_cap)
logits_soft_cap=self.logits_soft_cap,
k_scale=k_scale,
v_scale=v_scale)
return output.view(num_tokens, hidden_size)
4 changes: 4 additions & 0 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,10 @@ def which_attn_to_use(
elif kv_cache_dtype is not None and kv_cache_dtype.startswith("fp8"):
logger.info(
"Cannot use FlashAttention-2 backend for FP8 KV cache.")
logger.warning(
"Please use FlashInfer backend with FP8 KV Cache for "
"better performance by setting environment variable "
"VLLM_ATTENTION_BACKEND=FLASHINFER")
selected_backend = _Backend.XFORMERS
elif block_size % 16 != 0:
logger.info(
Expand Down
Loading