Skip to content

Commit

Permalink
Add Quantization code.
Browse files Browse the repository at this point in the history
  • Loading branch information
arnocandel committed May 2, 2023
1 parent 562e757 commit 48e2530
Show file tree
Hide file tree
Showing 15 changed files with 2,488 additions and 0 deletions.
35 changes: 35 additions & 0 deletions QUANTIZE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
## Quantization to 4-bit for CPU inference

First, make sure to install all [dependencies](../INSTALL.md).

Select the model to quantize:
```bash
MODEL=h2ogpt-oig-oasst1-512-6.9b
#MODEL=h2ogpt-oasst1-512-12b
#MODEL=h2ogpt-oasst1-512-20b
```

Run the conversion:
```bash
PYTHONPATH=. CUDA_VISIBLE_DEVICES=0 python \
quantize/neox.py h2oai/${MODEL} wikitext2 \
--wbits 4 \
--save ${MODEL}-4bit.pt
```

Now test the model:
```bash
CUDA_VISIBLE_DEVICES=0 python \
quantize/inference.py h2oai/${MODEL} \
--wbits 4 \
--load ${MODEL}-4bit.pt \
--text "Tell me a joke about cookies."
```

FIXME: creates garbage output
```
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:0 for open-end generation.
Tell me a joke about cookies.� emitted column Sullivan Meatrah thinkers TemplateSLavorable homologous beat qubit WD differentiallyabstractKBchu
260 econ environments unitaryimage endorse physicistisksaines observables preference euthan Creation 580 blinkowa metrics extrac lowered Raz proportions numerically claimant Plugin
```
251 changes: 251 additions & 0 deletions quantize/gptq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,251 @@
import math
import time

import torch
import torch.nn as nn
import transformers
import quant
from texttable import Texttable
from utils import torch_snr_error

torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False


class Observer:
def __init__(self, topk=32):
self.loss_list = []
self.topk = topk

def submit(self, name: str, layerid: int, gptq, error: float):

item = (name, layerid, {"gptq": gptq, "error": error})

if len(self.loss_list) < self.topk:
self.loss_list.append(item)
return

min_error = error
min_idx = -1
for idx, data in enumerate(self.loss_list):
if min_error > data[2]["error"]:
min_idx = idx
min_error = data[2]["error"]

if min_idx >= 0:
self.loss_list[min_idx] = item

def print(self):
self.loss_list = sorted(
self.loss_list, key=lambda s: s[2]["error"], reverse=True
)

table = Texttable()

table.header(["name", "error"])
table.set_cols_dtype(["t", "f"])

for item in self.loss_list:
table.add_row([f"{item[0]}.{item[1]}", item[2]["error"]])
print(table.draw())
print("\n")

def items(self):
return self.loss_list


class GPTQ:
def __init__(self, layer, observe=False):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.quantizer = quant.Quantizer()
self.observe = observe

def add_batch(self, inp, out):
# Hessian H = 2 X XT + λ I
if self.observe:
self.inp1 = inp
self.out1 = out
else:
self.inp1 = None
self.out1 = None

if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(
self.layer, transformers.Conv1D
):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())

def print_loss(self, name, q_weight, weight_error, timecost):
table = Texttable()
name += " " * (16 - len(name))

table.header(["name", "weight_error", "fp_inp_SNR", "q_inp_SNR", "time"])

# assign weight
self.layer.weight.data = q_weight.reshape(self.layer.weight.shape).to(
self.layer.weight.data.dtype
)

if self.inp1 is not None:
# quantize input to int8
quantizer = quant.Quantizer()
quantizer.configure(8, perchannel=False, sym=True, mse=False)
quantizer.find_params(self.inp1)
q_in = quantizer.quantize(self.inp1).type(torch.float16)
q_out = self.layer(q_in)

# get kinds of SNR
q_SNR = torch_snr_error(q_out, self.out1).item()
fp_SNR = torch_snr_error(self.layer(self.inp1), self.out1).item()
else:
q_SNR = "-"
fp_SNR = "-"

table.add_row([name, weight_error, fp_SNR, q_SNR, timecost])
print(table.draw().split("\n")[-2])

def fasterquant(
self, blocksize=128, percdamp=0.01, groupsize=-1, actorder=False, name=""
):
self.layer.to(self.dev)

W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()

tick = time.time()

if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)

H = self.H
if not self.observe:
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0

if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]

Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)

damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H

g_idx = []
scale = []
zero = []
now_idx = 1

for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1

W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]

for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]

if groupsize != -1:
if (i1 + i) % groupsize == 0:
self.quantizer.find_params(
W[:, (i1 + i) : (i1 + i + groupsize)], weight=True
)

if ((i1 + i) // groupsize) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1

q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d ** 2

err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1

Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2

W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])

torch.cuda.synchronize()
error = torch.sum(Losses).item()

groupsize = groupsize if groupsize != -1 else self.columns
g_idx = [i // groupsize for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
invperm = torch.argsort(perm)
Q = Q[:, invperm]
g_idx = g_idx[invperm]

if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()

self.print_loss(
name=name, q_weight=Q, weight_error=error, timecost=(time.time() - tick)
)

if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx, error

def free(self):
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
Loading

0 comments on commit 48e2530

Please sign in to comment.