diff --git a/quantize.py b/quantize.py index 635d473f69..04364fe801 100644 --- a/quantize.py +++ b/quantize.py @@ -688,7 +688,7 @@ def forward(self, indices: torch.Tensor) -> torch.Tensor: weight_odd = self.weight.remainder(16) weight_unpacked = torch.stack((weight_even, weight_odd), dim=-1) weight = weight_unpacked.view(self.weight.shape[0], -1) - weight = weight.view(torch.int8).add(-8) + weight = weight.to(torch.int8).add(-8) else: weight = self.weight