diff --git a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp index 2fd2f5391c..b602146ecf 100644 --- a/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/benchmarks/benchmark_bitpacking.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h index b5276b94ce..2cff223268 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/bitpack.h @@ -109,6 +109,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_32_lowbit_values( vget_high_u8(shifted0), vget_low_u8(shifted1), vget_high_u8(shifted1)); + break; case 3: uint8_t buffer3[32]; vst1q_u8(buffer3, shifted0); @@ -185,6 +186,7 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_32_lowbit_values( shifted0_low, shifted0_high, shifted1_low, shifted1_high, packed); shifted0 = vcombine_u8(shifted0_low, shifted0_high); shifted1 = vcombine_u8(shifted1_low, shifted1_high); + break; case 3: uint8_t buffer3[32]; torchao::bitpacking::internal::unpack_8_uint3_values(buffer3, packed); diff --git a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h index 41cc1d0b1a..9a42bdb002 100644 --- a/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h +++ b/torchao/experimental/kernels/cpu/aarch64/bitpacking/uint3.h @@ -22,29 +22,28 @@ namespace internal { TORCHAO_ALWAYS_INLINE inline void pack_8_uint3_values( uint8_t* packed, const uint8_t* unpacked) { - // Given 8 unpacked uint3 values: 0ab, 1cd, 2ef, 3gh, 4ij, 5kl, 6mn, 7op, + // Given 8 unpacked uint3 values: abc, def, ghi, jkl, mno, pqr, 012, 345 // this function packs them as: - // b2: 7|6|5|4|3|2|1|0 (upper bits for all values) - // b10_0: gh|ef|cd|ab (lower 2 bits for first 4 values) - // b10_1: op|mn|kl|ij (lower 2 bits for last 4 values) - // These are stored in packed as: b2, b10_0, b10_1 + // b0: 12|abc|def (bottom bits from 7th value, full bits from 1st/2nd + // value) + // b1: 45|ghi|jkl (bottom bits from 8th value, full bits from + // 3rd/4th value) + // b2: 03|mno|pqr (top bit from 7th/8th value, full bits + // from 5th/6th value) + // These are stored in packed as: b0, b1, b2 // // Input is 8 bytes // Output is 24 bits = 3 bytes + // b0 + packed[0] = ((unpacked[6] & 3) << 6) | ((unpacked[0] & 7) << 3) | unpacked[1]; + + // b1 + packed[1] = ((unpacked[7] & 3) << 6) | ((unpacked[2] & 7) << 3) | unpacked[3]; + // b2 - packed[0] = ((unpacked[0] & 4) >> 2) | ((unpacked[1] & 4) >> 1) | - ((unpacked[2] & 4)) | ((unpacked[3] & 4) << 1) | - ((unpacked[4] & 4) << 2) | ((unpacked[5] & 4) << 3) | - ((unpacked[6] & 4) << 4) | ((unpacked[7] & 4) << 5); - - // b10_0 - packed[1] = (unpacked[0] & 3) | ((unpacked[1] & 3) << 2) | - ((unpacked[2] & 3) << 4) | ((unpacked[3] & 3) << 6); - - // b10_1 - packed[2] = (unpacked[4] & 3) | ((unpacked[5] & 3) << 2) | - ((unpacked[6] & 3) << 4) | ((unpacked[7] & 3) << 6); + packed[2] = ((unpacked[6] & 4) << 5) | ((unpacked[7] & 4) << 4) | + ((unpacked[4] & 7) << 3) | unpacked[5]; } TORCHAO_ALWAYS_INLINE inline void unpack_8_uint3_values( @@ -55,27 +54,29 @@ TORCHAO_ALWAYS_INLINE inline void unpack_8_uint3_values( // Input is 24 bits = 3 bytes // Output is 8 bytes - uint8_t b2 = packed[0]; - uint8_t b10_0 = packed[1]; - uint8_t b10_1 = packed[2]; + uint8_t b0 = packed[0]; + uint8_t b1 = packed[1]; + uint8_t b2 = packed[2]; + + unpacked[0] = ((b0 >> 3) & 7); + unpacked[1] = b0 & 7; - unpacked[0] = ((b2 & 1) << 2) | (b10_0 & 3); - unpacked[1] = ((b2 & 2) << 1) | ((b10_0 & 12) >> 2); - unpacked[2] = (b2 & 4) | ((b10_0 & 48) >> 4); - unpacked[3] = ((b2 & 8) >> 1) | ((b10_0 & 192) >> 6); + unpacked[2] = ((b1 >> 3) & 7); + unpacked[3] = b1 & 7; - unpacked[4] = ((b2 & 16) >> 2) | (b10_1 & 3); - unpacked[5] = ((b2 & 32) >> 3) | ((b10_1 & 12) >> 2); - unpacked[6] = ((b2 & 64) >> 4) | ((b10_1 & 48) >> 4); - unpacked[7] = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6); + unpacked[4] = ((b2 >> 3) & 7); + unpacked[5] = b2 & 7; + + unpacked[6] = (b0 >> 6) | ((b2 >> 5) & 4); + unpacked[7] = (b1 >> 6) | ((b2 >> 4) & 4); } TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint3_values( uint8_t* packed, - const uint8x16_t& unpacked0, - const uint8x16_t& unpacked1, - const uint8x16_t& unpacked2, - const uint8x16_t& unpacked3) { + const uint8x16_t& unpacked0, // 0, 1 + const uint8x16_t& unpacked1, // 2, 3 + const uint8x16_t& unpacked2, // 4, 5 + const uint8x16_t& unpacked3) { // 6, 7 // This function is a vectorized version of pack_8_uint3_values // To understand it, please see pack_8_uint3_values first. // Before each code section, there is a comment indicating the @@ -84,62 +85,38 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_64_uint3_values( // Input is 64 bytes // Output is 3*64= 192 bits = 24 bytes - uint8x8_t b2; - uint8x8_t mask; + uint8x8_t b; + // b0 + // packed[0] = ((unpacked[6] & 3) << 6) | ((unpacked[0] & 7) << 3) | + // unpacked[1]; + b = vshl_n_u8(vand_u8(vget_low_u8(unpacked3), vdup_n_u8(3)), 6); + b = vorr_u8(b, vshl_n_u8(vand_u8(vget_low_u8(unpacked0), vdup_n_u8(7)), 3)); + b = vorr_u8(b, vget_high_u8(unpacked0)); + vst1_u8(packed, b); + + // b1 + // packed[1] = ((unpacked[7] & 3) << 6) | ((unpacked[2] & 7) << 3) | + // unpacked[3]; + b = vshl_n_u8(vand_u8(vget_high_u8(unpacked3), vdup_n_u8(3)), 6); + b = vorr_u8(b, vshl_n_u8(vand_u8(vget_low_u8(unpacked1), vdup_n_u8(7)), 3)); + b = vorr_u8(b, vget_high_u8(unpacked1)); + vst1_u8(packed + 8, b); // b2 - // packed[0] = ((unpacked[0] & 4) >> 2) | ((unpacked[1] & 4) >> 1) | - // ((unpacked[2] & 4)) | ((unpacked[3] & 4) << 1) | - // ((unpacked[4] & 4) << 2) | ((unpacked[5] & 4) << 3) | - // ((unpacked[6] & 4) << 4) | ((unpacked[7] & 4) << 5); - mask = vdup_n_u8(4); - b2 = vshr_n_u8(vand_u8(vget_low_u8(unpacked0), mask), 2); - b2 = vorr_u8(b2, vshr_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 1)); - - b2 = vorr_u8(b2, vand_u8(vget_low_u8(unpacked1), mask)); - b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 1)); - - b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_low_u8(unpacked2), mask), 2)); - b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_high_u8(unpacked2), mask), 3)); - - b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_low_u8(unpacked3), mask), 4)); - b2 = vorr_u8(b2, vshl_n_u8(vand_u8(vget_high_u8(unpacked3), mask), 5)); - - vst1_u8(packed, b2); - - // b10_0 - // packed[1] = (unpacked[0] & 3) | ((unpacked[1] & 3) << 2) | - // ((unpacked[2] & 3) << 4) | ((unpacked[3] & 3) << 6); - mask = vdup_n_u8(3); - uint8x8_t b10_0; - - b10_0 = vand_u8(vget_low_u8(unpacked0), mask); - b10_0 = vorr_u8(b10_0, vshl_n_u8(vand_u8(vget_high_u8(unpacked0), mask), 2)); - - b10_0 = vorr_u8(b10_0, vshl_n_u8(vand_u8(vget_low_u8(unpacked1), mask), 4)); - b10_0 = vorr_u8(b10_0, vshl_n_u8(vand_u8(vget_high_u8(unpacked1), mask), 6)); - - vst1_u8(packed + 8, b10_0); - - // b10_1 - // packed[2] = (unpacked[4] & 3) | ((unpacked[5] & 3) << 2) | - // ((unpacked[6] & 3) << 4) | ((unpacked[7] & 3) << 6); - uint8x8_t b10_1; - - b10_1 = vand_u8(vget_low_u8(unpacked2), mask); - b10_1 = vorr_u8(b10_1, vshl_n_u8(vand_u8(vget_high_u8(unpacked2), mask), 2)); - - b10_1 = vorr_u8(b10_1, vshl_n_u8(vand_u8(vget_low_u8(unpacked3), mask), 4)); - b10_1 = vorr_u8(b10_1, vshl_n_u8(vand_u8(vget_high_u8(unpacked3), mask), 6)); - - vst1_u8(packed + 16, b10_1); + // packed[2] = ((unpacked[6] & 4) << 5) | ((unpacked[7] & 4) << 4) | + // ((unpacked[4] & 7) << 3) | unpacked[5]; + b = vshl_n_u8(vand_u8(vget_low_u8(unpacked3), vdup_n_u8(4)), 5); + b = vorr_u8(b, vshl_n_u8(vand_u8(vget_high_u8(unpacked3), vdup_n_u8(4)), 4)); + b = vorr_u8(b, vshl_n_u8(vand_u8(vget_low_u8(unpacked2), vdup_n_u8(7)), 3)); + b = vorr_u8(b, vget_high_u8(unpacked2)); + vst1_u8(packed + 16, b); } TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint3_values( - uint8x16_t& unpacked0, - uint8x16_t& unpacked1, - uint8x16_t& unpacked2, - uint8x16_t& unpacked3, + uint8x16_t& unpacked0, // 0, 1 + uint8x16_t& unpacked1, // 2, 3 + uint8x16_t& unpacked2, // 4, 5 + uint8x16_t& unpacked3, // 6, 7 const uint8_t* packed) { // Unpacks data packed by pack_64_uint3_values // @@ -151,55 +128,43 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_64_uint3_values( // Input is 3*64= 192 bits = 24 bytes // Output is 64 bytes - uint8x8_t b2 = vld1_u8(packed); - uint8x8_t b10_0 = vld1_u8(packed + 8); + uint8x8_t b0 = vld1_u8(packed); + uint8x8_t b1 = vld1_u8(packed + 8); + uint8x8_t b2 = vld1_u8(packed + 16); uint8x8_t unpacked_tmp0; uint8x8_t unpacked_tmp1; - // unpacked[0] = ((b2 & 1) << 2) | (b10_0 & 3); - unpacked_tmp0 = vshl_n_u8(vand_u8(b2, vdup_n_u8(1)), 2); - unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b10_0, vdup_n_u8(3))); - - // unpacked[1] = ((b2 & 2) << 1) | ((b10_0 & 12) >> 2); - unpacked_tmp1 = vshl_n_u8(vand_u8(b2, vdup_n_u8(2)), 1); - unpacked_tmp1 = - vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_0, vdup_n_u8(12)), 2)); + // unpacked[0] = ((b0 >> 3) & 7); + uint8x8_t mask = vdup_n_u8(7); + unpacked_tmp0 = vand_u8(vshr_n_u8(b0, 3), mask); + // unpacked[1] = b0 & 7; + unpacked_tmp1 = vand_u8(b0, mask); unpacked0 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); - // unpacked[2] = (b2 & 4) | ((b10_0 & 48) >> 4); - unpacked_tmp0 = vand_u8(b2, vdup_n_u8(4)); - unpacked_tmp0 = - vorr_u8(unpacked_tmp0, vshr_n_u8(vand_u8(b10_0, vdup_n_u8(48)), 4)); - - // unpacked[3] = ((b2 & 8) >> 1) | ((b10_0 & 192) >> 6); - unpacked_tmp1 = vshr_n_u8(vand_u8(b2, vdup_n_u8(8)), 1); - unpacked_tmp1 = - vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_0, vdup_n_u8(192)), 6)); + // unpacked[2] = ((b1 >> 3) & 7); + unpacked_tmp0 = vand_u8(vshr_n_u8(b1, 3), mask); + // unpacked[3] = b1 & 7; + unpacked_tmp1 = vand_u8(b1, mask); unpacked1 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); - // unpacked[4] = ((b2 & 16) >> 2) | (b10_1 & 3); - uint8x8_t b10_1 = vld1_u8(packed + 16); - unpacked_tmp0 = vshr_n_u8(vand_u8(b2, vdup_n_u8(16)), 2); - unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(b10_1, vdup_n_u8(3))); - - // unpacked[5] = ((b2 & 32) >> 3) | ((b10_1 & 12) >> 2); - unpacked_tmp1 = vshr_n_u8(vand_u8(b2, vdup_n_u8(32)), 3); - unpacked_tmp1 = - vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_1, vdup_n_u8(12)), 2)); + // unpacked[4] = ((b2 >> 3) & 7); + unpacked_tmp0 = vand_u8(vshr_n_u8(b2, 3), mask); + // unpacked[5] = b2 & 7; + unpacked_tmp1 = vand_u8(b2, mask); unpacked2 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); - // unpacked[6] = ((b2 & 64) >> 4) | ((b10_1 & 48) >> 4); - unpacked_tmp0 = vshr_n_u8(vand_u8(b2, vdup_n_u8(64)), 4); - unpacked_tmp0 = - vorr_u8(unpacked_tmp0, vshr_n_u8(vand_u8(b10_1, vdup_n_u8(48)), 4)); + // unpacked[6] = (b0 >> 6) | ((b2 >> 5) & 4); + mask = vdup_n_u8(4); + unpacked_tmp0 = vshr_n_u8(b0, 6); + unpacked_tmp0 = vorr_u8(unpacked_tmp0, vand_u8(vshr_n_u8(b2, 5), mask)); + + // unpacked[7] = (b1 >> 6) | ((b2 >> 4) & 4); + unpacked_tmp1 = vshr_n_u8(b1, 6); + unpacked_tmp1 = vorr_u8(unpacked_tmp1, vand_u8(vshr_n_u8(b2, 4), mask)); - // unpacked[7] = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6); - unpacked_tmp1 = vshr_n_u8(vand_u8(b2, vdup_n_u8(128)), 5); - unpacked_tmp1 = - vorr_u8(unpacked_tmp1, vshr_n_u8(vand_u8(b10_1, vdup_n_u8(192)), 6)); unpacked3 = vcombine_u8(unpacked_tmp0, unpacked_tmp1); } @@ -221,49 +186,31 @@ TORCHAO_ALWAYS_INLINE inline void vec_pack_128_uint3_values( // Input is 128 bytes // Output is 3*128= 384 bits = 48 bytes - uint8x16_t b2; - uint8x16_t mask; + uint8x16_t b; + // b0 + // packed[0] = ((unpacked[6] & 3) << 6) | ((unpacked[0] & 7) << 3) | + // unpacked[1]; + b = vshlq_n_u8(vandq_u8(unpacked6, vdupq_n_u8(3)), 6); + b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked0, vdupq_n_u8(7)), 3)); + b = vorrq_u8(b, unpacked1); + vst1q_u8(packed, b); + + // b1 + // packed[1] = ((unpacked[7] & 3) << 6) | ((unpacked[2] & 7) << 3) | + // unpacked[3]; + b = vshlq_n_u8(vandq_u8(unpacked7, vdupq_n_u8(3)), 6); + b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked2, vdupq_n_u8(7)), 3)); + b = vorrq_u8(b, unpacked3); + vst1q_u8(packed + 16, b); // b2 - // packed[0] = ((unpacked[0] & 4) >> 2) | ((unpacked[1] & 4) >> 1) | - // ((unpacked[2] & 4)) | ((unpacked[3] & 4) << 1) | - // ((unpacked[4] & 4) << 2) | ((unpacked[5] & 4) << 3) | - // ((unpacked[6] & 4) << 4) | ((unpacked[7] & 4) << 5); - mask = vdupq_n_u8(4); - b2 = vshrq_n_u8(vandq_u8(unpacked0, mask), 2); - b2 = vorrq_u8(b2, vshrq_n_u8(vandq_u8(unpacked1, mask), 1)); - b2 = vorrq_u8(b2, vandq_u8(unpacked2, mask)); - b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked3, mask), 1)); - b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked4, mask), 2)); - b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked5, mask), 3)); - b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked6, mask), 4)); - b2 = vorrq_u8(b2, vshlq_n_u8(vandq_u8(unpacked7, mask), 5)); - - vst1q_u8(packed, b2); - - // b10_0 - // packed[1] = (unpacked[0] & 3) | ((unpacked[1] & 3) << 2) | - // ((unpacked[2] & 3) << 4) | ((unpacked[3] & 3) << 6); - mask = vdupq_n_u8(3); - uint8x16_t b10_0; - - b10_0 = vandq_u8(unpacked0, mask); - b10_0 = vorrq_u8(b10_0, vshlq_n_u8(vandq_u8(unpacked1, mask), 2)); - b10_0 = vorrq_u8(b10_0, vshlq_n_u8(vandq_u8(unpacked2, mask), 4)); - b10_0 = vorrq_u8(b10_0, vshlq_n_u8(vandq_u8(unpacked3, mask), 6)); - - vst1q_u8(packed + 16, b10_0); - - // b10_1 - // packed[2] = (unpacked[4] & 3) | ((unpacked[5] & 3) << 2) | - // ((unpacked[6] & 3) << 4) | ((unpacked[7] & 3) << 6); - uint8x16_t b10_1; - b10_1 = vandq_u8(unpacked4, mask); - b10_1 = vorrq_u8(b10_1, vshlq_n_u8(vandq_u8(unpacked5, mask), 2)); - b10_1 = vorrq_u8(b10_1, vshlq_n_u8(vandq_u8(unpacked6, mask), 4)); - b10_1 = vorrq_u8(b10_1, vshlq_n_u8(vandq_u8(unpacked7, mask), 6)); - - vst1q_u8(packed + 32, b10_1); + // packed[2] = ((unpacked[6] & 4) << 5) | ((unpacked[7] & 4) << 4) | + // ((unpacked[4] & 7) << 3) | unpacked[5]; + b = vshlq_n_u8(vandq_u8(unpacked6, vdupq_n_u8(4)), 5); + b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked7, vdupq_n_u8(4)), 4)); + b = vorrq_u8(b, vshlq_n_u8(vandq_u8(unpacked4, vdupq_n_u8(7)), 3)); + b = vorrq_u8(b, unpacked5); + vst1q_u8(packed + 32, b); } TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint3_values( @@ -286,47 +233,37 @@ TORCHAO_ALWAYS_INLINE inline void vec_unpack_128_uint3_values( // Input is 3*128 = 384 bits = 48 bytes // Output is 128 bytes - uint8x16_t b2 = vld1q_u8(packed); - uint8x16_t b10_0 = vld1q_u8(packed + 16); - - // unpacked[0] = ((b2 & 1) << 2) | (b10_0 & 3); - unpacked0 = vshlq_n_u8(vandq_u8(b2, vdupq_n_u8(1)), 2); - unpacked0 = vorrq_u8(unpacked0, vandq_u8(b10_0, vdupq_n_u8(3))); - - // unpacked[1] = ((b2 & 2) << 1) | ((b10_0 & 12) >> 2); - unpacked1 = vshlq_n_u8(vandq_u8(b2, vdupq_n_u8(2)), 1); - unpacked1 = - vorrq_u8(unpacked1, vshrq_n_u8(vandq_u8(b10_0, vdupq_n_u8(12)), 2)); - - // unpacked[2] = (b2 & 4) | ((b10_0 & 48) >> 4); - unpacked2 = vandq_u8(b2, vdupq_n_u8(4)); - unpacked2 = - vorrq_u8(unpacked2, vshrq_n_u8(vandq_u8(b10_0, vdupq_n_u8(48)), 4)); - - // unpacked[3] = ((b2 & 8) >> 1) | ((b10_0 & 192) >> 6); - unpacked3 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(8)), 1); - unpacked3 = - vorrq_u8(unpacked3, vshrq_n_u8(vandq_u8(b10_0, vdupq_n_u8(192)), 6)); - - // unpacked[4] = ((b2 & 16) >> 2) | (b10_1 & 3); - uint8x16_t b10_1 = vld1q_u8(packed + 32); - unpacked4 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(16)), 2); - unpacked4 = vorrq_u8(unpacked4, vandq_u8(b10_1, vdupq_n_u8(3))); - - // unpacked[5] = ((b2 & 32) >> 3) | ((b10_1 & 12) >> 2); - unpacked5 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(32)), 3); - unpacked5 = - vorrq_u8(unpacked5, vshrq_n_u8(vandq_u8(b10_1, vdupq_n_u8(12)), 2)); - - // unpacked[6] = ((b2 & 64) >> 4) | ((b10_1 & 48) >> 4); - unpacked6 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(64)), 4); - unpacked6 = - vorrq_u8(unpacked6, vshrq_n_u8(vandq_u8(b10_1, vdupq_n_u8(48)), 4)); - - // unpacked[7] = ((b2 & 128) >> 5) | ((b10_1 & 192) >> 6); - unpacked7 = vshrq_n_u8(vandq_u8(b2, vdupq_n_u8(128)), 5); - unpacked7 = - vorrq_u8(unpacked7, vshrq_n_u8(vandq_u8(b10_1, vdupq_n_u8(192)), 6)); + uint8x16_t b0 = vld1q_u8(packed); + uint8x16_t b1 = vld1q_u8(packed + 16); + uint8x16_t b2 = vld1q_u8(packed + 32); + + // unpacked[0] = ((b0 >> 3) & 7); + uint8x16_t mask = vdupq_n_u8(7); + unpacked0 = vandq_u8(vshrq_n_u8(b0, 3), mask); + + // unpacked[1] = b0 & 7; + unpacked1 = vandq_u8(b0, mask); + + // unpacked[2] = ((b1 >> 3) & 7); + unpacked2 = vandq_u8(vshrq_n_u8(b1, 3), mask); + + // unpacked[3] = b1 & 7; + unpacked3 = vandq_u8(b1, mask); + + // unpacked[4] = ((b2 >> 3) & 7); + unpacked4 = vandq_u8(vshrq_n_u8(b2, 3), mask); + + // unpacked[5] = b2 & 7; + unpacked5 = vandq_u8(b2, mask); + + // unpacked[6] = (b0 >> 6) | ((b2 >> 5) & 4); + mask = vdupq_n_u8(4); + unpacked6 = vshrq_n_u8(b0, 6); + unpacked6 = vorrq_u8(unpacked6, vandq_u8(vshrq_n_u8(b2, 5), mask)); + + // unpacked[7] = (b1 >> 6) | ((b2 >> 4) & 4); + unpacked7 = vshrq_n_u8(b1, 6); + unpacked7 = vorrq_u8(unpacked7, vandq_u8(vshrq_n_u8(b2, 4), mask)); } } // namespace internal diff --git a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp index ae9c5c5344..ef51fd7d43 100644 --- a/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp +++ b/torchao/experimental/kernels/cpu/aarch64/tests/test_bitpacking.cpp @@ -548,17 +548,9 @@ TEST(test_bitpacking_64_uint6_values, PackUnpackAreSame) { torchao::bitpacking::internal::vec_load_64_uint8_values( input0, input1, input2, input3, input.data()); torchao::bitpacking::internal::vec_pack_64_uint6_values( - packed.data(), - input0, - input1, - input2, - input3); + packed.data(), input0, input1, input2, input3); torchao::bitpacking::internal::vec_unpack_64_uint6_values( - unpacked0, - unpacked1, - unpacked2, - unpacked3, - packed.data()); + unpacked0, unpacked1, unpacked2, unpacked3, packed.data()); for (int i = 0; i < 16; ++i) { EXPECT_EQ(input0[i], unpacked0[i]);