Skip to content

Commit 3317008

Browse files
authored
[Neuron][Kernel] Vectorize KV cache load in FlashPagedAttention to maximize DMA bandwidth (vllm-project#13245)
Signed-off-by: Lingfan Yu <lingfany@amazon.com>
1 parent 71face8 commit 3317008

File tree

3 files changed

+764
-348
lines changed

3 files changed

+764
-348
lines changed

tests/neuron/test_block_table.py

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
import os
3+
4+
import neuronxcc.nki.language as nl
5+
import pytest
6+
import torch
7+
import torch.nn.functional as F
8+
from neuronxcc import nki
9+
10+
from vllm.attention.ops.nki_flash_attn import (
11+
load_block_tables, transform_block_tables_for_indirect_load)
12+
13+
14+
def is_power_of_2(n):
15+
return n > 0 and (n & (n - 1) == 0)
16+
17+
18+
def nki_load_and_transform_block_tables(
19+
block_tables,
20+
num_tiles,
21+
num_blocks_per_tile,
22+
num_head,
23+
head_id,
24+
block_size_tiling_factor,
25+
):
26+
assert is_power_of_2(
27+
num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2"
28+
block_tables_sbuf = load_block_tables(block_tables, num_tiles,
29+
num_blocks_per_tile)
30+
31+
# we need to pass an Index as head_id
32+
head_id = nl.arange(1)[None, :] + head_id
33+
34+
block_tables_transposed = transform_block_tables_for_indirect_load(
35+
block_tables_sbuf, block_size_tiling_factor, num_head, head_id)
36+
B_P_SIZE = 128
37+
assert block_tables_transposed.shape[1] == B_P_SIZE
38+
39+
out = nl.ndarray(
40+
block_tables_transposed.shape,
41+
dtype=nl.int32,
42+
buffer=nl.shared_hbm,
43+
)
44+
for i in nl.affine_range(block_tables_transposed.shape[0]):
45+
nl.store(dst=out[i], value=block_tables_transposed[i])
46+
return out
47+
48+
49+
def ref_block_tables_transform(
50+
block_tables,
51+
num_tiles,
52+
num_blocks_per_tile,
53+
num_head,
54+
head_id,
55+
block_size_tiling_factor,
56+
):
57+
assert block_tables.numel() == num_tiles * num_blocks_per_tile
58+
block_tables = block_tables.view(num_tiles, num_blocks_per_tile)
59+
B_F_SIZE = 128
60+
num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE
61+
block_tables = F.pad(
62+
block_tables,
63+
(0, 0, 0, num_tiles_padded - num_tiles),
64+
"constant",
65+
0,
66+
)
67+
68+
block_tables = block_tables * num_head + head_id
69+
block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1)
70+
offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1)
71+
block_tables = block_tables * block_size_tiling_factor + offset
72+
block_tables_transposed = block_tables.view(num_tiles_padded, -1).t()
73+
74+
num_blocks_per_tile = block_tables_transposed.shape[0]
75+
assert num_blocks_per_tile % B_F_SIZE == 0
76+
return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE,
77+
B_F_SIZE, num_tiles_padded)
78+
79+
80+
@pytest.mark.parametrize(
81+
"q_head_per_kv_head,head_id",
82+
[
83+
(1, 0),
84+
(3, 1),
85+
],
86+
)
87+
@pytest.mark.parametrize(
88+
"num_tiles,num_blocks_per_tile",
89+
[
90+
(1, 1),
91+
(13, 16),
92+
(17, 128),
93+
(35, 512),
94+
(128, 128),
95+
(130, 64),
96+
(280, 256),
97+
(315, 1),
98+
],
99+
)
100+
@torch.inference_mode()
101+
def test_load_and_transform_block_tables(
102+
num_tiles,
103+
num_blocks_per_tile,
104+
q_head_per_kv_head,
105+
head_id,
106+
) -> None:
107+
import torch_xla.core.xla_model as xm
108+
109+
device = xm.xla_device()
110+
111+
compiler_flags = [
112+
"-O1",
113+
"--retry_failed_compilation",
114+
]
115+
compiler_flags_str = " ".join(compiler_flags)
116+
os.environ["NEURON_CC_FLAGS"] = compiler_flags_str
117+
118+
torch.manual_seed(10000)
119+
torch.set_printoptions(sci_mode=False)
120+
121+
# On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient
122+
B_P_SIZE = 128
123+
if num_blocks_per_tile < B_P_SIZE:
124+
assert B_P_SIZE % num_blocks_per_tile == 0
125+
block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile
126+
else:
127+
block_size_tiling_factor = 1
128+
max_num_blocks = 100000
129+
block_tables = torch.randint(
130+
0,
131+
max_num_blocks,
132+
(num_tiles * num_blocks_per_tile, ),
133+
dtype=torch.int32,
134+
)
135+
nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1](
136+
block_tables.to(device=device),
137+
num_tiles,
138+
num_blocks_per_tile,
139+
q_head_per_kv_head,
140+
head_id,
141+
block_size_tiling_factor,
142+
).cpu()
143+
ref_out = ref_block_tables_transform(
144+
block_tables,
145+
num_tiles,
146+
num_blocks_per_tile,
147+
q_head_per_kv_head,
148+
head_id,
149+
block_size_tiling_factor,
150+
)
151+
assert (nki_out.shape == ref_out.shape
152+
), f"{nki_out.shape=} != {ref_out.shape=}"
153+
assert torch.all(nki_out == ref_out)

0 commit comments

Comments
 (0)