diff --git a/thirdparty/hnswlib/hnswlib/space_cosine.h b/thirdparty/hnswlib/hnswlib/space_cosine.h index e55dadbe1..a3b35b3b8 100644 --- a/thirdparty/hnswlib/hnswlib/space_cosine.h +++ b/thirdparty/hnswlib/hnswlib/space_cosine.h @@ -8,21 +8,14 @@ namespace hnswlib { template static DistanceType Cosine(const void* pVect1, const void* pVect2, const void* qty_ptr) { - if constexpr (std::is_same_v) { - return faiss::fp16_vec_inner_product((const knowhere::fp16*)pVect1, (const knowhere::fp16*)pVect2, - *((size_t*)qty_ptr)); + if constexpr (std::is_same_v) { + return faiss::fvec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr)); + } else if constexpr (std::is_same_v) { + return faiss::fp16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr)); } else if constexpr (std::is_same_v) { - return faiss::bf16_vec_inner_product((const knowhere::bf16*)pVect1, (const knowhere::bf16*)pVect2, - *((size_t*)qty_ptr)); - } else if constexpr (std::is_same_v) { - return faiss::fvec_inner_product((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr)); + return faiss::bf16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr)); } else { - size_t qty = *((size_t*)qty_ptr); - float res = 0; - for (unsigned i = 0; i < qty; i++) { - res += (DistanceType)((DataType*)pVect1)[i] * (DistanceType)((DataType*)pVect2)[i]; - } - return res; + throw std::runtime_error("Unknown Datatype\n"); } } diff --git a/thirdparty/hnswlib/hnswlib/space_ip.h b/thirdparty/hnswlib/hnswlib/space_ip.h index af952571d..0ce7f46c6 100644 --- a/thirdparty/hnswlib/hnswlib/space_ip.h +++ b/thirdparty/hnswlib/hnswlib/space_ip.h @@ -8,21 +8,14 @@ namespace hnswlib { template static DistanceType InnerProduct(const void* pVect1, const void* pVect2, const void* qty_ptr) { - if constexpr (std::is_same_v) { - return faiss::fp16_vec_inner_product((const knowhere::fp16*)pVect1, (const knowhere::fp16*)pVect2, - *((size_t*)qty_ptr)); + if constexpr (std::is_same_v) { + return faiss::fvec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr)); + } else if constexpr (std::is_same_v) { + return faiss::fp16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr)); } else if constexpr (std::is_same_v) { - return faiss::bf16_vec_inner_product((const knowhere::bf16*)pVect1, (const knowhere::bf16*)pVect2, - *((size_t*)qty_ptr)); - } else if constexpr (std::is_same_v) { - return faiss::fvec_inner_product((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr)); + return faiss::bf16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr)); } else { - size_t qty = *((size_t*)qty_ptr); - float res = 0; - for (unsigned i = 0; i < qty; i++) { - res += (DistanceType)((DataType*)pVect1)[i] * (DistanceType)((DataType*)pVect2)[i]; - } - return res; + throw std::runtime_error("Unknown Datatype\n"); } } diff --git a/thirdparty/hnswlib/hnswlib/space_l2.h b/thirdparty/hnswlib/hnswlib/space_l2.h index 49a7c5102..6bac0ff11 100644 --- a/thirdparty/hnswlib/hnswlib/space_l2.h +++ b/thirdparty/hnswlib/hnswlib/space_l2.h @@ -22,27 +22,14 @@ NormSqr(const void* pVect1v, const void* qty_ptr) { template static DistanceType L2Sqr(const void* pVect1v, const void* pVect2v, const void* qty_ptr) { - if constexpr (std::is_same_v) { - return faiss::fp16_vec_L2sqr((const knowhere::fp16*)pVect1v, (const knowhere::fp16*)pVect2v, - *((size_t*)qty_ptr)); + if constexpr (std::is_same_v) { + return faiss::fvec_L2sqr((const DataType*)pVect1v, (const DataType*)pVect2v, *((size_t*)qty_ptr)); + } else if constexpr (std::is_same_v) { + return faiss::fp16_vec_L2sqr((const DataType*)pVect1v, (const DataType*)pVect2v, *((size_t*)qty_ptr)); } else if constexpr (std::is_same_v) { - return faiss::bf16_vec_L2sqr((const knowhere::bf16*)pVect1v, (const knowhere::bf16*)pVect2v, - *((size_t*)qty_ptr)); - } else if constexpr (std::is_same_v) { - return faiss::fvec_L2sqr((const float*)pVect1v, (const float*)pVect2v, *((size_t*)qty_ptr)); + return faiss::bf16_vec_L2sqr((const DataType*)pVect1v, (const DataType*)pVect2v, *((size_t*)qty_ptr)); } else { - auto pVect1 = (DataType*)pVect1v; - auto pVect2 = (DataType*)pVect2v; - size_t qty = *((size_t*)qty_ptr); - - float res = 0; - for (size_t i = 0; i < qty; i++) { - float t = (DistanceType)(*pVect1) - (DistanceType)(*pVect2); - pVect1++; - pVect2++; - res += t * t; - } - return (res); + throw std::runtime_error("Unknown Datatype\n"); } }