diff --git a/awq/modules/linear/gemm.py b/awq/modules/linear/gemm.py index 1eb942c2..6efc7ee1 100644 --- a/awq/modules/linear/gemm.py +++ b/awq/modules/linear/gemm.py @@ -253,7 +253,6 @@ def forward(self, x): if input_dtype != torch.float16: out = out.to(dtype=input_dtype) - out = out + self.bias if self.bias is not None else out return out.reshape(out_shape) def extra_repr(self) -> str: