-
Notifications
You must be signed in to change notification settings - Fork 606
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Knn UDF for Exact vector search (#4524)
Co-authored-by: azevaykin <azevaykin@yandex-team.com> Co-authored-by: Valerii Mironov <mbkkt@ydb.tech>
- Loading branch information
1 parent
4d22497
commit ea2ef69
Showing
67 changed files
with
23,818 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
Библиотека для вычисления скалярного произведения векторов. | ||
===================================================== | ||
|
||
Данная библиотека содержит функцию DotProduct, вычисляющую скалярное произведение векторов различных типов. | ||
В отличии от наивной реализации, библиотека использует SSE и работает существенно быстрее. Для сравнения | ||
можно посмотреть результаты бенчмарка. | ||
|
||
Типичное использование - замена кусков кода вроде: | ||
``` | ||
for (int i = 0; i < len; i++) | ||
dot_product += a[i] * b[i]); | ||
``` | ||
на существенно более эффективный вызов ```DotProduct(a, b, len)```. | ||
|
||
Работает для типов i8, i32, float, double. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,274 @@ | ||
#include "dot_product.h" | ||
#include "dot_product_sse.h" | ||
#include "dot_product_avx2.h" | ||
#include "dot_product_simple.h" | ||
|
||
#include <library/cpp/sse/sse.h> | ||
#include <library/cpp/testing/common/env.h> | ||
#include <util/system/compiler.h> | ||
#include <util/generic/utility.h> | ||
#include <util/system/cpu_id.h> | ||
#include <util/system/env.h> | ||
|
||
namespace NDotProductImpl { | ||
i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept = &DotProductSimple; | ||
ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept = &DotProductSimple; | ||
i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept = &DotProductSimple; | ||
float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept = &DotProductSimple; | ||
double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept = &DotProductSimple; | ||
|
||
namespace { | ||
[[maybe_unused]] const int _ = [] { | ||
if (!FromYaTest() && GetEnv("Y_NO_AVX_IN_DOT_PRODUCT") == "" && NX86::HaveAVX2() && NX86::HaveFMA()) { | ||
DotProductI8Impl = &DotProductAvx2; | ||
DotProductUi8Impl = &DotProductAvx2; | ||
DotProductI32Impl = &DotProductAvx2; | ||
DotProductFloatImpl = &DotProductAvx2; | ||
DotProductDoubleImpl = &DotProductAvx2; | ||
} else { | ||
#ifdef ARCADIA_SSE | ||
DotProductI8Impl = &DotProductSse; | ||
DotProductUi8Impl = &DotProductSse; | ||
DotProductI32Impl = &DotProductSse; | ||
DotProductFloatImpl = &DotProductSse; | ||
DotProductDoubleImpl = &DotProductSse; | ||
#endif | ||
} | ||
return 0; | ||
}(); | ||
} | ||
} | ||
|
||
#ifdef ARCADIA_SSE | ||
float L2NormSquared(const float* v, size_t length) noexcept { | ||
__m128 sum1 = _mm_setzero_ps(); | ||
__m128 sum2 = _mm_setzero_ps(); | ||
__m128 a1, a2, m1, m2; | ||
|
||
while (length >= 8) { | ||
a1 = _mm_loadu_ps(v); | ||
m1 = _mm_mul_ps(a1, a1); | ||
|
||
a2 = _mm_loadu_ps(v + 4); | ||
sum1 = _mm_add_ps(sum1, m1); | ||
|
||
m2 = _mm_mul_ps(a2, a2); | ||
sum2 = _mm_add_ps(sum2, m2); | ||
|
||
length -= 8; | ||
v += 8; | ||
} | ||
|
||
if (length >= 4) { | ||
a1 = _mm_loadu_ps(v); | ||
sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1)); | ||
|
||
length -= 4; | ||
v += 4; | ||
} | ||
|
||
sum1 = _mm_add_ps(sum1, sum2); | ||
|
||
if (length) { | ||
switch (length) { | ||
case 3: | ||
a1 = _mm_set_ps(0.0f, v[2], v[1], v[0]); | ||
break; | ||
|
||
case 2: | ||
a1 = _mm_set_ps(0.0f, 0.0f, v[1], v[0]); | ||
break; | ||
|
||
case 1: | ||
a1 = _mm_set_ps(0.0f, 0.0f, 0.0f, v[0]); | ||
break; | ||
|
||
default: | ||
Y_UNREACHABLE(); | ||
} | ||
|
||
sum1 = _mm_add_ps(sum1, _mm_mul_ps(a1, a1)); | ||
} | ||
|
||
alignas(16) float res[4]; | ||
_mm_store_ps(res, sum1); | ||
|
||
return res[0] + res[1] + res[2] + res[3]; | ||
} | ||
|
||
template <bool computeLL, bool computeLR, bool computeRR> | ||
Y_FORCE_INLINE | ||
static void TriWayDotProductIteration(__m128& sumLL, __m128& sumLR, __m128& sumRR, const __m128 a, const __m128 b) { | ||
if constexpr (computeLL) { | ||
sumLL = _mm_add_ps(sumLL, _mm_mul_ps(a, a)); | ||
} | ||
if constexpr (computeLR) { | ||
sumLR = _mm_add_ps(sumLR, _mm_mul_ps(a, b)); | ||
} | ||
if constexpr (computeRR) { | ||
sumRR = _mm_add_ps(sumRR, _mm_mul_ps(b, b)); | ||
} | ||
} | ||
|
||
|
||
template <bool computeLL, bool computeLR, bool computeRR> | ||
static TTriWayDotProduct<float> TriWayDotProductImpl(const float* lhs, const float* rhs, size_t length) noexcept { | ||
__m128 sumLL1 = _mm_setzero_ps(); | ||
__m128 sumLR1 = _mm_setzero_ps(); | ||
__m128 sumRR1 = _mm_setzero_ps(); | ||
__m128 sumLL2 = _mm_setzero_ps(); | ||
__m128 sumLR2 = _mm_setzero_ps(); | ||
__m128 sumRR2 = _mm_setzero_ps(); | ||
|
||
while (length >= 8) { | ||
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); | ||
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL2, sumLR2, sumRR2, _mm_loadu_ps(lhs + 4), _mm_loadu_ps(rhs + 4)); | ||
length -= 8; | ||
lhs += 8; | ||
rhs += 8; | ||
} | ||
|
||
if (length >= 4) { | ||
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, _mm_loadu_ps(lhs + 0), _mm_loadu_ps(rhs + 0)); | ||
length -= 4; | ||
lhs += 4; | ||
rhs += 4; | ||
} | ||
|
||
if constexpr (computeLL) { | ||
sumLL1 = _mm_add_ps(sumLL1, sumLL2); | ||
} | ||
if constexpr (computeLR) { | ||
sumLR1 = _mm_add_ps(sumLR1, sumLR2); | ||
} | ||
if constexpr (computeRR) { | ||
sumRR1 = _mm_add_ps(sumRR1, sumRR2); | ||
} | ||
|
||
if (length) { | ||
__m128 a, b; | ||
switch (length) { | ||
case 3: | ||
a = _mm_set_ps(0.0f, lhs[2], lhs[1], lhs[0]); | ||
b = _mm_set_ps(0.0f, rhs[2], rhs[1], rhs[0]); | ||
break; | ||
case 2: | ||
a = _mm_set_ps(0.0f, 0.0f, lhs[1], lhs[0]); | ||
b = _mm_set_ps(0.0f, 0.0f, rhs[1], rhs[0]); | ||
break; | ||
case 1: | ||
a = _mm_set_ps(0.0f, 0.0f, 0.0f, lhs[0]); | ||
b = _mm_set_ps(0.0f, 0.0f, 0.0f, rhs[0]); | ||
break; | ||
default: | ||
Y_UNREACHABLE(); | ||
} | ||
TriWayDotProductIteration<computeLL, computeLR, computeRR>(sumLL1, sumLR1, sumRR1, a, b); | ||
} | ||
|
||
__m128 t0 = sumLL1; | ||
__m128 t1 = sumLR1; | ||
__m128 t2 = sumRR1; | ||
__m128 t3 = _mm_setzero_ps(); | ||
_MM_TRANSPOSE4_PS(t0, t1, t2, t3); | ||
t0 = _mm_add_ps(t0, t1); | ||
t0 = _mm_add_ps(t0, t2); | ||
t0 = _mm_add_ps(t0, t3); | ||
|
||
alignas(16) float res[4]; | ||
_mm_store_ps(res, t0); | ||
TTriWayDotProduct<float> result{res[0], res[1], res[2]}; | ||
static constexpr const TTriWayDotProduct<float> def; | ||
// fill skipped fields with default values | ||
if constexpr (!computeLL) { | ||
result.LL = def.LL; | ||
} | ||
if constexpr (!computeLR) { | ||
result.LR = def.LR; | ||
} | ||
if constexpr (!computeRR) { | ||
result.RR = def.RR; | ||
} | ||
return result; | ||
} | ||
|
||
|
||
TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept { | ||
mask &= 0b111; | ||
if (Y_LIKELY(mask == 0b111)) { // compute dot-product and length² of two vectors | ||
return TriWayDotProductImpl<true, true, true>(lhs, rhs, length); | ||
} else if (Y_LIKELY(mask == 0b110 || mask == 0b011)) { // compute dot-product and length² of one vector | ||
const bool computeLL = (mask == 0b110); | ||
if (!computeLL) { | ||
DoSwap(lhs, rhs); | ||
} | ||
auto result = TriWayDotProductImpl<true, true, false>(lhs, rhs, length); | ||
if (!computeLL) { | ||
DoSwap(result.LL, result.RR); | ||
} | ||
return result; | ||
} else { | ||
// dispatch unlikely & sparse cases | ||
TTriWayDotProduct<float> result{}; | ||
switch(mask) { | ||
case 0b000: | ||
break; | ||
case 0b100: | ||
result.LL = L2NormSquared(lhs, length); | ||
break; | ||
case 0b010: | ||
result.LR = DotProduct(lhs, rhs, length); | ||
break; | ||
case 0b001: | ||
result.RR = L2NormSquared(rhs, length); | ||
break; | ||
case 0b101: | ||
result.LL = L2NormSquared(lhs, length); | ||
result.RR = L2NormSquared(rhs, length); | ||
break; | ||
default: | ||
Y_UNREACHABLE(); | ||
} | ||
return result; | ||
} | ||
} | ||
|
||
#else | ||
|
||
float L2NormSquared(const float* v, size_t length) noexcept { | ||
return DotProduct(v, v, length); | ||
} | ||
|
||
TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept { | ||
TTriWayDotProduct<float> result; | ||
if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LL)) { | ||
result.LL = L2NormSquared(lhs, length); | ||
} | ||
if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::LR)) { | ||
result.LR = DotProduct(lhs, rhs, length); | ||
} | ||
if (mask & static_cast<unsigned>(ETriWayDotProductComputeMask::RR)) { | ||
result.RR = L2NormSquared(rhs, length); | ||
} | ||
return result; | ||
} | ||
|
||
#endif // ARCADIA_SSE | ||
|
||
namespace NDotProduct { | ||
void DisableAvx2() { | ||
#ifdef ARCADIA_SSE | ||
NDotProductImpl::DotProductI8Impl = &DotProductSse; | ||
NDotProductImpl::DotProductUi8Impl = &DotProductSse; | ||
NDotProductImpl::DotProductI32Impl = &DotProductSse; | ||
NDotProductImpl::DotProductFloatImpl = &DotProductSse; | ||
NDotProductImpl::DotProductDoubleImpl = &DotProductSse; | ||
#else | ||
NDotProductImpl::DotProductI8Impl = &DotProductSimple; | ||
NDotProductImpl::DotProductUi8Impl = &DotProductSimple; | ||
NDotProductImpl::DotProductI32Impl = &DotProductSimple; | ||
NDotProductImpl::DotProductFloatImpl = &DotProductSimple; | ||
NDotProductImpl::DotProductDoubleImpl = &DotProductSimple; | ||
#endif | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
#pragma once | ||
|
||
#include <util/system/types.h> | ||
#include <util/system/compiler.h> | ||
|
||
#include <numeric> | ||
|
||
/** | ||
* Dot product (Inner product or scalar product) implementation using SSE when possible. | ||
*/ | ||
namespace NDotProductImpl { | ||
extern i32 (*DotProductI8Impl)(const i8* lhs, const i8* rhs, size_t length) noexcept; | ||
extern ui32 (*DotProductUi8Impl)(const ui8* lhs, const ui8* rhs, size_t length) noexcept; | ||
extern i64 (*DotProductI32Impl)(const i32* lhs, const i32* rhs, size_t length) noexcept; | ||
extern float (*DotProductFloatImpl)(const float* lhs, const float* rhs, size_t length) noexcept; | ||
extern double (*DotProductDoubleImpl)(const double* lhs, const double* rhs, size_t length) noexcept; | ||
} | ||
|
||
Y_PURE_FUNCTION | ||
inline i32 DotProduct(const i8* lhs, const i8* rhs, size_t length) noexcept { | ||
return NDotProductImpl::DotProductI8Impl(lhs, rhs, length); | ||
} | ||
|
||
Y_PURE_FUNCTION | ||
inline ui32 DotProduct(const ui8* lhs, const ui8* rhs, size_t length) noexcept { | ||
return NDotProductImpl::DotProductUi8Impl(lhs, rhs, length); | ||
} | ||
|
||
Y_PURE_FUNCTION | ||
inline i64 DotProduct(const i32* lhs, const i32* rhs, size_t length) noexcept { | ||
return NDotProductImpl::DotProductI32Impl(lhs, rhs, length); | ||
} | ||
|
||
Y_PURE_FUNCTION | ||
inline float DotProduct(const float* lhs, const float* rhs, size_t length) noexcept { | ||
return NDotProductImpl::DotProductFloatImpl(lhs, rhs, length); | ||
} | ||
|
||
Y_PURE_FUNCTION | ||
inline double DotProduct(const double* lhs, const double* rhs, size_t length) noexcept { | ||
return NDotProductImpl::DotProductDoubleImpl(lhs, rhs, length); | ||
} | ||
|
||
/** | ||
* Dot product to itself | ||
*/ | ||
Y_PURE_FUNCTION | ||
float L2NormSquared(const float* v, size_t length) noexcept; | ||
|
||
// TODO(yazevnul): make `L2NormSquared` for double, this should be faster than `DotProduct` | ||
// where `lhs == rhs` because it will save N load instructions. | ||
|
||
template <typename T> | ||
struct TTriWayDotProduct { | ||
T LL = 1; | ||
T LR = 0; | ||
T RR = 1; | ||
}; | ||
|
||
enum class ETriWayDotProductComputeMask: unsigned { | ||
// basic | ||
LL = 0b100, | ||
LR = 0b010, | ||
RR = 0b001, | ||
|
||
// useful combinations | ||
All = 0b111, | ||
Left = 0b110, // skip computation of R·R | ||
Right = 0b011, // skip computation of L·L | ||
}; | ||
|
||
Y_PURE_FUNCTION | ||
TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, unsigned mask) noexcept; | ||
|
||
/** | ||
* For two vectors L and R computes 3 dot-products: L·L, L·R, R·R | ||
*/ | ||
Y_PURE_FUNCTION | ||
static inline TTriWayDotProduct<float> TriWayDotProduct(const float* lhs, const float* rhs, size_t length, ETriWayDotProductComputeMask mask = ETriWayDotProductComputeMask::All) noexcept { | ||
return TriWayDotProduct(lhs, rhs, length, static_cast<unsigned>(mask)); | ||
} | ||
|
||
namespace NDotProduct { | ||
// Simpler wrapper allowing to use this functions as template argument. | ||
template <typename T> | ||
struct TDotProduct { | ||
using TResult = decltype(DotProduct(static_cast<const T*>(nullptr), static_cast<const T*>(nullptr), 0)); | ||
Y_PURE_FUNCTION | ||
inline TResult operator()(const T* l, const T* r, size_t length) const { | ||
return DotProduct(l, r, length); | ||
} | ||
}; | ||
|
||
void DisableAvx2(); | ||
} | ||
|
Oops, something went wrong.