From 6d9a2059a5e1d842b56bdce8f70cea82e880979f Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Sat, 15 Jun 2024 15:14:14 -0700 Subject: [PATCH] Add support for int4 weight-only QAT Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting. Test Plan: python test/quantization/test_qat.py -k test_qat_4w_linear Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim, HDCharles, supriyar --- test/quantization/test_qat.py | 184 ++++++++++++++++++-- torchao/quantization/GPTQ.py | 89 ++++++++-- torchao/quantization/prototype/qat.py | 204 ++++++++++++++++++++++- torchao/quantization/quant_primitives.py | 91 ++++++++-- torchao/quantization/utils.py | 8 +- 5 files changed, 526 insertions(+), 50 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 433fdcb28a..6e2859e109 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,18 +18,26 @@ fake_quantize_per_channel_group, fake_quantize_per_token, ) -from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + get_groupwise_affine_qparams, + groupwise_affine_dequantize_tensor_from_qparams, + groupwise_affine_quantize_tensor, + groupwise_affine_quantize_tensor_from_qparams, +) from torchao.utils import TORCH_VERSION_AFTER_2_4 # TODO: put this in a common test utils file +_CUDA_IS_AVAILABLE = torch.cuda.is_available() + class Sub(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float) + self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 32).to(torch.float),) + return (torch.randn(1, 256).to(torch.float),) def forward(self, x): return self.linear(x) @@ -37,12 +45,12 @@ def forward(self, x): class M(torch.nn.Module): def __init__(self): super().__init__() - self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float) self.sub = Sub() - self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 64).to(torch.float),) + return (torch.randn(1, 512).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -111,23 +119,46 @@ def test_fake_quantize_per_token(self): def _set_ptq_weight( self, - ptq_linear: "Int8DynActInt4WeightLinear", - fp32_weight: torch.Tensor, - group_size: int, + ptq_linear: torch.nn.Module, + qat_linear: torch.nn.Module, ): """ Set the weight to the quantized version of the given fp32 weights, for making linear outputs comparable with QAT. """ + from torchao.quantization.GPTQ import ( + Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, + ) + from torchao.quantization.prototype.qat import ( + Int8DynActInt4WeightQATLinear, + Int4WeightOnlyQATLinear, + ) n_bit = 4 (qmin, qmax) = self._get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) - q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, - ) - ptq_linear.weight = q_weight - ptq_linear.scales = s - ptq_linear.zeros = zp + if isinstance(ptq_linear, Int8DynActInt4WeightLinear): + assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear) + fp32_weight = qat_linear.weight + group_size = qat_linear.groupsize + (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( + fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + ) + ptq_linear.weight = q_weight + ptq_linear.scales = s + ptq_linear.zeros = zp + elif isinstance(ptq_linear, WeightOnlyInt4Linear): + assert isinstance(qat_linear, Int4WeightOnlyQATLinear) + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + qat_linear.weight, n_bit, qat_linear.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to("cuda"), qat_linear.inner_k_tiles, + ) + ptq_linear.weight = q_weight + ptq_linear.scales_and_zeros = scales_and_zeros + else: + raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): @@ -144,7 +175,7 @@ def test_qat_8da4w_linear(self): ) # Force the weights to be the same - self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size) + self._set_ptq_weight(ptq_linear, qat_linear) # Compare linear values torch.manual_seed(self.SEED) @@ -280,7 +311,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): loss_fn1 = torch.nn.CrossEntropyLoss() loss_fn2 = torch.nn.CrossEntropyLoss() example_inputs = nn_model.example_inputs() - target = torch.randn(1, 64).float() + target = torch.randn(1, 512).float() output1 = nn_model(*example_inputs) output2 = qat_model(*example_inputs) torch.testing.assert_close(output1, output2, atol=0, rtol=0) @@ -322,6 +353,123 @@ def test_qat_generic_fake_quantize(self): torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0) + def _assert_close_4w(self, val, ref): + # Note: for int4 weight-only quantization, we do not expect exact match + # because torch._weight_int4pack_mm and torch.mm do not match exactly. + # Here we use the same error bar as PyTorch core to determine closeness: + # https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079 + mean_err = ((val - ref) / ref).mean().abs() + print(mean_err) + self.assertTrue(mean_err < 0.05) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_primitives(self): + n_bit = 4 + group_size = 32 + inner_k_tiles = 8 + scales_precision = torch.bfloat16 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + weight = torch.randn(512, 256, dtype=dtype, device=device) + + # PTQ + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + weight, n_bit, group_size, scales_precision, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(device), inner_k_tiles, + ) + ptq_out = torch.ops.aten._weight_int4pack_mm( + x, q_weight, group_size, scales_and_zeros + ) + + # QAT + scales, zero_points = get_groupwise_affine_qparams( + weight, n_bit, group_size, scales_precision, + ) + w_q = groupwise_affine_quantize_tensor_from_qparams( + weight, scales, zero_points, n_bit, group_size, cast_dtypes=False, + ) + w_dq = groupwise_affine_dequantize_tensor_from_qparams( + w_q, scales, zero_points, n_bit, group_size, cast_dtypes=False, + ) + qat_out = torch.nn.functional.linear(x, w_dq) + + self._assert_close_4w(qat_out, ptq_out) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_linear(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear + from torchao.quantization.GPTQ import WeightOnlyInt4Linear + + group_size = 128 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + qat_linear = Int4WeightOnlyQATLinear( + 256, 688, bias=False, groupsize=group_size, device=device, + ) + ptq_linear = WeightOnlyInt4Linear( + 256, 688, bias=False, groupsize=group_size, device=device, + ) + + # Force the weights to be the same + self._set_ptq_weight(ptq_linear, qat_linear) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + x2 = copy.deepcopy(x) + qat_out = qat_linear(x) + ptq_out = ptq_linear(x2) + self._assert_close_4w(qat_out, ptq_out) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_quantizer(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + + group_size = 32 + inner_k_tiles = 8 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + m = M().to(device).to(dtype) + m2 = copy.deepcopy(m) + qat_quantizer = Int4WeightOnlyQATQuantizer( + groupsize=group_size, inner_k_tiles=inner_k_tiles, + ) + ptq_quantizer = Int4WeightOnlyQuantizer( + groupsize=group_size, inner_k_tiles=inner_k_tiles, + ) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) + + # Compare model values + torch.manual_seed(self.SEED) + x = [i.to(device).to(dtype) for i in m.example_inputs()] + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + self._assert_close_4w(qat_out, ptq_out) + + # Convert QAT model and compare model values + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 996d812ba8..c27c5d6de8 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -9,7 +9,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional, List, Type +from typing import Optional, Callable, List, Type import torch @@ -522,14 +522,21 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles return k_divisible_by_groupsize -def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize): +def linear_forward_int4( + x: torch.Tensor, + weight_int4pack: torch.Tensor, + scales_and_zeros: torch.Tensor, + out_features: int, + groupsize: int, + scales_precision: torch.dtype = torch.bfloat16, +): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) c = torch.ops.aten._weight_int4pack_mm( x.to(torch.bfloat16), weight_int4pack, groupsize, - scales_and_zeros.to(torch.bfloat16) + scales_and_zeros.to(scales_precision) ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) @@ -544,6 +551,7 @@ class WeightOnlyInt4Linear(torch.nn.Module): def __init__( self, in_features: int, out_features: int, bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, + precision: torch.dtype = torch.int32, scales_precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) @@ -555,40 +563,91 @@ def __init__( self.in_features = in_features self.out_features = out_features assert not bias, "require bias=False" + self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision assert out_features % 8 == 0, "require out_features % 8 == 0" assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=self.precision, device=device) ) self.register_buffer( "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=torch.bfloat16) + torch.empty((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device) ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: - import torch.nn.functional as F input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize + self.weight, + self.scales_and_zeros, + self.out_features, + self.groupsize, + self.scales_precision, ) -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): +def _replace_linear_int4( + module: torch.nn.Module, + groupsize: int, + inner_k_tiles: Optional[int], + padding_allowed: bool, + skip_layer_func: Optional[Callable] = None, + precision: torch.dtype = torch.int32, + scales_precision: torch.dtype = torch.bfloat16, + linear_class: Type[torch.nn.Module] = WeightOnlyInt4Linear, + copy_weights: bool = False, +): for name, child in module.named_children(): if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, - )) + new_linear = linear_class( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=precision, + scales_precision=scales_precision, + ) + # TODO: merge with 8da4w? + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if copy_weights and child.weight.device != torch.device("meta"): + new_linear.weight = child.weight + setattr(module, name, new_linear) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func) + _replace_linear_int4( + child, + groupsize, + inner_k_tiles, + padding_allowed, + skip_layer_func, + precision, + scales_precision, + linear_class, + copy_weights, + ) + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): + _replace_linear_int4( + module, + groupsize, + inner_k_tiles, + padding_allowed, + skip_layer_func, + linear_class=WeightOnlyInt4Linear, + ) + class Int4WeightOnlyQuantizer(Quantizer): def __init__( @@ -646,6 +705,7 @@ def _create_quantized_state_dict( 4, # n_bit self.groupsize, ) + # TODO: just get the device from mod.weight.device? weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device) @@ -669,6 +729,7 @@ def quantize( model.load_state_dict(state_dict, strict=False) return model + class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): def __init__( self, @@ -1001,6 +1062,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: self.groupsize, self.padding_allowed, self.precision, + # TODO: this should be self.scales_precision? self.precision, ) return model @@ -1086,6 +1148,7 @@ def _convert_for_runtime(self, model): self.groupsize, self.padding_allowed, self.precision, + # TODO: this should be self.scales_precision? self.precision, ) return model diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 71b585b15e..2d86f79c6e 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -4,20 +4,34 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch +import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib from torch.library import impl -from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + groupwise_affine_dequantize_tensor, +) from torchao.quantization.unified import TwoStepQuantizer from torchao.quantization.GPTQ import ( + _check_linear_int4_k, + _replace_linear_int4, _replace_linear_8da4w, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, + groupwise_affine_quantize_tensor_from_qparams, + groupwise_affine_dequantize_tensor_from_qparams, Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, ) +# ================= +# | 8da4w QAT | +# ================= class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): """ @@ -171,7 +185,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) else: w_fq = self.weight - return torch.nn.functional.linear(x_fq, w_fq) + return F.linear(x_fq, w_fq) # TODO: move this to common util def _get_qmin_qmax(self, n_bit: int): @@ -193,10 +207,194 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module): if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.disable_fake_quant() + +# ================== +# | int4wo QAT | +# ================== + +class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + inner_k_tiles: Optional[int] = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + self.inner_k_tiles = inner_k_tiles + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + padding_allowed=True, + precision=self.precision, + scales_precision=self.scales_precision, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_4w(model) + return model + +def _convert_qat_linear_4w(module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + groupsize = child.groupsize + inner_k_tiles = child.inner_k_tiles + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, child.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + _convert_qat_linear_4w(child) + +class Int4WeightOnlyQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int4 fake quantized grouped + per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, + which uses the efficient int4 tinygemm kernel. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert not bias, "require bias=False" + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_bit = 4 + qmin = 0 + qmax = 2 ** n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + self.weight, n_bit, self.groupsize, self.scales_precision, + ) + w_fq = _Int4WeightOnlyFakeQuantize.apply( + self.weight, scales, zero_points, qmin, qmax, self.groupsize, + ) + return F.linear(x, w_fq) + +def enable_4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.enable_fake_quant() + +def disable_4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.disable_fake_quant() + + # ======================== # | QUANT PRIMITIVES | # ======================== +class _Int4WeightOnlyFakeQuantize(torch.autograd.Function): + """ + Implementation of int4 grouped per channel weight-only fake quantize + intended to match the numerics of the efficient int4 tinygemm kernel. + """ + + @staticmethod + def forward(ctx, input, scales, zero_points, quant_min, quant_max, groupsize): + n_bit = 4 + w_q = groupwise_affine_quantize_tensor_from_qparams( + input, scales, zero_points, n_bit, groupsize, cast_dtypes=False, + ) + w_dq = groupwise_affine_dequantize_tensor_from_qparams( + w_q, scales, zero_points, n_bit, groupsize, cast_dtypes=False, + ) + mask = torch.logical_and((w_q >= quant_min), (w_q <= quant_max)) + ctx.save_for_backward(mask) + return w_dq + + @staticmethod + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + class _GenericFakeQuantize(torch.autograd.Function): """ Implementation of generic fake quantize with backward STE. diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a78c42605a..876be13e63 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -174,6 +174,29 @@ def quantize_affine( Output: quantized tensor with requested dtype """ + return _do_quantize_affine( + input, + block_size, + scale, + zero_point, + output_dtype, + quant_min, + quant_max, + zero_point_domain, + cast_dtypes=True, + ) + +def _do_quantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + output_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + cast_dtypes: bool = True, +): # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size assert input.dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported input dtype: {input.dtype}" @@ -191,7 +214,9 @@ def quantize_affine( if zero_point_domain == ZeroPointDomain.INT: quant = torch.clamp( torch.round(input * (1.0 / scale)) + zero_point, quant_min, quant_max - ).to(output_dtype) + ) + if cast_dtypes: + quant = quant.to(output_dtype) else: assert zero_point_domain == ZeroPointDomain.FLOAT mid_point = (quant_max + quant_min + 1) / 2 @@ -200,7 +225,9 @@ def quantize_affine( torch.clamp( torch.round((input - min_val) / scale), quant_min, quant_max) - ).to(output_dtype) + ) + if cast_dtypes: + quant = quant.to(output_dtype) quant = quant.view(original_shape) return quant @@ -238,11 +265,37 @@ def dequantize_affine( Output: dequantized Tensor, with requested dtype or fp32 """ - + return _do_dequantize_affine( + input, + block_size, + scale, + zero_point, + input_dtype, + quant_min, + quant_max, + zero_point_domain, + output_dtype=output_dtype, + cast_dtypes=True, + ) + +def _do_dequantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + input_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + *, + output_dtype: torch.dtype = torch.float32, + cast_dtypes: bool = True, +): # TODO: validations # TODO: validate scale/zero_point dimensions are compatible with block_size - assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" - assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" + if cast_dtypes: + assert input.dtype == input_dtype, f"Expected: {input_dtype}, got: {input.dtype}" + assert output_dtype in [torch.float32, torch.float16, torch.bfloat16], f"Unsupported output dtype: {output_dtype}" quant_min, quant_max = _get_and_check_qmin_qmax(input_dtype, quant_min, quant_max) shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size()) @@ -256,24 +309,34 @@ def dequantize_affine( zero_point = zero_point.view(shape_after_reduction) if zero_point_domain == ZeroPointDomain.INT: - # Force a copy to avoid input modification due - # to upcoming in-place operations. - dequant = input.to(torch.int32, copy=True) - if zero_point is not None: - dequant -= zero_point.to(torch.int32) - dequant = dequant.to(output_dtype) - dequant *= scale + if cast_dtypes: + # Force a copy to avoid input modification due + # to upcoming in-place operations. + dequant = input.to(torch.int32, copy=True) + if zero_point is not None: + dequant -= zero_point.to(torch.int32) + dequant = dequant.to(output_dtype) + dequant *= scale + else: + dequant = input.clone() + if zero_point is not None: + dequant -= zero_point + dequant *= scale else: assert zero_point_domain == ZeroPointDomain.FLOAT, f"Unexpected zero point domain: {zero_point_domain}" mid_point = (quant_max + quant_min + 1) / 2 # This should allocate new memory and avoid input modification dequant = input - mid_point - dequant = dequant.to(output_dtype) + if cast_dtypes: + dequant = dequant.to(output_dtype) dequant *= scale if zero_point is not None: dequant += zero_point - return dequant.view(original_shape).to(output_dtype) + dequant = dequant.view(original_shape) + if cast_dtypes: + dequant = dequant.to(output_dtype) + return dequant def choose_qparams_affine( input: torch.Tensor, diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 3e3943c93c..a3205b8fb7 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -10,6 +10,8 @@ import torch.nn.utils.parametrize as parametrize from torchao.utils import find_multiple from .quant_primitives import ( + _do_quantize_affine, + _do_dequantize_affine, MappingType, ZeroPointDomain, choose_qparams_affine, @@ -333,6 +335,7 @@ def groupwise_affine_quantize_tensor_from_qparams( zeros, n_bit=4, groupsize=128, + cast_dtypes=True, ): assert groupsize > 1 # needed for GPTQ single column quantize @@ -347,7 +350,7 @@ def groupwise_affine_quantize_tensor_from_qparams( quant_min = 0 quant_max = 2 ** n_bit - 1 - return quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain = ZeroPointDomain.FLOAT) + return _do_quantize_affine(w, block_size, scales, zeros, output_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, cast_dtypes=cast_dtypes) def groupwise_affine_dequantize_tensor_from_qparams( w_int4x8, @@ -355,6 +358,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( zeros, n_bit=4, groupsize=128, + cast_dtypes=True, ): assert groupsize > 1 # needed for GPTQ single column dequantize @@ -367,7 +371,7 @@ def groupwise_affine_dequantize_tensor_from_qparams( input_dtype = torch.int32 quant_min = 0 quant_max = 2**n_bit - 1 - return dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype) + return _do_dequantize_affine(w_int4x8, block_size, scales, zeros, input_dtype, quant_min, quant_max, zero_point_domain=ZeroPointDomain.FLOAT, output_dtype=scales.dtype, cast_dtypes=cast_dtypes) def groupwise_affine_quantize_tensor(w, n_bit=4, groupsize=128, dtype=torch.bfloat16):