Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Infer] Revise and Adapt Triton Kernels for Spec-Dec #5401

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/kernel/triton/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from .flash_decoding import flash_decoding_attention
from .fused_rotary_embedding import fused_rotary_embedding
from .gptq_triton import gptq_fused_linear_triton
from .kvcache_copy import copy_kv_to_blocked_cache
from .kvcache_copy import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
from .rms_layernorm import rms_layernorm
from .rotary_cache_copy import get_xine_cache
Expand All @@ -21,6 +21,7 @@
__all__ = [
"context_attention_unpadded",
"flash_decoding_attention",
"copy_k_to_blocked_cache",
"copy_kv_to_blocked_cache",
"softmax",
"rms_layernorm",
Expand Down
106 changes: 60 additions & 46 deletions colossalai/kernel/triton/flash_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,14 @@
# Triton 2.1.0
@triton.jit
def _flash_decoding_fwd_kernel(
Q, # [batch_size, head_num, q_len(1), head_dim]
Q, # [batch_size * q_len, head_num, head_dim]
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
block_tables, # [batch_size, max_blocks_per_sequence]
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size, head_num, kv_split_num]
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
kv_seq_len, # [batch_size]
q_len,
batch_size,
stride_qt,
stride_qh,
Expand All @@ -39,44 +40,37 @@ def _flash_decoding_fwd_kernel(
BLOCK_SIZE: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size:
return
cur_head_idx = tl.program_id(1)
block_start_kv = tl.program_id(2) # for splitting k/v

cur_kv_head_idx = cur_head_idx // KV_GROUPS
offsets_dmodel = tl.arange(0, HEAD_DIM)

# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
# and then support calculating multiple kv cache blocks on an instance
tl.static_assert(BLOCK_KV == BLOCK_SIZE)

# get the current (kv) sequence length from provided context lengths tensor
# get the current (kv) sequence length
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return

offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
q = tl.load(Q + offsets_q)

# block table for the current sequence
block_table_ptr = block_tables + cur_seq_idx * stride_bts

# actually current block table current block start idx
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
cur_bt_start_idx = block_start_kv
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)

if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
return

# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
cur_occupied_size = tl.where(
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
)
tl.device_assert(cur_occupied_size >= 0)

cur_kv_head_idx = cur_head_idx // KV_GROUPS
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh

K_block_ptr = tl.make_block_ptr(
base=KCache + offset_kvcache,
shape=(cur_occupied_size, HEAD_DIM),
Expand Down Expand Up @@ -115,14 +109,14 @@ def _flash_decoding_fwd_kernel(
acc = acc / l

offsets_mid_o = (
cur_seq_idx * stride_mid_ot
cur_token_idx * stride_mid_ot
+ cur_head_idx * stride_mid_oh
+ block_start_kv * stride_mid_ob
+ offsets_dmodel * stride_mid_od
)
tl.store(mid_o + offsets_mid_o, acc)
offsets_mid_o_lse = (
cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
)
# logsumexp L^(j) = m^(j) + log(l^(j))
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
Expand All @@ -135,6 +129,7 @@ def _flash_decoding_fwd_reduce_kernel(
mid_o_lse, # [batch_size, head_num, kv_split_num]
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
kv_seq_len,
q_len,
batch_size,
stride_mid_ot,
stride_mid_oh,
Expand All @@ -149,7 +144,8 @@ def _flash_decoding_fwd_reduce_kernel(
BLOCK_KV: tl.constexpr,
HEAD_DIM: tl.constexpr,
):
cur_seq_idx = tl.program_id(0)
cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // q_len
if cur_seq_idx >= batch_size:
return
cur_head_idx = tl.program_id(1)
Expand All @@ -164,8 +160,8 @@ def _flash_decoding_fwd_reduce_kernel(
l = 0.0 # sum exp
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)

offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh
for block_i in range(0, kv_split_num, 1):
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)
Expand All @@ -179,7 +175,7 @@ def _flash_decoding_fwd_reduce_kernel(
m_i = m_ij

acc = acc / l
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
tl.store(O + offsets_O, acc.to(O.type.element_ty))
return

Expand All @@ -199,32 +195,40 @@ def flash_decoding_attention(
mid_output_lse: torch.Tensor = None,
sm_scale: int = None,
kv_group_num: int = 1,
q_len: int = 1,
):
"""
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.

Args:
q (torch.Tensor): [bsz, num_heads, head_dim]
q (torch.Tensor): [bsz * q_len, num_heads, head_dim]
q_len > 1 only for verification process in speculative-decoding.
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
kv_seq_len (torch.Tensor): [batch_size]
records the (kv) sequence lengths incorporating past kv sequence lengths.
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
max_seq_len_in_batch (int): Maximum sequence length in the batch.
output (torch.Tensor): [bsz, num_heads * head_dim]
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim]
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
q_len > 1 only for verification process in speculative-decoding.
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
q_len > 1 only for verification process in speculative-decoding.
block_size (int): Size of each block in the blocked key/value cache.
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
Defaults to 1.

Returns:
Output tensor with shape [bsz, num_heads * head_dim]
Output tensor with shape [bsz * q_len, num_heads * head_dim]
"""
q = q.squeeze() if q.dim() == 4 else q
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
bsz, num_heads, head_dim = q.shape
n_tokens, num_heads, head_dim = q.shape
assert n_tokens % q_len == 0, "Invalid q_len"
bsz = n_tokens // q_len

assert head_dim in {32, 64, 128, 256}
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
Expand All @@ -247,22 +251,31 @@ def flash_decoding_attention(
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
# For compatibility (TODO revise modeling in future)
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
mid_output = (
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
if mid_output is None
else mid_output
)
mid_output_lse = (
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
if mid_output_lse is None
else mid_output_lse
)

if mid_output is None:
mid_output = torch.empty(
(bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
)
if mid_output_lse is None:
mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
if output is None:
# A hack to prevent `view` operation in modeling
output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device)

assert (
mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num
), "Incompatible kv split number of intermediate output tensors"
assert (
mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens
), f"Incompatible first dimension of output tensors"

# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output

grid = (
triton.next_power_of_2(bsz * q_len),
num_heads,
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
)
_flash_decoding_fwd_kernel[grid](
q,
k_cache,
Expand All @@ -271,6 +284,7 @@ def flash_decoding_attention(
mid_output,
mid_output_lse,
kv_seq_len,
q_len,
bsz,
q.stride(0),
q.stride(1),
Expand All @@ -295,13 +309,13 @@ def flash_decoding_attention(
HEAD_DIM=head_dim,
)

grid = (triton.next_power_of_2(bsz), num_heads)

grid = (triton.next_power_of_2(bsz * q_len), num_heads)
_flash_decoding_fwd_reduce_kernel[grid](
mid_output,
mid_output_lse,
output,
kv_seq_len,
q_len,
bsz,
mid_output.stride(0),
mid_output.stride(1),
Expand Down
109 changes: 106 additions & 3 deletions colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,50 @@
import triton.language as tl


# Triton 2.1.0
@triton.jit
def _copy_to_kcache_seqlen_n_kernel(
KV, # K or V
KVCache, # KCache or VCache
BLOCK_TABLES,
context_lengths,
stride_kt,
stride_kh,
stride_kd,
stride_cacheb,
stride_cacheh,
stride_cachebs,
stride_cached,
stride_bts,
stride_btb,
block_size,
n,
HEAD_DIM: tl.constexpr,
):
cur_token_idx = tl.program_id(0)
cur_seq_idx = cur_token_idx // n
cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1))
# cur_token_shift = cur_token_idx - n * cur_seq_idx
cur_kv_head_idx = tl.program_id(1)

past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift
last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offset_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
offsets_kvcache = (
block_id * stride_cacheb
+ cur_kv_head_idx * stride_cacheh
+ offset_last_block * stride_cachebs
+ offsets_dmodel * stride_cached
)
tl.store(KVCache + offsets_kvcache, kv)
return


# Triton 2.1.0
@triton.jit
def _copy_to_kvcache_seqlen1_kernel(
Expand Down Expand Up @@ -40,10 +84,11 @@ def _copy_to_kvcache_seqlen1_kernel(
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = past_kv_seq_len % block_size
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd

k = tl.load(K + offsets_kv)
v = tl.load(V + offsets_kv)
k = tl.load(K + offsets_k)
v = tl.load(V + offsets_v)

offsets_kcache = (
block_id * stride_cachekb
Expand All @@ -63,6 +108,64 @@ def _copy_to_kvcache_seqlen1_kernel(
return


def copy_k_to_blocked_cache(
k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1
):
"""
Copy keys or values to the blocked key/value cache during decoding stage.

Args:
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
n (int): Number of tokens to copy for each sequence. Default to 1.
"""
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."

k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k
assert k.dim() == 3, f"Invalid k dim {k.dim()}"
bsz, num_kv_heads, head_dim = k.shape
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]
if n > 1:
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied"
bsz = bsz // n

assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
)

# Modify if the shape of kv cahce is changed.
block_size = k_cache.size(-2)

num_warps = 8 if head_dim > 128 else 4

grid = (bsz * n, num_kv_heads)
_copy_to_kcache_seqlen_n_kernel[grid](
k,
k_cache,
block_tables,
kv_lengths,
k.stride(0),
k.stride(1),
k.stride(2),
k_cache.stride(0),
k_cache.stride(1),
k_cache.stride(2),
k_cache.stride(3),
block_tables.stride(0),
block_tables.stride(1),
block_size,
n=n,
HEAD_DIM=head_dim,
num_warps=num_warps,
)


def copy_kv_to_blocked_cache(
k: torch.Tensor,
v: torch.Tensor,
Expand Down
Loading
Loading