diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index f26248d44612..c749ce145dcb 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -16,7 +16,6 @@ ) from transformers.utils import logging -from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention from colossalai.pipeline.stage_manager import PipelineStageManager @@ -399,6 +398,8 @@ def llama_for_sequence_classification_forward( def get_llama_flash_attention_forward(): + + from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention def forward( self: LlamaAttention, diff --git a/tests/test_infer_ops/triton/test_bloom_context_attention.py b/tests/test_infer_ops/triton/test_bloom_context_attention.py index 63d77ce3e16e..ea89d6bb4764 100644 --- a/tests/test_infer_ops/triton/test_bloom_context_attention.py +++ b/tests/test_infer_ops/triton/test_bloom_context_attention.py @@ -45,13 +45,6 @@ def test_bloom_context_attention(): torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim) assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched" - - latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len, alibi) - latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) - - print("the triton op latency is {} ms".format(str(latency_1))) - print("the torch op latency is {} ms".format(str(latency_2))) - if __name__ == "__main__": test_bloom_context_attention() \ No newline at end of file diff --git a/tests/test_infer_ops/triton/test_copy_kv_dest.py b/tests/test_infer_ops/triton/test_copy_kv_dest.py index 068295a0e4a9..188493eb13ce 100644 --- a/tests/test_infer_ops/triton/test_copy_kv_dest.py +++ b/tests/test_infer_ops/triton/test_copy_kv_dest.py @@ -32,9 +32,6 @@ def test_kv_cache_copy_op(): assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched" - latency = benchmark(copy_kv_cache_to_dest, cache, dest_index, dest_data) - print("the average latency is {} ms".format(str(latency))) - if __name__ == "__main__": test_kv_cache_copy_op() diff --git a/tests/test_infer_ops/triton/test_layernorm_triton.py b/tests/test_infer_ops/triton/test_layernorm_triton.py index 15d0fe74c1ed..9648f91e2f28 100644 --- a/tests/test_infer_ops/triton/test_layernorm_triton.py +++ b/tests/test_infer_ops/triton/test_layernorm_triton.py @@ -41,24 +41,5 @@ def test_layer_norm(M, N): assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0) -@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON, - reason="triton requires cuda version to be higher than 11.4") -@parameterize('M', [4]) -@parameterize('N', [128]) -def test_benchmark(M, N): - dtype = torch.float16 - eps = 1e-5 - x_shape = (M, N) - w_shape = (x_shape[-1],) - weight = torch.rand(w_shape, dtype=dtype, device='cuda') - bias = torch.rand(w_shape, dtype=dtype, device='cuda') - x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda') - - latency_1 = benchmark(layer_norm, x, weight, bias, eps) - latency_2 = benchmark(torch.nn.functional.layer_norm, x, w_shape, weight, bias, eps) - print("the triton op latency is {} ms".format(str(latency_1))) - print("the torch op latency is {} ms".format(str(latency_2))) - - if __name__ == "__main__": test_layer_norm() diff --git a/tests/test_infer_ops/triton/test_llama_context_attention.py b/tests/test_infer_ops/triton/test_llama_context_attention.py index b0fac1263047..4c49c0b51333 100644 --- a/tests/test_infer_ops/triton/test_llama_context_attention.py +++ b/tests/test_infer_ops/triton/test_llama_context_attention.py @@ -46,12 +46,5 @@ def test_llama_context_attention(): assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched" - latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len) - latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim) - - print("the triton op latency is {} ms".format(str(latency_1))) - print("the torch op latency is {} ms".format(str(latency_2))) - - if __name__ == "__main__": test_llama_context_attention() \ No newline at end of file diff --git a/tests/test_infer_ops/triton/test_rotary_embedding.py b/tests/test_infer_ops/triton/test_rotary_embedding.py index 9fafd480a956..f9457c1a04f7 100644 --- a/tests/test_infer_ops/triton/test_rotary_embedding.py +++ b/tests/test_infer_ops/triton/test_rotary_embedding.py @@ -49,7 +49,6 @@ def test_rotary_emb(): y_torch = torch_rotary_emb(x, cos, sin) rotary_embedding_fwd(x, cos, sin) y_triton = x - # print("max delta:", torch.max(torch.abs(y_torch - y_triton))) # compare assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2) diff --git a/tests/test_infer_ops/triton/test_token_attn_1.py b/tests/test_infer_ops/triton/test_token_attn_1.py index ba236de82498..d01685e7788f 100644 --- a/tests/test_infer_ops/triton/test_token_attn_1.py +++ b/tests/test_infer_ops/triton/test_token_attn_1.py @@ -62,6 +62,7 @@ def test_attn_1(): # Warm up for _ in range(10): token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) + run_iter = 1000 torch.cuda.synchronize() t1 = time.time() @@ -77,38 +78,5 @@ def test_attn_1(): print("mean ", torch.mean(torch.abs(torch_out - o))) assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - -# def test_alibi_attn_1(): -# import torch - -# batch_size, seq_len, head_num, head_dim = 2, 1025, 12, 128 - -# dtype = torch.float16 - -# q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) -# k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) -# attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda") - -# # print(attn_out) - -# b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda") -# kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") -# kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda") - -# for i in range(batch_size): -# kv_cache_start_loc[i] = i * seq_len -# kv_cache_seq_len[i] = seq_len -# b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda") -# # print(b_loc[i]) - -# token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len) - -# torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze() -# o = attn_out.squeeze() -# print("max ", torch.max(torch.abs(torch_out - o))) -# print("mean ", torch.mean(torch.abs(torch_out - o))) -# assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) - if __name__ == "__main__": test_attn_1() - # test_alibi_attn_1()