Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[kernel] Revise KVCache copy triton kernel API #5273

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions colossalai/kernel/triton/kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def _copy_to_kvcache_seqlen1_kernel(
cur_seq_idx = tl.program_id(0)
cur_kv_head_idx = tl.program_id(1)

cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx)
last_bt_block_idx = cur_kv_seq_len // block_size
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - 1
last_bt_block_idx = past_kv_seq_len // block_size
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
offsets_in_last_block = (cur_kv_seq_len % block_size) * stride_cachebs
offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs
offsets_dmodel = tl.arange(0, HEAD_DIM)
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
kv = tl.load(KV + offsets_kv)
Expand All @@ -43,23 +43,30 @@ def _copy_to_kvcache_seqlen1_kernel(
return


# Used with blocked kv cache.
# Copy k or v to block k/v cache during decoding stage
def copy_kv_to_blocked_cache(
k: torch.Tensor, # [bsz, 1, num_kv_heads, head_dim], k or v during decoding stage
k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size], blocked k or v cache (for now, the shapes of them are the same)
context_lengths: torch.Tensor, # [bsz], past kv seq len (not incorporating the current kv of length 1)
block_tables: torch.Tensor, # [bsz, max_blocks_per_sequence]
k: torch.Tensor,
k_cache: torch.Tensor,
kv_lengths: torch.Tensor,
block_tables: torch.Tensor,
):
"""
Copy keys or values to the blocked key/value cache during decoding stage.

Parameters:
- k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
- k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache.
- kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
- block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
"""
assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)"
assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)"
assert k.size(-1) == k_cache.size(-2), "Incompatible head dim"
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
bsz, _, num_kv_heads, head_dim = k.shape
assert context_lengths.shape[0] == block_tables.shape[0] == bsz, (
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
f"Got incompatible batch size (number of seqs):\n"
f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, "
f"batch size {bsz}"
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
)

# Modify if the shape of kv cahce is changed.
Expand All @@ -74,7 +81,7 @@ def copy_kv_to_blocked_cache(
k,
k_cache,
block_tables,
context_lengths,
kv_lengths,
k.stride(0),
k.stride(1),
k.stride(2),
Expand Down
44 changes: 23 additions & 21 deletions tests/test_infer_ops/triton/test_kvcache_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def prepare_data(
dtype=torch.float16,
):
if same_context_len:
# context_lengths in this test records the previous kv seq len
# past_kv_seq_lengths in this test records the previous kv seq len
# (not incorporating the current input whose seq len is 1)
context_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device)
else:
context_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(context_lengths).item()
past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device)
num_tokens = torch.sum(past_kv_seq_lengths).item()

kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
Expand All @@ -46,15 +46,18 @@ def prepare_data(
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
# Mock allocation on block tables as well as blocked kv caches
block_tables = mock_alloc_block_table_and_kvcache(
k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size
)
block_tables = block_tables.to(device=device)

new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device)
# mock allocating blocks for the new k/v and update block tables
mock_alloc_single_token(block_tables, context_lengths, block_size)
mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size)

return new_k, k_cache, context_lengths, block_tables
# kv seq len = past kv seq len + seq len (1 during decoding stage)
kv_seq_lengths = past_kv_seq_lengths + 1

return new_k, k_cache, kv_seq_lengths, block_tables


@pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton")
Expand All @@ -80,7 +83,7 @@ def test_copy_kv_to_caches(
dtype = torch.float16
device = get_current_device()

new_k, k_cache, context_lengths, block_tables = prepare_data(
new_k, k_cache, kv_seq_lengths, block_tables = prepare_data(
bsz,
num_kv_heads,
head_dim,
Expand All @@ -91,25 +94,24 @@ def test_copy_kv_to_caches(
device=device,
dtype=dtype,
)

copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables)
copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables)

for seq_i in range(bsz):
ki = new_k[seq_i]
ki = ki.squeeze()
context_len_i = context_lengths[seq_i]
target_block_id = block_tables[seq_i, context_len_i // block_size]
offsets_in_block = context_len_i % block_size
past_kv_seq_len = kv_seq_lengths[seq_i] - 1
target_block_id = block_tables[seq_i, past_kv_seq_len // block_size]
offsets_in_block = past_kv_seq_len % block_size
target = k_cache[target_block_id, :, :, offsets_in_block]
orig = new_k[seq_i].squeeze(dim=0)
assert torch.equal(orig, target)


BATCH = 4
BATCH = 16
configs = [
triton.testing.Benchmark(
x_names=["PAST_KVLEN"],
x_vals=[2**i - 1 for i in range(8, 13)],
x_names=["KV_SEQ_LEN"],
x_vals=[2**i for i in range(8, 13)],
line_arg="provider",
line_vals=["torch_copy_func", "triton_copy_func"],
line_names=["torch_copy_func", "triton_copy_func"],
Expand All @@ -127,7 +129,7 @@ def benchmark_kvcache_copy(
bsz: int,
block_size: int,
max_seq_len: int,
PAST_KVLEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens)
num_kv_heads: int,
same_context_len: bool,
):
Expand All @@ -138,7 +140,7 @@ def benchmark_kvcache_copy(
dtype = torch.float16
device = get_current_device()

assert PAST_KVLEN < max_seq_len, "Assigned maximum past kv length must be smaller or equal to maximum seq len"
assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len"

new_k, k_cache, context_lengths, block_tables = prepare_data(
bsz,
Expand All @@ -147,7 +149,7 @@ def benchmark_kvcache_copy(
block_size,
max_seq_len // block_size,
same_context_len,
PAST_KVLEN,
KV_SEQ_LEN,
device=device,
dtype=dtype,
)
Expand All @@ -164,5 +166,5 @@ def benchmark_kvcache_copy(


if __name__ == "__main__":
test_copy_kv_to_caches(4, 32, 8, 16, False)
# benchmark_kvcache_copy.run(save_path=".")
test_copy_kv_to_caches(4, 32, 8, 16, True)
# benchmark_kvcache_copy.run(save_path=".", print_data=True)
Loading