Skip to content

Commit

Permalink
Set norm to 1.0 for all-0 vectors (zilliztech#803)
Browse files Browse the repository at this point in the history
Signed-off-by: Cai Yudong <yudong.cai@zilliz.com>
  • Loading branch information
cydrain authored Sep 3, 2024
1 parent 7a6be01 commit 14818a1
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 8 deletions.
2 changes: 1 addition & 1 deletion thirdparty/DiskANN/include/diskann/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -691,7 +691,7 @@ namespace diskann {
}

for (auto& norm : norms) {
norm = std::sqrt(norm);
norm = (norm == 0.0 ? 1.0 : std::sqrt(norm));
}

in_reader.seekg(2 * sizeof(_u32), std::ios::beg);
Expand Down
1 change: 1 addition & 0 deletions thirdparty/DiskANN/src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2990,6 +2990,7 @@ namespace diskann {
for (unsigned i = 0; i < _nd; i++) {
char *cur_node_offset = _opt_graph + i * _node_size;
float cur_norm = norm_l2sqr(_data + i * _aligned_dim, _aligned_dim);
cur_norm = (cur_norm == 0.0 ? 1.0 : cur_norm);
std::memcpy(cur_node_offset, &cur_norm, sizeof(float));
std::memcpy(cur_node_offset + sizeof(float), _data + i * _aligned_dim,
_data_len - sizeof(float));
Expand Down
14 changes: 8 additions & 6 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -460,11 +460,11 @@ void exhaustive_cosine_seq_impl(

// the lambda that applies a filtered element.
auto apply = [&resi, y, y_norms, d](const float ip, const idx_t j) {
const float norm =
float norm =
(y_norms != nullptr) ?
y_norms[j] :
sqrtf(fvec_norm_L2sqr(y + j * d, d));

norm = (norm == 0.0 ? 1.0 : norm);
resi.add_result(ip / norm, j);
};

Expand Down Expand Up @@ -719,9 +719,10 @@ void exhaustive_cosine_blas(

for (size_t j = j0; j < j1; j++) {
float ip = *ip_line;
float dis = (y_norms_in != nullptr) ? ip / y_norms_in[j]
: ip / y_norms[j];
*ip_line = dis;
float norm = (y_norms_in != nullptr) ? y_norms_in[j]
: y_norms[j];
norm = (norm == 0.0 ? 1.0 : norm);
*ip_line = ip / norm;
ip_line++;
}
}
Expand Down Expand Up @@ -1448,10 +1449,11 @@ void knn_cosine_by_idx(
break;
}
float ip = fvec_inner_product(x_, y + d * idsi[j], d);
const float norm =
float norm =
(y_norms != nullptr) ?
y_norms[idsi[j]] :
sqrtf(fvec_norm_L2sqr(y + d * idsi[j], d));
norm = (norm == 0.0 ? 1.0 : norm);
ip /= norm;

if (ip > simi[0]) {
Expand Down
3 changes: 2 additions & 1 deletion thirdparty/hnswlib/hnswlib/hnswalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -1236,7 +1236,8 @@ class HierarchicalNSW : public AlgorithmInterface<dist_t> {
if constexpr (has_raw_data) {
memcpy(getDataByInternalId(cur_c), data_point, data_size_);
if (metric_type_ == Metric::COSINE) {
data_norm_l2_[cur_c] = std::sqrt(NormSqr<data_t, dist_t>(data_point, dist_func_param_));
float norm = NormSqr<data_t, dist_t>(data_point, dist_func_param_);
data_norm_l2_[cur_c] = (norm == 0.0 ? 1.0 : std::sqrt(norm));
}
}
if constexpr (sq_enabled) {
Expand Down

0 comments on commit 14818a1

Please sign in to comment.