From 9f4ab2eb924b938348df2c713bb4580972f18eb1 Mon Sep 17 00:00:00 2001 From: Jianghai <72591262+CjhHa1@users.noreply.github.com> Date: Wed, 7 Feb 2024 11:36:04 +0800 Subject: [PATCH] [Inference] Adapt to Fused rotary (#5348) * revise rotary embedding * remove useless print * adapt * fix * add * fix * modeling * fix * fix * fix --- .../modeling/models/nopadding_llama.py | 5 +- colossalai/kernel/triton/kvcache_copy.py | 1 - .../kernel/triton/no_pad_rotary_embedding.py | 136 ++++++++++++++++-- examples/inference/run_benchmark.sh | 1 + .../triton/test_rotary_embdding_unpad.py | 40 ++++-- 5 files changed, 161 insertions(+), 22 deletions(-) diff --git a/colossalai/inference/modeling/models/nopadding_llama.py b/colossalai/inference/modeling/models/nopadding_llama.py index 355140bc1f46..44ce381a471c 100644 --- a/colossalai/inference/modeling/models/nopadding_llama.py +++ b/colossalai/inference/modeling/models/nopadding_llama.py @@ -282,11 +282,10 @@ def forward( torch.bmm(hidden_states, self.qkv_weight).view(3, token_nums, self.num_heads, self.head_dim).unbind(0) ) - rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) - block_size = k_cache.size(-2) if is_prompts: + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1]) attn_output = context_attention_unpadded( q=query_states, k=key_states, @@ -301,7 +300,7 @@ def forward( sm_scale=sm_scale, ) else: - copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables) + rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1], k_cache, block_tables, sequence_lengths) copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables) attn_output = flash_decoding_attention( q=query_states, diff --git a/colossalai/kernel/triton/kvcache_copy.py b/colossalai/kernel/triton/kvcache_copy.py index 1aaeb6830e7c..8e31b42a8ae7 100644 --- a/colossalai/kernel/triton/kvcache_copy.py +++ b/colossalai/kernel/triton/kvcache_copy.py @@ -75,7 +75,6 @@ def copy_kv_to_blocked_cache( block_size = k_cache.size(-2) num_warps = 8 if head_dim > 128 else 4 - grid = (bsz, num_kv_heads) _copy_to_kvcache_seqlen1_kernel[grid]( k, diff --git a/colossalai/kernel/triton/no_pad_rotary_embedding.py b/colossalai/kernel/triton/no_pad_rotary_embedding.py index 9194319d5ece..7a38c0fc8692 100644 --- a/colossalai/kernel/triton/no_pad_rotary_embedding.py +++ b/colossalai/kernel/triton/no_pad_rotary_embedding.py @@ -222,11 +222,11 @@ def fused_rotary_embedding_kernel( out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :] out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :] # total_tokens, head_num, head_dim - past_kv_seq_len = tl.load(context_lengths + tokens_range) - 1 + past_kv_seq_len = tl.load(context_lengths + tokens_range, mask=(tokens_range < q_total_tokens)) - 1 last_block_idx = past_kv_seq_len // block_size block_table_ptr = BLOCK_TABLES + tokens_range * bts_stride - block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride) + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(tokens_range < q_total_tokens)) offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride kv_range0 = ( @@ -274,6 +274,122 @@ def fused_rotary_embedding_kernel( ) +@triton.jit +def fused_rotary_embedding_kernel_v2( + q, + k, + cos, + sin, + kv_cache, + BLOCK_TABLES, + context_lengths, + q_token_stride, + q_head_stride, + k_token_stride, + k_head_stride, + head_dim_stride, + cos_token_stride, + cos_stride, + cacheb_stride, + cacheh_stride, + cachebs_stride, + cached_stride, + bts_stride, + btb_stride, + block_size, + q_total_tokens, + Q_HEAD_NUM: tl.constexpr, + K_HEAD_NUM: tl.constexpr, + HEAD_DIM: tl.constexpr, +): + block_head_index = tl.program_id(0) + if block_head_index >= Q_HEAD_NUM: + return + block_token_index = tl.program_id(1) + + dim_range0 = tl.arange(0, HEAD_DIM // 2) + dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM) + + off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride + off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride + off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride + off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride + + loaded_q0 = tl.load( + q + off_q0, + ) + loaded_q1 = tl.load( + q + off_q1, + ) + + loaded_k0 = tl.load( + k + off_k0, + ) + + loaded_k1 = tl.load( + k + off_k1, + ) + + off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride + + loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) + loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0) + + out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin + out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos + + out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin + out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim + + past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1 + + last_block_idx = past_kv_seq_len // block_size + block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride + block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens)) + offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride + + kv_range0 = ( + block_ids * cacheb_stride + + block_head_index * cacheh_stride + + offsets_in_last_block + + dim_range0 * cached_stride + ) + kv_range1 = ( + block_ids * cacheb_stride + + block_head_index * cacheh_stride + + offsets_in_last_block + + dim_range1 * cached_stride + ) + + tl.store( + kv_cache + kv_range0, + out_k0, + ) + tl.store( + kv_cache + kv_range1, + out_k1, + ) + + # concat + tl.store( + q + off_q0, + out_q0, + ) + tl.store( + q + off_q1, + out_q1, + ) + tl.store( + k + off_k0, + out_k0, + ) + tl.store( + k + off_k1, + out_k1, + ) + + +@torch.no_grad() def rotary_embedding( q: torch.Tensor, k: torch.Tensor, @@ -297,12 +413,13 @@ def rotary_embedding( assert q.size(0) == k.size(0) BLOCK_HEAD = 4 BLOCK_TOKENS = 4 - grid = lambda META: (triton.cdiv(q_head_num, META["BLOCK_HEAD"]), triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"])) - if head_dim >= 256: + if head_dim >= 1024: num_warps = 32 - elif head_dim >= 128: + elif head_dim >= 512: num_warps = 16 + elif head_dim >= 256: + num_warps = 8 else: num_warps = 4 @@ -318,6 +435,10 @@ def rotary_embedding( cos_token_stride = cos.stride(0) cos_stride = cos.stride(1) if k_cache == None: + grid = lambda META: ( + triton.cdiv(q_head_num, META["BLOCK_HEAD"]), + triton.cdiv(q_total_tokens, META["BLOCK_TOKENS"]), + ) rotary_embedding_kernel[grid]( q, k, @@ -339,7 +460,8 @@ def rotary_embedding( num_warps=num_warps, ) else: - fused_rotary_embedding_kernel[grid]( + grid = (triton.next_power_of_2(q_head_num), q_total_tokens) + fused_rotary_embedding_kernel_v2[grid]( q, k, cos, @@ -365,8 +487,6 @@ def rotary_embedding( Q_HEAD_NUM=q_head_num, K_HEAD_NUM=k_head_num, HEAD_DIM=head_dim, - BLOCK_HEAD=BLOCK_HEAD, - BLOCK_TOKENS=BLOCK_TOKENS, num_warps=num_warps, ) return diff --git a/examples/inference/run_benchmark.sh b/examples/inference/run_benchmark.sh index 2a6e5a5d75b8..a8619bce99f7 100755 --- a/examples/inference/run_benchmark.sh +++ b/examples/inference/run_benchmark.sh @@ -1,4 +1,5 @@ ROOT=$(realpath $(dirname $0)) +echo $ROOT PY_SCRIPT=${ROOT}/benchmark_llama.py GPU=$(nvidia-smi -L | head -1 | cut -d' ' -f4 | cut -d'-' -f1) mode=$1 diff --git a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py index 6a8dc85f0aec..e4f4bb282647 100644 --- a/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py +++ b/tests/test_infer/test_ops/triton/test_rotary_embdding_unpad.py @@ -3,7 +3,7 @@ from packaging import version from transformers.models.llama.modeling_llama import LlamaRotaryEmbedding, apply_rotary_pos_emb -from colossalai.kernel.triton import rotary_embedding +from colossalai.kernel.triton import copy_kv_to_blocked_cache, rotary_embedding from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2 try: @@ -94,8 +94,8 @@ def test_rotary_emb(BATCH_SIZE, SEQ_LEN, H, D, dtype): x_names=["num_tokens"], x_vals=[2**i for i in range(4, 11)], line_arg="provider", - line_vals=["torch_rotary_emb_func", "triton_rotary_emb_func"], - line_names=["torch_rotary_emb_func", "triton_rotary_emb_func"], + line_vals=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], + line_names=["no_fused_rotary_emb_func", "fused_triton_rotary_emb_func"], styles=[("red", "-"), ("blue", "-")], ylabel="ms", plot_name=f"rotary_emb-batch-{BATCH}", @@ -110,11 +110,16 @@ def benchmark_rotary_emb( num_tokens: int, num_kv_heads: int, ): + BATCH_SIZE = 4 + SEQ_LEN = num_tokens // BATCH_SIZE + max_num_blocks_per_seq = 8 + block_size = 64 warmup = 10 rep = 100 - head_dim = 128 + head_dim = 256 dtype = torch.float16 + q_shape = (num_tokens, num_kv_heads, head_dim) q = -2.3 + 0.5 * torch.randn(q_shape, dtype=dtype, device="cuda") k_shape = (num_tokens, num_kv_heads, head_dim) @@ -122,11 +127,26 @@ def benchmark_rotary_emb( cos_shape = (num_tokens, head_dim // 2) cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim) + k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda") + v = torch.randn_like(k) + v_cache = torch.zeros_like(k_cache) + past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda") + block_tables = mock_alloc_block_table_and_kvcache_v2( + k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size + ) + new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda") + new_q = torch.randn_like(new_k) + kv_seq_lengths = past_kv_seq_lengths + 1 + block_tables = block_tables.to(device="cuda") - if provider == "torch_rotary_emb_func": - fn = lambda: torch_rotary_emb(q, cos, sin) - elif provider == "triton_rotary_emb_func": - fn = lambda: rotary_embedding(q, k, cos, sin) + if provider == "no_fused_rotary_emb_func": + fn = lambda: [ + rotary_embedding(new_q, new_k, cos, sin), + copy_kv_to_blocked_cache(new_k, k_cache, kv_lengths=kv_seq_lengths, block_tables=block_tables), + ] + elif provider == "fused_triton_rotary_emb_func": + fn = lambda: rotary_embedding(new_q, new_k, cos, sin, k_cache, block_tables, kv_seq_lengths) else: raise ValueError("Undefined provider") @@ -135,5 +155,5 @@ def benchmark_rotary_emb( if __name__ == "__main__": - test_rotary_emb(4, 64, 32, 64, torch.float32) - # benchmark_rotary_emb.run(save_path=".",print_data=True) + # test_rotary_emb(4, 64, 32, 64, torch.float32) + benchmark_rotary_emb.run(save_path=".", print_data=True)