diff --git a/ruff.toml b/ruff.toml index 1a4a5ff097..773497eb5c 100644 --- a/ruff.toml +++ b/ruff.toml @@ -11,6 +11,6 @@ include = [ "torchao/quantization/linear_activation_weight_observer.py", "test/quantization/test_observer.py", "test/dtypes/test_affine_quantized_float.py", - "torchao/quantization/weight_tensor_linear_activation_quantization.py" - + "torchao/quantization/weight_tensor_linear_activation_quantization.py", + "torchao/dtypes/**/*.py", ] diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index 4ab0c3f701..1cd1f9289c 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,4 +1,5 @@ from .nf4tensor import NF4Tensor, to_nf4 + # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor from .affine_quantized_tensor import ( @@ -21,7 +22,7 @@ __all__ = [ "NF4Tensor", "to_nf4", - "UInt4Tensor" + "UInt4Tensor", "AffineQuantizedTensor", "to_affine_quantized_intx", "to_affine_quantized_intx_static", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 03cec525f4..4c77f76257 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,8 +1,5 @@ import torch -from typing import Tuple, Optional, Union, List -import torchao.ops -from collections import defaultdict -import functools +from typing import Tuple, Optional, Union import math from torchao.quantization.quant_primitives import ( _get_reduction_params, @@ -32,7 +29,7 @@ preprocess_data, Float8MMConfig, addmm_float8_unwrapped_inference, - _is_rowwise_scaled + _is_rowwise_scaled, ) from torch.utils._python_dispatch import is_traceable_wrapper_subclass from dataclasses import dataclass @@ -47,9 +44,10 @@ logger = logging.getLogger(__name__) -from torchao.float8.inference import Float8MMConfig + aten = torch.ops.aten + ############################### # Base Tensor Impl Subclass # ############################### @@ -60,6 +58,7 @@ class AQTTensorImpl(TorchAOBaseTensor): Note: This is not a user facing API, it's used by AffineQuantizedTensor to construct the underlying implementation of a AQT based on layout """ + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the plain (unpacked) Tensor for the tensor impl @@ -79,7 +78,7 @@ def from_plain( zero_point: torch.Tensor, _layout: Layout, ): - """ Construct a TensorImpl from data, scale, zero_point and the _layout""" + """Construct a TensorImpl from data, scale, zero_point and the _layout""" pass def __repr__(self): @@ -94,11 +93,14 @@ def __repr__(self): class QuantizedLinearNotImplementedError(NotImplementedError): - """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ + """Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table""" + pass _AQT_QLINEAR_DISPATCH_TABLE = {} + + def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """Register a dispatch for quantized linear op with dispatch_condition function and impl function both takes three arguments: @@ -115,11 +117,15 @@ def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """ _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl + def deregister_aqt_quantized_linear_dispatch(dispatch_condition): if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] else: - logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") + logger.warn( + f"Attempting to remove non-existant dispatch condition {dispatch_condition}" + ) + class AffineQuantizedTensor(TorchAOBaseTensor): """ @@ -205,9 +211,16 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype = self.dtype from torchao.dtypes.floatx import FloatxTensorCoreLayout + if isinstance(self._layout, FloatxTensorCoreLayout): int_data, scale = self.tensor_impl.get_plain() - return dequantize_affine_floatx(int_data, scale, self._layout.ebits, self._layout.mbits, output_dtype=output_dtype) + return dequantize_affine_floatx( + int_data, + scale, + self._layout.ebits, + self._layout.mbits, + output_dtype=output_dtype, + ) else: data, scale, zero_point = self.tensor_impl.get_plain() dq = dequantize_affine( @@ -232,17 +245,28 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) - raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") + raise QuantizedLinearNotImplementedError( + "No specialized dispatch found for quantized linear op" + ) def __tensor_flatten__(self): - return ["tensor_impl"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["tensor_impl"], [ + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + self.dtype, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): tensor_impl = tensor_data_dict["tensor_impl"] - block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = ( + tensor_attributes + ) return cls( tensor_impl, block_size, @@ -275,20 +299,58 @@ def from_hp_to_intx( input_float = _layout.pre_process(input_float) if use_hqq: - assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization." + assert ( + zero_point_domain == ZeroPointDomain.FLOAT + and mapping_type == MappingType.ASYMMETRIC + and quant_min == 0 + ), "Invalid input parameters for HQQ quantization." nbits = int(math.log2(quant_max + 1)) - axis = 1 if (block_size[0]==1) else 0 + axis = 1 if (block_size[0] == 1) else 0 group_size = max(block_size) - compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype + compute_dtype = ( + zero_point_dtype + if (zero_point_dtype is not None) + else input_float.dtype + ) device = input_float.device - data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False) + data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( + input_float, + nbits=nbits, + group_size=group_size, + axis=axis, + compute_dtype=compute_dtype, + device=device, + verbose=False, + raw_output=False, + ) data = data.to(target_dtype) else: - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + scale, zero_point = choose_qparams_affine( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None if zero_point_domain is None: zero_point = None - data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) # Note: output will be uint8 tensor for sub byte tensors for now data = _layout.post_process(data) @@ -301,7 +363,7 @@ def from_hp_to_intx( quant_min, quant_max, zero_point_domain, - dtype=input_float.dtype + dtype=input_float.dtype, ) @classmethod @@ -318,12 +380,27 @@ def from_hp_to_intx_static( _layout: Layout = PlainLayout(), ): if target_dtype not in FP8_TYPES: - assert zero_point_domain is not None, "zero_point_domain must be specified for non-fp8 types" - assert zero_point is not None, "zero_point must be specified for non-fp8 types" + assert ( + zero_point_domain is not None + ), "zero_point_domain must be specified for non-fp8 types" + assert ( + zero_point is not None + ), "zero_point must be specified for non-fp8 types" original_shape = input_float.shape - input_float, scale, zero_point = _layout.pre_process_static(input_float, scale, zero_point, block_size) + input_float, scale, zero_point = _layout.pre_process_static( + input_float, scale, zero_point, block_size + ) - int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) int_data = _layout.post_process(int_data) @@ -348,7 +425,6 @@ def from_hp_to_floatx( _layout: Layout, scale_dtype: Optional[torch.dtype] = None, ): - if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -366,7 +442,9 @@ def from_hp_to_floatx( use_hqq=False, ) else: - raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx") + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx" + ) @classmethod def from_hp_to_floatx_static( @@ -377,7 +455,6 @@ def from_hp_to_floatx_static( target_dtype: torch.dtype, _layout: Layout, ): - if target_dtype in FP8_TYPES: return cls.from_hp_to_intx_static( input_float=input_float, @@ -391,7 +468,9 @@ def from_hp_to_floatx_static( _layout=_layout, ) else: - raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static") + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" + ) @classmethod def from_hp_to_fpx( @@ -400,7 +479,10 @@ def from_hp_to_fpx( _layout: Layout, ): from torchao.dtypes.floatx import FloatxTensorCoreLayout - assert isinstance(_layout, FloatxTensorCoreLayout), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" + + assert isinstance( + _layout, FloatxTensorCoreLayout + ), f"Only FloatxTensorCoreLayout is supported for floatx, got {_layout}" original_shape = input_float.shape input_float = _layout.pre_process(input_float) # per axis quantization, where axis = 1 @@ -415,12 +497,7 @@ def from_hp_to_fpx( tensor_impl_ctr = get_tensor_impl_constructor(type(_layout)) tensor_impl = tensor_impl_ctr(floatx_packed, scale, None, _layout) - return cls( - tensor_impl, - block_size, - original_shape, - dtype=input_float.dtype - ) + return cls(tensor_impl, block_size, original_shape, dtype=input_float.dtype) @property def _layout(self) -> Layout: @@ -472,9 +549,9 @@ def _apply_fn_to_data(self, fn): register_layout = AffineQuantizedTensor.register_layout get_tensor_impl_constructor = AffineQuantizedTensor.get_tensor_impl_constructor + @dataclass(frozen=True) class SemiSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() @@ -494,6 +571,7 @@ class TensorCoreTiledLayout(Layout): inner_k_tiles is an internal argument for packing function of tensor core tiled layout that can affect the performance of the matmul kernel """ + inner_k_tiles: int = 8 def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -506,14 +584,25 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: ) return input - def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def pre_process_static( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: input = self.pre_process(input) orig_qparam_shape = scale.shape - new_qparam_shape, reduction_dims = _get_reduction_params(block_size, input.size()) + new_qparam_shape, reduction_dims = _get_reduction_params( + block_size, input.size() + ) for dim in reduction_dims: new_qparam_shape.pop(dim) - change_in_qparam_shape = [new_dim_size-orig_dim_size for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape)] - padding_changes=[] + change_in_qparam_shape = [ + new_dim_size - orig_dim_size + for new_dim_size, orig_dim_size in zip(new_qparam_shape, orig_qparam_shape) + ] + padding_changes = [] for dim_change in change_in_qparam_shape: padding_changes = [0, dim_change] + padding_changes scale = torch.nn.functional.pad(scale, padding_changes) @@ -541,7 +630,6 @@ class Float8Layout(Layout): @dataclass(frozen=True) class MarlinSparseLayout(Layout): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1ยบ: the input tensor is transposed since the linear layer keeps the weights in a transposed format @@ -555,6 +643,7 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: the preprocessed tensor """ from torchao.sparsity.marlin import inject_24 # avoid circular import + input_t = input.t() w_24, _ = inject_24(input_t, *input_t.shape) return w_24.t() @@ -571,6 +660,7 @@ class PlainAQTTensorImpl(AQTTensorImpl): scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor """ + def __new__( cls, int_data: torch.Tensor, @@ -607,8 +697,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - _layout, = tensor_attributes + int_data, scale, zero_point = ( + tensor_data_dict["int_data"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + (_layout,) = tensor_attributes return cls(int_data, scale, zero_point, _layout) def to(self, *args, **kwargs): @@ -653,13 +747,27 @@ def __torch_dispatch__(cls, func, types, args, kwargs): self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), ) elif dim == 1: - assert len(self.scale.shape) == 1, f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" - return PlainAQTTensorImpl(aten.slice.Tensor(self.int_data, dim, start, end, step), self.scale.view(-1), self.zero_point.view(-1), self._layout) + assert ( + len(self.scale.shape) == 1 + ), f"slice dim==1 only works when len(scale.shape) == 1 currently, got: {self.scale.shape}" + return PlainAQTTensorImpl( + aten.slice.Tensor(self.int_data, dim, start, end, step), + self.scale.view(-1), + self.zero_point.view(-1), + self._layout, + ) else: - raise NotImplementedError(f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"PlainAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) raise NotImplementedError( f"PlainAQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -684,11 +792,13 @@ def from_plain( assert isinstance(_layout, PlainLayout) return cls(int_data, scale, zero_point, _layout) + @register_layout(SemiSparseLayout) class SemiSparseAQTTensorImpl(PlainAQTTensorImpl): """ TensorImpl for semi_sparse_cusparselt layout for affine quantized tensor """ + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -706,10 +816,10 @@ def get_plain(self): # Currently we don't have cuSPARSELt expansion routines, so we matmul by # the identity matrix to get the original dense matrix. This is slow though. cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, - torch.eye(cols, - dtype=self.int_data.dtype, - device=self.int_data.device).t()) + int_data_expanded = torch._cslt_sparse_mm( + self.int_data, + torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t(), + ) return int_data_expanded, self.scale, self.zero_point @classmethod @@ -724,6 +834,7 @@ def from_plain( int_data_compressed = torch._cslt_compress(int_data) return cls(int_data_compressed, scale, zero_point, _layout) + @register_layout(BlockSparseLayout) class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): bsr_crow_indices: Optional[torch.Tensor] @@ -732,7 +843,13 @@ class BlockSparseAQTTensorImpl(PlainAQTTensorImpl): scale: Optional[torch.Tensor] zero_point: Optional[torch.Tensor] - __slots__ = ["bsr_crow_indices", "bsr_col_indices", "bsr_values", "scale", "zero_point"] + __slots__ = [ + "bsr_crow_indices", + "bsr_col_indices", + "bsr_values", + "scale", + "zero_point", + ] @staticmethod def __new__( # noqa: PYI034 @@ -814,17 +931,23 @@ def from_plain(cls, int_data, scale, zero_point, _layout): bsr_values=bsr_tensor.values(), scale=scale, zero_point=zero_point, - _layout = _layout, + _layout=_layout, requires_grad=False, ) def get_plain(self): - int_data_expanded = torch.ops.blocksparse.bsr_to_dense(self.crow_indices(), self.col_indices(), self.values(), self.shape[0], self.shape[1]) + int_data_expanded = torch.ops.blocksparse.bsr_to_dense( + self.crow_indices(), + self.col_indices(), + self.values(), + self.shape[0], + self.shape[1], + ) return int_data_expanded, self.scale, self.zero_point def _apply_fn_to_data(self, func): return self.__class__( - shape = self.shape, + shape=self.shape, bsr_crow_indices=func(self.bsr_crow_indices), bsr_col_indices=func(self.bsr_col_indices), bsr_values=func(self.bsr_values), @@ -864,6 +987,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs): f"BlockSparseAQTTensorImpl dispatch: attempting to run {func}, this is not supported" ) + @register_layout(MarlinSparseLayout) class MarlinSparseAQTTensorImpl(AQTTensorImpl): """ @@ -882,6 +1006,7 @@ class MarlinSparseAQTTensorImpl(AQTTensorImpl): group_size (int): the group size used to pack the tensor num_bits (int): the number of bits used to quantize the tensor """ + @staticmethod def __new__( cls, @@ -938,7 +1063,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [self._layout, self.original_shape, self.group_size, self.num_bits] + return ["int_data", "scale", "zero_point", "meta"], [ + self._layout, + self.original_shape, + self.group_size, + self.num_bits, + ] @classmethod def __tensor_unflatten__( @@ -949,10 +1079,22 @@ def __tensor_unflatten__( zero_point = tensor_data_dict["zero_point"] meta = tensor_data_dict["meta"] _layout, original_shape, group_size, num_bits = tensor_attributes - return cls(int_data, scale, zero_point, meta, _layout, original_shape, group_size, num_bits) + return cls( + int_data, + scale, + zero_point, + meta, + _layout, + original_shape, + group_size, + num_bits, + ) def get_plain(self): - from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import + from torchao.sparsity.marlin import ( + unpack_from_marlin_24, + ) # avoid circular import + int_data_expanded, scales_expanded = unpack_from_marlin_24( self.int_data, self.scale, @@ -973,7 +1115,11 @@ def from_plain( zero_point: torch.Tensor, _layout: Layout, ): - from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import + from torchao.sparsity.marlin import ( + pack_to_marlin_24, + const, + ) # avoid circular import + assert isinstance(_layout, MarlinSparseLayout) # Linear layers are (in_features, out_features) but the int_data that is reaching this point @@ -983,7 +1129,7 @@ def from_plain( if not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( - f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." ) if q_w_24.dtype != torch.int32: @@ -1000,14 +1146,14 @@ def from_plain( # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main num_bits = 4 if torch.max(q_w_24) < 16 else -1 if num_bits not in [4]: - raise ValueError( - f"Only {[4]} bits are supported, got {num_bits}." - ) + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") group_size = in_features // scale_t.shape[0] if group_size == 0: group_size = in_features - assert group_size <= in_features, "Group size must be less than or equal to in_features." + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." if group_size not in const.SUPPORTED_GROUP_SIZES: raise ValueError( @@ -1015,12 +1161,19 @@ def from_plain( ) # Compress quantized weight to marlin 2:4 format - marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( + q_w_24, scale_t, num_bits, group_size + ) return cls( - marlin_24_q_w_comp, marlin_24_s, zero_point, - meta, _layout, q_w_24.shape, - group_size, num_bits + marlin_24_q_w_comp, + marlin_24_s, + zero_point, + meta, + _layout, + q_w_24.shape, + group_size, + num_bits, ) def get_layout(self) -> Layout: @@ -1042,6 +1195,7 @@ class Float8AQTTensorImpl(AQTTensorImpl): Note: technically we should not create a new layout for float8 we should merge this into plain layout """ + float8_data: torch.Tensor scale: torch.Tensor transposed: bool @@ -1076,7 +1230,7 @@ def __init__( self._layout = _layout def _apply_fn_to_data(self, fn): - """ Applys a fn to all tensor components stored on this class""" + """Applys a fn to all tensor components stored on this class""" return self.__class__( fn(self.float8_data), fn(self.scale), @@ -1101,7 +1255,10 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - transposed, _layout, = tensor_attributes + ( + transposed, + _layout, + ) = tensor_attributes return cls(float8_data, scale, transposed, _layout) @classmethod @@ -1125,23 +1282,50 @@ def __torch_dispatch__(cls, func, types, args, kwargs): elif func is aten.slice.Tensor: self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) if dim == 0: - #TODO: scale replecation should be dependent on block size + # TODO: scale replecation should be dependent on block size if self.scale.ndim == 1: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: aten.slice.Tensor(x, dim, start, end, step)) + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: aten.slice.Tensor(x, dim, start, end, step) + ), ) elif self.scale.ndim == 0: return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step), self.scale, None, self._layout) + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor(self.float8_data, dim, start, end, step), + self.scale, + None, + self._layout, + ), ) else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported") + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with scale ndim={dim}, that is not supported" + ) elif dim == 1: return return_and_correct_aliasing( - func, args, kwargs, Float8AQTTensorImpl(aten.slice.Tensor(self.float8_data, dim, start, end, step).contiguous(), self.scale, None, self._layout) + func, + args, + kwargs, + Float8AQTTensorImpl( + aten.slice.Tensor( + self.float8_data, dim, start, end, step + ).contiguous(), + self.scale, + None, + self._layout, + ), ) else: - raise NotImplementedError(f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"Float8AQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) else: raise NotImplementedError( f"Float8AQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -1163,19 +1347,25 @@ def from_plain( zero_point: Optional[torch.Tensor], _layout: Layout, ): - """ Main entrypoint for constructing Float8TensorImpl""" - assert _is_float8_type(data.dtype), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(_layout, Float8Layout), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" + """Main entrypoint for constructing Float8TensorImpl""" + assert _is_float8_type( + data.dtype + ), f"Float8 TensorImpl must be constructed from float8 dtype but got {data.dtype}" + assert isinstance( + _layout, Float8Layout + ), f"Float8 TensorImpl must be constructed from Float8Layout but got {_layout}" return cls(data, scale, False, _layout) def __repr__(self): float8_data, scale, _ = self.get_plain() _layout = self.get_layout() - return (f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"_layout={_layout})") + return ( + f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"_layout={_layout})" + ) @register_layout(TensorCoreTiledLayout) @@ -1212,7 +1402,9 @@ def __new__( kwargs = {} kwargs["device"] = packed_weight.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout ) kwargs["dtype"] = packed_weight.dtype kwargs["requires_grad"] = False @@ -1238,8 +1430,14 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - transposed, _layout, = tensor_attributes + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + _layout, + ) = tensor_attributes return cls(packed_weight, scale_and_zero, transposed, _layout) @classmethod @@ -1248,20 +1446,25 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - _layout: Layout + _layout: Layout, ): - assert isinstance(_layout, TensorCoreTiledLayout) if TORCH_VERSION_AT_LEAST_2_5: int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" else: - assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, _layout.inner_k_tiles) + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, _layout.inner_k_tiles + ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros + scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) return cls(packed_weight, scale_and_zero, False, _layout) @@ -1272,7 +1475,9 @@ def to(self, *args, **kwargs): # between these two devices, in the future we should not use the same layout for # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 if not is_device(torch.device(self.device).type, device): - raise ValueError(f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}") + raise ValueError( + f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}" + ) return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -1309,7 +1514,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): """we don't need to repack the weight and just rely on external shape being changed and record the status of transpose/no-transpose """ - transposed = TensorCoreTiledAQTTensorImpl(args[0].packed_weight, args[0].scale_and_zero, not args[0].transposed, args[0]._layout) + transposed = TensorCoreTiledAQTTensorImpl( + args[0].packed_weight, + args[0].scale_and_zero, + not args[0].transposed, + args[0]._layout, + ) return return_and_correct_aliasing(func, args, kwargs, transposed) if func is aten.slice.Tensor: @@ -1334,11 +1544,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs): # this is to handle padding int_data = self._layout.post_process(int_data) scale = aten.slice.Tensor(scale, dim, start_scale, end_scale, step) - zero_point = aten.slice.Tensor(zero_point, dim, start_scale, end_scale, step) + zero_point = aten.slice.Tensor( + zero_point, dim, start_scale, end_scale, step + ) sliced = self.from_plain(int_data, scale, zero_point, self._layout) return sliced else: - raise NotImplementedError(f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported") + raise NotImplementedError( + f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, with dim={dim}, that is not supported" + ) raise NotImplementedError( f"TensorCoreTiledAQTTensorImpl dispatch: attempting to run {func}, this is not supported" @@ -1352,6 +1566,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quantize_affine, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -1368,12 +1583,26 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) dequantized = dequantized.t().contiguous() # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) return int_data, scale, zero def get_layout(self) -> Layout: @@ -1384,28 +1613,31 @@ def get_layout(self) -> Layout: # torch functional and aten operator implementation # ##################################################### + def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.tensor_impl.dtype == torch.int8 and - (aqt.quant_min is None or aqt.quant_min == -128) and - (aqt.quant_max is None or aqt.quant_max == 127) + aqt.tensor_impl.dtype == torch.int8 + and (aqt.quant_min is None or aqt.quant_min == -128) + and (aqt.quant_max is None or aqt.quant_max == 127) ) + def _aqt_is_int8_reduced_range(aqt): return ( - aqt.tensor_impl.dtype == torch.int8 and - aqt.quant_min == -127 and - (aqt.quant_max is None or aqt.quant_max == 127) + aqt.tensor_impl.dtype == torch.int8 + and aqt.quant_min == -127 + and (aqt.quant_max is None or aqt.quant_max == 127) ) + def _aqt_is_tensor_core_tile_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.tensor_impl.dtype == torch.int32 and - aqt.quant_min == 0 and - aqt.quant_max == 15 + aqt.tensor_impl.dtype == torch.int32 + and aqt.quant_min == 0 + and aqt.quant_max == 15 ) @@ -1417,16 +1649,18 @@ def _aqt_is_tensor_core_tile_uint4(aqt): # bias: dimension is (out_features,) # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches + def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, PlainLayout) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, PlainLayout) ) + def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): # # 1. do the matrix form of dot(X_i, W_j) @@ -1447,7 +1681,9 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): x_scales_dtype = x_scales.dtype # Cast fp16 scale to float to avoid overflow in int_scaled_matmul intermediate_dtype = torch.float if x_scales_dtype == torch.half else x_scales_dtype - y_dot_scaled = int_scaled_matmul(tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype)) + y_dot_scaled = int_scaled_matmul( + tmp, w_vals_int8_t, x_scales.reshape(-1, 1).to(intermediate_dtype) + ) y_dot_scaled = y_dot_scaled.to(x_scales_dtype) y = (y_dot_scaled * w_scales).reshape( @@ -1462,18 +1698,23 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): return y -def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weight_tensor, bias): +def _linear_int8_act_int8_weight_semi_structured_sparse_check( + input_tensor, weight_tensor, bias +): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, SemiSparseLayout) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, SemiSparseLayout) ) -def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): + +def _linear_int8_act_int8_weight_semi_structured_sparse_impl( + input_tensor, weight_tensor, bias +): x_vals_int8 = input_tensor.tensor_impl.int_data x_scales = input_tensor.tensor_impl.scale w_vals_int8 = weight_tensor.tensor_impl.int_data @@ -1481,7 +1722,10 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, + w_vals_int8, + tmp.t(), + alpha=w_scales.to(torch.float32), + out_dtype=torch.bfloat16, ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] @@ -1493,15 +1737,16 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y + def _linear_int8_act_int8_weight_block_sparse_check(input_tensor, weight_tensor, bias): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor._layout, PlainLayout) and - isinstance(weight_tensor._layout, BlockSparseLayout) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor._layout, PlainLayout) + and isinstance(weight_tensor._layout, BlockSparseLayout) ) @@ -1513,12 +1758,14 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) tmp_t = tmp.t() - y = torch.ops.blocksparse.int_addmm(w_vals.crow_indices(), - w_vals.col_indices(), - w_vals.values(), - tmp_t, - w_scales, - x_scales.reshape(-1)) + y = torch.ops.blocksparse.int_addmm( + w_vals.crow_indices(), + w_vals.col_indices(), + w_vals.values(), + tmp_t, + w_scales, + x_scales.reshape(-1), + ) y_shape = (*x_vals_int8.shape[:-1], w_scales.shape[-1]) y = y.reshape(*y_shape) @@ -1533,20 +1780,23 @@ def _linear_int8_act_int8_weight_block_sparse_impl(input_tensor, weight_tensor, def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.dtype == torch.bfloat16 and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.dtype == torch.bfloat16 + and # weight is uint4, group quantized tensor_core_tiled tensor impl affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - weight_tensor.dtype == torch.bfloat16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and - isinstance(weight_tensor._layout, TensorCoreTiledLayout) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and weight_tensor.dtype == torch.bfloat16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and isinstance(weight_tensor._layout, TensorCoreTiledLayout) ) def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): - assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" assert input_tensor.shape[-1] == weight_tensor.shape[1], ( f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " @@ -1570,14 +1820,15 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero) + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] y = y[:, :orig_out_features] y = y.reshape(*orig_act_size[:-1], orig_out_features) - if bias is not None: y += bias return y.to(orig_dtype) @@ -1586,19 +1837,21 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): return ( # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_int8(weight_tensor) and - len(weight_tensor.shape) == 2 and - len(weight_tensor.block_size) == 2 and - weight_tensor.block_size[0] == 1 and - weight_tensor.block_size[1] == weight_tensor.shape[1] and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, PlainLayout) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int8(weight_tensor) + and len(weight_tensor.shape) == 2 + and len(weight_tensor.block_size) == 2 + and weight_tensor.block_size[0] == 1 + and weight_tensor.block_size[1] == weight_tensor.shape[1] + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, PlainLayout) ) + def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # TODO: enable cpu and mps efficient path # is_cpu and is_mps only, some issue with is_contiguous() currently @@ -1607,7 +1860,6 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # per channel int8 weight only quantizated mm w_vals_int8_t = weight_tensor.tensor_impl.int_data.t() scale = weight_tensor.tensor_impl.scale - orig_dtype = input_tensor.dtype m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8_t.to(input_tensor.dtype), @@ -1618,30 +1870,31 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y + def _linear_f16_bf16_act_floatx_weight_check(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import FloatxTensorCoreLayout + return ( # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - input_tensor.dtype in (torch.float16, torch.bfloat16) and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and input_tensor.dtype in (torch.float16, torch.bfloat16) + and # weight is floatx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, FloatxTensorCoreLayout) and - ( + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, FloatxTensorCoreLayout) + and ( # weight is using fp6 quantization - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 3) or + (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 3) + or # weight is using fp5 quantization - (weight_tensor._layout.ebits == 2 and - weight_tensor._layout.mbits == 2) or - (weight_tensor._layout.ebits == 3 and - weight_tensor._layout.mbits == 1) + (weight_tensor._layout.ebits == 2 and weight_tensor._layout.mbits == 2) + or (weight_tensor._layout.ebits == 3 and weight_tensor._layout.mbits == 1) ) ) + def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): from torchao.dtypes.floatx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear @@ -1670,6 +1923,7 @@ def _linear_f16_bf16_act_floatx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) + def _linear_fp8_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], @@ -1677,16 +1931,17 @@ def _linear_fp8_act_fp8_weight_check( ) -> bool: def check_aqt(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( - isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt._layout, Float8Layout) + isinstance(aqt, AffineQuantizedTensor) + and isinstance(aqt._layout, Float8Layout) and aqt.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and (aqt.shape == aqt.block_size or _is_rowwise_scaled(aqt)) ) + return check_aqt(input_tensor) and check_aqt(weight_tensor) def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): - """ Ensures input tensor is correctly formated for _scaled_mm """ + """Ensures input tensor is correctly formated for _scaled_mm""" input_scale = input_scale.unsqueeze(-1) if input_scale.dim() > 2: @@ -1694,6 +1949,7 @@ def preprocess_scale(input_scale: torch.Tensor, input_shape: Tuple[int]): return input_scale + def _linear_fp8_act_fp8_weight_impl( input_tensor: AffineQuantizedTensor, weight_tensor: AffineQuantizedTensor, @@ -1717,7 +1973,9 @@ def _linear_fp8_act_fp8_weight_impl( # Handle rowwise case if _is_rowwise_scaled(weight_tensor): - assert _is_rowwise_scaled(input_tensor), "Input tensor must be rowwise block size" + assert _is_rowwise_scaled( + input_tensor + ), "Input tensor must be rowwise block size" w_scale = w_scale.unsqueeze(-1).T input_scale = preprocess_scale(input_scale, input_tensor.shape) @@ -1735,6 +1993,7 @@ def _linear_fp8_act_fp8_weight_impl( use_fast_accum=scaled_mm_config.use_fast_accum, ).reshape(out_shape) + def _linear_fp_act_fp8_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], @@ -1742,15 +2001,20 @@ def _linear_fp_act_fp8_weight_check( ) -> bool: return ( # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and # weight is float8 quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor._layout, Float8Layout) + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor._layout, Float8Layout) and weight_tensor.tensor_impl.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] - and (weight_tensor.shape == weight_tensor.block_size or _is_rowwise_scaled(weight_tensor)) + and ( + weight_tensor.shape == weight_tensor.block_size + or _is_rowwise_scaled(weight_tensor) + ) ) + def _linear_fp_act_fp8_weight_impl( input_tensor: torch.Tensor, weight_tensor: AffineQuantizedTensor, @@ -1758,18 +2022,20 @@ def _linear_fp_act_fp8_weight_impl( ): return torch.nn.functional.linear(input_tensor, weight_tensor.dequantize(), bias) + def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): return ( - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_tensor_core_tile_uint4(weight_tensor) and - input_tensor.dtype == torch.float16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor._layout, MarlinSparseLayout) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_tensor_core_tile_uint4(weight_tensor) + and input_tensor.dtype == torch.float16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor._layout, MarlinSparseLayout) ) + def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.sparsity.marlin import marlin_24_workspace, const + from torchao.sparsity.marlin import marlin_24_workspace from torchao.ops import marlin_24_gemm assert isinstance(weight_tensor, AffineQuantizedTensor) @@ -1789,8 +2055,15 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b workspace_24 = marlin_24_workspace(original_shape[1]) out = marlin_24_gemm( - input_2d, sparse_w_int4, meta, scale, - workspace_24, num_bits, size_m, size_n, size_k + input_2d, + sparse_w_int4, + meta, + scale, + workspace_24, + num_bits, + size_m, + size_n, + size_k, ) # Unfold the batch dimension @@ -1804,19 +2077,33 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), - (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), - (_linear_int8_act_int8_weight_block_sparse_check, _linear_int8_act_int8_weight_block_sparse_impl), + ( + _linear_int8_act_int8_weight_semi_structured_sparse_check, + _linear_int8_act_int8_weight_semi_structured_sparse_impl, + ), + ( + _linear_int8_act_int8_weight_block_sparse_check, + _linear_int8_act_int8_weight_block_sparse_impl, + ), (_linear_fp8_act_fp8_weight_check, _linear_fp8_act_fp8_weight_impl), (_linear_fp_act_fp8_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), - (_linear_f16_bf16_act_floatx_weight_check, _linear_f16_bf16_act_floatx_weight_impl), - (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), + ( + _linear_f16_bf16_act_floatx_weight_check, + _linear_f16_bf16_act_floatx_weight_impl, + ), + ( + _linear_fp_act_int4_weight_sparse_marlin_check, + _linear_fp_act_int4_weight_sparse_marlin_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + _register_aqt_quantized_linear_dispatches() + @implements([torch.nn.functional.linear, aten.linear.default]) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -1825,7 +2112,9 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -1834,7 +2123,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1843,23 +2136,36 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + @implements(torch.nn.functional.embedding) def _(func, types, args, kwargs): # new_arg1 = args[1].dequantize() # return torch.nn.embedding(args[0], new_arg1, *args[2:], **kwargs) - assert isinstance(args[1].tensor_impl, PlainAQTTensorImpl), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" - assert kwargs["padding_idx"] is None and kwargs["max_norm"] is None and not kwargs["scale_grad_by_freq"] and not kwargs["sparse"] and kwargs["norm_type"]==2.0 + assert isinstance( + args[1].tensor_impl, PlainAQTTensorImpl + ), f"embedding only works with PlainAQTTensorImpl but got {type(args[1].tensor_impl)}" + assert ( + kwargs["padding_idx"] is None + and kwargs["max_norm"] is None + and not kwargs["scale_grad_by_freq"] + and not kwargs["sparse"] + and kwargs["norm_type"] == 2.0 + ) idx = args[0] int_data, scale, zero_point = args[1].tensor_impl.get_plain() - - sliced_data, sliced_scale, sliced_zero_point = int_data[idx], scale[idx], zero_point[idx] + + sliced_data, sliced_scale, sliced_zero_point = ( + int_data[idx], + scale[idx], + zero_point[idx], + ) # Block size is expecting 2 dimensions [1, group size] but - # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so + # batchsize or other dims gets added to sliced_data, sliced_scale and sliced_zero_point so # we need to increase block size to correct dim - new_blocks = idx.dim()-1 + new_blocks = idx.dim() - 1 return dequantize_affine( sliced_data, - new_blocks*[1]+list(args[1].block_size), + new_blocks * [1] + list(args[1].block_size), sliced_scale, sliced_zero_point, sliced_data.dtype, @@ -1869,6 +2175,7 @@ def _(func, types, args, kwargs): output_dtype=sliced_scale.dtype, ) + @implements(aten.addmm.default) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -1877,7 +2184,9 @@ def _(func, types, args, kwargs): args[0], ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -1887,7 +2196,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1896,22 +2209,25 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return func(bias, input_tensor, weight_tensor) + @implements(aten.mm.default) def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - None - ) + input_tensor, weight_tensor, bias = (args[0], args[1], None) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `_layout.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor._layout, "quantized_linear_impl") and weight_tensor._layout.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor._layout, "quantized_linear_impl") + and weight_tensor._layout.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1920,6 +2236,7 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return func(input_tensor, weight_tensor) + @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( @@ -1943,6 +2260,7 @@ def _(func, types, args, kwargs): args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) + @implements(aten.t.default) def _(func, types, args, kwargs): block_size = args[0].block_size @@ -1951,10 +2269,18 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.tensor_impl.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + tensor.tensor_impl.t(), + transposed_block_size, + shape, + tensor.quant_min, + tensor.quant_max, + tensor.zero_point_domain, + dtype=tensor.dtype, + strides=tensor.stride(), ) return return_and_correct_aliasing(func, args, kwargs, new) + @implements(aten.slice.Tensor) def _(func, types, args, kwargs): self, dim, start, end, step = fill_defaults(args, 5, [0, None, None, 1]) @@ -1965,27 +2291,59 @@ def _(func, types, args, kwargs): shape = list(self.shape) shape[dim] = end - start block_size = self.block_size - assert len(block_size) == 2, f"Slice only works for 2d block_size right now, got: {block_size}" + assert ( + len(block_size) == 2 + ), f"Slice only works for 2d block_size right now, got: {block_size}" # with slice, some shape dimension might be smaller than block_size dimension, so # we need to make sure there is no overflow block_size = (min(shape[0], block_size[0]), min(shape[1], block_size[1])) - new = self.__class__(aten.slice.Tensor(self.tensor_impl, dim, start, end, step), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + new = self.__class__( + aten.slice.Tensor(self.tensor_impl, dim, start, end, step), + block_size, + shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) return return_and_correct_aliasing(func, args, kwargs, new) + # this is needed for DTensor.from_local() and for flattening tensor @implements(aten.view.default) def _(func, types, args, kwargs): self, shape = args if tuple(self.shape) == tuple(shape): - return self.__class__(self.tensor_impl, self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__( + self.tensor_impl, + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) if len(shape) == 1 and shape[0] == -1: assert len(self.block_size) == 2 and self.block_size[0] == 1 block_size = (self.block_size[1],) - return self.__class__(self.tensor_impl, block_size, (self.numel(),), self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()) + return self.__class__( + self.tensor_impl, + block_size, + (self.numel(),), + self.quant_min, + self.quant_max, + self.zero_point_domain, + dtype=self.dtype, + strides=self.stride(), + ) - raise ValueError(f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]") + raise ValueError( + f"{self.__class__.__name__} only supports .view() with same shape or shape=[-1]" + ) to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx diff --git a/torchao/dtypes/floatx/__init__.py b/torchao/dtypes/floatx/__init__.py index d7559015f4..c7864582fa 100644 --- a/torchao/dtypes/floatx/__init__.py +++ b/torchao/dtypes/floatx/__init__.py @@ -1 +1,15 @@ -from .floatx import FloatxTensorCoreLayout, FloatxTensorCoreAQTTensorImpl, to_scaled_tc_floatx, from_scaled_tc_floatx, _SPLIT_K_MAP +from .floatx import ( + FloatxTensorCoreLayout, + FloatxTensorCoreAQTTensorImpl, + to_scaled_tc_floatx, + from_scaled_tc_floatx, + _SPLIT_K_MAP, +) + +__all__ = [ + "FloatxTensorCoreLayout", + "FloatxTensorCoreAQTTensorImpl", + "to_scaled_tc_floatx", + "from_scaled_tc_floatx", + "_SPLIT_K_MAP", +] diff --git a/torchao/dtypes/floatx/floatx.py b/torchao/dtypes/floatx/floatx.py index a4745e9315..a9c4d46917 100644 --- a/torchao/dtypes/floatx/floatx.py +++ b/torchao/dtypes/floatx/floatx.py @@ -4,11 +4,14 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.prototype.custom_fp_utils import _f32_to_floatx_unpacked, _floatx_unpacked_to_f32, _n_ones +from torchao.prototype.custom_fp_utils import ( + _f32_to_floatx_unpacked, + _floatx_unpacked_to_f32, + _n_ones, +) from torchao.dtypes.utils import ( Layout, ) -from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass from torchao.dtypes.affine_quantized_tensor import AQTTensorImpl, register_layout @@ -18,11 +21,23 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) + return reduce( + torch.bitwise_or, + [ + x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ], + ) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) + return torch.stack( + [ + (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits) + ], + dim=-1, + ).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -36,8 +51,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, - 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 1: [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 3, + 7, + 11, + 15, + 19, + 23, + 27, + 31, + 0, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + ], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -46,8 +93,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, - 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], + 1: [ + 16, + 0, + 24, + 8, + 17, + 1, + 25, + 9, + 18, + 2, + 26, + 10, + 19, + 3, + 27, + 11, + 20, + 4, + 28, + 12, + 21, + 5, + 29, + 13, + 22, + 6, + 30, + 14, + 23, + 7, + 31, + 15, + ], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -83,8 +162,12 @@ def _pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code - tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code + tensor_ybit = ( + tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + ) # Pass 2 from original code + tensor_ybit = _bit_interleave( + tensor_ybit.flatten(), y + ) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -118,7 +201,9 @@ def pack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: return _pack_tc_floatx(tensor, nbits) -def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Tensor]: +def to_scaled_tc_floatx( + tensor: Tensor, ebits: int, mbits: int +) -> Tuple[Tensor, Tensor]: # _n_ones() is not compatible with torch.compile() due to << operator # https://github.com/pytorch/pytorch/issues/119152 # exp_bias = _n_ones(ebits - 1) @@ -126,7 +211,9 @@ def to_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( + _ONES_TABLE[mbits + 1] / (2**mbits) + ) dtype = tensor.dtype tensor = tensor.float() @@ -153,8 +240,10 @@ def _unpack_tc_floatx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = ( + tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + ) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -225,7 +314,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 7 + 57344: 7, }, { # tokens: [65:128] 3072: 9, @@ -236,7 +325,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 7, 28672: 7, - 57344: 6 + 57344: 6, }, { # tokens: [129:192] 3072: 6, @@ -247,7 +336,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 5, 14336: 5, 28672: 5, - 57344: 4 + 57344: 4, }, { # tokens: [193:256] 3072: 9, @@ -258,7 +347,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 4, 14336: 8, 28672: 6, - 57344: 4 + 57344: 4, }, { # tokens: [257:320] 3072: 7, @@ -269,7 +358,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 3, 28672: 3, - 57344: 4 + 57344: 4, }, { # tokens: [321:384] 3072: 3, @@ -280,7 +369,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 8, 14336: 3, 28672: 4, - 57344: 3 + 57344: 3, }, { # tokens: [385:448] 3072: 5, @@ -291,7 +380,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 1, 28672: 1, - 57344: 3 + 57344: 3, }, { # tokens: [449:512] 3072: 2, @@ -302,7 +391,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 6, 28672: 4, - 57344: 1 + 57344: 1, }, { # tokens: [513:576] 3072: 2, @@ -313,7 +402,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 3, 14336: 3, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [577:640] 3072: 5, @@ -324,7 +413,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [641:704] 3072: 3, @@ -335,7 +424,7 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 2, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [705:768] 3072: 3, @@ -346,20 +435,22 @@ def from_scaled_tc_floatx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> 10240: 1, 14336: 1, 28672: 1, - 57344: 1 - } + 57344: 1, + }, ] # quantization api integrations + @dataclass(frozen=True) class FloatxTensorCoreLayout(Layout): - """Layout type for FloatxTensorCoreAQTTensorImpl - """ + """Layout type for FloatxTensorCoreAQTTensorImpl""" + ebits: int mbits: int + @register_layout(FloatxTensorCoreLayout) class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): """FloatxTensorCoreAQTTensorImpl represents a Tensor with dtype floatx(ebits=a, mbits=b), @@ -383,6 +474,7 @@ class FloatxTensorCoreAQTTensorImpl(AQTTensorImpl): it will then pack the weight and instantiate the FloatxTensorCoreAQTTensorImpl tensor FloatxTensorCoreAQTTensorImpl.__init__() takes a packed floatx Tensor of shape (M, N // 8 * nbit) """ + def __new__( cls, packed_floatx_data: torch.Tensor, @@ -391,11 +483,16 @@ def __new__( ): assert packed_floatx_data.ndim == 2 assert packed_floatx_data.dtype == torch.uint8 - shape = (packed_floatx_data.shape[0], packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8) + shape = ( + packed_floatx_data.shape[0], + packed_floatx_data.shape[1] // (1 + _layout.ebits + _layout.mbits) * 8, + ) kwargs = {} kwargs["device"] = packed_floatx_data.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_floatx_data.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_floatx_data.layout ) kwargs["dtype"] = packed_floatx_data.dtype kwargs["requires_grad"] = False @@ -418,12 +515,17 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_floatx_data, scale = tensor_data_dict["packed_floatx_data"], tensor_data_dict["scale"] - _layout, = tensor_attributes + packed_floatx_data, scale = ( + tensor_data_dict["packed_floatx_data"], + tensor_data_dict["scale"], + ) + (_layout,) = tensor_attributes return cls(packed_floatx_data, scale, _layout) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_floatx_data = unpack_tc_floatx(self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits) + unpacked_floatx_data = unpack_tc_floatx( + self.packed_floatx_data, 1 + self._layout.ebits + self._layout.mbits + ) return unpacked_floatx_data, self.scale @classmethod @@ -442,7 +544,9 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(_layout, FloatxTensorCoreLayout) - packed_floatx_data = pack_tc_floatx(unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits) + packed_floatx_data = pack_tc_floatx( + unpacked_floatx_data, 1 + _layout.ebits + _layout.mbits + ) return cls(packed_floatx_data, scale, _layout) def __repr__(self): @@ -480,7 +584,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: x.to(device=kwargs.pop("device", None)) + ), ) raise NotImplementedError( diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uint4.py index fc6eb2646c..14bd0bd3ae 100644 --- a/torchao/dtypes/uint4.py +++ b/torchao/dtypes/uint4.py @@ -105,7 +105,6 @@ def __new__(cls, elem, **kwargs): ) def __init__(self, elem, **kwargs): - self.elem = elem @classmethod diff --git a/torchao/dtypes/uintx/__init__.py b/torchao/dtypes/uintx/__init__.py index c44803f6d2..068bce8686 100644 --- a/torchao/dtypes/uintx/__init__.py +++ b/torchao/dtypes/uintx/__init__.py @@ -1 +1,15 @@ -from .uintx import UintxTensor, UintxLayout, UintxAQTTensorImpl, to_uintx, _DTYPE_TO_BIT_WIDTH +from .uintx import ( + UintxTensor, + UintxLayout, + UintxAQTTensorImpl, + to_uintx, + _DTYPE_TO_BIT_WIDTH, +) + +__all__ = [ + "UintxTensor", + "UintxLayout", + "UintxAQTTensorImpl", + "to_uintx", + "_DTYPE_TO_BIT_WIDTH", +] diff --git a/torchao/dtypes/uintx/bitpacking.py b/torchao/dtypes/uintx/bitpacking.py index 244ca437ef..e7f2218f92 100644 --- a/torchao/dtypes/uintx/bitpacking.py +++ b/torchao/dtypes/uintx/bitpacking.py @@ -7,16 +7,16 @@ 1: (0x01,), 2: (0x03,), 3: (0x03, 0x04), - 4: (0x0f,), - 5: (0x0f, 0x10), - 6: (0x0f, 0x30), - 7: (0x0f, 0x30, 0x40), + 4: (0x0F,), + 5: (0x0F, 0x10), + 6: (0x0F, 0x30), + 7: (0x0F, 0x30, 0x40), } unpack_mask = { - 1: (0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80), - 2: (0x03,0x0c,0x30,0xc0), - 4: (0x0f,0xf0), + 1: (0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80), + 2: (0x03, 0x0C, 0x30, 0xC0), + 4: (0x0F, 0xF0), } # size of each shard @@ -41,6 +41,7 @@ 7: (0, 4, 6), } + # for shifting groups left but right if shift is negative def abs_lsh(data, shift): if shift == 0: @@ -61,9 +62,9 @@ def abs_rsh(data, shift): return data >> shift -def pack_cpu(data: torch.Tensor, - elem_size: int, - dim: Optional[int] = -1) -> List[torch.Tensor]: +def pack_cpu( + data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 +) -> List[torch.Tensor]: """ Inputs: data: a tensor of sub byte elements in uint8 @@ -111,7 +112,10 @@ def pack_cpu(data: torch.Tensor, After pack, data went from 8 elements to 6: [[0, 105, 151, 37], [39, 146]] In general this means pack reduces input tensor size from n * 8 to n * elem_size """ - torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + torch._assert( + data.shape[dim] % 8 == 0, + f"pack dimension size ({data.shape[dim]}) is not divisble by scale", + ) torch._assert(data.dtype == torch.uint8, "data must be uint8") output_shape = list(data.shape) @@ -131,9 +135,9 @@ def pack_cpu(data: torch.Tensor, return output -def unpack_cpu(data: List[torch.Tensor], - elem_size: int, - dim: Optional[int] = -1) -> torch.Tensor: +def unpack_cpu( + data: List[torch.Tensor], elem_size: int, dim: Optional[int] = -1 +) -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. @@ -159,31 +163,37 @@ def unpack_cpu(data: List[torch.Tensor], for j in range(scale): output_narrow = output.narrow(dim, j * group_size, group_size) group = data[i] & unpack_mask[bit_size][j] - shift_amt = j * bit_size - rel_pos - output_narrow.copy_(torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos))) + output_narrow.copy_( + torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos)) + ) return output + # these are faster on the GPU + def _pack(data, elem_size, scale, dim): - ''' + """ Inner for loop from above pack function - ''' + """ packed_shape = list(data.shape) packed_shape[dim] = packed_shape[dim] // scale packed = torch.zeros(packed_shape, dtype=data.dtype, device=data.device) for i in range(scale): - narrow_slice = data.narrow(dim, data.shape[dim]*i//scale, data.shape[dim] // scale) + narrow_slice = data.narrow( + dim, data.shape[dim] * i // scale, data.shape[dim] // scale + ) packed |= narrow_slice << (elem_size * i) return packed + def _unpack(data, element_size, scale, dim): - ''' + """ Inner for loop from above unpack function - ''' + """ unpacked_shape = list(data.shape) unpacked_shape[dim] *= scale @@ -193,30 +203,57 @@ def _unpack(data, element_size, scale, dim): for i in range(scale): shift_amt = element_size * i - chunk = unpacked_data.narrow(dim, unpacked_data.shape[dim]*i//scale, unpacked_data.shape[dim] // scale).copy_((data >> shift_amt) & nbits) + unpacked_data.narrow( + dim, + unpacked_data.shape[dim] * i // scale, + unpacked_data.shape[dim] // scale, + ).copy_((data >> shift_amt) & nbits) return unpacked_data -def pack(data: torch.Tensor, - elem_size: int, - dim: Optional[int] = -1) -> List[torch.Tensor]: - ''' +def pack( + data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 +) -> List[torch.Tensor]: + """ a less branching but more compute version so better for gpu - ''' - torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + """ + torch._assert( + data.shape[dim] % 8 == 0, + f"pack dimension size ({data.shape[dim]}) is not divisble by scale", + ) torch._assert(data.dtype == torch.uint8, "data must be uint8") container_size = 8 - shards = [(data & maskbits[elem_size][i]) >> shifts[elem_size][i] for i in range(len(maskbits[elem_size]))] - return tuple([_pack(shards[i], numbits[elem_size][i], container_size//numbits[elem_size][i], dim) for i in range(len(maskbits[elem_size]))]) - -def unpack(data: List[torch.Tensor], - elem_size: int, - dim: Optional[int] = 0) -> torch.Tensor: - ''' + shards = [ + (data & maskbits[elem_size][i]) >> shifts[elem_size][i] + for i in range(len(maskbits[elem_size])) + ] + return tuple( + [ + _pack( + shards[i], + numbits[elem_size][i], + container_size // numbits[elem_size][i], + dim, + ) + for i in range(len(maskbits[elem_size])) + ] + ) + + +def unpack( + data: List[torch.Tensor], elem_size: int, dim: Optional[int] = 0 +) -> torch.Tensor: + """ a less branching but more compute version so better for gpu - ''' + """ container_size = 8 # unpack each 4,2,1 bit shard and unshift them back to the correct position - data = [_unpack(data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim) << shifts[elem_size][i] for i in range(len(data))] + data = [ + _unpack( + data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim + ) + << shifts[elem_size][i] + for i in range(len(data)) + ] return reduce(torch.bitwise_or, data) diff --git a/torchao/dtypes/uintx/uintx.py b/torchao/dtypes/uintx/uintx.py index a48faee8dc..ae66a0ce93 100644 --- a/torchao/dtypes/uintx/uintx.py +++ b/torchao/dtypes/uintx/uintx.py @@ -43,6 +43,7 @@ class UintxTensor(TorchAOBaseTensor): bit_width (int): number of bits for each element pack_dim: (int) dimension to pack along """ + bits_to_shard = { 1: ["int1_shard"], 2: ["int2_shard"], @@ -52,6 +53,7 @@ class UintxTensor(TorchAOBaseTensor): 6: ["int4_shard", "int2_shard"], 7: ["int4_shard", "int2_shard", "int1_shard"], } + def __new__( cls, shards: List[torch.Tensor], @@ -81,24 +83,28 @@ def __init__( self.pack_dim = pack_dim def get_shards(self): - return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_width]] + return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]] def __repr__(self): return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)})" def __tensor_flatten__(self): - return self.__class__.bits_to_shard[self.bit_width], [self.packed_shape, self.bit_width, self.pack_dim] + return self.__class__.bits_to_shard[self.bit_width], [ + self.packed_shape, + self.bit_width, + self.pack_dim, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - shards = list(tensor_data_dict.values()) + shards = list(tensor_data_dict.values()) packed_shape, bit_width, pack_dim = tensor_attributes return cls(shards, packed_shape, bit_width, pack_dim) def get_plain(self): - return unpack(self.get_shards(), self.bit_width, dim = self.pack_dim) + return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim) # temporary until kernels on packed tensors are created def apply_transformation(self, fn): @@ -110,18 +116,21 @@ def apply_transformation(self, fn): # temporary until kernels on packed tensors are created def apply_fn_to_shards(self, fn): new_shards = [fn(shard) for shard in self.get_shards()] - return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim) + return self.__class__( + new_shards, self.packed_shape, self.bit_width, self.pack_dim + ) @classmethod def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): - assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" + assert ( + dtype in _DTYPE_TO_BIT_WIDTH.keys() + ), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" bit_width = _DTYPE_TO_BIT_WIDTH[dtype] shards = pack(int_data, bit_width, dim=pack_dim) shape = list(int_data.shape) shape[pack_dim] = shape[pack_dim] * bit_width // 8 return cls(shards, int_data.shape, bit_width, pack_dim) - def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) device = self.device if device is None else device @@ -150,42 +159,52 @@ def to(self, *args, **kwargs): return super().to(*args, **kwargs) - implements = UintxTensor.implements + @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) ) + @implements(aten.view.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) ) + @implements(aten._to_copy.default) def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0] - ) + return return_and_correct_aliasing(func, args, kwargs, args[0]) + @implements(aten.sub.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)) + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)), ) + @implements(aten.mul.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)) + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)), ) + # quantization api integrations to_uintx = UintxTensor.from_uint8 + @dataclass(frozen=True) class UintxLayout(Layout): dtype: torch.dtype @@ -194,9 +213,9 @@ class UintxLayout(Layout): def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) + @register_layout(UintxLayout) class UintxAQTTensorImpl(PlainAQTTensorImpl): - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 7c0dfd9dc8..b0df07cb19 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -20,6 +20,8 @@ behaviors when running the same operator, e.g. transpose, quantized_linear. This is the same as layout in PyTorch native Tensor """ + + @dataclass(frozen=True) class Layout: def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -28,7 +30,13 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: def post_process(self, input: torch.Tensor) -> torch.Tensor: return input - def pre_process_static(self, input: torch.Tensor, scale: torch.Tensor, zero_point: torch.Tensor, block_size: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + def pre_process_static( + self, + input: torch.Tensor, + scale: torch.Tensor, + zero_point: torch.Tensor, + block_size: Tuple[int, ...], + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.pre_process(input), scale, zero_point def __repr__(self): @@ -37,16 +45,21 @@ def __repr__(self): def extra_repr(self) -> str: return "" + """ Plain Layout, the most basic Layout, also has no extra metadata, will typically be the default """ + + @dataclass(frozen=True) class PlainLayout(Layout): pass + def is_device(target_device_str: str, device: Union[str, torch.device]): return torch.device(device).type == target_device_str + def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]: """Returns the unflattened shape of the input tensor. Args: