From d71a65a71c4c66d25a3637dd5e617f9e1a13e6e1 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Tue, 14 May 2024 18:20:01 -0700 Subject: [PATCH 1/3] Remove input_quant_func from AffineQuantizedTensor subclass Summary: Currently we have a input_quant_func in the AffineQuantizedTensor, which is a bit convoluted, we want to use a separate LinearActAffineQuantizedTensor subclass for activation quantization (dynamic quantization) instead Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 23 ++- torchao/quantization/subclass.py | 276 +++++++++++++++++++++------- 2 files changed, 219 insertions(+), 80 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index cea659e61d..946febe0a9 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -395,7 +395,10 @@ def test_eval_wrapper(self): # TODO: move to a separate test file @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") def test_quantized_tensor_subclass_8da4w(self): - from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.subclass import ( + AffineQuantizedTensor, + LinearActAffineQuantizedTensor, + ) from torchao.quantization.quant_primitives import MappingType import copy @@ -419,13 +422,16 @@ def get_per_token_block_size(x): # input settings input_mapping_type = MappingType.ASYMMETRIC input_target_dtype = torch.int8 - input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) + + def dynamic_quant(linear): + linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_mapping_type, get_per_token_block_size, input_target_dtype, quant_min, quant_max, eps), requires_grad=False) + linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False) m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() - m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) - m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) + dynamic_quant(m.linear1) + dynamic_quant(m.linear2) assert isinstance(m.linear1.weight, AffineQuantizedTensor) assert isinstance(m.linear2.weight, AffineQuantizedTensor) @@ -461,9 +467,6 @@ def test_quantized_tensor_subclass_int4(self): preserve_zero = False zero_point_dtype = torch.bfloat16 - # weight only quantization - input_quant_func = None - # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) @@ -475,7 +478,6 @@ def to_quantized(weight): zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=ZeroPointDomain.FLOAT, - input_quant_func=input_quant_func, ) m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) @@ -506,16 +508,13 @@ def test_quantized_tensor_subclass_int8(self): eps = torch.finfo(torch.float32).eps zero_point_dtype = torch.int64 - # weight only quantization - input_quant_func = None - m = ToyLinearModel().eval().to(torch.bfloat16) m_copy = copy.deepcopy(m) example_inputs = tuple(map(lambda x: x.to(torch.bfloat16), m.example_inputs())) def to_quantized(weight): block_size = (1, weight.shape[1]) - return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, input_quant_func=input_quant_func) + return AffineQuantizedTensor.from_float(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) m.linear1.weight = torch.nn.Parameter(to_quantized(m.linear1.weight), requires_grad=False) m.linear2.weight = torch.nn.Parameter(to_quantized(m.linear2.weight), requires_grad=False) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 607cb77766..ef772ba422 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -21,6 +21,7 @@ quantize_affine, dequantize_affine, ZeroPointDomain, + MappingType, ) from .utils import find_multiple from typing import Tuple, Optional, Callable @@ -643,7 +644,6 @@ def __new__( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - input_quant_func: Optional[Callable] = None, dtype=None, # TODO: remove args and kwargs *args, @@ -670,7 +670,6 @@ def __init__( quant_min: Optional[int] = None, quant_max: Optional[int] = None, zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - input_quant_func: Optional[Callable] = None, dtype=None, *args, **kwargs @@ -682,12 +681,11 @@ def __init__( self.quant_min = quant_min self.quant_max = quant_max self.zero_point_domain = zero_point_domain - self.input_quant_func = input_quant_func def __repr__(self): return ( f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " - f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})" + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) def dequantize(self, output_dtype=None): @@ -696,14 +694,14 @@ def dequantize(self, output_dtype=None): return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) def __tensor_flatten__(self): - return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.input_quant_func, self.dtype] + return ["int_data", "scales", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - block_size, shape, quant_min, quant_max, zero_point_domain, input_quant_func, dtype = tensor_attributes + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes return cls( int_data, scale, @@ -713,7 +711,6 @@ def __tensor_unflatten__( quant_min, quant_max, zero_point_domain, - input_quant_func=input_quant_func, dtype=dtype, strides=outer_stride, ) @@ -730,7 +727,6 @@ def from_float( eps = None, scale_dtype = None, zero_point_dtype = None, - input_quant_func = None, preserve_zero = True, zero_point_domain = ZeroPointDomain.INT, ): @@ -745,7 +741,6 @@ def from_float( quant_min, quant_max, zero_point_domain, - input_quant_func=input_quant_func, dtype=input_float.dtype ) @@ -759,56 +754,52 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - if weight_qtensor.input_quant_func is None: - is_cuda = args[0].is_cuda - is_cpu = args[0].device == torch.device("cpu") - # weight only quantization - is_int8 = ( - weight_qtensor.int_data.dtype == torch.int8 and - weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and - weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127 - ) - is_uint4 = ( - weight_qtensor.int_data.dtype == torch.int32 and - weight_qtensor.quant_min == 0 and - weight_qtensor.quant_max == 15 - ) + is_cuda = args[1].is_cuda + is_cpu = args[1].device == torch.device("cpu") + # weight only quantization + is_int8 = ( + weight_qtensor.int_data.dtype == torch.int8 and + weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and + weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127 + ) + is_uint4 = ( + weight_qtensor.int_data.dtype == torch.int32 and + weight_qtensor.quant_min == 0 and + weight_qtensor.quant_max == 15 + ) + + # TODO: enable cpu and mps path as well + # TODO: make sure weight dimension matches the expectation of the int4mm kernel + # TODO: move this to TinygemmAffineQuantizedTensor + if ( + is_cuda and + is_uint4 and + weight_qtensor.dtype == torch.bfloat16 and + len(weight_qtensor.shape) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT + ): + # groupwise int4 quantization + # TODO: currently doing packing on the fly, we'll need to figure out + # the API to do packing before hand + # TODO: expose the arg + innerKTiles = 8 + packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) + groupsize = weight_qtensor.block_size[-1] + return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) + elif ( + is_cpu and + is_int8 and + len(weight_qtensor.shape) == 2 and + len(weight_qtensor.block_size) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.block_size[1] == weight_qtensor.shape[1] + ): + # TODO: enable mps path as well + # per channel int8 weight only quantizated mm + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) - # TODO: enable cpu and mps path as well - # TODO: make sure weight dimension matches the expectation of the int4mm kernel - # TODO: move this to TinygemmAffineQuantizedTensor - if ( - is_cuda and - is_uint4 and - weight_qtensor.dtype == torch.bfloat16 and - len(weight_qtensor.shape) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT - ): - # groupwise int4 quantization - # TODO: currently doing packing on the fly, we'll need to figure out - # the API to do packing before hand - # TODO: expose the arg - innerKTiles = 8 - packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) - groupsize = weight_qtensor.block_size[-1] - return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) - elif ( - is_cpu and - is_int8 and - len(weight_qtensor.shape) == 2 and - len(weight_qtensor.block_size) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.block_size[1] == weight_qtensor.shape[1] - ): - # TODO: enable mps path as well - # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) - else: - # dynamic quantization - input_tensor = weight_qtensor.input_quant_func(input_tensor) - input_tensor = input_tensor.dequantize() weight_tensor = weight_qtensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) @@ -816,7 +807,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) except: - print(f"ERR: subclass doesn't implement {func}") + print(f"ERR: AffineQuantizedTensor subclass doesn't implement {func}") def _get_to_kwargs(self, *args, **kwargs): @@ -844,7 +835,6 @@ def to(self, *args, **kwargs): self.quant_min, self.quant_max, self.zero_point_domain, - self.input_quant_func, **kwargs, ) @@ -858,7 +848,6 @@ def _apply_fn_to_data(self, fn): self.quant_min, self.quant_max, self.zero_point_domain, - self.input_quant_func, dtype=self.dtype, ) @@ -900,16 +889,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): args[1], None if len(args) == 2 else args[2], ) - if weight_qtensor.input_quant_func is not None: - # dynamic quantization - input_tensor = weight_qtensor.input_quant_func(input_tensor) - input_tensor = input_tensor.dequantize() weight_tensor = weight_qtensor.dequantize() return func(input_tensor, weight_tensor, bias) - if (func is aten.detach.default or - func is aten.clone.default or - func is aten._to_copy.default): + if func is aten.detach.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) ) @@ -933,3 +916,160 @@ def __torch_dispatch__(cls, func, types, args, kwargs): kwargs, args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) + + raise NotImplementedError( + f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" + ) + + +class LinearActAffineQuantizedTensor(torch.Tensor): + """ + Activation quantization with AffineQuantizedTensor + Applies activation affine quantization for linear operator + """ + def __new__( + cls, + float_tensor: torch.Tensor, + mapping_type: MappingType, + get_block_size: Callable[[torch.Tensor], Tuple[int, ...]], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ): + kwargs = {} + dtype = float_tensor.dtype + kwargs["dtype"] = dtype + kwargs["requires_grad"] = False + shape = float_tensor.shape + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + float_tensor: torch.Tensor, + mapping_type: MappingType, + get_block_size: Callable[[torch.Tensor], Tuple[int, ...]], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ): + self.mapping_type = mapping_type + self.get_block_size = get_block_size + self.target_dtype = target_dtype + self.quant_min = quant_min + self.quant_max = quant_max + self.eps = eps + self.scale_dtype = scale_dtype + self.zero_point_dtype = zero_point_dtype + self.preserve_zero = preserve_zero + self.zero_point_domain = zero_point_domain + + def __tensor_flatten__(self): + return ["float_tensor"], [self.mapping_type, self.get_block_size, self.target_dtype, self.quant_min, self.quant_max, self.eps, self.scale_dtype, self.zero_point_dtype, self.preserve_zero, self.zer_point_domain] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + float_tensor = tensor_data_dict["float_tensor"] + mapping_type, get_block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain = tensor_attributes + return cls( + float_tensor, + mapping_type, + get_block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + + @classmethod + def from_float( + cls, + input_float, + mapping_type, + get_block_size, + target_dtype, + quant_min = None, + quant_max = None, + eps = None, + scale_dtype = None, + zero_point_dtype = None, + preserve_zero = True, + zero_point_domain = ZeroPointDomain.INT, + ): + return cls( + input_float, + mapping_type, + get_block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func is torch.nn.functional.linear: + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + block_size = self.get_block_size(args[0]) + aqt = AffineQuantizedTensor.from_float(input_tensor, self.mapping_type, block_size, self.target_dtype, self.quant_min, self.quant_max, self.eps, self.scale_dtype, self.zero_point_dtype, self.preserve_zero, self.zero_point_domain) + return F.linear(aqt, weight_tensor, bias) + + try: + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + except: + print(f"ERR: LinearActAffineQuantizedTensor subclass doesn't implement {func}") + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.float_tensor), + self.mapping_type, + self.get_block_size, + self.target_dtype, + self.quant_min, + self.quant_max, + self.scale_dtype, + self.zero_point_dtype, + self.preserve_zero, + self.zero_point_domain, + ) + + def __torch_dispatch__(cls, func, types, args, kwargs): + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.clone.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + raise NotImplementedError( + f"LinearActAffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" + ) From 94a058cdd5d049148ab9687ebda62cb1def81410 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 15 May 2024 15:23:15 -0700 Subject: [PATCH 2/3] Add dispatch for dynamic quantization in `AffineQuantizedTensor` Summary: This PR added dispatch for int8act-int8 weight dynamic quantization that's calling `int_scaled_matmul` kernel in the end Test Plan: python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_int8_dyn_quant Reviewers: Subscribers: Tasks: Tags: --- test/quantization/test_quant_api.py | 67 ++++++- torchao/quantization/subclass.py | 294 +++++++++++++++------------- 2 files changed, 227 insertions(+), 134 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 946febe0a9..0d3999298b 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -412,6 +412,7 @@ def test_quantized_tensor_subclass_8da4w(self): quant_max = 7 # TODO: make a general helper function? + # input settings def get_per_token_block_size(x): block_size = [] for i in range(len(x.shape)-1): @@ -422,18 +423,20 @@ def get_per_token_block_size(x): # input settings input_mapping_type = MappingType.ASYMMETRIC input_target_dtype = torch.int8 + input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) def dynamic_quant(linear): - linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_mapping_type, get_per_token_block_size, input_target_dtype, quant_min, quant_max, eps), requires_grad=False) + # note: order is important linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False) + linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() dynamic_quant(m.linear1) dynamic_quant(m.linear2) - assert isinstance(m.linear1.weight, AffineQuantizedTensor) - assert isinstance(m.linear2.weight, AffineQuantizedTensor) + assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor) # reference from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -531,5 +534,63 @@ def to_quantized(weight): torch.testing.assert_close(res, ref, rtol=0.00001, atol=1e-2) + @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") + @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") + def test_quantized_tensor_subclass_int8_dyn_quant(self): + from torchao.quantization.subclass import AffineQuantizedTensor + from torchao.quantization.subclass import LinearActAffineQuantizedTensor + from torchao.quantization.quant_primitives import MappingType + from torchao.quantization.quant_primitives import ZeroPointDomain + import copy + + # weight settings + mapping_type = MappingType.SYMMETRIC + def get_weight_block_size(x): + return (1, x.shape[1]) + target_dtype = torch.int8 + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int64 + + # input settings + def get_per_token_block_size(x): + block_size = list(x.shape) + for i in range(len(block_size)-1): + block_size[i] = 1 + return block_size + + input_mapping_type = MappingType.SYMMETRIC + input_target_dtype = torch.int8 + input_eps = 1e-5 + input_quant_min = -127 + input_quant_max = 127 + input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float) + + # use 1024 so that we don't need padding + m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m_copy = copy.deepcopy(m) + example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) + + def dynamic_quant(linear): + # note: order is important + linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False) + linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + + dynamic_quant(m.linear1) + dynamic_quant(m.linear2) + assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor) + assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) + assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) + + # reference + from torchao.quantization.quant_api import change_linear_weights_to_int8_dqtensors + change_linear_weights_to_int8_dqtensors(m_copy) + + res = m(*example_inputs) + ref = m_copy(*example_inputs) + + self.assertTrue(torch.equal(res, ref)) + + if __name__ == "__main__": unittest.main() diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index ef772ba422..d6bbe8130c 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -23,6 +23,7 @@ ZeroPointDomain, MappingType, ) +from torchao.kernel.intmm import int_scaled_matmul from .utils import find_multiple from typing import Tuple, Optional, Callable @@ -37,6 +38,30 @@ aten = torch.ops.aten +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.int_data.dtype == torch.int8 and + aqt.quant_min is None or aqt.quant_min == -128 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_int8_reduce_range(aqt): + return ( + aqt.int_data.dtype == torch.int8 and + aqt.quant_min is None or aqt.quant_min == -127 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + # TODO: use torch.uint4 + return ( + aqt.int_data.dtype == torch.int32 and + aqt.quant_min is None or aqt.quant_min == 0 and + aqt.quant_max is None or aqt.quant_max == 15 + ) + class QuantizedLinearWeightBase(torch.Tensor): """ @@ -754,55 +779,93 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - is_cuda = args[1].is_cuda - is_cpu = args[1].device == torch.device("cpu") - # weight only quantization - is_int8 = ( - weight_qtensor.int_data.dtype == torch.int8 and - weight_qtensor.quant_min is None or weight_qtensor.quant_min == -128 and - weight_qtensor.quant_max is None or weight_qtensor.quant_max == 127 - ) - is_uint4 = ( - weight_qtensor.int_data.dtype == torch.int32 and - weight_qtensor.quant_min == 0 and - weight_qtensor.quant_max == 15 - ) - - # TODO: enable cpu and mps path as well - # TODO: make sure weight dimension matches the expectation of the int4mm kernel - # TODO: move this to TinygemmAffineQuantizedTensor - if ( - is_cuda and - is_uint4 and - weight_qtensor.dtype == torch.bfloat16 and - len(weight_qtensor.shape) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT - ): - # groupwise int4 quantization - # TODO: currently doing packing on the fly, we'll need to figure out - # the API to do packing before hand - # TODO: expose the arg - innerKTiles = 8 - packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) - groupsize = weight_qtensor.block_size[-1] - return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) - elif ( - is_cpu and - is_int8 and - len(weight_qtensor.shape) == 2 and - len(weight_qtensor.block_size) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.block_size[1] == weight_qtensor.shape[1] - ): - # TODO: enable mps path as well - # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) - - weight_tensor = weight_qtensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - + is_cuda = weight_qtensor.is_cuda + is_cpu = weight_qtensor.device == torch.device("cpu") + if isinstance(weight_qtensor, AffineQuantizedTensor): + weight_is_int8 = _aqt_is_int8(weight_qtensor) + weight_is_uint4 = _aqt_is_uint4(weight_qtensor) + + if isinstance(input_tensor, AffineQuantizedTensor): + # if input tensor is quantized, either dispatch to the int8 mm kernel + # or just dequantize the input tensor + input_is_int8 = _aqt_is_int8_reduce_range(input_tensor) + input_tensor_dtype_is_expected = input_tensor.dtype in [ + torch.float, + torch.bfloat16 + ] + if ( + is_cuda and + input_is_int8 and + input_tensor_dtype_is_expected + ): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.int_data + x_scales = input_tensor.scale + w_vals_int8_t = weight_qtensor.int_data.contiguous().t() + w_scales = weight_qtensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y + else: + input_tensor = input_tensor.dequantize() + + # weight only quantization + + # TODO: enable cpu and mps path as well + # TODO: make sure weight dimension matches the expectation of the int4mm kernel + # TODO: move this to TinygemmAffineQuantizedTensor + if ( + is_cuda and + weight_is_uint4 and + weight_qtensor.dtype == torch.bfloat16 and + len(weight_qtensor.shape) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT + ): + # groupwise int4 quantization + # TODO: currently doing packing on the fly, we'll need to figure out + # the API to do packing before hand + # TODO: expose the arg + innerKTiles = 8 + packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) + groupsize = weight_qtensor.block_size[-1] + return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) + elif ( + is_cpu and + weight_is_int8 and + len(weight_qtensor.shape) == 2 and + len(weight_qtensor.block_size) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.block_size[1] == weight_qtensor.shape[1] + ): + # TODO: enable mps path as well + # per channel int8 weight only quantizated mm + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + else: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) try: with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -929,100 +992,47 @@ class LinearActAffineQuantizedTensor(torch.Tensor): """ def __new__( cls, - float_tensor: torch.Tensor, - mapping_type: MappingType, - get_block_size: Callable[[torch.Tensor], Tuple[int, ...]], - target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + original_weight_tensor: torch.Tensor, + input_quant_func: Callable, ): kwargs = {} - dtype = float_tensor.dtype + dtype = original_weight_tensor.dtype kwargs["dtype"] = dtype kwargs["requires_grad"] = False - shape = float_tensor.shape + shape = original_weight_tensor.shape return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] def __init__( self, - float_tensor: torch.Tensor, - mapping_type: MappingType, - get_block_size: Callable[[torch.Tensor], Tuple[int, ...]], - target_dtype: torch.dtype, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - eps: Optional[float] = None, - scale_dtype: Optional[torch.dtype] = None, - zero_point_dtype: Optional[torch.dtype] = None, - preserve_zero = True, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + original_weight_tensor: torch.Tensor, + input_quant_func: Callable, ): - self.mapping_type = mapping_type - self.get_block_size = get_block_size - self.target_dtype = target_dtype - self.quant_min = quant_min - self.quant_max = quant_max - self.eps = eps - self.scale_dtype = scale_dtype - self.zero_point_dtype = zero_point_dtype - self.preserve_zero = preserve_zero - self.zero_point_domain = zero_point_domain + self.original_weight_tensor = original_weight_tensor + self.input_quant_func = input_quant_func def __tensor_flatten__(self): - return ["float_tensor"], [self.mapping_type, self.get_block_size, self.target_dtype, self.quant_min, self.quant_max, self.eps, self.scale_dtype, self.zero_point_dtype, self.preserve_zero, self.zer_point_domain] + return ["original_weight_tensor"], [self.input_quant_func] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - float_tensor = tensor_data_dict["float_tensor"] - mapping_type, get_block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain = tensor_attributes + original_weight_tensor = tensor_data_dict["original_weight_tensor"] + input_quant_func = tensor_attributes return cls( - float_tensor, - mapping_type, - get_block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - preserve_zero, - zero_point_domain, + original_weight_tensor, + input_quant_func, ) @classmethod def from_float( cls, input_float, - mapping_type, - get_block_size, - target_dtype, - quant_min = None, - quant_max = None, - eps = None, - scale_dtype = None, - zero_point_dtype = None, - preserve_zero = True, - zero_point_domain = ZeroPointDomain.INT, + input_quant_func, ): return cls( input_float, - mapping_type, - get_block_size, - target_dtype, - quant_min, - quant_max, - eps, - scale_dtype, - zero_point_dtype, - preserve_zero, - zero_point_domain, + input_quant_func, ) @classmethod @@ -1035,10 +1045,11 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - block_size = self.get_block_size(args[0]) - aqt = AffineQuantizedTensor.from_float(input_tensor, self.mapping_type, block_size, self.target_dtype, self.quant_min, self.quant_max, self.eps, self.scale_dtype, self.zero_point_dtype, self.preserve_zero, self.zero_point_domain) - return F.linear(aqt, weight_tensor, bias) - + if isinstance(weight_tensor, LinearActAffineQuantizedTensor): + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return torch.nn.functional.linear(aqt, original_weight_tensor, bias) try: with torch._C.DisableTorchFunctionSubclass(): return func(*args, **kwargs) @@ -1047,19 +1058,40 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): def _apply_fn_to_data(self, fn): return self.__class__( - fn(self.float_tensor), - self.mapping_type, - self.get_block_size, - self.target_dtype, - self.quant_min, - self.quant_max, - self.scale_dtype, - self.zero_point_dtype, - self.preserve_zero, - self.zero_point_domain, + fn(self.original_weight_tensor), + self.input_quant_func, ) def __torch_dispatch__(cls, func, types, args, kwargs): + if ( + func in [aten.mm.default, aten.addmm.default] + and args[0].is_floating_point() + ): + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + input_tensor, weight_qtensor, bias = ( + args[1], + args[2], + args[0], + ) + aqt = self.input_quant_func(input_tensor) + return func(bias, aqt, weight_tensor) + else: + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + None if len(args) == 2 else args[2], + ) + aqt = self.input_quant_func(input_tensor) + return func(aqt, weight_tensor, bias) + if func is aten.detach.default: return return_and_correct_aliasing( func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) From b43bce7a3b500004a3557c318f8756d098c5080e Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 15 May 2024 15:50:12 -0700 Subject: [PATCH 3/3] Fix test --- test/quantization/test_quant_api.py | 16 ++++++------- torchao/quantization/subclass.py | 35 +++++++++++++---------------- 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 0d3999298b..fcab07c913 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -397,7 +397,7 @@ def test_eval_wrapper(self): def test_quantized_tensor_subclass_8da4w(self): from torchao.quantization.subclass import ( AffineQuantizedTensor, - LinearActAffineQuantizedTensor, + LinearActQuantizedTensor, ) from torchao.quantization.quant_primitives import MappingType import copy @@ -428,15 +428,15 @@ def get_per_token_block_size(x): def dynamic_quant(linear): # note: order is important linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() dynamic_quant(m.linear1) dynamic_quant(m.linear2) - assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor) - assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor) + assert isinstance(m.linear1.weight, LinearActQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActQuantizedTensor) # reference from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer @@ -538,7 +538,7 @@ def to_quantized(weight): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): from torchao.quantization.subclass import AffineQuantizedTensor - from torchao.quantization.subclass import LinearActAffineQuantizedTensor + from torchao.quantization.subclass import LinearActQuantizedTensor from torchao.quantization.quant_primitives import MappingType from torchao.quantization.quant_primitives import ZeroPointDomain import copy @@ -573,12 +573,12 @@ def get_per_token_block_size(x): def dynamic_quant(linear): # note: order is important linear.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(linear.weight, mapping_type, get_weight_block_size(linear.weight), target_dtype, eps=eps, zero_point_dtype=zero_point_dtype), requires_grad=False) - linear.weight = torch.nn.Parameter(LinearActAffineQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) + linear.weight = torch.nn.Parameter(LinearActQuantizedTensor.from_float(linear.weight, input_quant_func), requires_grad=False) dynamic_quant(m.linear1) dynamic_quant(m.linear2) - assert isinstance(m.linear1.weight, LinearActAffineQuantizedTensor) - assert isinstance(m.linear2.weight, LinearActAffineQuantizedTensor) + assert isinstance(m.linear1.weight, LinearActQuantizedTensor) + assert isinstance(m.linear2.weight, LinearActQuantizedTensor) assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor) assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor) diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index d6bbe8130c..bc40ffeaff 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -46,10 +46,10 @@ def _aqt_is_int8(aqt): aqt.quant_max is None or aqt.quant_max == 127 ) -def _aqt_is_int8_reduce_range(aqt): +def _aqt_is_int8_reduced_range(aqt): return ( aqt.int_data.dtype == torch.int8 and - aqt.quant_min is None or aqt.quant_min == -127 and + aqt.quant_min == -127 and aqt.quant_max is None or aqt.quant_max == 127 ) @@ -788,7 +788,7 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): if isinstance(input_tensor, AffineQuantizedTensor): # if input tensor is quantized, either dispatch to the int8 mm kernel # or just dequantize the input tensor - input_is_int8 = _aqt_is_int8_reduce_range(input_tensor) + input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) input_tensor_dtype_is_expected = input_tensor.dtype in [ torch.float, torch.bfloat16 @@ -830,7 +830,6 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): input_tensor = input_tensor.dequantize() # weight only quantization - # TODO: enable cpu and mps path as well # TODO: make sure weight dimension matches the expectation of the int4mm kernel # TODO: move this to TinygemmAffineQuantizedTensor @@ -862,15 +861,16 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): # TODO: enable mps path as well # per channel int8 weight only quantizated mm return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + else: + weight_tensor = weight_qtensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) else: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - try: - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - except: - print(f"ERR: AffineQuantizedTensor subclass doesn't implement {func}") + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) def _get_to_kwargs(self, *args, **kwargs): @@ -985,10 +985,9 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) -class LinearActAffineQuantizedTensor(torch.Tensor): +class LinearActQuantizedTensor(torch.Tensor): """ - Activation quantization with AffineQuantizedTensor - Applies activation affine quantization for linear operator + Applies activation quantization for linear operator """ def __new__( cls, @@ -1045,16 +1044,14 @@ def __torch_function__(cls, func, types, args=(), kwargs=None): args[1], args[2] if len(args) > 2 else None, ) - if isinstance(weight_tensor, LinearActAffineQuantizedTensor): + if isinstance(weight_tensor, LinearActQuantizedTensor): input_quant_func = weight_tensor.input_quant_func original_weight_tensor = weight_tensor.original_weight_tensor aqt = input_quant_func(input_tensor) return torch.nn.functional.linear(aqt, original_weight_tensor, bias) - try: - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - except: - print(f"ERR: LinearActAffineQuantizedTensor subclass doesn't implement {func}") + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) def _apply_fn_to_data(self, fn): return self.__class__( @@ -1103,5 +1100,5 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) raise NotImplementedError( - f"LinearActAffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" + f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" )