Skip to content

Commit

Permalink
Add symmetric quantization with no clipping error in the tensor subcl…
Browse files Browse the repository at this point in the history
…ass based API
  • Loading branch information
Martin Yuan committed Sep 9, 2024
1 parent 1b317f9 commit 9b1120b
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 4 deletions.
25 changes: 25 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,31 @@ def test_quantized_tensor_subclass_8da4w(self):
self.assertTrue(torch.equal(res, ref))

@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w_no_clipping_err(self):
group_size = 32
m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
quantize_(m, int8_dynamic_activation_int4_weight(group_size=group_size, mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR))

assert isinstance(m.linear1.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear2.weight, LinearActivationQuantizedTensor)
assert isinstance(m.linear1.weight.original_weight_tensor, AffineQuantizedTensor)
assert isinstance(m.linear2.weight.original_weight_tensor, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=group_size, mapping_type=MappingType.SYMMETRIC_NO_CLIPPING_ERR)
m_copy = quantizer.quantize(m_copy)
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))

# @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "Test currently doesn't work for 2.5+")
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_quantized_tensor_subclass_int4(self):
Expand Down
7 changes: 3 additions & 4 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,14 +473,13 @@ def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
target_dtype = torch.int8
return to_affine_quantized_intx(x, mapping_type, _get_per_token_block_size(x), target_dtype)

def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32, mapping_type=MappingType.SYMMETRIC):
"""This is defined here instead of local function to support serialization
"""
if weight.shape[-1] % group_size != 0:
return weight

# weight settings
mapping_type = MappingType.SYMMETRIC
block_size = (1, group_size)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
Expand All @@ -494,7 +493,7 @@ def apply_int8_dynamic_activation_int4_weight_quant(weight, group_size=32):
weight = to_linear_activation_quantized(weight, input_quant_func)
return weight

def int8_dynamic_activation_int4_weight(group_size=32):
def int8_dynamic_activation_int4_weight(group_size=32, mapping_type=MappingType.SYMMETRIC):
"""Applies int8 dynamic per token asymmetric activation quantization and int4 per group weight symmetric quantization to linear
This is used to produce a model for executorch backend, but currently executorch did not
support lowering for the quantized model from this flow yet
Expand All @@ -503,7 +502,7 @@ def int8_dynamic_activation_int4_weight(group_size=32):
`group_size`: parameter for quantization, controls the granularity of quantization, smaller
size is more fine grained
"""
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size)
return _get_linear_subclass_inserter(apply_int8_dynamic_activation_int4_weight_quant, group_size=group_size, mapping_type=mapping_type)


def int4_weight_only(group_size=128, layout_type=TensorCoreTiledLayoutType(inner_k_tiles=8), use_hqq=False):
Expand Down

0 comments on commit 9b1120b

Please sign in to comment.