From ec95afdf976630d292f0526f2beea23b7705c874 Mon Sep 17 00:00:00 2001 From: andrewor14 Date: Wed, 17 Jul 2024 15:18:11 -0400 Subject: [PATCH] Skip int4 QAT tests for nightly for now (#521) int4 tinygemm quantization is currently broken in master and being fixed in https://github.com/pytorch/ao/pull/517. Let's skip these tests for now until that is fixed. --- test/quantization/test_qat.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/test/quantization/test_qat.py b/test/quantization/test_qat.py index e53ef03819..3634ac791f 100644 --- a/test/quantization/test_qat.py +++ b/test/quantization/test_qat.py @@ -27,7 +27,10 @@ get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, ) -from torchao.utils import TORCH_VERSION_AFTER_2_4 +from torchao.utils import ( + TORCH_VERSION_AFTER_2_4, + TORCH_VERSION_AFTER_2_5, +) # TODO: put this in a common test utils file @@ -366,6 +369,8 @@ def _assert_close_4w(self, val, ref): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_primitives(self): n_bit = 4 group_size = 32 @@ -411,6 +416,8 @@ def test_qat_4w_primitives(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_linear(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear from torchao.quantization.GPTQ import WeightOnlyInt4Linear @@ -439,6 +446,8 @@ def test_qat_4w_linear(self): @unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower") @unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available") + # TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517 + @unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now") def test_qat_4w_quantizer(self): from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer