Skip to content

Commit

Permalink
Fix dimension issues for int4 weight only quant path
Browse files Browse the repository at this point in the history
Summary:
Currently the accepted dimension of _quantized_linear is not clear, this PR fixes the issue.

Currently the "tensor_core_tiled" layout tensor does not do repacking in view operation, which is incorrect, this PR removes the view support (which is not needed right now), and restrict the use case to transpose op, and records the transpose status of the tensor instead of doing repacking for performance.

Test Plan:
python test/quantization/test_quant_api.py
python test/integration/test_integration.py

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed Jun 7, 2024
1 parent e2196fd commit c0600c2
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 30 deletions.
29 changes: 29 additions & 0 deletions test/dtypes/test_aq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from torch.testing._internal.common_utils import (
TestCase,
run_tests,
)
from torchao.quantization.quant_api import get_apply_int4wo_quant
import torch
import unittest


class TestAQ(TestCase):
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_tensor_core_layout_transpose(self):
t = torch.rand(128, 256, dtype=torch.bfloat16, device="cuda")
shape = t.shape
apply_int4wo_quant = get_apply_int4wo_quant(groupsize=32)
aqt = apply_int4wo_quant(t)
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

# transpose shape test
for _ in range(10):
t = t.t()
aqt = aqt.t()
shape = t.shape
aqt_shape = aqt.shape
self.assertEqual(aqt_shape, shape)

if __name__ == "__main__":
run_tests()
6 changes: 3 additions & 3 deletions test/integration/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,9 @@ def _test_smooth_linear_impl(self, x_shape, lin_shape, device):
# rtol=0.00001), \
# 'y_smooth_fq_only not close to y_dynamic_q'

self.assertTrue(sqnr_smooth_fq.item() >= 40.0)
self.assertTrue(sqnr_dynamic_q.item() >= 40.0)
self.assertTrue(sqnr_fq.item() >= 40.0)
self.assertTrue(sqnr_smooth_fq.item() >= 40.0, f"got: {sqnr_smooth_fq.item()}")
self.assertTrue(sqnr_dynamic_q.item() >= 40.0, f"got: {sqnr_dynamic_q.item()}")
self.assertTrue(sqnr_fq.item() >= 40.0, f"got: {sqnr_fq.item()}")

# Restore backend
torch.backends.quantized.engine = orig_backend
Expand Down
70 changes: 43 additions & 27 deletions torchao/dtypes/aqt.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ def _apply_fn_to_data(self, fn):
fn(self.zero_point),
)

def _change_shape(self, shape):
def _transpose_change_shape(self, shape):
return self.__class__(
self.int_data.view(shape), self.scale, self.zero_point
)
Expand All @@ -200,9 +200,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
if func is aten.t.default:
new = args[0]._transpose_change_shape(args[0].shape[::-1])
return return_and_correct_aliasing(func, args, kwargs, new)

raise NotImplementedError(
Expand Down Expand Up @@ -239,6 +238,7 @@ def __new__(
cls,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
):
kwargs = {}
kwargs["device"] = packed_weight.device
Expand All @@ -254,27 +254,30 @@ def __init__(
self,
packed_weight: torch.Tensor,
scale_and_zero: torch.Tensor,
transposed: bool,
):
self.packed_weight = packed_weight
self.scale_and_zero = scale_and_zero
self.transposed = False

def __tensor_flatten__(self):
return ["packed_weight", "scale_and_zero"], []
return ["packed_weight", "scale_and_zero"], [self.transposed]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
packed_weight, scale_and_zero = tensor_data_dict["packed_weight"], tensor_data_dict["scale_and_zero"]
return cls(packed_weight, scale_and_zero)
transposed, = tensor_attributes
return cls(packed_weight, scale_and_zero, transposed)

@classmethod
def from_plain(cls, int_data, scale, zero_point, inner_k_tiles=8):
packed_weight = torch.ops.aten._convert_weight_to_int4pack(int_data.to(torch.int32), inner_k_tiles)
scale = scale.reshape(int_data.shape[0], -1)
zero_point = zero_point.reshape(int_data.shape[0], -1)
scale_and_zero = pack_tinygemm_scales_and_zeros(scale, zero_point)
return cls(packed_weight, scale_and_zero)
return cls(packed_weight, scale_and_zero, False)

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
Expand All @@ -283,20 +286,21 @@ def to(self, *args, **kwargs):
raise ValueError(f"TensorCoreTiledAQTLayout is only available for cuda device")
return self.__class__(
self.packed_weight.to(kwargs["device"]),
self.scale_and_zero.to(kwargs["device"])
self.scale_and_zero.to(kwargs["device"]),
self.transposed
)

def _apply_fn_to_data(self, fn):
self.packed_weight = fn(self.packed_weight)
self.scale_and_zero = fn(self.scale_and_zero)
return self

def _change_shape(self, shape):
# int_data, scale, zero = self.get_plain()
# int_data = int_data.view(shape)
# changed = self.from_plain(int_data, scale, zero)
# return changed
# TODO: changing shape is no-op for int4 packed weight right now
def _transpose_change_shape(self):
"""Changing the shape of the tensor for transpose operation
In this case we don't need to repack the weight and just rely on external
shape being changed and record the status of transpose/no-transpose
"""
self.transposed = not self.transposed
return self

@classmethod
Expand All @@ -308,9 +312,8 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.view.default:
assert len(args) == 2
new = args[0]._change_shape(args[1])
if func is aten.t.default:
new = args[0]._transpose_change_shape()
return return_and_correct_aliasing(func, args, kwargs, new)

raise NotImplementedError(
Expand All @@ -327,8 +330,7 @@ def get_plain(self):
)
cur_shape = self.shape
assert len(cur_shape) == 4
# TODO: expose the arg
inner_k_tiles = self.cur_shape[-1] * 2
inner_k_tiles = cur_shape[-1] * 2
original_shape = (cur_shape[0] * 8, cur_shape[1] * (inner_k_tiles * 16))
eye_shape = original_shape[1]
block_size = (1, 32)
Expand Down Expand Up @@ -555,9 +557,11 @@ def _apply_fn_to_data(self, fn):
strides=self.stride(),
)

def _change_shape(self, shape, block_size):
def _transpose_change_shape(self, shape, block_size):
"""Changing the shape of the tensor for transpose operation
"""
return self.__class__(
self.layout_tensor.view(shape), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()
self.layout_tensor.t(), block_size, shape, self.quant_min, self.quant_max, self.zero_point_domain, dtype=self.dtype, strides=self.stride()
)

@classmethod
Expand All @@ -581,7 +585,15 @@ def __torch_dispatch__(cls, func, types, args, kwargs):
f"AffineQuantizedTensor dispatch: attempting to run {func}, this is not supported"
)

def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True):
def _quantized_linear_op(input_tensor, weight_qtensor, bias):
"""
Quantized version of F.linear operator
Args:
input_tensor: dimension is (batch_size, in_features)
weight_tensor: dimension is (out_features, in_features)
bias: dimension is (out_features,)
"""
# TODO: the old tensor subclass can use the single implementation for both F.linear dispatch
# and aten.addmm/aten.mm dispatch because `_change_shape` is not implmeneted correctly (got ignored
# for the int_data), this makes the dimension for weight_qtensor indeterministic, we need to fix
Expand Down Expand Up @@ -647,9 +659,11 @@ def _quantized_linear_op(input_tensor, weight_qtensor, bias, _from_flinear=True)
weight_qtensor.zero_point_domain == ZeroPointDomain.FLOAT and
weight_qtensor.layout == "tensor_core_tiled"
):
if not _from_flinear:
weight_qtensor = weight_qtensor.t()
assert weight_qtensor.block_size[0] == 1, f"Requires groupwise quantization, got block_size: {block_size}"
assert input_tensor.shape[-1] == weight_qtensor.shape[1], (
f"need input_tensor shape: {input_tensor.shape} final"
f"dim to match weight_tensor shape: {weight_qtensor.shape} second dim "
)

# TODO: check groupsize quantization
# avoid circular dep, TODO: move this to a common util.py
Expand Down Expand Up @@ -745,7 +759,8 @@ def aten_mm(func, *args, **kwargs):
args[0],
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False)
weight_tensor = weight_tensor.t()
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand All @@ -759,7 +774,8 @@ def aten_mm(func, *args, **kwargs):
None
)
try:
return _quantized_linear_op(input_tensor, weight_tensor, bias, _from_flinear=False)
weight_tensor = weight_tensor.t()
return _quantized_linear_op(input_tensor, weight_tensor, bias)
except:
if isinstance(input_tensor, AffineQuantizedTensor):
input_tensor = input_tensor.dequantize()
Expand Down Expand Up @@ -795,7 +811,7 @@ def t(func, *args, **kwargs):
block_size = args[0].block_size
assert len(block_size) == 2
transposed_block_size = (block_size[1], block_size[0])
new = args[0]._change_shape(args[0].shape[::-1], transposed_block_size)
new = args[0]._transpose_change_shape(args[0].shape[::-1], transposed_block_size)
return return_and_correct_aliasing(func, args, kwargs, new)

to_aq = AffineQuantizedTensor.from_float

0 comments on commit c0600c2

Please sign in to comment.