diff --git a/ruff.toml b/ruff.toml index 04c9e32cca..9545a9be96 100644 --- a/ruff.toml +++ b/ruff.toml @@ -2,6 +2,8 @@ # We plan to add files in chunks using the 'include' list below. # To add a new path: Simply add it to the 'include' list. # Example: To lint all files in every subfolder of 'test', add "test/**/*" +# To exclude a file type: Simply add it to the 'include' list. +# Example: To lint all files in every subfolder of 'test', add "test/**/*" include = [ "torchao/float8/inference.py", "torchao/float8/float8_utils.py", @@ -10,4 +12,9 @@ include = [ "torchao/float8/float8_tensor.py", "torchao/quantization/linear_activation_weight_observer.py", "test/quantization/test_observer.py", + "torchao/dtypes/*" +] + +exclude = [ + "**/*.md" ] diff --git a/torchao/dtypes/__init__.py b/torchao/dtypes/__init__.py index e27bf6497a..8d65fafaca 100644 --- a/torchao/dtypes/__init__.py +++ b/torchao/dtypes/__init__.py @@ -1,4 +1,5 @@ from .nf4tensor import NF4Tensor, to_nf4 + # from ..prototype.dtypes.uint2 import UInt2Tensor, BitnetTensor from .uint4 import UInt4Tensor from .affine_quantized_tensor import ( @@ -21,7 +22,7 @@ __all__ = [ "NF4Tensor", "to_nf4", - "UInt4Tensor" + "UInt4Tensor", "AffineQuantizedTensor", "to_affine_quantized_intx", "to_affine_quantized_intx_static", diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index c6a3730859..19decc4dbb 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1,7 +1,6 @@ import torch from typing import Tuple, Optional, Union -from collections import defaultdict -import functools +import torchao.ops import math from torchao.quantization.quant_primitives import ( choose_qparams_affine, @@ -32,13 +31,13 @@ find_multiple, TorchAOBaseTensor, TORCH_VERSION_AT_LEAST_2_5, - _is_float8_type + _is_float8_type, ) import logging +from torchao.float8.inference import Float8MMConfig logger = logging.getLogger(__name__) -from torchao.float8.inference import Float8MMConfig aten = torch.ops.aten @@ -49,6 +48,7 @@ class AQTLayout(TorchAOBaseTensor): """ Base class for the layout tensor for `AffineQuantizedTensor` """ + def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Get the plain (unpacked) Tensor for the layout Tensor @@ -68,7 +68,7 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - """ Construct a Layout from data, scale, zero_point and the layout_type""" + """Construct a Layout from data, scale, zero_point and the layout_type""" pass def __repr__(self): @@ -83,11 +83,14 @@ def __repr__(self): class QuantizedLinearNotImplementedError(NotImplementedError): - """ Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table """ + """Thin wrapper around NotImplementedError to make it easier to catch this error in the dispatch table""" + pass _AQT_QLINEAR_DISPATCH_TABLE = {} + + def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """Register a dispatch for quantized linear op with dispatch_condition function and impl function both takes three arguments: @@ -104,11 +107,15 @@ def register_aqt_quantized_linear_dispatch(dispatch_condition, impl): """ _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] = impl + def deregister_aqt_quantized_linear_dispatch(dispatch_condition): if dispatch_condition in _AQT_QLINEAR_DISPATCH_TABLE: del _AQT_QLINEAR_DISPATCH_TABLE[dispatch_condition] else: - logger.warn(f"Attempting to remove non-existant dispatch condition {dispatch_condition}") + logger.warn( + f"Attempting to remove non-existant dispatch condition {dispatch_condition}" + ) + class AffineQuantizedTensor(TorchAOBaseTensor): """ @@ -155,7 +162,9 @@ def __new__( kwargs = {} kwargs["device"] = layout_tensor.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else layout_tensor.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else layout_tensor.layout ) kwargs["dtype"] = dtype if strides is not None: @@ -194,9 +203,16 @@ def dequantize(self, output_dtype: Optional[torch.dtype] = None) -> torch.Tensor output_dtype = self.dtype from torchao.dtypes.fpx import FpxTensorCoreLayoutType + if isinstance(self.layout_type, FpxTensorCoreLayoutType): int_data, scale = self.layout_tensor.get_plain() - return dequantize_affine_fpx(int_data, scale, self.layout_type.ebits, self.layout_type.mbits, output_dtype=output_dtype) + return dequantize_affine_fpx( + int_data, + scale, + self.layout_type.ebits, + self.layout_type.mbits, + output_dtype=output_dtype, + ) else: data, scale, zero_point = self.layout_tensor.get_plain() return dequantize_affine( @@ -216,17 +232,28 @@ def _quantized_linear_op(input_tensor, weight_tensor, bias): for dispatch_condition, impl in _AQT_QLINEAR_DISPATCH_TABLE.items(): if dispatch_condition(input_tensor, weight_tensor, bias): return impl(input_tensor, weight_tensor, bias) - raise QuantizedLinearNotImplementedError("No specialized dispatch found for quantized linear op") + raise QuantizedLinearNotImplementedError( + "No specialized dispatch found for quantized linear op" + ) def __tensor_flatten__(self): - return ["layout_tensor"], [self.block_size, self.shape, self.quant_min, self.quant_max, self.zero_point_domain, self.dtype] + return ["layout_tensor"], [ + self.block_size, + self.shape, + self.quant_min, + self.quant_max, + self.zero_point_domain, + self.dtype, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): layout_tensor = tensor_data_dict["layout_tensor"] - block_size, shape, quant_min, quant_max, zero_point_domain, dtype = tensor_attributes + block_size, shape, quant_min, quant_max, zero_point_domain, dtype = ( + tensor_attributes + ) return cls( layout_tensor, block_size, @@ -259,20 +286,58 @@ def from_hp_to_intx( input_float = layout_type.pre_process(input_float) if use_hqq: - assert zero_point_domain == ZeroPointDomain.FLOAT and mapping_type == MappingType.ASYMMETRIC and quant_min==0, "Invalid input parameters for HQQ quantization." + assert ( + zero_point_domain == ZeroPointDomain.FLOAT + and mapping_type == MappingType.ASYMMETRIC + and quant_min == 0 + ), "Invalid input parameters for HQQ quantization." nbits = int(math.log2(quant_max + 1)) - axis = 1 if (block_size[0]==1) else 0 + axis = 1 if (block_size[0] == 1) else 0 group_size = max(block_size) - compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype + compute_dtype = ( + zero_point_dtype + if (zero_point_dtype is not None) + else input_float.dtype + ) device = input_float.device - data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False) + data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq( + input_float, + nbits=nbits, + group_size=group_size, + axis=axis, + compute_dtype=compute_dtype, + device=device, + verbose=False, + raw_output=False, + ) data = data.to(target_dtype) else: - scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) + scale, zero_point = choose_qparams_affine( + input_float, + mapping_type, + block_size, + target_dtype, + quant_min, + quant_max, + eps, + scale_dtype, + zero_point_dtype, + preserve_zero, + zero_point_domain, + ) # choose_qparams_affine is a custom op that does support returning optional Tensors. We thus set the zero_point to None if its domain is None if zero_point_domain is None: zero_point = None - data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) # Note: output will be uint8 tensor for sub byte tensors for now data = layout_type.post_process(data) @@ -285,7 +350,7 @@ def from_hp_to_intx( quant_min, quant_max, zero_point_domain, - dtype=input_float.dtype + dtype=input_float.dtype, ) @classmethod @@ -302,12 +367,25 @@ def from_hp_to_intx_static( layout_type: LayoutType = PlainLayoutType(), ): if target_dtype not in FP8_TYPES: - assert zero_point_domain is not None, "zero_point_domain must be specified for non-fp8 types" - assert zero_point is not None, "zero_point must be specified for non-fp8 types" + assert ( + zero_point_domain is not None + ), "zero_point_domain must be specified for non-fp8 types" + assert ( + zero_point is not None + ), "zero_point must be specified for non-fp8 types" original_shape = input_float.shape input_float = layout_type.pre_process(input_float) - int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = quantize_affine( + input_float, + block_size, + scale, + zero_point, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) int_data = layout_type.post_process(int_data) @@ -332,7 +410,6 @@ def from_hp_to_floatx( scale_dtype: Optional[torch.dtype], layout_type: LayoutType, ): - if target_dtype in FP8_TYPES: return cls.from_hp_to_intx( input_float=input_float, @@ -350,7 +427,9 @@ def from_hp_to_floatx( use_hqq=False, ) else: - raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx") + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx" + ) @classmethod def from_hp_to_floatx_static( @@ -361,7 +440,6 @@ def from_hp_to_floatx_static( target_dtype: torch.dtype, layout_type: LayoutType, ): - if target_dtype in FP8_TYPES: return cls.from_hp_to_intx_static( input_float=input_float, @@ -375,7 +453,9 @@ def from_hp_to_floatx_static( layout_type=layout_type, ) else: - raise NotImplementedError(f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static") + raise NotImplementedError( + f"Unsupported dtype {target_dtype} for from_hp_to_floatx_static" + ) @classmethod def from_hp_to_fpx( @@ -384,7 +464,10 @@ def from_hp_to_fpx( layout_type: LayoutType, ): from torchao.dtypes.fpx import FpxTensorCoreLayoutType - assert isinstance(layout_type, FpxTensorCoreLayoutType), f"Only FpxTensorCoreLayoutType is supported for fpx, got {layout_type}" + + assert isinstance( + layout_type, FpxTensorCoreLayoutType + ), f"Only FpxTensorCoreLayoutType is supported for fpx, got {layout_type}" original_shape = input_float.shape input_float = layout_type.pre_process(input_float) # per axis quantization, where axis = 1 @@ -399,12 +482,7 @@ def from_hp_to_fpx( layout_tensor_ctr = get_layout_tensor_constructor(type(layout_type)) layout_tensor = layout_tensor_ctr(fpx_packed, scale, None, layout_type) - return cls( - layout_tensor, - block_size, - original_shape, - dtype=input_float.dtype - ) + return cls(layout_tensor, block_size, original_shape, dtype=input_float.dtype) @property def layout_type(self) -> LayoutType: @@ -456,9 +534,9 @@ def _apply_fn_to_data(self, fn): register_layout_cls = AffineQuantizedTensor.register_layout_cls get_layout_tensor_constructor = AffineQuantizedTensor.get_layout_tensor_constructor + @dataclass(frozen=True) class SemiSparseLayoutType(LayoutType): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: # prune to 2:4 if not already temp = input.detach() @@ -492,11 +570,10 @@ class Float8LayoutType(LayoutType): @dataclass(frozen=True) class MarlinSparseLayoutType(LayoutType): - def pre_process(self, input: torch.Tensor) -> torch.Tensor: """Preprocess the input tensor to be in the correct format for the Marlin sparse kernel. - 1º: the input tensor is transposed since the linear layer keeps the weights in a transposed format - - 2º: tensor is injected with 2:4 sparsity + - 2º: tensor is injected with 2:4 sparsity - 3º: transposes it again because the quantization process will compute the scales for dim=-1 Args: @@ -506,6 +583,7 @@ def pre_process(self, input: torch.Tensor) -> torch.Tensor: torch.Tensor: the preprocessed tensor """ from torchao.sparsity.marlin import inject_24 # avoid circular import + input_t = input.t() w_24, _ = inject_24(input_t, *input_t.shape) return w_24.t() @@ -522,6 +600,7 @@ class PlainAQTLayout(AQTLayout): scale (torch.Tensor): the scale Tensor used to map between floating point tensor to quantized tensor zero_point (torch.Tensor): the zero_point Tensor used to map between floating point tensor to quantized tensor """ + def __new__( cls, int_data: torch.Tensor, @@ -558,8 +637,12 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"] - layout_type, = tensor_attributes + int_data, scale, zero_point = ( + tensor_data_dict["int_data"], + tensor_data_dict["scale"], + tensor_data_dict["zero_point"], + ) + (layout_type,) = tensor_attributes return cls(int_data, scale, zero_point, layout_type) def to(self, *args, **kwargs): @@ -596,7 +679,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs): if func is aten.t.default: tensor = args[0] new = tensor.__class__( - tensor.int_data.view(tensor.shape[::-1]), tensor.scale, tensor.zero_point, tensor.layout_type + tensor.int_data.view(tensor.shape[::-1]), + tensor.scale, + tensor.zero_point, + tensor.layout_type, ) return return_and_correct_aliasing(func, args, kwargs, new) @@ -623,11 +709,13 @@ def from_plain( assert isinstance(layout_type, PlainLayoutType) return cls(int_data, scale, zero_point, layout_type) + @register_layout_cls(SemiSparseLayoutType) class SemiSparseAQTLayout(PlainAQTLayout): """ Layout storage class for semi_sparse_cusparselt layout for affine quantized tensor """ + @classmethod def __torch_dispatch__(cls, func, types, args, kwargs): kwargs = {} if kwargs is None else kwargs @@ -645,10 +733,10 @@ def get_plain(self): # Currently we don't have cuSPARSELt expansion routines, so we matmul by # the identity matrix to get the original dense matrix. This is slow though. cols = self.int_data.numel() * 16 // (10 * self.scale.shape[0]) - int_data_expanded = torch._cslt_sparse_mm(self.int_data, - torch.eye(cols, - dtype=self.int_data.dtype, - device=self.int_data.device).t()) + int_data_expanded = torch._cslt_sparse_mm( + self.int_data, + torch.eye(cols, dtype=self.int_data.dtype, device=self.int_data.device).t(), + ) return int_data_expanded, self.scale, self.zero_point @classmethod @@ -667,8 +755,8 @@ def from_plain( @register_layout_cls(MarlinSparseLayoutType) class MarlinSparseAQTLayout(AQTLayout): """ - Layout storage class for sparse_marlin_24 layout for affine quantized tensor. - + Layout storage class for sparse_marlin_24 layout for affine quantized tensor. + Can be used with 4 bits and 8 bits quantization. Original marlin documentation and information: @@ -682,6 +770,7 @@ class MarlinSparseAQTLayout(AQTLayout): group_size (int): the group size used to pack the tensor num_bits (int): the number of bits used to quantize the tensor """ + @staticmethod def __new__( cls, @@ -738,7 +827,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) def __tensor_flatten__(self): - return ["int_data", "scale", "zero_point", "meta"], [self.layout_type, self.original_shape, self.group_size, self.num_bits] + return ["int_data", "scale", "zero_point", "meta"], [ + self.layout_type, + self.original_shape, + self.group_size, + self.num_bits, + ] @classmethod def __tensor_unflatten__( @@ -749,14 +843,26 @@ def __tensor_unflatten__( zero_point = tensor_data_dict["zero_point"] meta = tensor_data_dict["meta"] layout_type, original_shape, group_size, num_bits = tensor_attributes - return cls(int_data, scale, zero_point, meta, layout_type, original_shape, group_size, num_bits) + return cls( + int_data, + scale, + zero_point, + meta, + layout_type, + original_shape, + group_size, + num_bits, + ) def get_plain(self): - from torchao.sparsity.marlin import unpack_from_marlin_24 # avoid circular import + from torchao.sparsity.marlin import ( + unpack_from_marlin_24, + ) # avoid circular import + int_data_expanded, scales_expanded = unpack_from_marlin_24( - self.int_data, - self.scale, - self.meta, + self.int_data, + self.scale, + self.meta, self.original_shape, self.group_size, self.num_bits, @@ -773,7 +879,11 @@ def from_plain( zero_point: torch.Tensor, layout_type: LayoutType, ): - from torchao.sparsity.marlin import pack_to_marlin_24, const # avoid circular import + from torchao.sparsity.marlin import ( + pack_to_marlin_24, + const, + ) # avoid circular import + assert isinstance(layout_type, MarlinSparseLayoutType) # Linear layers are (in_features, out_features) but the int_data that is reaching this point @@ -783,12 +893,12 @@ def from_plain( if not torch.cuda.get_device_capability()[0] >= 8: raise ValueError( - f'Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel.' + f"Can not use Sparse Marlin 2:4 int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel." ) if q_w_24.dtype != torch.int32: raise ValueError("Only `torch.int32` weights are supported.") - + in_features, out_features = q_w_24.shape if in_features % 128 != 0 or out_features != 256 == 0: raise ValueError( @@ -800,14 +910,14 @@ def from_plain( # Check the link for a reference: https://github.com/neuralmagic/nm-vllm/tree/main num_bits = 4 if torch.max(q_w_24) < 16 else -1 if num_bits not in [4]: - raise ValueError( - f"Only {[4]} bits are supported, got {num_bits}." - ) + raise ValueError(f"Only {[4]} bits are supported, got {num_bits}.") group_size = in_features // scale_t.shape[0] if group_size == 0: group_size = in_features - assert group_size <= in_features, "Group size must be less than or equal to in_features." + assert ( + group_size <= in_features + ), "Group size must be less than or equal to in_features." if group_size not in const.SUPPORTED_GROUP_SIZES: raise ValueError( @@ -815,14 +925,21 @@ def from_plain( ) # Compress quantized weight to marlin 2:4 format - marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24(q_w_24, scale_t, num_bits, group_size) + marlin_24_q_w_comp, marlin_24_s, meta = pack_to_marlin_24( + q_w_24, scale_t, num_bits, group_size + ) return cls( - marlin_24_q_w_comp, marlin_24_s, zero_point, - meta, layout_type, q_w_24.shape, - group_size, num_bits + marlin_24_q_w_comp, + marlin_24_s, + zero_point, + meta, + layout_type, + q_w_24.shape, + group_size, + num_bits, ) - + def get_layout_type(self) -> LayoutType: return self.layout_type @@ -839,6 +956,7 @@ class Float8AQTLayout(AQTLayout): """ Layout storage class for float8 layout for affine quantized tensor """ + float8_data: torch.Tensor scale: torch.Tensor transposed: bool @@ -873,7 +991,7 @@ def __init__( self.layout_type = layout_type def _apply_fn_to_data(self, fn): - """ Applys a fn to all tensor components stored on this class""" + """Applys a fn to all tensor components stored on this class""" fn(self.float8_data) fn(self.scale) return self @@ -895,7 +1013,10 @@ def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): float8_data, scale = tensor_data_dict["float8_data"], tensor_data_dict["scale"] - transposed, layout_type, = tensor_attributes + ( + transposed, + layout_type, + ) = tensor_attributes return cls(float8_data, scale, transposed, layout_type) @classmethod @@ -937,20 +1058,26 @@ def from_plain( zero_point: Optional[torch.Tensor], layout_type: LayoutType, ): - """ Main entrypoint for constructing Float8Layout Tensor""" - assert _is_float8_type(data.dtype), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" - assert isinstance(layout_type, Float8LayoutType), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" + """Main entrypoint for constructing Float8Layout Tensor""" + assert _is_float8_type( + data.dtype + ), f"Float8 Layout must be constructed from float8 dtype but got {data.dtype}" + assert isinstance( + layout_type, Float8LayoutType + ), f"Float8 Layout must be constructed from Float8LayoutType but got {layout_type}" return cls(data, scale, False, layout_type) def __repr__(self): float8_data, scale, _ = self.get_plain() layout_type = self.get_layout_type() - return (f"{self.__class__.__name__}(\n" - f"float8_data={float8_data},\n" - f"scale={scale},\n" - f"transposed={self.transposed}, " - f"layout_type={layout_type})") - + return ( + f"{self.__class__.__name__}(\n" + f"float8_data={float8_data},\n" + f"scale={scale},\n" + f"transposed={self.transposed}, " + f"layout_type={layout_type})" + ) + @register_layout_cls(TensorCoreTiledLayoutType) class TensorCoreTiledAQTLayout(AQTLayout): @@ -974,7 +1101,9 @@ def __new__( kwargs = {} kwargs["device"] = packed_weight.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_weight.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_weight.layout ) kwargs["dtype"] = packed_weight.dtype kwargs["requires_grad"] = False @@ -1000,8 +1129,14 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"] - transposed, layout_type, = tensor_attributes + packed_weight, scale_and_zero = ( + tensor_data_dict["packed_weight"], + tensor_data_dict["scale_and_zero"], + ) + ( + transposed, + layout_type, + ) = tensor_attributes return cls(packed_weight, scale_and_zero, transposed, layout_type) @classmethod @@ -1010,20 +1145,24 @@ def from_plain( int_data: torch.Tensor, scale: torch.Tensor, zero_point: Optional[torch.Tensor], - layout_type: LayoutType + layout_type: LayoutType, ): - assert isinstance(layout_type, TensorCoreTiledLayoutType) if TORCH_VERSION_AT_LEAST_2_5: int_data = (int_data[::, ::2] << 4 | int_data[::, 1::2]).to(torch.uint8) - assert int_data.dtype == torch.uint8, "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" + assert ( + int_data.dtype == torch.uint8 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.5 expects `uint8` dtype" else: - assert int_data.dtype == torch.int32, "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" - packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data, layout_type.inner_k_tiles) + assert ( + int_data.dtype == torch.int32 + ), "torch.ops.aten._convert_weight_to_int4pack in torch 2.4 expects `int32` dtype" + packed_weight = torch.ops.aten._convert_weight_to_int4pack( + int_data, layout_type.inner_k_tiles + ) scale = scale.reshape(int_data.shape[0], -1) zero_point = zero_point.reshape(int_data.shape[0], -1) - from torchao.quantization.utils import pack_tinygemm_scales_and_zeros scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point) return cls(packed_weight, scale_and_zero, False, layout_type) @@ -1031,7 +1170,9 @@ def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] if not is_device("cuda", device): - raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}") + raise ValueError( + f"TensorCoreTiledAQTLayout is only available for cuda device, can't convert to {device}" + ) return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device), @@ -1077,6 +1218,7 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quantize_affine, ) from torchao.quantization.utils import unpack_tinygemm_scales_and_zeros + scale, zero = unpack_tinygemm_scales_and_zeros(self.scale_and_zero) cur_shape = self.shape @@ -1093,12 +1235,26 @@ def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: quant_max = 15 zero_point_domain = ZeroPointDomain.FLOAT assert len(block_size) == 2 and block_size[0] == 1 - dequantized = torch.ops.aten._weight_int4pack_mm(torch.eye(eye_shape, device=device, dtype=original_dtype), self.packed_weight, groupsize, self.scale_and_zero) + dequantized = torch.ops.aten._weight_int4pack_mm( + torch.eye(eye_shape, device=device, dtype=original_dtype), + self.packed_weight, + groupsize, + self.scale_and_zero, + ) dequantized = dequantized.t().contiguous() # TODO: move this to `unpack_tinygemm_scales_and_zeros`? scale = scale.reshape(scale.shape[:-1]).contiguous() zero = zero.reshape(zero.shape[:-1]).contiguous() - int_data = quantize_affine(dequantized, block_size, scale, zero, target_dtype, quant_min, quant_max, zero_point_domain) + int_data = quantize_affine( + dequantized, + block_size, + scale, + zero, + target_dtype, + quant_min, + quant_max, + zero_point_domain, + ) return int_data, scale, zero def get_layout_type(self) -> LayoutType: @@ -1109,28 +1265,36 @@ def get_layout_type(self) -> LayoutType: # torch functional and aten operator implementation # ##################################################### + def _aqt_is_int8(aqt): """Check if an AffineQuantizedTensor is int8 quantized Tensor""" return ( - aqt.layout_tensor.dtype == torch.int8 and - aqt.quant_min is None or aqt.quant_min == -128 and - aqt.quant_max is None or aqt.quant_max == 127 + aqt.layout_tensor.dtype == torch.int8 + and aqt.quant_min is None + or aqt.quant_min == -128 + and aqt.quant_max is None + or aqt.quant_max == 127 ) + def _aqt_is_int8_reduced_range(aqt): return ( - aqt.layout_tensor.dtype == torch.int8 and - aqt.quant_min == -127 and - aqt.quant_max is None or aqt.quant_max == 127 + aqt.layout_tensor.dtype == torch.int8 + and aqt.quant_min == -127 + and aqt.quant_max is None + or aqt.quant_max == 127 ) + def _aqt_is_uint4(aqt): """Check if an AffineQuantizedTensor is uint4 quantized Tensor""" # TODO: use torch.uint4 return ( - aqt.layout_tensor.dtype == torch.int32 and - aqt.quant_min is None or aqt.quant_min == 0 and - aqt.quant_max is None or aqt.quant_max == 15 + aqt.layout_tensor.dtype == torch.int32 + and aqt.quant_min is None + or aqt.quant_min == 0 + and aqt.quant_max is None + or aqt.quant_max == 15 ) @@ -1142,17 +1306,19 @@ def _aqt_is_uint4(aqt): # bias: dimension is (out_features,) # so that these can be shared by F.linear, aten.mm, aten.addmm dispatches + def _linear_int8_act_int8_weight_check(input_tensor, weight_tensor, bias): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor.layout_type, PlainLayoutType) and - isinstance(weight_tensor.layout_type, PlainLayoutType) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor.layout_type, PlainLayoutType) + and isinstance(weight_tensor.layout_type, PlainLayoutType) ) + def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): # # 1. do the matrix form of dot(X_i, W_j) @@ -1184,18 +1350,23 @@ def _linear_int8_act_int8_weight_impl(input_tensor, weight_tensor, bias): return y -def _linear_int8_act_int8_weight_semi_structured_sparse_check(input_tensor, weight_tensor, bias): +def _linear_int8_act_int8_weight_semi_structured_sparse_check( + input_tensor, weight_tensor, bias +): return ( - isinstance(input_tensor, AffineQuantizedTensor) and - _aqt_is_int8_reduced_range(input_tensor) and - isinstance(weight_tensor, AffineQuantizedTensor) and - weight_tensor.is_cuda and - input_tensor.dtype == weight_tensor.dtype and - isinstance(input_tensor.layout_type, PlainLayoutType) and - isinstance(weight_tensor.layout_type, SemiSparseLayoutType) + isinstance(input_tensor, AffineQuantizedTensor) + and _aqt_is_int8_reduced_range(input_tensor) + and isinstance(weight_tensor, AffineQuantizedTensor) + and weight_tensor.is_cuda + and input_tensor.dtype == weight_tensor.dtype + and isinstance(input_tensor.layout_type, PlainLayoutType) + and isinstance(weight_tensor.layout_type, SemiSparseLayoutType) ) -def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weight_tensor, bias): + +def _linear_int8_act_int8_weight_semi_structured_sparse_impl( + input_tensor, weight_tensor, bias +): x_vals_int8 = input_tensor.layout_tensor.int_data x_scales = input_tensor.layout_tensor.scale w_vals_int8 = weight_tensor.layout_tensor.int_data @@ -1203,7 +1374,10 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh tmp = x_vals_int8.reshape(-1, x_vals_int8.shape[-1]) # we fuse one of the scalar matrix multiplications (w_scales) into the sparse mm y_dot_bf16_w_scales_fused = torch._cslt_sparse_mm( - w_vals_int8, tmp.t(), alpha=w_scales.to(torch.float32), out_dtype=torch.bfloat16, + w_vals_int8, + tmp.t(), + alpha=w_scales.to(torch.float32), + out_dtype=torch.bfloat16, ).t() y = (y_dot_bf16_w_scales_fused * x_scales.reshape(-1, 1)).reshape( *x_vals_int8.shape[:-1], y_dot_bf16_w_scales_fused.shape[-1] @@ -1215,23 +1389,27 @@ def _linear_int8_act_int8_weight_semi_structured_sparse_impl(input_tensor, weigh y += bias return y + def _linear_bf16_act_uint4_weight_check(input_tensor, weight_tensor, bias): return ( # input is native bfloat16 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.dtype == torch.bfloat16 and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.dtype == torch.bfloat16 + and # weight is uint4, group quantized tensor_core_tiled layout affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_uint4(weight_tensor) and - weight_tensor.dtype == torch.bfloat16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT and - isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and weight_tensor.dtype == torch.bfloat16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.FLOAT + and isinstance(weight_tensor.layout_type, TensorCoreTiledLayoutType) ) def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): - assert weight_tensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}" + assert ( + weight_tensor.block_size[0] == 1 + ), f"Requires groupwise quantization, got block_size: {weight_tensor.block_size}" assert input_tensor.shape[-1] == weight_tensor.shape[1], ( f"need input_tensor shape: {input_tensor.shape} final" f"dim to match weight_tensor shape: {weight_tensor.shape} second dim " @@ -1255,7 +1433,9 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): # groupwise int4 quantization groupsize = weight_tensor.block_size[1] - y = torch.ops.aten._weight_int4pack_mm(act_mat.contiguous(), packed_weight, groupsize, scale_and_zero) + y = torch.ops.aten._weight_int4pack_mm( + act_mat.contiguous(), packed_weight, groupsize, scale_and_zero + ) # remove out_feature padding orig_out_features = weight_tensor.shape[-2] @@ -1270,19 +1450,21 @@ def _linear_bf16_act_uint4_weight_impl(input_tensor, weight_tensor, bias): def _linear_fp_act_int8_weight_check(input_tensor, weight_tensor, bias): return ( # input is native float tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and # weight is int8 per channel quantized affine quantized tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_int8(weight_tensor) and - len(weight_tensor.shape) == 2 and - len(weight_tensor.block_size) == 2 and - weight_tensor.block_size[0] == 1 and - weight_tensor.block_size[1] == weight_tensor.shape[1] and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor.layout_type, PlainLayoutType) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_int8(weight_tensor) + and len(weight_tensor.shape) == 2 + and len(weight_tensor.block_size) == 2 + and weight_tensor.block_size[0] == 1 + and weight_tensor.block_size[1] == weight_tensor.shape[1] + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor.layout_type, PlainLayoutType) ) + def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # TODO: enable cpu and mps efficient path # is_cpu and is_mps only, some issue with is_contiguous() currently @@ -1291,7 +1473,6 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): # per channel int8 weight only quantizated mm w_vals_int8_t = weight_tensor.layout_tensor.int_data.t() scale = weight_tensor.layout_tensor.scale - orig_dtype = input_tensor.dtype m = torch.mm( input_tensor.reshape(-1, input_tensor.shape[-1]), w_vals_int8_t.to(input_tensor.dtype), @@ -1302,30 +1483,43 @@ def _linear_fp_act_int8_weight_impl(input_tensor, weight_tensor, bias): y += bias.to(m.dtype) return y + def _linear_f16_act_fpx_weight_check(input_tensor, weight_tensor, bias): from torchao.dtypes.fpx import FpxTensorCoreLayoutType + return ( # input is native float32 tensor - not is_traceable_wrapper_subclass(input_tensor) and - input_tensor.is_floating_point() and - input_tensor.dtype == torch.float16 and + not is_traceable_wrapper_subclass(input_tensor) + and input_tensor.is_floating_point() + and input_tensor.dtype == torch.float16 + and # weight is fpx Tensor - isinstance(weight_tensor, AffineQuantizedTensor) and - isinstance(weight_tensor.layout_type, FpxTensorCoreLayoutType) and - ( + isinstance(weight_tensor, AffineQuantizedTensor) + and isinstance(weight_tensor.layout_type, FpxTensorCoreLayoutType) + and ( # weight is using fp6 quantization - (weight_tensor.layout_type.ebits == 3 and - weight_tensor.layout_type.mbits == 2) or - (weight_tensor.layout_type.ebits == 2 and - weight_tensor.layout_type.mbits == 3) or + ( + weight_tensor.layout_type.ebits == 3 + and weight_tensor.layout_type.mbits == 2 + ) + or ( + weight_tensor.layout_type.ebits == 2 + and weight_tensor.layout_type.mbits == 3 + ) + or # weight is using fp5 quantization - (weight_tensor.layout_type.ebits == 2 and - weight_tensor.layout_type.mbits == 2) or - (weight_tensor.layout_type.ebits == 3 and - weight_tensor.layout_type.mbits == 1) + ( + weight_tensor.layout_type.ebits == 2 + and weight_tensor.layout_type.mbits == 2 + ) + or ( + weight_tensor.layout_type.ebits == 3 + and weight_tensor.layout_type.mbits == 1 + ) ) ) + def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): from torchao.dtypes.fpx import _SPLIT_K_MAP from torchao.ops import quant_llm_linear @@ -1354,6 +1548,7 @@ def _linear_f16_act_fpx_weight_impl(input_tensor, weight_tensor, bias): return out.view(*act.shape[:-1], out_dim).to(act.dtype) + def _linear_fp_act_fp8_tensor_wise_weight_check( input_tensor: Union[torch.Tensor, AffineQuantizedTensor], weight_tensor: Union[torch.Tensor, AffineQuantizedTensor], @@ -1361,11 +1556,12 @@ def _linear_fp_act_fp8_tensor_wise_weight_check( ) -> bool: def check_aqt_tensorwise(aqt: Union[torch.Tensor, AffineQuantizedTensor]) -> bool: return ( - isinstance(aqt, AffineQuantizedTensor) and - isinstance(aqt.layout_type, Float8LayoutType) + isinstance(aqt, AffineQuantizedTensor) + and isinstance(aqt.layout_type, Float8LayoutType) and aqt.layout_tensor.dtype in [torch.float8_e4m3fn, torch.float8_e5m2] and aqt.shape == aqt.block_size ) + return check_aqt_tensorwise(input_tensor) and check_aqt_tensorwise(weight_tensor) @@ -1382,7 +1578,9 @@ def _linear_fp_act_fp8_weight_impl( ) scaled_mm_config = weight_tensor.layout_type.mm_config - scaled_mm_config = scaled_mm_config if scaled_mm_config is not None else Float8MMConfig() + scaled_mm_config = ( + scaled_mm_config if scaled_mm_config is not None else Float8MMConfig() + ) w_layout = weight_tensor.layout_tensor w_data = weight_tensor.layout_tensor.float8_data @@ -1414,17 +1612,17 @@ def _linear_fp_act_fp8_weight_impl( def _linear_fp_act_int4_weight_sparse_marlin_check(input_tensor, weight_tensor, bias): return ( - isinstance(weight_tensor, AffineQuantizedTensor) and - _aqt_is_uint4(weight_tensor) and - input_tensor.dtype == torch.float16 and - len(weight_tensor.shape) == 2 and - weight_tensor.zero_point_domain == ZeroPointDomain.INT and - isinstance(weight_tensor.layout_type, MarlinSparseLayoutType) + isinstance(weight_tensor, AffineQuantizedTensor) + and _aqt_is_uint4(weight_tensor) + and input_tensor.dtype == torch.float16 + and len(weight_tensor.shape) == 2 + and weight_tensor.zero_point_domain == ZeroPointDomain.INT + and isinstance(weight_tensor.layout_type, MarlinSparseLayoutType) ) + def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, bias): - from torchao.sparsity.marlin import marlin_24_workspace, const - from torchao.ops import marlin_24_gemm + from torchao.sparsity.marlin import marlin_24_workspace assert isinstance(weight_tensor, AffineQuantizedTensor) @@ -1442,9 +1640,16 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b size_k = input_2d.shape[1] workspace_24 = marlin_24_workspace(original_shape[1]) - out = marlin_24_gemm( - input_2d, sparse_w_int4, meta, scale, - workspace_24, num_bits, size_m, size_n, size_k + out = torchao.ops.marlin_24_gemm( + input_2d, + sparse_w_int4, + meta, + scale, + workspace_24, + num_bits, + size_m, + size_n, + size_k, ) # Unfold the batch dimension @@ -1458,17 +1663,25 @@ def _linear_fp_act_int4_weight_sparse_marlin_impl(input_tensor, weight_tensor, b def _register_aqt_quantized_linear_dispatches(): for dispatch_condition, impl in [ (_linear_int8_act_int8_weight_check, _linear_int8_act_int8_weight_impl), - (_linear_int8_act_int8_weight_semi_structured_sparse_check, _linear_int8_act_int8_weight_semi_structured_sparse_impl), + ( + _linear_int8_act_int8_weight_semi_structured_sparse_check, + _linear_int8_act_int8_weight_semi_structured_sparse_impl, + ), (_linear_fp_act_fp8_tensor_wise_weight_check, _linear_fp_act_fp8_weight_impl), (_linear_bf16_act_uint4_weight_check, _linear_bf16_act_uint4_weight_impl), (_linear_fp_act_int8_weight_check, _linear_fp_act_int8_weight_impl), (_linear_f16_act_fpx_weight_check, _linear_f16_act_fpx_weight_impl), - (_linear_fp_act_int4_weight_sparse_marlin_check, _linear_fp_act_int4_weight_sparse_marlin_impl), + ( + _linear_fp_act_int4_weight_sparse_marlin_check, + _linear_fp_act_int4_weight_sparse_marlin_impl, + ), ]: register_aqt_quantized_linear_dispatch(dispatch_condition, impl) + _register_aqt_quantized_linear_dispatches() + @implements(torch.nn.functional.linear) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -1477,7 +1690,9 @@ def _(func, types, args, kwargs): args[2] if len(args) > 2 else None, ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -1486,7 +1701,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor.layout_type, "quantized_linear_impl") + and weight_tensor.layout_type.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1495,6 +1714,7 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return torch.nn.functional.linear(input_tensor, weight_tensor, bias) + @implements(aten.addmm.default) def _(func, types, args, kwargs): input_tensor, weight_tensor, bias = ( @@ -1503,7 +1723,9 @@ def _(func, types, args, kwargs): args[0], ) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) # using try/except here so that we can have a general fallback when input_tensor/weight_tensor # is not picked up by any of the dispatch paths in `_quantized_linear_op`, this allows us to @@ -1513,7 +1735,11 @@ def _(func, types, args, kwargs): return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor.layout_type, "quantized_linear_impl") + and weight_tensor.layout_type.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1522,22 +1748,25 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return func(bias, input_tensor, weight_tensor) + @implements(aten.mm.default) def _(func, types, args, kwargs): - input_tensor, weight_tensor, bias = ( - args[0], - args[1], - None - ) + input_tensor, weight_tensor, bias = (args[0], args[1], None) if not input_tensor.is_floating_point(): - raise NotImplementedError(f"{func} is not implemented for non floating point input") + raise NotImplementedError( + f"{func} is not implemented for non floating point input" + ) try: weight_tensor = weight_tensor.t() return weight_tensor._quantized_linear_op(input_tensor, weight_tensor, bias) except QuantizedLinearNotImplementedError as e: # fallback path is only called when user did not specify a specfic quantized linear implementation with `layout_type.quantized_linear_impl` - if isinstance(weight_tensor, AffineQuantizedTensor) and hasattr(weight_tensor.layout_type, "quantized_linear_impl") and weight_tensor.layout_type.quantized_linear_impl is not None: + if ( + isinstance(weight_tensor, AffineQuantizedTensor) + and hasattr(weight_tensor.layout_type, "quantized_linear_impl") + and weight_tensor.layout_type.quantized_linear_impl is not None + ): raise e if isinstance(input_tensor, AffineQuantizedTensor): @@ -1546,6 +1775,7 @@ def _(func, types, args, kwargs): weight_tensor = weight_tensor.dequantize() return func(input_tensor, weight_tensor) + @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( @@ -1569,6 +1799,7 @@ def _(func, types, args, kwargs): args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone), ) + @implements(aten.t.default) def _(func, types, args, kwargs): block_size = args[0].block_size @@ -1577,10 +1808,18 @@ def _(func, types, args, kwargs): tensor = args[0] shape = tensor.shape[::-1] new = tensor.__class__( - tensor.layout_tensor.t(), transposed_block_size, shape, tensor.quant_min, tensor.quant_max, tensor.zero_point_domain, dtype=tensor.dtype, strides=tensor.stride() + tensor.layout_tensor.t(), + transposed_block_size, + shape, + tensor.quant_min, + tensor.quant_max, + tensor.zero_point_domain, + dtype=tensor.dtype, + strides=tensor.stride(), ) return return_and_correct_aliasing(func, args, kwargs, new) + to_affine_quantized_intx = AffineQuantizedTensor.from_hp_to_intx to_affine_quantized_intx_static = AffineQuantizedTensor.from_hp_to_intx_static to_affine_quantized_floatx = AffineQuantizedTensor.from_hp_to_floatx diff --git a/torchao/dtypes/fpx/__init__.py b/torchao/dtypes/fpx/__init__.py index af77685fac..a62eb48283 100644 --- a/torchao/dtypes/fpx/__init__.py +++ b/torchao/dtypes/fpx/__init__.py @@ -1 +1,15 @@ -from .fpx import FpxTensorCoreLayoutType, FpxTensorCoreAQTLayout, to_scaled_tc_fpx, from_scaled_tc_fpx, _SPLIT_K_MAP +from .fpx import ( + FpxTensorCoreLayoutType, + FpxTensorCoreAQTLayout, + to_scaled_tc_fpx, + from_scaled_tc_fpx, + _SPLIT_K_MAP, +) + +__all__ = [ + "FpxTensorCoreAQTLayout", + "FpxTensorCoreLayoutType", + "to_scaled_tc_fpx", + "from_scaled_tc_fpx", + "_SPLIT_K_MAP", +] diff --git a/torchao/dtypes/fpx/fpx.py b/torchao/dtypes/fpx/fpx.py index 6afa22f560..77064baefc 100644 --- a/torchao/dtypes/fpx/fpx.py +++ b/torchao/dtypes/fpx/fpx.py @@ -4,11 +4,14 @@ import torch from torch import Tensor from torch.utils._python_dispatch import return_and_correct_aliasing -from torchao.prototype.custom_fp_utils import _f32_to_fpx_unpacked, _fpx_unpacked_to_f32, _n_ones +from torchao.prototype.custom_fp_utils import ( + _f32_to_fpx_unpacked, + _fpx_unpacked_to_f32, + _n_ones, +) from torchao.dtypes.utils import ( LayoutType, ) -from torchao.quantization.quant_api import _get_linear_subclass_inserter from dataclasses import dataclass from torchao.dtypes.affine_quantized_tensor import AQTLayout, register_layout_cls @@ -18,11 +21,23 @@ def _pack(x: Tensor, n_bits: int) -> Tensor: - return reduce(torch.bitwise_or, [x[..., i::(8 // n_bits)] << (8 - (i + 1) * n_bits) for i in range(8 // n_bits)]) + return reduce( + torch.bitwise_or, + [ + x[..., i :: (8 // n_bits)] << (8 - (i + 1) * n_bits) + for i in range(8 // n_bits) + ], + ) def _unpack(x: Tensor, n_bits: int) -> Tensor: - return torch.stack([(x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) for i in range(8 // n_bits)], dim=-1).flatten(-2) + return torch.stack( + [ + (x >> (8 - (i + 1) * n_bits)) & ((1 << n_bits) - 1) + for i in range(8 // n_bits) + ], + dim=-1, + ).flatten(-2) # https://github.com/usyd-fsalab/fp6_llm/blob/5df6737cca32f604e957e3f63f03ccc2e4d1df0d/fp6_llm/csrc/utils/weight_prepacking.h#L87-L116 @@ -36,8 +51,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: if not undo: bit_order = { - 1: [1, 5, 9, 13, 17, 21, 25, 29, 3, 7, 11, 15, 19, 23, 27, 31, - 0, 4, 8, 12, 16, 20, 24, 28, 2, 6, 10, 14, 18, 22, 26, 30], + 1: [ + 1, + 5, + 9, + 13, + 17, + 21, + 25, + 29, + 3, + 7, + 11, + 15, + 19, + 23, + 27, + 31, + 0, + 4, + 8, + 12, + 16, + 20, + 24, + 28, + 2, + 6, + 10, + 14, + 18, + 22, + 26, + 30, + ], 2: [1, 5, 9, 13, 3, 7, 11, 15, 0, 4, 8, 12, 2, 6, 10, 14], 4: [1, 5, 3, 7, 0, 4, 2, 6], }[n_bits] @@ -46,8 +93,40 @@ def _bit_interleave(x: Tensor, n_bits: int, undo: bool = False) -> Tensor: # this is inverse of the above, obtained by running # [v.index(i) for i in range(len(v))] bit_order = { - 1: [16, 0, 24, 8, 17, 1, 25, 9, 18, 2, 26, 10, 19, 3, 27, 11, - 20, 4, 28, 12, 21, 5, 29, 13, 22, 6, 30, 14, 23, 7, 31, 15], + 1: [ + 16, + 0, + 24, + 8, + 17, + 1, + 25, + 9, + 18, + 2, + 26, + 10, + 19, + 3, + 27, + 11, + 20, + 4, + 28, + 12, + 21, + 5, + 29, + 13, + 22, + 6, + 30, + 14, + 23, + 7, + 31, + 15, + ], 2: [8, 0, 12, 4, 9, 1, 13, 5, 10, 2, 14, 6, 11, 3, 15, 7], 4: [4, 0, 6, 2, 5, 1, 7, 3], }[n_bits] @@ -83,8 +162,12 @@ def _pack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = (tensor >> (nbits - used_bits - y)) & mask tensor_ybit = _pack(tensor_ybit, y) - tensor_ybit = tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) # Pass 2 from original code - tensor_ybit = _bit_interleave(tensor_ybit.flatten(), y) # Pass 3 from original code + tensor_ybit = ( + tensor_ybit.view(32, -1, 4).permute(1, 0, 2).flip(2) + ) # Pass 2 from original code + tensor_ybit = _bit_interleave( + tensor_ybit.flatten(), y + ) # Pass 3 from original code fragments.append(tensor_ybit) used_bits += y @@ -126,7 +209,9 @@ def to_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int) -> Tuple[Tensor, Te # workaround: global lookup table exp_bias = _ONES_TABLE[ebits - 1] - max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * (_ONES_TABLE[mbits + 1] / (2 ** mbits)) + max_normal = 2 ** (_ONES_TABLE[ebits] - exp_bias) * ( + _ONES_TABLE[mbits + 1] / (2**mbits) + ) tensor = tensor.float() scale = tensor.abs().amax(1).clamp(min=1e-12) / max_normal @@ -152,8 +237,10 @@ def _unpack_tc_fpx(tensor: Tensor, nbits: int) -> Tensor: tensor_ybit = tensor[offset : offset + size_ybit] offset += size_ybit - tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 - tensor_ybit = tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) # undo Pass 2 + tensor_ybit = _bit_interleave(tensor_ybit, y, undo=True) # undo Pass 3 + tensor_ybit = ( + tensor_ybit.view(-1, 32, 4).flip(2).permute(1, 0, 2) + ) # undo Pass 2 tensor_ybit = _unpack(tensor_ybit.flatten(), y) tensor_ybit = tensor_ybit << (nbits - used_bits - y) @@ -224,7 +311,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 5, 14336: 7, 28672: 7, - 57344: 7 + 57344: 7, }, { # tokens: [65:128] 3072: 9, @@ -235,7 +322,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 5, 14336: 7, 28672: 7, - 57344: 6 + 57344: 6, }, { # tokens: [129:192] 3072: 6, @@ -246,7 +333,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 5, 14336: 5, 28672: 5, - 57344: 4 + 57344: 4, }, { # tokens: [193:256] 3072: 9, @@ -257,7 +344,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 4, 14336: 8, 28672: 6, - 57344: 4 + 57344: 4, }, { # tokens: [257:320] 3072: 7, @@ -268,7 +355,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 1, 14336: 3, 28672: 3, - 57344: 4 + 57344: 4, }, { # tokens: [321:384] 3072: 3, @@ -279,7 +366,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 8, 14336: 3, 28672: 4, - 57344: 3 + 57344: 3, }, { # tokens: [385:448] 3072: 5, @@ -290,7 +377,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 3, 14336: 1, 28672: 1, - 57344: 3 + 57344: 3, }, { # tokens: [449:512] 3072: 2, @@ -301,7 +388,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 2, 14336: 6, 28672: 4, - 57344: 1 + 57344: 1, }, { # tokens: [513:576] 3072: 2, @@ -312,7 +399,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 3, 14336: 3, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [577:640] 3072: 5, @@ -323,7 +410,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 1, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [641:704] 3072: 3, @@ -334,7 +421,7 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 2, 14336: 1, 28672: 1, - 57344: 1 + 57344: 1, }, { # tokens: [705:768] 3072: 3, @@ -345,20 +432,22 @@ def from_scaled_tc_fpx(tensor: Tensor, ebits: int, mbits: int, scale=None) -> Te 10240: 1, 14336: 1, 28672: 1, - 57344: 1 - } + 57344: 1, + }, ] # quantization api integrations + @dataclass(frozen=True) class FpxTensorCoreLayoutType(LayoutType): - """Layout type for FpxTensorCoreAQTLayout - """ + """Layout type for FpxTensorCoreAQTLayout""" + ebits: int mbits: int + @register_layout_cls(FpxTensorCoreLayoutType) class FpxTensorCoreAQTLayout(AQTLayout): """FpxTensorCoreAQTLayout represents a Tensor with dtype fpx(ebits=a, mbits=b), @@ -382,6 +471,7 @@ class FpxTensorCoreAQTLayout(AQTLayout): it will then pack the weight and instantiate the FpxTensorCoreAQTLayout tensor FpxTensorCoreAQTLayout.__init__() takes a packed fpx Tensor of shape (M, N // 8 * nbit) """ + def __new__( cls, packed_fpx_data: torch.Tensor, @@ -390,11 +480,16 @@ def __new__( ): assert packed_fpx_data.ndim == 2 assert packed_fpx_data.dtype == torch.uint8 - shape = (packed_fpx_data.shape[0], packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8) + shape = ( + packed_fpx_data.shape[0], + packed_fpx_data.shape[1] // (1 + layout_type.ebits + layout_type.mbits) * 8, + ) kwargs = {} kwargs["device"] = packed_fpx_data.device kwargs["layout"] = ( - kwargs.get("layout") if kwargs.get("layout", False) else packed_fpx_data.layout + kwargs.get("layout") + if kwargs.get("layout", False) + else packed_fpx_data.layout ) kwargs["dtype"] = packed_fpx_data.dtype kwargs["requires_grad"] = False @@ -417,12 +512,17 @@ def __tensor_flatten__(self): def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - packed_fpx_data, scale = tensor_data_dict["packed_fpx_data"], tensor_data_dict["scale"] - layout_type, = tensor_attributes + packed_fpx_data, scale = ( + tensor_data_dict["packed_fpx_data"], + tensor_data_dict["scale"], + ) + (layout_type,) = tensor_attributes return cls(packed_fpx_data, scale, layout_type) def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor]: - unpacked_fpx_data = unpack_tc_fpx(self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits) + unpacked_fpx_data = unpack_tc_fpx( + self.packed_fpx_data, 1 + self.layout_type.ebits + self.layout_type.mbits + ) return unpacked_fpx_data, self.scale @classmethod @@ -441,7 +541,9 @@ def from_plain( bit, M is mantissa bit """ assert isinstance(layout_type, FpxTensorCoreLayoutType) - packed_fpx_data = pack_tc_fpx(unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits) + packed_fpx_data = pack_tc_fpx( + unpacked_fpx_data, 1 + layout_type.ebits + layout_type.mbits + ) return cls(packed_fpx_data, scale, layout_type) def __repr__(self): @@ -479,7 +581,12 @@ def __torch_dispatch__(cls, func, types, args, kwargs): ) elif func is aten._to_copy.default: return return_and_correct_aliasing( - func, args, kwargs, args[0]._apply_fn_to_data(lambda x: x.to(device=kwargs.pop("device", None))), + func, + args, + kwargs, + args[0]._apply_fn_to_data( + lambda x: x.to(device=kwargs.pop("device", None)) + ), ) raise NotImplementedError( diff --git a/torchao/dtypes/uint4.py b/torchao/dtypes/uint4.py index fc6eb2646c..14bd0bd3ae 100644 --- a/torchao/dtypes/uint4.py +++ b/torchao/dtypes/uint4.py @@ -105,7 +105,6 @@ def __new__(cls, elem, **kwargs): ) def __init__(self, elem, **kwargs): - self.elem = elem @classmethod diff --git a/torchao/dtypes/uintx/Uintx.py b/torchao/dtypes/uintx/Uintx.py index cfe75f4dc7..157f40f8da 100644 --- a/torchao/dtypes/uintx/Uintx.py +++ b/torchao/dtypes/uintx/Uintx.py @@ -43,6 +43,7 @@ class UintxTensor(TorchAOBaseTensor): bit_width (int): number of bits for each element pack_dim: (int) dimension to pack along """ + bits_to_shard = { 1: ["int1_shard"], 2: ["int2_shard"], @@ -52,6 +53,7 @@ class UintxTensor(TorchAOBaseTensor): 6: ["int4_shard", "int2_shard"], 7: ["int4_shard", "int2_shard", "int1_shard"], } + def __new__( cls, shards: List[torch.Tensor], @@ -81,24 +83,28 @@ def __init__( self.pack_dim = pack_dim def get_shards(self): - return [getattr(self,i) for i in self.__class__.bits_to_shard[self.bit_width]] + return [getattr(self, i) for i in self.__class__.bits_to_shard[self.bit_width]] def __repr__(self): return f"Int{self.bit_width}Tensor(shape = {self.packed_shape}, data = {unpack(self.get_shards(), self.bit_width, dim = self.pack_dim)})" def __tensor_flatten__(self): - return self.__class__.bits_to_shard[self.bit_width], [self.packed_shape, self.bit_width, self.pack_dim] + return self.__class__.bits_to_shard[self.bit_width], [ + self.packed_shape, + self.bit_width, + self.pack_dim, + ] @classmethod def __tensor_unflatten__( cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride ): - shards = list(tensor_data_dict.values()) + shards = list(tensor_data_dict.values()) packed_shape, bit_width, pack_dim = tensor_attributes return cls(shards, packed_shape, bit_width, pack_dim) def get_plain(self): - return unpack(self.get_shards(), self.bit_width, dim = self.pack_dim) + return unpack(self.get_shards(), self.bit_width, dim=self.pack_dim) # temporary until kernels on packed tensors are created def apply_transformation(self, fn): @@ -110,18 +116,21 @@ def apply_transformation(self, fn): # temporary until kernels on packed tensors are created def apply_fn_to_shards(self, fn): new_shards = [fn(shard) for shard in self.get_shards()] - return self.__class__(new_shards, self.packed_shape, self.bit_width, self.pack_dim) + return self.__class__( + new_shards, self.packed_shape, self.bit_width, self.pack_dim + ) @classmethod def from_uint8(cls, int_data: torch.Tensor, dtype: torch.dtype, pack_dim: int = -1): - assert dtype in _DTYPE_TO_BIT_WIDTH.keys(), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" + assert ( + dtype in _DTYPE_TO_BIT_WIDTH.keys() + ), "Expected dtype to be one of {_DTYPE_TO_BIT_WIDTH.keys()}" bit_width = _DTYPE_TO_BIT_WIDTH[dtype] shards = pack(int_data, bit_width, dim=pack_dim) shape = list(int_data.shape) shape[pack_dim] = shape[pack_dim] * bit_width // 8 return cls(shards, int_data.shape, bit_width, pack_dim) - def _get_to_kwargs(self, *args, **kwargs): device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs) device = self.device if device is None else device @@ -150,42 +159,52 @@ def to(self, *args, **kwargs): return super().to(*args, **kwargs) - implements = UintxTensor.implements + @implements(aten.detach.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_fn_to_shards(torch.detach) ) + @implements(aten.view.default) def _(func, types, args, kwargs): return return_and_correct_aliasing( func, args, kwargs, args[0].apply_transformation(lambda x: x.view(*args[1:])) ) + @implements(aten._to_copy.default) def _(func, types, args, kwargs): - return return_and_correct_aliasing( - func, args, kwargs, args[0] - ) + return return_and_correct_aliasing(func, args, kwargs, args[0]) + @implements(aten.sub.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)) + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x - args[1]).to(torch.uint8)), ) + @implements(aten.mul.Tensor) def _(func, types, args, kwargs): return return_and_correct_aliasing( - func, args, kwargs, args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)) + func, + args, + kwargs, + args[0].apply_transformation(lambda x: (x * args[1]).to(torch.uint8)), ) + # quantization api integrations to_uintx = UintxTensor.from_uint8 + @dataclass(frozen=True) class UintxLayoutType(LayoutType): dtype: torch.dtype @@ -194,9 +213,9 @@ class UintxLayoutType(LayoutType): def post_process(self, input: torch.Tensor) -> torch.Tensor: return to_uintx(input, self.dtype, self.pack_dim) + @register_layout_cls(UintxLayoutType) class UintxAQTLayout(PlainAQTLayout): - def get_plain(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return self.int_data.get_plain(), self.scale, self.zero_point diff --git a/torchao/dtypes/uintx/bitpacking.py b/torchao/dtypes/uintx/bitpacking.py index 244ca437ef..5e0331b72c 100644 --- a/torchao/dtypes/uintx/bitpacking.py +++ b/torchao/dtypes/uintx/bitpacking.py @@ -7,16 +7,16 @@ 1: (0x01,), 2: (0x03,), 3: (0x03, 0x04), - 4: (0x0f,), - 5: (0x0f, 0x10), - 6: (0x0f, 0x30), - 7: (0x0f, 0x30, 0x40), + 4: (0x0F,), + 5: (0x0F, 0x10), + 6: (0x0F, 0x30), + 7: (0x0F, 0x30, 0x40), } unpack_mask = { - 1: (0x01,0x02,0x04,0x08, 0x10,0x20,0x40,0x80), - 2: (0x03,0x0c,0x30,0xc0), - 4: (0x0f,0xf0), + 1: (0x01, 0x02, 0x04, 0x08, 0x10, 0x20, 0x40, 0x80), + 2: (0x03, 0x0C, 0x30, 0xC0), + 4: (0x0F, 0xF0), } # size of each shard @@ -41,6 +41,7 @@ 7: (0, 4, 6), } + # for shifting groups left but right if shift is negative def abs_lsh(data, shift): if shift == 0: @@ -61,9 +62,9 @@ def abs_rsh(data, shift): return data >> shift -def pack_cpu(data: torch.Tensor, - elem_size: int, - dim: Optional[int] = -1) -> List[torch.Tensor]: +def pack_cpu( + data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 +) -> List[torch.Tensor]: """ Inputs: data: a tensor of sub byte elements in uint8 @@ -111,7 +112,10 @@ def pack_cpu(data: torch.Tensor, After pack, data went from 8 elements to 6: [[0, 105, 151, 37], [39, 146]] In general this means pack reduces input tensor size from n * 8 to n * elem_size """ - torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + torch._assert( + data.shape[dim] % 8 == 0, + f"pack dimension size ({data.shape[dim]}) is not divisble by scale", + ) torch._assert(data.dtype == torch.uint8, "data must be uint8") output_shape = list(data.shape) @@ -131,9 +135,9 @@ def pack_cpu(data: torch.Tensor, return output -def unpack_cpu(data: List[torch.Tensor], - elem_size: int, - dim: Optional[int] = -1) -> torch.Tensor: +def unpack_cpu( + data: List[torch.Tensor], elem_size: int, dim: Optional[int] = -1 +) -> torch.Tensor: """ Unpacks small dtype elements from a larger dtype. @@ -160,30 +164,37 @@ def unpack_cpu(data: List[torch.Tensor], output_narrow = output.narrow(dim, j * group_size, group_size) group = data[i] & unpack_mask[bit_size][j] shift_amt = j * bit_size - rel_pos - output_narrow.copy_(torch.bitwise_or(output_narrow, abs_rsh(group, j * bit_size - rel_pos))) + output_narrow.copy_( + torch.bitwise_or(output_narrow, abs_rsh(group, shift_amt)) + ) return output + # these are faster on the GPU + def _pack(data, elem_size, scale, dim): - ''' + """ Inner for loop from above pack function - ''' + """ packed_shape = list(data.shape) packed_shape[dim] = packed_shape[dim] // scale packed = torch.zeros(packed_shape, dtype=data.dtype, device=data.device) for i in range(scale): - narrow_slice = data.narrow(dim, data.shape[dim]*i//scale, data.shape[dim] // scale) + narrow_slice = data.narrow( + dim, data.shape[dim] * i // scale, data.shape[dim] // scale + ) packed |= narrow_slice << (elem_size * i) return packed + def _unpack(data, element_size, scale, dim): - ''' + """ Inner for loop from above unpack function - ''' + """ unpacked_shape = list(data.shape) unpacked_shape[dim] *= scale @@ -193,30 +204,57 @@ def _unpack(data, element_size, scale, dim): for i in range(scale): shift_amt = element_size * i - chunk = unpacked_data.narrow(dim, unpacked_data.shape[dim]*i//scale, unpacked_data.shape[dim] // scale).copy_((data >> shift_amt) & nbits) + unpacked_data.narrow( + dim, + unpacked_data.shape[dim] * i // scale, + unpacked_data.shape[dim] // scale, + ).copy_((data >> shift_amt) & nbits) return unpacked_data -def pack(data: torch.Tensor, - elem_size: int, - dim: Optional[int] = -1) -> List[torch.Tensor]: - ''' +def pack( + data: torch.Tensor, elem_size: int, dim: Optional[int] = -1 +) -> List[torch.Tensor]: + """ a less branching but more compute version so better for gpu - ''' - torch._assert(data.shape[dim] % 8 == 0, f"pack dimension size ({data.shape[dim]}) is not divisble by scale") + """ + torch._assert( + data.shape[dim] % 8 == 0, + f"pack dimension size ({data.shape[dim]}) is not divisble by scale", + ) torch._assert(data.dtype == torch.uint8, "data must be uint8") container_size = 8 - shards = [(data & maskbits[elem_size][i]) >> shifts[elem_size][i] for i in range(len(maskbits[elem_size]))] - return tuple([_pack(shards[i], numbits[elem_size][i], container_size//numbits[elem_size][i], dim) for i in range(len(maskbits[elem_size]))]) - -def unpack(data: List[torch.Tensor], - elem_size: int, - dim: Optional[int] = 0) -> torch.Tensor: - ''' + shards = [ + (data & maskbits[elem_size][i]) >> shifts[elem_size][i] + for i in range(len(maskbits[elem_size])) + ] + return tuple( + [ + _pack( + shards[i], + numbits[elem_size][i], + container_size // numbits[elem_size][i], + dim, + ) + for i in range(len(maskbits[elem_size])) + ] + ) + + +def unpack( + data: List[torch.Tensor], elem_size: int, dim: Optional[int] = 0 +) -> torch.Tensor: + """ a less branching but more compute version so better for gpu - ''' + """ container_size = 8 # unpack each 4,2,1 bit shard and unshift them back to the correct position - data = [_unpack(data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim) << shifts[elem_size][i] for i in range(len(data))] + data = [ + _unpack( + data[i], numbits[elem_size][i], container_size // numbits[elem_size][i], dim + ) + << shifts[elem_size][i] + for i in range(len(data)) + ] return reduce(torch.bitwise_or, data) diff --git a/torchao/dtypes/utils.py b/torchao/dtypes/utils.py index 7771bc34c5..2407393fb9 100644 --- a/torchao/dtypes/utils.py +++ b/torchao/dtypes/utils.py @@ -11,6 +11,8 @@ layout interacts with different operators, e.g. the same data representation can have different behaviors when running the same operator, e.g. transpose, quantized_linear. """ + + @dataclass(frozen=True) class LayoutType: def pre_process(self, input: torch.Tensor) -> torch.Tensor: @@ -25,16 +27,21 @@ def __repr__(self): def extra_repr(self) -> str: return "" + """ Plain LayoutType, the most basic LayoutType, also has no extra metadata, will typically be the default """ + + @dataclass(frozen=True) class PlainLayoutType(LayoutType): pass + def is_device(target_device_str: str, device: Union[str, torch.device]): return torch.device(device).type == target_device_str + def get_out_shape(input_shape: Tuple[int], weight_shape: Tuple[int]) -> Tuple[int, int]: """Returns the unflattened shape of the input tensor. Args: