-
Notifications
You must be signed in to change notification settings - Fork 231
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
6b7992a
commit 68c727a
Showing
7 changed files
with
260 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
from .exllama import WQLinear_Exllama | ||
from .exllamav2 import WQLinear_ExllamaV2 | ||
from .exllama import WQLinear_Exllama, exllama_post_init | ||
from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init | ||
from .gemm import WQLinear_GEMM | ||
from .gemv import WQLinear_GEMV | ||
from .marlin import WQLinear_Marlin | ||
from .marlin import WQLinear_Marlin, marlin_post_init | ||
from .gemv_fast import WQLinear_GEMVFast |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
import torch | ||
|
||
try: | ||
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels) | ||
|
||
AWQ_INSTALLED = True | ||
except: | ||
AWQ_INSTALLED = False | ||
|
||
|
||
def make_divisible(c, divisor): | ||
return (c + divisor - 1) // divisor | ||
|
||
|
||
def calculate_zeros_width(in_features, group_size=128, pack_num=8): | ||
if group_size >= 128: | ||
size_multiplier = 1 | ||
elif group_size == 64: | ||
size_multiplier = 2 | ||
elif group_size == 32: | ||
size_multiplier = 4 | ||
else: | ||
raise NotImplementedError | ||
|
||
base_width = make_divisible(in_features // group_size, pack_num) | ||
base_width = make_divisible(base_width, size_multiplier) * size_multiplier | ||
return base_width | ||
|
||
|
||
def pack_intweight(unpacked_qweight, interleave, kstride): | ||
# unpacked_qweight: [N, K] | ||
N = unpacked_qweight.shape[0] | ||
K = unpacked_qweight.shape[1] | ||
|
||
Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32) | ||
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...] | ||
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4) | ||
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32) | ||
|
||
# reorder each 8 weights for fast dequantization | ||
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7] | ||
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8) | ||
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3) | ||
Packed_Kernel = Packed_Kernel.reshape(N, K) | ||
|
||
# interleaving every four rows | ||
Packed_Kernel = Packed_Kernel.reshape( | ||
N // interleave, interleave, K // kstride, kstride | ||
) | ||
# N // 4, K // 64, 4, 64 | ||
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3) | ||
Packed_Kernel = Packed_Kernel.reshape( | ||
N // interleave, K // kstride, kstride, interleave | ||
) | ||
# Packing -> (N // 4, K // 64, 64) | ||
Packed_Kernel = ( | ||
Packed_Kernel[..., 0] | ||
| (Packed_Kernel[..., 1] << 4) | ||
| (Packed_Kernel[..., 2] << 8) | ||
| (Packed_Kernel[..., 3] << 12) | ||
) | ||
# reshape to (N // 4, K), FP16 format | ||
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K) | ||
qweight = ( | ||
torch.tensor(Packed_Kernel.astype("int16")) | ||
.to(unpacked_qweight.device) | ||
.contiguous() | ||
) | ||
return qweight | ||
|
||
|
||
class WQLinear_GEMVFast(torch.nn.Module): | ||
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): | ||
super().__init__() | ||
|
||
self.in_features = in_features | ||
self.out_features = out_features | ||
self.w_bit = w_bit | ||
self.group_size = group_size if group_size != -1 else in_features | ||
self.split_k_iters = 8 | ||
self.interleave = 4 | ||
|
||
# quick sanity check (make sure aligment) | ||
assert self.in_features % self.group_size == 0 | ||
assert out_features % (32 // self.w_bit) == 0 | ||
pack_num = 32 // self.w_bit | ||
int16_pack_num = 16 // self.w_bit | ||
|
||
assert out_features % (self.interleave) == 0 | ||
self.register_buffer( | ||
"qweight", | ||
torch.zeros( | ||
( | ||
out_features // self.interleave, | ||
in_features // int16_pack_num * self.interleave, | ||
), | ||
dtype=torch.int16, | ||
device=dev, | ||
), | ||
) | ||
self.register_buffer( | ||
"scales", | ||
torch.zeros( | ||
( | ||
calculate_zeros_width(in_features, self.group_size) * pack_num, | ||
out_features, | ||
), | ||
dtype=torch.float16, | ||
device=dev, | ||
), | ||
) | ||
self.register_buffer( | ||
"qzeros", | ||
torch.zeros( | ||
( | ||
calculate_zeros_width(in_features, self.group_size) * pack_num, | ||
out_features, | ||
), | ||
dtype=torch.float16, | ||
device=dev, | ||
), | ||
) | ||
|
||
if bias: | ||
self.register_buffer( | ||
"bias", torch.zeros((out_features), dtype=torch.float16, device=dev) | ||
) | ||
else: | ||
self.bias = None | ||
|
||
@classmethod | ||
def from_linear( | ||
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None | ||
): | ||
awq_linear = cls( | ||
w_bit, | ||
group_size, | ||
linear.in_features, | ||
linear.out_features, | ||
linear.bias is not None, | ||
linear.weight.device, | ||
) | ||
if init_only: | ||
return awq_linear | ||
|
||
# need scales and zeros info for real quantization | ||
assert scales is not None and zeros is not None | ||
scale_zeros = zeros * scales | ||
|
||
pack_num = 32 // awq_linear.w_bit | ||
qscales = torch.zeros( | ||
( | ||
scales.shape[0], | ||
calculate_zeros_width(linear.in_features, group_size) * pack_num, | ||
), | ||
dtype=torch.float16, | ||
device=scales.device, | ||
) | ||
qscales[:, : scales.shape[1]] = scales | ||
# awq_linear.scales = scales.clone().half() | ||
awq_linear.scales = qscales.transpose(1, 0).contiguous() | ||
if linear.bias is not None: | ||
awq_linear.bias = linear.bias.clone().half() | ||
|
||
intweight = [] | ||
for idx in range(awq_linear.in_features): | ||
intweight.append( | ||
torch.round( | ||
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) | ||
/ qscales[:, idx // group_size] | ||
).to(torch.int)[:, None] | ||
) | ||
intweight = torch.cat(intweight, dim=1) | ||
intweight = intweight.to(dtype=torch.int32) | ||
awq_linear.qweight = pack_intweight( | ||
intweight.contiguous(), interleave=4, kstride=64 | ||
) | ||
|
||
zeros = zeros.to(dtype=torch.int32) | ||
qzeros = torch.zeros_like(qscales) | ||
|
||
qzeros[:, : scales.shape[1]] = -( | ||
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32)) | ||
).to(torch.float16) | ||
awq_linear.qzeros = qzeros.transpose(1, 0).contiguous() | ||
|
||
return awq_linear | ||
|
||
@torch.no_grad() | ||
def forward(self, x): | ||
inputs = x | ||
if inputs.numel() / inputs.shape[-1] < 8: | ||
out = awq_v2_ext.gemv_forward_cuda_decode( | ||
inputs, | ||
self.qweight, | ||
self.scales, | ||
self.qzeros, | ||
inputs.numel() // inputs.shape[-1], | ||
self.out_features, | ||
self.in_features, | ||
self.group_size, | ||
) | ||
else: | ||
out = awq_v2_ext.gemm_forward_cuda_prefill( | ||
inputs, self.qweight, self.scales, self.qzeros | ||
) | ||
out = out + self.bias if self.bias is not None else out | ||
|
||
return out |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters