Skip to content

Commit

Permalink
add int4 non-gptq and bugfixes (#119)
Browse files Browse the repository at this point in the history
Summary: int4weightlinear had a bug that made it not pad when it should
have

Test Plan: python test/quantization/test_quant_api.py -k "int4wo"

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
HDCharles authored Apr 4, 2024
1 parent b0a333c commit ec258e0
Show file tree
Hide file tree
Showing 5 changed files with 133 additions and 14 deletions.
36 changes: 35 additions & 1 deletion test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ def test_gptq_quantizer_gpt_fast(self):
@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_gptq_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper
# should be similar to TorchCompileDynamicQuantizer
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
Expand Down Expand Up @@ -357,6 +356,41 @@ def test_gptq_quantizer_int4wo(self):
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_quantizer_int4wo(self):
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper
precision = torch.bfloat16
device = "cuda"
checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth")
model = Transformer.from_name(checkpoint_path.parent.name)
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
model.load_state_dict(checkpoint, assign=True)
model = model.to(dtype=precision, device=device)
model.eval()
tokenizer_path = checkpoint_path.parent / "tokenizer.model"
assert tokenizer_path.is_file(), tokenizer_path
tokenizer = SentencePieceProcessor( # pyre-ignore[28]
model_file=str(tokenizer_path)
)
groupsize = 128
quantizer = Int4WeightOnlyQuantizer(
groupsize,
)
model = quantizer.quantize(model).cuda()
result = TransformerEvalWrapper(
model,
tokenizer,
model.config.block_size,
prepare_inputs_for_model,
device,
).run_eval(
["wikitext"],
1,
)
assert result['results']['wikitext']['word_perplexity,none'] < 8.24, (
f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}"
)

@unittest.skip("skipping until we get checkpoints for gpt-fast")
def test_eval_wrapper(self):
from torchao.quantization.GPTQ import TransformerEvalWrapper
Expand Down
106 changes: 94 additions & 12 deletions torchao/quantization/GPTQ.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
groupwise_affine_quantize_tensor_from_qparams,
groupwise_affine_dequantize_tensor_from_qparams,
pack_tinygemm_scales_and_zeros,
groupwise_affine_quantize_tensor,
)
aten = torch.ops.aten

Expand Down Expand Up @@ -65,8 +66,8 @@

__all__ = [
"MultiInput",
"WeightOnlyInt4Linear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
] + add_ons

if lm_eval_available:
Expand Down Expand Up @@ -117,7 +118,10 @@ def __init__(

@property
def eot_token_id(self):
return self._tokenizer.eos_id()
try:
return self._tokenizer.eos_id()
except:
return self._tokenizer.eos_id

@property
def max_length(self):
Expand All @@ -139,7 +143,10 @@ def tok_encode(self, string: str, **kwargs):
# TODO: verify this for multi-batch as well
tokens = self._tokenizer.encode(string)
if hasattr(self._tokenizer, "bos_id"):
tokens = [self._tokenizer.bos_id()] + tokens
try:
tokens = [self._tokenizer.bos_id()] + tokens
except:
tokens = [self._tokenizer.bos_id] + tokens
return tokens

def tok_decode(self, tokens):
Expand Down Expand Up @@ -747,6 +754,12 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module":
def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module:
pass

def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
k_divisible_by_groupsize = k % groupsize == 0
if inner_k_tiles is not None:
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_groupsize

def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize):
origin_x_size = x.size()
Expand All @@ -767,7 +780,7 @@ def __init__(
bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True,
) -> None:
super().__init__()
self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles)
if self.padding:
from model import find_multiple
self.origin_in_features = in_features
Expand Down Expand Up @@ -806,14 +819,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
self.weight, self.scales_and_zeros, self.out_features, self.groupsize
)


def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None):
k_divisible_by_groupsize = k % groupsize == 0
if inner_k_tiles is not None:
k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0
return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles
return k_divisible_by_groupsize

def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None):

for name, child in module.named_children():
Expand All @@ -826,6 +831,83 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c
else:
replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func)

class Int4WeightOnlyQuantizer(Quantizer):
def __init__(
self,
groupsize: int = 256,
padding_allowed: bool = True,
inner_k_tiles: Optional[int] = 8,
) -> None:
super().__init__()
assert inner_k_tiles in [2, 4, 8]
assert groupsize in [32, 64, 128, 256]

self.inner_k_tiles = inner_k_tiles
self.groupsize: int = groupsize
self.padding_allowed: bool = padding_allowed

@torch.no_grad()
def _create_quantized_state_dict(
self, model: torch.nn.Module
) -> Dict[str, torch.Tensor]:
cur_state_dict = model.state_dict()
for fqn, mod in model.named_modules():
if isinstance(mod, torch.nn.Linear):
assert not mod.bias
out_features = mod.out_features
in_features = mod.in_features
# assert out_features % 8 == 0, "require out_features % 8 == 0"
print(f"linear: {fqn}, in={in_features}, out={out_features}")

assert (
in_features % self.groupsize == 0
), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0"

weight = mod.weight.data
if not _check_linear_int4_k(
in_features, self.groupsize, self.inner_k_tiles
):
if self.padding_allowed:
from .utils import find_multiple
import torch.nn.functional as F
print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0")
padded_in_features = find_multiple(in_features, 1024)
weight = F.pad(weight, pad=(0, padded_in_features - in_features))
else:
print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " +
"and that groupsize and inner_k_tiles*16 evenly divide into it")
continue
(
w_int4x8,
scales_and_zeros
) = groupwise_affine_quantize_tensor(
weight,
4, # n_bit
self.groupsize,
)
weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles)
cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda")
cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda")
return cur_state_dict

def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module:
replace_linear_int4(
model,
self.groupsize,
self.inner_k_tiles,
self.padding_allowed,
)
return model

def quantize(
self, model: torch.nn.Module, *args: Any, **kwargs: Any
) -> torch.nn.Module:
state_dict = self._create_quantized_state_dict(model)
model = self._convert_for_runtime(model)
# TODO: make it strict
model.load_state_dict(state_dict, strict=False)
return model

class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer):
def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
"compute_error",
"get_model_size_in_bytes",
"WeightOnlyInt8QuantLinear",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer",
]
2 changes: 2 additions & 0 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .unified import Quantizer, TwoStepQuantizer
from .GPTQ import (
Int4WeightOnlyGPTQQuantizer,
Int4WeightOnlyQuantizer,
)


Expand All @@ -45,6 +46,7 @@
"Quantizer",
"TwoStepQuantizer",
"Int4WeightOnlyGPTQQuantizer",
"Int4WeightOnlyQuantizer"
]

if TORCH_VERSION_AFTER_2_3:
Expand Down
1 change: 0 additions & 1 deletion torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,6 @@ def pack_tinygemm_scales_and_zeros(scales, zeros):

def unpack_tinygemm_scales_and_zeros(scales_and_zeros):
assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2
assert scales_and_zeros.dtype == torch.float
return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)


Expand Down

0 comments on commit ec258e0

Please sign in to comment.