diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index 35b0107836..365e21aafe 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -493,7 +493,7 @@ def test_quantized_tensor_subclass_int8(self): @unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") def test_quantized_tensor_subclass_int8_dyn_quant(self): # use 1024 so that we don't need padding - m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") + m = ToyLinearModel(1024, 1024, 2048).eval().to(torch.bfloat16).to("cuda") m_copy = copy.deepcopy(m) # setting batch_size to 20 to be compatible with the kernel example_inputs = m.example_inputs(batch_size=20, dtype=torch.bfloat16, device="cuda") diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py index f4b758ddca..fe3e629add 100644 --- a/torchao/dtypes/aqt.py +++ b/torchao/dtypes/aqt.py @@ -177,6 +177,11 @@ def _apply_fn_to_data(self, fn): fn(self.zero_point), ) + def _change_shape(self, shape): + return self.__class__( + self.int_data.view(shape), self.scale, self.zero_point + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -245,6 +250,7 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] + # TODO: fix the unflatten logic return cls(packed_weight, scale_and_zero) def to(self, *args, **kwargs): @@ -282,6 +288,74 @@ def get_plain(self): f"Unpacking for tensor core tiled storage is not yet implemented" ) +@register_aqt_layout_cls("transposed") +class TransposedAQTLayout(PlainAQTLayout): + """ + Layout storage class for transposed layout for affine quantized tensor, it's the same as + plain layout but stores transposed int_data. + + fields: + int_data (torch.Tensor): the transposed quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + """ + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + ): + self.int_data = int_data.contiguous().t() + self.scale = scale + self.zero_point = zero_point + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + return cls(int_data.t(), scale, zero_point) + + def _change_shape(self, shape): + return self.__class__( + self.int_data.t().view(shape), self.scale, self.zero_point + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + kwargs = {} if kwargs is None else kwargs + + if func is aten.detach.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + if func is aten.view.default: + assert len(args) == 2 + new = args[0]._change_shape(args[1]) + return return_and_correct_aliasing( + func, args, kwargs, new + ) + + raise NotImplementedError( + f"TransposedAQTLayout dispatch: attempting to run {func}, this is not supported" + ) + + __torch_function__ = torch._C._disabled_torch_function_impl + + def get_plain(self): + return self.int_data.t(), self.scale, self.zero_point + class AffineQuantizedTensor(torch.Tensor): """ Base affine quantized tensor subclass. When the from_float method is used, @@ -356,7 +430,7 @@ def __init__( def __repr__(self): return ( - f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " + f"{self.__class__.__name__}(data={self.dequantize(self.dtype)}, shape={self.shape}, " f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" ) @@ -470,6 +544,11 @@ def _apply_fn_to_data(self, fn): strides=self.stride(), ) + def _change_shape(self, shape, block_size): + return self.__class__( + self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride() + ) + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): # Note: we only added cpu path here for 8da4w, this is for executorch, in the future @@ -491,13 +570,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) -@implements_aqt_torch_function(torch.nn.functional.linear) -def functional_linear(*args, **kwargs): - input_tensor, weight_qtensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) +def _quantized_linear_op(input_tensor, weight_qtensor, bias): is_cuda = weight_qtensor.is_cuda is_cpu = weight_qtensor.device == torch.device("cpu") if isinstance(weight_qtensor, AffineQuantizedTensor): @@ -516,9 +589,14 @@ def functional_linear(*args, **kwargs): is_cuda and input_is_int8 and input_tensor_dtype_is_expected and + input_tensor.dtype == weight_qtensor.dtype and input_tensor.layout == "plain" and - weight_qtensor.layout == "plain" + weight_qtensor.layout == "transposed" ): + assert input_tensor.shape[-1] == weight_qtensor.layout_tensor.int_data.shape[0], ( + f"need mat1 shape: {input_tensor.shape} final " + f"dim to match mat2 shape: {weight_qtensor.layout_tensor.int_data.shape} first dim " + ) # # 1. do the matrix form of dot(X_i, W_j) # @@ -532,7 +610,7 @@ def functional_linear(*args, **kwargs): x_vals_int8 = input_tensor.layout_tensor.int_data x_scales = input_tensor.layout_tensor.scale - w_vals_int8_t = weight_qtensor.layout_tensor.int_data.contiguous().t() + w_vals_int8_t = weight_qtensor.layout_tensor.int_data w_scales = weight_qtensor.layout_tensor.scale tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) @@ -579,42 +657,58 @@ def functional_linear(*args, **kwargs): # TODO: enable mps path as well # per channel int8 weight only quantizated mm return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.layout_tensor.int_data, weight_qtensor.layout_tensor.scale) - else: - weight_tensor = weight_qtensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - else: + raise NotImplementedError("No specialized dispatch found for quantized linear op") + + +@implements_aqt_torch_function(torch.nn.functional.linear) +def functional_linear(*args, **kwargs): + input_tensor, weight_tensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: if isinstance(input_tensor, AffineQuantizedTensor): input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - @implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) def aten_mm(func, *args, **kwargs): if not args[0].is_floating_point(): raise NotImplementedError(f"{func} is not implemented for non floating point input") if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[1], args[2], args[0], ) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(bias, input_tensor, weight_tensor) else: - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[0], args[1], - None if len(args) == 2 else args[2], + None ) - weight_tensor = weight_qtensor.dequantize() - return func(input_tensor, weight_tensor, bias) + try: + return _quantized_linear_op(input_tensor, weight_tensor, bias) + except: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + if isinstance(weight_tensor, AffineQuantizedTensor): + weight_tensor = weight_tensor.dequantize() + return func(bias, input_tensor, weight_tensor) @implements_aqt_aten_ops([aten.detach.default]) def detach(func, *args, **kwargs): @@ -641,10 +735,10 @@ def _to_copy(func, *args, **kwargs): @implements_aqt_aten_ops([aten.t.default]) def t(func, *args, **kwargs): - # TODO: need to implement this - # args[0].transposed = not args[0].transposed - # new = args[0]._change_shape(args[0].shape[::-1]) - # return return_and_correct_aliasing(func, args, kwargs, new) - raise Exception("transpose not implemented yet") + block_size = args[0].block_size + assert len(block_size) == 2 + transposed_block_size = (block_size[1], block_size[0]) + new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size) + return return_and_correct_aliasing(func, args, kwargs, new) to_aq = AffineQuantizedTensor.from_float diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 7ec88c7498..0fdc2e20fc 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -25,7 +25,11 @@ from typing import Any, Callable from .dynamic_quant import DynamicallyPerAxisQuantizedLinear -from .utils import TORCH_VERSION_AFTER_2_3, TORCH_VERSION_AFTER_2_4 +from .utils import ( + TORCH_VERSION_AFTER_2_3, + TORCH_VERSION_AFTER_2_4, + unwrap_tensor_subclass, +) from .subclass import ( Int4WeightOnlyQuantizedLinearWeight, @@ -187,9 +191,13 @@ def change_linear_weights_to_int8_dqtensors(model, filter_fn=None, **kwargs): *args ) - _replace_with_custom_fn_if_matches_filter( - model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=TORCH_VERSION_AFTER_2_4, **kwargs), filter_fn - ) + if TORCH_VERSION_AFTER_2_4: + quantize(model, get_apply_int8dyn_quant(), filter_fn) + unwrap_tensor_subclass(model, filter_fn) + else: + _replace_with_custom_fn_if_matches_filter( + model, _get_subclass_inserter(Int8DynamicallyQuantizedLinearWeight, enable_parametrization=False, **kwargs), filter_fn + ) def change_linear_weights_to_int8_woqtensors(model, filter_fn=None, **kwargs): @@ -393,7 +401,7 @@ def get_per_token_block_size(x): input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) block_size = get_weight_block_size(weight) - weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + weight = to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype, extended_layout="transposed") weight = to_laq(weight, input_quant_func) return weight return apply_int8dyn_quant diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index ee13512e9f..972699f0bf 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -610,6 +610,7 @@ def __new__( dtype = original_weight_tensor.dtype kwargs["dtype"] = dtype kwargs["requires_grad"] = False + kwargs["device"] = original_weight_tensor.device shape = original_weight_tensor.shape return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] @@ -664,6 +665,27 @@ def _apply_fn_to_data(self, fn): self.input_quant_func, ) + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.original_weight_tensor.to(**kwargs), + self.input_quant_func, + ) + def __torch_dispatch__(cls, func, types, args, kwargs): if ( func in [aten.mm.default, aten.addmm.default] @@ -674,25 +696,29 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"need mat1 shape: {args[1].shape} final" f"dim to match mat2 shape: {args[2].shape} first dim " ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor, bias = ( args[1], args[2], args[0], ) - aqt = self.input_quant_func(input_tensor) - return func(bias, aqt, weight_tensor) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(bias, aqt, original_weight_tensor) else: + # aten.mm.default assert args[0].shape[-1] == args[1].shape[0], ( f"need mat1 shape: {args[0].shape} final dim" f"to match mat2 shape: {args[1].shape} first dim" ) - input_tensor, weight_qtensor, bias = ( + input_tensor, weight_tensor = ( args[0], args[1], - None if len(args) == 2 else args[2], ) - aqt = self.input_quant_func(input_tensor) - return func(aqt, weight_tensor, bias) + input_quant_func = weight_tensor.input_quant_func + original_weight_tensor = weight_tensor.original_weight_tensor + aqt = input_quant_func(input_tensor) + return func(aqt, original_weight_tensor) if func is aten.detach.default: return return_and_correct_aliasing( @@ -704,6 +730,19 @@ def __torch_dispatch__(cls, func, types, args, kwargs): func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) ) + if func is aten._to_copy.default: + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + + if func is aten.t.default: + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.t) + ) + raise NotImplementedError( f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) diff --git a/torchao/quantization/utils.py b/torchao/quantization/utils.py index 78a76863f3..e6787b0cf9 100644 --- a/torchao/quantization/utils.py +++ b/torchao/quantization/utils.py @@ -133,11 +133,14 @@ def right_inverse(self, tensor): def unwrap_tensor_subclass(model, filter_fn=None): for name, child in model.named_children(): + # make sure child.weight is a tensor subclass if ( isinstance(child, torch.nn.Linear) and hasattr(child, "weight") and type(child.weight) is not torch.Tensor and - isinstance(child.weight, torch.Tensor) + type(child.weight) is not torch.nn.Parameter and + isinstance(child.weight, torch.Tensor) and + issubclass(type(child.weight), torch.Tensor) ): parametrize.register_parametrization(child, "weight", UnwrapTensorSubclass()) unwrap_tensor_subclass(child) diff --git a/tutorials/quantize_vit/run_vit_b_quant.py b/tutorials/quantize_vit/run_vit_b_quant.py index 0396a9dffd..583ad36f70 100644 --- a/tutorials/quantize_vit/run_vit_b_quant.py +++ b/tutorials/quantize_vit/run_vit_b_quant.py @@ -19,7 +19,7 @@ inductorconfig.force_fuse_int_mm_with_mul = True ## Quantization code - end -model = torch.compile(model, mode='max-autotune') +model = torch.compile(model, mode='max-autotune', fullgraph=True) # Must run with no_grad when optimizing for inference with torch.no_grad():