From 5755483d958a5d3639346a8d61a10372bc70e507 Mon Sep 17 00:00:00 2001 From: Jerry Zhang Date: Wed, 30 Oct 2024 14:04:32 -0700 Subject: [PATCH] Making sure int4 weight only supports cpu as well Summary: We want to deprecate int4 weight only quantizer in torchchat, so making sure cpu is also supported Test Plan: python test/dtypes/test_affine_quantized.py Reviewers: Subscribers: Tasks: Tags: --- test/dtypes/test_affine_quantized.py | 23 +++++++++++++++++++++++ torchao/dtypes/affine_quantized_tensor.py | 7 +++++-- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/test/dtypes/test_affine_quantized.py b/test/dtypes/test_affine_quantized.py index 4e98ffd564..1882afd36b 100644 --- a/test/dtypes/test_affine_quantized.py +++ b/test/dtypes/test_affine_quantized.py @@ -133,7 +133,30 @@ def test_print_quantized_module(self, apply_quant): assert "AffineQuantizedTensor" in str(ql) +class TestAffineQuantizedBasic(TestCase): + COMMON_DEVICES = ["cpu"] + (["cuda"] if torch.cuda.is_available() else []) + COMMON_DTYPES = [torch.bfloat16] + + @common_utils.parametrize("apply_quant", get_quantization_functions(False, True)) + @common_utils.parametrize("device", COMMON_DEVICES) + @common_utils.parametrize("dtype", COMMON_DTYPES) + def test_flatten_unflatten(self, apply_quant, device, dtype): + l = torch.nn.Linear(128, 256, dtype=dtype, device=device) + ql = apply_quant(l) + lp_tensor = ql.weight + tensor_data_name_dict, tensor_attributes = lp_tensor.__tensor_flatten__() + tensor_data_dict = {name: getattr(lp_tensor, name) for name in tensor_data_name_dict} + outer_size = lp_tensor.size() + outer_stride = lp_tensor.stride() + reconstructed = type(lp_tensor).__tensor_unflatten__(tensor_data_dict, tensor_attributes, outer_size, outer_stride) + example_inputs = (torch.randn(32, 128, dtype=dtype, device=device),) + ref = ql(*example_inputs) + ql.weight = torch.nn.Parameter(reconstructed, requires_grad=False) + reconstruct_res = ql(*example_inputs) + self.assertEqual(reconstruct_res, ref) + common_utils.instantiate_parametrized_tests(TestAffineQuantized) +common_utils.instantiate_parametrized_tests(TestAffineQuantizedBasic) if __name__ == "__main__": diff --git a/torchao/dtypes/affine_quantized_tensor.py b/torchao/dtypes/affine_quantized_tensor.py index 75d178fb50..46b4b51ef6 100644 --- a/torchao/dtypes/affine_quantized_tensor.py +++ b/torchao/dtypes/affine_quantized_tensor.py @@ -1268,8 +1268,11 @@ def from_plain( def to(self, *args, **kwargs): kwargs = self._get_to_kwargs(*args, **kwargs) device = kwargs["device"] - if not is_device("cuda", device): - raise ValueError(f"TensorCoreTiledAQTTensorImpl is only available for cuda device, can't convert to {device}") + # tensor core tiled layout supports both cpu and cuda but does not support the conversion + # between these two devices, in the future we should not use the same layout for + # cpu and cuda device: https://github.com/pytorch/ao/issues/1117 + if not is_device(torch.device(self.device).type, device): + raise ValueError(f"TensorCoreTiledAQTTensorImpl does not support conversion from {self.device} to {device}") return self.__class__( self.packed_weight.to(device), self.scale_and_zero.to(device),