Skip to content

Commit 54a319e

Browse files
Xu-Kaiflybird11111yuanheng-zhaobinmakeswellBaizhou Zhang
committed
[inference] add int8 rotary embedding kernel for smoothquant (hpcaitech#4843)
* [shardformer] fix GPT2DoubleHeadsModel (hpcaitech#4703) * [hotfix] Fix import error: colossal.kernel without triton installed (hpcaitech#4722) * [hotfix] remove triton kernels from kernel init * revise bloom/llama kernel imports for infer * [shardformer] to fix whisper test failed due to significant accuracy differences. (hpcaitech#4710) * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [shardformer] fix whisper test failed * [doc] fix llama2 code link (hpcaitech#4726) * [doc] fix llama2 code link * [doc] fix llama2 code link * [doc] fix llama2 code link * [doc] Add user document for Shardformer (hpcaitech#4702) * create shardformer doc files * add docstring for seq-parallel * update ShardConfig docstring * add links to llama example * add outdated massage * finish introduction & supporting information * finish 'how shardformer works' * finish shardformer.md English doc * fix doctest fail * add Chinese document * [format] applied code formatting on changed files in pull request 4726 (hpcaitech#4727) Co-authored-by: github-actions <github-actions@github.com> * [doc] add shardformer support matrix/update tensor parallel documents (hpcaitech#4728) * add compatibility matrix for shardformer doc * update tp doc * Optimized some syntax errors in the documentation and code under applications/ (hpcaitech#4127) Co-authored-by: flybird11111 <1829166702@qq.com> * [shardformer] update pipeline parallel document (hpcaitech#4725) * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [shardformer] update pipeline parallel document * [legacy] remove deterministic data loader test * [shardformer] update seq parallel document (hpcaitech#4730) * update doc of seq parallel * fix typo * [example] add gpt2 HybridParallelPlugin example (hpcaitech#4653) * add gpt2 HybridParallelPlugin example * update readme and testci * update test ci * fix test_ci bug * update requirements * add requirements * update requirements * add requirement * rename file * [doc] polish shardformer doc (hpcaitech#4735) * arrange position of chapters * fix typos in seq parallel doc * [shardformer] add custom policy in hybrid parallel plugin (hpcaitech#4718) * add custom policy * update assert * [example] llama2 add fine-tune example (hpcaitech#4673) * [shardformer] update shardformer readme [shardformer] update shardformer readme [shardformer] update shardformer readme * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] update llama2/opt finetune example and shardformer update to llama2 * [shardformer] change dataset * [shardformer] change dataset * [shardformer] fix CI * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix * [shardformer] fix [example] update opt example [example] resolve comments fix fix * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * [example] llama2 add finetune example * fix * update llama2 example * update llama2 example * fix * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * update llama2 example * Update requirements.txt * update llama2 example * update llama2 example * update llama2 example * [doc] explaination of loading large pretrained models (hpcaitech#4741) * [kernel] update triton init hpcaitech#4740 (hpcaitech#4740) * [legacy] clean up legacy code (hpcaitech#4743) * [legacy] remove outdated codes of pipeline (hpcaitech#4692) * [legacy] remove cli of benchmark and update optim (hpcaitech#4690) * [legacy] remove cli of benchmark and update optim * [doc] fix cli doc test * [legacy] fix engine clip grad norm * [legacy] remove outdated colo tensor (hpcaitech#4694) * [legacy] remove outdated colo tensor * [test] fix test import * [legacy] move outdated zero to legacy (hpcaitech#4696) * [legacy] clean up utils (hpcaitech#4700) * [legacy] clean up utils * [example] update examples * [legacy] clean up amp * [legacy] fix amp module * [legacy] clean up gpc (hpcaitech#4742) * [legacy] clean up context * [legacy] clean core, constants and global vars * [legacy] refactor initialize * [example] fix examples ci * [example] fix examples ci * [legacy] fix tests * [example] fix gpt example * [example] fix examples ci * [devops] fix ci installation * [example] fix examples ci * [format] applied code formatting on changed files in pull request 4743 (hpcaitech#4750) Co-authored-by: github-actions <github-actions@github.com> * [misc] update pre-commit and run all files (hpcaitech#4752) * [misc] update pre-commit * [misc] run pre-commit * [misc] remove useless configuration files * [misc] ignore cuda for clang-format * [doc] explain suitable use case for each plugin * [doc] put individual plugin explanation in front * [doc] add model examples for each plugin * [doc] put native colossalai plugins first in description section * [chat]: update rm, add wandb and fix bugs (hpcaitech#4471) * feat: modify forward fn of critic and reward model * feat: modify calc_action_log_probs * to: add wandb in sft and rm trainer * feat: update train_sft * feat: update train_rm * style: modify type annotation and add warning * feat: pass tokenizer to ppo trainer * to: modify trainer base and maker base * feat: add wandb in ppo trainer * feat: pass tokenizer to generate * test: update generate fn tests * test: update train tests * fix: remove action_mask * feat: remove unused code * fix: fix wrong ignore_index * fix: fix mock tokenizer * chore: update requirements * revert: modify make_experience * fix: fix inference * fix: add padding side * style: modify _on_learn_batch_end * test: use mock tokenizer * fix: use bf16 to avoid overflow * fix: fix workflow * [chat] fix gemini strategy * [chat] fix * sync: update colossalai strategy * fix: fix args and model dtype * fix: fix checkpoint test * fix: fix requirements * fix: fix missing import and wrong arg * fix: temporarily skip gemini test in stage 3 * style: apply pre-commit * fix: temporarily skip gemini test in stage 1&2 --------- Co-authored-by: Mingyan Jiang <1829166702@qq.com> * [shardformer] fix master param sync for hybrid plugin/rewrite unwrapping logic (hpcaitech#4758) * fix master param sync for hybrid plugin * rewrite unwrap for ddp/fsdp * rewrite unwrap for zero/gemini * rewrite unwrap for hybrid plugin * fix geemini unwrap * fix bugs * [bug] fix get_default_parser in examples (hpcaitech#4764) * [doc] clean up outdated docs (hpcaitech#4765) * [doc] clean up outdated docs * [doc] fix linking * [doc] fix linking * [doc] add shardformer doc to sidebar (hpcaitech#4768) * [chat]: add lora merge weights config (hpcaitech#4766) * feat: modify lora merge weights fn * feat: add lora merge weights config * [lazy] support torch 2.0 (hpcaitech#4763) * [lazy] support _like methods and clamp * [lazy] pass transformers models * [lazy] fix device move and requires grad * [lazy] fix requires grad and refactor api * [lazy] fix requires grad * [bug] Fix the version check bug in colossalai run when generating the cmd. (hpcaitech#4713) * Fix the version check bug in colossalai run when generating the cmd. * polish code * [feature] add gptq for inference (hpcaitech#4754) * [gptq] add gptq kernel (hpcaitech#4416) * add gptq * refactor code * fix tests * replace auto-gptq * rname inferance/quant * refactor test * add auto-gptq as an option * reset requirements * change assert and check auto-gptq * add import warnings * change test flash attn version * remove example * change requirements of flash_attn * modify tests * [skip ci] change requirements-test * [gptq] faster gptq cuda kernel (hpcaitech#4494) * [skip ci] add cuda kernels * add license * [skip ci] fix max_input_len * format files & change test size * [skip ci] * [gptq] add gptq tensor parallel (hpcaitech#4538) * add gptq tensor parallel * add gptq tp * delete print * add test gptq check * add test auto gptq check * [gptq] combine gptq and kv cache manager (hpcaitech#4706) * combine gptq and kv cache manager * add init bits * delete useless code * add model path * delete usless print and update test * delete usless import * move option gptq to shard config * change replace linear to shardformer * update bloom policy * delete useless code * fix import bug and delete uselss code * change colossalai/gptq to colossalai/quant/gptq * update import linear for tests * delete useless code and mv gptq_kernel to kernel directory * fix triton kernel * add triton import * [inference] chatglm2 infer demo (hpcaitech#4724) * add chatglm2 * add * gather needed kernels * fix some bugs * finish context forward * finish context stage * fix * add * pause * add * fix bugs * finish chatglm * fix bug * change some logic * fix bugs * change some logics * add * add * add * fix * fix tests * fix * [release] update version (hpcaitech#4775) * [release] update version * [doc] revert versions * initial commit: add colossal llama 2 (hpcaitech#4784) * [feature] ColossalEval: Evaluation Pipeline for LLMs (hpcaitech#4786) * Add ColossalEval * Delete evaluate in Chat --------- Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> * [doc] add llama2 domain-specific solution news (hpcaitech#4789) * [doc] add llama2 domain-specific solution news * [fix] fix weekly runing example (hpcaitech#4787) * [fix] fix weekly runing example * [fix] fix weekly runing example * [doc] polish shardformer doc (hpcaitech#4779) * fix example format in docstring * polish shardformer doc * [checkpointio] support unsharded checkpointIO for hybrid parallel (hpcaitech#4774) * support unsharded saving/loading for model * support optimizer unsharded saving * update doc * support unsharded loading for optimizer * small fix * update readme * [lazy] support from_pretrained (hpcaitech#4801) * [lazy] patch from pretrained * [lazy] fix from pretrained and add tests * [devops] update ci * update * [hotfix] change llama2 Colossal-LLaMA-2 script filename (hpcaitech#4800) change filename: pretraining.py -> trainin.py there is no file named pretraing.py. wrong writing * [misc] add last_epoch in CosineAnnealingWarmupLR (hpcaitech#4778) * [doc] add lazy init docs (hpcaitech#4808) * [hotfix] fix norm type error in zero optimizer (hpcaitech#4795) * [hotfix] Correct several erroneous code comments (hpcaitech#4794) * [format] applied code formatting on changed files in pull request 4595 (hpcaitech#4602) Co-authored-by: github-actions <github-actions@github.com> * fix format (hpcaitech#4815) * [chat] fix gemini strategy (hpcaitech#4698) * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * [chat] fix gemini strategy * g# This is a combination of 2 commits. [chat] fix gemini strategy fox * [chat] fix gemini strategy update llama2 example [chat] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * [fix] fix gemini strategy * fix * fix * fix * fix * fix * Update train_prompts.py * Update Qwen-7B results (hpcaitech#4821) Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> * [doc] update slack link (hpcaitech#4823) * add autotune (hpcaitech#4822) * update Colossal (hpcaitech#4832) * add int8 rotary embedding kernel * remove useless code --------- Co-authored-by: flybird11111 <1829166702@qq.com> Co-authored-by: Yuanheng Zhao <54058983+yuanheng-zhao@users.noreply.github.com> Co-authored-by: binmakeswell <binmakeswell@gmail.com> Co-authored-by: Baizhou Zhang <eddiezhang@pku.edu.cn> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: github-actions <github-actions@github.com> Co-authored-by: digger yu <digger-yu@outlook.com> Co-authored-by: Pengtai Xu <henryxu880@gmail.com> Co-authored-by: Bin Jia <45593998+FoolPlayer@users.noreply.github.com> Co-authored-by: ppt0011 <143150326+ppt0011@users.noreply.github.com> Co-authored-by: Xuanlei Zhao <43881818+oahzxl@users.noreply.github.com> Co-authored-by: Hongxin Liu <lhx0217@gmail.com> Co-authored-by: Wenhao Chen <cwher@outlook.com> Co-authored-by: littsk <1214689160@qq.com> Co-authored-by: Jianghai <72591262+CjhHa1@users.noreply.github.com> Co-authored-by: Tong Li <tong.li352711588@gmail.com> Co-authored-by: Yuanchen <70520919+chengeharrison@users.noreply.github.com> Co-authored-by: Xu Yuanchen <yuanchen.xu00@gmail.com> Co-authored-by: Desperado-Jia <502205863@qq.com> Co-authored-by: Chandler-Bing <brp12138@163.com> Co-authored-by: Yan haixu <40758050+hova88@users.noreply.github.com>
1 parent 39f2582 commit 54a319e

File tree

3 files changed

+180
-0
lines changed

3 files changed

+180
-0
lines changed

colossalai/kernel/triton/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .copy_kv_cache_dest import copy_kv_cache_to_dest
1414
from .fused_layernorm import layer_norm
1515
from .gptq_triton import gptq_fused_linear_triton
16+
from .int8_rotary_embedding_kernel import int8_rotary_embedding_fwd
1617
from .rms_norm import rmsnorm_forward
1718
from .rotary_embedding_kernel import rotary_embedding_fwd
1819
from .softmax import softmax
@@ -28,4 +29,5 @@
2829
"rotary_embedding_fwd",
2930
"token_attention_fwd",
3031
"gptq_fused_linear_triton",
32+
"int8_rotary_embedding_fwd",
3133
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Adapted from ModelTC https://github.com/ModelTC/lightllm
2+
import torch
3+
import triton
4+
import triton.language as tl
5+
6+
7+
@triton.jit
8+
def _rotary_kernel(
9+
q,
10+
input_scale,
11+
output_scale,
12+
Cos,
13+
Sin,
14+
q_bs_stride,
15+
q_h_stride,
16+
q_d_stride,
17+
cos_bs_stride,
18+
cos_d_stride,
19+
total_len,
20+
HEAD_NUM: tl.constexpr,
21+
BLOCK_HEAD: tl.constexpr,
22+
BLOCK_SEQ: tl.constexpr,
23+
HEAD_DIM: tl.constexpr,
24+
):
25+
current_head_index = tl.program_id(0)
26+
current_seq_index = tl.program_id(1)
27+
28+
dim_range0 = tl.arange(0, HEAD_DIM // 2)
29+
dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)
30+
31+
current_head_range = current_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)
32+
current_seq_range = current_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)
33+
34+
off_q0 = (
35+
current_seq_range[:, None, None] * q_bs_stride
36+
+ current_head_range[None, :, None] * q_h_stride
37+
+ dim_range0[None, None, :] * q_d_stride
38+
)
39+
off_q1 = (
40+
current_seq_range[:, None, None] * q_bs_stride
41+
+ current_head_range[None, :, None] * q_h_stride
42+
+ dim_range1[None, None, :] * q_d_stride
43+
)
44+
45+
off_dimcos_sin = current_seq_range[:, None, None] * cos_bs_stride + dim_range0[None, None, :] * cos_d_stride
46+
47+
q0 = tl.load(
48+
q + off_q0,
49+
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
50+
other=0.0,
51+
)
52+
q1 = tl.load(
53+
q + off_q1,
54+
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
55+
other=0.0,
56+
)
57+
58+
cos = tl.load(Cos + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
59+
sin = tl.load(Sin + off_dimcos_sin, mask=current_seq_range[:, None, None] < total_len, other=0.0)
60+
in_scale = tl.load(input_scale)
61+
o_scale = tl.load(output_scale)
62+
63+
q0 = q0.to(tl.float32) * in_scale
64+
q1 = q1.to(tl.float32) * in_scale
65+
66+
out0 = (q0 * cos - q1 * sin) / o_scale
67+
out1 = (q0 * sin + q1 * cos) / o_scale
68+
69+
# out0 = out0.to(tl.int8)
70+
# out1 = out1.to(tl.int8)
71+
72+
tl.store(
73+
q + off_q0,
74+
out0,
75+
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
76+
)
77+
tl.store(
78+
q + off_q1,
79+
out1,
80+
mask=(current_seq_range[:, None, None] < total_len) & (current_head_range[None, :, None] < HEAD_NUM),
81+
)
82+
83+
return
84+
85+
86+
@torch.no_grad()
87+
def int8_rotary_embedding_fwd(q, cos, sin, input_scale, output_scale):
88+
total_len = q.shape[0]
89+
head_num = q.shape[1]
90+
head_dim = q.shape[2]
91+
assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}"
92+
BLOCK_HEAD = 4
93+
BLOCK_SEQ = 32
94+
grid = (triton.cdiv(head_num, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))
95+
if head_dim >= 128:
96+
num_warps = 8
97+
else:
98+
num_warps = 4
99+
100+
_rotary_kernel[grid](
101+
q,
102+
input_scale,
103+
output_scale,
104+
cos,
105+
sin,
106+
q.stride(0),
107+
q.stride(1),
108+
q.stride(2),
109+
cos.stride(0),
110+
cos.stride(1),
111+
total_len,
112+
HEAD_NUM=head_num,
113+
BLOCK_HEAD=BLOCK_HEAD,
114+
BLOCK_SEQ=BLOCK_SEQ,
115+
HEAD_DIM=head_dim,
116+
num_warps=num_warps,
117+
num_stages=1,
118+
)
119+
return
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Adapted from ModelTC https://github.com/ModelTC/lightllm
2+
3+
4+
import pytest
5+
import torch
6+
from packaging import version
7+
8+
try:
9+
from colossalai.kernel.triton import int8_rotary_embedding_fwd
10+
11+
HAS_TRITON = True
12+
except ImportError:
13+
HAS_TRITON = False
14+
print("please install triton from https://github.com/openai/triton")
15+
16+
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
17+
18+
19+
def torch_rotary_emb(x, cos, sin):
20+
seq_len, h, dim = x.shape
21+
x0 = x[:, :, 0 : dim // 2]
22+
x1 = x[:, :, dim // 2 : dim]
23+
cos = cos.view((seq_len, 1, dim // 2))
24+
sin = sin.view((seq_len, 1, dim // 2))
25+
o0 = x0 * cos - x1 * sin
26+
o1 = x0 * sin + x1 * cos
27+
return torch.cat((o0, o1), dim=-1)
28+
29+
30+
@pytest.mark.skipif(
31+
not TRITON_CUDA_SUPPORT or not HAS_TRITON, reason="triton requires cuda version to be higher than 11.4"
32+
)
33+
def test_rotary_emb():
34+
SEQ_LEN = 1
35+
HEAD_NUM = 32
36+
HEAD_DIM = 128
37+
dtype = torch.float
38+
# create data
39+
x_shape = (SEQ_LEN, HEAD_NUM, HEAD_DIM)
40+
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
41+
cos_shape = (SEQ_LEN, HEAD_DIM // 2)
42+
cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
43+
sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda")
44+
# forward pass
45+
y_torch = torch_rotary_emb(x, cos, sin)
46+
47+
input_scale = torch.max(torch.abs(x)) / 127
48+
output_scale = torch.max(torch.abs(y_torch)) / 127
49+
50+
x = x / input_scale
51+
x = x.to(torch.int8)
52+
53+
int8_rotary_embedding_fwd(x, cos, sin, input_scale, output_scale)
54+
y_triton = x.to(torch.float) * output_scale
55+
assert torch.allclose(y_triton, y_torch, atol=2e-1, rtol=1e-2, equal_nan=True)
56+
57+
58+
if __name__ == "__main__":
59+
test_rotary_emb()

0 commit comments

Comments
 (0)