Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward-merge branch-24.10 into branch-24.12 #371

Merged
merged 1 commit into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ add_library(
src/neighbors/nn_descent_half.cu
src/neighbors/nn_descent_int8.cu
src/neighbors/nn_descent_uint8.cu
src/neighbors/reachability.cu
src/neighbors/refine/detail/refine_device_float_float.cu
src/neighbors/refine/detail/refine_device_half_float.cu
src/neighbors/refine/detail/refine_device_int8_t_float.cu
Expand All @@ -417,6 +418,7 @@ add_library(
src/neighbors/refine/detail/refine_host_uint8_t_float.cpp
src/neighbors/sample_filter.cu
src/selection/select_k_float_int64_t.cu
src/selection/select_k_float_int32_t.cu
src/selection/select_k_float_uint32_t.cu
src/selection/select_k_half_uint32_t.cu
src/stats/silhouette_score.cu
Expand Down
79 changes: 79 additions & 0 deletions cpp/include/cuvs/neighbors/reachability.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/core/device_mdspan.hpp>
#include <raft/core/handle.hpp>
#include <raft/sparse/coo.hpp>

#include <cuvs/distance/distance.hpp>

namespace cuvs::neighbors::reachability {

/**
* @defgroup reachability_cpp Mutual Reachability
* @{
*/
/**
* Constructs a mutual reachability graph, which is a k-nearest neighbors
* graph projected into mutual reachability space using the following
* function for each data point, where core_distance is the distance
* to the kth neighbor: max(core_distance(a), core_distance(b), d(a, b))
*
* Unfortunately, points in the tails of the pdf (e.g. in sparse regions
* of the space) can have very large neighborhoods, which will impact
* nearby neighborhoods. Because of this, it's possible that the
* radius for points in the main mass, which might have a very small
* radius initially, to expand very large. As a result, the initial
* knn which was used to compute the core distances may no longer
* capture the actual neighborhoods after projection into mutual
* reachability space.
*
* For the experimental version, we execute the knn twice- once
* to compute the radii (core distances) and again to capture
* the final neighborhoods. Future iterations of this algorithm
* will work improve upon this "exact" version, by using
* more specialized data structures, such as space-partitioning
* structures. It has also been shown that approximate nearest
* neighbors can yield reasonable neighborhoods as the
* data sizes increase.
*
* @param[in] handle raft handle for resource reuse
* @param[in] X input data points (size m * n)
* @param[in] min_samples this neighborhood will be selected for core distances
* @param[out] indptr CSR indptr of output knn graph (size m + 1)
* @param[out] core_dists output core distances array (size m)
* @param[out] out COO object, uninitialized on entry, on exit it stores the
* (symmetrized) maximum reachability distance for the k nearest
* neighbors.
* @param[in] metric distance metric to use, default Euclidean
* @param[in] alpha weight applied when internal distance is chosen for
* mutual reachability (value of 1.0 disables the weighting)
*/
void mutual_reachability_graph(
const raft::resources& handle,
raft::device_matrix_view<const float, int, raft::row_major> X,
int min_samples,
raft::device_vector_view<int> indptr,
raft::device_vector_view<float> core_dists,
raft::sparse::COO<float, int>& out,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2SqrtExpanded,
float alpha = 1.0);
/**
* @}
*/
} // namespace cuvs::neighbors::reachability
10 changes: 10 additions & 0 deletions cpp/include/cuvs/selection/select_k.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ void select_k(
SelectAlgo algo = SelectAlgo::kAuto,
std::optional<raft::device_vector_view<const int64_t, int64_t>> len_i = std::nullopt);

void select_k(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> in_val,
std::optional<raft::device_matrix_view<const int, int64_t, raft::row_major>> in_idx,
raft::device_matrix_view<float, int64_t, raft::row_major> out_val,
raft::device_matrix_view<int, int64_t, raft::row_major> out_idx,
bool select_min,
bool sorted = false,
SelectAlgo algo = SelectAlgo::kAuto,
std::optional<raft::device_vector_view<const int, int64_t>> len_i = std::nullopt);

/**
* Select k smallest or largest key/values from each row in the input data.
*
Expand Down
27 changes: 23 additions & 4 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@ namespace cuvs::neighbors::detail {
* Calculates brute force knn, using a fixed memory budget
* by tiling over both the rows and columns of pairwise_distances
*/
template <typename ElementType = float, typename IndexType = int64_t, typename DistanceT = float>
template <typename ElementType = float,
typename IndexType = int64_t,
typename DistanceT = float,
typename DistanceEpilogue = raft::identity_op>
void tiled_brute_force_knn(const raft::resources& handle,
const ElementType* search, // size (m ,d)
const ElementType* index, // size (n ,d)
Expand All @@ -78,7 +81,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
size_t max_col_tile_size = 0,
const DistanceT* precomputed_index_norms = nullptr,
const DistanceT* precomputed_search_norms = nullptr,
const uint32_t* filter_bitmap = nullptr)
const uint32_t* filter_bitmap = nullptr,
DistanceEpilogue distance_epilogue = raft::identity_op())
{
// Figure out the number of rows/cols to tile for
size_t tile_rows = 0;
Expand Down Expand Up @@ -207,7 +211,8 @@ void tiled_brute_force_knn(const raft::resources& handle,
IndexType col = j + (idx % current_centroid_size);

cuvs::distance::detail::ops::l2_exp_cutlass_op<DistanceT, DistanceT> l2_op(sqrt);
return l2_op(row_norms[row], col_norms[col], dist[idx]);
auto val = l2_op(row_norms[row], col_norms[col], dist[idx]);
return distance_epilogue(val, row, col);
});
} else if (metric == cuvs::distance::DistanceType::CosineExpanded) {
auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data();
Expand All @@ -221,8 +226,22 @@ void tiled_brute_force_knn(const raft::resources& handle,
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
auto val = DistanceT(1.0) - dist[idx] / DistanceT(row_norms[row] * col_norms[col]);
return val;
return distance_epilogue(val, row, col);
});
} else {
// if we're not l2 distance, and we have a distance epilogue - run it now
if constexpr (!std::is_same_v<DistanceEpilogue, raft::identity_op>) {
auto distances_ptr = temp_distances.data();
raft::linalg::map_offset(
handle,
raft::make_device_vector_view(temp_distances.data(),
current_query_size * current_centroid_size),
[=] __device__(size_t idx) {
IndexType row = i + (idx / current_centroid_size);
IndexType col = j + (idx % current_centroid_size);
return distance_epilogue(distances_ptr[idx], row, col);
});
}
}

if (filter_bitmap != nullptr) {
Expand Down
Loading
Loading