Skip to content

Commit

Permalink
Reduce memory usage for symmetric choose_qparams
Browse files Browse the repository at this point in the history
Summary:
Also unified the impl a bit more

Test Plan:
python test/quantization/test_quant_primitives.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 2, 2024
1 parent 6ae2c0b commit e5e484f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 12 deletions.
8 changes: 8 additions & 0 deletions test/quantization/test_quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,14 @@ def test_quantize_dequantize_channel_asym_4d_multi_dim_reduction(self):
# we don't have corresponding ops in existing primitives, so just make sure it runs and it's close to float
torch.testing.assert_allclose(dequantized, input, rtol=2, atol=0.02)

@unittest.skipIf(not torch.cuda.is_available(), "skipping when cuda is not available")
def test_get_group_qparams_symmetric_memory(self):
"""Check the memory usage of the op"""
weight = torch.randn(1024, 1024).to(device="cuda")
n_bit = 4
groupsize = 128
(scale_ao, _) = get_group_qparams_symmetric(weight, n_bit, groupsize)
self.assertTrue(torch.cuda.max_memory_allocated() < 4.5e6)

if __name__ == "__main__":
unittest.main()
23 changes: 11 additions & 12 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,8 @@ def choose_qparams_affine(
Tuple of scales and zero_points Tensor with requested dtype
"""
quant_min, quant_max = _get_and_check_qmin_qmax(target_dtype, quant_min, quant_max)
assert mapping_type in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC], f"Unsupported mapping type: {mapping_type}"

if scale_dtype is None:
scale_dtype = input.dtype
if zero_point_dtype is None:
Expand All @@ -269,23 +271,20 @@ def choose_qparams_affine(
shape_for_reduction, reduction_dims = _get_reduction_params(block_size, input.size())
input = input.view(shape_for_reduction)

if mapping_type == MappingType.SYMMETRIC:
amax = torch.amax(torch.abs(input), dim=reduction_dims, keepdim=False)
scale = amax / (float(quant_max - quant_min) / 2)
zero_point = torch.ones_like(scale)
zero_point *= int((quant_min + quant_max + 1) / 2)
elif mapping_type == MappingType.ASYMMETRIC:
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)
min_val = torch.amin(input, dim=reduction_dims, keepdim=False)
max_val = torch.amax(input, dim=reduction_dims, keepdim=False)

min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))
min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

if mapping_type == MappingType.SYMMETRIC:
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
zero_point = torch.full_like(scale, int((quant_min + quant_max + 1) / 2))
else:
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
zero_point = quant_min - torch.round(min_val_neg / scale)
zero_point = torch.clamp(zero_point, quant_min, quant_max)
else:
raise RuntimeError(f"Unsupported mapping type: {mapping_type}")

if eps is not None:
scale = torch.clamp(scale, min=eps)
Expand Down

0 comments on commit e5e484f

Please sign in to comment.