From c3bf7d076bc3109e52b17e3c2ea8d779fb030207 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 22 Aug 2023 13:10:03 +0800 Subject: [PATCH 01/22] added _vllm_rms_norm --- colossalai/shardformer/modeling/llama.py | 48 ++++++++++++++++++++++-- 1 file changed, 45 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f1d2998bbee4..7fe75ea4c483 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -7,7 +7,7 @@ CausalLMOutputWithPast, SequenceClassifierOutputWithPast, ) -from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel +from transformers.models.llama.modeling_llama import LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaRMSNorm from transformers.utils import logging from colossalai.pipeline.stage_manager import PipelineStageManager @@ -391,8 +391,17 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb - from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention + + try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True + except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + def forward( self: LlamaAttention, @@ -415,7 +424,12 @@ def forward( kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if HAS_VLLM_KERNERL: + cos_sin_cache = torch.cat((cos, sin), dim=-1) + rotary_embedding_neox(position_ids, query_states, key_states, self.head_dim, cos_sin_cache) + else: + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) if past_key_value is not None: # reuse k, v, self_attention @@ -450,3 +464,31 @@ def forward( return attn_output, None, past_key_value return forward + + +def get_llama_vllm_rmsnorm_forward(): + try: + from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True + except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + + if HAS_VLLM_KERNERL: + def _vllm_rmsnorm_forward(self: LlamaRMSNorm, hidden_states: torch.Tensor): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + self.weight.data, + self.variance_epsilon, + ) + + return out + + return _vllm_rmsnorm_forward + else: + return None From 3a31b68c9b126b6d2f07e220bca1ba2e619114f6 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 22 Aug 2023 13:14:05 +0800 Subject: [PATCH 02/22] change place --- tests/test_kernels/{ => triton}/test_self_attention.py | 0 tests/test_kernels/{ => triton}/test_softmax.py | 0 2 files changed, 0 insertions(+), 0 deletions(-) rename tests/test_kernels/{ => triton}/test_self_attention.py (100%) rename tests/test_kernels/{ => triton}/test_softmax.py (100%) diff --git a/tests/test_kernels/test_self_attention.py b/tests/test_kernels/triton/test_self_attention.py similarity index 100% rename from tests/test_kernels/test_self_attention.py rename to tests/test_kernels/triton/test_self_attention.py diff --git a/tests/test_kernels/test_softmax.py b/tests/test_kernels/triton/test_softmax.py similarity index 100% rename from tests/test_kernels/test_softmax.py rename to tests/test_kernels/triton/test_softmax.py From 2ef5d836ef71cce7a325b314d03352d394cdf752 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 22 Aug 2023 13:44:43 +0800 Subject: [PATCH 03/22] added tests --- tests/test_kernels/cuda/test_vllm_rmsnorm.py | 60 +++++++++++++++++++ ...on.py => test_self_attention_nonfusion.py} | 0 2 files changed, 60 insertions(+) create mode 100644 tests/test_kernels/cuda/test_vllm_rmsnorm.py rename tests/test_kernels/triton/{test_self_attention.py => test_self_attention_nonfusion.py} (100%) diff --git a/tests/test_kernels/cuda/test_vllm_rmsnorm.py b/tests/test_kernels/cuda/test_vllm_rmsnorm.py new file mode 100644 index 000000000000..cb12faf6276c --- /dev/null +++ b/tests/test_kernels/cuda/test_vllm_rmsnorm.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import os +import pytest +import numpy as np +from packaging import version + +import torch +from torch import nn +from torch.nn import functional as F + +try: + from vllm import layernorm_ops + rms_norm = layernorm_ops.rms_norm + HAS_VLLM_KERNERL = True +except: + print("please install vllm kernels to install rmsnorm") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + +class LlamaRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + LlamaRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + +def cuda_rmsnorm_forward(hidden_states, weight, variance_epsilon): + x = hidden_states + out = torch.empty_like(x) + rms_norm( + out, + x, + weight, + variance_epsilon, + ) + return out + +@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") +def test_rmsnorm(): + data = torch.randn((1024, 64), dtype=torch.float16, device="cuda") + hg_rms = LlamaRMSNorm(64) + hg_rms = hg_rms.half().cuda() + out_torch = hg_rms(data) + out_cuda = cuda_rmsnorm_forward(data, hg_rms.weight.data, hg_rms.variance_epsilon) + + check = torch.allclose(out_torch.cpu(), out_cuda.cpu(), rtol=1e-3, atol=1e-5) + assert check is True, "cuda rmsnorm forward is not matched with torch rmsnorm forward" + +if __name__ == "__main__": + test_rmsnorm() \ No newline at end of file diff --git a/tests/test_kernels/triton/test_self_attention.py b/tests/test_kernels/triton/test_self_attention_nonfusion.py similarity index 100% rename from tests/test_kernels/triton/test_self_attention.py rename to tests/test_kernels/triton/test_self_attention_nonfusion.py From 11e771c32a1b1ab669497a0fb7f3699079d0acd6 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 22 Aug 2023 14:09:41 +0800 Subject: [PATCH 04/22] added tests --- LICENSE | 16 ++ .../cuda/test_vllm_rotary_embedding.py | 156 ++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 tests/test_kernels/cuda/test_vllm_rotary_embedding.py diff --git a/LICENSE b/LICENSE index c7a5bb16880e..6e7742c3078b 100644 --- a/LICENSE +++ b/LICENSE @@ -396,3 +396,19 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + ---------------- LICENSE FOR VLLM TEAM ---------------- + + from VLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/vllm-project/vllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/tests/test_kernels/cuda/test_vllm_rotary_embedding.py b/tests/test_kernels/cuda/test_vllm_rotary_embedding.py new file mode 100644 index 000000000000..2a85566c65c6 --- /dev/null +++ b/tests/test_kernels/cuda/test_vllm_rotary_embedding.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- +import pytest +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import apply_rotary_pos_emb, rotate_half + +try: + from vllm import pos_encoding_ops + rotary_embedding_neox = pos_encoding_ops.rotary_embedding_neox + HAS_VLLM_KERNERL = True +except: + print("fall back to original rotary_embedding_neox of huggingface") + print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + HAS_VLLM_KERNERL = False + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def apply_rotary_pos_emb( + q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +class RefRotaryEmbeddingNeox(nn.Module): + """Reference implementation of the GPT-NeoX style rotary embedding.""" + + def __init__( + self, + dim: int, + max_position_embeddings: int = 2048, + base: int = 10000, + ) -> None: + super().__init__() + self.rotary_dim = dim + self.max_position_embeddings = max_position_embeddings + + # Create cos and sin embeddings. + inv_freq = 1.0 / (base**(torch.arange(0, dim, 2) / dim)) + t = torch.arange(max_position_embeddings).float() + freqs = torch.einsum("i,j->ij", t, inv_freq.float()) + emb = torch.cat((freqs, freqs), dim=-1) + cos = emb.cos().to(dtype=inv_freq.dtype) + sin = emb.sin().to(dtype=inv_freq.dtype) + self.register_buffer("cos_cached", cos, persistent=False) + self.register_buffer("sin_cached", sin, persistent=False) + + def forward( + self, + positions: torch.Tensor, # [num_tokens] + query: torch.Tensor, # [num_tokens, num_heads, head_size] + key: torch.Tensor, # [num_tokens, num_heads, head_size] + ) -> Tuple[torch.Tensor, torch.Tensor]: + + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + + query_rot = query_rot.transpose(0, 1) + key_rot = key_rot.transpose(0, 1) + cos = F.embedding(positions, self.cos_cached) + sin = F.embedding(positions, self.sin_cached) + query_rot, key_rot = apply_rotary_pos_emb(query_rot, key_rot, cos, sin) + query_rot = query_rot.transpose(0, 1).contiguous() + key_rot = key_rot.transpose(0, 1).contiguous() + + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + + # Output query/key shape: [num_tokens, num_tokens, head_size] + return query, key + +def run_rotary_embedding_neox( + num_tokens: int, + num_heads: int, + head_size: int, + max_position: int, + rotary_dim: int, + dtype: torch.dtype, + base: int = 10000, +) -> None: + positions = torch.randint(0, max_position, (num_tokens, ), device='cuda') + query = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + key = torch.randn(num_tokens, + num_heads * head_size, + dtype=dtype, + device='cuda') + + # Create the rotary embedding. + inv_freq = 1.0 / (base**(torch.arange(0, rotary_dim, 2) / rotary_dim)) + t = torch.arange(max_position).float() + freqs = torch.einsum('i,j -> ij', t, inv_freq.float()) + cos = freqs.cos() + sin = freqs.sin() + cos_sin_cache = torch.cat((cos, sin), dim=-1) + cos_sin_cache = cos_sin_cache.to(dtype=dtype, device='cuda') + + # Run the kernel. The kernel is in-place, so we need to clone the inputs. + out_query = query.clone() + out_key = key.clone() + rotary_embedding_neox( + positions, + out_query, + out_key, + head_size, + cos_sin_cache, + ) + + # Run the reference implementation. + ref_rotary_embedding = RefRotaryEmbeddingNeox( + dim=rotary_dim, + max_position_embeddings=max_position, + base=base, + ).to(dtype=dtype, device='cuda') + ref_query, ref_key = ref_rotary_embedding( + positions, + query.view(num_tokens, num_heads, head_size), + key.view(num_tokens, num_heads, head_size), + ) + ref_query = ref_query.view(num_tokens, num_heads * head_size) + ref_key = ref_key.view(num_tokens, num_heads * head_size) + + # Compare the results. + assert torch.allclose(out_query, ref_query, atol=1e-3, rtol=1e-5) + assert torch.allclose(out_key, ref_key, atol=1e-3, rtol=1e-5) + +@pytest.mark.skipif(not HAS_VLLM_KERNERL, reason="You need to install llama supported cuda kernels to run this test") +def test_rotary_embedding(): + run_rotary_embedding_neox( + num_tokens=1024, + num_heads=8, + head_size=64, + max_position=8192, + rotary_dim=64, + dtype=torch.float16, + ) + +if __name__ == "__main__": + test_rotary_embedding() \ No newline at end of file From 5dd17d266d9d71fa3ef7df7cbcdbb4adf2863599 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 22 Aug 2023 22:22:52 +0800 Subject: [PATCH 05/22] modify --- colossalai/kernel/triton/ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/ops.py index 5e8d4ba3ec99..fbca2cf3a73d 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/ops.py @@ -196,7 +196,7 @@ def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Ten elif block_size >= 2048: BLOCK_M = 8 - softmax_kernel_2[grid](output_ptr = output, + softmax_kernel[grid](output_ptr = output, input_ptr = input, row_stride = input.stride(0), n_rows = num_rows, From cdabf05b54409c7d8fa5eda8781784352a96f037 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Tue, 22 Aug 2023 22:42:16 +0800 Subject: [PATCH 06/22] adding kernels --- LICENSE | 16 +++ .../triton/bloom_context_attention_kernel.py | 101 ++++++++++++++++++ .../triton/llama_context_attention_kernel.py | 9 ++ 3 files changed, 126 insertions(+) create mode 100644 colossalai/kernel/triton/bloom_context_attention_kernel.py create mode 100644 colossalai/kernel/triton/llama_context_attention_kernel.py diff --git a/LICENSE b/LICENSE index 6e7742c3078b..06629068faa5 100644 --- a/LICENSE +++ b/LICENSE @@ -412,3 +412,19 @@ Copyright 2021- HPC-AI Technology Inc. All rights reserved. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. + + ---------------- LICENSE FOR LIGHTLLM TEAM ---------------- + + from LIGHTLLM TEAM: + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + https://github.com/ModelTC/lightllm/blob/main/LICENSE + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/colossalai/kernel/triton/bloom_context_attention_kernel.py b/colossalai/kernel/triton/bloom_context_attention_kernel.py new file mode 100644 index 000000000000..9ffd45d8c06a --- /dev/null +++ b/colossalai/kernel/triton/bloom_context_attention_kernel.py @@ -0,0 +1,101 @@ +import torch +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + @triton.jit + def _context_fwd_kernel( + Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + stride_tmp_b, stride_tmp_h, stride_tmp_s, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + alibi_m = tl.load(Alibi + cur_head) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m + + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return \ No newline at end of file diff --git a/colossalai/kernel/triton/llama_context_attention_kernel.py b/colossalai/kernel/triton/llama_context_attention_kernel.py new file mode 100644 index 000000000000..c8b6f3a4b7bd --- /dev/null +++ b/colossalai/kernel/triton/llama_context_attention_kernel.py @@ -0,0 +1,9 @@ +import torch +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + \ No newline at end of file From cabbe27fd3926d2df013ad22fd0ebf1b85802cee Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 11:39:24 +0800 Subject: [PATCH 07/22] added tests: --- ...n_kernel.py => bloom_context_attention.py} | 4 +- .../kernel/triton/llama_context_attention.py | 96 +++++++++++++++++++ .../triton/llama_context_attention_kernel.py | 9 -- .../{ops.py => self_attention_nofusion.py} | 55 +---------- colossalai/kernel/triton/softmax.py | 96 +++++++++++++++++++ colossalai/kernel/triton/softmax_kernel.py | 44 --------- .../triton/test_self_attention_nonfusion.py | 2 +- tests/test_kernels/triton/test_softmax.py | 2 +- 8 files changed, 198 insertions(+), 110 deletions(-) rename colossalai/kernel/triton/{bloom_context_attention_kernel.py => bloom_context_attention.py} (97%) create mode 100644 colossalai/kernel/triton/llama_context_attention.py delete mode 100644 colossalai/kernel/triton/llama_context_attention_kernel.py rename colossalai/kernel/triton/{ops.py => self_attention_nofusion.py} (74%) create mode 100644 colossalai/kernel/triton/softmax.py delete mode 100644 colossalai/kernel/triton/softmax_kernel.py diff --git a/colossalai/kernel/triton/bloom_context_attention_kernel.py b/colossalai/kernel/triton/bloom_context_attention.py similarity index 97% rename from colossalai/kernel/triton/bloom_context_attention_kernel.py rename to colossalai/kernel/triton/bloom_context_attention.py index 9ffd45d8c06a..b6eb3872ba48 100644 --- a/colossalai/kernel/triton/bloom_context_attention_kernel.py +++ b/colossalai/kernel/triton/bloom_context_attention.py @@ -11,7 +11,7 @@ if HAS_TRITON: @triton.jit def _context_fwd_kernel( - Q, K, V, sm_scale, Alibi, B_Start_Loc, B_Seqlen, + Q, K, V, sm_scale, bias, B_Start_Loc, B_Seqlen, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, stride_qbs, stride_qh, stride_qd, @@ -49,7 +49,7 @@ def _context_fwd_kernel( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - alibi_m = tl.load(Alibi + cur_head) + alibi_m = tl.load(bias + cur_head) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) diff --git a/colossalai/kernel/triton/llama_context_attention.py b/colossalai/kernel/triton/llama_context_attention.py new file mode 100644 index 000000000000..7972bb6ffaf2 --- /dev/null +++ b/colossalai/kernel/triton/llama_context_attention.py @@ -0,0 +1,96 @@ +import torch +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + + +if HAS_TRITON: + @triton.jit + def _fwd_kernel( + Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, + TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + Out, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + stride_tmp_b, stride_tmp_h, stride_tmp_s, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + ): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd + q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + # t_ptrs = TMP + offs_m + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + tl.store(t_ptrs, acc_scale) + acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load + acc = acc * acc_scale[:, None] + # update acc + v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + + return \ No newline at end of file diff --git a/colossalai/kernel/triton/llama_context_attention_kernel.py b/colossalai/kernel/triton/llama_context_attention_kernel.py deleted file mode 100644 index c8b6f3a4b7bd..000000000000 --- a/colossalai/kernel/triton/llama_context_attention_kernel.py +++ /dev/null @@ -1,9 +0,0 @@ -import torch -try: - import triton - import triton.language as tl - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - \ No newline at end of file diff --git a/colossalai/kernel/triton/ops.py b/colossalai/kernel/triton/self_attention_nofusion.py similarity index 74% rename from colossalai/kernel/triton/ops.py rename to colossalai/kernel/triton/self_attention_nofusion.py index fbca2cf3a73d..a6c9bdfbdff6 100644 --- a/colossalai/kernel/triton/ops.py +++ b/colossalai/kernel/triton/self_attention_nofusion.py @@ -11,7 +11,7 @@ if HAS_TRITON: from .qkv_matmul_kernel import qkv_gemm_4d_kernel - from .softmax_kernel import softmax_kernel + from .softmax import softmax_kernel def self_attention_forward_without_fusion(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, input_mask: torch.Tensor, scale: float): r""" A function to do QKV Attention calculation by calling GEMM and softmax triton kernels @@ -155,55 +155,4 @@ def self_attention_compute_using_triton(qkv, data_output_triton = self_attention_forward_without_fusion( q, k, v, input_mask, scale) - return data_output_triton - - - def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: - if mask is not None: - assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" - assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" - - hidden_dim = input.shape[-1] - output = torch.empty_like(input) - input = input.view(-1, hidden_dim) - if mask is not None: - mask = mask.view(-1, hidden_dim) - assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" - - num_rows, num_cols = input.shape - block_size = max(triton.next_power_of_2(num_cols), 2) - num_warps = 16 - if block_size >= 4096: - num_warps = 16 - elif block_size >= 2048: - num_warps = 8 - else: - num_warps = 4 - - if num_rows <= 350000: - grid = (num_rows,) - softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) - else: - grid = lambda meta: () - - grid = lambda meta: ( - triton.cdiv(num_rows, meta["BLOCK_M"]), - ) - - BLOCK_M = 32 - if block_size >= 4096: - BLOCK_M = 4 - elif block_size >= 2048: - BLOCK_M = 8 - - softmax_kernel[grid](output_ptr = output, - input_ptr = input, - row_stride = input.stride(0), - n_rows = num_rows, - n_cols = num_cols, - mask_ptr = mask, - # currently manually setting up size - BLOCK_M = 32, - BLOCK_SIZE = block_size) - - return output \ No newline at end of file + return data_output_triton \ No newline at end of file diff --git a/colossalai/kernel/triton/softmax.py b/colossalai/kernel/triton/softmax.py new file mode 100644 index 000000000000..c65adaf40dda --- /dev/null +++ b/colossalai/kernel/triton/softmax.py @@ -0,0 +1,96 @@ +import torch +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + ''' + softmax kernel is modified based on + https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py + ''' + @triton.jit + def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): + r""" the kernel function for implementing softmax operator + Args: + output_ptr: the output after finishing softmax operation, (N, hidden_dim) + input_ptr: the tensor of input, shape should be (N, hidden_dim) + n_cols(tl.constexpr): the number of cols of input + BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim + """ + row_idx = tl.program_id(0) + row_start_ptr = input_ptr + row_idx * row_stride + col_offsets = tl.arange(0, BLOCK_SIZE) + input_ptrs = row_start_ptr + col_offsets + row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) + row_minus_max = row - tl.max(row, axis=0) + + if mask_ptr is not None: + # load mask into SRAM + mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets + mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) + + # update + row_minus_max = row_minus_max + mask + + numerator = tl.exp(row_minus_max) + denominator = tl.sum(numerator, axis=0) + softmax_output = numerator / denominator + output_row_start_ptr = output_ptr + row_idx * row_stride + output_ptrs = output_row_start_ptr + col_offsets + # Write back output to DRAM + tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) + + + def softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor: + if mask is not None: + assert input[-1] == mask[-1], "the last dimentions should be the same for input and mask" + assert dim == -1 or dim == len(input.shape)-1, "currently softmax layer only support last dimention" + + hidden_dim = input.shape[-1] + output = torch.empty_like(input) + input = input.view(-1, hidden_dim) + if mask is not None: + mask = mask.view(-1, hidden_dim) + assert input.shape[0] == mask.shape[0], "the fist dimention of mask and input should be the same" + + num_rows, num_cols = input.shape + block_size = max(triton.next_power_of_2(num_cols), 2) + num_warps = 16 + if block_size >= 4096: + num_warps = 16 + elif block_size >= 2048: + num_warps = 8 + else: + num_warps = 4 + + if num_rows <= 350000: + grid = (num_rows,) + softmax_kernel[grid](output, input, input.stride(0), num_cols, mask, BLOCK_SIZE = block_size, num_warps=num_warps) + else: + grid = lambda meta: () + + grid = lambda meta: ( + triton.cdiv(num_rows, meta["BLOCK_M"]), + ) + + BLOCK_M = 32 + if block_size >= 4096: + BLOCK_M = 4 + elif block_size >= 2048: + BLOCK_M = 8 + + softmax_kernel[grid](output_ptr = output, + input_ptr = input, + row_stride = input.stride(0), + n_rows = num_rows, + n_cols = num_cols, + mask_ptr = mask, + # currently manually setting up size + BLOCK_M = 32, + BLOCK_SIZE = block_size) + + return output \ No newline at end of file diff --git a/colossalai/kernel/triton/softmax_kernel.py b/colossalai/kernel/triton/softmax_kernel.py deleted file mode 100644 index c215890badff..000000000000 --- a/colossalai/kernel/triton/softmax_kernel.py +++ /dev/null @@ -1,44 +0,0 @@ -try: - import triton - import triton.language as tl - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - -if HAS_TRITON: - ''' - softmax kernel is modified based on - https://github.com/openai/triton/blob/34817ecc954a6f4ca7b4dfb352fdde1f8bd49ca5/python/tutorials/02-fused-softmax.py - ''' - @triton.jit - def softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr): - r""" the kernel function for implementing softmax operator - Args: - output_ptr: the output after finishing softmax operation, (N, hidden_dim) - input_ptr: the tensor of input, shape should be (N, hidden_dim) - n_cols(tl.constexpr): the number of cols of input - BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim - """ - row_idx = tl.program_id(0) - row_start_ptr = input_ptr + row_idx * row_stride - col_offsets = tl.arange(0, BLOCK_SIZE) - input_ptrs = row_start_ptr + col_offsets - row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf')).to(tl.float32) - row_minus_max = row - tl.max(row, axis=0) - - if mask_ptr is not None: - # load mask into SRAM - mask_ptrs = (mask_ptr + (row_indx * row_stride)) + col_offsets - mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32) - - # update - row_minus_max = row_minus_max + mask - - numerator = tl.exp(row_minus_max) - denominator = tl.sum(numerator, axis=0) - softmax_output = numerator / denominator - output_row_start_ptr = output_ptr + row_idx * row_stride - output_ptrs = output_row_start_ptr + col_offsets - # Write back output to DRAM - tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols) \ No newline at end of file diff --git a/tests/test_kernels/triton/test_self_attention_nonfusion.py b/tests/test_kernels/triton/test_self_attention_nonfusion.py index b316404a58db..bfe67afdd638 100644 --- a/tests/test_kernels/triton/test_self_attention_nonfusion.py +++ b/tests/test_kernels/triton/test_self_attention_nonfusion.py @@ -4,7 +4,7 @@ from torch import nn import torch.nn.functional as F -from colossalai.kernel.triton.ops import self_attention_compute_using_triton +from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel try: diff --git a/tests/test_kernels/triton/test_softmax.py b/tests/test_kernels/triton/test_softmax.py index 843d811d019c..e7309b08cd84 100644 --- a/tests/test_kernels/triton/test_softmax.py +++ b/tests/test_kernels/triton/test_softmax.py @@ -3,7 +3,7 @@ import torch from torch import nn -from colossalai.kernel.triton.ops import softmax +from colossalai.kernel.triton.softmax import softmax TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') From b400a0c3fd2b4187b53a7a54c8788c8f3135a263 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 15:04:27 +0800 Subject: [PATCH 08/22] adding kernels --- .../kernel/triton/copy_kv_cache_dest.py | 69 +++++++++++++++++++ .../test_kernels/triton/test_copy_kv_dest.py | 34 +++++++++ tests/test_kernels/triton/test_softmax.py | 1 + tests/test_kernels/triton/utils.py | 25 +++++++ 4 files changed, 129 insertions(+) create mode 100644 colossalai/kernel/triton/copy_kv_cache_dest.py create mode 100644 tests/test_kernels/triton/test_copy_kv_dest.py create mode 100644 tests/test_kernels/triton/utils.py diff --git a/colossalai/kernel/triton/copy_kv_cache_dest.py b/colossalai/kernel/triton/copy_kv_cache_dest.py new file mode 100644 index 000000000000..c1eaa8a10ed1 --- /dev/null +++ b/colossalai/kernel/triton/copy_kv_cache_dest.py @@ -0,0 +1,69 @@ +import torch + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +if HAS_TRITON: + @triton.jit + def _fwd_copy_kv_cache_dest( + kv_cache_ptr, dest_index_ptr, + out, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_d, + head_num, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr + ): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(dest_index_ptr + cur_index) + + cache_offsets = stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] + k_ptrs = kv_cache_ptr + cur_index * stride_k_bs + cache_offsets + + o_offsets = stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] + o_ptrs = out + dest_index * stride_o_bs + o_offsets + + k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0) + tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num) + return + + + @torch.no_grad() + def copy_kv_cache_to_dest(k_ptr, dest_index_ptr, out): + seq_len = dest_index_ptr.shape[0] + head_num = k_ptr.shape[1] + head_dim = k_ptr.shape[2] + assert head_num == out.shape[1], "head_num should be the same for k_ptr and out" + assert head_dim == out.shape[2], "head_dim should be the same for k_ptr and out" + + num_warps = 2 + + _fwd_copy_kv_cache_dest[(seq_len,)]( + k_ptr, dest_index_ptr, out, + k_ptr.stride(0), + k_ptr.stride(1), + k_ptr.stride(2), + out.stride(0), + out.stride(1), + out.stride(2), + head_num, + BLOCK_DMODEL=head_dim, + BLOCK_HEAD=triton.next_power_of_2(head_num), + num_warps=num_warps, + num_stages=2, + ) + return + + diff --git a/tests/test_kernels/triton/test_copy_kv_dest.py b/tests/test_kernels/triton/test_copy_kv_dest.py new file mode 100644 index 000000000000..72771c66ca14 --- /dev/null +++ b/tests/test_kernels/triton/test_copy_kv_dest.py @@ -0,0 +1,34 @@ +import pytest +from packaging import version + +import torch +from torch import nn + +from tests.test_kernels.triton.utils import benchmark +from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +def test_kv_cache_copy_op(): + + B_NTX = 32 * 2048 + head_num = 8 + head_dim = 64 + + cache = torch.randn((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) + dest_index = torch.arange(0, B_NTX, device="cuda", dtype=torch.int32) + + dest_data = torch.ones((B_NTX, head_num, head_dim), device="cuda", dtype=torch.float16) + + copy_kv_cache_to_dest(cache, dest_index, dest_data) + + assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" + + latency = benchmark(copy_kv_cache_to_dest, cache, dest_index, dest_data) + print("the average latency is {} ms".format(str(latency))) + + +if __name__ == "__main__": + test_kv_cache_copy_op() + diff --git a/tests/test_kernels/triton/test_softmax.py b/tests/test_kernels/triton/test_softmax.py index e7309b08cd84..d5eb9b53abbf 100644 --- a/tests/test_kernels/triton/test_softmax.py +++ b/tests/test_kernels/triton/test_softmax.py @@ -1,5 +1,6 @@ import pytest from packaging import version + import torch from torch import nn diff --git a/tests/test_kernels/triton/utils.py b/tests/test_kernels/triton/utils.py new file mode 100644 index 000000000000..c32c5ad181f0 --- /dev/null +++ b/tests/test_kernels/triton/utils.py @@ -0,0 +1,25 @@ +import torch +import numpy as np + + +def benchmark(func, *args): + starter, ender = torch.cuda.Event( + enable_timing=True), torch.cuda.Event(enable_timing=True) + repetitions = 300 + + for i in range(10): + func(*args) + + timings = np.zeros((repetitions, 1)) + with torch.no_grad(): + for rep in range(repetitions): + starter.record() + func(*args) + ender.record() + # WAIT FOR GPU SYNC + torch.cuda.synchronize() + curr_time = starter.elapsed_time(ender) + timings[rep] = curr_time + + mean_syn = np.sum(timings) / repetitions + return mean_syn \ No newline at end of file From 0799569576ce9e1b23ef42c19b5cdacb385df86a Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 15:07:24 +0800 Subject: [PATCH 09/22] modify --- .../kernel/triton/bloom_context_attention.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/colossalai/kernel/triton/bloom_context_attention.py b/colossalai/kernel/triton/bloom_context_attention.py index b6eb3872ba48..a4f188866c9c 100644 --- a/colossalai/kernel/triton/bloom_context_attention.py +++ b/colossalai/kernel/triton/bloom_context_attention.py @@ -14,12 +14,23 @@ def _context_fwd_kernel( Q, K, V, sm_scale, bias, B_Start_Loc, B_Seqlen, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_tmp_b, stride_tmp_h, stride_tmp_s, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_tmp_b, + stride_tmp_h, + stride_tmp_s, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) From 5d662fc0758ca5f566bff43046017f836608b50d Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 15:37:25 +0800 Subject: [PATCH 10/22] added --- tests/test_kernels/triton/test_copy_kv_dest.py | 10 +++++++++- .../triton/test_self_attention_nonfusion.py | 4 ++-- tests/test_kernels/triton/test_softmax.py | 12 +++++++++--- 3 files changed, 20 insertions(+), 6 deletions(-) diff --git a/tests/test_kernels/triton/test_copy_kv_dest.py b/tests/test_kernels/triton/test_copy_kv_dest.py index 72771c66ca14..35f30d4f1243 100644 --- a/tests/test_kernels/triton/test_copy_kv_dest.py +++ b/tests/test_kernels/triton/test_copy_kv_dest.py @@ -7,9 +7,17 @@ from tests.test_kernels.triton.utils import benchmark from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_kv_cache_copy_op(): B_NTX = 32 * 2048 diff --git a/tests/test_kernels/triton/test_self_attention_nonfusion.py b/tests/test_kernels/triton/test_self_attention_nonfusion.py index bfe67afdd638..714b6bd975f8 100644 --- a/tests/test_kernels/triton/test_self_attention_nonfusion.py +++ b/tests/test_kernels/triton/test_self_attention_nonfusion.py @@ -17,7 +17,7 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_qkv_matmul(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) scale = 1.2 @@ -106,7 +106,7 @@ def self_attention_compute_using_torch(qkv, return res.view(batches, -1, d_model), score_output, softmax_output -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_self_atttention_test(): qkv = torch.randn((4, 24, 64*3), device="cuda", dtype=torch.float16) diff --git a/tests/test_kernels/triton/test_softmax.py b/tests/test_kernels/triton/test_softmax.py index d5eb9b53abbf..970e54c8eb79 100644 --- a/tests/test_kernels/triton/test_softmax.py +++ b/tests/test_kernels/triton/test_softmax.py @@ -1,14 +1,20 @@ import pytest from packaging import version - import torch from torch import nn - from colossalai.kernel.triton.softmax import softmax +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT, reason="triton requires cuda version to be higher than 11.4") +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_softmax_op(): data_samples = [ torch.randn((3, 4, 5, 32), device = "cuda", dtype = torch.float32), From 8498ca146db14b83c9d2bcc3559b89d496581be6 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 17:17:58 +0800 Subject: [PATCH 11/22] updating kernels --- .../kernel/triton/bloom_context_attention.py | 67 ++++++++++----- .../triton/test_bloom_context_attention.py | 81 +++++++++++++++++++ 2 files changed, 128 insertions(+), 20 deletions(-) create mode 100644 tests/test_kernels/triton/test_bloom_context_attention.py diff --git a/colossalai/kernel/triton/bloom_context_attention.py b/colossalai/kernel/triton/bloom_context_attention.py index a4f188866c9c..9c6bdaaa6600 100644 --- a/colossalai/kernel/triton/bloom_context_attention.py +++ b/colossalai/kernel/triton/bloom_context_attention.py @@ -1,4 +1,5 @@ import torch +import math try: import triton import triton.language as tl @@ -10,27 +11,18 @@ if HAS_TRITON: @triton.jit - def _context_fwd_kernel( - Q, K, V, sm_scale, bias, B_Start_Loc, B_Seqlen, + def _bloom_context_flash_attention_kernel( + Q, K, V, sm_scale, + Alibi, + B_Start_Loc, B_Seqlen, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, - stride_qbs, - stride_qh, - stride_qd, - stride_kbs, - stride_kh, - stride_kd, - stride_vbs, - stride_vh, - stride_vd, - stride_obs, - stride_oh, - stride_od, - stride_tmp_b, - stride_tmp_h, - stride_tmp_s, - BLOCK_M: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, + stride_qbs, stride_qh, stride_qd, + stride_kbs, stride_kh, stride_kd, + stride_vbs, stride_vh, stride_vd, + stride_obs, stride_oh, stride_od, + stride_tmp_b, stride_tmp_h, stride_tmp_s, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) @@ -60,7 +52,7 @@ def _context_fwd_kernel( l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - alibi_m = tl.load(bias + cur_head) + alibi_m = tl.load(Alibi + cur_head) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) @@ -109,4 +101,39 @@ def _context_fwd_kernel( off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + + @torch.no_grad() + def bloom_context_flash_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_len): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv, "context process only supports equal query, key, value length" + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + num_warps = 4 if Lk <= 64 else 8 + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + + _bloom_context_flash_attention_kernel[grid]( + q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, + tmp, + o, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + tmp.stride(0), tmp.stride(1), tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) return \ No newline at end of file diff --git a/tests/test_kernels/triton/test_bloom_context_attention.py b/tests/test_kernels/triton/test_bloom_context_attention.py new file mode 100644 index 000000000000..94e283fcdab3 --- /dev/null +++ b/tests/test_kernels/triton/test_bloom_context_attention.py @@ -0,0 +1,81 @@ +import pytest +import math +from packaging import version + +import torch +from torch import nn +from torch.nn import functional as F + + +from tests.test_kernels.triton.utils import benchmark +from colossalai.kernel.triton.bloom_context_attention import bloom_context_flash_attention_fwd as bloom_context_attn_fwd + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): + ''' + adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + ''' + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + sm_scale = 1/math.sqrt(head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale + scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) + + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + return output + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +def test_bloom_context_attention(): + bs = 4 + head_num = 8 + seq_len = 1024 + head_dim = 64 + + query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + + max_input_len = seq_len + b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + + for i in range(bs): + b_start[i] = i * seq_len + b_len[i] = seq_len + + o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") + bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, alibi, b_start, b_len,max_input_len) + + torch_out = torch_att(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + + latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, alibi, b_start, b_len, max_input_len) + latency_2 = benchmark(torch_att, query, k, v, bs, seq_len, head_num, head_dim) + + print("the triton op latency is {} ms".format(str(latency_1))) + print("the torch op latency is {} ms".format(str(latency_2))) + + +if __name__ == "__main__": + test_bloom_context_attention() \ No newline at end of file From 37588d72a9a3f8348bd1eb8e283e1d0db79f474a Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 17:47:46 +0800 Subject: [PATCH 12/22] adding tests --- .../kernel/triton/llama_context_attention.py | 35 ++++++++++- .../triton/test_bloom_context_attention.py | 28 +-------- .../triton/test_llama_context_attention.py | 59 +++++++++++++++++++ tests/test_kernels/triton/utils.py | 29 ++++++++- 4 files changed, 123 insertions(+), 28 deletions(-) create mode 100644 tests/test_kernels/triton/test_llama_context_attention.py diff --git a/colossalai/kernel/triton/llama_context_attention.py b/colossalai/kernel/triton/llama_context_attention.py index 7972bb6ffaf2..ad464bed0c40 100644 --- a/colossalai/kernel/triton/llama_context_attention.py +++ b/colossalai/kernel/triton/llama_context_attention.py @@ -10,7 +10,7 @@ if HAS_TRITON: @triton.jit - def _fwd_kernel( + def _llama_context_forward_kernel( Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug Out, @@ -93,4 +93,37 @@ def _fwd_kernel( out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + return + + @torch.no_grad() + def llama_context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / (Lq**0.5) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _llama_context_forward_kernel[grid]( + q, k, v, sm_scale, b_start_loc, b_seq_len, + tmp, + o, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + tmp.stride(0), tmp.stride(1), tmp.stride(2), + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) return \ No newline at end of file diff --git a/tests/test_kernels/triton/test_bloom_context_attention.py b/tests/test_kernels/triton/test_bloom_context_attention.py index 94e283fcdab3..2fd8f9346263 100644 --- a/tests/test_kernels/triton/test_bloom_context_attention.py +++ b/tests/test_kernels/triton/test_bloom_context_attention.py @@ -7,7 +7,7 @@ from torch.nn import functional as F -from tests.test_kernels.triton.utils import benchmark +from tests.test_kernels.triton.utils import benchmark, torch_context_attention from colossalai.kernel.triton.bloom_context_attention import bloom_context_flash_attention_fwd as bloom_context_attn_fwd try: @@ -20,28 +20,6 @@ TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') -def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim): - ''' - adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 - ''' - xq = xq.view(bs, seqlen, num_head, head_dim) - xk = xk.view(bs, seqlen, num_head, head_dim) - xv = xv.view(bs, seqlen, num_head, head_dim) - mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() - mask[mask == 0.] = -100000000.0 - mask = mask.repeat(bs, num_head, 1, 1) - keys = xk - values = xv - xq = xq.transpose(1, 2) - keys = keys.transpose(1, 2) - values = values.transpose(1, 2) - sm_scale = 1/math.sqrt(head_dim) - scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale - scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) - - output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) - return output - @pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") def test_bloom_context_attention(): bs = 4 @@ -66,12 +44,12 @@ def test_bloom_context_attention(): alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, alibi, b_start, b_len,max_input_len) - torch_out = torch_att(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, alibi, b_start, b_len, max_input_len) - latency_2 = benchmark(torch_att, query, k, v, bs, seq_len, head_num, head_dim) + latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) print("the triton op latency is {} ms".format(str(latency_1))) print("the torch op latency is {} ms".format(str(latency_2))) diff --git a/tests/test_kernels/triton/test_llama_context_attention.py b/tests/test_kernels/triton/test_llama_context_attention.py new file mode 100644 index 000000000000..9a3e89deff35 --- /dev/null +++ b/tests/test_kernels/triton/test_llama_context_attention.py @@ -0,0 +1,59 @@ +import pytest +import math +from packaging import version + +import torch +from torch import nn +from torch.nn import functional as F + + +from tests.test_kernels.triton.utils import benchmark, torch_context_attention +from colossalai.kernel.triton.llama_context_attention import llama_context_attention_fwd + +try: + import triton + import triton.language as tl + HAS_TRITON = True +except ImportError: + HAS_TRITON = False + print("please install triton from https://github.com/openai/triton") + +TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse('11.4') + + +@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4") +def test_llama_context_attention(): + bs = 4 + head_num = 8 + seq_len = 1024 + head_dim = 64 + + query = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + k = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + v = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + + + max_input_len = seq_len + b_start = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + b_len = torch.zeros((bs, ), device="cuda", dtype=torch.int32) + + for i in range(bs): + b_start[i] = i * seq_len + b_len[i] = seq_len + + o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") + llama_context_attention_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len,max_input_len) + + torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) + + assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" + + latency_1 = benchmark(llama_context_attention_fwd, query, k, v, o, b_start, b_len, max_input_len) + latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) + + print("the triton op latency is {} ms".format(str(latency_1))) + print("the torch op latency is {} ms".format(str(latency_2))) + + +if __name__ == "__main__": + test_llama_context_attention() \ No newline at end of file diff --git a/tests/test_kernels/triton/utils.py b/tests/test_kernels/triton/utils.py index c32c5ad181f0..940d277cfb02 100644 --- a/tests/test_kernels/triton/utils.py +++ b/tests/test_kernels/triton/utils.py @@ -1,5 +1,8 @@ -import torch import numpy as np +import math + +import torch +from torch.nn import functional as F def benchmark(func, *args): @@ -22,4 +25,26 @@ def benchmark(func, *args): timings[rep] = curr_time mean_syn = np.sum(timings) / repetitions - return mean_syn \ No newline at end of file + return mean_syn + +def torch_context_attention(xq, xk, xv, bs, seqlen, num_head, head_dim): + ''' + adepted from https://github.com/ModelTC/lightllm/blob/main/lightllm/models/bloom/triton_kernel/context_flashattention_nopad.py#L253 + ''' + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen, num_head, head_dim) + xv = xv.view(bs, seqlen, num_head, head_dim) + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.] = -100000000.0 + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + sm_scale = 1/math.sqrt(head_dim) + scores = torch.matmul(xq, keys.transpose(2, 3)) * sm_scale + scores = F.softmax(scores.float() + mask, dim=-1).to(dtype=torch.float16) + + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + return output \ No newline at end of file From eab096f7405e73b802dce94fe6a64826f6c73dd9 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 18:19:59 +0800 Subject: [PATCH 13/22] added tests --- ...text_attention.py => context_attention.py} | 59 ++++++-- .../kernel/triton/llama_context_attention.py | 129 ------------------ .../triton/test_bloom_context_attention.py | 6 +- .../triton/test_llama_context_attention.py | 7 +- 4 files changed, 56 insertions(+), 145 deletions(-) rename colossalai/kernel/triton/{bloom_context_attention.py => context_attention.py} (72%) delete mode 100644 colossalai/kernel/triton/llama_context_attention.py diff --git a/colossalai/kernel/triton/bloom_context_attention.py b/colossalai/kernel/triton/context_attention.py similarity index 72% rename from colossalai/kernel/triton/bloom_context_attention.py rename to colossalai/kernel/triton/context_attention.py index 9c6bdaaa6600..9ff89b3530b9 100644 --- a/colossalai/kernel/triton/bloom_context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -11,11 +11,11 @@ if HAS_TRITON: @triton.jit - def _bloom_context_flash_attention_kernel( + def _context_flash_attention_kernel( Q, K, V, sm_scale, - Alibi, B_Start_Loc, B_Seqlen, TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + alibi_ptr, Out, stride_qbs, stride_qh, stride_qd, stride_kbs, stride_kh, stride_kd, @@ -51,11 +51,13 @@ def _bloom_context_flash_attention_kernel( m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - alibi_m = tl.load(Alibi + cur_head) + + if alibi_ptr is not None: + alibi_m = tl.load(alibi_ptr + cur_head) block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + # modify from for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) # -- compute qk ---- @@ -66,8 +68,10 @@ def _bloom_context_flash_attention_kernel( qk += tl.dot(q, k) qk *= sm_scale - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) - qk -= alibi_loc * alibi_m + if alibi_ptr is not None: + + alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) + qk -= alibi_loc * alibi_m qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) @@ -105,7 +109,7 @@ def _bloom_context_flash_attention_kernel( @torch.no_grad() - def bloom_context_flash_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, max_input_len): + def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, alibi=None): BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] @@ -121,9 +125,46 @@ def bloom_context_flash_attention_fwd(q, k, v, o, alibi, b_start_loc, b_seq_len, tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - _bloom_context_flash_attention_kernel[grid]( - q, k, v, sm_scale, alibi, b_start_loc, b_seq_len, + _context_flash_attention_kernel[grid]( + q, k, v, sm_scale, + b_start_loc, b_seq_len, + tmp, + alibi, + o, + q.stride(0), q.stride(1), q.stride(2), + k.stride(0), k.stride(1), k.stride(2), + v.stride(0), v.stride(1), v.stride(2), + o.stride(0), o.stride(1), o.stride(2), + tmp.stride(0), tmp.stride(1), tmp.stride(2), + # manually setting this blcok num, we can use tuning config to futher speed-up + BLOCK_M=BLOCK, + BLOCK_DMODEL=Lk, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + @torch.no_grad() + def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + BLOCK = 128 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + + sm_scale = 1.0 / math.sqrt(Lk) + batch, head = b_seq_len.shape[0], q.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) + + tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + # num_warps = 4 + _context_flash_attention_kernel[grid]( + q, k, v, sm_scale, b_start_loc, b_seq_len, tmp, + None, o, q.stride(0), q.stride(1), q.stride(2), k.stride(0), k.stride(1), k.stride(2), diff --git a/colossalai/kernel/triton/llama_context_attention.py b/colossalai/kernel/triton/llama_context_attention.py deleted file mode 100644 index ad464bed0c40..000000000000 --- a/colossalai/kernel/triton/llama_context_attention.py +++ /dev/null @@ -1,129 +0,0 @@ -import torch -try: - import triton - import triton.language as tl - HAS_TRITON = True -except ImportError: - HAS_TRITON = False - print("please install triton from https://github.com/openai/triton") - - -if HAS_TRITON: - @triton.jit - def _llama_context_forward_kernel( - Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug - Out, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_tmp_b, stride_tmp_h, stride_tmp_s, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - ): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - start_m = tl.program_id(2) - - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - - block_start_loc = BLOCK_M * start_m - - # initialize offsets - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd - off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - - k_ptrs = K + off_k - v_ptrs = V + off_v - - t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s - # t_ptrs = TMP + offs_m - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - - block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - - for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) - - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - qk *= sm_scale - qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) - - m_ij = tl.max(qk, 1) - p = tl.exp(qk - m_ij[:, None]) - l_ij = tl.sum(p, 1) - # -- update m_i and l_i - m_i_new = tl.maximum(m_i, m_ij) - alpha = tl.exp(m_i - m_i_new) - beta = tl.exp(m_ij - m_i_new) - l_i_new = alpha * l_i + beta * l_ij - # -- update output accumulator -- - # scale p - p_scale = beta / l_i_new - p = p * p_scale[:, None] - # scale acc - acc_scale = l_i / l_i_new * alpha - tl.store(t_ptrs, acc_scale) - acc_scale = tl.load(t_ptrs) # BUG: have to store and immediately load - acc = acc * acc_scale[:, None] - # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) - - p = p.to(v.dtype) - acc += tl.dot(p, v) - # update m_i and l_i - l_i = l_i_new - m_i = m_i_new - # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) - - return - - @torch.no_grad() - def llama_context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): - BLOCK = 128 - # shape constraints - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - - sm_scale = 1.0 / (Lq**0.5) - batch, head = b_seq_len.shape[0], q.shape[1] - - grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) - - tmp = torch.empty((batch, head, max_input_len + 256), device=q.device, dtype=torch.float32) - num_warps = 4 if Lk <= 64 else 8 - # num_warps = 4 - _llama_context_forward_kernel[grid]( - q, k, v, sm_scale, b_start_loc, b_seq_len, - tmp, - o, - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - tmp.stride(0), tmp.stride(1), tmp.stride(2), - BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, - BLOCK_N=BLOCK, - num_warps=num_warps, - num_stages=1, - ) - return \ No newline at end of file diff --git a/tests/test_kernels/triton/test_bloom_context_attention.py b/tests/test_kernels/triton/test_bloom_context_attention.py index 2fd8f9346263..30d823a6f06c 100644 --- a/tests/test_kernels/triton/test_bloom_context_attention.py +++ b/tests/test_kernels/triton/test_bloom_context_attention.py @@ -8,7 +8,7 @@ from tests.test_kernels.triton.utils import benchmark, torch_context_attention -from colossalai.kernel.triton.bloom_context_attention import bloom_context_flash_attention_fwd as bloom_context_attn_fwd +from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd try: import triton @@ -42,13 +42,13 @@ def test_bloom_context_attention(): o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") alibi = torch.zeros((head_num,), dtype=torch.float32, device="cuda") - bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, alibi, b_start, b_len,max_input_len) + bloom_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len, alibi) torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" - latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, alibi, b_start, b_len, max_input_len) + latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len, alibi) latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) print("the triton op latency is {} ms".format(str(latency_1))) diff --git a/tests/test_kernels/triton/test_llama_context_attention.py b/tests/test_kernels/triton/test_llama_context_attention.py index 9a3e89deff35..203da895055c 100644 --- a/tests/test_kernels/triton/test_llama_context_attention.py +++ b/tests/test_kernels/triton/test_llama_context_attention.py @@ -8,8 +8,7 @@ from tests.test_kernels.triton.utils import benchmark, torch_context_attention -from colossalai.kernel.triton.llama_context_attention import llama_context_attention_fwd - +from colossalai.kernel.triton.context_attention import llama_context_attn_fwd try: import triton import triton.language as tl @@ -42,13 +41,13 @@ def test_llama_context_attention(): b_len[i] = seq_len o = torch.randn((bs*seq_len, head_num, head_dim), dtype=torch.float16, device="cuda") - llama_context_attention_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len,max_input_len) + llama_context_attn_fwd(query.clone(), k.clone(), v.clone(), o, b_start, b_len, max_input_len) torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" - latency_1 = benchmark(llama_context_attention_fwd, query, k, v, o, b_start, b_len, max_input_len) + latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len) latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) print("the triton op latency is {} ms".format(str(latency_1))) From efa1c40c23cee4b70fba96f7843a7c7814a33ef6 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 19:00:50 +0800 Subject: [PATCH 14/22] kernel change --- colossalai/kernel/triton/context_attention.py | 24 +++++++++++-------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 9ff89b3530b9..c89650f76c01 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -10,6 +10,10 @@ if HAS_TRITON: + ''' + this function is modified from + https://github.com/ModelTC/lightllm/blob/f093edc20683ac3ea1bca3fb5d8320a0dd36cf7b/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L10 + ''' @triton.jit def _context_flash_attention_kernel( Q, K, V, sm_scale, @@ -25,6 +29,7 @@ def _context_flash_attention_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): + cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) @@ -38,14 +43,11 @@ def _context_flash_attention_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd - off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd - off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + load_p_ptrs = Q + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) - k_ptrs = K + off_k - v_ptrs = V + off_v + k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd + v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") @@ -69,7 +71,6 @@ def _context_flash_attention_kernel( qk *= sm_scale if alibi_ptr is not None: - alibi_loc = offs_m[:, None] - (start_n + offs_n[None, :]) qk -= alibi_loc * alibi_m @@ -101,6 +102,7 @@ def _context_flash_attention_kernel( # update m_i and l_i l_i = l_i_new m_i = m_i_new + # initialize pointers to output off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od out_ptrs = Out + off_o @@ -113,7 +115,8 @@ def bloom_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, al BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv, "context process only supports equal query, key, value length" + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / math.sqrt(Lk) @@ -150,7 +153,8 @@ def llama_context_attn_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): BLOCK = 128 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv + assert Lq == Lk, "context process only supports equal query, key, value length" + assert Lk == Lv, "context process only supports equal query, key, value length" assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / math.sqrt(Lk) From 62837d8ea4ba557ca88fc26cf93963540c64055f Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 19:14:21 +0800 Subject: [PATCH 15/22] submit --- colossalai/kernel/triton/context_attention.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index c89650f76c01..9d9e17454677 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -18,7 +18,7 @@ def _context_flash_attention_kernel( Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, - TMP, # NOTE: TMP is a scratchpad buffer to workaround a compiler bug + TMP, alibi_ptr, Out, stride_qbs, stride_qh, stride_qd, @@ -26,16 +26,18 @@ def _context_flash_attention_kernel( stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_tmp_b, stride_tmp_h, stride_tmp_s, - BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + # suggtest set-up 128, 256, 512, 1024 + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): - cur_batch = tl.program_id(0) + batch_id = tl.program_id(0) cur_head = tl.program_id(1) start_m = tl.program_id(2) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) block_start_loc = BLOCK_M * start_m @@ -43,12 +45,12 @@ def _context_flash_attention_kernel( offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - load_p_ptrs = Q + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd + load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) k_ptrs = K + offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd v_ptrs = V + offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd - t_ptrs = TMP + cur_batch * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s + t_ptrs = TMP + batch_id * stride_tmp_b + cur_head * stride_tmp_h + offs_m * stride_tmp_s m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") l_i = tl.zeros([BLOCK_M], dtype=tl.float32) @@ -62,8 +64,7 @@ def _context_flash_attention_kernel( # modify from for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) @@ -94,7 +95,7 @@ def _context_flash_attention_kernel( acc_scale = tl.load(t_ptrs) acc = acc * acc_scale[:, None] # update acc - v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + v = tl.load(v_ptrs + (cur_batch_start_index + start_n) * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0) p = p.to(v.dtype) @@ -103,8 +104,7 @@ def _context_flash_attention_kernel( l_i = l_i_new m_i = m_i_new - # initialize pointers to output - off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od + off_o = (cur_batch_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od out_ptrs = Out + off_o tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) return From 4685e6dd60ef58d8cdb1db140764786bfb33aed8 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 19:19:11 +0800 Subject: [PATCH 16/22] modify --- colossalai/kernel/triton/context_attention.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index 9d9e17454677..be8a146aa443 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -36,15 +36,16 @@ def _context_flash_attention_kernel( cur_head = tl.program_id(1) start_m = tl.program_id(2) - cur_batch_seq_len = tl.load(B_Seqlen + batch_id) - cur_batch_start_index = tl.load(B_Start_Loc + batch_id) - - block_start_loc = BLOCK_M * start_m - # initialize offsets offs_n = tl.arange(0, BLOCK_N) offs_d = tl.arange(0, BLOCK_DMODEL) offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + + # get batch info + cur_batch_seq_len = tl.load(B_Seqlen + batch_id) + cur_batch_start_index = tl.load(B_Start_Loc + batch_id) + block_start_loc = BLOCK_M * start_m + load_p_ptrs = Q + (cur_batch_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd q = tl.load(load_p_ptrs, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) From 042c1ca59325ea48c4a652bf4f29c02337c3581d Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Wed, 23 Aug 2023 19:44:43 +0800 Subject: [PATCH 17/22] added --- colossalai/kernel/triton/context_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index be8a146aa443..d945ac4614bf 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -26,7 +26,7 @@ def _context_flash_attention_kernel( stride_vbs, stride_vh, stride_vd, stride_obs, stride_oh, stride_od, stride_tmp_b, stride_tmp_h, stride_tmp_s, - # suggtest set-up 128, 256, 512, 1024 + # suggtest set-up 64, 128, 256, 512 BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, From fa9042ff1138ba0f2b64c0cf310f18d5901faaf3 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 24 Aug 2023 00:03:43 +0800 Subject: [PATCH 18/22] edit comments --- colossalai/shardformer/modeling/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 7fe75ea4c483..dd9a27d67367 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -400,6 +400,7 @@ def get_llama_flash_attention_forward(): except: print("fall back to original rotary_embedding_neox of huggingface") print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") HAS_VLLM_KERNERL = False @@ -474,6 +475,7 @@ def get_llama_vllm_rmsnorm_forward(): except: print("please install vllm kernels to install rmsnorm") print("install vllm from https://github.com/vllm-project/vllm to accelerate your inference") + print("if falied to install vllm, please use this branch to install: https://github.com/tiandiao123/vllm/tree/setup_branch") HAS_VLLM_KERNERL = False if HAS_VLLM_KERNERL: From 580518806b1dedec8dcad356bda1d5806994f9d4 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 24 Aug 2023 11:04:36 +0800 Subject: [PATCH 19/22] change name --- tests/{test_kernels => test_infer_ops}/cuda/test_vllm_rmsnorm.py | 0 .../cuda/test_vllm_rotary_embedding.py | 0 .../triton/test_bloom_context_attention.py | 0 .../{test_kernels => test_infer_ops}/triton/test_copy_kv_dest.py | 0 .../triton/test_llama_context_attention.py | 0 .../triton/test_self_attention_nonfusion.py | 0 tests/{test_kernels => test_infer_ops}/triton/test_softmax.py | 0 tests/{test_kernels => test_infer_ops}/triton/utils.py | 0 8 files changed, 0 insertions(+), 0 deletions(-) rename tests/{test_kernels => test_infer_ops}/cuda/test_vllm_rmsnorm.py (100%) rename tests/{test_kernels => test_infer_ops}/cuda/test_vllm_rotary_embedding.py (100%) rename tests/{test_kernels => test_infer_ops}/triton/test_bloom_context_attention.py (100%) rename tests/{test_kernels => test_infer_ops}/triton/test_copy_kv_dest.py (100%) rename tests/{test_kernels => test_infer_ops}/triton/test_llama_context_attention.py (100%) rename tests/{test_kernels => test_infer_ops}/triton/test_self_attention_nonfusion.py (100%) rename tests/{test_kernels => test_infer_ops}/triton/test_softmax.py (100%) rename tests/{test_kernels => test_infer_ops}/triton/utils.py (100%) diff --git a/tests/test_kernels/cuda/test_vllm_rmsnorm.py b/tests/test_infer_ops/cuda/test_vllm_rmsnorm.py similarity index 100% rename from tests/test_kernels/cuda/test_vllm_rmsnorm.py rename to tests/test_infer_ops/cuda/test_vllm_rmsnorm.py diff --git a/tests/test_kernels/cuda/test_vllm_rotary_embedding.py b/tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py similarity index 100% rename from tests/test_kernels/cuda/test_vllm_rotary_embedding.py rename to tests/test_infer_ops/cuda/test_vllm_rotary_embedding.py diff --git a/tests/test_kernels/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py similarity index 100% rename from tests/test_kernels/triton/test_bloom_context_attention.py rename to tests/test_infer_ops/triton/test_bloom_context_attention.py diff --git a/tests/test_kernels/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py similarity index 100% rename from tests/test_kernels/triton/test_copy_kv_dest.py rename to tests/test_infer_ops/triton/test_copy_kv_dest.py diff --git a/tests/test_kernels/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py similarity index 100% rename from tests/test_kernels/triton/test_llama_context_attention.py rename to tests/test_infer_ops/triton/test_llama_context_attention.py diff --git a/tests/test_kernels/triton/test_self_attention_nonfusion.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py similarity index 100% rename from tests/test_kernels/triton/test_self_attention_nonfusion.py rename to tests/test_infer_ops/triton/test_self_attention_nonfusion.py diff --git a/tests/test_kernels/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py similarity index 100% rename from tests/test_kernels/triton/test_softmax.py rename to tests/test_infer_ops/triton/test_softmax.py diff --git a/tests/test_kernels/triton/utils.py b/tests/test_infer_ops/triton/utils.py similarity index 100% rename from tests/test_kernels/triton/utils.py rename to tests/test_infer_ops/triton/utils.py From fa47c1c5673ba5f206217d30f71c8c9a7079f813 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 24 Aug 2023 14:44:06 +0800 Subject: [PATCH 20/22] change commnets and fix import --- colossalai/kernel/triton/context_attention.py | 1 - 1 file changed, 1 deletion(-) diff --git a/colossalai/kernel/triton/context_attention.py b/colossalai/kernel/triton/context_attention.py index d945ac4614bf..38db2048c6a4 100644 --- a/colossalai/kernel/triton/context_attention.py +++ b/colossalai/kernel/triton/context_attention.py @@ -62,7 +62,6 @@ def _context_flash_attention_kernel( block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) - # modify from for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) k = tl.load(k_ptrs + (cur_batch_start_index + start_n) * stride_kbs, From 96eb052e7ae13d9ad0957b9b493c7049f52687d7 Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 24 Aug 2023 14:55:20 +0800 Subject: [PATCH 21/22] add --- tests/test_infer_ops/triton/test_bloom_context_attention.py | 6 ++---- tests/test_infer_ops/triton/test_llama_context_attention.py | 5 ++--- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py index 30d823a6f06c..6c10ee3ffe3f 100644 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -6,13 +6,11 @@ from torch import nn from torch.nn import functional as F - -from tests.test_kernels.triton.utils import benchmark, torch_context_attention -from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd - try: import triton import triton.language as tl + from tests.test_kernels.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton.context_attention import bloom_context_attn_fwd HAS_TRITON = True except ImportError: HAS_TRITON = False diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index 203da895055c..04d08140815d 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -6,12 +6,11 @@ from torch import nn from torch.nn import functional as F - -from tests.test_kernels.triton.utils import benchmark, torch_context_attention -from colossalai.kernel.triton.context_attention import llama_context_attn_fwd try: import triton import triton.language as tl + from tests.test_kernels.triton.utils import benchmark, torch_context_attention + from colossalai.kernel.triton.context_attention import llama_context_attn_fwd HAS_TRITON = True except ImportError: HAS_TRITON = False From d5b8b9073611c3ecc396884add0bc4dbfb167a3d Mon Sep 17 00:00:00 2001 From: "cuiqing.li" Date: Thu, 24 Aug 2023 15:09:34 +0800 Subject: [PATCH 22/22] added --- tests/test_infer_ops/triton/test_copy_kv_dest.py | 5 ++--- tests/test_infer_ops/triton/test_self_attention_nonfusion.py | 5 ++--- tests/test_infer_ops/triton/test_softmax.py | 3 ++- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py index 35f30d4f1243..068295a0e4a9 100644 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -4,12 +4,11 @@ import torch from torch import nn -from tests.test_kernels.triton.utils import benchmark -from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest - try: import triton import triton.language as tl + from tests.test_kernels.triton.utils import benchmark + from colossalai.kernel.triton.copy_kv_cache_dest import copy_kv_cache_to_dest HAS_TRITON = True except ImportError: HAS_TRITON = False diff --git a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py index 714b6bd975f8..9692737a05a0 100644 --- a/tests/test_infer_ops/triton/test_self_attention_nonfusion.py +++ b/tests/test_infer_ops/triton/test_self_attention_nonfusion.py @@ -4,12 +4,11 @@ from torch import nn import torch.nn.functional as F -from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton -from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel - try: import triton import triton.language as tl + from colossalai.kernel.triton.self_attention_nofusion import self_attention_compute_using_triton + from colossalai.kernel.triton.qkv_matmul_kernel import qkv_gemm_4d_kernel HAS_TRITON = True except ImportError: HAS_TRITON = False diff --git a/tests/test_infer_ops/triton/test_softmax.py b/tests/test_infer_ops/triton/test_softmax.py index 970e54c8eb79..6a244608c43f 100644 --- a/tests/test_infer_ops/triton/test_softmax.py +++ b/tests/test_infer_ops/triton/test_softmax.py @@ -2,11 +2,12 @@ from packaging import version import torch from torch import nn -from colossalai.kernel.triton.softmax import softmax + try: import triton import triton.language as tl + from colossalai.kernel.triton.softmax import softmax HAS_TRITON = True except ImportError: HAS_TRITON = False