-
Notifications
You must be signed in to change notification settings - Fork 121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Unified AffineQuantizedTensor subclass #214
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -87,7 +87,7 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module: | |||
apply_dynamic_quant(model) | ||||
return model | ||||
|
||||
class M(torch.nn.Module): | ||||
class ToyLinearModel(torch.nn.Module): | ||||
def __init__(self): | ||||
super().__init__() | ||||
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float) | ||||
|
@@ -103,7 +103,7 @@ def forward(self, x): | |||
|
||||
class TestQuantFlow(unittest.TestCase): | ||||
def test_dynamic_quant_gpu_singleline(self): | ||||
m = M().eval() | ||||
m = ToyLinearModel().eval() | ||||
m = _apply_dynamic_quant(m) | ||||
quantized = m(*m.example_inputs()) | ||||
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64 | ||||
|
@@ -116,7 +116,7 @@ def test_dynamic_quant_gpu_singleline(self): | |||
@unittest.skip("skipping for now due to torch.compile error") | ||||
def test_dynamic_quant_gpu_unified_api_unified_impl(self): | ||||
quantizer = XNNPackDynamicQuantizer() | ||||
m = M().eval() | ||||
m = ToyLinearModel().eval() | ||||
example_inputs = m.example_inputs() | ||||
m = quantizer.prepare(m) | ||||
m = quantizer.convert(m) | ||||
|
@@ -131,7 +131,7 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self): | |||
@unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!") | ||||
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): | ||||
quantizer = TorchCompileDynamicQuantizer() | ||||
m = M().eval() | ||||
m = ToyLinearModel().eval() | ||||
example_inputs = m.example_inputs() | ||||
m = quantizer.quantize(m) | ||||
quantized = m(*example_inputs) | ||||
|
@@ -141,7 +141,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self): | |||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available") | ||||
def test_int8_wo_quant_save_load(self): | ||||
m = M().eval().cpu() | ||||
m = ToyLinearModel().eval().cpu() | ||||
apply_weight_only_int8_quant(m) | ||||
example_inputs = m.example_inputs() | ||||
ref = m(*example_inputs) | ||||
|
@@ -150,7 +150,7 @@ def test_int8_wo_quant_save_load(self): | |||
|
||||
state_dict = torch.load(_TMP_FN) | ||||
os.remove(_TMP_FN) | ||||
m2 = M().eval() | ||||
m2 = ToyLinearModel().eval() | ||||
apply_weight_only_int8_quant(m2) | ||||
m2.load_state_dict(state_dict) | ||||
m2 = m2.to(device="cuda") | ||||
|
@@ -165,7 +165,7 @@ def test_8da4w_quantizer(self): | |||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear | ||||
|
||||
quantizer = Int8DynActInt4WeightQuantizer(groupsize=32) | ||||
m = M().eval() | ||||
m = ToyLinearModel().eval() | ||||
example_inputs = m.example_inputs() | ||||
m = quantizer.quantize(m) | ||||
assert isinstance(m.linear1, Int8DynActInt4WeightLinear) | ||||
|
@@ -392,5 +392,58 @@ def test_eval_wrapper(self): | |||
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}" | ||||
) | ||||
|
||||
# TODO: move to a separate test file | ||||
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+") | ||||
def test_quantized_tensor_subclass_8da4w(self): | ||||
from torchao.quantization.subclass import AffineQuantizedTensor | ||||
from torchao.quantization.quant_primitives import MappingType | ||||
import copy | ||||
|
||||
# weight settings | ||||
groupsize = 32 | ||||
mapping_type = MappingType.SYMMETRIC | ||||
block_size = (1, groupsize) | ||||
target_dtype = torch.int8 | ||||
eps = torch.finfo(torch.float32).eps | ||||
quant_min = -8 | ||||
quant_max = 7 | ||||
|
||||
# TODO: make a general helper function? | ||||
def get_per_token_block_size(x): | ||||
block_size = [] | ||||
for i in range(len(x.shape)-1): | ||||
block_size.append(1) | ||||
block_size.append(x.shape[-1]) | ||||
return block_size | ||||
|
||||
# input settings | ||||
input_mapping_type = MappingType.ASYMMETRIC | ||||
input_target_dtype = torch.int8 | ||||
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype) | ||||
|
||||
m = ToyLinearModel().eval() | ||||
m_copy = copy.deepcopy(m) | ||||
example_inputs = m.example_inputs() | ||||
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Instead of ao/torchao/dtypes/nf4tensor.py Line 883 in f6d56ca
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is not the final UI, we will need to integrate this with things like https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L129 as well, I'm not sure how There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I dont think this is a good idea. What if you have an op that takes in two tensor plus quantized weight. Now you will need to input_quant_funcs? We are expanding the execution semantics of the op, in this case linear, for which this class's At high level, it is still not clear how we will use tensor subclass to represent quantized compute |
||||
m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False) | ||||
assert isinstance(m.linear1.weight, AffineQuantizedTensor) | ||||
assert isinstance(m.linear2.weight, AffineQuantizedTensor) | ||||
|
||||
# reference | ||||
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer | ||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear | ||||
|
||||
quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize) | ||||
m_copy = quantizer.quantize(m_copy) | ||||
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear) | ||||
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear) | ||||
|
||||
res = m(*example_inputs) | ||||
ref = m_copy(*example_inputs) | ||||
self.assertTrue(torch.equal(res, ref)) | ||||
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__": | ||||
unittest.main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -136,7 +136,7 @@ def _get_reduction_params(block_size, input_size): | |
|
||
def quantize_affine( | ||
input: torch.Tensor, | ||
block_size: List[int], | ||
block_size: Tuple[int, ...], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why this change? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think a Tuple is better for the block_size argument, since it's immutable |
||
scale: torch.Tensor, | ||
zero_point: Optional[torch.Tensor], | ||
output_dtype: torch.dtype, | ||
|
@@ -146,7 +146,7 @@ def quantize_affine( | |
""" | ||
Args: | ||
input (torch.Tensor): original float32 or bfloat16 Tensor | ||
block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam | ||
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam | ||
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization | ||
scale (float): quantization parameter for affine quantization | ||
zero_point (int): quantization parameter for affine quantization | ||
|
@@ -191,7 +191,7 @@ def quantize_affine( | |
|
||
def dequantize_affine( | ||
input: torch.Tensor, | ||
block_size: List[int], | ||
block_size: Tuple[int, ...], | ||
scale: torch.Tensor, | ||
zero_point: Optional[torch.Tensor], | ||
input_dtype: torch.dtype, | ||
|
@@ -244,7 +244,7 @@ class MappingType(Enum): | |
def choose_qparams_affine( | ||
input: torch.Tensor, | ||
mapping_type: MappingType, | ||
block_size: List[int], | ||
block_size: Tuple[int, ...], | ||
target_dtype: torch.dtype, | ||
quant_min: Optional[int] = None, | ||
quant_max: Optional[int] = None, | ||
|
@@ -256,12 +256,14 @@ def choose_qparams_affine( | |
Args: | ||
input (torch.Tensor): fp32, bf16, fp16 input Tensor | ||
mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric | ||
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam | ||
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization | ||
target_dtype (torch.dtype): dtype for target quantized Tensor | ||
quant_min (Optional[int]): minimum quantized value for target quantized Tensor | ||
quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor | ||
eps (Optional[float]: minimum scale | ||
scale_dtype (torch.dtype): dtype for scales | ||
zero_point_dtype (torch.dtype): dtype for zero_points | ||
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where can the user find There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just running There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hm, it's a bit of a nit, but you could just add that snippet like
|
||
scale_dtype (torch.dtype): dtype for scale Tensor | ||
zero_point_dtype (torch.dtype): dtype for zero_point Tensor | ||
|
||
Output: | ||
Tuple of scales and zero_points Tensor with requested dtype | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you expect this list to be fairly constant over time? Should we consider some dataclass like object for the config?
Also things like quant_min and quant_max should be derivable assuming the target_dtype is int8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we could have some helper functions to give us everything here in the future I think
for quant_min/quant_max, right now this is using 4 bit, we don't have torch.int4 right now, but we could add that