Skip to content

Commit

Permalink
Optimize 3-bit packing (#1029)
Browse files Browse the repository at this point in the history
Summary:

Optimizes 3-bit packing as outlined here: T199311618

Before change:
----------------------------------------------------------------------------------
benchmark_pack_uint_values<3>/128/8           47.0 ns         46.4 ns     15106555
benchmark_pack_uint_values<3>/128/64          6.94 ns         6.90 ns    101226284
benchmark_pack_uint_values<3>/128/128         3.27 ns         3.24 ns    215022716
benchmark_unpack_uint_values<3>/128/8         22.0 ns         21.9 ns     32585572
benchmark_unpack_uint_values<3>/128/64        6.02 ns         5.98 ns    116910230
benchmark_unpack_uint_values<3>/128/128       2.74 ns         2.73 ns    257088291

After change:
----------------------------------------------------------------------------------
benchmark_pack_uint_values<3>/128/8           19.5 ns         19.5 ns     36050883
benchmark_pack_uint_values<3>/128/64          3.90 ns         3.87 ns    181151919
benchmark_pack_uint_values<3>/128/128         1.57 ns         1.57 ns    447247194
benchmark_unpack_uint_values<3>/128/8         20.5 ns         20.4 ns     34490914
benchmark_unpack_uint_values<3>/128/64        3.19 ns         3.11 ns    228019714
benchmark_unpack_uint_values<3>/128/128       1.71 ns         1.70 ns    408587338

Unpacking perf for 128 values is 1.60x faster (2.74/1.71).

Reviewed By: digantdesai

Differential Revision: D64010666
  • Loading branch information
metascroy authored and facebook-github-bot committed Oct 7, 2024
1 parent dec0313 commit af0ea95
Show file tree
Hide file tree
Showing 4 changed files with 143 additions and 211 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint2.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint4.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint5.h>
#include <torchao/experimental/kernels/cpu/aarch64/bitpacking/uint6.h>
#include <torchao/experimental/kernels/cpu/aarch64/tests/test_utils.h>
#include <cassert>
Expand Down
2 changes: 2 additions & 0 deletions torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values(
vget_high_u8(shifted0),
vget_low_u8(shifted1),
vget_high_u8(shifted1));
break;
case 3:
uint8_t buffer3[32];
vst1q_u8(buffer3, shifted0);
Expand Down Expand Up @@ -185,6 +186,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values(
shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed);
shifted0 = vcombine_u8(shifted0_low, shifted0_high);
shifted1 = vcombine_u8(shifted1_low, shifted1_high);
break;
case 3:
uint8_t buffer3[32];
torchao::bitpacking::internal::unpack_8_uint3_values(buffer3, packed);
Expand Down
Loading

0 comments on commit af0ea95

Please sign in to comment.