-
Notifications
You must be signed in to change notification settings - Fork 207
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: Pull Request resolved: #897 This diff adds a quantizer for the new torchao kernels that is similar to the Int8DynActInt4WeightQuantizer quantizer in torchchat (imported from from torchao.quantization.quant_api). See the draft torchchat PR (pytorch/torchchat#1070) for how this can integrate with torchchat's quantization API. I confirmed that models quantized with this are compatible with eager, compile, AOTI, and export to ExecuTorch in torchchat. They do not run on ExecuTorch because we still have not written an ExecuTorch kernel wrapper. jerryzh168 this does not use the new subclass API, and this is something I'd like to discuss further with you. I'll set up a sync with you this week, but I wanted to have some API on the table to ground the discussion. We do not currently have the required C++ methods implemented to support the new subclass API (e.g., we cannot unpack the packed weights from python; they are instead unpacked inline in the kernel). From a torchchat user's perspective, I do not think this is important, but I'd like to discuss further. Differential Revision: D62394341
- Loading branch information
1 parent
44cdd79
commit bdd1486
Showing
8 changed files
with
432 additions
and
351 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
56 changes: 0 additions & 56 deletions
56
torchao/experimental/kernels/cpu/linear/examples/torch_custom_op/test_custom_op.py
This file was deleted.
Oops, something went wrong.
79 changes: 79 additions & 0 deletions
79
...al/kernels/cpu/linear/examples/torch_custom_op/test_int8_dyn_act_intx_weight_quantizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
import copy | ||
|
||
import glob | ||
import os | ||
|
||
import sys | ||
import unittest | ||
|
||
import torch | ||
|
||
sys.path.insert( | ||
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../..")) | ||
) | ||
from quant_api import ( | ||
_Int8DynActIntxWeightQuantizedLinearFallback, | ||
Int8DynActIntxWeightQuantizer, | ||
) | ||
|
||
libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*") | ||
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs)) | ||
if len(libs) == 0: | ||
print( | ||
"Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed." | ||
) | ||
else: | ||
torch.ops.load_library(libs[0]) | ||
|
||
|
||
class TestInt8DynActIntxWeightQuantizer(unittest.TestCase): | ||
def test_accuracy(self): | ||
group_size = 128 | ||
m = 1 | ||
n = 1071 | ||
k = 4096 | ||
activations = torch.randn(m, k, dtype=torch.float32) | ||
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)]) | ||
|
||
for nbit in [1, 2, 3, 4, 5, 6, 7]: | ||
for has_weight_zeros in [True, False]: | ||
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}") | ||
quantized_model = copy.deepcopy(model) | ||
quantizer = Int8DynActIntxWeightQuantizer( | ||
device="cpu", | ||
precision=torch.float32, | ||
bitwidth=nbit, | ||
groupsize=group_size, | ||
has_weight_zeros=has_weight_zeros, | ||
) | ||
quantized_model = quantizer.quantize(quantized_model) | ||
|
||
with torch.no_grad(): | ||
result = quantized_model(activations) | ||
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback() | ||
reference_impl.quantize_and_pack_weights( | ||
model[0].weight, nbit, group_size, has_weight_zeros | ||
) | ||
expected_result = reference_impl(activations) | ||
|
||
num_mismatch_at_low_tol = 0 | ||
num_total = result.reshape(-1).shape[0] | ||
for i in range(num_total): | ||
actual_val = result.reshape(-1)[i] | ||
expected_val = expected_result.reshape(-1)[i] | ||
self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6)) | ||
if not torch.allclose(actual_val, expected_val): | ||
num_mismatch_at_low_tol += 1 | ||
|
||
# Assert at most 5% of entries are not close at a low tolerance | ||
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
Oops, something went wrong.