|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +import os |
| 3 | + |
| 4 | +import neuronxcc.nki.language as nl |
| 5 | +import pytest |
| 6 | +import torch |
| 7 | +import torch.nn.functional as F |
| 8 | +from neuronxcc import nki |
| 9 | + |
| 10 | +from vllm.attention.ops.nki_flash_attn import ( |
| 11 | + load_block_tables, transform_block_tables_for_indirect_load) |
| 12 | + |
| 13 | + |
| 14 | +def is_power_of_2(n): |
| 15 | + return n > 0 and (n & (n - 1) == 0) |
| 16 | + |
| 17 | + |
| 18 | +def nki_load_and_transform_block_tables( |
| 19 | + block_tables, |
| 20 | + num_tiles, |
| 21 | + num_blocks_per_tile, |
| 22 | + num_head, |
| 23 | + head_id, |
| 24 | + block_size_tiling_factor, |
| 25 | +): |
| 26 | + assert is_power_of_2( |
| 27 | + num_blocks_per_tile), f"{num_blocks_per_tile=} must be power of 2" |
| 28 | + block_tables_sbuf = load_block_tables(block_tables, num_tiles, |
| 29 | + num_blocks_per_tile) |
| 30 | + |
| 31 | + # we need to pass an Index as head_id |
| 32 | + head_id = nl.arange(1)[None, :] + head_id |
| 33 | + |
| 34 | + block_tables_transposed = transform_block_tables_for_indirect_load( |
| 35 | + block_tables_sbuf, block_size_tiling_factor, num_head, head_id) |
| 36 | + B_P_SIZE = 128 |
| 37 | + assert block_tables_transposed.shape[1] == B_P_SIZE |
| 38 | + |
| 39 | + out = nl.ndarray( |
| 40 | + block_tables_transposed.shape, |
| 41 | + dtype=nl.int32, |
| 42 | + buffer=nl.shared_hbm, |
| 43 | + ) |
| 44 | + for i in nl.affine_range(block_tables_transposed.shape[0]): |
| 45 | + nl.store(dst=out[i], value=block_tables_transposed[i]) |
| 46 | + return out |
| 47 | + |
| 48 | + |
| 49 | +def ref_block_tables_transform( |
| 50 | + block_tables, |
| 51 | + num_tiles, |
| 52 | + num_blocks_per_tile, |
| 53 | + num_head, |
| 54 | + head_id, |
| 55 | + block_size_tiling_factor, |
| 56 | +): |
| 57 | + assert block_tables.numel() == num_tiles * num_blocks_per_tile |
| 58 | + block_tables = block_tables.view(num_tiles, num_blocks_per_tile) |
| 59 | + B_F_SIZE = 128 |
| 60 | + num_tiles_padded = (num_tiles + B_F_SIZE - 1) // B_F_SIZE * B_F_SIZE |
| 61 | + block_tables = F.pad( |
| 62 | + block_tables, |
| 63 | + (0, 0, 0, num_tiles_padded - num_tiles), |
| 64 | + "constant", |
| 65 | + 0, |
| 66 | + ) |
| 67 | + |
| 68 | + block_tables = block_tables * num_head + head_id |
| 69 | + block_tables = block_tables.view(num_tiles_padded, num_blocks_per_tile, 1) |
| 70 | + offset = torch.arange(0, block_size_tiling_factor).view(1, 1, -1) |
| 71 | + block_tables = block_tables * block_size_tiling_factor + offset |
| 72 | + block_tables_transposed = block_tables.view(num_tiles_padded, -1).t() |
| 73 | + |
| 74 | + num_blocks_per_tile = block_tables_transposed.shape[0] |
| 75 | + assert num_blocks_per_tile % B_F_SIZE == 0 |
| 76 | + return block_tables_transposed.view(num_blocks_per_tile // B_F_SIZE, |
| 77 | + B_F_SIZE, num_tiles_padded) |
| 78 | + |
| 79 | + |
| 80 | +@pytest.mark.parametrize( |
| 81 | + "q_head_per_kv_head,head_id", |
| 82 | + [ |
| 83 | + (1, 0), |
| 84 | + (3, 1), |
| 85 | + ], |
| 86 | +) |
| 87 | +@pytest.mark.parametrize( |
| 88 | + "num_tiles,num_blocks_per_tile", |
| 89 | + [ |
| 90 | + (1, 1), |
| 91 | + (13, 16), |
| 92 | + (17, 128), |
| 93 | + (35, 512), |
| 94 | + (128, 128), |
| 95 | + (130, 64), |
| 96 | + (280, 256), |
| 97 | + (315, 1), |
| 98 | + ], |
| 99 | +) |
| 100 | +@torch.inference_mode() |
| 101 | +def test_load_and_transform_block_tables( |
| 102 | + num_tiles, |
| 103 | + num_blocks_per_tile, |
| 104 | + q_head_per_kv_head, |
| 105 | + head_id, |
| 106 | +) -> None: |
| 107 | + import torch_xla.core.xla_model as xm |
| 108 | + |
| 109 | + device = xm.xla_device() |
| 110 | + |
| 111 | + compiler_flags = [ |
| 112 | + "-O1", |
| 113 | + "--retry_failed_compilation", |
| 114 | + ] |
| 115 | + compiler_flags_str = " ".join(compiler_flags) |
| 116 | + os.environ["NEURON_CC_FLAGS"] = compiler_flags_str |
| 117 | + |
| 118 | + torch.manual_seed(10000) |
| 119 | + torch.set_printoptions(sci_mode=False) |
| 120 | + |
| 121 | + # On Neuron, we need B_P_SIZE = 128 blocks to make DMA efficient |
| 122 | + B_P_SIZE = 128 |
| 123 | + if num_blocks_per_tile < B_P_SIZE: |
| 124 | + assert B_P_SIZE % num_blocks_per_tile == 0 |
| 125 | + block_size_tiling_factor = B_P_SIZE // num_blocks_per_tile |
| 126 | + else: |
| 127 | + block_size_tiling_factor = 1 |
| 128 | + max_num_blocks = 100000 |
| 129 | + block_tables = torch.randint( |
| 130 | + 0, |
| 131 | + max_num_blocks, |
| 132 | + (num_tiles * num_blocks_per_tile, ), |
| 133 | + dtype=torch.int32, |
| 134 | + ) |
| 135 | + nki_out = nki.jit(nki_load_and_transform_block_tables)[1, 1]( |
| 136 | + block_tables.to(device=device), |
| 137 | + num_tiles, |
| 138 | + num_blocks_per_tile, |
| 139 | + q_head_per_kv_head, |
| 140 | + head_id, |
| 141 | + block_size_tiling_factor, |
| 142 | + ).cpu() |
| 143 | + ref_out = ref_block_tables_transform( |
| 144 | + block_tables, |
| 145 | + num_tiles, |
| 146 | + num_blocks_per_tile, |
| 147 | + q_head_per_kv_head, |
| 148 | + head_id, |
| 149 | + block_size_tiling_factor, |
| 150 | + ) |
| 151 | + assert (nki_out.shape == ref_out.shape |
| 152 | + ), f"{nki_out.shape=} != {ref_out.shape=}" |
| 153 | + assert torch.all(nki_out == ref_out) |
0 commit comments