diff --git a/QUANTIZE.md b/QUANTIZE.md new file mode 100644 index 000000000..f061b19da --- /dev/null +++ b/QUANTIZE.md @@ -0,0 +1,35 @@ +## Quantization to 4-bit for CPU inference + +First, make sure to install all [dependencies](../INSTALL.md). + +Select the model to quantize: +```bash +MODEL=h2ogpt-oig-oasst1-512-6.9b +#MODEL=h2ogpt-oasst1-512-12b +#MODEL=h2ogpt-oasst1-512-20b +``` + +Run the conversion: +```bash +PYTHONPATH=. CUDA_VISIBLE_DEVICES=0 python \ + quantize/neox.py h2oai/${MODEL} wikitext2 \ + --wbits 4 \ + --save ${MODEL}-4bit.pt +``` + +Now test the model: +```bash +CUDA_VISIBLE_DEVICES=0 python \ + quantize/inference.py h2oai/${MODEL} \ + --wbits 4 \ + --load ${MODEL}-4bit.pt \ + --text "Tell me a joke about cookies." +``` + +FIXME: creates garbage output +``` +The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. +Setting `pad_token_id` to `eos_token_id`:0 for open-end generation. +Tell me a joke about cookies.� emitted column Sullivan Meatrah thinkers TemplateSLavorable homologous beat qubit WD differentiallyabstractKBchu + 260 econ environments unitaryimage endorse physicistisksaines observables preference euthan Creation 580 blinkowa metrics extrac lowered Raz proportions numerically claimant Plugin +``` \ No newline at end of file diff --git a/quantize/gptq.py b/quantize/gptq.py new file mode 100644 index 000000000..530fb9be9 --- /dev/null +++ b/quantize/gptq.py @@ -0,0 +1,251 @@ +import math +import time + +import torch +import torch.nn as nn +import transformers +import quant +from texttable import Texttable +from utils import torch_snr_error + +torch.backends.cuda.matmul.allow_tf32 = False +torch.backends.cudnn.allow_tf32 = False + + +class Observer: + def __init__(self, topk=32): + self.loss_list = [] + self.topk = topk + + def submit(self, name: str, layerid: int, gptq, error: float): + + item = (name, layerid, {"gptq": gptq, "error": error}) + + if len(self.loss_list) < self.topk: + self.loss_list.append(item) + return + + min_error = error + min_idx = -1 + for idx, data in enumerate(self.loss_list): + if min_error > data[2]["error"]: + min_idx = idx + min_error = data[2]["error"] + + if min_idx >= 0: + self.loss_list[min_idx] = item + + def print(self): + self.loss_list = sorted( + self.loss_list, key=lambda s: s[2]["error"], reverse=True + ) + + table = Texttable() + + table.header(["name", "error"]) + table.set_cols_dtype(["t", "f"]) + + for item in self.loss_list: + table.add_row([f"{item[0]}.{item[1]}", item[2]["error"]]) + print(table.draw()) + print("\n") + + def items(self): + return self.loss_list + + +class GPTQ: + def __init__(self, layer, observe=False): + self.layer = layer + self.dev = self.layer.weight.device + W = layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + self.rows = W.shape[0] + self.columns = W.shape[1] + self.H = torch.zeros((self.columns, self.columns), device=self.dev) + self.nsamples = 0 + self.quantizer = quant.Quantizer() + self.observe = observe + + def add_batch(self, inp, out): + # Hessian H = 2 X XT + λ I + if self.observe: + self.inp1 = inp + self.out1 = out + else: + self.inp1 = None + self.out1 = None + + if len(inp.shape) == 2: + inp = inp.unsqueeze(0) + tmp = inp.shape[0] + if isinstance(self.layer, nn.Linear) or isinstance( + self.layer, transformers.Conv1D + ): + if len(inp.shape) == 3: + inp = inp.reshape((-1, inp.shape[-1])) + inp = inp.t() + if isinstance(self.layer, nn.Conv2d): + unfold = nn.Unfold( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.stride, + ) + inp = unfold(inp) + inp = inp.permute([1, 0, 2]) + inp = inp.flatten(1) + self.H *= self.nsamples / (self.nsamples + tmp) + self.nsamples += tmp + # inp = inp.float() + inp = math.sqrt(2 / self.nsamples) * inp.float() + # self.H += 2 / self.nsamples * inp.matmul(inp.t()) + self.H += inp.matmul(inp.t()) + + def print_loss(self, name, q_weight, weight_error, timecost): + table = Texttable() + name += " " * (16 - len(name)) + + table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) + + # assign weight + self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( + self.layer.weight.data.dtype + ) + + if self.inp1 is not None: + # quantize input to int8 + quantizer = quant.Quantizer() + quantizer.configure(8, perchannel=False, sym=True, mse=False) + quantizer.find_params(self.inp1) + q_in = quantizer.quantize(self.inp1).type(torch.float16) + q_out = self.layer(q_in) + + # get kinds of SNR + q_SNR = torch_snr_error(q_out, self.out1).item() + fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() + else: + q_SNR = "-" + fp_SNR = "-" + + table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) + print(table.draw().split("\n")[-2]) + + def fasterquant( + self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name="" + ): + self.layer.to(self.dev) + + W = self.layer.weight.data.clone() + if isinstance(self.layer, nn.Conv2d): + W = W.flatten(1) + if isinstance(self.layer, transformers.Conv1D): + W = W.t() + W = W.float() + + tick = time.time() + + if not self.quantizer.ready(): + self.quantizer.find_params(W, weight=True) + + H = self.H + if not self.observe: + del self.H + dead = torch.diag(H) == 0 + H[dead, dead] = 1 + W[:, dead] = 0 + + if actorder: + perm = torch.argsort(torch.diag(H), descending=True) + W = W[:, perm] + H = H[perm][:, perm] + + Losses = torch.zeros_like(W) + Q = torch.zeros_like(W) + + damp = percdamp * torch.mean(torch.diag(H)) + diag = torch.arange(self.columns, device=self.dev) + H[diag, diag] += damp + H = torch.linalg.cholesky(H) + H = torch.cholesky_inverse(H) + H = torch.linalg.cholesky(H, upper=True) + Hinv = H + + g_idx = [] + scale = [] + zero = [] + now_idx = 1 + + for i1 in range(0, self.columns, blocksize): + i2 = min(i1 + blocksize, self.columns) + count = i2 - i1 + + W1 = W[:, i1:i2].clone() + Q1 = torch.zeros_like(W1) + Err1 = torch.zeros_like(W1) + Losses1 = torch.zeros_like(W1) + Hinv1 = Hinv[i1:i2, i1:i2] + + for i in range(count): + w = W1[:, i] + d = Hinv1[i, i] + + if groupsize != -1: + if (i1 + i) % groupsize == 0: + self.quantizer.find_params( + W[:, (i1 + i) : (i1 + i + groupsize)], weight=True + ) + + if ((i1 + i) // groupsize) - now_idx == -1: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + now_idx += 1 + + q = self.quantizer.quantize(w.unsqueeze(1)).flatten() + Q1[:, i] = q + Losses1[:, i] = (w - q) ** 2 / d ** 2 + + err1 = (w - q) / d + W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) + Err1[:, i] = err1 + + Q[:, i1:i2] = Q1 + Losses[:, i1:i2] = Losses1 / 2 + + W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) + + torch.cuda.synchronize() + error = torch.sum(Losses).item() + + groupsize = groupsize if groupsize != -1 else self.columns + g_idx = [i // groupsize for i in range(self.columns)] + g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) + if actorder: + invperm = torch.argsort(perm) + Q = Q[:, invperm] + g_idx = g_idx[invperm] + + if isinstance(self.layer, transformers.Conv1D): + Q = Q.t() + + self.print_loss( + name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) + ) + + if scale == []: + scale.append(self.quantizer.scale) + zero.append(self.quantizer.zero) + scale = torch.cat(scale, dim=1) + zero = torch.cat(zero, dim=1) + return scale, zero, g_idx, error + + def free(self): + self.inp1 = None + self.out1 = None + self.H = None + self.Losses = None + self.Trace = None + torch.cuda.empty_cache() diff --git a/quantize/inference.py b/quantize/inference.py new file mode 100644 index 000000000..24d905971 --- /dev/null +++ b/quantize/inference.py @@ -0,0 +1,123 @@ +import argparse + +import torch +import torch.nn as nn +import quant + +from gptq import GPTQ +from utils import find_layers, DEV, set_seed, get_wikitext2, get_ptb, get_c4, get_ptb_new, get_c4_new, get_loaders +import transformers +from transformers import AutoTokenizer + + +def get_llama(model): + + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import LlamaForCausalLM + model = LlamaForCausalLM.from_pretrained(model, torch_dtype='auto') + model.seqlen = 2048 + return model + + +def load_quant(model, checkpoint, wbits, groupsize=-1, fused_mlp=True, eval=True, warmup_autotune=True): + from transformers import GPTNeoXConfig, GPTNeoXForCausalLM + + config = GPTNeoXConfig.from_pretrained(model) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + transformers.modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = GPTNeoXForCausalLM(config) + torch.set_default_dtype(torch.float) + if eval: + model = model.eval() + layers = find_layers(model) + for name in ['lm_head']: + if name in layers: + del layers[name] + quant.make_quant_linear(model, layers, wbits, groupsize) + + del layers + + print('Loading model ...') + if checkpoint.endswith('.safetensors'): + from safetensors.torch import load_file as safe_load + model.load_state_dict(safe_load(checkpoint), strict=False) + else: + model.load_state_dict(torch.load(checkpoint), strict=False) + + quant.make_quant_attn(model) + if eval and fused_mlp: + quant.make_fused_mlp(model) + + if warmup_autotune: + quant.autotune_warmup_linear(model, transpose=not (eval)) + if eval and fused_mlp: + quant.autotune_warmup_fused(model) + model.seqlen = 2048 + print('Done.') + + return model + + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + + parser.add_argument('model', type=str, help='llama model to load') + parser.add_argument('--wbits', type=int, default=16, choices=[2, 3, 4, 8, 16], help='#bits to use for quantization; use 16 for evaluating base model.') + parser.add_argument('--groupsize', type=int, default=-1, help='Groupsize to use for quantization; default uses full row.') + parser.add_argument('--load', type=str, default='', help='Load quantized model.') + + parser.add_argument('--text', type=str, help='input text') + + parser.add_argument('--min_length', type=int, default=10, help='The minimum length of the sequence to be generated.') + + parser.add_argument('--max_length', type=int, default=50, help='The maximum length of the sequence to be generated.') + + parser.add_argument('--top_p', + type=float, + default=0.95, + help='If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation.') + + parser.add_argument('--temperature', type=float, default=0.8, help='The value used to module the next token probabilities.') + + parser.add_argument('--device', type=int, default=-1, help='The device used to load the model when using safetensors. Default device is "cpu" or specify, 0,1,2,3,... for GPU device.') + + args = parser.parse_args() + + if type(args.load) is not str: + args.load = args.load.as_posix() + + if args.load: + model = load_quant(args.model, args.load, args.wbits, args.groupsize) + else: + model = get_llama(args.model) + model.eval() + + model.to(DEV) + tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=True) + input_ids = tokenizer.encode(args.text, return_tensors="pt").to(DEV) + + with torch.no_grad(): + generated_ids = model.generate( + input_ids, + do_sample=True, + min_length=args.min_length, + max_length=args.max_length, + top_p=args.top_p, + temperature=args.temperature, + ) + print(tokenizer.decode([el.item() for el in generated_ids[0]])) diff --git a/quantize/neox.py b/quantize/neox.py new file mode 100644 index 000000000..70d59d167 --- /dev/null +++ b/quantize/neox.py @@ -0,0 +1,552 @@ +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +import quant + +from gptq import GPTQ, Observer +from utils import ( + find_layers, + DEV, + set_seed, + get_wikitext2, + get_ptb, + get_c4, + get_ptb_new, + get_c4_new, + get_loaders, + export_quant_table, + gen_conditions, +) +from texttable import Texttable + + +def get_neox(model, seqlen=-1): + def skip(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + from transformers import GPTNeoXForCausalLM + + model = GPTNeoXForCausalLM.from_pretrained(model, torch_dtype=torch.float16) + model.seqlen = seqlen if seqlen != -1 else model.config.max_position_embeddings + return model + + +@torch.no_grad() +def neox_sequential(model, dataloader, dev): + print("Starting ...") + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.gpt_neox.layers + + model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (args.nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {"i": 0, "attention_mask": None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache["i"]] = inp + cache["i"] += 1 + cache["attention_mask"] = kwargs["attention_mask"] + cache["position_ids"] = kwargs["position_ids"] + raise ValueError + + layers[0] = Catcher(layers[0]) + for batch in dataloader: + try: + model(batch[0].to(dev)) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache["attention_mask"] + position_ids = cache["position_ids"] + + print("Ready.") + + quantizers = {} + observer = Observer() + for i in range(len(layers)): + + print(f"Quantizing layer {i+1}/{len(layers)}..") + print("+------------------+--------------+------------+-----------+-------+") + print("| name | weight_error | fp_inp_SNR | q_inp_SNR | time |") + print("+==================+==============+============+===========+=======+") + + layer = layers[i].to(dev) + full = find_layers(layer) + sequential = [list(full.keys())] + + for names in sequential: + subset = {n: full[n] for n in names} + gptq = {} + for name in subset: + gptq[name] = GPTQ(subset[name], observe=False) + gptq[name].quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False + ) + + def add_batch(name): + def tmp(_, inp, out): + gptq[name].add_batch(inp[0].data, out.data) + + return tmp + + handles = [] + for name in subset: + handles.append(subset[name].register_forward_hook(add_batch(name))) + for j in range(args.nsamples): + outs[j] = layer( + inps[j].unsqueeze(0), + attention_mask=attention_mask, + position_ids=position_ids, + )[0] + for h in handles: + h.remove() + + for name in subset: + scale, zero, g_idx, error = gptq[name].fasterquant( + percdamp=args.percdamp, + groupsize=args.groupsize, + actorder=args.act_order, + name=name, + ) + quantizers["gpt_neox.layers.%d.%s" % (i, name)] = ( + gptq[name].quantizer.cpu(), + scale.cpu(), + zero.cpu(), + g_idx.cpu(), + args.wbits, + args.groupsize, + ) + gptq[name].free() + + for j in range(args.nsamples): + outs[j] = layer( + inps[j].unsqueeze(0), + attention_mask=attention_mask, + position_ids=position_ids, + )[0] + + layers[i] = layer.cpu() + del layer + del gptq + torch.cuda.empty_cache() + + inps, outs = outs, inps + print("+------------------+--------------+------------+-----------+-------+") + print("\n") + + model.config.use_cache = use_cache + + return quantizers + + +@torch.no_grad() +def neox_eval(model, testenc, dev): + print("Evaluating ...") + + testenc = testenc.input_ids + nsamples = testenc.numel() // model.seqlen + + use_cache = model.config.use_cache + model.config.use_cache = False + layers = model.gpt_neox.layers + + model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(dev) + layers[0] = layers[0].to(dev) + + dtype = next(iter(model.parameters())).dtype + inps = torch.zeros( + (nsamples, model.seqlen, model.config.hidden_size), dtype=dtype, device=dev + ) + cache = {"i": 0, "attention_mask": None} + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + + def forward(self, inp, **kwargs): + inps[cache["i"]] = inp + cache["i"] += 1 + cache["attention_mask"] = kwargs["attention_mask"] + cache["position_ids"] = kwargs["position_ids"] + raise ValueError + + layers[0] = Catcher(layers[0]) + for i in range(nsamples): + batch = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)].to(dev) + try: + model(batch) + except ValueError: + pass + layers[0] = layers[0].module + + layers[0] = layers[0].cpu() + model.gpt_neox.embed_in = model.gpt_neox.embed_in.cpu() + torch.cuda.empty_cache() + + outs = torch.zeros_like(inps) + attention_mask = cache["attention_mask"] + position_ids = cache["position_ids"] + + for i in range(len(layers)): + print(i) + layer = layers[i].to(dev) + + if args.nearest: + subset = find_layers(layer) + for name in subset: + quantizer = quant.Quantizer() + quantizer.configure( + args.wbits, perchannel=True, sym=args.sym, mse=False + ) + W = subset[name].weight.data + quantizer.find_params(W, weight=True) + subset[name].weight.data = quantizer.quantize(W).to( + next(iter(layer.parameters())).dtype + ) + + for j in range(nsamples): + outs[j] = layer( + inps[j].unsqueeze(0), + attention_mask=attention_mask, + position_ids=position_ids, + )[0] + layers[i] = layer.cpu() + del layer + torch.cuda.empty_cache() + inps, outs = outs, inps + + model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(dev) + model.embed_out = model.embed_out.to(dev) + + testenc = testenc.to(dev) + nlls = [] + for i in range(nsamples): + hidden_states = inps[i].unsqueeze(0) + hidden_states = model.gpt_neox.final_layer_norm(hidden_states) + lm_logits = model.embed_out(hidden_states) + shift_logits = lm_logits[:, :-1, :].contiguous() + shift_labels = testenc[:, (i * model.seqlen) : ((i + 1) * model.seqlen)][:, 1:] + loss_fct = nn.CrossEntropyLoss() + loss = loss_fct( + shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) + ) + neg_log_likelihood = loss.float() * model.seqlen + nlls.append(neg_log_likelihood) + ppl = torch.exp(torch.stack(nlls).sum() / (nsamples * model.seqlen)) + print(ppl.item()) + + model.config.use_cache = use_cache + + +# TODO: perform packing on GPU +def neox_pack(model, quantizers, wbits, groupsize): + layers = find_layers(model) + layers = {n: layers[n] for n in quantizers} + quant.make_quant_linear(model, quantizers, wbits, groupsize) + qlayers = find_layers(model, [quant.QuantLinear]) + print("Packing ...") + for name in qlayers: + print(name) + quantizers[name], scale, zero, g_idx, _, _ = quantizers[name] + qlayers[name].pack(layers[name], scale, zero, g_idx) + print("Done.") + return model + + +def load_quant(model, checkpoint, wbits, groupsize=-1, eval=True, warmup_autotune=True): + from transformers import GPTNeoXConfig, GPTNeoXForCausalLM, modeling_utils + + config = GPTNeoXConfig.from_pretrained(model) + + def noop(*args, **kwargs): + pass + + torch.nn.init.kaiming_uniform_ = noop + torch.nn.init.uniform_ = noop + torch.nn.init.normal_ = noop + + torch.set_default_dtype(torch.half) + modeling_utils._init_weights = False + torch.set_default_dtype(torch.half) + model = GPTNeoXForCausalLM(config) + torch.set_default_dtype(torch.float) + if eval: + model = model.eval() + layers = find_layers(model) + for name in ["embed_in", "embed_out"]: + if name in layers: + del layers[name] + quant.make_quant_linear(model, layers, wbits, groupsize) + + del layers + + print("Loading model ...") + if checkpoint.endswith(".safetensors"): + from safetensors.torch import load_file as safe_load + + model.load_state_dict(safe_load(checkpoint)) + else: + model.load_state_dict(torch.load(checkpoint)) + + if warmup_autotune: + quant.autotune_warmup_linear(model, transpose=not (eval)) + + model.seqlen = model.config.max_position_embeddings + print("Done.") + + return model + + +def neox_multigpu(model, gpus): + model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(gpus[0]) + model.gpt_neox.final_layer_norm = model.gpt_neox.final_layer_norm.to(gpus[-1]) + import copy + + model.embed_out = copy.deepcopy(model.embed_out).to(gpus[-1]) + + cache = {"mask": None} + + class MoveModule(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.dev = next(iter(self.module.parameters())).device + + def forward(self, *inp, **kwargs): + inp = list(inp) + if inp[0].device != self.dev: + inp[0] = inp[0].to(self.dev) + if cache["mask"] is None or cache["mask"].device != self.dev: + cache["mask"] = kwargs["attention_mask"].to(self.dev) + kwargs["attention_mask"] = cache["mask"] + tmp = self.module(*inp, **kwargs) + return tmp + + layers = model.gpt_neox.layers + pergpu = math.ceil(len(layers) / len(gpus)) + for i in range(len(layers)): + layers[i] = MoveModule(layers[i].to(gpus[i // pergpu])) + + model.gpus = gpus + + +def benchmark(model, input_ids, check=False): + input_ids = input_ids.to(model.gpus[0] if hasattr(model, "gpus") else DEV) + torch.cuda.synchronize() + + cache = {"past": None} + + def clear_past(i): + def tmp(layer, inp, out): + if cache["past"]: + cache["past"][i] = None + + return tmp + + for i, layer in enumerate(model.gpt_neox.layers): + layer.register_forward_hook(clear_past(i)) + + print("Benchmarking ...") + + if check: + loss = nn.CrossEntropyLoss() + tot = 0.0 + + def sync(): + if hasattr(model, "gpus"): + for gpu in model.gpus: + torch.cuda.synchronize(gpu) + else: + torch.cuda.synchronize() + + max_memory = 0 + with torch.no_grad(): + attention_mask = torch.ones((1, input_ids.numel()), device=DEV) + times = [] + for i in range(input_ids.numel()): + tick = time.time() + out = model( + input_ids[:, i : i + 1], + past_key_values=cache["past"], + attention_mask=attention_mask[:, : (i + 1)].reshape((1, -1)), + ) + sync() + times.append(time.time() - tick) + print(i, times[-1]) + max_memory = max(max_memory, torch.cuda.memory_allocated() / 1024 / 1024) + if check and i != input_ids.numel() - 1: + tot += loss( + out.logits[0].to(DEV), input_ids[:, (i + 1)].to(DEV) + ).float() + cache["past"] = list(out.past_key_values) + del out + sync() + print("Median:", np.median(times)) + if check: + print("PPL:", torch.exp(tot / (input_ids.numel() - 1)).item()) + print("max memory(MiB):", max_memory) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + + parser.add_argument("model", type=str, help="llama model to load") + parser.add_argument( + "dataset", + type=str, + choices=["wikitext2", "ptb", "c4"], + help="Where to extract calibration data from.", + ) + parser.add_argument( + "--seed", type=int, default=0, help="Seed for sampling the calibration data." + ) + parser.add_argument( + "--nsamples", type=int, default=128, help="Number of calibration data samples." + ) + parser.add_argument( + "--percdamp", + type=float, + default=0.01, + help="Percent of the average Hessian diagonal to use for dampening.", + ) + parser.add_argument( + "--nearest", action="store_true", help="Whether to run the RTN baseline." + ) + parser.add_argument( + "--wbits", + type=int, + default=16, + choices=[2, 3, 4, 8, 16], + help="bits to use for quantization; use 16 for evaluating base model.", + ) + parser.add_argument( + "--seqlen", + type=int, + default=-1, + help="seqlen to use for quantization; default uses full seqlen", + ) + parser.add_argument( + "--trits", action="store_true", help="Whether to use trits for quantization." + ) + parser.add_argument( + "--groupsize", + type=int, + default=-1, + help="Groupsize to use for quantization; default uses full row.", + ) + parser.add_argument("--eval", action="store_true", help="evaluate quantized model.") + parser.add_argument( + "--save", + type=str, + default="", + help="Save quantized checkpoint under this name.", + ) + parser.add_argument( + "--save_safetensors", + type=str, + default="", + help="Save quantized `.safetensors` checkpoint under this name.", + ) + parser.add_argument("--load", type=str, default="", help="Load quantized model.") + parser.add_argument( + "--benchmark", + type=int, + default=0, + help="Number of tokens to use for benchmarking.", + ) + parser.add_argument( + "--check", + action="store_true", + help="Whether to compute perplexity during benchmarking for verification.", + ) + parser.add_argument( + "--sym", action="store_true", help="Whether to perform symmetric quantization." + ) + parser.add_argument( + "--act-order", + action="store_true", + help="Whether to apply the activation order GPTQ heuristic", + ) + parser.add_argument( + "--new-eval", action="store_true", help="Whether to use the new PTB and C4 eval" + ) + args = parser.parse_args() + + if type(args.load) is not str: + args.load = args.load.as_posix() + + if args.load: + model = load_quant(args.model, args.load, args.wbits, args.groupsize) + else: + model = get_neox(args.model) + model.eval() + + dataloader, testloader = get_loaders( + args.dataset, + nsamples=args.nsamples, + seed=args.seed, + model=args.model, + seqlen=model.seqlen, + ) + + if not args.load and args.wbits < 16 and not args.nearest: + tick = time.time() + quantizers = neox_sequential(model, dataloader, DEV) + print(time.time() - tick) + + if args.benchmark: + gpus = [torch.device("cuda:%d" % i) for i in range(torch.cuda.device_count())] + if len(gpus) > 1: + neox_multigpu(model, gpus) + else: + model = model.to(DEV) + if args.benchmark: + input_ids = next(iter(dataloader))[0][:, : args.benchmark] + benchmark(model, input_ids, check=args.check) + + if args.eval: + datasets = ["wikitext2", "ptb", "c4"] + if args.new_eval: + datasets = ["wikitext2", "ptb-new", "c4-new"] + for dataset in datasets: + dataloader, testloader = get_loaders( + dataset, seed=args.seed, model=args.model, seqlen=model.seqlen + ) + print(dataset) + neox_eval(model, testloader, DEV) + + if args.save: + neox_pack(model, quantizers, args.wbits, args.groupsize) + torch.save(model.state_dict(), args.save) + + if args.save_safetensors: + neox_pack(model, quantizers, args.wbits, args.groupsize) + from safetensors.torch import save_file as safe_save + + state_dict = model.state_dict() + state_dict = {k: v.clone().contiguous() for k, v in state_dict.items()} + safe_save(state_dict, args.save_safetensors) diff --git a/quantize/quant/__init__.py b/quantize/quant/__init__.py new file mode 100644 index 000000000..cd639a406 --- /dev/null +++ b/quantize/quant/__init__.py @@ -0,0 +1,4 @@ +from .quantizer import Quantizer +from .fused_attn import QuantLlamaAttention, make_quant_attn +from .fused_mlp import QuantLlamaMLP, make_fused_mlp, autotune_warmup_fused +from .quant_linear import QuantLinear, make_quant_linear, autotune_warmup_linear diff --git a/quantize/quant/custom_autotune.py b/quantize/quant/custom_autotune.py new file mode 100644 index 000000000..875c832e8 --- /dev/null +++ b/quantize/quant/custom_autotune.py @@ -0,0 +1,193 @@ +#https://github.com/fpgaminer/GPTQ-triton +""" +Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100. +""" + +import builtins +import math +import time +from typing import Dict + +import triton + + +class Autotuner(triton.KernelInterface): + + def __init__(self, fn, arg_names, configs, key, reset_to_zero, prune_configs_by: Dict = None, nearest_power_of_two: bool = False): + ''' + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs. + 'nearest_power_of_two'(optional): whether to round key arguments to the nearest power of two when caching tuning results + ''' + if not configs: + self.configs = [triton.Config({}, num_warps=4, num_stages=2)] + else: + self.configs = configs + self.key_idx = [arg_names.index(k) for k in key] + self.nearest_power_of_two = nearest_power_of_two + self.cache = {} + # hook to reset all required tensor to zeros before relaunching a kernel + self.hook = lambda args: 0 + if reset_to_zero is not None: + self.reset_idx = [arg_names.index(k) for k in reset_to_zero] + + def _hook(args): + for i in self.reset_idx: + args[i].zero_() + + self.hook = _hook + self.arg_names = arg_names + # prune configs + if prune_configs_by: + perf_model, top_k = prune_configs_by['perf_model'], prune_configs_by['top_k'] + if 'early_config_prune' in prune_configs_by: + early_config_prune = prune_configs_by['early_config_prune'] + else: + perf_model, top_k, early_config_prune = None, None, None + self.perf_model, self.configs_top_k = perf_model, top_k + self.early_config_prune = early_config_prune + self.fn = fn + + def _bench(self, *args, config, **meta): + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError(f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols.") + # augment meta-parameters with tunable ones + current = dict(meta, **config.kwargs) + + def kernel_call(): + if config.pre_hook: + config.pre_hook(self.nargs) + self.hook(args) + self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current) + + try: + # In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses + # PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default + return triton.testing.do_bench(kernel_call, percentiles=(0.5, 0.2, 0.8), rep=40) + except triton.compiler.OutOfResources: + return (float('inf'), float('inf'), float('inf')) + + def run(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + if len(self.configs) > 1: + key = tuple(args[i] for i in self.key_idx) + + # This reduces the amount of autotuning by rounding the keys to the nearest power of two + # In my testing this gives decent results, and greatly reduces the amount of tuning required + if self.nearest_power_of_two: + key = tuple([2**int(math.log2(x) + 0.5) for x in key]) + + if key not in self.cache: + # prune configs + pruned_configs = self.prune_configs(kwargs) + bench_start = time.time() + timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs} + bench_end = time.time() + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.hook(args) + self.configs_timings = timings + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if config.pre_hook is not None: + config.pre_hook(self.nargs) + return self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **kwargs, **config.kwargs) + + def prune_configs(self, kwargs): + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + if len(pruned_configs) > top_k: + est_timing = {config: self.perf_model(**self.nargs, **kwargs, **config.kwargs, num_stages=config.num_stages, num_warps=config.num_warps) for config in pruned_configs} + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + def warmup(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + for config in self.prune_configs(kwargs): + self.fn.warmup( + *args, + num_warps=config.num_warps, + num_stages=config.num_stages, + **kwargs, + **config.kwargs, + ) + self.nargs = None + + +def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False): + """ + Decorator for auto-tuning a :code:`triton.jit`'d function. + .. highlight:: python + .. code-block:: python + @triton.autotune(configs=[ + triton.Config(meta={'BLOCK_SIZE': 128}, num_warps=4), + triton.Config(meta={'BLOCK_SIZE': 1024}, num_warps=8), + ], + key=['x_size'] # the two above configs will be evaluated anytime + # the value of x_size changes + ) + @triton.jit + def kernel(x_ptr, x_size, **META): + BLOCK_SIZE = META['BLOCK_SIZE'] + :note: When all the configurations are evaluated, the kernel will run multiple time. + This means that whatever value the kernel updates will be updated multiple times. + To avoid this undesired behavior, you can use the `reset_to_zero` argument, which + reset the value of the provided tensor to `zero` before running any configuration. + :param configs: a list of :code:`triton.Config` objects + :type configs: list[triton.Config] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs. + :param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs. + :type reset_to_zero: list[str] + """ + + def decorator(fn): + return Autotuner(fn, fn.arg_names, configs, key, reset_to_zero, prune_configs_by, nearest_power_of_two) + + return decorator + + +def matmul248_kernel_config_pruner(configs, nargs): + """ + The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller. + """ + m = max(2**int(math.ceil(math.log2(nargs['M']))), 16) + n = max(2**int(math.ceil(math.log2(nargs['N']))), 16) + k = max(2**int(math.ceil(math.log2(nargs['K']))), 16) + + used = set() + for config in configs: + block_size_m = min(m, config.kwargs['BLOCK_SIZE_M']) + block_size_n = min(n, config.kwargs['BLOCK_SIZE_N']) + block_size_k = min(k, config.kwargs['BLOCK_SIZE_K']) + group_size_m = config.kwargs['GROUP_SIZE_M'] + + if (block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps) in used: + continue + + used.add((block_size_m, block_size_n, block_size_k, group_size_m, config.num_stages, config.num_warps)) + yield triton.Config({ + 'BLOCK_SIZE_M': block_size_m, + 'BLOCK_SIZE_N': block_size_n, + 'BLOCK_SIZE_K': block_size_k, + 'GROUP_SIZE_M': group_size_m + }, + num_stages=config.num_stages, + num_warps=config.num_warps) diff --git a/quantize/quant/fused_attn.py b/quantize/quant/fused_attn.py new file mode 100644 index 000000000..b4e8d464e --- /dev/null +++ b/quantize/quant/fused_attn.py @@ -0,0 +1,123 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.nn import functional as F +from torch.cuda.amp import custom_bwd, custom_fwd +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb +from .quant_linear import * + + +class QuantLlamaAttention(nn.Module): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__( + self, + hidden_size, + num_heads, + qkv_proj, + o_proj, + rotary_emb, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = hidden_size // num_heads + + if (self.head_dim * num_heads) != self.hidden_size: + raise ValueError(f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {num_heads}).") + self.qkv_proj = qkv_proj + self.o_proj = o_proj + self.rotary_emb = rotary_emb + + def _shape(self, tensor, seq_len, bsz): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): + """Input shape: Batch x Time x Channel""" + + bsz, q_len, _ = hidden_states.size() + + qkv_states = self.qkv_proj(hidden_states) + query_states, key_states, value_states = torch.split(qkv_states, self.hidden_size, dim=2) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + kv_seq_len += past_key_value[0].shape[-2] + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + # [bsz, nh, t, hd] + + is_causal = past_key_value is None + if past_key_value is not None: + # reuse k, v, self_attention + key_states = torch.cat([past_key_value[0], key_states], dim=2) + value_states = torch.cat([past_key_value[1], value_states], dim=2) + + if use_cache: + # Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor + # which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this. + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + + past_key_value = (key_states, value_states) if use_cache else None + + with torch.backends.cuda.sdp_kernel(enable_math=False): + attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=is_causal) + + attn_output = attn_output.transpose(1, 2) + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +def make_quant_attn(model): + """ + Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections. + """ + for name, m in model.named_modules(): + if not isinstance(m, LlamaAttention): + continue + + q_proj = m.q_proj + k_proj = m.k_proj + v_proj = m.v_proj + + qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) + qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) + scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) + g_idx = torch.cat([q_proj.g_idx, k_proj.g_idx, v_proj.g_idx], dim=0) + bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None + + qkv_layer = QuantLinear(q_proj.bits, q_proj.groupsize, q_proj.infeatures, q_proj.outfeatures + k_proj.outfeatures + v_proj.outfeatures, True if q_proj.bias is not None else False) + qkv_layer.qweight = qweights + qkv_layer.qzeros = qzeros + qkv_layer.scales = scales + qkv_layer.g_idx = g_idx + qkv_layer.bias = bias + + attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, m.rotary_emb) + + if '.' in name: + parent_name = name.rsplit('.', 1)[0] + child_name = name[len(parent_name) + 1:] + parent = model.get_submodule(parent_name) + else: + parent_name = '' + parent = model + child_name = name + + #print(f"Replacing {name} with quant_attn; parent: {parent_name}, child's name: {child_name}") + + setattr(parent, child_name, attn) diff --git a/quantize/quant/fused_mlp.py b/quantize/quant/fused_mlp.py new file mode 100644 index 000000000..a5e402e38 --- /dev/null +++ b/quantize/quant/fused_mlp.py @@ -0,0 +1,288 @@ +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd +from transformers.models.llama.modeling_llama import LlamaMLP + +try: + import triton + import triton.language as tl + from . import custom_autotune + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 256, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 16, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 16, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), # 3090 + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, + ) + @triton.jit + def fusedmatmul_248_kernel(a_ptr, c_ptr, b1_ptr, scales1_ptr, zeros1_ptr, g1_ptr, b2_ptr, scales2_ptr, zeros2_ptr, g2_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, + stride_cm, stride_cn, stride_scales, stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Computes: C = silu(A * B1) * (A * B2) + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (1, N) float16 + zeros is of shape (1, N//8) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b1_ptrs = b1_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + b2_ptrs = b2_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) + g1_ptrs = g1_ptr + offs_k + g2_ptrs = g2_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales1_ptrs = scales1_ptr + offs_bn[None, :] + scales2_ptrs = scales2_ptr + offs_bn[None, :] + zeros1_ptrs = zeros1_ptr + (offs_bn[None, :] // infearure_per_bits) + zeros2_ptrs = zeros2_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + accumulator2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + for k in range(0, num_pid_k): + g1_idx = tl.load(g1_ptrs) + g2_idx = tl.load(g2_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales1 = tl.load(scales1_ptrs + g1_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + scales2 = tl.load(scales2_ptrs + g2_idx[:, None] * stride_scales) + + zeros1 = tl.load(zeros1_ptrs + g1_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros1 = (zeros1 >> zeros_shifter[None, :]) & maxq + zeros1 = (zeros1 + 1) + + zeros2 = tl.load(zeros2_ptrs + g2_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros2 = (zeros2 >> zeros_shifter[None, :]) & maxq + zeros2 = (zeros2 + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b1 = tl.load(b1_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + b2 = tl.load(b2_ptrs) + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b1 = (b1 >> shifter[:, None]) & maxq # Extract the N-bit values + b1 = (b1 - zeros1) * scales1 # Scale and shift + accumulator1 += tl.dot(a, b1) + + b2 = (b2 >> shifter[:, None]) & maxq + b2 = (b2 - zeros2) * scales2 + accumulator2 += tl.dot(a, b2) + + a_ptrs += BLOCK_SIZE_K + b1_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + b2_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g1_ptrs += BLOCK_SIZE_K + g2_ptrs += BLOCK_SIZE_K + + accumulator1 = silu(accumulator1) + c = accumulator1 * accumulator2 + c = c.to(tl.float16) + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, c, mask=c_mask) + + @triton.jit + def silu(x): + return x * tl.sigmoid(x) +except: + print('triton not installed.') + + +class QuantLlamaMLP(nn.Module): + + def __init__( + self, + gate_proj, + down_proj, + up_proj, + ): + super().__init__() + self.register_buffer('gate_proj_qweight', gate_proj.qweight) + self.register_buffer('gate_proj_scales', gate_proj.scales) + self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) + self.register_buffer('gate_proj_g_idx', gate_proj.g_idx) + self.register_buffer('up_proj_qweight', up_proj.qweight) + self.register_buffer('up_proj_scales', up_proj.scales) + self.register_buffer('up_proj_qzeros', up_proj.qzeros) + self.register_buffer('up_proj_g_idx', up_proj.g_idx) + + self.infeatures = gate_proj.infeatures + self.intermediate_size = gate_proj.outfeatures + self.outfeatures = down_proj.outfeatures + self.bits = gate_proj.bits + self.maxq = gate_proj.maxq + + self.down_proj = down_proj + + def forward(self, x): + return self.down_proj(self.triton_llama_mlp(x)) + + def triton_llama_mlp(self, x): + with torch.cuda.device(x.device): + out_shape = x.shape[:-1] + (self.intermediate_size, ) + x = x.reshape(-1, x.shape[-1]) + M, K = x.shape + N = self.intermediate_size + c = torch.empty((M, N), device=x.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), ) + fusedmatmul_248_kernel[grid](x, c, self.gate_proj_qweight, self.gate_proj_scales, self.gate_proj_qzeros, self.gate_proj_g_idx, self.up_proj_qweight, self.up_proj_scales, + self.up_proj_qzeros, self.up_proj_g_idx, M, N, K, self.bits, self.maxq, x.stride(0), x.stride(1), self.gate_proj_qweight.stride(0), + self.gate_proj_qweight.stride(1), c.stride(0), c.stride(1), self.gate_proj_scales.stride(0), self.gate_proj_qzeros.stride(0)) + c = c.reshape(out_shape) + return c + + def fused2cuda(self): + self.gate_proj_qweight = self.gate_proj_qweight.cuda() + self.gate_proj_scales = self.gate_proj_scales.cuda() + self.gate_proj_qzeros = self.gate_proj_qzeros.cuda() + self.gate_proj_g_idx = self.gate_proj_g_idx.cuda() + self.up_proj_qweight = self.up_proj_qweight.cuda() + self.up_proj_scales = self.up_proj_scales.cuda() + self.up_proj_qzeros = self.up_proj_qzeros.cuda() + self.up_proj_g_idx = self.up_proj_g_idx.cuda() + + def fused2cpu(self): + self.gate_proj_qweight = self.gate_proj_qweight.cpu() + self.gate_proj_scales = self.gate_proj_scales.cpu() + self.gate_proj_qzeros = self.gate_proj_qzeros.cpu() + self.gate_proj_g_idx = self.gate_proj_g_idx.cpu() + self.up_proj_qweight = self.up_proj_qweight.cpu() + self.up_proj_scales = self.up_proj_scales.cpu() + self.up_proj_qzeros = self.up_proj_qzeros.cpu() + self.up_proj_g_idx = self.up_proj_g_idx.cpu() + + +def make_fused_mlp(m, parent_name=''): + """ + Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations. + """ + if isinstance(m, LlamaMLP): + return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) + + for name, child in m.named_children(): + child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") + + if isinstance(child, QuantLlamaMLP): + setattr(m, name, child) + return m + + +def autotune_warmup_fused(model): + """ + Pre-tunes the quantized kernel + """ + from tqdm import tqdm + + kn_values = {} + + for _, m in model.named_modules(): + if not isinstance(m, QuantLlamaMLP): + continue + + k = m.infeatures + n = m.intermediate_size + + m.fused2cuda() + if (k, n) not in kn_values: + kn_values[(k, n)] = m + + print(f'Found {len(kn_values)} unique fused mlp KN values.') + + print('Warming up autotune cache ...') + with torch.no_grad(): + for m in tqdm(range(0, 12)): + m = 2**m # [1, 2048] + for (k, n), (modules) in kn_values.items(): + a = torch.randn(m, k, dtype=torch.float16, device='cuda') + modules.triton_llama_mlp(a) + + for (k, n), (modules) in kn_values.items(): + a = torch.randn(m, k, dtype=torch.float16, device='cuda') + modules.fused2cpu() + del kn_values diff --git a/quantize/quant/quant_linear.py b/quantize/quant/quant_linear.py new file mode 100644 index 000000000..9b1b776b9 --- /dev/null +++ b/quantize/quant/quant_linear.py @@ -0,0 +1,423 @@ +import math +import numpy as np +import torch +import torch.nn as nn +from torch.cuda.amp import custom_bwd, custom_fwd + +try: + import triton + import triton.language as tl + from . import custom_autotune + + # code based https://github.com/fpgaminer/GPTQ-triton + @custom_autotune.autotune( + configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 256, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True, + prune_configs_by={ + 'early_config_prune': custom_autotune.matmul248_kernel_config_pruner, + 'perf_model': None, + 'top_k': None, + }, + ) + @triton.jit + def matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, stride_zeros, + BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, K) float16 + B is of shape (K//8, N) int32 + C is of shape (M, N) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_k + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_bn[None, :] + zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits) + + shifter = (offs_k % infearure_per_bits) * bits + zeros_shifter = (offs_bn % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, num_pid_k): + g_idx = tl.load(g_ptrs) + + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_K + b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk + g_ptrs += BLOCK_SIZE_K + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + @custom_autotune.autotune(configs=[ + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 256, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 128, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=4, num_warps=4), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 32, + 'BLOCK_SIZE_K': 128, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 64, + 'BLOCK_SIZE_K': 64, + 'GROUP_SIZE_M': 8 + }, num_stages=3, num_warps=8), + triton.Config({ + 'BLOCK_SIZE_M': 32, + 'BLOCK_SIZE_N': 128, + 'BLOCK_SIZE_K': 32, + 'GROUP_SIZE_M': 8 + }, num_stages=2, num_warps=4), + ], + key=['M', 'N', 'K'], + nearest_power_of_two=True) + @triton.jit + def transpose_matmul_248_kernel(a_ptr, b_ptr, c_ptr, scales_ptr, zeros_ptr, g_ptr, M, N, K, bits, maxq, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, stride_scales, + stride_zeros, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, GROUP_SIZE_M: tl.constexpr): + """ + Compute the matrix multiplication C = A x B. + A is of shape (M, N) float16 + B is of shape (K//8, N) int32 + C is of shape (M, K) float16 + scales is of shape (G, N) float16 + zeros is of shape (G, N) float16 + g_ptr is of shape (K) int32 + """ + infearure_per_bits = 32 // bits + + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(M, BLOCK_SIZE_M) + num_pid_k = tl.cdiv(K, BLOCK_SIZE_K) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_k + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + (pid % group_size_m) + pid_k = (pid % num_pid_in_group) // group_size_m + + offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) + offs_n = tl.arange(0, BLOCK_SIZE_N) + a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + a_mask = (offs_am[:, None] < M) + # b_ptrs is set up such that it repeats elements along the K axis 8 times + b_ptrs = b_ptr + ((offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N) + g_ptrs = g_ptr + offs_bk + g_idx = tl.load(g_ptrs) + + # shifter is used to extract the N bits of each element in the 32-bit word from B + scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales + zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros + + shifter = (offs_bk % infearure_per_bits) * bits + zeros_shifter = (offs_n % infearure_per_bits) * bits + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32) + + for n in range(0, num_pid_n): + # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop + scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,) + + zeros = (zeros >> zeros_shifter[None, :]) & maxq + zeros = (zeros + 1) + + a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_N) + b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated + + # Now we need to unpack b (which is N-bit values) into 32-bit values + b = (b >> shifter[:, None]) & maxq # Extract the N-bit values + b = (b - zeros) * scales # Scale and shift + b = tl.trans(b) + + accumulator += tl.dot(a, b) + a_ptrs += BLOCK_SIZE_N + b_ptrs += BLOCK_SIZE_N + scales_ptrs += BLOCK_SIZE_N + zeros_ptrs += (BLOCK_SIZE_N // infearure_per_bits) + + c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :] + c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K) + tl.store(c_ptrs, accumulator, mask=c_mask) +except: + print('trioton not installed.') + + +def matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(qweight.shape[1], META['BLOCK_SIZE_N']), ) + matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], input.shape[1], bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + return output + + +def transpose_matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq): + with torch.cuda.device(input.device): + output_dim = (qweight.shape[0] * 32) // bits + output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=torch.float16) + grid = lambda META: (triton.cdiv(input.shape[0], META['BLOCK_SIZE_M']) * triton.cdiv(output_dim, META['BLOCK_SIZE_K']), ) + transpose_matmul_248_kernel[grid](input, qweight, output, scales, qzeros, g_idx, input.shape[0], qweight.shape[1], output_dim, bits, maxq, input.stride(0), input.stride(1), qweight.stride(0), + qweight.stride(1), output.stride(0), output.stride(1), scales.stride(0), qzeros.stride(0)) + return output + + +class QuantLinearFunction(torch.autograd.Function): + + @staticmethod + @custom_fwd(cast_inputs=torch.float16) + def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq): + output = matmul248(input, qweight, scales, qzeros, g_idx, bits, maxq) + ctx.save_for_backward(qweight, scales, qzeros, g_idx) + ctx.bits, ctx.maxq = bits, maxq + return output + + @staticmethod + @custom_bwd + def backward(ctx, grad_output): + qweight, scales, qzeros, g_idx = ctx.saved_tensors + bits, maxq = ctx.bits, ctx.maxq + grad_input = None + + if ctx.needs_input_grad[0]: + grad_input = transpose_matmul248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq) + return grad_input, None, None, None, None, None, None + + +class QuantLinear(nn.Module): + + def __init__(self, bits, groupsize, infeatures, outfeatures, bias): + super().__init__() + if bits not in [2, 4, 8]: + raise NotImplementedError("Only 2,4,8 bits are supported.") + self.infeatures = infeatures + self.outfeatures = outfeatures + self.bits = bits + self.maxq = 2**self.bits - 1 + self.groupsize = groupsize if groupsize != -1 else infeatures + + self.register_buffer('qweight', torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32)) + self.register_buffer('qzeros', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures // 32 * self.bits), dtype=torch.int32)) + self.register_buffer('scales', torch.zeros((math.ceil(infeatures / self.groupsize), outfeatures), dtype=torch.float16)) + self.register_buffer('g_idx', torch.tensor([i // self.groupsize for i in range(infeatures)], dtype=torch.int32)) + if bias: + self.register_buffer('bias', torch.zeros((outfeatures), dtype=torch.float16)) + else: + self.bias = None + + def pack(self, linear, scales, zeros, g_idx=None): + self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx + + scales = scales.t().contiguous() + zeros = zeros.t().contiguous() + scale_zeros = zeros * scales + self.scales = scales.clone().half() + if linear.bias is not None: + self.bias = linear.bias.clone().half() + + intweight = [] + for idx in range(self.infeatures): + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[:, None]) + intweight = torch.cat(intweight, dim=1) + intweight = intweight.t().contiguous() + intweight = intweight.numpy().astype(np.uint32) + qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32) + i = 0 + row = 0 + while row < qweight.shape[0]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qweight[row] |= intweight[j] << (self.bits * (j - i)) + i += 32 // self.bits + row += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qweight = qweight.astype(np.int32) + self.qweight = torch.from_numpy(qweight) + + zeros -= 1 + zeros = zeros.numpy().astype(np.uint32) + qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32) + i = 0 + col = 0 + while col < qzeros.shape[1]: + if self.bits in [2, 4, 8]: + for j in range(i, i + (32 // self.bits)): + qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i)) + i += 32 // self.bits + col += 1 + else: + raise NotImplementedError("Only 2,4,8 bits are supported.") + + qzeros = qzeros.astype(np.int32) + self.qzeros = torch.from_numpy(qzeros) + + def forward(self, x): + out_shape = x.shape[:-1] + (self.outfeatures, ) + out = QuantLinearFunction.apply(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.g_idx, self.bits, self.maxq) + out = out + self.bias if self.bias is not None else out + return out.reshape(out_shape) + + +def make_quant_linear(module, names, bits, groupsize, name=''): + if isinstance(module, QuantLinear): + return + for attr in dir(module): + tmp = getattr(module, attr) + name1 = name + '.' + attr if name != '' else attr + if name1 in names: + delattr(module, attr) + setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features, tmp.bias is not None)) + for name1, child in module.named_children(): + make_quant_linear(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1) + + +def autotune_warmup_linear(model, transpose=False): + """ + Pre-tunes the quantized kernel + """ + from tqdm import tqdm + + kn_values = {} + + for _, m in model.named_modules(): + if not isinstance(m, QuantLinear): + continue + + k = m.infeatures + n = m.outfeatures + + if (k, n) not in kn_values: + kn_values[(k, n)] = (m.qweight.cuda(), m.scales.cuda(), m.qzeros.cuda(), m.g_idx.cuda(), m.bits, m.maxq) + + print(f'Found {len(kn_values)} unique KN Linear values.') + + print('Warming up autotune cache ...') + with torch.no_grad(): + for m in tqdm(range(0, 12)): + m = 2**m # [1, 2048] + for (k, n), (qweight, scales, qzeros, g_idx, bits, maxq) in kn_values.items(): + a = torch.randn(m, k, dtype=torch.float16, device='cuda') + matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) + if transpose: + a = torch.randn(m, n, dtype=torch.float16, device='cuda') + transpose_matmul248(a, qweight, scales, qzeros, g_idx, bits, maxq) + del kn_values diff --git a/quantize/quant/quantizer.py b/quantize/quant/quantizer.py new file mode 100644 index 000000000..76844b876 --- /dev/null +++ b/quantize/quant/quantizer.py @@ -0,0 +1,127 @@ +import numpy as np +import torch +import torch.nn as nn +import math + + +class Quantizer(nn.Module): + + def __init__(self, shape=1): + super(Quantizer, self).__init__() + self.register_buffer('maxq', torch.tensor(0)) + self.register_buffer('scale', torch.zeros(shape)) + self.register_buffer('zero', torch.zeros(shape)) + + def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8, trits=False): + + self.maxq = torch.tensor(2**bits - 1) + self.perchannel = perchannel + self.sym = sym + self.mse = mse + self.norm = norm + self.grid = grid + self.maxshrink = maxshrink + if trits: + self.maxq = torch.tensor(-1) + self.scale = torch.zeros_like(self.scale) + + def _quantize(self, x, scale, zero, maxq): + if maxq < 0: + return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero + q = torch.clamp(torch.round(x / scale) + zero, 0, maxq) + return scale * (q - zero) + + def find_params(self, x, weight=False): + dev = x.device + self.maxq = self.maxq.to(dev) + + shape = x.shape + if self.perchannel: + if weight: + x = x.flatten(1) + else: + if len(shape) == 4: + x = x.permute([1, 0, 2, 3]) + x = x.flatten(1) + if len(shape) == 3: + x = x.reshape((-1, shape[-1])).t() + if len(shape) == 2: + x = x.t() + else: + x = x.flatten().unsqueeze(0) + + tmp = torch.zeros(x.shape[0], device=dev) + xmin = torch.minimum(x.min(1)[0], tmp) + xmax = torch.maximum(x.max(1)[0], tmp) + + if self.sym: + xmax = torch.maximum(torch.abs(xmin), xmax) + tmp = xmin < 0 + if torch.any(tmp): + xmin[tmp] = -xmax[tmp] + tmp = (xmin == 0) & (xmax == 0) + xmin[tmp] = -1 + xmax[tmp] = +1 + + if self.maxq < 0: + self.scale = xmax + self.zero = xmin + else: + self.scale = (xmax - xmin) / self.maxq + if self.sym: + self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2) + else: + self.zero = torch.round(-xmin / self.scale) + + if self.mse: + best = torch.full([x.shape[0]], float('inf'), device=dev) + for i in range(int(self.maxshrink * self.grid)): + p = 1 - i / self.grid + xmin1 = p * xmin + xmax1 = p * xmax + scale1 = (xmax1 - xmin1) / self.maxq + zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero + q = self._quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq) + q -= x + q.abs_() + q.pow_(self.norm) + err = torch.sum(q, 1) + tmp = err < best + if torch.any(tmp): + best[tmp] = err[tmp] + self.scale[tmp] = scale1[tmp] + self.zero[tmp] = zero1[tmp] + if not self.perchannel: + if weight: + tmp = shape[0] + else: + tmp = shape[1] if len(shape) != 3 else shape[2] + self.scale = self.scale.repeat(tmp) + self.zero = self.zero.repeat(tmp) + + if weight: + shape = [-1] + [1] * (len(shape) - 1) + self.scale = self.scale.reshape(shape) + self.zero = self.zero.reshape(shape) + return + if len(shape) == 4: + self.scale = self.scale.reshape((1, -1, 1, 1)) + self.zero = self.zero.reshape((1, -1, 1, 1)) + if len(shape) == 3: + self.scale = self.scale.reshape((1, 1, -1)) + self.zero = self.zero.reshape((1, 1, -1)) + if len(shape) == 2: + self.scale = self.scale.unsqueeze(0) + self.zero = self.zero.unsqueeze(0) + + def quantize(self, x): + if self.ready(): + return self._quantize(x, self.scale, self.zero, self.maxq) + + return x + + def enabled(self): + return self.maxq > 0 + + def ready(self): + return torch.all(self.scale != 0) diff --git a/quantize/utils/__init__.py b/quantize/utils/__init__.py new file mode 100644 index 000000000..79c85521c --- /dev/null +++ b/quantize/utils/__init__.py @@ -0,0 +1,11 @@ +from .modelutils import DEV, find_layers, gen_conditions, torch_snr_error +from .datautils import ( + set_seed, + get_wikitext2, + get_ptb, + get_c4, + get_ptb_new, + get_c4_new, + get_loaders, +) +from .export import export_quant_table diff --git a/quantize/utils/datautils.py b/quantize/utils/datautils.py new file mode 100644 index 000000000..7f04e4fa2 --- /dev/null +++ b/quantize/utils/datautils.py @@ -0,0 +1,225 @@ +import numpy as np +import torch + + +def set_seed(seed): + np.random.seed(seed) + torch.random.manual_seed(seed) + + +def get_wikitext2(nsamples, seed, seqlen, model): + from datasets import load_dataset + + traindata = load_dataset("wikitext", "wikitext-2-raw-v1", split="train") + testdata = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + trainenc = tokenizer("\n\n".join(traindata["text"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(testdata["text"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_ptb(nsamples, seed, seqlen, model): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + valdata = load_dataset("ptb_text_only", "penn_treebank", split="validation") + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + trainenc = tokenizer("\n\n".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer("\n\n".join(valdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4(nsamples, seed, seqlen, model): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + use_auth_token=False, + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + use_auth_token=False, + ) + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + import random + + random.seed(0) + valenc = [] + for _ in range(256): + while True: + i = random.randint(0, len(valdata) - 1) + tmp = tokenizer(valdata[i]["text"], return_tensors="pt") + if tmp.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, tmp.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + valenc.append(tmp.input_ids[:, i:j]) + valenc = torch.hstack(valenc) + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_ptb_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + + traindata = load_dataset("ptb_text_only", "penn_treebank", split="train") + testdata = load_dataset("ptb_text_only", "penn_treebank", split="test") + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + trainenc = tokenizer(" ".join(traindata["sentence"]), return_tensors="pt") + testenc = tokenizer(" ".join(testdata["sentence"]), return_tensors="pt") + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + return trainloader, testenc + + +def get_c4_new(nsamples, seed, seqlen, model): + from datasets import load_dataset + + traindata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"train": "en/c4-train.00000-of-01024.json.gz"}, + split="train", + ) + valdata = load_dataset( + "allenai/c4", + "allenai--c4", + data_files={"validation": "en/c4-validation.00000-of-00008.json.gz"}, + split="validation", + ) + + from transformers import AutoTokenizer + + try: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False) + except: + tokenizer = AutoTokenizer.from_pretrained(model, use_fast=True) + + import random + + random.seed(seed) + trainloader = [] + for _ in range(nsamples): + while True: + i = random.randint(0, len(traindata) - 1) + trainenc = tokenizer(traindata[i]["text"], return_tensors="pt") + if trainenc.input_ids.shape[1] >= seqlen: + break + i = random.randint(0, trainenc.input_ids.shape[1] - seqlen - 1) + j = i + seqlen + inp = trainenc.input_ids[:, i:j] + tar = inp.clone() + tar[:, :-1] = -100 + trainloader.append((inp, tar)) + + valenc = tokenizer(" ".join(valdata[:1100]["text"]), return_tensors="pt") + valenc = valenc.input_ids[:, : (256 * seqlen)] + + class TokenizerWrapper: + def __init__(self, input_ids): + self.input_ids = input_ids + + valenc = TokenizerWrapper(valenc) + + return trainloader, valenc + + +def get_loaders(name, nsamples=128, seed=0, seqlen=2048, model=""): + if "wikitext2" in name: + return get_wikitext2(nsamples, seed, seqlen, model) + if "ptb" in name: + if "new" in name: + return get_ptb_new(nsamples, seed, seqlen, model) + return get_ptb(nsamples, seed, seqlen, model) + if "c4" in name: + if "new" in name: + return get_c4_new(nsamples, seed, seqlen, model) + return get_c4(nsamples, seed, seqlen, model) diff --git a/quantize/utils/export.py b/quantize/utils/export.py new file mode 100644 index 000000000..943e25c56 --- /dev/null +++ b/quantize/utils/export.py @@ -0,0 +1,37 @@ +import numpy as np +import toml +import os + + +def export_quant_table(quantizers: dict, quant_dir: str, format: str = "toml"): + + table = {} + + def save_tensor(name: str, tensor): + np.save(os.path.join(quant_dir, name), tensor.numpy()) + return "{}.npy".format(name) + + for key, value in quantizers.items(): + quantizer = value[0] + + dump = dict() + + sym = quantizer.sym + if not sym: + dump["zero"] = save_tensor(name=key + ".zero", tensor=value[2]) + dump["scale"] = save_tensor(name=key + ".scale", tensor=value[1]) + dump["wbits"] = value[4] + dump["groupsize"] = value[5] + if value[5] > 0: + dump["group_ids"] = save_tensor(name=key + ".group_ids", tensor=value[3]) + + dump["sym"] = sym + dump["perchannel"] = quantizer.perchannel + + table[key] = dump + + if not os.path.exists(quant_dir): + os.mkdir(quant_dir) + + with open(os.path.join(quant_dir, "quant.toml"), "w") as f: + toml.dump(table, f) diff --git a/quantize/utils/modelutils.py b/quantize/utils/modelutils.py new file mode 100644 index 000000000..d57941079 --- /dev/null +++ b/quantize/utils/modelutils.py @@ -0,0 +1,91 @@ +import torch +import torch.nn as nn + +DEV = torch.device("cuda:0") + + +def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=""): + if type(module) in layers: + return {name: module} + res = {} + for name1, child in module.named_children(): + res.update( + find_layers( + child, layers=layers, name=name + "." + name1 if name != "" else name1 + ) + ) + return res + + +def gen_conditions(_wbits, _groupsize): + wbits = _wbits + groupsize = _groupsize + conditions = [] + while True: + if wbits >= 8: + if groupsize == -1 or groupsize == 32: + break + + if groupsize > 32: + groupsize /= 2 + else: + wbits *= 2 + groupsize = _groupsize + + conditions.append((int(wbits), int(groupsize))) + return conditions + + +# copy from https://github.com/openppl-public/ppq/blob/master/ppq/quantization/measure/norm.py +def torch_snr_error( + y_pred: torch.Tensor, y_real: torch.Tensor, reduction: str = "mean" +) -> torch.Tensor: + """ + Compute SNR between y_pred(tensor) and y_real(tensor) + + SNR can be calcualted as following equation: + + SNR(pred, real) = (pred - real) ^ 2 / (real) ^ 2 + + if x and y are matrixs, SNR error over matrix should be the mean value of SNR error over all elements. + + SNR(pred, real) = mean((pred - real) ^ 2 / (real) ^ 2) + Args: + y_pred (torch.Tensor): _description_ + y_real (torch.Tensor): _description_ + reduction (str, optional): _description_. Defaults to 'mean'. + Raises: + ValueError: _description_ + ValueError: _description_ + Returns: + torch.Tensor: _description_ + """ + y_pred = y_pred.type(torch.float32) + y_real = y_real.type(torch.float32) + + if y_pred.shape != y_real.shape: + raise ValueError( + f"Can not compute snr loss for tensors with different shape. " + f"({y_pred.shape} and {y_real.shape})" + ) + reduction = str(reduction).lower() + + if y_pred.ndim == 1: + y_pred = y_pred.unsqueeze(0) + y_real = y_real.unsqueeze(0) + + y_pred = y_pred.flatten(start_dim=1) + y_real = y_real.flatten(start_dim=1) + + noise_power = torch.pow(y_pred - y_real, 2).sum(dim=-1) + signal_power = torch.pow(y_real, 2).sum(dim=-1) + snr = (noise_power) / (signal_power + 1e-7) + + if reduction == "mean": + return torch.mean(snr) + elif reduction == "sum": + return torch.sum(snr) + elif reduction == "none": + return snr + else: + raise ValueError(f"Unsupported reduction method.") diff --git a/requirements.txt b/requirements.txt index d3fcd9f97..e8fa4e8a5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -49,3 +49,8 @@ pypandoc==1.11 openpyxl==3.1.2 lm_dataformat==0.0.20 bioc==2.0 + +# quantization +safetensors==0.3.1 +texttable==1.6.7 +toml==0.10.2