Skip to content

Commit

Permalink
enable alibi in pagedattention
Browse files Browse the repository at this point in the history
  • Loading branch information
SunflowerAries committed Apr 28, 2024
1 parent 97c6134 commit ac1ae36
Show file tree
Hide file tree
Showing 6 changed files with 49 additions and 18 deletions.
45 changes: 30 additions & 15 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,21 @@ def forward(
inference_ops.decode_kv_cache_memcpy(
key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables
)
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
self.alibi_slopes,
sm_scale,
)
attn_output = output_tensor
else:
if not is_verifier and not self.use_alibi_attn:
decoding_fused_rotary_embedding(
Expand All @@ -276,21 +291,21 @@ def forward(
value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables, n=q_len
)

attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
alibi_slopes=self.alibi_slopes,
sm_scale=sm_scale,
q_len=q_len,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = torch.mm(attn_output, self.o_proj_weight)
Expand Down
1 change: 1 addition & 0 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,7 @@ def forward(
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
None,
sm_scale,
)
attn_output = output_tensor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def benchmark_flash_decoding_attention(
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
alibi_slopes = None
kv_scale = 1.0

mid_output = torch.empty(
Expand Down Expand Up @@ -166,6 +167,7 @@ def benchmark_flash_decoding_attention(
max_seq_len_across_batch,
mid_output,
mid_output_lse,
alibi_slopes,
sm_scale,
)
else:
Expand Down
15 changes: 13 additions & 2 deletions extensions/csrc/kernel/cuda/flash_decoding_attention_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ __global__ void flash_decoding_attention_kernel(
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, block_size, head_size]
const int* __restrict__ context_lens, // [num_tokens]
const int* __restrict__ block_tables, // [num_tokens, max_num_blocks_per_seq]
const float* __restrict__ alibi_slopes, // [num_heads]
const int max_seq_len,
const int num_kv_heads,
const float scale,
Expand Down Expand Up @@ -90,6 +91,7 @@ __global__ void flash_decoding_attention_kernel(
using Float_vec = typename FloatVecTypeTrait<L_vec>::Type;

const int context_len = context_lens[seq_idx];
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const int thread_group_offset = lane % NUM_THREADS_PER_X;
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
Expand Down Expand Up @@ -149,6 +151,7 @@ __global__ void flash_decoding_attention_kernel(

if (thread_group_offset == 0 && lane < NUM_ROWS_PER_ROUNDS * NUM_THREADS_PER_X) {
const int token_idx = block_idx * BLOCK_SIZE + i * NUM_ROWS_PER_ROUNDS + lane / NUM_THREADS_PER_X;
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
const bool mask = token_idx >= context_len;
logits[token_idx] = mask ? 0.f : qk;
qk_max = mask ? qk_max : fmaxf(qk_max, qk);
Expand Down Expand Up @@ -246,6 +249,7 @@ __global__ void flash_decoding_attention_kernel(
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
context_lens.data_ptr<int>(), \
block_tables.data_ptr<int>(), \
alibi_slopes_ptr, \
max_context_len, \
num_kv_heads, \
scale, \
Expand All @@ -267,7 +271,8 @@ void flash_decoding_attention_v1_launcher(
torch::Tensor& context_lens, // [num_tokens]
torch::Tensor& block_tables, // [num_tokens, max_num_blocks_per_seq]
int max_context_len,
float scale) {
float scale,
const c10::optional<torch::Tensor>& alibi_slopes) {
int num_tokens = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -289,6 +294,10 @@ void flash_decoding_attention_v1_launcher(
// Keep that in sync with the logic here!
int shared_mem_size = std::max(logits_size, outputs_size) + DIVIDE_ROUND_UP(max_num_blocks_per_seq * sizeof(int), sizeof(float4)) * sizeof(float4);

const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;

dim3 grid(num_heads, num_tokens, 1);
dim3 block(NUM_THREADS);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
Expand Down Expand Up @@ -321,7 +330,8 @@ void flash_decoding_attention_v1_launcher(
context_lens, \
block_tables, \
max_context_len, \
scale);
scale, \
alibi_slopes);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -352,6 +362,7 @@ void flash_decoding_attention(
int max_context_len,
torch::Tensor& tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
const c10::optional<torch::Tensor>& alibi_slopes,
float scale) {
switch (query.scalar_type()) {
case at::ScalarType::Float:
Expand Down
2 changes: 1 addition & 1 deletion extensions/pybind/inference/inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void flash_decoding_attention(
torch::Tensor&
tmp_out, // [num_tokens, num_heads, max_num_partitions, head_size]
torch::Tensor& tmp_out_lse, // [num_tokens, num_heads, max_num_partitions]
float scale);
const c10::optional<torch::Tensor>& alibi_slopes, float scale);

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("decode_kv_cache_memcpy", &decode_kv_cache_memcpy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def test_flash_decoding_attention(
kv_max_split_num = (max_seq_len_across_batch + BLOCK_SIZE - 1) // BLOCK_SIZE
output = torch.empty((BATCH_SIZE, NUM_ATTN_HEADS, HEAD_SIZE), dtype=dtype, device=device)
sm_scale = 1.0 / (HEAD_SIZE**0.5)
alibi_slopes = None

k_torch = convert_kv_unpad_to_padded(k_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
v_torch = convert_kv_unpad_to_padded(v_unpad, kv_seq_lengths, BATCH_SIZE, max_seq_len_across_batch)
Expand Down Expand Up @@ -146,6 +147,7 @@ def test_flash_decoding_attention(
max_seq_len_across_batch,
mid_output,
mid_output_lse,
alibi_slopes,
sm_scale,
)
numpy_allclose(out_ref, output, rtol=rtol, atol=atol)
Expand Down

0 comments on commit ac1ae36

Please sign in to comment.