Skip to content

Commit

Permalink
[gptq] add gptq kernel (hpcaitech#4416)
Browse files Browse the repository at this point in the history
* add gptq

* refactor code

* fix tests

* replace auto-gptq

* rname inferance/quant

* refactor test

* add auto-gptq as an option

* reset requirements

* change assert and check auto-gptq

* add import warnings

* change test flash attn version

* remove example

* change requirements of flash_attn

* modify tests

* [skip ci] change requirements-test
  • Loading branch information
Xu-Kai committed Sep 12, 2023
1 parent 8844691 commit 564f54d
Show file tree
Hide file tree
Showing 7 changed files with 973 additions and 0 deletions.
7 changes: 7 additions & 0 deletions colossalai/gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .cai_gptq import HAS_AUTO_GPTQ

if HAS_AUTO_GPTQ:
from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear,
CaiQuantLinear, CaiGPTQLinearOp)


14 changes: 14 additions & 0 deletions colossalai/gptq/cai_gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import warnings

HAS_AUTO_GPTQ = False
try:
import auto_gptq
HAS_AUTO_GPTQ = True
except ImportError:
warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
HAS_AUTO_GPTQ = False

if HAS_AUTO_GPTQ:
from .gptq_triton import gptq_fused_linear_triton
from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear
from .gptq_op import CaiGPTQLinearOp
131 changes: 131 additions & 0 deletions colossalai/gptq/cai_gptq/cai_quant_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

import math
import numpy as np
import torch
import torch.nn as nn
from .gptq_op import CaiGPTQLinearOp
import triton

class CaiQuantLinear(nn.Module):

def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures

self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64))
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))

if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None

self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)


def pack(self, linear, scales, zeros, g_idx=None):

g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
half_scales = scales.clone().half()
# print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()

wn = 16
pbits = 64
ptype = torch.int64
unsign_type = np.uint64
sign_type = np.int64

# wn = 8
# pbits = 32
# ptype = torch.int32
# unsign_type = np.uint32
# sign_type = np.int32

intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type)
qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)

i = 0
row = 0
# print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits)
# print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape)
# print("weight value ", intweight[0], qweight[0])

while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qweight[row] |= intweight[j] << ( self.bits * (j - i))
i += pbits // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda")
self.qweight.data.copy_(qweight1)

qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
zeros -= 1
zeros = zeros.numpy().astype(unsign_type)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qzeros[:, col] |= zeros[:, j] << ( self.bits * (j - i))
i += pbits // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(sign_type)
qzeros = torch.from_numpy(qzeros)
qzeros = qzeros
self.qzeros.data.copy_(qzeros)

if torch.equal(self.g_idx, g_idx):
self.g_idx = None
else:
self.g_idx = g_idx


def forward(self, x):

cai_out = self.gptq_linear(x,
self.qweight,
self.scales,
self.qzeros,
g_idx = self.g_idx,
bias = self.bias,)
return cai_out

def make_cai_quant_linear(module, names, bits, groupsize, name=''):
if isinstance(module, CaiQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
for name1, child in module.named_children():
make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)

44 changes: 44 additions & 0 deletions colossalai/gptq/cai_gptq/gptq_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from .gptq_triton import gptq_fused_linear_triton
import torch


class CaiGPTQLinearOp(torch.nn.Module):

def __init__(self, gptq_group_size, gptq_quant_bits):
super(CaiGPTQLinearOp, self).__init__()
self.group_size = gptq_group_size
self.bits = gptq_quant_bits
self.maxq = 2**self.bits - 1
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())

def forward(self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
g_idx: torch.Tensor = None,
act_type = 0,
bias: torch.Tensor = None,
residual: torch.Tensor=None,
qkv_fused = False):

add_bias = True
if bias is None:
bias = self.empty_tensor
add_bias = False

add_residual = True
if residual is None:
residual = self.empty_tensor
add_residual = False
x = input.view(-1, input.shape[-1])

out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual,
self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual,
act_type=act_type, g_idx=g_idx)
if qkv_fused:
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
else:
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])

return out
Loading

0 comments on commit 564f54d

Please sign in to comment.