Skip to content

Commit

Permalink
Add 2, 3, 4, 5 bit custom ops (pytorch#828)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#828

Refactors some of the custom op example code to add support for more kernel variants.

Reviewed By: digantdesai

Differential Revision: D62248716
  • Loading branch information
metascroy authored and facebook-github-bot committed Sep 9, 2024
1 parent 1b317f9 commit 7af35e6
Show file tree
Hide file tree
Showing 3 changed files with 360 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from torch_custom_op import quantize, replace_linear_with_quantized_linear
import torch
import copy

group_size = 16
import torch
from torch_custom_op import (
linear_a8sz_w_lowbit_reference_impl,
replace_linear_with_quantized_linear,
)

group_size = 256
m = 1
n = 4096
k = 4096
nbit = 4
n_layers = 10
nbit = 5
has_weight_zeros = True
n_layers = 5

print("Creating random model")
layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)]
Expand All @@ -22,8 +27,15 @@

print("Quantizing random model")
quantized_model = copy.deepcopy(model)
quantized_model = quantized_model.eval()
replace_linear_with_quantized_linear(quantized_model, kwargs={"group_size": group_size, "nbit": nbit})
quantized_model = quantized_model.eval()
replace_linear_with_quantized_linear(
quantized_model,
kwargs={
"group_size": group_size,
"nbit": nbit,
"has_weight_zeros": has_weight_zeros,
},
)

print("Creating random activations")
activations = torch.randn(m, k, dtype=torch.float32)
Expand All @@ -48,36 +60,42 @@
fn(activations)


print("Checking correctness on layer 0")

rtol=1e-05

# default is 1e-8, but PyTorch and C++ (and ARM neon) have different rounding
# conventions for ties (PyTorch rounds half to even and C++ rounds half to odd)
# TODO(T200109708): address this
atol=1e-05

print("\nChecking correctness on layer 0")
linear = model[0]
quantized_linear = quantized_model[0]
weight_qvals, weight_scales = quantize(linear.weight, group_size, quantized_linear.nbit, scale_only=True)

activation_qvals, activations_scales, activations_zeros = quantize(activations, k, 8, False)
activations_dequantized = activations_scales * (activation_qvals - activations_zeros)
weights_dequantized = (weight_qvals.reshape(-1, group_size) * weight_scales.reshape(-1, 1)).reshape(n, k)

with torch.no_grad():
result = quantized_linear(activations)
expected_result = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0))
expected_result = linear_a8sz_w_lowbit_reference_impl(
linear.weight, activations, group_size, nbit, has_weight_zeros
)
non_quantized_result = linear(activations)

if not (torch.allclose(result, expected_result, rtol=rtol, atol=atol)):
rand_idxs = torch.randint(0, result.shape[1], (5,))
print("rand_idxs: ", rand_idxs)
print("kernel_result[rand_idxs]: ", result[0][rand_idxs])
print("expected_result[rand_idxs]: ", expected_result[0][rand_idxs])
assert False
else:
print("Correctness check passed")

print("kernel_result[0:5]: ", result[0][0:5])
print("non_quantized_result[0:5]: ", non_quantized_result[0][0:5])

# Check that entries in result match entries in expected_result
num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
actual_val = result.reshape(-1)[i]
expected_val = expected_result.reshape(-1)[i]
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

# If results are not close at a relaxed tolerance, exit with failure
if not torch.allclose(actual_val, expected_val, atol=1e-6):
assert False, "Correctness check failed"

# Assert at most 5% of entries are not close at a low tolerance
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
print(
"Correctness check passed. All results are close, and ",
(num_total - num_mismatch_at_low_tol),
"/",
num_total,
" entries are close at a low tolerance.",
)
print("Quantization errors:")
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
print("\tquantized_result[0:5]: ", result[0][0:5])
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <torchao/experimental/kernels/cpu/parallel.h>

template <int weight_nbit>
at::Tensor pack_weights_cpu(
at::Tensor pack_weights_without_zeros_cpu(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
// TODO(T200095131): convert to int64_t when supported by AOTI
Expand Down Expand Up @@ -54,9 +54,8 @@ at::Tensor pack_weights_cpu(

auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
auto options = torch::TensorOptions().dtype(torch::kInt8);

at::Tensor packed_weights = torch::empty({packed_weight_data_size}, options);
at::Tensor packed_weights =
torch::empty({packed_weight_data_size}, torch::kInt8);
pack_weight_data_operator(
ukernel_config,
pack_weight_tiling_params,
Expand All @@ -72,7 +71,74 @@ at::Tensor pack_weights_cpu(
}

template <int weight_nbit>
at::Tensor pack_weights_meta(
at::Tensor pack_weights_with_zeros_cpu(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
const at::Tensor& weight_zeros,
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
const at::Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);

TORCH_CHECK(
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");
TORCH_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");

// In PyTorch, weights are nxk in row-major format (with activations being
// right-multiplied).
// In kernel, activations are left-multiplied by kxn transposed
// weights in column-major format.
// Note the underlying data is the same in both cases
int n = weight_qvals.size(0);
int k = weight_qvals.size(1);

TORCH_CHECK(
weight_scales.dtype() == torch::kFloat32,
"weight_scales must be float32");
TORCH_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D");
TORCH_CHECK(
weight_scales.size(0) == ((n * k) / group_size),
"expected 1 scale per group");
TORCH_CHECK(
weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8");
TORCH_CHECK(weight_zeros.dim() == 1, "weight_zeros must be 1D");
TORCH_CHECK(
weight_zeros.size(0) == ((n * k) / group_size),
"expected 1 zero per group");

using namespace torchao::operators::cpu::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;

auto ukernel_config = get_ukernel_config<
weight_nbit,
true /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
ukernel_config, n, /*target_panels_per_thread=*/1);

torchao::set_num_threads(torch::get_num_threads());

auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
at::Tensor packed_weights =
torch::empty({packed_weight_data_size}, torch::kInt8);
pack_weight_data_operator(
ukernel_config,
pack_weight_tiling_params,
packed_weights.data_ptr<int8_t>(),
n,
k,
group_size,
weight_qvals.const_data_ptr<int8_t>(),
weight_scales.const_data_ptr<float>(),
weight_zeros.const_data_ptr<int8_t>());

return packed_weights;
}

template <int weight_nbit>
at::Tensor pack_weights_without_zeros_meta(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
// TODO(T200095131): convert to int64_t when supported by AOTI
Expand All @@ -98,6 +164,33 @@ at::Tensor pack_weights_meta(
}

template <int weight_nbit>
at::Tensor pack_weights_with_zeros_meta(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
const at::Tensor& weight_zeros,
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
const at::Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);

int n = weight_qvals.size(0);
int k = weight_qvals.size(1);

using namespace torchao::operators::cpu::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;

auto ukernel_config = get_ukernel_config<
weight_nbit,
true /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();

auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
return torch::empty({packed_weight_data_size}).to("meta");
}

template <int weight_nbit, bool has_weight_zeros>
at::Tensor linear_cpu(
const at::Tensor& packed_weights,
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
Expand All @@ -123,7 +216,7 @@ at::Tensor linear_cpu(

auto ukernel_config = get_ukernel_config<
weight_nbit,
false /*has_weight_zeros*/,
has_weight_zeros /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();
auto linear_tiling_params = get_default_linear_tiling_params(
Expand Down Expand Up @@ -167,7 +260,7 @@ at::Tensor linear_cpu(
return output_tensor;
}

template <int weight_nbit>
template <int weight_nbit, bool has_weight_zeros>
at::Tensor linear_meta(
const at::Tensor& packed_weights,
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
Expand All @@ -187,26 +280,78 @@ at::Tensor linear_meta(
}

TORCH_LIBRARY(torchao, m) {
// Pack weights without zeros
m.def(
"_pack_weights_a8sz_w2s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w3s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w4s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w5s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
// Pack weights with zeros
m.def(
"_pack_weights_a8sz_w2sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w3sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w4sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w5sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
// Linear weights without zeros
m.def(
"_linear_a8sz_w2s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_a8sz_w3s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_a8sz_w4s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_a8sz_w5s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
// Linear weights with zeros
m.def(
"_pack_weights_3bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
"_linear_a8sz_w2sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_3bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
"_linear_a8sz_w3sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_pack_weights_4bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
"_linear_a8sz_w4sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_4bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
"_linear_a8sz_w5sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("_pack_weights_3bit", &pack_weights_cpu<3>);
m.impl("_linear_3bit", &linear_cpu<3>);
m.impl("_pack_weights_4bit", &pack_weights_cpu<4>);
m.impl("_linear_4bit", &linear_cpu<4>);
m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_cpu<2>);
m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_cpu<3>);
m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_cpu<4>);
m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_cpu<5>);
m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_cpu<2>);
m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_cpu<3>);
m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_cpu<4>);
m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_cpu<5>);
m.impl("_linear_a8sz_w2s", &linear_cpu<2, false>);
m.impl("_linear_a8sz_w3s", &linear_cpu<3, false>);
m.impl("_linear_a8sz_w4s", &linear_cpu<4, false>);
m.impl("_linear_a8sz_w5s", &linear_cpu<5, false>);
m.impl("_linear_a8sz_w2sz", &linear_cpu<2, true>);
m.impl("_linear_a8sz_w3sz", &linear_cpu<3, true>);
m.impl("_linear_a8sz_w4sz", &linear_cpu<4, true>);
m.impl("_linear_a8sz_w5sz", &linear_cpu<5, true>);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
m.impl("_pack_weights_3bit", &pack_weights_meta<3>);
m.impl("_linear_3bit", &linear_meta<3>);
m.impl("_pack_weights_4bit", &pack_weights_meta<4>);
m.impl("_linear_4bit", &linear_meta<4>);
m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_meta<2>);
m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_meta<3>);
m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_meta<4>);
m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_meta<5>);
m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_meta<2>);
m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_meta<3>);
m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_meta<4>);
m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_meta<5>);
m.impl("_linear_a8sz_w2s", &linear_meta<2, false>);
m.impl("_linear_a8sz_w3s", &linear_meta<3, false>);
m.impl("_linear_a8sz_w4s", &linear_meta<4, false>);
m.impl("_linear_a8sz_w5s", &linear_meta<5, false>);
m.impl("_linear_a8sz_w2sz", &linear_meta<2, true>);
m.impl("_linear_a8sz_w3sz", &linear_meta<3, true>);
m.impl("_linear_a8sz_w4sz", &linear_meta<4, true>);
m.impl("_linear_a8sz_w5sz", &linear_meta<5, true>);
}
Loading

0 comments on commit 7af35e6

Please sign in to comment.