diff --git a/test/hqq/test_hqq_affine.py b/test/hqq/test_hqq_affine.py index c37fe948b3..f3fa41c643 100644 --- a/test/hqq/test_hqq_affine.py +++ b/test/hqq/test_hqq_affine.py @@ -11,8 +11,11 @@ ) from torchao.utils import ( - TORCH_VERSION_AT_LEAST_2_4, - TORCH_VERSION_AT_LEAST_2_5, + TORCH_VERSION_AT_LEAST_2_3, +) +from torchao.quantization import ( + uintx_weight_only, + int4_weight_only, ) cuda_available = torch.cuda.is_available() @@ -20,7 +23,7 @@ #Parameters device = 'cuda:0' compute_dtype = torch.bfloat16 -group_size = 64 +group_size = 64 mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) #axis=1 preserve_zero = False @@ -34,36 +37,24 @@ def _init_data(in_features, out_features, compute_dtype, device, torch_seed): torch.random.manual_seed(torch_seed) - linear_layer = torch.nn.Linear(in_features, out_features, bias=False, device=device) + linear_layer = torch.nn.Linear(in_features, out_features, bias=False).to(device) x = torch.randn((1, linear_layer.in_features), dtype=torch.float, device=device)/20. y_ref = linear_layer(x) W = linear_layer.weight.data.clone().to(device=device, dtype=compute_dtype) return W, x, y_ref -def _eval_hqq(nbits, layout_type): - W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed) - - #Plain layout - target_dtype = torch.uint8 - #Tensorcore layout - if isinstance(layout_type, TensorCoreTiledLayoutType): - target_dtype = torch.uint8 if TORCH_VERSION_AT_LEAST_2_5 else torch.int32 - - q_tensor_hqq = to_affine_quantized_intx( - input_float=W, - mapping_type=mapping_type, - block_size=block_size, - target_dtype=target_dtype, - quant_min=0, - quant_max=2**nbits - 1, - zero_point_domain=zero_point_domain, - preserve_zero=preserve_zero, - layout_type=layout_type, - use_hqq=True, - ) +def _eval_hqq(dtype): + W, x, y_ref = _init_data(in_features, out_features, compute_dtype, device, torch_seed) + + dummy_linear = torch.nn.Linear(in_features=in_features, out_features=out_features, bias=False) + dummy_linear.weight.data = W + if dtype == torch.uint4: + q_tensor_hqq = int4_weight_only(group_size=max(block_size), use_hqq=True)(dummy_linear).weight + else: + q_tensor_hqq = uintx_weight_only(dtype, group_size=max(block_size), use_hqq=True)(dummy_linear).weight quant_linear_layer = torch.nn.Linear(W.shape[1], W.shape[0], bias=False, device=W.device) - del quant_linear_layer.weight + del quant_linear_layer.weight quant_linear_layer.weight = q_tensor_hqq dequantize_error = (W - q_tensor_hqq.dequantize()).abs().mean().item() dot_product_error = (y_ref - quant_linear_layer(x.to(compute_dtype))).abs().mean().item() @@ -71,44 +62,35 @@ def _eval_hqq(nbits, layout_type): return dequantize_error, dot_product_error -class TestHQQBase(unittest.TestCase): - @unittest.skipIf(not cuda_available, "Need CUDA available") - def test_hqq(self, nbits=None, layout_type=None, ref_dequantize_error=None, ref_dot_product_error=None): - if(nbits is None): return - dequantize_error, dot_product_error = _eval_hqq(nbits=nbits, layout_type=layout_type) +@unittest.skipIf(not cuda_available, "Need CUDA available") +@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "Need torch 2.3+") +class TestHQQ(unittest.TestCase): + def _test_hqq(self, dtype=None, ref_dequantize_error=None, ref_dot_product_error=None): + if(dtype is None): return + dequantize_error, dot_product_error = _eval_hqq(dtype) self.assertTrue(dequantize_error < ref_dequantize_error) self.assertTrue(dot_product_error < ref_dot_product_error) -class TestHQQ8Bit(TestHQQBase): def test_hqq_plain_8bit(self): - self.test_hqq(nbits=8, layout_type=PlainLayoutType(), ref_dequantize_error=5e-5, ref_dot_product_error=0.00013) + self._test_hqq(dtype=torch.uint8, ref_dequantize_error=5e-5, ref_dot_product_error=0.00013) -class TestHQQ7Bit(TestHQQBase): def test_hqq_plain_7bit(self): - self.test_hqq(nbits=7, layout_type=PlainLayoutType(), ref_dequantize_error=6e-05, ref_dot_product_error=0.000193) + self._test_hqq(dtype=torch.uint7, ref_dequantize_error=6e-05, ref_dot_product_error=0.000193) -class TestHQQ6Bit(TestHQQBase): def test_hqq_plain_6bit(self): - self.test_hqq(nbits=6, layout_type=PlainLayoutType(), ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353) + self._test_hqq(dtype=torch.uint6, ref_dequantize_error=0.0001131, ref_dot_product_error=0.000353) -class TestHQQ5Bit(TestHQQBase): def test_hqq_plain_5bit(self): - self.test_hqq(nbits=5, layout_type=PlainLayoutType(), ref_dequantize_error=0.00023, ref_dot_product_error=0.000704) + self._test_hqq(dtype=torch.uint5, ref_dequantize_error=0.00023, ref_dot_product_error=0.000704) -class TestHQQ4bit(TestHQQBase): def test_hqq_plain_4bit(self): - self.test_hqq(nbits=4, layout_type=PlainLayoutType(), ref_dequantize_error=0.000487, ref_dot_product_error=0.001472) - - def test_hqq_tensorcore_4bit(self): - self.test_hqq(nbits=4, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=inner_k_tiles), ref_dequantize_error=0.000487, ref_dot_product_error=0.00147) + self._test_hqq(dtype=torch.uint4, ref_dequantize_error=0.000487, ref_dot_product_error=0.001472) -class TestHQQ3Bit(TestHQQBase): def test_hqq_plain_3bit(self): - self.test_hqq(nbits=3, layout_type=PlainLayoutType(), ref_dequantize_error=0.00101, ref_dot_product_error=0.003047) + self._test_hqq(dtype=torch.uint3, ref_dequantize_error=0.00101, ref_dot_product_error=0.003047) -class TestHQQ2Bit(TestHQQBase): def test_hqq_plain_2bit(self): - self.test_hqq(nbits=2, layout_type=PlainLayoutType(), ref_dequantize_error=0.002366, ref_dot_product_error=0.007255) + self._test_hqq(dtype=torch.uint2, ref_dequantize_error=0.002366, ref_dot_product_error=0.007255) if __name__ == "__main__": unittest.main() diff --git a/torchao/_models/llama/eval.py b/torchao/_models/llama/eval.py index bf7448e023..fa2e77cf05 100644 --- a/torchao/_models/llama/eval.py +++ b/torchao/_models/llama/eval.py @@ -81,14 +81,19 @@ def run_evaluation( assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" quantize_(model.to(device), int4_weight_only(group_size=groupsize, use_hqq=use_hqq)) if "uintx" in quantization: - # uintx-nbits-group_size + # uintx-nbits-groupsize # "uintx-2-64" + if "hqq" in quantization: + use_hqq = True + quantization = quantization[:-4] + else: + use_hqq = False _quant_args = quantization.split("-") - nbits = int(_quant_args[1]) + nbits = int(_quant_args[0]) _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} dtype = _NBITS_TO_DTYPE[nbits] - group_size = int(_quant_args[2]) - quantize_(model, uintx_weight_only(dtype, group_size)) + group_size = int(_quant_args[1]) + quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) if "int4wo" in quantization and "gptq" in quantization: groupsize=int(quantization.split("-")[-2]) assert groupsize in [32,64,128,256], f"int4wo groupsize needs to be one of [32,64,128,256] but got {groupsize}" @@ -135,7 +140,7 @@ def run_evaluation( parser.add_argument('--limit', type=int, default=None, help='Number of eval samples to evaluate') parser.add_argument('--precision', type=lambda x: getattr(torch, x.split(".")[-1]), default=torch.bfloat16, help='dtype precision to use') parser.add_argument('--device', type=str, default="cuda", help='Device to use for evaluation') - parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq, uintx--") + parser.add_argument("-q", "--quantization", type=str, help="Which quantization techniques to apply: int8dq, int8wo, int4wo-, int4wo--gptq, int4wo--hqq, uintx--, uintx---hqq") parser.add_argument('--compile', action='store_true', help='Whether to compile the model.') parser.add_argument('--max_length', type=int, default=None, help='Length of text to process at one time') parser.add_argument('--calibration_tasks', type=str, nargs='+', default=['wikitext'], help='tasks to do gptq calibration on, if doing gptq') diff --git a/torchao/_models/llama/generate.py b/torchao/_models/llama/generate.py index 7d12dc2757..91270895d6 100644 --- a/torchao/_models/llama/generate.py +++ b/torchao/_models/llama/generate.py @@ -273,15 +273,20 @@ def main( if "fp6" in quantization: quantize_(model, fpx_weight_only(3, 2)) if "uintx" in quantization: - # uintx-nbits-group_size - # "uintx-2-64" + # uintx-nbits-groupsize, e.g. "uintx-2-64" + if "hqq" in quantization: + # uintx-nbits-groupsize-hqq + quantization = quantization[:-4] + use_hqq = True + else: + use_hqq = False _quant_args = quantization.split("-") - nbits = int(_quant_args[1]) + nbits = int(_quant_args[0]) assert nbits >= 1 and nbits <= 8, "nbits must be 1 to 8" _NBITS_TO_DTYPE = {1: torch.uint1, 2: torch.uint2, 3: torch.uint3, 4: torch.uint4, 5: torch.uint5, 6: torch.uint6, 7: torch.uint7, 8: torch.uint8} dtype = _NBITS_TO_DTYPE[nbits] - group_size = int(_quant_args[2]) - quantize_(model, uintx_weight_only(dtype, group_size)) + group_size = int(_quant_args[1]) + quantize_(model, uintx_weight_only(dtype, group_size, use_hqq=use_hqq)) if "autoquant" in quantization: if "autoquant-int4" == quantization: model = autoquant(model, manual=True, qtensor_class_list = torchao.quantization.DEFAULT_INT4_AUTOQUANT_CLASS_LIST) @@ -451,7 +456,7 @@ def callback(x): parser.add_argument('--top_k', type=int, default=200, help='Top-k for sampling.') parser.add_argument('--temperature', type=float, default=0.8, help='Temperature for sampling.') parser.add_argument('--checkpoint_path', type=Path, default=Path("../../../checkpoints/meta-llama/Llama-2-7b-chat-hf/model.pth"), help='Model checkpoint path.') - parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq, autoround-------, uintx--') + parser.add_argument('-q', '--quantization', type=str, help='Which quantization techniques to apply: int8dq, int8wo, int4wo-, autoquant, autoquant-int4, int4wo--hqq, autoround-------, uintx--, uintx---hqq') parser.add_argument('--kv_cache_quantization', action='store_true', help='Whether to quantize the KV cache') parser.add_argument('--cache_size', type=int, default=None, help='Force size of cache to be a certain number of tokens, if not set, will use max_new_tokens+prompt_size') parser.add_argument('--linear_causal_mask', action='store_true', help='Whether to use the memory efficient, but slightly less fast, linear causal mask (important for long context lengths)') diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 8cce22655a..df42161c7c 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -10,7 +10,7 @@ ZeroPointDomain, MappingType, int_scaled_matmul, - quantize_affine_hqq, + choose_qparams_and_quantize_affine_hqq, FP8_TYPES, choose_qparams_affine_fpx, quantize_affine_fpx, @@ -264,7 +264,7 @@ def from_hp_to_intx( group_size = max(block_size) compute_dtype = zero_point_dtype if (zero_point_dtype is not None) else input_float.dtype device = input_float.device - data, scale, zero_point, _ = quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False) + data, scale, zero_point, _ = choose_qparams_and_quantize_affine_hqq(input_float, nbits=nbits, group_size=group_size, axis=axis, compute_dtype=compute_dtype, device=device, verbose=False, raw_output=False) data = data.to(target_dtype) else: scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype, preserve_zero, zero_point_domain) diff --git a/torchao/quantization/README.md b/torchao/quantization/README.md index 33bef5dd4f..2ea25ccc60 100644 --- a/torchao/quantization/README.md +++ b/torchao/quantization/README.md @@ -227,7 +227,12 @@ This technique works best when the torch._inductor.config.use_mixed_mm option is ```python # for torch 2.4+ from torchao.quantization import quantize_, int4_weight_only -quantize_(model, int4_weight_only()) +group_size = 32 + +# you can enable [hqq](https://github.com/mobiusml/hqq/tree/master) quantization which is expected to improves accuracy through +# use_hqq flag for `int4_weight_only` quantization +use_hqq = False +quantize_(model, int4_weight_only(group_size=group_size, use_hqq=use_hqq)) # for torch 2.2.2 and 2.3 from torchao.quantization.quant_api import change_linear_weights_to_int4_woqtensors diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py index 284e4b88cc..f536609eee 100644 --- a/torchao/quantization/quant_api.py +++ b/torchao/quantization/quant_api.py @@ -596,7 +596,7 @@ def input_quant_func(x: torch.Tensor): return _get_linear_subclass_inserter(apply_float8_dynamic_activation_quant) -def uintx_weight_only(dtype, group_size=64, pack_dim=-1): +def uintx_weight_only(dtype, group_size=64, pack_dim=-1, use_hqq=False): """ Applies uintx weight-only asymmetric per-group quantization to linear layers, using uintx quantization where x is the number of bits specified by `dtype` @@ -606,23 +606,46 @@ def uintx_weight_only(dtype, group_size=64, pack_dim=-1): `group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, defaults to 64 `pack_dim`: the dimension we use for packing, defaults to -1 + `use_hqq`: whether to use hqq algorithm or the default algorithm to quantize the weight """ - def apply_uintx_weight_only_quant(weight): - layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim) + from torchao.quantization.quant_primitives import _DTYPE_TO_QVALUE_BOUNDS + + SUPPORTED_DTYPES = {torch.uint1, torch.uint2, torch.uint3, torch.uint4, torch.uint5, torch.uint6, torch.uint7, torch.uint8} + assert dtype in SUPPORTED_DTYPES, f"Unsupported dtype for hqq: {dtype}" + + def apply_uintx_weight_only_quant(weight, dtype): mapping_type = MappingType.ASYMMETRIC block_size = (1, group_size) - eps = torch.finfo(torch.float32).eps - zero_point_dtype = torch.int32 - zero_point_domain = ZeroPointDomain.INT + + if use_hqq: + if dtype == torch.uint4: + logger.warn(f"Recommended to use `int4_weight_only(group_size, use_hqq=True)` for the best performance") + quant_min, quant_max = _DTYPE_TO_QVALUE_BOUNDS[dtype] + dtype = torch.uint8 + eps = None + zero_point_dtype = None + zero_point_domain = ZeroPointDomain.FLOAT + preserve_zero = False + layout_type = PlainLayoutType() + else: + quant_min, quant_max = None, None + eps = torch.finfo(torch.float32).eps + zero_point_dtype = torch.int32 + zero_point_domain = ZeroPointDomain.INT + preserve_zero = True + layout_type = UintxLayoutType(dtype=dtype, pack_dim=pack_dim) return to_affine_quantized_intx( weight, mapping_type, block_size, dtype, + quant_min=quant_min, quant_max=quant_max, eps=eps, zero_point_dtype=zero_point_dtype, zero_point_domain=zero_point_domain, + preserve_zero=preserve_zero, layout_type=layout_type, + use_hqq=use_hqq, ) - return _get_linear_subclass_inserter(apply_uintx_weight_only_quant) + return _get_linear_subclass_inserter(apply_uintx_weight_only_quant, dtype=dtype) def fpx_weight_only(ebits: int, mbits: int): """Sub-byte floating point dtypes defined by `ebits`: exponent bits and `mbits`: mantissa bits @@ -652,5 +675,6 @@ def apply_quant_llm(weight: torch.Tensor) -> torch.Tensor: return to_affine_quantized_fpx(weight, layout_type) return _get_linear_subclass_inserter(apply_quant_llm) + if TORCH_VERSION_AT_LEAST_2_5: torch.serialization.add_safe_globals([_int8_asymm_per_token_quant, _int8_symm_per_token_reduced_range_quant]) diff --git a/torchao/quantization/quant_primitives.py b/torchao/quantization/quant_primitives.py index a5fddcb98c..7827ab88ad 100644 --- a/torchao/quantization/quant_primitives.py +++ b/torchao/quantization/quant_primitives.py @@ -30,7 +30,7 @@ "dequantize_affine_fpx", "fake_quantize_affine", "fake_quantize_affine_cachemask", - "quantize_affine_hqq", + "choose_qparams_and_quantize_affine_hqq", ] class MappingType(Enum): @@ -840,7 +840,7 @@ def _convert_to_affinequantized_format(W_q: torch.Tensor, scale: torch.Tensor, z return W_q_ao, scale_ao, zero_ao # Main hqq quantizer function -def quantize_affine_hqq( +def choose_qparams_and_quantize_affine_hqq( tensor: torch.Tensor, nbits: float = 4, group_size: int = 64,