Skip to content

Commit

Permalink
improve SQ6 code for AVX512 (zilliztech#744)
Browse files Browse the repository at this point in the history
Signed-off-by: Alexandr Guzhva <alexanderguzhva@gmail.com>
  • Loading branch information
alexanderguzhva authored Aug 5, 2024
1 parent c91e59b commit beb9f85
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 46 deletions.
1 change: 1 addition & 0 deletions cmake/libs/libfaiss.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ if(__X86_64)
-mavx512f
-mavx512dq
-mavx512bw
-mavx512vl
-mpopcnt>)

add_library(faiss STATIC ${FAISS_SRCS})
Expand Down
145 changes: 99 additions & 46 deletions thirdparty/faiss/faiss/impl/ScalarQuantizerCodec_avx512.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,43 +67,89 @@ struct Codec6bit_avx512 : public Codec6bit_avx {
// TODO: can be optimized
static FAISS_ALWAYS_INLINE __m512
decode_16_components(const uint8_t* code, int i) {
// // todo aguzhva: the following piece of code is very fast
// // for Intel chips. AMD ones will be very slow unless Zen3+
//
// const uint16_t* data16_0 = (const uint16_t*)(code + (i >> 2) * 3);
// const uint64_t* data64_0 = (const uint64_t*)data16_0;
// const uint64_t val_0 = *data64_0;
// const uint64_t vext_0 = _pdep_u64(val_0, 0x3F3F3F3F3F3F3F3FULL);
//
// const uint16_t* data16_1 = data16_0 + 3;
// const uint32_t* data32_1 = (const uint32_t*)data16_1;
// const uint64_t val_1 = *data32_1 + ((uint64_t)data16_1[2] << 32);
// const uint64_t vext_1 = _pdep_u64(val_1, 0x3F3F3F3F3F3F3F3FULL);
//
// const __m128i i8 = _mm_set_epi64x(vext_1, vext_0);
// const __m512i i32 = _mm512_cvtepi8_epi32(i8);
// const __m512 f8 = _mm512_cvtepi32_ps(i32);
// const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
// const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
// return _mm512_fmadd_ps(f8, one_255, half_one_255);

return _mm512_set_ps(
decode_component(code, i + 15),
decode_component(code, i + 14),
decode_component(code, i + 13),
decode_component(code, i + 12),
decode_component(code, i + 11),
decode_component(code, i + 10),
decode_component(code, i + 9),
decode_component(code, i + 8),
decode_component(code, i + 7),
decode_component(code, i + 6),
decode_component(code, i + 5),
decode_component(code, i + 4),
decode_component(code, i + 3),
decode_component(code, i + 2),
decode_component(code, i + 1),
decode_component(code, i + 0));
/*
// todo aguzhva: the following piece of code is very fast
// for Intel chips. AMD ones will be very slow unless Zen3+
const uint16_t* data16_0 = (const uint16_t*)(code + (i >> 2) * 3);
const uint64_t* data64_0 = (const uint64_t*)data16_0;
const uint64_t val_0 = *data64_0;
const uint64_t vext_0 = _pdep_u64(val_0, 0x3F3F3F3F3F3F3F3FULL);
const uint16_t* data16_1 = data16_0 + 3;
const uint32_t* data32_1 = (const uint32_t*)data16_1;
const uint64_t val_1 = *data32_1 + ((uint64_t)data16_1[2] << 32);
const uint64_t vext_1 = _pdep_u64(val_1, 0x3F3F3F3F3F3F3F3FULL);
const __m128i i8 = _mm_set_epi64x(vext_1, vext_0);
const __m512i i32 = _mm512_cvtepi8_epi32(i8);
const __m512 f8 = _mm512_cvtepi32_ps(i32);
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
return _mm512_fmadd_ps(f8, one_255, half_one_255);
*/

/*
// todo aguzhva: another candidate for pdep, which might be faster
const uint16_t* data16_0 = (const uint16_t*)(code + (i >> 2) * 3);
const uint64_t* data64_0 = (const uint64_t*)data16_0;
const uint64_t val_0 = *data64_0;
const uint64_t vext_0 = _pdep_u64(val_0, 0x3F3F3F3F3F3F3F3FULL);
const uint32_t* data32_1 = (const uint32_t*)data16_0;
const uint64_t val_1 = (val_0 >> 48) | (((uint64_t)data32_1[1]) << 16);
const uint64_t vext_1 = _pdep_u64(val_1, 0x3F3F3F3F3F3F3F3FULL);
const __m128i i8 = _mm_set_epi64x(vext_1, vext_0);
const __m512i i32 = _mm512_cvtepi8_epi32(i8);
const __m512 f8 = _mm512_cvtepi32_ps(i32);
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
return _mm512_fmadd_ps(f8, one_255, half_one_255);
*/

// pure AVX512 implementation, slower than pdep one, but has no problems
// for AMD

// clang-format off

// 16 components, 16x6 bit=12 bytes
const __m128i bit_6v =
_mm_maskz_loadu_epi8(0b0000111111111111, code + (i >> 2) * 3);
const __m256i bit_6v_256 = _mm256_broadcast_i32x4(bit_6v);

// 00 01 02 03 04 05 06 07 08 09 0A 0B 0C 0D 0E 0F
// 00 01 02 03
const __m256i shuffle_mask = _mm256_setr_epi16(
0xFF00, 0x0100, 0x0201, 0xFF02,
0xFF03, 0x0403, 0x0504, 0xFF05,
0xFF06, 0x0706, 0x0807, 0xFF08,
0xFF09, 0x0A09, 0x0B0A, 0xFF0B);
const __m256i shuffled = _mm256_shuffle_epi8(bit_6v_256, shuffle_mask);

// 0: xxxxxxxx xx543210
// 1: xxxx5432 10xxxxxx
// 2: xxxxxx54 3210xxxx
// 3: xxxxxxxx 543210xx
const __m256i shift_right_v = _mm256_setr_epi16(
0x0U, 0x6U, 0x4U, 0x2U,
0x0U, 0x6U, 0x4U, 0x2U,
0x0U, 0x6U, 0x4U, 0x2U,
0x0U, 0x6U, 0x4U, 0x2U);
__m256i shuffled_shifted = _mm256_srlv_epi16(shuffled, shift_right_v);

// remove unneeded bits
shuffled_shifted =
_mm256_and_si256(shuffled_shifted, _mm256_set1_epi16(0x003F));

// scale
const __m512 f8 =
_mm512_cvtepi32_ps(_mm512_cvtepi16_epi32(shuffled_shifted));
const __m512 half_one_255 = _mm512_set1_ps(0.5f / 63.f);
const __m512 one_255 = _mm512_set1_ps(1.f / 63.f);
return _mm512_fmadd_ps(f8, one_255, half_one_255);

// clang-format on
}
};

Expand Down Expand Up @@ -264,8 +310,8 @@ struct Quantizer8bitDirect_avx512<16> : public Quantizer8bitDirect_avx<8> {
FAISS_ALWAYS_INLINE __m512
reconstruct_16_components(const uint8_t* code, int i) const {
__m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
__m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
return _mm512_cvtepi32_ps(y16); // 16 * float32
__m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
return _mm512_cvtepi32_ps(y16); // 16 * float32
}
};

Expand All @@ -277,26 +323,33 @@ template <int SIMDWIDTH>
struct Quantizer8bitDirectSigned_avx512 {};

template <>
struct Quantizer8bitDirectSigned_avx512<1> : public Quantizer8bitDirectSigned_avx<1> {
struct Quantizer8bitDirectSigned_avx512<1>
: public Quantizer8bitDirectSigned_avx<1> {
Quantizer8bitDirectSigned_avx512(size_t d, const std::vector<float>& unused)
: Quantizer8bitDirectSigned_avx<1>(d, unused) {}
};

template <>
struct Quantizer8bitDirectSigned_avx512<8> : public Quantizer8bitDirectSigned_avx<8> {
Quantizer8bitDirectSigned_avx512(size_t d, const std::vector<float>& trained)
struct Quantizer8bitDirectSigned_avx512<8>
: public Quantizer8bitDirectSigned_avx<8> {
Quantizer8bitDirectSigned_avx512(
size_t d,
const std::vector<float>& trained)
: Quantizer8bitDirectSigned_avx<8>(d, trained) {}
};

template <>
struct Quantizer8bitDirectSigned_avx512<16> : public Quantizer8bitDirectSigned_avx<8> {
Quantizer8bitDirectSigned_avx512(size_t d, const std::vector<float>& trained)
struct Quantizer8bitDirectSigned_avx512<16>
: public Quantizer8bitDirectSigned_avx<8> {
Quantizer8bitDirectSigned_avx512(
size_t d,
const std::vector<float>& trained)
: Quantizer8bitDirectSigned_avx<8>(d, trained) {}

FAISS_ALWAYS_INLINE __m512
reconstruct_16_components(const uint8_t* code, int i) const {
__m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
__m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
__m128i x16 = _mm_loadu_si128((__m128i*)(code + i)); // 16 * int8
__m512i y16 = _mm512_cvtepu8_epi32(x16); // 16 * int32
__m512i c16 = _mm512_set1_epi32(128);
__m512i z16 = _mm512_sub_epi32(y16, c16); // subtract 128 from all lanes
return _mm512_cvtepi32_ps(z16); // 16 * float32
Expand Down

0 comments on commit beb9f85

Please sign in to comment.