From 4595971e2d31769dfc5c3fe92ff4f63d620c2fb6 Mon Sep 17 00:00:00 2001 From: Manuel Candales Date: Thu, 25 Jul 2024 14:21:20 -0400 Subject: [PATCH] Adapt to _convert_weight_to_int4pack new behavior --- torchao/quantization/GPTQ.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/torchao/quantization/GPTQ.py b/torchao/quantization/GPTQ.py index e45bb26e4d..4a7911d90f 100644 --- a/torchao/quantization/GPTQ.py +++ b/torchao/quantization/GPTQ.py @@ -720,6 +720,8 @@ def _create_quantized_state_dict( self.precision, # dtype for scales_and_zeros ) # TODO: just get the device from mod.weight.device? + w_cpu = w_int4x8.cpu() + w_int4x8 = (w_cpu[::, ::2] << 4 | w_cpu[::, 1::2]).to(torch.uint8) weight_int4pack = torch.ops.aten._convert_weight_to_int4pack(w_int4x8.to(self.device), self.inner_k_tiles) cur_state_dict[f"{fqn}.weight"] = weight_int4pack.to(self.device) cur_state_dict[f"{fqn}.scales_and_zeros"] = scales_and_zeros.to(self.device)