Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 2, 3, 4, 5 bit custom ops #828

Merged
merged 1 commit into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,21 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

from torch_custom_op import quantize, replace_linear_with_quantized_linear
import torch
import copy

group_size = 16
import torch
from torch_custom_op import (
linear_a8sz_w_lowbit_reference_impl,
replace_linear_with_quantized_linear,
)

group_size = 256
m = 1
n = 4096
k = 4096
nbit = 4
n_layers = 10
nbit = 5
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering is this uint5 or int5?

Copy link
Contributor Author

@metascroy metascroy Sep 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The 5bit kernel has quantized range [-16, 15].

has_weight_zeros = True
n_layers = 5

print("Creating random model")
layers = [torch.nn.Linear(k, n, bias=False) for _ in range(n_layers)]
Expand All @@ -22,8 +27,15 @@

print("Quantizing random model")
quantized_model = copy.deepcopy(model)
quantized_model = quantized_model.eval()
replace_linear_with_quantized_linear(quantized_model, kwargs={"group_size": group_size, "nbit": nbit})
quantized_model = quantized_model.eval()
replace_linear_with_quantized_linear(
quantized_model,
kwargs={
"group_size": group_size,
"nbit": nbit,
"has_weight_zeros": has_weight_zeros,
},
)

print("Creating random activations")
activations = torch.randn(m, k, dtype=torch.float32)
Expand All @@ -48,36 +60,42 @@
fn(activations)


print("Checking correctness on layer 0")

rtol=1e-05

# default is 1e-8, but PyTorch and C++ (and ARM neon) have different rounding
# conventions for ties (PyTorch rounds half to even and C++ rounds half to odd)
# TODO(T200109708): address this
atol=1e-05

print("\nChecking correctness on layer 0")
linear = model[0]
quantized_linear = quantized_model[0]
weight_qvals, weight_scales = quantize(linear.weight, group_size, quantized_linear.nbit, scale_only=True)

activation_qvals, activations_scales, activations_zeros = quantize(activations, k, 8, False)
activations_dequantized = activations_scales * (activation_qvals - activations_zeros)
weights_dequantized = (weight_qvals.reshape(-1, group_size) * weight_scales.reshape(-1, 1)).reshape(n, k)

with torch.no_grad():
result = quantized_linear(activations)
expected_result = torch.matmul(activations_dequantized, weights_dequantized.transpose(1, 0))
expected_result = linear_a8sz_w_lowbit_reference_impl(
linear.weight, activations, group_size, nbit, has_weight_zeros
)
non_quantized_result = linear(activations)

if not (torch.allclose(result, expected_result, rtol=rtol, atol=atol)):
rand_idxs = torch.randint(0, result.shape[1], (5,))
print("rand_idxs: ", rand_idxs)
print("kernel_result[rand_idxs]: ", result[0][rand_idxs])
print("expected_result[rand_idxs]: ", expected_result[0][rand_idxs])
assert False
else:
print("Correctness check passed")

print("kernel_result[0:5]: ", result[0][0:5])
print("non_quantized_result[0:5]: ", non_quantized_result[0][0:5])

# Check that entries in result match entries in expected_result
num_mismatch_at_low_tol = 0
num_total = result.reshape(-1).shape[0]
for i in range(num_total):
actual_val = result.reshape(-1)[i]
expected_val = expected_result.reshape(-1)[i]
if not torch.allclose(actual_val, expected_val):
num_mismatch_at_low_tol += 1

# If results are not close at a relaxed tolerance, exit with failure
if not torch.allclose(actual_val, expected_val, atol=1e-6):
assert False, "Correctness check failed"

# Assert at most 5% of entries are not close at a low tolerance
assert num_mismatch_at_low_tol / num_total <= 0.05, "Correctness check failed"
print(
"Correctness check passed. All results are close, and ",
(num_total - num_mismatch_at_low_tol),
"/",
num_total,
" entries are close at a low tolerance.",
)
print("Quantization errors:")
print("\tL1 error: ", torch.mean(torch.abs(result - non_quantized_result)).item())
print("\tL2 error: ", torch.mean((result - non_quantized_result) ** 2).item())
print("\tquantized_result[0:5]: ", result[0][0:5])
print("\tnon_quantized_result[0:5]: ", non_quantized_result[0][0:5])
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include <torchao/experimental/kernels/cpu/parallel.h>

template <int weight_nbit>
at::Tensor pack_weights_cpu(
at::Tensor pack_weights_without_zeros_cpu(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
// TODO(T200095131): convert to int64_t when supported by AOTI
Expand Down Expand Up @@ -54,9 +54,8 @@ at::Tensor pack_weights_cpu(

auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
auto options = torch::TensorOptions().dtype(torch::kInt8);

at::Tensor packed_weights = torch::empty({packed_weight_data_size}, options);
at::Tensor packed_weights =
torch::empty({packed_weight_data_size}, torch::kInt8);
pack_weight_data_operator(
ukernel_config,
pack_weight_tiling_params,
Expand All @@ -72,7 +71,74 @@ at::Tensor pack_weights_cpu(
}

template <int weight_nbit>
at::Tensor pack_weights_meta(
at::Tensor pack_weights_with_zeros_cpu(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
const at::Tensor& weight_zeros,
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
const at::Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);

TORCH_CHECK(
weight_qvals.dtype() == torch::kInt8, "weight_qvals must be int8");
TORCH_CHECK(weight_qvals.dim() == 2, "weight_qvals must be 2D");

// In PyTorch, weights are nxk in row-major format (with activations being
// right-multiplied).
// In kernel, activations are left-multiplied by kxn transposed
// weights in column-major format.
// Note the underlying data is the same in both cases
int n = weight_qvals.size(0);
int k = weight_qvals.size(1);

TORCH_CHECK(
weight_scales.dtype() == torch::kFloat32,
"weight_scales must be float32");
TORCH_CHECK(weight_scales.dim() == 1, "weight_scales must be 1D");
TORCH_CHECK(
weight_scales.size(0) == ((n * k) / group_size),
"expected 1 scale per group");
TORCH_CHECK(
weight_zeros.dtype() == torch::kInt8, "weight_zeros must be int8");
TORCH_CHECK(weight_zeros.dim() == 1, "weight_zeros must be 1D");
TORCH_CHECK(
weight_zeros.size(0) == ((n * k) / group_size),
"expected 1 zero per group");

using namespace torchao::operators::cpu::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;

auto ukernel_config = get_ukernel_config<
weight_nbit,
true /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();
auto pack_weight_tiling_params = get_default_pack_weight_data_tiling_params(
ukernel_config, n, /*target_panels_per_thread=*/1);

torchao::set_num_threads(torch::get_num_threads());

auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
at::Tensor packed_weights =
torch::empty({packed_weight_data_size}, torch::kInt8);
pack_weight_data_operator(
ukernel_config,
pack_weight_tiling_params,
packed_weights.data_ptr<int8_t>(),
n,
k,
group_size,
weight_qvals.const_data_ptr<int8_t>(),
weight_scales.const_data_ptr<float>(),
weight_zeros.const_data_ptr<int8_t>());

return packed_weights;
}

template <int weight_nbit>
at::Tensor pack_weights_without_zeros_meta(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
// TODO(T200095131): convert to int64_t when supported by AOTI
Expand All @@ -98,6 +164,33 @@ at::Tensor pack_weights_meta(
}

template <int weight_nbit>
at::Tensor pack_weights_with_zeros_meta(
const at::Tensor& weight_qvals,
const at::Tensor& weight_scales,
const at::Tensor& weight_zeros,
// TODO(T200095131): convert to int64_t when supported by AOTI
// group_size is a meta tensor with size (group_size)
const at::Tensor& group_size_tensor) {
int64_t group_size = group_size_tensor.size(0);

int n = weight_qvals.size(0);
int k = weight_qvals.size(1);

using namespace torchao::operators::cpu::linear::
channelwise_8bit_activation_groupwise_lowbit_weight;

auto ukernel_config = get_ukernel_config<
weight_nbit,
true /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();

auto packed_weight_data_size =
get_packed_weight_data_size(ukernel_config, n, k, group_size);
return torch::empty({packed_weight_data_size}).to("meta");
}

template <int weight_nbit, bool has_weight_zeros>
at::Tensor linear_cpu(
const at::Tensor& packed_weights,
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
Expand All @@ -123,7 +216,7 @@ at::Tensor linear_cpu(

auto ukernel_config = get_ukernel_config<
weight_nbit,
false /*has_weight_zeros*/,
has_weight_zeros /*has_weight_zeros*/,
false /*has_bias*/,
false /*has_clamp*/>();
auto linear_tiling_params = get_default_linear_tiling_params(
Expand Down Expand Up @@ -167,7 +260,7 @@ at::Tensor linear_cpu(
return output_tensor;
}

template <int weight_nbit>
template <int weight_nbit, bool has_weight_zeros>
at::Tensor linear_meta(
const at::Tensor& packed_weights,
// TODO(T200095131): convert n_tensor, k_tensor, group_size_tensor to
Expand All @@ -187,26 +280,78 @@ at::Tensor linear_meta(
}

TORCH_LIBRARY(torchao, m) {
// Pack weights without zeros
m.def(
"_pack_weights_a8sz_w2s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w3s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w4s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w5s(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
// Pack weights with zeros
m.def(
"_pack_weights_a8sz_w2sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w3sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w4sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
m.def(
"_pack_weights_a8sz_w5sz(Tensor weight_qvals, Tensor weight_scales, Tensor weight_zeros, Tensor group_size) -> Tensor");
// Linear weights without zeros
m.def(
"_linear_a8sz_w2s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_a8sz_w3s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_a8sz_w4s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_a8sz_w5s(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
// Linear weights with zeros
m.def(
"_pack_weights_3bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
"_linear_a8sz_w2sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_3bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
"_linear_a8sz_w3sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_pack_weights_4bit(Tensor weight_qvals, Tensor weight_scales, Tensor group_size) -> Tensor");
"_linear_a8sz_w4sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
m.def(
"_linear_4bit(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
"_linear_a8sz_w5sz(Tensor packed_weights, Tensor n, Tensor k, Tensor group_size, Tensor activations) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchao, CPU, m) {
m.impl("_pack_weights_3bit", &pack_weights_cpu<3>);
m.impl("_linear_3bit", &linear_cpu<3>);
m.impl("_pack_weights_4bit", &pack_weights_cpu<4>);
m.impl("_linear_4bit", &linear_cpu<4>);
m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_cpu<2>);
m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_cpu<3>);
m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_cpu<4>);
m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_cpu<5>);
m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_cpu<2>);
m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_cpu<3>);
m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_cpu<4>);
m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_cpu<5>);
m.impl("_linear_a8sz_w2s", &linear_cpu<2, false>);
m.impl("_linear_a8sz_w3s", &linear_cpu<3, false>);
m.impl("_linear_a8sz_w4s", &linear_cpu<4, false>);
m.impl("_linear_a8sz_w5s", &linear_cpu<5, false>);
m.impl("_linear_a8sz_w2sz", &linear_cpu<2, true>);
m.impl("_linear_a8sz_w3sz", &linear_cpu<3, true>);
m.impl("_linear_a8sz_w4sz", &linear_cpu<4, true>);
m.impl("_linear_a8sz_w5sz", &linear_cpu<5, true>);
}

TORCH_LIBRARY_IMPL(torchao, Meta, m) {
m.impl("_pack_weights_3bit", &pack_weights_meta<3>);
m.impl("_linear_3bit", &linear_meta<3>);
m.impl("_pack_weights_4bit", &pack_weights_meta<4>);
m.impl("_linear_4bit", &linear_meta<4>);
m.impl("_pack_weights_a8sz_w2s", &pack_weights_without_zeros_meta<2>);
m.impl("_pack_weights_a8sz_w3s", &pack_weights_without_zeros_meta<3>);
m.impl("_pack_weights_a8sz_w4s", &pack_weights_without_zeros_meta<4>);
m.impl("_pack_weights_a8sz_w5s", &pack_weights_without_zeros_meta<5>);
m.impl("_pack_weights_a8sz_w2sz", &pack_weights_with_zeros_meta<2>);
m.impl("_pack_weights_a8sz_w3sz", &pack_weights_with_zeros_meta<3>);
m.impl("_pack_weights_a8sz_w4sz", &pack_weights_with_zeros_meta<4>);
m.impl("_pack_weights_a8sz_w5sz", &pack_weights_with_zeros_meta<5>);
m.impl("_linear_a8sz_w2s", &linear_meta<2, false>);
m.impl("_linear_a8sz_w3s", &linear_meta<3, false>);
m.impl("_linear_a8sz_w4s", &linear_meta<4, false>);
m.impl("_linear_a8sz_w5s", &linear_meta<5, false>);
m.impl("_linear_a8sz_w2sz", &linear_meta<2, true>);
m.impl("_linear_a8sz_w3sz", &linear_meta<3, true>);
m.impl("_linear_a8sz_w4sz", &linear_meta<4, true>);
m.impl("_linear_a8sz_w5sz", &linear_meta<5, true>);
}
Loading
Loading