diff --git a/gptqmodel/models/base.py b/gptqmodel/models/base.py index f7a0dcd9..085a758f 100644 --- a/gptqmodel/models/base.py +++ b/gptqmodel/models/base.py @@ -170,9 +170,8 @@ def quantize( logger.warning("According to the issue https://github.com/ModelCloud/GPTQModel/issues/278, transformers version 4.43.0 has broken batch_size. until the issue is resolved, hard set the batch_size to 1.") batch_size = 1 - # TODO: lm_head quantization is yet ready but pending - if self.quantize_config.lm_head: - raise ValueError("lm_head quantization is currently inference only and not applicable for quantization. Please set `lm_head=False`.") + if self.quantize_config.lm_head and not isinstance(self.quantize_config, AutoRoundQuantizeConfig): + raise ValueError("`lm_head=True` quantization is only available with AutoRound quantizer. Please use `AutoRoundQuantizeConfig` instead of `QuantizeConfig` and set `lm_head=True` or set `lm_head=False`.") if len(calibration_dataset) == 0: raise ValueError("Calibration dataset must not be empty.")