diff --git a/README.md b/README.md index 150c67b512..2328c67c8b 100644 --- a/README.md +++ b/README.md @@ -29,9 +29,23 @@ git clone https://github.com/pytorch/ao cd ao pip install -r requirements.txt pip install -r dev-requirements.txt -pip install . ``` +There are two options; +-If you plan to be developing the library run: +```Shell +python setup.py develop +``` + +If you want to install from source run +```Shell +python setup.py install +``` + +** Note: +Since we are building pytorch c++/cuda extensions by default, running `pip install .` will +not work. + ### Quantization ```python diff --git a/docs/static/pruning_ecosystem_diagram.png b/docs/static/pruning_ecosystem_diagram.png new file mode 100644 index 0000000000..f6562cdf20 Binary files /dev/null and b/docs/static/pruning_ecosystem_diagram.png differ diff --git a/docs/static/pruning_flow.png b/docs/static/pruning_flow.png new file mode 100644 index 0000000000..e6a4092f6c Binary files /dev/null and b/docs/static/pruning_flow.png differ diff --git a/test/quantization/test_quant_api.py b/test/quantization/test_quant_api.py index f0830cf8a8..24882b8418 100644 --- a/test/quantization/test_quant_api.py +++ b/test/quantization/test_quant_api.py @@ -18,17 +18,18 @@ get_symmetric_quantization_config, ) -from torchao.quantization.subclass import ( - to_aqt, - to_laqt, +from torchao.dtypes import ( + to_aq, AffineQuantizedTensor, - LinearActQuantizedTensor, ) from torchao.quantization.quant_primitives import ( MappingType, ZeroPointDomain, ) - +from torchao.quantization.subclass import ( + to_laq, + LinearActQuantizedTensor, +) from torchao.quantization.quant_api import ( _replace_with_custom_fn_if_matches_filter, apply_dynamic_quant, @@ -429,17 +430,17 @@ def get_per_token_block_size(x): # input settings input_mapping_type = MappingType.ASYMMETRIC input_target_dtype = torch.int8 - input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) + input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) m = ToyLinearModel().eval() m_copy = copy.deepcopy(m) example_inputs = m.example_inputs() def apply_weight_quant(weight): - return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) + return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps) def apply_act_quant(weight): - return to_laqt(weight, input_quant_func) + return to_laq(weight, input_quant_func) # note: order is important m = quantize(m, apply_weight_quant) @@ -484,7 +485,7 @@ def test_quantized_tensor_subclass_int4(self): example_inputs = tuple(map(lambda x: x.to(torch.bfloat16).to("cuda"), m.example_inputs())) def apply_weight_quant(weight): - return to_aqt(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) + return to_aq(weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, zero_point_dtype=zero_point_dtype, preserve_zero=preserve_zero, zero_point_domain=zero_point_domain) m = quantize(m, apply_weight_quant) assert isinstance(m.linear1.weight, AffineQuantizedTensor) @@ -515,7 +516,7 @@ def test_quantized_tensor_subclass_int8(self): def apply_weight_quant(weight): block_size = (1, weight.shape[1]) - return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) m = quantize(m, apply_weight_quant) @@ -555,7 +556,7 @@ def get_per_token_block_size(x): input_eps = 1e-5 input_quant_min = -127 input_quant_max = 127 - input_quant_func = lambda x: to_aqt(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) + input_quant_func = lambda x: to_aq(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype, eps=input_eps, quant_min=input_quant_min, quant_max=input_quant_max, scale_dtype=torch.float32 if x.dtype == torch.float16 else None) # use 1024 so that we don't need padding m = ToyLinearModel(1024, 1024, 1024).eval().to(torch.bfloat16).to("cuda") @@ -565,10 +566,10 @@ def get_per_token_block_size(x): def apply_weight_quant(weight): block_size = get_weight_block_size(weight) - return to_aqt(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) + return to_aq(weight, mapping_type, block_size, target_dtype, eps=eps, zero_point_dtype=zero_point_dtype) def apply_act_quant(weight): - return to_laqt(weight, input_quant_func) + return to_laq(weight, input_quant_func) m = quantize(m, apply_weight_quant) m = quantize(m, apply_act_quant) diff --git a/test/test_ops.py b/test/test_ops.py index d73ae536ac..6ce6a4afba 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -30,21 +30,6 @@ def _create_tensors_with_iou(self, N, iou_thresh): scores = torch.rand(N) return boxes, scores - @unittest.skipIf(not torch.cuda.is_available(), "CUDA not available") - @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch verion is 2.3 or lower") - def test_nms(self): - iou = 0.2 - boxes, scores = self._create_tensors_with_iou(1000, iou) - boxes = boxes.cuda() - scores = scores.cuda() - - # smoke test - _ = torchao.ops.nms(boxes, scores, iou) - - # comprehensive testing - test_utils = ["test_schema", "test_autograd_registration", "test_faketensor", "test_aot_dispatch_dynamic"] - opcheck(torch.ops.torchao.nms, (boxes, scores, iou), test_utils=test_utils) - def _create_fp6_inputs(self, BS: int, OC: int, IC: int): # Randomly initialize each bytes. The highest value for randint() is set the the max value of uint32_t. fp6_weight = torch.randint(4294967295, (OC, IC // 16 * 3)).to(torch.int) diff --git a/torchao/csrc/cuda/nms.cu b/torchao/csrc/cuda/nms.cu deleted file mode 100644 index 5bbbff8d79..0000000000 --- a/torchao/csrc/cuda/nms.cu +++ /dev/null @@ -1,181 +0,0 @@ -#include -#include -#include -#include -#include - -namespace torchao { - -namespace { - -#define CUDA_1D_KERNEL_LOOP_T(i, n, index_t) \ - for (index_t i = (blockIdx.x * blockDim.x) + threadIdx.x; i < (n); \ - i += (blockDim.x * gridDim.x)) - -#define CUDA_1D_KERNEL_LOOP(i, n) CUDA_1D_KERNEL_LOOP_T(i, n, int) - -template -constexpr __host__ __device__ inline integer ceil_div(integer n, integer m) { - return (n + m - 1) / m; -} - -int const threadsPerBlock = sizeof(unsigned long long) * 8; - -template -__device__ inline bool devIoU( - T const* const a, - T const* const b, - const float threshold) { - T left = max(a[0], b[0]), right = min(a[2], b[2]); - T top = max(a[1], b[1]), bottom = min(a[3], b[3]); - T width = max(right - left, (T)0), height = max(bottom - top, (T)0); - using acc_T = at::acc_type; - acc_T interS = (acc_T)width * height; - acc_T Sa = ((acc_T)a[2] - a[0]) * (a[3] - a[1]); - acc_T Sb = ((acc_T)b[2] - b[0]) * (b[3] - b[1]); - return (interS / (Sa + Sb - interS)) > threshold; -} - -template -__global__ void nms_kernel_impl( - int n_boxes, - double iou_threshold, - const T* dev_boxes, - unsigned long long* dev_mask) { - const int row_start = blockIdx.y; - const int col_start = blockIdx.x; - - if (row_start > col_start) - return; - - const int row_size = - min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); - const int col_size = - min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); - - __shared__ T block_boxes[threadsPerBlock * 4]; - if (threadIdx.x < col_size) { - block_boxes[threadIdx.x * 4 + 0] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 0]; - block_boxes[threadIdx.x * 4 + 1] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 1]; - block_boxes[threadIdx.x * 4 + 2] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 2]; - block_boxes[threadIdx.x * 4 + 3] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 4 + 3]; - } - __syncthreads(); - - if (threadIdx.x < row_size) { - const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; - const T* cur_box = dev_boxes + cur_box_idx * 4; - int i = 0; - unsigned long long t = 0; - int start = 0; - if (row_start == col_start) { - start = threadIdx.x + 1; - } - for (i = start; i < col_size; i++) { - if (devIoU(cur_box, block_boxes + i * 4, iou_threshold)) { - t |= 1ULL << i; - } - } - const int col_blocks = ceil_div(n_boxes, threadsPerBlock); - dev_mask[cur_box_idx * col_blocks + col_start] = t; - } -} - -at::Tensor nms_kernel( - const at::Tensor& dets, - const at::Tensor& scores, - double iou_threshold) { - TORCH_CHECK(dets.is_cuda(), "dets must be a CUDA tensor"); - TORCH_CHECK(scores.is_cuda(), "scores must be a CUDA tensor"); - - TORCH_CHECK( - dets.dim() == 2, "boxes should be a 2d tensor, got ", dets.dim(), "D"); - TORCH_CHECK( - dets.size(1) == 4, - "boxes should have 4 elements in dimension 1, got ", - dets.size(1)); - TORCH_CHECK( - scores.dim() == 1, - "scores should be a 1d tensor, got ", - scores.dim(), - "D"); - TORCH_CHECK( - dets.size(0) == scores.size(0), - "boxes and scores should have same number of elements in ", - "dimension 0, got ", - dets.size(0), - " and ", - scores.size(0)) - - at::cuda::CUDAGuard device_guard(dets.device()); - - if (dets.numel() == 0) { - return at::empty({0}, dets.options().dtype(at::kLong)); - } - - auto order_t = std::get<1>( - scores.sort(/*stable=*/true, /*dim=*/0, /* descending=*/true)); - auto dets_sorted = dets.index_select(0, order_t).contiguous(); - - int dets_num = dets.size(0); - - const int col_blocks = ceil_div(dets_num, threadsPerBlock); - - at::Tensor mask = - at::empty({dets_num * col_blocks}, dets.options().dtype(at::kLong)); - - dim3 blocks(col_blocks, col_blocks); - dim3 threads(threadsPerBlock); - cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - dets_sorted.scalar_type(), "nms_kernel", [&] { - nms_kernel_impl<<>>( - dets_num, - iou_threshold, - dets_sorted.data_ptr(), - (unsigned long long*)mask.data_ptr()); - }); - - at::Tensor mask_cpu = mask.to(at::kCPU); - unsigned long long* mask_host = - (unsigned long long*)mask_cpu.data_ptr(); - - std::vector remv(col_blocks); - memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); - - at::Tensor keep = - at::empty({dets_num}, dets.options().dtype(at::kLong).device(at::kCPU)); - int64_t* keep_out = keep.data_ptr(); - - int num_to_keep = 0; - for (int i = 0; i < dets_num; i++) { - int nblock = i / threadsPerBlock; - int inblock = i % threadsPerBlock; - - if (!(remv[nblock] & (1ULL << inblock))) { - keep_out[num_to_keep++] = i; - unsigned long long* p = mask_host + i * col_blocks; - for (int j = nblock; j < col_blocks; j++) { - remv[j] |= p[j]; - } - } - } - - AT_CUDA_CHECK(cudaGetLastError()); - return order_t.index( - {keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep) - .to(order_t.device(), keep.scalar_type())}); -} - -} // namespace - -TORCH_LIBRARY_IMPL(torchao, CUDA, m) { - m.impl("torchao::nms", &nms_kernel); -} - -} // namespace torchao diff --git a/torchao/csrc/nms.cpp b/torchao/csrc/nms.cpp deleted file mode 100644 index 5cc26d1593..0000000000 --- a/torchao/csrc/nms.cpp +++ /dev/null @@ -1,8 +0,0 @@ -#include -#include -#include - -TORCH_LIBRARY_FRAGMENT(torchao, m) { - m.impl_abstract_pystub("torchao.ops"); - m.def("nms(Tensor dets, Tensor scores, float iou_threshold) -> Tensor"); -} diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index b14aff9904..dccd22f3d4 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,8 +1,11 @@ from .nf4tensor import NF4Tensor, to_nf4 from .uint4 import UInt4Tensor +from .aqt import AffineQuantizedTensor, to_aq __all__ = [ "NF4Tensor", "to_nf4", "UInt4Tensor" + "AffineQuantizedTensor", + "to_aq", ] diff --git a/torchao/dtypes/aqt.py b/torchao/dtypes/aqt.py new file mode 100644 index 0000000000..7619545f52 --- /dev/null +++ b/torchao/dtypes/aqt.py @@ -0,0 +1,444 @@ +import torch +from typing import Dict, Callable, Any, Tuple, Optional +from collections import defaultdict +import functools +from torchao.quantization.quant_primitives import ( + choose_qparams_affine, + quantize_affine, + dequantize_affine, + ZeroPointDomain, + MappingType, + pack_tinygemm_scales_and_zeros, +) +from torch.utils._python_dispatch import return_and_correct_aliasing +from torchao.kernel.intmm import int_scaled_matmul + +aten = torch.ops.aten + +def _aqt_is_int8(aqt): + """Check if an AffineQuantizedTensor is int8 quantized Tensor""" + return ( + aqt.int_data.dtype == torch.int8 and + aqt.quant_min is None or aqt.quant_min == -128 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_int8_reduced_range(aqt): + return ( + aqt.int_data.dtype == torch.int8 and + aqt.quant_min == -127 and + aqt.quant_max is None or aqt.quant_max == 127 + ) + +def _aqt_is_uint4(aqt): + """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" + # TODO: use torch.uint4 + return ( + aqt.int_data.dtype == torch.int32 and + aqt.quant_min is None or aqt.quant_min == 0 and + aqt.quant_max is None or aqt.quant_max == 15 + ) + +# TODO: merge with nf4 implements decorator +# aten op to their __torch_dispatch__ implemnetations for the tensor subclass +_ATEN_OPS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) + +def implements_aten_ops(cls, aten_ops): + """Use this decorator to implement a function for an aten op in __torch_dispatch__""" + + def decorator(func): + for op in aten_ops: + _ATEN_OPS_TABLE[cls][op] = func + return func + + return decorator + +_TORCH_FUNCTIONS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) + +def implements_torch_function(cls, torch_function): + def decorator(func): + functools.update_wrapper(func, torch_function) + _TORCH_FUNCTIONS_TABLE[cls][torch_function] = func + return func + + return decorator + +def implements_aqt_aten_ops(aten_ops): + return implements_aten_ops(AffineQuantizedTensor, aten_ops) + +def implements_aqt_torch_function(torch_function): + return implements_torch_function(AffineQuantizedTensor, torch_function) + + +class AffineQuantizedTensor(torch.Tensor): + """ + Base affine quantized tensor subclass. When the from_float method is used, + to create an instance of any AffineQuantizedTensor + + The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, + regardless of the internal representation's type or orientation. + + Affine quantization means we quantize the floating point tensor with an affine transformation: + quantized_tensor = float_tensor / scale + zero_point + + fields: + int_data (torch.Tensor): the quantized integer data Tensor + scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor + zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor + block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam + e.g. when size is the same as the input tensor dimension, we are using per tensor quantization + shape (torch.Size): the shape for the Tensor + quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` + zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float + if zero_point is in integer domain, zero point is added to the quantized integer value during + quantization + if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) + value during quantization + default is ZeroPointDomain.INT + input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object + dtype: dtype for external representation of the tensor, e.g. torch.float32 + """ + + @staticmethod + def __new__( + cls, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + shape: torch.Size, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + dtype=None, + strides=None, + ): + kwargs = {} + kwargs["device"] = int_data.device + kwargs["layout"] = ( + kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout + ) + if dtype is None: + dtype = scale.dtype + kwargs["dtype"] = dtype + if strides is not None: + kwargs["strides"] = strides + kwargs["requires_grad"] = False + return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] + + def __init__( + self, + int_data: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + shape: torch.Size, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + dtype=None, + strides=None, + ): + self.int_data = int_data + self.scale = scale + self.zero_point = zero_point + self.block_size = block_size + self.quant_min = quant_min + self.quant_max = quant_max + self.zero_point_domain = zero_point_domain + + def __repr__(self): + return ( + f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " + f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" + ) + + def dequantize(self, output_dtype=None): + if output_dtype is None: + output_dtype = self.dtype + return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) + + def __tensor_flatten__(self): + return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + + @classmethod + def __tensor_unflatten__( + cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride + ): + int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes + return cls( + int_data, + scale, + zero_point, + block_size, + shape if outer_size is None else outer_size, + quant_min, + quant_max, + zero_point_domain, + dtype=dtype, + strides=outer_stride, + ) + + @classmethod + def from_float( + cls, + input_float: torch.Tensor, + mapping_type: MappingType, + block_size: Tuple[int, ...], + target_dtype: torch.dtype, + quant_min: Optional[int] = None, + quant_max: Optional[int] = None, + eps: Optional[float] = None, + scale_dtype: Optional[torch.dtype] = None, + zero_point_dtype: Optional[torch.dtype] = None, + preserve_zero: bool = True, + zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, + ): + scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + return cls( + int_data, + scale, + zero_point, + block_size, + input_float.shape, + quant_min, + quant_max, + zero_point_domain, + dtype=input_float.dtype + ) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + kwargs = {} if kwargs is None else kwargs + + if func in _TORCH_FUNCTIONS_TABLE[cls]: + return _TORCH_FUNCTIONS_TABLE[cls][func](*args, **kwargs) + + with torch._C.DisableTorchFunctionSubclass(): + return func(*args, **kwargs) + + + def _get_to_kwargs(self, *args, **kwargs): + device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) + device = self.device if device is None else device + dtype = self.dtype if dtype is None else dtype + memory_format = ( + memory_format if memory_format is not None else torch.preserve_format + ) + kwargs = { + "device": device, + "dtype": dtype, + "memory_format": memory_format, + } + return kwargs + + def to(self, *args, **kwargs): + kwargs = self._get_to_kwargs(*args, **kwargs) + return self.__class__( + self.int_data.to(kwargs["device"]), + self.scale.to(kwargs["device"]), + self.zero_point.to(kwargs["device"]), + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + **kwargs, + ) + + def _apply_fn_to_data(self, fn): + return self.__class__( + fn(self.int_data), + fn(self.scale), + fn(self.zero_point), + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) + + @classmethod + def __torch_dispatch__(cls, func, types, args, kwargs): + # Note: we only added cpu path here for 8da4w, this is for executorch, in the future + # 1. we'll add cpu/cuda version (int4mm etc.) + # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like + # cpu device + et laytout --> gives current 8da4w executorch representation + # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. + # cuda device + some layout --> gives cuda kernel + + # two scenarios where we currently fall back to vanilla mm: + # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized + # kernels in CPU as well, see the note above + # 2 - we're given non-floats - quantizing long to int8 is crazy + + if func in _ATEN_OPS_TABLE[cls]: + return _ATEN_OPS_TABLE[cls][func](func, *args, **kwargs) + + raise NotImplementedError( + f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" + ) + +@implements_aqt_torch_function(torch.nn.functional.linear) +def functional_linear(*args, **kwargs): + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + args[2] if len(args) > 2 else None, + ) + is_cuda = weight_qtensor.is_cuda + is_cpu = weight_qtensor.device == torch.device("cpu") + if isinstance(weight_qtensor, AffineQuantizedTensor): + weight_is_int8 = _aqt_is_int8(weight_qtensor) + weight_is_uint4 = _aqt_is_uint4(weight_qtensor) + + if isinstance(input_tensor, AffineQuantizedTensor): + # if input tensor is quantized, either dispatch to the int8 mm kernel + # or just dequantize the input tensor + input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) + input_tensor_dtype_is_expected = input_tensor.dtype in [ + torch.float, + torch.bfloat16 + ] + if ( + is_cuda and + input_is_int8 and + input_tensor_dtype_is_expected + ): + # + # 1. do the matrix form of dot(X_i, W_j) + # + # + # 2. rescale the output + # + # in cases with large matrices, y_dot_int32 can grow sufficiently + # large that y_dot_int32 * a float16 scale is greater than the maximum + # value of a float 16, (which results in a value of inf even if multiplying + # by the other scale would bring it within the expected range) + + x_vals_int8 = input_tensor.int_data + x_scales = input_tensor.scale + w_vals_int8_t = weight_qtensor.int_data.contiguous().t() + w_scales = weight_qtensor.scale + tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) + y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) + + y = (y_dot_scaled * w_scales).reshape( + *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] + ) + + # can downcast only at the very end + output_dtype = input_tensor.dtype + y = y.to(output_dtype) + if bias is not None: + y += bias + return y + else: + input_tensor = input_tensor.dequantize() + + # weight only quantization + # TODO: enable cpu and mps path as well + # TODO: make sure weight dimension matches the expectation of the int4mm kernel + # TODO: move this to TinygemmAffineQuantizedTensor + if ( + is_cuda and + weight_is_uint4 and + weight_qtensor.dtype == torch.bfloat16 and + len(weight_qtensor.shape) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT + ): + # groupwise int4 quantization + # TODO: currently doing packing on the fly, we'll need to figure out + # the API to do packing before hand + # TODO: expose the arg + innerKTiles = 8 + packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) + scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) + groupsize = weight_qtensor.block_size[-1] + return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) + elif ( + is_cpu and + weight_is_int8 and + len(weight_qtensor.shape) == 2 and + len(weight_qtensor.block_size) == 2 and + weight_qtensor.block_size[0] == 1 and + weight_qtensor.block_size[1] == weight_qtensor.shape[1] + ): + # TODO: enable mps path as well + # per channel int8 weight only quantizated mm + return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) + else: + weight_tensor = weight_qtensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + else: + if isinstance(input_tensor, AffineQuantizedTensor): + input_tensor = input_tensor.dequantize() + return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + + +@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) +def aten_mm(func, *args, **kwargs): + if not args[0].is_floating_point(): + raise NotImplementedError(f"{func} is not implemented for non floating point input") + + if func == aten.addmm.default: + assert args[1].shape[-1] == args[2].shape[0], ( + f"need mat1 shape: {args[1].shape} final" + f"dim to match mat2 shape: {args[2].shape} first dim " + ) + input_tensor, weight_qtensor, bias = ( + args[1], + args[2], + args[0], + ) + else: + assert args[0].shape[-1] == args[1].shape[0], ( + f"need mat1 shape: {args[0].shape} final dim" + f"to match mat2 shape: {args[1].shape} first dim" + ) + input_tensor, weight_qtensor, bias = ( + args[0], + args[1], + None if len(args) == 2 else args[2], + ) + weight_tensor = weight_qtensor.dequantize() + return func(input_tensor, weight_tensor, bias) + +@implements_aqt_aten_ops([aten.detach.default]) +def detach(func, *args, **kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) + ) + + +@implements_aqt_aten_ops([aten.clone.default]) +def clone(func, *args, **kwargs): + return return_and_correct_aliasing( + func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) + ) + + +@implements_aqt_aten_ops([aten._to_copy.default]) +def _to_copy(func, *args, **kwargs): + return return_and_correct_aliasing( + func, + args, + kwargs, + args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), + ) + +@implements_aqt_aten_ops([aten.t.default]) +def t(func, *args, **kwargs): + # TODO: need to implement this + # args[0].transposed = not args[0].transposed + # new = args[0]._change_shape(args[0].shape[::-1]) + # return return_and_correct_aliasing(func, args, kwargs, new) + raise Exception("transpose not implemented yet") + +to_aq = AffineQuantizedTensor.from_float diff --git a/torchao/kernel/intmm.py b/torchao/kernel/intmm.py index d2afa66a0a..8491a2ba6c 100644 --- a/torchao/kernel/intmm.py +++ b/torchao/kernel/intmm.py @@ -21,10 +21,24 @@ from torch._dynamo import is_compiling as dynamo_is_compiling from torch._higher_order_ops.out_dtype import out_dtype def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + """ + Performs a safe integer matrix multiplication, considering different paths for + torch.compile, cublas, and fallback cases. + + Args: + input (torch.Tensor): The input tensor of shape [i, j]. + mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. + + Returns: + torch.Tensor: The result of the matrix multiplication. + + Raises: + AssertionError: If the tensors are not on the same device. + """ # torch.compile path if dynamo_is_compiling() or "FakeTensor" in input.__repr__(): return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) - + # error checking for cublas path assert ( mat2.device == input.device @@ -39,13 +53,13 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: and j_is_nonzero_multiple_of_8 and k_is_nonzero_multiple_of_8 ) - + if device_cpu or bad_dimensions_for_cublas: # fallback path return torch.matmul(input.cpu().to(torch.int32), mat2.cpu().to(torch.int32)).to( input.device.type ) - + # cublas paths if not mat2.is_contiguous(): # silently gives incorrect result without this mat2 = mat2.contiguous() @@ -58,18 +72,53 @@ def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: return out_dtype(torch.ops.aten.mm.default, torch.int32, input, mat2) else: def safe_int_mm(input: torch.Tensor, mat2: torch.Tensor) -> torch.Tensor: + """ + Performs a fallback integer matrix multiplication for torch versions before 2.2. + + Args: + input (torch.Tensor): The input tensor of shape [i, j]. + mat2 (torch.Tensor): The matrix to multiply with, of shape [j, k]. + + Returns: + torch.Tensor: The result of the matrix multiplication in int32. + """ # We can improve on this by writing Triton code that works for older versions of Triton # that ship with 2.1 or 2.0. return torch.matmul(input.to(torch.float32), mat2.to(torch.float32)).to(torch.int32) -def int_matmul(a, b): +def int_matmul(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + """ + Performs integer matrix multiplication using intmm_triton if available and autotuner is enabled, + otherwise falls back to safe_int_mm. + + Args: + a (torch.Tensor): The first matrix to multiply. + b (torch.Tensor): The second matrix to multiply. + + Returns: + torch.Tensor: The result of the matrix multiplication. + """ if intmm_triton is not None and AUTOTUNER_ENABLE: return torch.ops.torchao.int_matmul(a, b) return safe_int_mm(a, b) -def int_scaled_matmul(a, b, scales1): +def int_scaled_matmul(a: torch.Tensor, b: torch.Tensor, scales1: torch.Tensor) -> torch.Tensor: + """ + Performs scaled integer matrix multiplication. + + Args: + a (torch.Tensor): The first matrix to multiply. + b (torch.Tensor): The second matrix to multiply. + scales1 (torch.Tensor): The scaling factors for the rows of the result. + + Returns: + torch.Tensor: The result of the scaled matrix multiplication. + + Raises: + AssertionError: If the dimensions of the input tensors do not match the expected shapes. + """ M, K = a.shape K, N = b.shape assert M == scales1.size(0) diff --git a/torchao/ops.py b/torchao/ops.py index fcc6ae9364..05a1668399 100644 --- a/torchao/ops.py +++ b/torchao/ops.py @@ -10,26 +10,6 @@ def decorator(func): return torch.library.impl_abstract(f"{name}")(func) return decorator -def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: - """ - See https://pytorch.org/vision/main/generated/torchvision.ops.nms.html - """ - return torch.ops.torchao.nms.default(boxes, scores, iou_threshold) - - -# Defines the meta kernel / fake kernel / abstract impl -@register_custom_op("torchao::nms") -def _(dets, scores, iou_threshold): - torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") - torch._check(dets.size(1) == 4, lambda: f"boxes should have 4 elements in dimension 1, got {dets.size(1)}") - torch._check(scores.dim() == 1, lambda: f"scores should be a 1d tensor, got {scores.dim()}") - torch._check( - dets.size(0) == scores.size(0), - lambda: f"boxes and scores should have same number of elements in dimension 0, got {dets.size(0)} and {scores.size(0)}", - ) - ctx = torch._custom_ops.get_ctx() - num_to_keep = ctx.create_unbacked_symint() - return dets.new_empty(num_to_keep, dtype=torch.long) def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: @@ -45,6 +25,7 @@ def prepack_fp6_weight(fp6_weight: Tensor) -> Tensor: return torch.ops.torchao.prepack_fp6_weight.default(fp6_weight) +# Defines the meta kernel / fake kernel / abstract impl @register_custom_op("torchao::prepack_fp6_weight") def _(fp6_weight): torch._check(fp6_weight.dim() == 2, lambda: f"weight should be a 2d tensor, got {fp6_weight.dim()}D") diff --git a/torchao/quantization/subclass.py b/torchao/quantization/subclass.py index 6e844530d4..ee13512e9f 100644 --- a/torchao/quantization/subclass.py +++ b/torchao/quantization/subclass.py @@ -14,58 +14,25 @@ dynamically_quantize_per_channel, groupwise_affine_quantize_tensor, quant_int8_dynamic_per_token_linear, - pack_tinygemm_scales_and_zeros, unpack_tinygemm_scales_and_zeros, groupwise_affine_quantize_tensor_from_qparams, - choose_qparams_affine, - quantize_affine, - dequantize_affine, - ZeroPointDomain, MappingType, ) -from torchao.kernel.intmm import int_scaled_matmul from .utils import find_multiple from typing import Tuple, Optional, Callable, Dict, Any -from collections import defaultdict -import functools __all__ = [ "Int8DynamicallyQuantizedLinearWeight", "Int8WeightOnlyQuantizedLinearWeight", "Int4WeightOnlyQuantizedLinearWeight", - "AffineQuantizedTensor", "LinearActQuantizedTensor", + "to_laq", ] aten = torch.ops.aten -def _aqt_is_int8(aqt): - """Check if an AffineQuantizedTensor is int8 quantized Tensor""" - return ( - aqt.int_data.dtype == torch.int8 and - aqt.quant_min is None or aqt.quant_min == -128 and - aqt.quant_max is None or aqt.quant_max == 127 - ) - -def _aqt_is_int8_reduced_range(aqt): - return ( - aqt.int_data.dtype == torch.int8 and - aqt.quant_min == -127 and - aqt.quant_max is None or aqt.quant_max == 127 - ) - -def _aqt_is_uint4(aqt): - """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" - # TODO: use torch.uint4 - return ( - aqt.int_data.dtype == torch.int32 and - aqt.quant_min is None or aqt.quant_min == 0 and - aqt.quant_max is None or aqt.quant_max == 15 - ) - - class QuantizedLinearWeightBase(torch.Tensor): """ Base quantized tensor subclass for quantized linear weights. When the from_float method is used, @@ -630,409 +597,6 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8): return int_data, scales_and_zeros, False, groupsize, inner_k_tiles -# TODO: merge with nf4 implements decorator -# aten op to their __torch_dispatch__ implemnetations for the tensor subclass -_ATEN_OPS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) - -def implements_aten_ops(cls, aten_ops): - """Use this decorator to implement a function for an aten op in __torch_dispatch__""" - - def decorator(func): - for op in aten_ops: - _ATEN_OPS_TABLE[cls][op] = func - return func - - return decorator - -_TORCH_FUNCTIONS_TABLE: Dict[Callable, Dict[Any, Any]] = defaultdict(dict) - -def implements_torch_function(cls, torch_function): - def decorator(func): - functools.update_wrapper(func, torch_function) - _TORCH_FUNCTIONS_TABLE[cls][torch_function] = func - return func - - return decorator - -def implements_aqt_aten_ops(aten_ops): - return implements_aten_ops(AffineQuantizedTensor, aten_ops) - -def implements_aqt_torch_function(torch_function): - return implements_torch_function(AffineQuantizedTensor, torch_function) - - -class AffineQuantizedTensor(torch.Tensor): - """ - Base affine quantized tensor subclass. When the from_float method is used, - to create an instance of any AffineQuantizedTensor - - The shape and dtype of the tensor subclass represent how the tensor subclass looks externally, - regardless of the internal representation's type or orientation. - - Affine quantization means we quantize the floating point tensor with an affine transformation: - quantized_tensor = float_tensor / scale + zero_point - - fields: - int_data (torch.Tensor): the quantized integer data Tensor - scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor - zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor - block_size (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam - e.g. when size is the same as the input tensor dimension, we are using per tensor quantization - shape (torch.Size): the shape for the Tensor - quant_min (Optional[int]): minimum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - quant_max (Optional[int]): maximum quantized value for the Tensor, if not specified, it will be derived from dtype of `int_data` - zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float - if zero_point is in integer domain, zero point is added to the quantized integer value during - quantization - if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized) - value during quantization - default is ZeroPointDomain.INT - input_quant_func (Optional[Callable]): function for quantizing the input float Tensor to a quantized tensor subclass object, that takes float Tensor as input and outputs an AffineQuantizedTensor object - dtype: dtype for external representation of the tensor, e.g. torch.float32 - """ - - @staticmethod - def __new__( - cls, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - block_size: Tuple[int, ...], - shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - dtype=None, - strides=None, - ): - kwargs = {} - kwargs["device"] = int_data.device - kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout - ) - if dtype is None: - dtype = scale.dtype - kwargs["dtype"] = dtype - if strides is not None: - kwargs["strides"] = strides - kwargs["requires_grad"] = False - return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined] - - def __init__( - self, - int_data: torch.Tensor, - scale: torch.Tensor, - zero_point: torch.Tensor, - block_size: Tuple[int, ...], - shape: torch.Size, - quant_min: Optional[int] = None, - quant_max: Optional[int] = None, - zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT, - dtype=None, - strides=None, - ): - self.int_data = int_data - self.scale = scale - self.zero_point = zero_point - self.block_size = block_size - self.quant_min = quant_min - self.quant_max = quant_max - self.zero_point_domain = zero_point_domain - - def __repr__(self): - return ( - f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, " - f"device={self.device}, dtype={self.dtype}, requires_grad={self.requires_grad})" - ) - - def dequantize(self, output_dtype=None): - if output_dtype is None: - output_dtype = self.dtype - return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, self.zero_point_domain, output_dtype=output_dtype) - - def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] - - @classmethod - def __tensor_unflatten__( - cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride - ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes - return cls( - int_data, - scale, - zero_point, - block_size, - shape if outer_size is None else outer_size, - quant_min, - quant_max, - zero_point_domain, - dtype=dtype, - strides=outer_stride, - ) - - @classmethod - def from_float( - cls, - input_float, - mapping_type, - block_size, - target_dtype, - quant_min = None, - quant_max = None, - eps = None, - scale_dtype = None, - zero_point_dtype = None, - preserve_zero = True, - zero_point_domain = ZeroPointDomain.INT, - ): - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) - int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) - return cls( - int_data, - scale, - zero_point, - block_size, - input_float.shape, - quant_min, - quant_max, - zero_point_domain, - dtype=input_float.dtype - ) - - @classmethod - def __torch_function__(cls, func, types, args=(), kwargs=None): - kwargs = {} if kwargs is None else kwargs - - if func in _TORCH_FUNCTIONS_TABLE[cls]: - return _TORCH_FUNCTIONS_TABLE[cls][func](*args, **kwargs) - - with torch._C.DisableTorchFunctionSubclass(): - return func(*args, **kwargs) - - - def _get_to_kwargs(self, *args, **kwargs): - device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) - device = self.device if device is None else device - dtype = self.dtype if dtype is None else dtype - memory_format = ( - memory_format if memory_format is not None else torch.preserve_format - ) - kwargs = { - "device": device, - "dtype": dtype, - "memory_format": memory_format, - } - return kwargs - - def to(self, *args, **kwargs): - kwargs = self._get_to_kwargs(*args, **kwargs) - return self.__class__( - self.int_data.to(kwargs["device"]), - self.scale.to(kwargs["device"]), - self.zero_point.to(kwargs["device"]), - self.block_size, - self.shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - **kwargs, - ) - - def _apply_fn_to_data(self, fn): - return self.__class__( - fn(self.int_data), - fn(self.scale), - fn(self.zero_point), - self.block_size, - self.shape, - self.quant_min, - self.quant_max, - self.zero_point_domain, - dtype=self.dtype, - strides=self.stride(), - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args, kwargs): - # Note: we only added cpu path here for 8da4w, this is for executorch, in the future - # 1. we'll add cpu/cuda version (int4mm etc.) - # 2. we'll need to hide the 8da4w executorch version under things like layouts (we also have multiple impl for cpu kernel as Michael mentioned), so it will be something like - # cpu device + et laytout --> gives current 8da4w executorch representation - # cpu device + avx layout --> gives optimized kernel for 8da4w in avx cpu etc. - # cuda device + some layout --> gives cuda kernel - - # two scenarios where we currently fall back to vanilla mm: - # 1 - when tensor is on CUDA: we'll add this later, we'll also enable dispatching to optimized - # kernels in CPU as well, see the note above - # 2 - we're given non-floats - quantizing long to int8 is crazy - - if func in _ATEN_OPS_TABLE[cls]: - return _ATEN_OPS_TABLE[cls][func](func, *args, **kwargs) - - raise NotImplementedError( - f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported" - ) - -@implements_aqt_torch_function(torch.nn.functional.linear) -def functional_linear(*args, **kwargs): - input_tensor, weight_qtensor, bias = ( - args[0], - args[1], - args[2] if len(args) > 2 else None, - ) - is_cuda = weight_qtensor.is_cuda - is_cpu = weight_qtensor.device == torch.device("cpu") - if isinstance(weight_qtensor, AffineQuantizedTensor): - weight_is_int8 = _aqt_is_int8(weight_qtensor) - weight_is_uint4 = _aqt_is_uint4(weight_qtensor) - - if isinstance(input_tensor, AffineQuantizedTensor): - # if input tensor is quantized, either dispatch to the int8 mm kernel - # or just dequantize the input tensor - input_is_int8 = _aqt_is_int8_reduced_range(input_tensor) - input_tensor_dtype_is_expected = input_tensor.dtype in [ - torch.float, - torch.bfloat16 - ] - if ( - is_cuda and - input_is_int8 and - input_tensor_dtype_is_expected - ): - # - # 1. do the matrix form of dot(X_i, W_j) - # - # - # 2. rescale the output - # - # in cases with large matrices, y_dot_int32 can grow sufficiently - # large that y_dot_int32 * a float16 scale is greater than the maximum - # value of a float 16, (which results in a value of inf even if multiplying - # by the other scale would bring it within the expected range) - - x_vals_int8 = input_tensor.int_data - x_scales = input_tensor.scale - w_vals_int8_t = weight_qtensor.int_data.contiguous().t() - w_scales = weight_qtensor.scale - tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1)) - - y = (y_dot_scaled * w_scales).reshape( - *x_vals_int8.shape[:-1], y_dot_scaled.shape[-1] - ) - - # can downcast only at the very end - output_dtype = input_tensor.dtype - y = y.to(output_dtype) - if bias is not None: - y += bias - return y - else: - input_tensor = input_tensor.dequantize() - - # weight only quantization - # TODO: enable cpu and mps path as well - # TODO: make sure weight dimension matches the expectation of the int4mm kernel - # TODO: move this to TinygemmAffineQuantizedTensor - if ( - is_cuda and - weight_is_uint4 and - weight_qtensor.dtype == torch.bfloat16 and - len(weight_qtensor.shape) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT - ): - # groupwise int4 quantization - # TODO: currently doing packing on the fly, we'll need to figure out - # the API to do packing before hand - # TODO: expose the arg - innerKTiles = 8 - packed_weight = torch.ops.aten._convert_weight_to_int4pack(weight_qtensor.int_data.to(torch.int32), innerKTiles) - scales_and_zeros = pack_tinygemm_scales_and_zeros(weight_qtensor.scale, weight_qtensor.zero_point) - groupsize = weight_qtensor.block_size[-1] - return torch.ops.aten._weight_int4pack_mm(input_tensor.contiguous(), packed_weight, groupsize, scales_and_zeros) - elif ( - is_cpu and - weight_is_int8 and - len(weight_qtensor.shape) == 2 and - len(weight_qtensor.block_size) == 2 and - weight_qtensor.block_size[0] == 1 and - weight_qtensor.block_size[1] == weight_qtensor.shape[1] - ): - # TODO: enable mps path as well - # per channel int8 weight only quantizated mm - return torch.ops.aten._weight_int8pack_mm(input_tensor.contiguous(), weight_qtensor.int_data, weight_qtensor.scale) - else: - weight_tensor = weight_qtensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - else: - if isinstance(input_tensor, AffineQuantizedTensor): - input_tensor = input_tensor.dequantize() - return torch.nn.functional.linear(input_tensor, weight_tensor, bias) - - -@implements_aqt_aten_ops([aten.mm.default, aten.addmm.default]) -def aten_mm(func, *args, **kwargs): - if not args[0].is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") - - if func == aten.addmm.default: - assert args[1].shape[-1] == args[2].shape[0], ( - f"need mat1 shape: {args[1].shape} final" - f"dim to match mat2 shape: {args[2].shape} first dim " - ) - input_tensor, weight_qtensor, bias = ( - args[1], - args[2], - args[0], - ) - else: - assert args[0].shape[-1] == args[1].shape[0], ( - f"need mat1 shape: {args[0].shape} final dim" - f"to match mat2 shape: {args[1].shape} first dim" - ) - input_tensor, weight_qtensor, bias = ( - args[0], - args[1], - None if len(args) == 2 else args[2], - ) - weight_tensor = weight_qtensor.dequantize() - return func(input_tensor, weight_tensor, bias) - -@implements_aqt_aten_ops([aten.detach.default]) -def detach(func, *args, **kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.detach) - ) - - -@implements_aqt_aten_ops([aten.clone.default]) -def clone(func, *args, **kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(torch.clone) - ) - - -@implements_aqt_aten_ops([aten._to_copy.default]) -def _to_copy(func, *args, **kwargs): - return return_and_correct_aliasing( - func, - args, - kwargs, - args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), - ) - -@implements_aqt_aten_ops([aten.t.default]) -def t(func, *args, **kwargs): - # TODO: need to implement this - # args[0].transposed = not args[0].transposed - # new = args[0]._change_shape(args[0].shape[::-1]) - # return return_and_correct_aliasing(func, args, kwargs, new) - raise Exception("transpose not implemented yet") - - class LinearActQuantizedTensor(torch.Tensor): """ Applies activation quantization for linear operator @@ -1072,15 +636,8 @@ def __tensor_unflatten__( ) @classmethod - def from_float( - cls, - input_float, - input_quant_func, - ): - return cls( - input_float, - input_quant_func, - ) + def from_float(cls, input_float, input_quant_func): + return cls(input_float, input_quant_func) @classmethod def __torch_function__(cls, func, types, args=(), kwargs=None): @@ -1151,5 +708,4 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"LinearActQuantizedTensor dispatch: attempting to run {func}, this is not supported" ) -to_aqt = AffineQuantizedTensor.from_float -to_laqt = LinearActQuantizedTensor.from_float +to_laq = LinearActQuantizedTensor.from_float diff --git a/torchao/sparsity/README.md b/torchao/sparsity/README.md index b18e996b58..49cbe51a13 100644 --- a/torchao/sparsity/README.md +++ b/torchao/sparsity/README.md @@ -44,7 +44,7 @@ The handoff point between these two pieces are sparse weights stored in a dense This also allows users with existing sparse weights in a dense format to take advantage of our fast sparse kernels. We anticipate many users to come up with their own custom frontend masking solution or to use another third party solution, as this is an active area of research. -![pruning_flow](https://private-user-images.githubusercontent.com/8041643/324607153-ba91eaca-14ce-4608-9db8-6cbb9ea1f9ec.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTQ1OTgzOTYsIm5iZiI6MTcxNDU5ODA5NiwicGF0aCI6Ii84MDQxNjQzLzMyNDYwNzE1My1iYTkxZWFjYS0xNGNlLTQ2MDgtOWRiOC02Y2JiOWVhMWY5ZWMucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDUwMSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA1MDFUMjExNDU2WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9YWVjOWQ5ZjFjMWZmNjg4ZTgyZGFkYWU3ZDQ3MDBjMTZkNzczZWQxYzczN2ZiM2ZjZGY0NjUwMGUwY2UwZDA1YyZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.ni5F_wDhNkeupMJ84bFNxhaSO3xPH-9zecz_933Uu68) +![pruning_flow](/docs/static/pruning_ecosystem_diagram.png) Below, we provide an example of accelerating a model with 2:4 sparsity + bf16 using our PyTorch APIs. @@ -97,7 +97,7 @@ Note that this section focuses on **pruning**, instead of **sparse training**. T Roughly, the flow for achieving a more performant pruned model looks like this: -![flow](https://private-user-images.githubusercontent.com/8041643/324607146-53542488-65ce-4d99-a3ae-21e724f89467.png?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3MTQ1OTgzOTYsIm5iZiI6MTcxNDU5ODA5NiwicGF0aCI6Ii84MDQxNjQzLzMyNDYwNzE0Ni01MzU0MjQ4OC02NWNlLTRkOTktYTNhZS0yMWU3MjRmODk0NjcucG5nP1gtQW16LUFsZ29yaXRobT1BV1M0LUhNQUMtU0hBMjU2JlgtQW16LUNyZWRlbnRpYWw9QUtJQVZDT0RZTFNBNTNQUUs0WkElMkYyMDI0MDUwMSUyRnVzLWVhc3QtMSUyRnMzJTJGYXdzNF9yZXF1ZXN0JlgtQW16LURhdGU9MjAyNDA1MDFUMjExNDU2WiZYLUFtei1FeHBpcmVzPTMwMCZYLUFtei1TaWduYXR1cmU9ZWJlYWMzZDFmNzc2NDM1MGI2ODNlMjUxZjQxYTAwYzhhNzBkNGU2ZGIwYTg4NzA5Yjk3N2JkNzI4MmUyNzg3NiZYLUFtei1TaWduZWRIZWFkZXJzPWhvc3QmYWN0b3JfaWQ9MCZrZXlfaWQ9MCZyZXBvX2lkPTAifQ.Hxk5XMuJXhNsORVNNgcKNRCk7W1nT4CndLTAC3Oz0qE) +![flow](/docs/static/pruning_flow.png) The general idea behind pruning is that we can mask out some of the weights of a trained neural network and recover any accuracy loss. The resultant pruned model can be run on optimized kernels that take advantage of this sparsity for accelerated inference.