From 9505cd5f58757c1878f8a2b922d91b8c03e6a151 Mon Sep 17 00:00:00 2001 From: Yuanheng Date: Tue, 16 Jan 2024 14:13:09 +0800 Subject: [PATCH 1/2] [kernel/fix] revise kvcache copy kernel api --- colossalai/kernel/triton/kvcache_copy.py | 33 +++++++++++-------- .../triton/test_kvcache_copy.py | 30 +++++++++-------- 2 files changed, 36 insertions(+), 27 deletions(-) diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index b979e24cd0fa..253b3912e6ab 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -25,11 +25,11 @@ def _copy_to_kvcache_seqlen1_kernel( cur_seq_idx = tl.program_id(0) cur_kv_head_idx = tl.program_id(1) - cur_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - last_bt_block_idx = cur_kv_seq_len // block_size + past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - 1 + last_bt_block_idx = past_kv_seq_len // block_size block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb) - offsets_in_last_block = (cur_kv_seq_len % block_size) * stride_cachebs + offsets_in_last_block = (past_kv_seq_len % block_size) * stride_cachebs offsets_dmodel = tl.arange(0, HEAD_DIM) offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd kv = tl.load(KV + offsets_kv) @@ -43,23 +43,30 @@ def _copy_to_kvcache_seqlen1_kernel( return -# Used with blocked kv cache. -# Copy k or v to block k/v cache during decoding stage def copy_kv_to_blocked_cache( - k: torch.Tensor, # [bsz, 1, num_kv_heads, head_dim], k or v during decoding stage - k_cache: torch.Tensor, # [num_blocks, num_kv_heads, head_dim, block_size], blocked k or v cache (for now, the shapes of them are the same) - context_lengths: torch.Tensor, # [bsz], past kv seq len (not incorporating the current kv of length 1) - block_tables: torch.Tensor, # [bsz, max_blocks_per_sequence] + k: torch.Tensor, + k_cache: torch.Tensor, + kv_lengths: torch.Tensor, + block_tables: torch.Tensor, ): + """ + Copy keys or values to the blocked key/value cache during decoding stage. + + Parameters: + - k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1. + - k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size] - Blocked key or value cache. + - kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence. + - block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence. + """ assert k.dim() == 4, "Unsupported shape of k (supposed to be used for decoding stage)" assert k.size(1) == 1, "Unsupported kv seq len (supposed to be used for decoding stage)" assert k.size(-1) == k_cache.size(-2), "Incompatible head dim" assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache." bsz, _, num_kv_heads, head_dim = k.shape - assert context_lengths.shape[0] == block_tables.shape[0] == bsz, ( + assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, ( f"Got incompatible batch size (number of seqs):\n" - f" Conext lengths bsz {context_lengths.shape[0]}, Block tables bsz {block_tables.shape[0]}, " - f"batch size {bsz}" + f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; " + f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}" ) # Modify if the shape of kv cahce is changed. @@ -74,7 +81,7 @@ def copy_kv_to_blocked_cache( k, k_cache, block_tables, - context_lengths, + kv_lengths, k.stride(0), k.stride(1), k.stride(2), diff --git a/tests/test_infer_ops/triton/test_kvcache_copy.py b/tests/test_infer_ops/triton/test_kvcache_copy.py index 875c34fba4dc..ce25f2d55b96 100644 --- a/tests/test_infer_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer_ops/triton/test_kvcache_copy.py @@ -30,12 +30,12 @@ def prepare_data( dtype=torch.float16, ): if same_context_len: - # context_lengths in this test records the previous kv seq len + # past_kv_seq_lengths in this test records the previous kv seq len # (not incorporating the current input whose seq len is 1) - context_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) + past_kv_seq_lengths = torch.tensor([max_seq_len - 1 for _ in range(bsz)], dtype=torch.int32, device=device) else: - context_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) - num_tokens = torch.sum(context_lengths).item() + past_kv_seq_lengths = torch.randint(low=1, high=max_seq_len - 1, size=(bsz,), dtype=torch.int32, device=device) + num_tokens = torch.sum(past_kv_seq_lengths).item() kv_size = (num_tokens, 2 * num_kv_heads, head_dim) kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5) @@ -46,15 +46,18 @@ def prepare_data( v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device) # Mock allocation on block tables as well as blocked kv caches block_tables = mock_alloc_block_table_and_kvcache( - k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size + k, v, k_cache, v_cache, past_kv_seq_lengths, bsz, max_num_blocks_per_seq, block_size ) block_tables = block_tables.to(device=device) new_k = torch.randn((bsz, 1, num_kv_heads, head_dim), dtype=dtype, device=device) # mock allocating blocks for the new k/v and update block tables - mock_alloc_single_token(block_tables, context_lengths, block_size) + mock_alloc_single_token(block_tables, past_kv_seq_lengths, block_size) - return new_k, k_cache, context_lengths, block_tables + # kv seq len = past kv seq len + seq len (1 during decoding stage) + kv_seq_lengths = past_kv_seq_lengths + 1 + + return new_k, k_cache, kv_seq_lengths, block_tables @pytest.mark.skipif(not (HAS_TRITON and TRITON_CUDA_SUPPORT), reason="requires triton") @@ -80,7 +83,7 @@ def test_copy_kv_to_caches( dtype = torch.float16 device = get_current_device() - new_k, k_cache, context_lengths, block_tables = prepare_data( + new_k, k_cache, kv_seq_lengths, block_tables = prepare_data( bsz, num_kv_heads, head_dim, @@ -91,15 +94,14 @@ def test_copy_kv_to_caches( device=device, dtype=dtype, ) - - copy_kv_to_blocked_cache(new_k, k_cache, context_lengths, block_tables) + copy_kv_to_blocked_cache(new_k, k_cache, kv_seq_lengths, block_tables) for seq_i in range(bsz): ki = new_k[seq_i] ki = ki.squeeze() - context_len_i = context_lengths[seq_i] - target_block_id = block_tables[seq_i, context_len_i // block_size] - offsets_in_block = context_len_i % block_size + past_kv_seq_len = kv_seq_lengths[seq_i] - 1 + target_block_id = block_tables[seq_i, past_kv_seq_len // block_size] + offsets_in_block = past_kv_seq_len % block_size target = k_cache[target_block_id, :, :, offsets_in_block] orig = new_k[seq_i].squeeze(dim=0) assert torch.equal(orig, target) @@ -164,5 +166,5 @@ def benchmark_kvcache_copy( if __name__ == "__main__": - test_copy_kv_to_caches(4, 32, 8, 16, False) + test_copy_kv_to_caches(4, 32, 8, 16, True) # benchmark_kvcache_copy.run(save_path=".") From c4264a0c2c36f40b83170333f8426fc4a943ce80 Mon Sep 17 00:00:00 2001 From: Yuanheng Date: Tue, 16 Jan 2024 14:35:02 +0800 Subject: [PATCH 2/2] fix benchmark --- tests/test_infer_ops/triton/test_kvcache_copy.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/tests/test_infer_ops/triton/test_kvcache_copy.py b/tests/test_infer_ops/triton/test_kvcache_copy.py index ce25f2d55b96..c2ccb5ef5f7b 100644 --- a/tests/test_infer_ops/triton/test_kvcache_copy.py +++ b/tests/test_infer_ops/triton/test_kvcache_copy.py @@ -107,11 +107,11 @@ def test_copy_kv_to_caches( assert torch.equal(orig, target) -BATCH = 4 +BATCH = 16 configs = [ triton.testing.Benchmark( - x_names=["PAST_KVLEN"], - x_vals=[2**i - 1 for i in range(8, 13)], + x_names=["KV_SEQ_LEN"], + x_vals=[2**i for i in range(8, 13)], line_arg="provider", line_vals=["torch_copy_func", "triton_copy_func"], line_names=["torch_copy_func", "triton_copy_func"], @@ -129,7 +129,7 @@ def benchmark_kvcache_copy( bsz: int, block_size: int, max_seq_len: int, - PAST_KVLEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens) + KV_SEQ_LEN: int, # maximum past kv length (unequal context lens in batch) or past kv len (equal context lens) num_kv_heads: int, same_context_len: bool, ): @@ -140,7 +140,7 @@ def benchmark_kvcache_copy( dtype = torch.float16 device = get_current_device() - assert PAST_KVLEN < max_seq_len, "Assigned maximum past kv length must be smaller or equal to maximum seq len" + assert KV_SEQ_LEN <= max_seq_len, "Assigned maximum kv length must be smaller or equal to maximum seq len" new_k, k_cache, context_lengths, block_tables = prepare_data( bsz, @@ -149,7 +149,7 @@ def benchmark_kvcache_copy( block_size, max_seq_len // block_size, same_context_len, - PAST_KVLEN, + KV_SEQ_LEN, device=device, dtype=dtype, ) @@ -167,4 +167,4 @@ def benchmark_kvcache_copy( if __name__ == "__main__": test_copy_kv_to_caches(4, 32, 8, 16, True) - # benchmark_kvcache_copy.run(save_path=".") + # benchmark_kvcache_copy.run(save_path=".", print_data=True)