Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gptq] add gptq kernel #4416

Merged
merged 15 commits into from
Aug 21, 2023
7 changes: 7 additions & 0 deletions colossalai/gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .cai_gptq import HAS_AUTO_GPTQ

if HAS_AUTO_GPTQ:
from .cai_gptq import (gptq_fused_linear_triton, make_cai_quant_linear,
CaiQuantLinear, CaiGPTQLinearOp)


14 changes: 14 additions & 0 deletions colossalai/gptq/cai_gptq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import warnings

HAS_AUTO_GPTQ = False
try:
import auto_gptq
HAS_AUTO_GPTQ = True
except ImportError:
warnings.warn('please install auto-gptq from https://github.com/PanQiWei/AutoGPTQ')
HAS_AUTO_GPTQ = False

if HAS_AUTO_GPTQ:
from .gptq_triton import gptq_fused_linear_triton
from .cai_quant_linear import make_cai_quant_linear, CaiQuantLinear
from .gptq_op import CaiGPTQLinearOp
131 changes: 131 additions & 0 deletions colossalai/gptq/cai_gptq/cai_quant_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@

import math
import numpy as np
import torch
import torch.nn as nn
from .gptq_op import CaiGPTQLinearOp
import triton

class CaiQuantLinear(nn.Module):

def __init__(self, bits, groupsize, infeatures, outfeatures, bias):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.maxq = 2**self.bits - 1
self.groupsize = groupsize if groupsize != -1 else infeatures

self.register_buffer('qweight', torch.zeros((infeatures // 64 * self.bits, outfeatures), dtype=torch.int64))
self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 64 * self.bits), dtype=torch.int64))
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16))
self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32))

if bias:
self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None

self.gptq_linear = CaiGPTQLinearOp(groupsize, bits)


def pack(self, linear, scales, zeros, g_idx=None):

g_idx = g_idx.clone() if g_idx is not None else torch.tensor([i // self.groupsize for i in range(self.infeatures)], dtype=torch.int32)

scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
half_scales = scales.clone().half()
# print("scale shape ", scales.shape, scale_zeros.shape, linear.weight.shape)
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()

wn = 16
pbits = 64
ptype = torch.int64
unsign_type = np.uint64
sign_type = np.int64

# wn = 8
# pbits = 32
# ptype = torch.int32
# unsign_type = np.uint32
# sign_type = np.int32

intweight = []
for idx in range(self.infeatures):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx[idx]]) / half_scales[g_idx[idx]]).to(ptype)[:, None])
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(unsign_type)
qweight = np.zeros((intweight.shape[0] // pbits * self.bits, intweight.shape[1]), dtype=unsign_type)

i = 0
row = 0
# print("weight shape ", intweight.shape, qweight.shape, out_qweight.shape, bits)
# print("weight shape ", intweight[0].shape, qweight[0].shape, out_qweight[0].shape)
# print("weight value ", intweight[0], qweight[0])

while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qweight[row] |= intweight[j] << ( self.bits * (j - i))
i += pbits // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(sign_type)
qweight1 = torch.from_numpy(qweight)
qweight1 = qweight1.contiguous() #.to("cuda")
self.qweight.data.copy_(qweight1)

qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // pbits * self.bits), dtype=unsign_type)
zeros -= 1
zeros = zeros.numpy().astype(unsign_type)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (pbits // self.bits)):
qzeros[:, col] |= zeros[:, j] << ( self.bits * (j - i))
i += pbits // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(sign_type)
qzeros = torch.from_numpy(qzeros)
qzeros = qzeros
self.qzeros.data.copy_(qzeros)

if torch.equal(self.g_idx, g_idx):
self.g_idx = None
else:
self.g_idx = g_idx


def forward(self, x):

cai_out = self.gptq_linear(x,
self.qweight,
self.scales,
self.qzeros,
g_idx = self.g_idx,
bias = self.bias,)
return cai_out

def make_cai_quant_linear(module, names, bits, groupsize, name=''):
if isinstance(module, CaiQuantLinear):
return
for attr in dir(module):
tmp = getattr(module, attr)
name1 = name + '.' + attr if name != '' else attr
if name1 in names:
delattr(module, attr)
setattr(module, attr, CaiQuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None))
for name1, child in module.named_children():
make_cai_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)

44 changes: 44 additions & 0 deletions colossalai/gptq/cai_gptq/gptq_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from .gptq_triton import gptq_fused_linear_triton
import torch


class CaiGPTQLinearOp(torch.nn.Module):

def __init__(self, gptq_group_size, gptq_quant_bits):
super(CaiGPTQLinearOp, self).__init__()
self.group_size = gptq_group_size
self.bits = gptq_quant_bits
self.maxq = 2**self.bits - 1
self.empty_tensor = torch.zeros(4, device=torch.cuda.current_device())

def forward(self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scales: torch.Tensor,
weight_zeros: torch.Tensor,
g_idx: torch.Tensor = None,
act_type = 0,
bias: torch.Tensor = None,
residual: torch.Tensor=None,
qkv_fused = False):

add_bias = True
if bias is None:
bias = self.empty_tensor
add_bias = False

add_residual = True
if residual is None:
residual = self.empty_tensor
add_residual = False
x = input.view(-1, input.shape[-1])

out = gptq_fused_linear_triton(x, weight, weight_scales, weight_zeros, bias, residual,
self.bits, self.maxq, self.group_size, qkv_fused, add_bias, add_residual,
act_type=act_type, g_idx=g_idx)
if qkv_fused:
out = out.view(3, input.shape[0], input.shape[1], weight.shape[-1])
else:
out = out.view(input.shape[0], input.shape[1], weight.shape[-1])

return out
Loading