From a087e500bca2baf9a0e8541379f2be49a0e18719 Mon Sep 17 00:00:00 2001 From: Tijmen Blankevoort Date: Mon, 16 Sep 2024 20:16:16 -0700 Subject: [PATCH] Add Embedding Quantization to QAT module_swap flow (#886) Summary: Pull Request resolved: https://github.com/pytorch/ao/pull/886 Adding the embedding quantizer in the same fashion as the other module swap setup. Differential Revision: D62664322 --- torchao/quantization/GPTQ.py | 35 ++++++++ .../prototype/qat/_module_swap_api.py | 82 +++++++++++++++++-- torchao/quantization/prototype/qat/utils.py | 5 ++ 3 files changed, 114 insertions(+), 8 deletions(-) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index 23c87141c..cb7d4f4e6 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -965,6 +965,41 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: self.precision, ) + +def _replace_embedding_4w( + module: torch.nn.Module, + groupsize: int, + embedding_class: Type[torch.nn.Module], + padding_allowed: bool, + copy_weights: bool = False, +): + #import the util function here to avoid circular dependency + from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter + + def filter_fn(child: torch.nn.Module, cur_fqn:str) -> bool: + return isinstance(child, nn.Embedding) and (_check_linear_int4_k(child.embedding_dim, groupsize) or padding_allowed) + + def replacement_fn(child: torch.nn.Module) -> torch.nn.Module: + new_embedding = embedding_class( + num_embeddings = child.num_embeddings, + embedding_dim = child.embedding_dim, + padding_idx = child.padding_idx, + max_norm = child.max_norm, + norm_type = child.norm_type, + scale_grad_by_freq = child.scale_grad_by_freq, + sparse = child.sparse, + device=child.weight.device, + groupsize=groupsize, + ) + # In distributed training, the model may be instantiated + # on the meta device, in which case there is no need to + # copy the weights, and doing so will result in an error + if copy_weights and child.weight.device != torch.device("meta"): + new_embedding.weight = child.weight + return new_embedding + + _replace_with_custom_fn_if_matches_filter(module, replacement_fn, filter_fn) + def _replace_linear_8da4w( module: torch.nn.Module, groupsize: int, diff --git a/torchao/quantization/prototype/qat/_module_swap_api.py b/torchao/quantization/prototype/qat/_module_swap_api.py index e8e1e0b7c..58a10134b 100644 --- a/torchao/quantization/prototype/qat/_module_swap_api.py +++ b/torchao/quantization/prototype/qat/_module_swap_api.py @@ -13,6 +13,7 @@ _check_linear_int4_k, _replace_linear_int4, _replace_linear_8da4w, + _replace_embedding_4w, get_groupwise_affine_qparams, groupwise_affine_quantize_tensor, Int8DynActInt4WeightLinear, @@ -28,6 +29,7 @@ _choose_qparams_per_token_asymmetric, _fake_quantize_per_channel_group, _fake_quantize_per_token, + _get_qmin_qmax ) @@ -47,6 +49,14 @@ class Int8DynActInt4WeightQATQuantizerModuleSwap(Int8DynActInt4WeightQATQuantize instead if possible. """ + def __init__(self, + quantize_embedding: bool = False, + embedding_groupsize: int = 32, + *args, **kwargs): + super().__init__(*args, **kwargs) + self.quantize_embedding = quantize_embedding + self.embedding_groupsize = embedding_groupsize + def prepare( self, model: torch.nn.Module, @@ -62,6 +72,14 @@ def prepare( Int8DynActInt4WeightQATLinear, copy_weights=True, ) + if self.quantize_embedding: + _replace_embedding_4w( + model, + self.embedding_groupsize, + Int4WeightQATEmbedding, + self.padding_allowed, + copy_weights=True + ) return model def convert( @@ -92,7 +110,7 @@ def _convert_qat_linear_8da4w(module: torch.nn.Module): # Load weights and qparams into quantized linear n_bit = 4 - (qmin, qmax) = child._get_qmin_qmax(n_bit) + (qmin, qmax) = _get_qmin_qmax(n_bit) (s, zp) = get_group_qparams_symmetric(child.weight, n_bit, child.groupsize) from torchao._executorch_ops import _quantized_decomposed_quantize_per_channel_group_wrapper q_weight = _quantized_decomposed_quantize_per_channel_group_wrapper( @@ -156,7 +174,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: (act_scales, act_zp) = _choose_qparams_per_token_asymmetric( x, self.scales_precision, self.zero_points_precision, ) - (act_qmin, act_qmax) = self._get_qmin_qmax(8) + (act_qmin, act_qmax) = _get_qmin_qmax(8) x_fq = _fake_quantize_per_token( x, act_scales, act_zp, act_qmin, act_qmax, ) @@ -170,7 +188,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: ) # TODO: pass zp dtype to `get_group_qparams_symmetric` instead weight_zp = weight_zp.to(self.zero_points_precision) - (weight_qmin, weight_qmax) = self._get_qmin_qmax(4) + (weight_qmin, weight_qmax) = _get_qmin_qmax(4) w_fq = _fake_quantize_per_channel_group( self.weight, weight_scales, @@ -183,11 +201,59 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: w_fq = self.weight return F.linear(x_fq, w_fq) - # TODO: move this to common util - def _get_qmin_qmax(self, n_bit: int): - qmin = -(2 ** (n_bit - 1)) - qmax = 2 ** (n_bit - 1) - 1 - return (qmin, qmax) + +class Int4WeightQATEmbedding(torch.nn.Embedding): + """ + This module implements a embedding layer with int4 + + args: + embedding_groupsize: the number of elements in each quantized group for weights + scales_precision: precision of per group scales and zero points + """ + + def __init__(self, + groupsize: int = 32, + scales_precision: torch.dtype = torch.float32, + *args, + **kwargs): + super().__init__(*args, **kwargs) + self.bit_width = 4 + self.groupsize = groupsize + self.scales_precision = scales_precision + self.zero_points_precision = torch.int32 + self.bit_width = 4 + self._fake_quant_enabled = True + + def forward(self, x): + weight = self.weight + + if self._fake_quant_enabled: + (weight_scales, weight_zp) = get_group_qparams_symmetric( + self.weight, self.bit_width, self.groupsize, self.scales_precision, + ) + # TODO: pass zp dtype to `get_group_qparams_symmetric` instead + weight_zp = weight_zp.to(self.zero_points_precision) + (weight_qmin, weight_qmax) = _get_qmin_qmax(self.bit_width) + w_fq = _fake_quantize_per_channel_group( + self.weight, + weight_scales, + weight_zp, + weight_qmin, + weight_qmax, + self.groupsize, + ) + else: + w_fq = self.weight + + return torch.nn.functional.embedding( + x, w_fq, self.padding_idx, self.max_norm, + self.norm_type, self.scale_grad_by_freq, self.sparse) + + def enable_fake_quant(self, enabled: bool = True): + self._fake_quant_enabled = enabled + + def disable_fake_quant(self): + self.enable_fake_quant(False) def enable_8da4w_fake_quant_module_swap(mod: torch.nn.Module): diff --git a/torchao/quantization/prototype/qat/utils.py b/torchao/quantization/prototype/qat/utils.py index 625da4e39..1e4b61b8a 100644 --- a/torchao/quantization/prototype/qat/utils.py +++ b/torchao/quantization/prototype/qat/utils.py @@ -259,3 +259,8 @@ def insert_subclass(lin): return lin return insert_subclass + +def _get_qmin_qmax(n_bit: int): + qmin = -(2 ** (n_bit - 1)) + qmax = 2 ** (n_bit - 1) - 1 + return (qmin, qmax)