diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 700c9c9b98..5cc5ac1fa3 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -300,7 +300,6 @@ def test_gptq_quantizer_gpt_fast(self): @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_gptq_quantizer_int4wo(self): from torchao.quantization.GPTQ import Int4WeightOnlyGPTQQuantizer, InputRecorder, TransformerEvalWrapper - # should be similar to TorchCompileDynamicQuantizer precision = torch.bfloat16 device = "cuda" checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") @@ -357,6 +356,41 @@ def test_gptq_quantizer_int4wo(self): f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" ) + @unittest.skip("skipping until we get checkpoints for gpt-fast") + def test_quantizer_int4wo(self): + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer, TransformerEvalWrapper + precision = torch.bfloat16 + device = "cuda" + checkpoint_path = Path("../gpt-fast/checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth") + model = Transformer.from_name(checkpoint_path.parent.name) + checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True) + model.load_state_dict(checkpoint, assign=True) + model = model.to(dtype=precision, device=device) + model.eval() + tokenizer_path = checkpoint_path.parent / "tokenizer.model" + assert tokenizer_path.is_file(), tokenizer_path + tokenizer = SentencePieceProcessor( # pyre-ignore[28] + model_file=str(tokenizer_path) + ) + groupsize = 128 + quantizer = Int4WeightOnlyQuantizer( + groupsize, + ) + model = quantizer.quantize(model).cuda() + result = TransformerEvalWrapper( + model, + tokenizer, + model.config.block_size, + prepare_inputs_for_model, + device, + ).run_eval( + ["wikitext"], + 1, + ) + assert result['results']['wikitext']['word_perplexity,none'] < 8.24, ( + f"accuracy regressed from 8.23 to {result['results']['wikitext']['word_perplexity,none']}" + ) + @unittest.skip("skipping until we get checkpoints for gpt-fast") def test_eval_wrapper(self): from torchao.quantization.GPTQ import TransformerEvalWrapper diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 559ab54f7d..d648507085 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -28,6 +28,7 @@ groupwise_affine_quantize_tensor_from_qparams, groupwise_affine_dequantize_tensor_from_qparams, pack_tinygemm_scales_and_zeros, + groupwise_affine_quantize_tensor, ) aten = torch.ops.aten @@ -65,8 +66,8 @@ __all__ = [ "MultiInput", - "WeightOnlyInt4Linear", "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer", ] + add_ons if lm_eval_available: @@ -117,7 +118,10 @@ def __init__( @property def eot_token_id(self): - return self._tokenizer.eos_id() + try: + return self._tokenizer.eos_id() + except: + return self._tokenizer.eos_id @property def max_length(self): @@ -139,7 +143,10 @@ def tok_encode(self, string: str, **kwargs): # TODO: verify this for multi-batch as well tokens = self._tokenizer.encode(string) if hasattr(self._tokenizer, "bos_id"): - tokens = [self._tokenizer.bos_id()] + tokens + try: + tokens = [self._tokenizer.bos_id()] + tokens + except: + tokens = [self._tokenizer.bos_id] + tokens return tokens def tok_decode(self, tokens): @@ -747,6 +754,12 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> "nn.Module": def quantize(self, model: torch.nn.Module, inputs: List[MultiInput], **kwargs: Any) -> torch.nn.Module: pass +def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): + k_divisible_by_groupsize = k % groupsize == 0 + if inner_k_tiles is not None: + k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 + return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles + return k_divisible_by_groupsize def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): origin_x_size = x.size() @@ -767,7 +780,7 @@ def __init__( bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, use_cuda=True, ) -> None: super().__init__() - self.padding = _check_linear_int4_k(in_features, groupsize, inner_k_tiles) + self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) if self.padding: from model import find_multiple self.origin_in_features = in_features @@ -806,14 +819,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.weight, self.scales_and_zeros, self.out_features, self.groupsize ) - -def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): - k_divisible_by_groupsize = k % groupsize == 0 - if inner_k_tiles is not None: - k_divisible_by_16_times_inner_k_tiles = k % (inner_k_tiles * 16) == 0 - return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles - return k_divisible_by_groupsize - def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_cuda=True, skip_layer_func = None): for name, child in module.named_children(): @@ -826,6 +831,83 @@ def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, use_c else: replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, use_cuda, skip_layer_func) +class Int4WeightOnlyQuantizer(Quantizer): + def __init__( + self, + groupsize: int = 256, + padding_allowed: bool = True, + inner_k_tiles: Optional[int] = 8, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + + self.inner_k_tiles = inner_k_tiles + self.groupsize: int = groupsize + self.padding_allowed: bool = padding_allowed + + @torch.no_grad() + def _create_quantized_state_dict( + self, model: torch.nn.Module + ) -> Dict[str, torch.Tensor]: + cur_state_dict = model.state_dict() + for fqn, mod in model.named_modules(): + if isinstance(mod, torch.nn.Linear): + assert not mod.bias + out_features = mod.out_features + in_features = mod.in_features + # assert out_features % 8 == 0, "require out_features % 8 == 0" + print(f"linear: {fqn}, in={in_features}, out={out_features}") + + assert ( + in_features % self.groupsize == 0 + ), f"require in_features:{in_features} % self.groupsize:{self.groupsize} == 0" + + weight = mod.weight.data + if not _check_linear_int4_k( + in_features, self.groupsize, self.inner_k_tiles + ): + if self.padding_allowed: + from .utils import find_multiple + import torch.nn.functional as F + print(f"warning: {fqn} is padded to satisfy in_features % 1024 == 0") + padded_in_features = find_multiple(in_features, 1024) + weight = F.pad(weight, pad=(0, padded_in_features - in_features)) + else: + print(f"warning: {fqn} is skipped, int4 requires that in_features is 32, 64, or is divisible by 1024, " + + "and that groupsize and inner_k_tiles*16 evenly divide into it") + continue + ( + w_int4x8, + scales_and_zeros + ) = groupwise_affine_quantize_tensor( + weight, + 4, # n_bit + self.groupsize, + ) + weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to("cuda"), self.inner_k_tiles) + cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to("cuda") + cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to("cuda") + return cur_state_dict + + def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: + replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + self.padding_allowed, + ) + return model + + def quantize( + self, model: torch.nn.Module, *args: Any, **kwargs: Any + ) -> torch.nn.Module: + state_dict = self._create_quantized_state_dict(model) + model = self._convert_for_runtime(model) + # TODO: make it strict + model.load_state_dict(state_dict, strict=False) + return model + class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): def __init__( self, diff --git a/torchao/quantization/__init__.py b/torchao/quantization/__init__.py index 4bfb279769..12aa70039b 100644 --- a/torchao/quantization/__init__.py +++ b/torchao/quantization/__init__.py @@ -42,4 +42,6 @@ "compute_error", "get_model_size_in_bytes", "WeightOnlyInt8QuantLinear", + "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer", ] diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index e7ce92976b..581e312927 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -32,6 +32,7 @@ from .unified import Quantizer, TwoStepQuantizer from .GPTQ import ( Int4WeightOnlyGPTQQuantizer, + Int4WeightOnlyQuantizer, ) @@ -45,6 +46,7 @@ "Quantizer", "TwoStepQuantizer", "Int4WeightOnlyGPTQQuantizer", + "Int4WeightOnlyQuantizer" ] if TORCH_VERSION_AFTER_2_3: diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 8b6cc9cc7f..88eafd4b2a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -383,7 +383,6 @@ def pack_tinygemm_scales_and_zeros(scales, zeros): def unpack_tinygemm_scales_and_zeros(scales_and_zeros): assert len(scales_and_zeros.shape) == 3 and scales_and_zeros.shape[2] == 2 - assert scales_and_zeros.dtype == torch.float return torch.split(scales_and_zeros.transpose(0, 1), 1, 2)