Skip to content

Commit

Permalink
Add torchchat quantizer (#897)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #897

This diff adds a quantizer for the new torchao kernels that is similar to the Int8DynActInt4WeightQuantizer quantizer in torchchat (imported from from torchao.quantization.quant_api).  See the draft torchchat PR (pytorch/torchchat#1070) for how this can integrate with torchchat's quantization API.

I confirmed that models quantized with this are compatible with eager, compile, AOTI, and export to ExecuTorch in torchchat.  They do not run on ExecuTorch because we still have not written an ExecuTorch kernel wrapper.

jerryzh168 this does not use the new subclass API, and this is something I'd like to discuss further with you.  I'll set up a sync with you this week, but I wanted to have some API on the table to ground the discussion.

We do not currently have the required C++ methods implemented to support the new subclass API (e.g., we cannot unpack the packed weights from python; they are instead unpacked inline in the kernel).  From a torchchat user's perspective, I do not think this is important, but I'd like to discuss further.

Differential Revision: D62394341
  • Loading branch information
metascroy authored and facebook-github-bot committed Sep 25, 2024
1 parent 44cdd79 commit bdd1486
Show file tree
Hide file tree
Showing 8 changed files with 432 additions and 351 deletions.
8 changes: 4 additions & 4 deletions torchao/experimental/kernels/cpu/aarch64/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@

add_library(
kernel_aarch64
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/find_min_and_max.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/reduction/compute_sum.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/quantization/quantize.cpp
${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64/valpacking/interleave.cpp
)
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ set(CMAKE_BUILD_TYPE Release)
add_compile_options("-Wall" "-Werror")

include(CMakePrintHelpers)
message("TORCHAO_LIBRARIES: ${TORCHAO_LIBRARIES}")
include_directories(${TORCHAO_LIBRARIES})
message("TORCHAO_INCLUDE_DIRS: ${TORCHAO_INCLUDE_DIRS}")
include_directories(${TORCHAO_INCLUDE_DIRS})

add_subdirectory(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)
add_subdirectory(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/aarch64 ${CMAKE_CURRENT_BINARY_DIR}/kernel_aarch64)

include(${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/Utils.cmake)
include(${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/Utils.cmake)

set(PLATFORM "ATEN" CACHE STRING "Choose platform surface: ATEN, EXECUTORCH")
string(TOUPPER ${PLATFORM} PLATFORM_TO_UPPER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
# LICENSE file in the root directory of this source tree.

SCRIPT_DIR=$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" &> /dev/null && pwd)
export TORCHAO_LIBRARIES=${SCRIPT_DIR}/../../../../../../..
export TORCHAO_INCLUDE_DIRS=${SCRIPT_DIR}/../../../../../../..

export CMAKE_PREFIX_PATH="$(python -c 'import torch.utils; print(torch.utils.cmake_prefix_path)')"
echo "CMAKE_PREFIX_PATH: ${CMAKE_PREFIX_PATH}"
export CMAKE_OUT=/tmp/cmake-out/torch_ao/examples/torch_custom_op
cmake -DTORCHAO_LIBRARIES=${TORCHAO_LIBRARIES} \
export CMAKE_OUT=/tmp/cmake-out/torchao
cmake -DTORCHAO_INCLUDE_DIRS=${TORCHAO_INCLUDE_DIRS} \
-DCMAKE_PREFIX_PATH=${CMAKE_PREFIX_PATH} \
-DPLATFORM="ATEN" \
-S ${TORCHAO_LIBRARIES}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
-S ${TORCHAO_INCLUDE_DIRS}/torchao/experimental/kernels/cpu/linear/examples/torch_custom_op \
-B ${CMAKE_OUT}
cmake --build ${CMAKE_OUT}
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,21 @@
# LICENSE file in the root directory of this source tree.

import copy
import glob
import os

import sys

import torch
from torch_custom_op import (
linear_a8sz_w_lowbit_reference_impl,
replace_linear_with_quantized_linear,

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
)
from quant_api import Int8DynActIntxWeightQuantizer

libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
torch.ops.load_library(libs[0])

group_size = 256
m = 1
Expand All @@ -27,15 +36,15 @@

print("Quantizing random model")
quantized_model = copy.deepcopy(model)
quantized_model = quantized_model.eval()
replace_linear_with_quantized_linear(
quantized_model,
kwargs={
"group_size": group_size,
"nbit": nbit,
"has_weight_zeros": has_weight_zeros,
},
quantizer = Int8DynActIntxWeightQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
groupsize=group_size,
has_weight_zeros=has_weight_zeros,
)
quantized_model = quantizer.quantize(quantized_model)
quantized_model = quantized_model.eval()

print("Creating random activations")
activations = torch.randn(m, k, dtype=torch.float32)
Expand All @@ -58,44 +67,3 @@
print("Running AOTI")
fn = torch._export.aot_load("/tmp/torch_custom_op_example_model.so", "cpu")
fn(activations)


print("\nChecking correctness on layer 0")
linear = model[0]
quantized_linear = quantized_model[0]

with torch.no_grad():
result = quantized_linear(activations)
expected_result = linear_a8sz_w_lowbit_reference_impl(
linear.weight, activations, group_size, nbit, has_weight_zeros
)
non_quantized_result = linear(activations)


# Check that entries in result match entries in expected_result
num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
actual_val = result.reshape(-1)[i]
expected_val = expected_result.reshape(-1)[i]
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

# If results are not close at a relaxed tolerance, exit with failure
if not torch.allclose(actual_val, expected_val, atol=1e-6):
assert False, "Correctness check failed"

# Assert at most 5% of entries are not close at a low tolerance
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
print(
"Correctness check passed. All results are close, and ",
(num_total - num_mismatch_at_low_tol),
"/",
num_total,
" entries are close at a low tolerance.",
)
print("Quantization errors:")
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
print("\tquantized_result[0:5]: ", result[0][0:5])
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy

import glob
import os

import sys
import unittest

import torch

sys.path.insert(
0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../../../../.."))
)
from quant_api import (
_Int8DynActIntxWeightQuantizedLinearFallback,
Int8DynActIntxWeightQuantizer,
)

libs = glob.glob("/tmp/cmake-out/torchao/liblowbit_op_aten.*")
libs = list(filter(lambda l: (l.endswith("so") or l.endswith("dylib")), libs))
if len(libs) == 0:
print(
"Could not find library lowbit_op_aten; please run `sh build_custom_op.sh` to build the library. A slow fallback kernel will be used instaed."
)
else:
torch.ops.load_library(libs[0])


class TestInt8DynActIntxWeightQuantizer(unittest.TestCase):
def test_accuracy(self):
group_size = 128
m = 1
n = 1071
k = 4096
activations = torch.randn(m, k, dtype=torch.float32)
model = torch.nn.Sequential(*[torch.nn.Linear(k, n, bias=False)])

for nbit in [1, 2, 3, 4, 5, 6, 7]:
for has_weight_zeros in [True, False]:
print(f"Testing nbit={nbit}, has_weight_zeros={has_weight_zeros}")
quantized_model = copy.deepcopy(model)
quantizer = Int8DynActIntxWeightQuantizer(
device="cpu",
precision=torch.float32,
bitwidth=nbit,
groupsize=group_size,
has_weight_zeros=has_weight_zeros,
)
quantized_model = quantizer.quantize(quantized_model)

with torch.no_grad():
result = quantized_model(activations)
reference_impl = _Int8DynActIntxWeightQuantizedLinearFallback()
reference_impl.quantize_and_pack_weights(
model[0].weight, nbit, group_size, has_weight_zeros
)
expected_result = reference_impl(activations)

num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
actual_val = result.reshape(-1)[i]
expected_val = expected_result.reshape(-1)[i]
self.assertTrue(torch.allclose(actual_val, expected_val, atol=1e-6))
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

# Assert at most 5% of entries are not close at a low tolerance
self.assertTrue(num_mismatch_at_low_tol / num_total <= 0.05)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit bdd1486

Please sign in to comment.