Skip to content

Commit 564f54d

Browse files
committed
[gptq] add gptq kernel (hpcaitech#4416)
* add gptq * refactor code * fix tests * replace auto-gptq * rname inferance/quant * refactor test * add auto-gptq as an option * reset requirements * change assert and check auto-gptq * add import warnings * change test flash attn version * remove example * change requirements of flash_attn * modify tests * [skip ci] change requirements-test
1 parent 8844691 commit 564f54d

File tree

7 files changed

+973
-0
lines changed

7 files changed

+973
-0
lines changed

colossalai/gptq/__init__.py

+7
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .cai_gptq import HAS_AUTO_GPTQ
2+
3+
if HAS_AUTO_GPTQ:
4+
from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear,
5+
CaiQuantLinear, CaiGPTQLinearOp)
6+
7+

colossalai/gptq/cai_gptq/__init__.py

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import warnings
2+
3+
HAS_AUTO_GPTQ = False
4+
try:
5+
import auto_gptq
6+
HAS_AUTO_GPTQ = True
7+
except ImportError:
8+
warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
9+
HAS_AUTO_GPTQ = False
10+
11+
if HAS_AUTO_GPTQ:
12+
from .gptq_triton import gptq_fused_linear_triton
13+
from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear
14+
from .gptq_op import CaiGPTQLinearOp
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
2+
import math
3+
import numpy as np
4+
import torch
5+
import torch.nn as nn
6+
from .gptq_op import CaiGPTQLinearOp
7+
import triton
8+
9+
class CaiQuantLinear(nn.Module):
10+
11+
def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
12+
super().__init__()
13+
if bits not in [2, 4, 8]:
14+
raise NotImplementedError("Only 2,4,8 bits are supported.")
15+
self.infeatures = infeatures
16+
self.outfeatures = outfeatures
17+
self.bits = bits
18+
self.maxq = 2**self.bits - 1
19+
self.groupsize = groupsize if groupsize != -1 else infeatures
20+
21+
self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64))
22+
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64))
23+
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
24+
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))
25+
26+
if bias:
27+
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
28+
else:
29+
self.bias = None
30+
31+
self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)
32+
33+
34+
def pack(self, linear, scales, zeros, g_idx=None):
35+
36+
g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)
37+
38+
scales = scales.t().contiguous()
39+
zeros = zeros.t().contiguous()
40+
scale_zeros = zeros * scales
41+
half_scales = scales.clone().half()
42+
# print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
43+
self.scales = scales.clone().half()
44+
if linear.bias is not None:
45+
self.bias = linear.bias.clone().half()
46+
47+
wn = 16
48+
pbits = 64
49+
ptype = torch.int64
50+
unsign_type = np.uint64
51+
sign_type = np.int64
52+
53+
# wn = 8
54+
# pbits = 32
55+
# ptype = torch.int32
56+
# unsign_type = np.uint32
57+
# sign_type = np.int32
58+
59+
intweight = []
60+
for idx in range(self.infeatures):
61+
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None])
62+
intweight = torch.cat(intweight, dim=1)
63+
intweight = intweight.t().contiguous()
64+
intweight = intweight.numpy().astype(unsign_type)
65+
qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)
66+
67+
i = 0
68+
row = 0
69+
# print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits)
70+
# print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape)
71+
# print("weight value ", intweight[0], qweight[0])
72+
73+
while row < qweight.shape[0]:
74+
if self.bits in [2, 4, 8]:
75+
for j in range(i, i + (pbits // self.bits)):
76+
qweight[row] |= intweight[j] << ( self.bits * (j - i))
77+
i += pbits // self.bits
78+
row += 1
79+
else:
80+
raise NotImplementedError("Only 2,4,8 bits are supported.")
81+
qweight = qweight.astype(sign_type)
82+
qweight1 = torch.from_numpy(qweight)
83+
qweight1 = qweight1.contiguous() #.to("cuda")
84+
self.qweight.data.copy_(qweight1)
85+
86+
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
87+
zeros -= 1
88+
zeros = zeros.numpy().astype(unsign_type)
89+
i = 0
90+
col = 0
91+
while col < qzeros.shape[1]:
92+
if self.bits in [2, 4, 8]:
93+
for j in range(i, i + (pbits // self.bits)):
94+
qzeros[:, col] |= zeros[:, j] << ( self.bits * (j - i))
95+
i += pbits // self.bits
96+
col += 1
97+
else:
98+
raise NotImplementedError("Only 2,4,8 bits are supported.")
99+
qzeros = qzeros.astype(sign_type)
100+
qzeros = torch.from_numpy(qzeros)
101+
qzeros = qzeros
102+
self.qzeros.data.copy_(qzeros)
103+
104+
if torch.equal(self.g_idx, g_idx):
105+
self.g_idx = None
106+
else:
107+
self.g_idx = g_idx
108+
109+
110+
def forward(self, x):
111+
112+
cai_out = self.gptq_linear(x,
113+
self.qweight,
114+
self.scales,
115+
self.qzeros,
116+
g_idx = self.g_idx,
117+
bias = self.bias,)
118+
return cai_out
119+
120+
def make_cai_quant_linear(module, names, bits, groupsize, name=''):
121+
if isinstance(module, CaiQuantLinear):
122+
return
123+
for attr in dir(module):
124+
tmp = getattr(module, attr)
125+
name1 = name + '.' + attr if name != '' else attr
126+
if name1 in names:
127+
delattr(module, attr)
128+
setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
129+
for name1, child in module.named_children():
130+
make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
131+

colossalai/gptq/cai_gptq/gptq_op.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from .gptq_triton import gptq_fused_linear_triton
2+
import torch
3+
4+
5+
class CaiGPTQLinearOp(torch.nn.Module):
6+
7+
def __init__(self, gptq_group_size, gptq_quant_bits):
8+
super(CaiGPTQLinearOp, self).__init__()
9+
self.group_size = gptq_group_size
10+
self.bits = gptq_quant_bits
11+
self.maxq = 2**self.bits - 1
12+
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())
13+
14+
def forward(self,
15+
input: torch.Tensor,
16+
weight: torch.Tensor,
17+
weight_scales: torch.Tensor,
18+
weight_zeros: torch.Tensor,
19+
g_idx: torch.Tensor = None,
20+
act_type = 0,
21+
bias: torch.Tensor = None,
22+
residual: torch.Tensor=None,
23+
qkv_fused = False):
24+
25+
add_bias = True
26+
if bias is None:
27+
bias = self.empty_tensor
28+
add_bias = False
29+
30+
add_residual = True
31+
if residual is None:
32+
residual = self.empty_tensor
33+
add_residual = False
34+
x = input.view(-1, input.shape[-1])
35+
36+
out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual,
37+
self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual,
38+
act_type=act_type, g_idx=g_idx)
39+
if qkv_fused:
40+
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
41+
else:
42+
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])
43+
44+
return out

0 commit comments

Comments
 (0)