Skip to content

Commit

Permalink
New optimized kernels (#365)
Browse files Browse the repository at this point in the history
  • Loading branch information
casper-hansen authored Feb 24, 2024
1 parent 6b7992a commit 68c727a
Show file tree
Hide file tree
Showing 7 changed files with 260 additions and 20 deletions.
18 changes: 13 additions & 5 deletions awq/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,17 @@
from huggingface_hub import snapshot_download
from transformers.modeling_utils import shard_checkpoint

from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin, marlin_post_init
from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.modules.linear import (
WQLinear_GEMM,
WQLinear_GEMV,
WQLinear_Marlin,
WQLinear_Exllama,
WQLinear_ExllamaV2,
WQLinear_GEMVFast,
marlin_post_init,
exllama_post_init,
exllamav2_post_init,
)
from awq.utils.module import (
get_named_linears,
set_op_by_name,
Expand Down Expand Up @@ -541,6 +547,8 @@ def _load_quantized_modules(
q_linear_module = WQLinear_GEMM
elif version == "gemv":
q_linear_module = WQLinear_GEMV
elif version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

q_linear = q_linear_module.from_linear(
module, quant_config.w_bit, quant_config.q_group_size, True
Expand Down
7 changes: 4 additions & 3 deletions awq/modules/linear/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .exllama import WQLinear_Exllama
from .exllamav2 import WQLinear_ExllamaV2
from .exllama import WQLinear_Exllama, exllama_post_init
from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin
from .marlin import WQLinear_Marlin, marlin_post_init
from .gemv_fast import WQLinear_GEMVFast
209 changes: 209 additions & 0 deletions awq/modules/linear/gemv_fast.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import torch

try:
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels)

AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False


def make_divisible(c, divisor):
return (c + divisor - 1) // divisor


def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError

base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width


def pack_intweight(unpacked_qweight, interleave, kstride):
# unpacked_qweight: [N, K]
N = unpacked_qweight.shape[0]
K = unpacked_qweight.shape[1]

Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)

# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
Packed_Kernel = Packed_Kernel.reshape(N, K)

# interleaving every four rows
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, interleave, K // kstride, kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, K // kstride, kstride, interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel = (
Packed_Kernel[..., 0]
| (Packed_Kernel[..., 1] << 4)
| (Packed_Kernel[..., 2] << 8)
| (Packed_Kernel[..., 3] << 12)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
qweight = (
torch.tensor(Packed_Kernel.astype("int16"))
.to(unpacked_qweight.device)
.contiguous()
)
return qweight


class WQLinear_GEMVFast(torch.nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()

self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.split_k_iters = 8
self.interleave = 4

# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = 32 // self.w_bit
int16_pack_num = 16 // self.w_bit

assert out_features % (self.interleave) == 0
self.register_buffer(
"qweight",
torch.zeros(
(
out_features // self.interleave,
in_features // int16_pack_num * self.interleave,
),
dtype=torch.int16,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(
calculate_zeros_width(in_features, self.group_size) * pack_num,
out_features,
),
dtype=torch.float16,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
calculate_zeros_width(in_features, self.group_size) * pack_num,
out_features,
),
dtype=torch.float16,
device=dev,
),
)

if bias:
self.register_buffer(
"bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
)
else:
self.bias = None

@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only:
return awq_linear

# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales

pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(
scales.shape[0],
calculate_zeros_width(linear.in_features, group_size) * pack_num,
),
dtype=torch.float16,
device=scales.device,
)
qscales[:, : scales.shape[1]] = scales
# awq_linear.scales = scales.clone().half()
awq_linear.scales = qscales.transpose(1, 0).contiguous()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()

intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ qscales[:, idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
awq_linear.qweight = pack_intweight(
intweight.contiguous(), interleave=4, kstride=64
)

zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros_like(qscales)

qzeros[:, : scales.shape[1]] = -(
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
).to(torch.float16)
awq_linear.qzeros = qzeros.transpose(1, 0).contiguous()

return awq_linear

@torch.no_grad()
def forward(self, x):
inputs = x
if inputs.numel() / inputs.shape[-1] < 8:
out = awq_v2_ext.gemv_forward_cuda_decode(
inputs,
self.qweight,
self.scales,
self.qzeros,
inputs.numel() // inputs.shape[-1],
self.out_features,
self.in_features,
self.group_size,
)
else:
out = awq_v2_ext.gemm_forward_cuda_prefill(
inputs, self.qweight, self.scales, self.qzeros
)
out = out + self.bias if self.bias is not None else out

return out
14 changes: 10 additions & 4 deletions awq/quantize/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.utils.utils import clear_memory, get_best_device
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin
from awq.modules.linear import (
WQLinear_GEMM,
WQLinear_GEMV,
WQLinear_Marlin,
WQLinear_GEMVFast,
)
from awq.utils.module import (
append_str_prefix,
get_op_name,
Expand Down Expand Up @@ -200,6 +203,9 @@ def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):

elif self.version == "marlin":
q_linear_module = WQLinear_Marlin

elif self.version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast

else:
raise ValueError(f"Unknown version {self.version}")
Expand Down Expand Up @@ -466,6 +472,7 @@ def forward(self, *args, **kwargs):
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
modules[0] = modules[0].module # restore

# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
Expand All @@ -474,7 +481,6 @@ def forward(self, *args, **kwargs):
layer_kwargs.pop("input_ids")

del samples
modules[0] = modules[0].module # restore
inps = inps[0]

modules[0] = modules[0].cpu()
Expand Down
26 changes: 21 additions & 5 deletions awq/utils/fused_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import torch

from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin
from awq.modules.linear.exllama import WQLinear_Exllama
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
from awq.modules.linear import (
WQLinear_GEMM,
WQLinear_GEMV,
WQLinear_Marlin,
WQLinear_Exllama,
WQLinear_ExllamaV2,
WQLinear_GEMVFast,
)


def prepare_correct_devices(next_layer, hidden_states, mask):
Expand Down Expand Up @@ -73,6 +76,8 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
q_linear = WQLinear_ExllamaV2
elif isinstance(q_proj, WQLinear_Marlin):
q_linear = WQLinear_Marlin
elif isinstance(q_proj, WQLinear_GEMVFast):
q_linear = WQLinear_GEMVFast

qkv_layer = q_linear(
q_proj.w_bit,
Expand Down Expand Up @@ -132,6 +137,17 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
# workspace is created in post_init
elif isinstance(q_proj, WQLinear_GEMVFast):
qkv_layer.qweight = torch.cat(
[q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0
)
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
).contiguous()
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
).contiguous()
qkv_layer.split_k_iters = q_proj.split_k_iters

qkv_layer.bias = bias

Expand Down
5 changes: 3 additions & 2 deletions examples/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ def run_round(generator, model_path, quant_file, n_generate, input_ids, batch_si
raise RuntimeError(ex)

total_memory_used = 0
memory_pct = 100
if successful_generate:
# number of tokens in context / time for processing context * batch size
prefill_tokens_per_second = input_ids.shape[1] / context_time * batch_size
prefill_tokens_per_second = round(input_ids.shape[1] / context_time * batch_size, 2)
# 1 second / median time per token in seconds * batch size
decode_tokens_per_second = 1 / np.median(generate_time) * batch_size
decode_tokens_per_second = round(1 / np.median(generate_time) * batch_size, 2)

print(f" ** Speed (Prefill): {prefill_tokens_per_second:.2f} tokens/second")
print(f" ** Speed (Decode): {decode_tokens_per_second:.2f} tokens/second")
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,6 @@ def get_kernels_whl_url(
"transformers>=4.35.0",
"tokenizers>=0.12.1",
"typing_extensions>=4.8.0",
"triton",
"accelerate",
"datasets",
"zstandard",
Expand Down

0 comments on commit 68c727a

Please sign in to comment.