Skip to content

Commit

Permalink
Let HNSW throw exception when calculate distance for unsupported Data…
Browse files Browse the repository at this point in the history
…Type (zilliztech#808)

Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
  • Loading branch information
cydrain authored Sep 3, 2024
1 parent d059d9e commit f1818e0
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 45 deletions.
19 changes: 6 additions & 13 deletions thirdparty/hnswlib/hnswlib/space_cosine.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@ namespace hnswlib {
template <typename DataType, typename DistanceType>
static DistanceType
Cosine(const void* pVect1, const void* pVect2, const void* qty_ptr) {
if constexpr (std::is_same_v<DataType, knowhere::fp16>) {
return faiss::fp16_vec_inner_product((const knowhere::fp16*)pVect1, (const knowhere::fp16*)pVect2,
*((size_t*)qty_ptr));
if constexpr (std::is_same_v<DataType, knowhere::fp32>) {
return faiss::fvec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::fp16>) {
return faiss::fp16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::bf16>) {
return faiss::bf16_vec_inner_product((const knowhere::bf16*)pVect1, (const knowhere::bf16*)pVect2,
*((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::fp32>) {
return faiss::fvec_inner_product((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr));
return faiss::bf16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr));
} else {
size_t qty = *((size_t*)qty_ptr);
float res = 0;
for (unsigned i = 0; i < qty; i++) {
res += (DistanceType)((DataType*)pVect1)[i] * (DistanceType)((DataType*)pVect2)[i];
}
return res;
throw std::runtime_error("Unknown Datatype\n");
}
}

Expand Down
19 changes: 6 additions & 13 deletions thirdparty/hnswlib/hnswlib/space_ip.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,14 @@ namespace hnswlib {
template <typename DataType, typename DistanceType>
static DistanceType
InnerProduct(const void* pVect1, const void* pVect2, const void* qty_ptr) {
if constexpr (std::is_same_v<DataType, knowhere::fp16>) {
return faiss::fp16_vec_inner_product((const knowhere::fp16*)pVect1, (const knowhere::fp16*)pVect2,
*((size_t*)qty_ptr));
if constexpr (std::is_same_v<DataType, knowhere::fp32>) {
return faiss::fvec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::fp16>) {
return faiss::fp16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::bf16>) {
return faiss::bf16_vec_inner_product((const knowhere::bf16*)pVect1, (const knowhere::bf16*)pVect2,
*((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::fp32>) {
return faiss::fvec_inner_product((const float*)pVect1, (const float*)pVect2, *((size_t*)qty_ptr));
return faiss::bf16_vec_inner_product((const DataType*)pVect1, (const DataType*)pVect2, *((size_t*)qty_ptr));
} else {
size_t qty = *((size_t*)qty_ptr);
float res = 0;
for (unsigned i = 0; i < qty; i++) {
res += (DistanceType)((DataType*)pVect1)[i] * (DistanceType)((DataType*)pVect2)[i];
}
return res;
throw std::runtime_error("Unknown Datatype\n");
}
}

Expand Down
25 changes: 6 additions & 19 deletions thirdparty/hnswlib/hnswlib/space_l2.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,14 @@ NormSqr(const void* pVect1v, const void* qty_ptr) {
template <typename DataType, typename DistanceType>
static DistanceType
L2Sqr(const void* pVect1v, const void* pVect2v, const void* qty_ptr) {
if constexpr (std::is_same_v<DataType, knowhere::fp16>) {
return faiss::fp16_vec_L2sqr((const knowhere::fp16*)pVect1v, (const knowhere::fp16*)pVect2v,
*((size_t*)qty_ptr));
if constexpr (std::is_same_v<DataType, knowhere::fp32>) {
return faiss::fvec_L2sqr((const DataType*)pVect1v, (const DataType*)pVect2v, *((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::fp16>) {
return faiss::fp16_vec_L2sqr((const DataType*)pVect1v, (const DataType*)pVect2v, *((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::bf16>) {
return faiss::bf16_vec_L2sqr((const knowhere::bf16*)pVect1v, (const knowhere::bf16*)pVect2v,
*((size_t*)qty_ptr));
} else if constexpr (std::is_same_v<DataType, knowhere::fp32>) {
return faiss::fvec_L2sqr((const float*)pVect1v, (const float*)pVect2v, *((size_t*)qty_ptr));
return faiss::bf16_vec_L2sqr((const DataType*)pVect1v, (const DataType*)pVect2v, *((size_t*)qty_ptr));
} else {
auto pVect1 = (DataType*)pVect1v;
auto pVect2 = (DataType*)pVect2v;
size_t qty = *((size_t*)qty_ptr);

float res = 0;
for (size_t i = 0; i < qty; i++) {
float t = (DistanceType)(*pVect1) - (DistanceType)(*pVect2);
pVect1++;
pVect2++;
res += t * t;
}
return (res);
throw std::runtime_error("Unknown Datatype\n");
}
}

Expand Down

0 comments on commit f1818e0

Please sign in to comment.