Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unified AffineQuantizedTensor subclass #214

Merged
merged 2 commits into from
May 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 60 additions & 7 deletions test/quantization/test_quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def quantize(self, model: torch.nn.Module) -> torch.nn.Module:
apply_dynamic_quant(model)
return model

class M(torch.nn.Module):
class ToyLinearModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.linear1 = torch.nn.Linear(64, 32, bias=False).to(torch.float)
Expand All @@ -103,7 +103,7 @@ def forward(self, x):

class TestQuantFlow(unittest.TestCase):
def test_dynamic_quant_gpu_singleline(self):
m = M().eval()
m = ToyLinearModel().eval()
m = _apply_dynamic_quant(m)
quantized = m(*m.example_inputs())
# AssertionError: Expecting input to have dtype torch.float32, but got dtype: torch.float64
Expand All @@ -116,7 +116,7 @@ def test_dynamic_quant_gpu_singleline(self):
@unittest.skip("skipping for now due to torch.compile error")
def test_dynamic_quant_gpu_unified_api_unified_impl(self):
quantizer = XNNPackDynamicQuantizer()
m = M().eval()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.prepare(m)
m = quantizer.convert(m)
Expand All @@ -131,7 +131,7 @@ def test_dynamic_quant_gpu_unified_api_unified_impl(self):
@unittest.skip("FAILED test/quantization/test_quant_api.py::TestQuantFlow::test_dynamic_quant_gpu_unified_api_eager_mode_impl - AssertionError: Tensor-likes are not equal!")
def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):
quantizer = TorchCompileDynamicQuantizer()
m = M().eval()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
quantized = m(*example_inputs)
Expand All @@ -141,7 +141,7 @@ def test_dynamic_quant_gpu_unified_api_eager_mode_impl(self):

@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
def test_int8_wo_quant_save_load(self):
m = M().eval().cpu()
m = ToyLinearModel().eval().cpu()
apply_weight_only_int8_quant(m)
example_inputs = m.example_inputs()
ref = m(*example_inputs)
Expand All @@ -150,7 +150,7 @@ def test_int8_wo_quant_save_load(self):

state_dict = torch.load(_TMP_FN)
os.remove(_TMP_FN)
m2 = M().eval()
m2 = ToyLinearModel().eval()
apply_weight_only_int8_quant(m2)
m2.load_state_dict(state_dict)
m2 = m2.to(device="cuda")
Expand All @@ -165,7 +165,7 @@ def test_8da4w_quantizer(self):
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=32)
m = M().eval()
m = ToyLinearModel().eval()
example_inputs = m.example_inputs()
m = quantizer.quantize(m)
assert isinstance(m.linear1, Int8DynActInt4WeightLinear)
Expand Down Expand Up @@ -392,5 +392,58 @@ def test_eval_wrapper(self):
f"accuracy regressed from 7.76 to {result['results']['wikitext']['word_perplexity,none']}"
)

# TODO: move to a separate test file
@unittest.skipIf(not TORCH_VERSION_AFTER_2_4, "Test only enabled for 2.4+")
def test_quantized_tensor_subclass_8da4w(self):
from torchao.quantization.subclass import AffineQuantizedTensor
from torchao.quantization.quant_primitives import MappingType
import copy

# weight settings
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you expect this list to be fairly constant over time? Should we consider some dataclass like object for the config?

Also things like quant_min and quant_max should be derivable assuming the target_dtype is int8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could have some helper functions to give us everything here in the future I think

for quant_min/quant_max, right now this is using 4 bit, we don't have torch.int4 right now, but we could add that

groupsize = 32
mapping_type = MappingType.SYMMETRIC
block_size = (1, groupsize)
target_dtype = torch.int8
eps = torch.finfo(torch.float32).eps
quant_min = -8
quant_max = 7

# TODO: make a general helper function?
def get_per_token_block_size(x):
block_size = []
for i in range(len(x.shape)-1):
block_size.append(1)
block_size.append(x.shape[-1])
return block_size

# input settings
input_mapping_type = MappingType.ASYMMETRIC
input_target_dtype = torch.int8
input_quant_func = lambda x: AffineQuantizedTensor.from_float(x, input_mapping_type, get_per_token_block_size(x), input_target_dtype)

m = ToyLinearModel().eval()
m_copy = copy.deepcopy(m)
example_inputs = m.example_inputs()
m.linear1.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear1.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of AffineQuantizedTensor.from_float could we have a factory function similar to to_nf4 in

def to_nf4(tensor, block_size: int = 64, scaler_block_size: int = 256):
?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not the final UI, we will need to integrate this with things like https://github.com/pytorch/ao/blob/main/torchao/quantization/quant_api.py#L129 as well, I'm not sure how to_nf4 would fit in there, but I think we could discuss further on how to expose this to end users

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is input_quant_func?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

input_quant_func is for quantizing input (in dynamic quantization)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think this is a good idea. What if you have an op that takes in two tensor plus quantized weight. Now you will need to input_quant_funcs? We are expanding the execution semantics of the op, in this case linear, for which this class's __torch_dispatch__ is invoked. If you wanted to do this, I would have used another tensor subclass AffineQuantizedDynamicLinear whose semantic is more clear in terms of how it will override linear specifically. But I dont know if we can truly generalize this.

At high level, it is still not clear how we will use tensor subclass to represent quantized compute

m.linear2.weight = torch.nn.Parameter(AffineQuantizedTensor.from_float(m.linear2.weight, mapping_type, block_size, target_dtype, quant_min, quant_max, eps, input_quant_func=input_quant_func), requires_grad=False)
assert isinstance(m.linear1.weight, AffineQuantizedTensor)
assert isinstance(m.linear2.weight, AffineQuantizedTensor)

# reference
from torchao.quantization.quant_api import Int8DynActInt4WeightQuantizer
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear

quantizer = Int8DynActInt4WeightQuantizer(groupsize=groupsize)
m_copy = quantizer.quantize(m_copy)
assert isinstance(m_copy.linear1, Int8DynActInt4WeightLinear)
assert isinstance(m_copy.linear2, Int8DynActInt4WeightLinear)

res = m(*example_inputs)
ref = m_copy(*example_inputs)
self.assertTrue(torch.equal(res, ref))




if __name__ == "__main__":
unittest.main()
16 changes: 9 additions & 7 deletions torchao/quantization/quant_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def _get_reduction_params(block_size, input_size):

def quantize_affine(
input: torch.Tensor,
block_size: List[int],
block_size: Tuple[int, ...],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think a Tuple is better for the block_size argument, since it's immutable

scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
output_dtype: torch.dtype,
Expand All @@ -146,7 +146,7 @@ def quantize_affine(
"""
Args:
input (torch.Tensor): original float32 or bfloat16 Tensor
block_size: (List[int]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
scale (float): quantization parameter for affine quantization
zero_point (int): quantization parameter for affine quantization
Expand Down Expand Up @@ -191,7 +191,7 @@ def quantize_affine(

def dequantize_affine(
input: torch.Tensor,
block_size: List[int],
block_size: Tuple[int, ...],
scale: torch.Tensor,
zero_point: Optional[torch.Tensor],
input_dtype: torch.dtype,
Expand Down Expand Up @@ -244,7 +244,7 @@ class MappingType(Enum):
def choose_qparams_affine(
input: torch.Tensor,
mapping_type: MappingType,
block_size: List[int],
block_size: Tuple[int, ...],
target_dtype: torch.dtype,
quant_min: Optional[int] = None,
quant_max: Optional[int] = None,
Expand All @@ -256,12 +256,14 @@ def choose_qparams_affine(
Args:
input (torch.Tensor): fp32, bf16, fp16 input Tensor
mapping_type (MappingType): determines how the qparams are calculated, symmetric or asymmetric
block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
target_dtype (torch.dtype): dtype for target quantized Tensor
quant_min (Optional[int]): minimum quantized value for target quantized Tensor
quant_max (Optioanl[int]): maximum quantized value for target quantized Tensor
eps (Optional[float]: minimum scale
scale_dtype (torch.dtype): dtype for scales
zero_point_dtype (torch.dtype): dtype for zero_points
eps (Optional[float]): minimum scale, if not provided, default to eps of input.dtype
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where can the user find eps of input.dtype?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just running torch.finfo(dtype).eps I think, I didn't find a table showing this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, it's a bit of a nit, but you could just add that snippet like

minimum scale, if not provided, default to torch.finfo(input.dtype).eps

scale_dtype (torch.dtype): dtype for scale Tensor
zero_point_dtype (torch.dtype): dtype for zero_point Tensor

Output:
Tuple of scales and zero_points Tensor with requested dtype
Expand Down
Loading
Loading