Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[typo]Comments fix #4633

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
)
from transformers.utils import logging

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention
from colossalai.pipeline.stage_manager import PipelineStageManager


Expand Down Expand Up @@ -399,6 +398,8 @@ def llama_for_sequence_classification_forward(


def get_llama_flash_attention_forward():

from colossalai.kernel.cuda_native import AttnMaskType, ColoAttention

def forward(
self: LlamaAttention,
Expand Down
7 changes: 0 additions & 7 deletions tests/test_infer_ops/triton/test_bloom_context_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,6 @@ def test_bloom_context_attention():
torch_out = torch_context_attention(query.clone(), k.clone(), v.clone(), bs, seq_len, head_num, head_dim)

assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-2), "outputs from triton and torch are not matched"

latency_1 = benchmark(bloom_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len, alibi)
latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)

print("the triton op latency is {} ms".format(str(latency_1)))
print("the torch op latency is {} ms".format(str(latency_2)))


if __name__ == "__main__":
test_bloom_context_attention()
3 changes: 0 additions & 3 deletions tests/test_infer_ops/triton/test_copy_kv_dest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@ def test_kv_cache_copy_op():

assert torch.allclose(cache.cpu(), dest_data.cpu(), rtol=1e-3, atol=1e-3), "copy_kv_cache_to_dest outputs from triton and torch are not matched"

latency = benchmark(copy_kv_cache_to_dest, cache, dest_index, dest_data)
print("the average latency is {} ms".format(str(latency)))


if __name__ == "__main__":
test_kv_cache_copy_op()
Expand Down
19 changes: 0 additions & 19 deletions tests/test_infer_ops/triton/test_layernorm_triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,24 +41,5 @@ def test_layer_norm(M, N):
assert torch.allclose(y_triton, y_torch, atol=1e-2, rtol=0)


@pytest.mark.skipif(not TRITON_CUDA_SUPPORT or not HAS_TRITON,
reason="triton requires cuda version to be higher than 11.4")
@parameterize('M', [4])
@parameterize('N', [128])
def test_benchmark(M, N):
dtype = torch.float16
eps = 1e-5
x_shape = (M, N)
w_shape = (x_shape[-1],)
weight = torch.rand(w_shape, dtype=dtype, device='cuda')
bias = torch.rand(w_shape, dtype=dtype, device='cuda')
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device='cuda')

latency_1 = benchmark(layer_norm, x, weight, bias, eps)
latency_2 = benchmark(torch.nn.functional.layer_norm, x, w_shape, weight, bias, eps)
print("the triton op latency is {} ms".format(str(latency_1)))
print("the torch op latency is {} ms".format(str(latency_2)))


if __name__ == "__main__":
test_layer_norm()
7 changes: 0 additions & 7 deletions tests/test_infer_ops/triton/test_llama_context_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,5 @@ def test_llama_context_attention():

assert torch.allclose(torch_out.cpu(), o.cpu(), rtol=1e-3, atol=1e-3), "outputs from triton and torch are not matched"

latency_1 = benchmark(llama_context_attn_fwd, query, k, v, o, b_start, b_len, max_input_len)
latency_2 = benchmark(torch_context_attention, query, k, v, bs, seq_len, head_num, head_dim)

print("the triton op latency is {} ms".format(str(latency_1)))
print("the torch op latency is {} ms".format(str(latency_2)))


if __name__ == "__main__":
test_llama_context_attention()
1 change: 0 additions & 1 deletion tests/test_infer_ops/triton/test_rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def test_rotary_emb():
y_torch = torch_rotary_emb(x, cos, sin)
rotary_embedding_fwd(x, cos, sin)
y_triton = x
# print("max delta:", torch.max(torch.abs(y_torch - y_triton)))
# compare
assert torch.allclose(y_torch, y_triton, atol=1e-2, rtol=1e-2)

Expand Down
34 changes: 1 addition & 33 deletions tests/test_infer_ops/triton/test_token_attn_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def test_attn_1():
# Warm up
for _ in range(10):
token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)

run_iter = 1000
torch.cuda.synchronize()
t1 = time.time()
Expand All @@ -77,38 +78,5 @@ def test_attn_1():
print("mean ", torch.mean(torch.abs(torch_out - o)))
assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)


# def test_alibi_attn_1():
# import torch

# batch_size, seq_len, head_num, head_dim = 2, 1025, 12, 128

# dtype = torch.float16

# q = torch.empty((batch_size, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
# k = torch.empty((batch_size * seq_len, head_num, head_dim), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2)
# attn_out = torch.empty((head_num, batch_size * seq_len), dtype=dtype, device="cuda")

# # print(attn_out)

# b_loc = torch.zeros((batch_size, seq_len), dtype=torch.int32, device="cuda")
# kv_cache_start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
# kv_cache_seq_len = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")

# for i in range(batch_size):
# kv_cache_start_loc[i] = i * seq_len
# kv_cache_seq_len[i] = seq_len
# b_loc[i] = i * seq_len + torch.arange(0, seq_len, dtype=torch.int32, device="cuda")
# # print(b_loc[i])

# token_attn_fwd_1(q, k, attn_out, b_loc, kv_cache_start_loc, kv_cache_seq_len, seq_len)

# torch_out = torch_attn(q, k, batch_size, seq_len, head_num, head_dim).squeeze()
# o = attn_out.squeeze()
# print("max ", torch.max(torch.abs(torch_out - o)))
# print("mean ", torch.mean(torch.abs(torch_out - o)))
# assert torch.allclose(torch_out, o, atol=1e-2, rtol=0)

if __name__ == "__main__":
test_attn_1()
# test_alibi_attn_1()
Loading