-
Notifications
You must be signed in to change notification settings - Fork 1.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
based on: https://github.com/qwopqwop200/GPTQ-for-LLaMa (Apache 2.0)
- Loading branch information
1 parent
562e757
commit 48e2530
Showing
15 changed files
with
2,488 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
## Quantization to 4-bit for CPU inference | ||
|
||
First, make sure to install all [dependencies](../INSTALL.md). | ||
|
||
Select the model to quantize: | ||
```bash | ||
MODEL=h2ogpt-oig-oasst1-512-6.9b | ||
#MODEL=h2ogpt-oasst1-512-12b | ||
#MODEL=h2ogpt-oasst1-512-20b | ||
``` | ||
|
||
Run the conversion: | ||
```bash | ||
PYTHONPATH=. CUDA_VISIBLE_DEVICES=0 python \ | ||
quantize/neox.py h2oai/${MODEL} wikitext2 \ | ||
--wbits 4 \ | ||
--save ${MODEL}-4bit.pt | ||
``` | ||
|
||
Now test the model: | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 python \ | ||
quantize/inference.py h2oai/${MODEL} \ | ||
--wbits 4 \ | ||
--load ${MODEL}-4bit.pt \ | ||
--text "Tell me a joke about cookies." | ||
``` | ||
|
||
FIXME: creates garbage output | ||
``` | ||
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results. | ||
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation. | ||
Tell me a joke about cookies.� emitted column Sullivan Meatrah thinkers TemplateSLavorable homologous beat qubit WD differentiallyabstractKBchu | ||
260 econ environments unitaryimage endorse physicistisksaines observables preference euthan Creation 580 blinkowa metrics extrac lowered Raz proportions numerically claimant Plugin | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,251 @@ | ||
import math | ||
import time | ||
|
||
import torch | ||
import torch.nn as nn | ||
import transformers | ||
import quant | ||
from texttable import Texttable | ||
from utils import torch_snr_error | ||
|
||
torch.backends.cuda.matmul.allow_tf32 = False | ||
torch.backends.cudnn.allow_tf32 = False | ||
|
||
|
||
class Observer: | ||
def __init__(self, topk=32): | ||
self.loss_list = [] | ||
self.topk = topk | ||
|
||
def submit(self, name: str, layerid: int, gptq, error: float): | ||
|
||
item = (name, layerid, {"gptq": gptq, "error": error}) | ||
|
||
if len(self.loss_list) < self.topk: | ||
self.loss_list.append(item) | ||
return | ||
|
||
min_error = error | ||
min_idx = -1 | ||
for idx, data in enumerate(self.loss_list): | ||
if min_error > data[2]["error"]: | ||
min_idx = idx | ||
min_error = data[2]["error"] | ||
|
||
if min_idx >= 0: | ||
self.loss_list[min_idx] = item | ||
|
||
def print(self): | ||
self.loss_list = sorted( | ||
self.loss_list, key=lambda s: s[2]["error"], reverse=True | ||
) | ||
|
||
table = Texttable() | ||
|
||
table.header(["name", "error"]) | ||
table.set_cols_dtype(["t", "f"]) | ||
|
||
for item in self.loss_list: | ||
table.add_row([f"{item[0]}.{item[1]}", item[2]["error"]]) | ||
print(table.draw()) | ||
print("\n") | ||
|
||
def items(self): | ||
return self.loss_list | ||
|
||
|
||
class GPTQ: | ||
def __init__(self, layer, observe=False): | ||
self.layer = layer | ||
self.dev = self.layer.weight.device | ||
W = layer.weight.data.clone() | ||
if isinstance(self.layer, nn.Conv2d): | ||
W = W.flatten(1) | ||
if isinstance(self.layer, transformers.Conv1D): | ||
W = W.t() | ||
self.rows = W.shape[0] | ||
self.columns = W.shape[1] | ||
self.H = torch.zeros((self.columns, self.columns), device=self.dev) | ||
self.nsamples = 0 | ||
self.quantizer = quant.Quantizer() | ||
self.observe = observe | ||
|
||
def add_batch(self, inp, out): | ||
# Hessian H = 2 X XT + λ I | ||
if self.observe: | ||
self.inp1 = inp | ||
self.out1 = out | ||
else: | ||
self.inp1 = None | ||
self.out1 = None | ||
|
||
if len(inp.shape) == 2: | ||
inp = inp.unsqueeze(0) | ||
tmp = inp.shape[0] | ||
if isinstance(self.layer, nn.Linear) or isinstance( | ||
self.layer, transformers.Conv1D | ||
): | ||
if len(inp.shape) == 3: | ||
inp = inp.reshape((-1, inp.shape[-1])) | ||
inp = inp.t() | ||
if isinstance(self.layer, nn.Conv2d): | ||
unfold = nn.Unfold( | ||
self.layer.kernel_size, | ||
dilation=self.layer.dilation, | ||
padding=self.layer.padding, | ||
stride=self.layer.stride, | ||
) | ||
inp = unfold(inp) | ||
inp = inp.permute([1, 0, 2]) | ||
inp = inp.flatten(1) | ||
self.H *= self.nsamples / (self.nsamples + tmp) | ||
self.nsamples += tmp | ||
# inp = inp.float() | ||
inp = math.sqrt(2 / self.nsamples) * inp.float() | ||
# self.H += 2 / self.nsamples * inp.matmul(inp.t()) | ||
self.H += inp.matmul(inp.t()) | ||
|
||
def print_loss(self, name, q_weight, weight_error, timecost): | ||
table = Texttable() | ||
name += " " * (16 - len(name)) | ||
|
||
table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"]) | ||
|
||
# assign weight | ||
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to( | ||
self.layer.weight.data.dtype | ||
) | ||
|
||
if self.inp1 is not None: | ||
# quantize input to int8 | ||
quantizer = quant.Quantizer() | ||
quantizer.configure(8, perchannel=False, sym=True, mse=False) | ||
quantizer.find_params(self.inp1) | ||
q_in = quantizer.quantize(self.inp1).type(torch.float16) | ||
q_out = self.layer(q_in) | ||
|
||
# get kinds of SNR | ||
q_SNR = torch_snr_error(q_out, self.out1).item() | ||
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item() | ||
else: | ||
q_SNR = "-" | ||
fp_SNR = "-" | ||
|
||
table.add_row([name, weight_error, fp_SNR, q_SNR, timecost]) | ||
print(table.draw().split("\n")[-2]) | ||
|
||
def fasterquant( | ||
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name="" | ||
): | ||
self.layer.to(self.dev) | ||
|
||
W = self.layer.weight.data.clone() | ||
if isinstance(self.layer, nn.Conv2d): | ||
W = W.flatten(1) | ||
if isinstance(self.layer, transformers.Conv1D): | ||
W = W.t() | ||
W = W.float() | ||
|
||
tick = time.time() | ||
|
||
if not self.quantizer.ready(): | ||
self.quantizer.find_params(W, weight=True) | ||
|
||
H = self.H | ||
if not self.observe: | ||
del self.H | ||
dead = torch.diag(H) == 0 | ||
H[dead, dead] = 1 | ||
W[:, dead] = 0 | ||
|
||
if actorder: | ||
perm = torch.argsort(torch.diag(H), descending=True) | ||
W = W[:, perm] | ||
H = H[perm][:, perm] | ||
|
||
Losses = torch.zeros_like(W) | ||
Q = torch.zeros_like(W) | ||
|
||
damp = percdamp * torch.mean(torch.diag(H)) | ||
diag = torch.arange(self.columns, device=self.dev) | ||
H[diag, diag] += damp | ||
H = torch.linalg.cholesky(H) | ||
H = torch.cholesky_inverse(H) | ||
H = torch.linalg.cholesky(H, upper=True) | ||
Hinv = H | ||
|
||
g_idx = [] | ||
scale = [] | ||
zero = [] | ||
now_idx = 1 | ||
|
||
for i1 in range(0, self.columns, blocksize): | ||
i2 = min(i1 + blocksize, self.columns) | ||
count = i2 - i1 | ||
|
||
W1 = W[:, i1:i2].clone() | ||
Q1 = torch.zeros_like(W1) | ||
Err1 = torch.zeros_like(W1) | ||
Losses1 = torch.zeros_like(W1) | ||
Hinv1 = Hinv[i1:i2, i1:i2] | ||
|
||
for i in range(count): | ||
w = W1[:, i] | ||
d = Hinv1[i, i] | ||
|
||
if groupsize != -1: | ||
if (i1 + i) % groupsize == 0: | ||
self.quantizer.find_params( | ||
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True | ||
) | ||
|
||
if ((i1 + i) // groupsize) - now_idx == -1: | ||
scale.append(self.quantizer.scale) | ||
zero.append(self.quantizer.zero) | ||
now_idx += 1 | ||
|
||
q = self.quantizer.quantize(w.unsqueeze(1)).flatten() | ||
Q1[:, i] = q | ||
Losses1[:, i] = (w - q) ** 2 / d ** 2 | ||
|
||
err1 = (w - q) / d | ||
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0)) | ||
Err1[:, i] = err1 | ||
|
||
Q[:, i1:i2] = Q1 | ||
Losses[:, i1:i2] = Losses1 / 2 | ||
|
||
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:]) | ||
|
||
torch.cuda.synchronize() | ||
error = torch.sum(Losses).item() | ||
|
||
groupsize = groupsize if groupsize != -1 else self.columns | ||
g_idx = [i // groupsize for i in range(self.columns)] | ||
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device) | ||
if actorder: | ||
invperm = torch.argsort(perm) | ||
Q = Q[:, invperm] | ||
g_idx = g_idx[invperm] | ||
|
||
if isinstance(self.layer, transformers.Conv1D): | ||
Q = Q.t() | ||
|
||
self.print_loss( | ||
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick) | ||
) | ||
|
||
if scale == []: | ||
scale.append(self.quantizer.scale) | ||
zero.append(self.quantizer.zero) | ||
scale = torch.cat(scale, dim=1) | ||
zero = torch.cat(zero, dim=1) | ||
return scale, zero, g_idx, error | ||
|
||
def free(self): | ||
self.inp1 = None | ||
self.out1 = None | ||
self.H = None | ||
self.Losses = None | ||
self.Trace = None | ||
torch.cuda.empty_cache() |
Oops, something went wrong.