Skip to content

Commit

Permalink
enable alibi in pagedattention
Browse files Browse the repository at this point in the history
  • Loading branch information
SunflowerAries committed Apr 28, 2024
1 parent 97c6134 commit 183dfd1
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 20 deletions.
45 changes: 30 additions & 15 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,21 @@ def forward(
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
Expand All @@ -276,21 +291,21 @@ def forward(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)

attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
Expand Down
1 change: 1 addition & 0 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def forward(
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
None,
sm_scale,
)
attn_output = output_tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def benchmark_flash_decoding_attention(
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
alibi_slopes = None
kv_scale = 1.0

mid_output = torch.empty(
Expand Down Expand Up @@ -166,6 +167,7 @@ def benchmark_flash_decoding_attention(
max_seq_len_across_batch,
mid_output,
mid_output_lse,
alibi_slopes,
sm_scale,
)
else:
Expand Down
15 changes: 13 additions & 2 deletions extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ __global__ void flash_decoding_attention_kernel(
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const int* __restrict__ context_lens, // [num_tokens]
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads]
const int max_seq_len,
const int num_kv_heads,
const float scale,
Expand Down Expand Up @@ -90,6 +91,7 @@ __global__ void flash_decoding_attention_kernel(
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;

const int context_len = context_lens[seq_idx];
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int thread_group_offset = lane % NUM_THREADS_PER_X;
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
Expand Down Expand Up @@ -149,6 +151,7 @@ __global__ void flash_decoding_attention_kernel(

if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
Expand Down Expand Up @@ -246,6 +249,7 @@ __global__ void flash_decoding_attention_kernel(
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
context_lens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
alibi_slopes_ptr, \
max_context_len, \
num_kv_heads, \
scale, \
Expand All @@ -267,7 +271,8 @@ void flash_decoding_attention_v1_launcher(
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
int max_context_len,
float scale) {
float scale,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_tokens = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -289,6 +294,10 @@ void flash_decoding_attention_v1_launcher(
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);

const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;

dim3 grid(num_heads, num_tokens, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
Expand Down Expand Up @@ -321,7 +330,8 @@ void flash_decoding_attention_v1_launcher(
context_lens, \
block_tables, \
max_context_len, \
scale);
scale, \
alibi_slopes);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -352,6 +362,7 @@ void flash_decoding_attention(
int max_context_len,
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) {
switch (query.scalar_type()) {
case at::ScalarType::Float:
Expand Down
2 changes: 1 addition & 1 deletion extensions/pybind/inference/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void flash_decoding_attention(
torch::Tensor&
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
float scale);
const c10::optional<torch::Tensor>& alibi_slopes, float scale);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
Expand Down
27 changes: 25 additions & 2 deletions tests/test_infer/test_ops/cuda/test_flash_decoding_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import pytest
import torch

from colossalai.inference.modeling.models.nopadding_baichuan import get_alibi_slopes
from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.utils import get_current_device
from tests.test_infer.test_ops.triton.test_context_attn_unpad import generate_alibi_mask

inference_ops = InferenceOpsLoader().load()

Expand Down Expand Up @@ -60,8 +62,9 @@ def numpy_allclose(x, y, rtol, atol):
@pytest.mark.parametrize("NUM_ATTN_HEADS", [16])
@pytest.mark.parametrize("KV_GROUP_NUM", [1, 2, 16])
@pytest.mark.parametrize("dtype", [torch.float16, torch.float32])
@pytest.mark.parametrize("use_alibi_slopes", [True, False])
def test_flash_decoding_attention(
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype
BATCH_SIZE, BLOCK_SIZE, MAX_NUM_BLOCKS_PER_SEQ, HEAD_SIZE, NUM_ATTN_HEADS, KV_GROUP_NUM, dtype, use_alibi_slopes
):
torch.manual_seed(123)
torch.cuda.empty_cache()
Expand All @@ -73,6 +76,11 @@ def test_flash_decoding_attention(
MAX_SEQ_LEN = BLOCK_SIZE * MAX_NUM_BLOCKS_PER_SEQ
device = get_current_device()

if use_alibi_slopes:
alibi_slopes = get_alibi_slopes(NUM_ATTN_HEADS, device)
else:
alibi_slopes = None

q, k_unpad, v_unpad, kv_seq_lengths = prepare_data(
BATCH_SIZE, HEAD_SIZE, NUM_ATTN_HEADS, NUM_KV_HEADS, MAX_SEQ_LEN, dtype, device
)
Expand All @@ -91,6 +99,15 @@ def test_flash_decoding_attention(
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
torch_padding_mask = create_attention_mask(kv_seq_lengths, BATCH_SIZE, q_len, max_seq_len_across_batch, device)

if use_alibi_slopes:
alibi_mask = generate_alibi_mask(alibi_slopes, NUM_ATTN_HEADS, max_seq_len_across_batch, device)
torch_padding_mask = torch_padding_mask + alibi_mask

if len(torch_padding_mask.size()) == 4:
torch_padding_mask = torch_padding_mask[:, :, -1:, :]
else:
torch_padding_mask = torch_padding_mask[:, -1:, :]

mid_output = torch.empty(
size=(BATCH_SIZE, NUM_ATTN_HEADS, kv_max_split_num, HEAD_SIZE), dtype=torch.float32, device=device
)
Expand Down Expand Up @@ -146,8 +163,14 @@ def test_flash_decoding_attention(
max_seq_len_across_batch,
mid_output,
mid_output_lse,
alibi_slopes,
sm_scale,
)

# The alibi may introduce relatively large errors
if use_alibi_slopes:
rtol = 1e0

numpy_allclose(out_ref, output, rtol=rtol, atol=atol)


Expand Down Expand Up @@ -277,5 +300,5 @@ def test_vllm_flash_decoding_attention(
dtype,
) in test_combinations:
test_flash_decoding_attention(
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype
batch_size, block_size, max_num_blocks_per_seq, head_size, num_attn_heads, kv_group_num, dtype, True
)

0 comments on commit 183dfd1

Please sign in to comment.