Skip to content

Commit

Permalink
Unified tensor subclass
Browse files Browse the repository at this point in the history
Summary:
Creatd a `QuantizedTensor` subclass that works for both weight and input (for dynamic quantization), for all granularities (levering the recently added choose_qparams_affine, quantize_affine
and dequantize_affine ops)

only verified for 8da4w right now, we can make it work for other types of quantization (mostly the operator dispatching part) later

Test Plan:
python test/quantization/test_quant_api.py -k test_quantized_tensor_subclass_8da4w

Reviewers:

Subscribers:

Tasks:

Tags:
  • Loading branch information
jerryzh168 committed May 4, 2024
1 parent 5364de6 commit 18fa26a
Show file tree
Hide file tree
Showing 2 changed files with 279 additions and 0 deletions.
53 changes: 53 additions & 0 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,5 +392,58 @@ def test_eval_wrapper(self):
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.subclass import QuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
groupsize = 32
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: QuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = M().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m.linear1.weight = torch.nn.Parameter(QuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
m.linear2.weight = torch.nn.Parameter(QuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
assert isinstance(m.linear1.weight, QuantizedTensor)
assert isinstance(m.linear2.weight, QuantizedTensor)
res = m(*example_inputs)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize)
m_copy = quantizer.quantize(m_copy)
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)
ref = m_copy(*example_inputs)

self.assertTrue(torch.equal(res, ref))




if __name__ == "__main__":
unittest.main()
226 changes: 226 additions & 0 deletions torchao/quantization/subclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
groupwise_affine_quantize_tensor,
quant_int8_dynamic_per_token_linear,
unpack_tinygemm_scales_and_zeros,
choose_qparams_affine,
quantize_affine,
dequantize_affine,
)
from .utils import find_multiple

Expand All @@ -23,6 +26,7 @@
"Int8DynamicallyQuantizedLinearWeight",
"Int8WeightOnlyQuantizedLinearWeight",
"Int4WeightOnlyQuantizedLinearWeight",
"QuantizedTensor",
]


Expand Down Expand Up @@ -592,3 +596,225 @@ def to_qtensor_components(cls, input_float, groupsize=128, inner_k_tiles=8):
)
int_data = aten._convert_weight_to_int4pack(input_int4x8, inner_k_tiles)
return int_data, scales_and_zeros, False, groupsize, inner_k_tiles


class QuantizedTensor(torch.Tensor):
"""
Base quantized tensor subclass. When the from_float method is used,
to create an instance of any QuantizedTensor
The shape and dtype of the tensor subclass represent how the tensor subclass looks externally,
regardless of the internal representation's type or orientation.
"""

@staticmethod
def __new__(cls, int_data, scale, zero_point, quant_min, quant_max, block_size, shape, input_quant_func=None, dtype=None, *args, **kwargs):
kwargs["device"] = int_data.device
kwargs["layout"] = (
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout
)
if dtype is None:
dtype = scale.dtype
kwargs["dtype"] = dtype
assert not kwargs.get("requires_grad", False)
kwargs["requires_grad"] = False
return torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]

def __init__(self, int_data, scale, zero_point, quant_min, quant_max, block_size, shape, input_quant_func=None, dtype=None, *args, **kwargs):
self.int_data = int_data
self.scale = scale
self.zero_point = zero_point
self.block_size = block_size
self.quant_min = quant_min
self.quant_max = quant_max
self.input_quant_func = input_quant_func

def __repr__(self):
return (
f"{self.__class__.__name__}(data={self.dequantize()}, shape={self.shape}, "
f"device={self.device}, dtype={self.dtype}, input_quant_func={self.input_quant_func}, requires_grad={self.requires_grad})"
)

def dequantize(self, output_dtype=None):
if output_dtype is None:
output_dtype = torch.float32

return dequantize_affine(self.int_data, self.block_size, self.scale, self.zero_point, self.int_data.dtype, self.quant_min, self.quant_max, output_dtype=output_dtype)

def __tensor_flatten__(self):
return ["int_data", "scales", "zero_point"], [self.quant_min, self.quant_max, self.block_size, self.shape, self.input_quant_func, self.dtype]

@classmethod
def __tensor_unflatten__(
cls, tensor_data_dict, tensor_attributes, outer_size, outer_stride
):
int_data, scale, zero_point = tensor_data_dict["int_data"], tensor_data_dict["scale"], tensor_data_dict["zero_point"]
block_size, quant_min, quant_max, shape, input_quant_func, dtype = tensor_attributes
return cls(
int_data,
scale,
zero_point,
quant_min,
quant_max,
block_size,
shape if outer_size is None else outer_size,
input_quant_func=input_quant_func,
dtype=dtype,
strides=outer_stride,
)

@classmethod
def from_float(
cls,
input_float,
mapping_type,
block_size,
target_dtype,
quant_min = None,
quant_max = None,
eps = None,
scale_dtype = None,
zero_point_dtype = None,
input_quant_func = None,
):
scale, zero_point = choose_qparams_affine(input_float, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, scale_dtype, zero_point_dtype)
int_data = quantize_affine(input_float, block_size, scale, zero_point, target_dtype, quant_min, quant_max)
return cls(
int_data,
scale,
zero_point,
quant_min,
quant_max,
block_size,
input_float.shape,
input_quant_func=input_quant_func,
dtype=input_float.dtype
)

# __torch_function__ = torch._C._disabled_torch_function_impl

@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
kwargs = {} if kwargs is None else kwargs

if func is torch.nn.functional.linear:
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
args[2] if len(args) > 2 else None,
)
if weight_qtensor.input_quant_func is not None:
input_tensor = weight_qtensor.input_quant_func(input_tensor)
input_tensor = input_tensor.dequantize()
weight_tensor = weight_qtensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)

try:
with torch._C.DisableTorchFunctionSubclass():
return func(*args, **kwargs)
except:
print(f"ERR: subclass doesn't implement {func}")


def _get_to_kwargs(self, *args, **kwargs):
device, dtype, _, memory_format = torch._C._nn._parse_to(*args, **kwargs)
device = self.device if device is None else device
dtype = self.dtype if dtype is None else dtype
memory_format = (
memory_format if memory_format is not None else torch.preserve_format
)
kwargs = {
"device": device,
"dtype": dtype,
"memory_format": memory_format,
}
return kwargs

def to(self, *args, **kwargs):
kwargs = self._get_to_kwargs(*args, **kwargs)
return self.__class__(
self.int_data.to(kwargs["device"]),
self.scale.to(kwargs["device"]),
self.zero_point.to(kwargs["device"]),
self.quant_min,
self.quant_max,
self.block_size,
self.shape,
self.input_quant_func,
**kwargs,
)

def _apply_fn_to_data(self, fn):
return self.__class__(
fn(self.int_data),
fn(self.scale),
fn(self.zero_point),
self.quant_min,
self.quant_max,
self.block_size,
self.shape,
self.input_quant_func,
dtype=self.dtype,
)

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
# two scenarios where we currently fall back to vanilla mm:
# 1 - when tensor is on CPU: we are missing qmm for CPU, but we should have a CPU implementation
# for consistency and to allow people to test
# 2 - we're given non-floats - quantizing long to int8 is crazy
if (
func in [aten.mm.default, aten.addmm.default]
and args[0].is_floating_point()
and args[0].is_cuda
):
if func == aten.addmm.default:
assert args[1].shape[-1] == args[2].shape[0], (
f"need mat1 shape: {args[1].shape} final"
f"dim to match mat2 shape: {args[2].shape} first dim "
)
input_tensor, weight_qtensor, bias = (
args[1],
args[2],
args[0],
)
else:
assert args[0].shape[-1] == args[1].shape[0], (
f"need mat1 shape: {args[0].shape} final dim"
f"to match mat2 shape: {args[1].shape} first dim"
)
input_tensor, weight_qtensor, bias = (
args[0],
args[1],
None if len(args) == 2 else args[2],
)
if weight_qtensor.input_quant_func is not None:
input_tensor = weight_qtensor.input_quant_func(input_tensor)
input_tensor = input_tensor.dequantize()
weight_tensor = weight_qtensor.dequantize()
return func(input_tensor, weight_tensor, bias)

if func is aten.detach.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.detach)
)

if func is aten.clone.default:
return return_and_correct_aliasing(
func, args, kwargs, args[0]._apply_fn_to_data(torch.clone)
)

if func is aten.t.default:
# TODO: need to implement this
# args[0].transposed = not args[0].transposed
# new = args[0]._change_shape(args[0].shape[::-1])
# return return_and_correct_aliasing(func, args, kwargs, new)
pass

if func is aten._to_copy.default:
return return_and_correct_aliasing(
func,
args,
kwargs,
args[0].to(*args[1:], **kwargs)._apply_fn_to_data(torch.clone),
)

0 comments on commit 18fa26a

Please sign in to comment.