From 95fc9240b0d38314e9a705ad62148e3ac4117a4a Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Fri, 12 Jul 2024 13:33:38 -0700 Subject: [PATCH 1/2] Add cachemask variant for fake_quantize_affine Summary: In QAT, we often wish to filter out the gradients corresponding to values outside the expected quantization range, for example: ``` q = _quantize_affine_no_dtype_cast(...) dq = _dequantize_affine_no_dtype_check(...) mask = torch.logical_and((q >= quant_min), (q <= quant_max)) grad = grad * mask ``` The existing `fake_quantize_affine` returns the dequantized values only, so callers do not have access to this mask. This commit adds the variant to this op that returns both the dequantized values and the mask, similar to `fake_quantize_per_tensor_affine_cachemask` in core. Test Plan: python test/quantization/test_quant_primitives.py -k test_fake_quantize_affine_cachemask --- test/quantization/test_quant_primitives.py | 24 +++++++ torchao/quantization/quant_primitives.py | 73 +++++++++++++++++++++- 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_quant_primitives.py b/test/quantization/test_quant_primitives.py index 6e26256e96..5f8680a509 100644 --- a/test/quantization/test_quant_primitives.py +++ b/test/quantization/test_quant_primitives.py @@ -10,6 +10,7 @@ import torch from torchao.quantization.quant_primitives import ( fake_quantize_affine, + fake_quantize_affine_cachemask, quantize_affine, dequantize_affine, choose_qparams_affine, @@ -523,5 +524,28 @@ def test_fake_quantize_affine(self): fake_quantized = fake_quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) torch.testing.assert_close(dequantized, fake_quantized) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + def test_fake_quantize_affine_cachemask(self): + input = torch.randn(10, 10) + + mapping_type = MappingType.SYMMETRIC + block_size = list(input.shape) + for i in range(len(block_size) - 1): + block_size[i] = 1 + dtype = torch.int8 + eps = 1e-5 + quant_min = -127 + quant_max = 127 + scale, zero_point = choose_qparams_affine(input, mapping_type, block_size, dtype, quant_min, quant_max, eps=eps, scale_dtype=torch.float) + + quantized = quantize_affine(input, block_size, scale, zero_point, dtype, quant_min, quant_max) + dequantized = dequantize_affine(quantized, block_size, scale, zero_point, dtype, quant_min, quant_max) + (fake_quantized, mask) = fake_quantize_affine_cachemask( + input, block_size, scale, zero_point, dtype, quant_min, quant_max, + ) + expected_mask = torch.full(input.shape, True) + torch.testing.assert_close(dequantized, fake_quantized) + torch.testing.assert_close(expected_mask, mask) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index 8f860917f4..9745a26d1a 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -24,6 +24,7 @@ "quantize_affine", "dequantize_affine", "fake_quantize_affine", + "fake_quantize_affine_cachemask", ] class MappingType(Enum): @@ -411,6 +412,76 @@ def fake_quantize_affine( value during quantization default is ZeroPointDomain.INT """ + (_, fq) = _do_fake_quantize_affine( + input, + block_size, + scale, + zero_point, + quant_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + return fq + + +def fake_quantize_affine_cachemask( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + General fake quantize op for quantization-aware training (QAT). + This is equivalent to calling `quantize_affine` + `dequantize_affine` + but without the dtype casts. + + Note: Compared to :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`, + this consumes more memory and returns an additional outlier mask for + intermediate quantized values. + + Args: + Same as :func:`~torchao.quantization.quant_primitives.fake_quantize_affine`. + + Returns: + A 2-tuple of ( + final fake quantized values, + outlier mask for intermediate quantized values + ) + + """ + (q, dq) = _do_fake_quantize_affine( + input, + block_size, + scale, + zero_point, + quant_dtype, + quant_min, + quant_max, + zero_point_domain, + ) + mask = torch.logical_and((q >= quant_min), (q <= quant_max)) + return (dq, mask) + + +def _do_fake_quantize_affine( + input: torch.Tensor, + block_size: Tuple[int, ...], + scale: torch.Tensor, + zero_point: Optional[torch.Tensor], + quant_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Helper function for `fake_quantize_affine` that returns both the + intermediate quantized values and the final dequantized values. + """ input_dtype = input.dtype quant_min, quant_max = _get_and_check_qmin_qmax(quant_dtype, quant_min, quant_max) q = _quantize_affine_no_dtype_cast( @@ -432,7 +503,7 @@ def fake_quantize_affine( zero_point_domain.name, output_dtype=input_dtype, ) - return dq + return (q, dq) def choose_qparams_affine( From 2f91cd3a9423a7e1cf5a9edaa63cfeae9310f5af Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Sat, 15 Jun 2024 15:14:14 -0700 Subject: [PATCH 2/2] Add support for int4 weight-only QAT Summary: This commit adds support for int4 weight-only QAT, which simulates the numerics of the existing Int4WeightOnlyQuantizer. The main motivation for this is to provide an end-to-end path for running QAT and lowering to the efficient int4 tinygemm cuda kernel. To enable this, we have to add new fake quantization primitives to match the numerics of the tinygemm kernel, and this required refactoring existing quant primitives to skip dtype casting. Test Plan: python test/quantization/test_qat.py -k test_qat_4w_linear Reviewers: jerryzh168, msaroufim Subscribers: jerryzh168, msaroufim, HDCharles, supriyar --- test/quantization/test_qat.py | 193 ++++++++++++++++++++--- torchao/quantization/GPTQ.py | 109 ++++++++++--- torchao/quantization/prototype/qat.py | 216 +++++++++++++++++++++++++- 3 files changed, 477 insertions(+), 41 deletions(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index 433fdcb28a..e53ef03819 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -18,18 +18,28 @@ fake_quantize_per_channel_group, fake_quantize_per_token, ) -from torchao.quantization.utils import get_group_qparams_symmetric +from torchao.quantization.quant_primitives import ( + fake_quantize_affine, + ZeroPointDomain, +) +from torchao.quantization.utils import ( + get_group_qparams_symmetric, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, +) from torchao.utils import TORCH_VERSION_AFTER_2_4 # TODO: put this in a common test utils file +_CUDA_IS_AVAILABLE = torch.cuda.is_available() + class Sub(torch.nn.Module): def __init__(self): super().__init__() - self.linear = torch.nn.Linear(32, 32, bias=False).to(torch.float) + self.linear = torch.nn.Linear(256, 256, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 32).to(torch.float),) + return (torch.randn(1, 256).to(torch.float),) def forward(self, x): return self.linear(x) @@ -37,12 +47,12 @@ def forward(self, x): class M(torch.nn.Module): def __init__(self): super().__init__() - self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) + self.linear1 = torch.nn.Linear(512, 256, bias=False).to(torch.float) self.sub = Sub() - self.linear2 = torch.nn.Linear(32, 64, bias=False).to(torch.float) + self.linear2 = torch.nn.Linear(256, 512, bias=False).to(torch.float) def example_inputs(self): - return (torch.randn(1, 64).to(torch.float),) + return (torch.randn(1, 512).to(torch.float),) def forward(self, x): x = self.linear1(x) @@ -111,23 +121,46 @@ def test_fake_quantize_per_token(self): def _set_ptq_weight( self, - ptq_linear: "Int8DynActInt4WeightLinear", - fp32_weight: torch.Tensor, - group_size: int, + ptq_linear: torch.nn.Module, + qat_linear: torch.nn.Module, ): """ Set the weight to the quantized version of the given fp32 weights, for making linear outputs comparable with QAT. """ + from torchao.quantization.GPTQ import ( + Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, + ) + from torchao.quantization.prototype.qat import ( + Int8DynActInt4WeightQATLinear, + Int4WeightOnlyQATLinear, + ) n_bit = 4 (qmin, qmax) = self._get_qmin_qmax(n_bit) - (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) - q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( - fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, - ) - ptq_linear.weight = q_weight - ptq_linear.scales = s - ptq_linear.zeros = zp + if isinstance(ptq_linear, Int8DynActInt4WeightLinear): + assert isinstance(qat_linear, Int8DynActInt4WeightQATLinear) + fp32_weight = qat_linear.weight + group_size = qat_linear.groupsize + (s, zp) = get_group_qparams_symmetric(fp32_weight, n_bit, group_size) + q_weight = torch.ops.quantized_decomposed.quantize_per_channel_group( + fp32_weight, s, zp, qmin, qmax, torch.int8, group_size, + ) + ptq_linear.weight = q_weight + ptq_linear.scales = s + ptq_linear.zeros = zp + elif isinstance(ptq_linear, WeightOnlyInt4Linear): + assert isinstance(qat_linear, Int4WeightOnlyQATLinear) + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + qat_linear.weight, n_bit, qat_linear.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to("cuda"), qat_linear.inner_k_tiles, + ) + ptq_linear.weight = q_weight + ptq_linear.scales_and_zeros = scales_and_zeros + else: + raise ValueError("Unknown ptq_linear type: %s" % type(ptq_linear)) @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") def test_qat_8da4w_linear(self): @@ -144,7 +177,7 @@ def test_qat_8da4w_linear(self): ) # Force the weights to be the same - self._set_ptq_weight(ptq_linear, qat_linear.weight, group_size) + self._set_ptq_weight(ptq_linear, qat_linear) # Compare linear values torch.manual_seed(self.SEED) @@ -280,7 +313,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant_backward(self): loss_fn1 = torch.nn.CrossEntropyLoss() loss_fn2 = torch.nn.CrossEntropyLoss() example_inputs = nn_model.example_inputs() - target = torch.randn(1, 64).float() + target = torch.randn(1, 512).float() output1 = nn_model(*example_inputs) output2 = qat_model(*example_inputs) torch.testing.assert_close(output1, output2, atol=0, rtol=0) @@ -322,6 +355,130 @@ def test_qat_generic_fake_quantize(self): torch.testing.assert_close(py_out, ao_out, atol=0, rtol=0) torch.testing.assert_close(py_input.grad, ao_input.grad, atol=0, rtol=0) + def _assert_close_4w(self, val, ref): + # Note: for int4 weight-only quantization, we do not expect exact match + # because torch._weight_int4pack_mm and torch.mm do not match exactly. + # Here we use the same error bar as PyTorch core to determine closeness: + # https://github.com/pytorch/pytorch/blob/6079c5091091d872b8dafbaa4e31a5b6194647ad/test/test_linalg.py#L6079 + mean_err = ((val - ref) / ref).mean().abs() + print(mean_err) + self.assertTrue(mean_err < 0.05) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_primitives(self): + n_bit = 4 + group_size = 32 + inner_k_tiles = 8 + scales_precision = torch.bfloat16 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + weight = torch.randn(512, 256, dtype=dtype, device=device) + + # PTQ + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + weight, n_bit, group_size, scales_precision, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(device), inner_k_tiles, + ) + ptq_out = torch.ops.aten._weight_int4pack_mm( + x, q_weight, group_size, scales_and_zeros + ) + + # QAT + block_size = (1, group_size) + quant_min = 0 + quant_max = 2 ** n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + weight, n_bit, group_size, scales_precision, + ) + w_fq = fake_quantize_affine( + weight, + block_size, + scales, + zero_points, + torch.int32, + quant_min, + quant_max, + zero_point_domain = ZeroPointDomain.FLOAT, + ) + qat_out = torch.nn.functional.linear(x, w_fq) + + self._assert_close_4w(qat_out, ptq_out) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_linear(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear + from torchao.quantization.GPTQ import WeightOnlyInt4Linear + + group_size = 128 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + qat_linear = Int4WeightOnlyQATLinear( + 256, 688, bias=False, groupsize=group_size, device=device, + ) + ptq_linear = WeightOnlyInt4Linear( + 256, 688, bias=False, groupsize=group_size, device=device, + ) + + # Force the weights to be the same + self._set_ptq_weight(ptq_linear, qat_linear) + + # Compare linear values + torch.manual_seed(self.SEED) + x = torch.randn(100, 256, dtype=dtype, device=device) + x2 = copy.deepcopy(x) + qat_out = qat_linear(x) + ptq_out = ptq_linear(x2) + self._assert_close_4w(qat_out, ptq_out) + + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") + @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + def test_qat_4w_quantizer(self): + from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer + from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer + + group_size = 32 + inner_k_tiles = 8 + device = torch.device("cuda") + dtype = torch.bfloat16 + torch.manual_seed(self.SEED) + m = M().to(device).to(dtype) + m2 = copy.deepcopy(m) + qat_quantizer = Int4WeightOnlyQATQuantizer( + groupsize=group_size, inner_k_tiles=inner_k_tiles, + ) + ptq_quantizer = Int4WeightOnlyQuantizer( + groupsize=group_size, inner_k_tiles=inner_k_tiles, + ) + qat_model = qat_quantizer.prepare(m) + ptq_model = ptq_quantizer.quantize(m2) + + # Compare model values + torch.manual_seed(self.SEED) + x = [i.to(device).to(dtype) for i in m.example_inputs()] + x2 = copy.deepcopy(x) + qat_out = qat_model(*x) + ptq_out = ptq_model(*x2) + self._assert_close_4w(qat_out, ptq_out) + + # Convert QAT model and compare model values + converted_model = qat_quantizer.convert(qat_model) + converted_out = converted_model(*x) + torch.testing.assert_close(converted_out, ptq_out, atol=0, rtol=0) + + # Compare converted state dict + ptq_state_dict = ptq_model.state_dict() + converted_state_dict = converted_model.state_dict() + self.assertEqual(ptq_state_dict.keys(), converted_state_dict.keys()) + for k in ptq_state_dict.keys(): + torch.testing.assert_close(ptq_state_dict[k], converted_state_dict[k], atol=0, rtol=0) + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 99b7621eed..e45bb26e4d 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -9,7 +9,7 @@ # LICENSE file in the root directory of this source tree. import logging -from typing import Optional, List, Type +from typing import Optional, Callable, List, Type import torch @@ -525,14 +525,22 @@ def _check_linear_int4_k(k, groupsize = 1, inner_k_tiles = None): return k_divisible_by_groupsize and k_divisible_by_16_times_inner_k_tiles return k_divisible_by_groupsize -def linear_forward_int4(x, weight_int4pack, scales_and_zeros, out_features, groupsize, dtype=torch.bfloat16): +def linear_forward_int4( + x: torch.Tensor, + weight_int4pack: torch.Tensor, + scales_and_zeros: torch.Tensor, + out_features: int, + groupsize: int, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, +): origin_x_size = x.size() x = x.reshape(-1, origin_x_size[-1]) c = torch.ops.aten._weight_int4pack_mm( - x.to(dtype), + x.to(precision), weight_int4pack, groupsize, - scales_and_zeros.to(dtype) + scales_and_zeros.to(scales_precision) ).to(dtype=x.dtype) new_shape = origin_x_size[:-1] + (out_features,) c = c.reshape(new_shape) @@ -546,7 +554,9 @@ class WeightOnlyInt4Linear(torch.nn.Module): def __init__( self, in_features: int, out_features: int, - bias=False, device=None, dtype=torch.bfloat16, groupsize: int = 128, inner_k_tiles: int = 8, + # TODO: remove dtype field, not used + bias=False, device=None, dtype=None, groupsize: int = 128, inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, scales_precision: torch.dtype = torch.bfloat16, ) -> None: super().__init__() self.padding = not _check_linear_int4_k(in_features, groupsize, inner_k_tiles) @@ -558,42 +568,96 @@ def __init__( self.in_features = in_features self.out_features = out_features assert not bias, "require bias=False" + self.device = device self.groupsize = groupsize self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + + if dtype is not None: + raise ValueError("Please specify 'precision' instead of 'dtype'") assert out_features % 8 == 0, "require out_features % 8 == 0" assert in_features % (inner_k_tiles * 16) == 0, "require in_features % (innerKTiles * 16) == 0" self.register_buffer( "weight", - torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32) + torch.empty((out_features // 8, in_features // (inner_k_tiles * 16), 32, inner_k_tiles // 2), dtype=torch.int32, device=device) ) self.dtype = dtype self.register_buffer( "scales_and_zeros", - torch.empty((in_features // groupsize, out_features, 2), dtype=self.dtype) + torch.empty((in_features // groupsize, out_features, 2), dtype=self.scales_precision, device=device) ) def forward(self, input: torch.Tensor) -> torch.Tensor: if self.padding: - import torch.nn.functional as F input = F.pad(input, pad=(0, self.in_features - self.origin_in_features)) return linear_forward_int4( input, - self.weight, self.scales_and_zeros, self.out_features, self.groupsize, self.dtype + self.weight, + self.scales_and_zeros, + self.out_features, + self.groupsize, + self.precision, + self.scales_precision, ) -def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None, dtype=torch.bfloat16): +def _replace_linear_int4( + module: torch.nn.Module, + groupsize: int, + inner_k_tiles: Optional[int], + padding_allowed: bool, + skip_layer_func: Optional[Callable] = None, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + linear_class: Type[torch.nn.Module] = WeightOnlyInt4Linear, + copy_weights: bool = False, +): for name, child in module.named_children(): if isinstance(child, nn.Linear) and (skip_layer_func is None or not skip_layer_func(child.weight)): if _check_linear_int4_k(child.in_features, groupsize, inner_k_tiles) or padding_allowed: - setattr(module, name, WeightOnlyInt4Linear( - child.in_features, child.out_features, bias=False, - groupsize=groupsize, inner_k_tiles=inner_k_tiles, - dtype=dtype, - )) + new_linear = linear_class( + child.in_features, + child.out_features, + bias=False, + device=child.weight.device, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=precision, + scales_precision=scales_precision, + ) + # TODO: merge with 8da4w? + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if copy_weights and child.weight.device != torch.device("meta"): + new_linear.weight = child.weight + setattr(module, name, new_linear) else: - replace_linear_int4(child, groupsize, inner_k_tiles, padding_allowed, skip_layer_func, dtype) + _replace_linear_int4( + child, + groupsize, + inner_k_tiles, + padding_allowed, + skip_layer_func, + precision, + scales_precision, + linear_class, + copy_weights, + ) + + +def replace_linear_int4(module, groupsize, inner_k_tiles, padding_allowed, skip_layer_func = None): + _replace_linear_int4( + module, + groupsize, + inner_k_tiles, + padding_allowed, + skip_layer_func, + linear_class=WeightOnlyInt4Linear, + ) + class Int4WeightOnlyQuantizer(Quantizer): def __init__( @@ -655,19 +719,21 @@ def _create_quantized_state_dict( self.groupsize, self.precision, # dtype for scales_and_zeros ) + # TODO: just get the device from mod.weight.device? weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device) return cur_state_dict def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: - replace_linear_int4( + _replace_linear_int4( model, self.groupsize, self.inner_k_tiles, self.padding_allowed, skip_layer_func=None, - dtype=self.precision, + precision=self.precision, + scales_precision=self.precision, ) return model @@ -680,6 +746,7 @@ def quantize( model.load_state_dict(state_dict, strict=False) return model + class Int4WeightOnlyGPTQQuantizer(GPTQQuantizer): def __init__( self, @@ -834,6 +901,7 @@ def __init__( out_features: int, bias=True, device=None, + # TODO: remove this field, not used dtype=None, groupsize: int = 256, precision: torch.dtype = torch.float32, @@ -858,6 +926,9 @@ def __init__( # that his module represents. self.precision = precision + if dtype is not None: + raise ValueError("Please specify 'precision' instead of 'dtype'") + # currently storing unpacked int8 weights self.register_buffer( "weight", @@ -1011,6 +1082,7 @@ def _convert_for_runtime(self, model: torch.nn.Module) -> torch.nn.Module: self.groupsize, self.padding_allowed, self.precision, + # TODO: this should be self.scales_precision? self.precision, ) return model @@ -1096,6 +1168,7 @@ def _convert_for_runtime(self, model): self.groupsize, self.padding_allowed, self.precision, + # TODO: this should be self.scales_precision? self.precision, ) return model diff --git a/torchao/quantization/prototype/qat.py b/torchao/quantization/prototype/qat.py index 71b585b15e..ac056916c4 100644 --- a/torchao/quantization/prototype/qat.py +++ b/torchao/quantization/prototype/qat.py @@ -4,21 +4,34 @@ # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. -from typing import Any, Tuple +from typing import Any, Optional, Tuple import torch +import torch.nn.functional as F from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib from torch.library import impl -from torchao.quantization.utils import get_group_qparams_symmetric -from torchao.quantization.unified import TwoStepQuantizer - from torchao.quantization.GPTQ import ( + _check_linear_int4_k, + _replace_linear_int4, _replace_linear_8da4w, + get_groupwise_affine_qparams, + groupwise_affine_quantize_tensor, Int8DynActInt4WeightLinear, + WeightOnlyInt4Linear, +) +from torchao.quantization.quant_primitives import ( + fake_quantize_affine_cachemask, + ZeroPointDomain, ) +from torchao.quantization.unified import TwoStepQuantizer +from torchao.quantization.utils import get_group_qparams_symmetric +# ================= +# | 8da4w QAT | +# ================= + class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer): """ Quantizer for performing QAT on a model, where linear layers have int8 @@ -171,7 +184,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) else: w_fq = self.weight - return torch.nn.functional.linear(x_fq, w_fq) + return F.linear(x_fq, w_fq) # TODO: move this to common util def _get_qmin_qmax(self, n_bit: int): @@ -193,10 +206,203 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module): if isinstance(mod, Int8DynActInt4WeightQATLinear): mod.disable_fake_quant() + +# ================== +# | int4wo QAT | +# ================== + +class Int4WeightOnlyQATQuantizer(TwoStepQuantizer): + """ + Quantizer for performing QAT on a model, where linear layers have + int4 fake quantized grouped per channel weights. + """ + + def __init__( + self, + groupsize: int = 256, + inner_k_tiles: Optional[int] = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__() + assert inner_k_tiles in [2, 4, 8] + assert groupsize in [32, 64, 128, 256] + self.inner_k_tiles = inner_k_tiles + self.groupsize = groupsize + self.precision = precision + self.scales_precision = scales_precision + + def prepare( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _replace_linear_int4( + model, + self.groupsize, + self.inner_k_tiles, + padding_allowed=True, + precision=self.precision, + scales_precision=self.scales_precision, + linear_class=Int4WeightOnlyQATLinear, + copy_weights=True, + ) + return model + + def convert( + self, + model: torch.nn.Module, + *args: Any, + **kwargs: Any + ) -> torch.nn.Module: + _convert_qat_linear_4w(model) + return model + +def _convert_qat_linear_4w(module: torch.nn.Module): + """ + Replace all `Int4WeightOnlyQATLinear` with `WeightOnlyInt4Linear`. + """ + for name, child in module.named_children(): + if isinstance(child, Int4WeightOnlyQATLinear): + in_features = child.in_features + out_features = child.out_features + groupsize = child.groupsize + inner_k_tiles = child.inner_k_tiles + quantized_linear = WeightOnlyInt4Linear( + in_features, + out_features, + bias=False, + groupsize=groupsize, + inner_k_tiles=inner_k_tiles, + precision=child.precision, + scales_precision=child.scales_precision, + ) + setattr(module, name, quantized_linear) + + # Load weights and qparams into quantized linear + n_bit = 4 + (q_weight, scales_and_zeros) = groupwise_affine_quantize_tensor( + child.weight, n_bit, child.groupsize, + ) + q_weight = torch.ops.aten._convert_weight_to_int4pack( + q_weight.to(child.weight.device), child.inner_k_tiles, + ) + quantized_linear.weight = q_weight + quantized_linear.scales_and_zeros = scales_and_zeros + else: + _convert_qat_linear_4w(child) + +class Int4WeightOnlyQATLinear(torch.nn.Linear): + """ + This module implements a linear layer with int4 fake quantized grouped + per channel weights, with forward numerics matching `WeightOnlyInt4Linear`, + which uses the efficient int4 tinygemm kernel. + + args: + groupsize: the number of elements in each quantized group for weights + precision: precision of weights + scales_precision: precision of per group scales and zero points + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = False, + device: torch.device = None, + groupsize: int = 256, + inner_k_tiles: int = 8, + precision: torch.dtype = torch.bfloat16, + scales_precision: torch.dtype = torch.bfloat16, + ) -> None: + super().__init__( + in_features, + out_features, + bias, + device=device, + dtype=precision, + ) + assert not bias, "require bias=False" + assert scales_precision == torch.bfloat16, "only bf16 is supported for scales" + if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles): + raise ValueError("Padding for QAT 4w is not supported yet") + self.groupsize = groupsize + self.inner_k_tiles = inner_k_tiles + self.precision = precision + self.scales_precision = scales_precision + self._fake_quant_enabled = True + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + n_bit = 4 + qmin = 0 + qmax = 2 ** n_bit - 1 + scales, zero_points = get_groupwise_affine_qparams( + self.weight, n_bit, self.groupsize, self.scales_precision, + ) + w_fq = _Int4WeightOnlyFakeQuantize.apply( + self.weight, scales, zero_points, qmin, qmax, self.groupsize, + ) + return F.linear(x, w_fq) + +def enable_4w_fake_quant(mod: torch.nn.Module): + """ + Enable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.enable_fake_quant() + +def disable_4w_fake_quant(mod: torch.nn.Module): + """ + Disable fake quantization for `Int4WeightOnlyQATLinear`. + """ + if isinstance(mod, Int4WeightOnlyQATLinear): + mod.disable_fake_quant() + + # ======================== # | QUANT PRIMITIVES | # ======================== +class _Int4WeightOnlyFakeQuantize(torch.autograd.Function): + """ + Implementation of int4 grouped per channel weight-only fake quantize + intended to match the numerics of the efficient int4 tinygemm kernel. + """ + + @staticmethod + def forward(ctx, input, scales, zero_points, quant_min, quant_max, groupsize): + assert groupsize > 1 + assert input.shape[-1] % groupsize == 0 + assert input.dim() == 2 + n_bit = 4 + block_size = (1, groupsize) + quant_min = 0 + quant_max = 2 ** n_bit - 1 + (fq, mask) = fake_quantize_affine_cachemask( + input, + block_size, + scales, + zero_points, + torch.int32, + quant_min, + quant_max, + zero_point_domain = ZeroPointDomain.FLOAT, + ) + ctx.save_for_backward(mask) + return fq + + @staticmethod + def backward(ctx, gy): + (mask,) = ctx.saved_tensors + return gy * mask, None, None, None, None, None + class _GenericFakeQuantize(torch.autograd.Function): """ Implementation of generic fake quantize with backward STE.