Skip to content

Commit

Permalink
Skip int4 QAT tests for nightly for now (#521)
Browse files Browse the repository at this point in the history
int4 tinygemm quantization is currently broken in master and
being fixed in #517. Let's
skip these tests for now until that is fixed.
  • Loading branch information
andrewor14 authored Jul 17, 2024
1 parent f8789f7 commit ec95afd
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion test/quantization/test_qat.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@
get_groupwise_affine_qparams,
groupwise_affine_quantize_tensor,
)
from torchao.utils import TORCH_VERSION_AFTER_2_4
from torchao.utils import (
TORCH_VERSION_AFTER_2_4,
TORCH_VERSION_AFTER_2_5,
)


# TODO: put this in a common test utils file
Expand Down Expand Up @@ -366,6 +369,8 @@ def _assert_close_4w(self, val, ref):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_primitives(self):
n_bit = 4
group_size = 32
Expand Down Expand Up @@ -411,6 +416,8 @@ def test_qat_4w_primitives(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_linear(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATLinear
from torchao.quantization.GPTQ import WeightOnlyInt4Linear
Expand Down Expand Up @@ -439,6 +446,8 @@ def test_qat_4w_linear(self):

@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "skipping when torch version is 2.4 or lower")
@unittest.skipIf(not _CUDA_IS_AVAILABLE, "skipping when cuda is not available")
# TODO: remove once we fix int4 error: https://github.com/pytorch/ao/pull/517
@unittest.skipIf(TORCH_VERSION_AFTER_2_5, "int4 doesn't work for 2.5+ right now")
def test_qat_4w_quantizer(self):
from torchao.quantization.prototype.qat import Int4WeightOnlyQATQuantizer
from torchao.quantization.GPTQ import Int4WeightOnlyQuantizer
Expand Down

0 comments on commit ec95afd

Please sign in to comment.