Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introducing 1-bit quantization for Llama in torchchat (#910) #911

Merged
merged 1 commit into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <benchmark/benchmark.h>

#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint1.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
Expand All @@ -16,6 +17,128 @@

namespace {

// Benchmark utility to compare variants of uint1 packing
void pack_uint1_values(
uint8_t* packed,
uint8_t* unpacked,
int packed_size,
int unpacked_size,
int variant) {
constexpr int nbit = 1;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;
uint8x16_t unpacked4;
uint8x16_t unpacked5;
uint8x16_t unpacked6;
uint8x16_t unpacked7;

switch (variant) {
case 8:
for (int i = 0; i < unpacked_size; i += 8) {
torchao::bitpacking::internal::pack_8_uint1_values(
packed + ((i * nbit) / bitsPerByte), unpacked + i);
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
torchao::bitpacking::internal::vec_pack_64_uint1_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0,
unpacked1,
unpacked2,
unpacked3);
}
break;
case 128:
for (int i = 0; i < unpacked_size; i += 128) {
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked0, unpacked1, unpacked2, unpacked3, unpacked + i);
torchao::bitpacking::internal::vec_load_64_uint8_values(
unpacked4, unpacked5, unpacked6, unpacked7, unpacked + i + 64);
torchao::bitpacking::internal::vec_pack_128_uint1_values(
packed + ((i * nbit) / bitsPerByte),
unpacked0,
unpacked1,
unpacked2,
unpacked3,
unpacked4,
unpacked5,
unpacked6,
unpacked7);
}
break;
}
}

// Benchmark utility to compare variants of uint1 packing
void unpack_uint1_values(
uint8_t* unpacked,
uint8_t* packed,
int unpacked_size,
int packed_size,
int variant) {
constexpr int nbit = 1;
constexpr int bitsPerByte = 8;
assert(unpacked_size * nbit / bitsPerByte == packed_size);
assert(packed_size % variant == 0);

uint8x16_t unpacked0;
uint8x16_t unpacked1;
uint8x16_t unpacked2;
uint8x16_t unpacked3;
uint8x16_t unpacked4;
uint8x16_t unpacked5;
uint8x16_t unpacked6;
uint8x16_t unpacked7;

switch (variant) {
case 8:
for (int i = 0; i < unpacked_size; i += 8) {
torchao::bitpacking::internal::unpack_8_uint1_values(
unpacked + i, packed + ((i * nbit) / bitsPerByte));
}
break;
case 64:
for (int i = 0; i < unpacked_size; i += 64) {
torchao::bitpacking::internal::vec_unpack_64_uint1_values(
unpacked0,
unpacked1,
unpacked2,
unpacked3,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
}
break;
case 128:
for (int i = 0; i < unpacked_size; i += 128) {
torchao::bitpacking::internal::vec_unpack_128_uint1_values(
unpacked0,
unpacked1,
unpacked2,
unpacked3,
unpacked4,
unpacked5,
unpacked6,
unpacked7,
packed + ((i * nbit) / bitsPerByte));
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i, unpacked0, unpacked1, unpacked2, unpacked3);
torchao::bitpacking::internal::vec_store_64_uint8_values(
unpacked + i + 64, unpacked4, unpacked5, unpacked6, unpacked7);
}
break;
}
}

// Benchmark utility to compare variants of uint2 packing
void pack_uint2_values(
uint8_t* packed,
Expand Down Expand Up @@ -470,6 +593,44 @@ void unpack_uint5_values(

} // namespace

static void benchmark_pack_uint1_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 1;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint1_values(
packed.data(), unpacked.data(), packed_size, unpacked_size, variant);
}
}

static void benchmark_unpack_uint1_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
int nbit = 1;

assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = torchao::get_random_lowbit_vector(packed_size, 8);
auto unpacked = std::vector<uint8_t>(unpacked_size, 0);

for (auto _ : state) {
unpack_uint1_values(
unpacked.data(),
packed.data(),
unpacked.size(),
packed.size(),
variant);
}
}

static void benchmark_pack_uint2_values(benchmark::State& state) {
int unpacked_size = state.range(0);
int variant = state.range(1);
Expand All @@ -478,8 +639,8 @@ static void benchmark_pack_uint2_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint2_values(
Expand Down Expand Up @@ -516,8 +677,8 @@ static void benchmark_pack_uint3_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint3_values(
Expand Down Expand Up @@ -554,8 +715,8 @@ static void benchmark_pack_uint4_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint4_values(
Expand Down Expand Up @@ -592,8 +753,8 @@ static void benchmark_pack_uint5_values(benchmark::State& state) {
assert(unpacked_size % 8 == 0);
int packed_size = (unpacked_size / 8) * nbit;

auto packed = std::vector<uint8_t>(unpacked_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(packed_size, 8);
auto packed = std::vector<uint8_t>(packed_size, 0);
auto unpacked = torchao::get_random_lowbit_vector(unpacked_size, nbit);

for (auto _ : state) {
pack_uint5_values(
Expand Down Expand Up @@ -622,6 +783,8 @@ static void benchmark_unpack_uint5_values(benchmark::State& state) {
}
}

BENCHMARK(benchmark_pack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_unpack_uint1_values)->ArgsProduct({{128}, {8, 64, 128}});
BENCHMARK(benchmark_pack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_unpack_uint2_values)->ArgsProduct({{128}, {4, 32, 64}});
BENCHMARK(benchmark_pack_uint3_values)->ArgsProduct({{128}, {8, 64, 128}});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,8 @@ channelwise_8bit_activation_groupwise_lowbit_weight_1x8x16_f32_neondot(
false>) \
->ArgsProduct(BENCHMARK_PARAMS)

BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
Expand All @@ -236,6 +238,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x1x32_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
Expand All @@ -244,6 +248,8 @@ BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT
4);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
5);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x4x16_F32_NEONDOT(
1);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
2);
BENCHMARK_CHANNELWISE_8BIT_ACTIVATION_GROUPWISE_LOWBIT_WEIGHT_1x8x16_F32_NEONDOT(
Expand Down
Loading