Skip to content

Commit

Permalink
Add sample filtering for ivf_flat. Filtering code refactoring and cle…
Browse files Browse the repository at this point in the history
…anup (#1541)

The PR does the following:
* Introduces `ivf_flat::search_with_filtering()` call in the same way the filtering was introduced to ivf_pq in #1513 
* Moves `sample_filter.cuh` from `raft/neighbor/detail` to `raft/neighbor`
* Moves `NoneSampleFilter` from `raft::neighbor::ivf_pq::detail` namespace to `raft::neighbor::filtering` namespace
* Renames `NoneSampleFilter` to `NoneIvfSampleFilter` and template argument `SampleFilterT` to `IvfSampleFilterT`
* Adds a missing `resource::get_workspace_resource(handle)` in `ivf_flat-inl.cuh` in a `search_with_filtering()` call (which was copied from `search()` call with the same problem)
* Adds more comments in `ivf_pq-inl.h`
* Some code cleanup in `ivf_pq-inl.h`

Authors:
  - Alexander Guzhva (https://github.com/alexanderguzhva)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1541
  • Loading branch information
alexanderguzhva authored Jun 5, 2023
1 parent 6bc237f commit 9bf7b4b
Show file tree
Hide file tree
Showing 25 changed files with 928 additions and 621 deletions.
52 changes: 31 additions & 21 deletions cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,24 +16,27 @@

#pragma once

#include <cstdint> // uintX_t
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view
#include <cstdint> // uintX_t
#include <raft/neighbors/ivf_flat_types.hpp> // raft::neighbors::ivf_flat::index
#include <raft/neighbors/sample_filter_types.hpp> // none_ivf_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT
#include <rmm/cuda_stream_view.hpp> // rmm:cuda_stream_view

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

namespace raft::neighbors::ivf_flat::detail {

template <typename T, typename AccT, typename IdxT>
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const uint32_t queries_offset,
const raft::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
Expand All @@ -43,23 +46,30 @@ void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index<T, IdxT>& i

#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \
extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<T, AccT, IdxT>( \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \
T, AccT, IdxT, IvfSampleFilterT) \
extern template void \
raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan<T, AccT, IdxT, IvfSampleFilterT>( \
const raft::neighbors::ivf_flat::index<T, IdxT>& index, \
const T* queries, \
const uint32_t* coarse_query_results, \
const uint32_t n_queries, \
const uint32_t queries_offset, \
const raft::distance::DistanceType metric, \
const uint32_t n_probes, \
const uint32_t k, \
const bool select_min, \
IvfSampleFilterT sample_filter, \
IdxT* neighbors, \
float* distances, \
uint32_t& grid_dim_x, \
rmm::cuda_stream_view stream)

instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);
instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(
uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter);

#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan
86 changes: 63 additions & 23 deletions cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -646,6 +646,7 @@ struct loadAndComputeDist<kUnroll, Lambda, 1, int8_t, int32_t> {
* @param n_probes
* @param k
* @param dim
* @param sample_filter
* @param[out] neighbors
* @param[out] distances
*/
Expand All @@ -655,6 +656,7 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
typename Lambda,
typename PostLambda>
__global__ void __launch_bounds__(kThreadsPerBlock)
Expand All @@ -666,9 +668,11 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
const IdxT* const* list_indices_ptrs,
const T* const* list_data_ptrs,
const uint32_t* list_sizes,
const uint32_t queries_offset,
const uint32_t n_probes,
const uint32_t k,
const uint32_t dim,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances)
{
Expand Down Expand Up @@ -736,7 +740,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
const bool valid = vec_id < list_length;

// Process first shm_assisted_dim dimensions (always using shared memory)
if (valid) {
if (valid && sample_filter(queries_offset + blockIdx.y, probe_id, vec_id)) {
loadAndComputeDist<kUnroll, decltype(compute_dist), Veclen, T, AccT> lc(dist,
compute_dist);
for (int pos = 0; pos < shm_assisted_dim;
Expand Down Expand Up @@ -803,6 +807,7 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
typename Lambda,
typename PostLambda>
void launch_kernel(Lambda lambda,
Expand All @@ -811,17 +816,26 @@ void launch_kernel(Lambda lambda,
const T* queries,
const uint32_t* coarse_index,
const uint32_t num_queries,
const uint32_t queries_offset,
const uint32_t n_probes,
const uint32_t k,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
{
RAFT_EXPECTS(Veclen == index.veclen(),
"Configured Veclen does not match the index interleaving pattern.");
constexpr auto kKernel =
interleaved_scan_kernel<Capacity, Veclen, Ascending, T, AccT, IdxT, Lambda, PostLambda>;
constexpr auto kKernel = interleaved_scan_kernel<Capacity,
Veclen,
Ascending,
T,
AccT,
IdxT,
IvfSampleFilterT,
Lambda,
PostLambda>;
const int max_query_smem = 16384;
int query_smem_elems =
std::min<int>(max_query_smem / sizeof(T), Pow2<Veclen * WarpSize>::roundUp(index.dim()));
Expand Down Expand Up @@ -860,9 +874,11 @@ void launch_kernel(Lambda lambda,
index.inds_ptrs().data_handle(),
index.data_ptrs().data_handle(),
index.list_sizes().data_handle(),
queries_offset + query_offset,
n_probes,
k,
index.dim(),
sample_filter,
neighbors,
distances);
queries += grid_dim_y * index.dim();
Expand Down Expand Up @@ -931,6 +947,7 @@ template <int Capacity,
typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
typename... Args>
void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args)
{
Expand All @@ -943,6 +960,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
T,
AccT,
IdxT,
IvfSampleFilterT,
euclidean_dist<Veclen, T, AccT>,
raft::identity_op>({}, {}, std::forward<Args>(args)...);
case raft::distance::DistanceType::L2SqrtExpanded:
Expand All @@ -953,6 +971,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
T,
AccT,
IdxT,
IvfSampleFilterT,
euclidean_dist<Veclen, T, AccT>,
raft::sqrt_op>({}, {}, std::forward<Args>(args)...);
case raft::distance::DistanceType::InnerProduct:
Expand All @@ -962,6 +981,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
T,
AccT,
IdxT,
IvfSampleFilterT,
inner_prod_dist<Veclen, T, AccT>,
raft::identity_op>({}, {}, std::forward<Args>(args)...);
// NB: update the description of `knn::ivf_flat::build` when adding here a new metric.
Expand All @@ -976,6 +996,7 @@ void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... arg
template <typename T,
typename AccT,
typename IdxT,
typename IvfSampleFilterT,
int Capacity = matrix::detail::select::warpsort::kMaxCapacity,
int Veclen = std::max<int>(1, 16 / sizeof(T))>
struct select_interleaved_scan_kernel {
Expand All @@ -990,13 +1011,20 @@ struct select_interleaved_scan_kernel {
{
if constexpr (Capacity > 1) {
if (capacity * 2 <= Capacity) {
return select_interleaved_scan_kernel<T, AccT, IdxT, Capacity / 2, Veclen>::run(
capacity, veclen, select_min, std::forward<Args>(args)...);
return select_interleaved_scan_kernel<T,
AccT,
IdxT,
IvfSampleFilterT,
Capacity / 2,
Veclen>::run(capacity,
veclen,
select_min,
std::forward<Args>(args)...);
}
}
if constexpr (Veclen > 1) {
if (veclen % Veclen != 0) {
return select_interleaved_scan_kernel<T, AccT, IdxT, Capacity, 1>::run(
return select_interleaved_scan_kernel<T, AccT, IdxT, IvfSampleFilterT, Capacity, 1>::run(
capacity, 1, select_min, std::forward<Args>(args)...);
}
}
Expand All @@ -1010,9 +1038,11 @@ struct select_interleaved_scan_kernel {
veclen == Veclen,
"Veclen must be power-of-two not bigger than the maximum allowed size for this data type.");
if (select_min) {
launch_with_fixed_consts<Capacity, Veclen, true, T, AccT, IdxT>(std::forward<Args>(args)...);
launch_with_fixed_consts<Capacity, Veclen, true, T, AccT, IdxT, IvfSampleFilterT>(
std::forward<Args>(args)...);
} else {
launch_with_fixed_consts<Capacity, Veclen, false, T, AccT, IdxT>(std::forward<Args>(args)...);
launch_with_fixed_consts<Capacity, Veclen, false, T, AccT, IdxT, IvfSampleFilterT>(
std::forward<Args>(args)...);
}
}
};
Expand All @@ -1028,6 +1058,9 @@ struct select_interleaved_scan_kernel {
* @param[in] queries device pointer to the query vectors [batch_size, dim]
* @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes]
* @param n_queries batch size
* @param[in] queries_offset
* An offset of the current query batch. It is used for feeding sample_filter with the
* correct query index.
* @param metric type of the measured distance
* @param n_probes number of nearest clusters to query
* @param k number of nearest neighbors.
Expand All @@ -1041,36 +1074,43 @@ struct select_interleaved_scan_kernel {
* @param[inout] grid_dim_x number of blocks launched across all n_probes clusters;
* (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes)
* @param stream
* @param sample_filter
* A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to
* provide a green light for every sample.
*/
template <typename T, typename AccT, typename IdxT>
template <typename T, typename AccT, typename IdxT, typename IvfSampleFilterT>
void ivfflat_interleaved_scan(const index<T, IdxT>& index,
const T* queries,
const uint32_t* coarse_query_results,
const uint32_t n_queries,
const uint32_t queries_offset,
const raft::distance::DistanceType metric,
const uint32_t n_probes,
const uint32_t k,
const bool select_min,
IvfSampleFilterT sample_filter,
IdxT* neighbors,
float* distances,
uint32_t& grid_dim_x,
rmm::cuda_stream_view stream)
{
const int capacity = bound_by_power_of_two(k);
select_interleaved_scan_kernel<T, AccT, IdxT>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
n_probes,
k,
neighbors,
distances,
grid_dim_x,
stream);
select_interleaved_scan_kernel<T, AccT, IdxT, IvfSampleFilterT>::run(capacity,
index.veclen(),
select_min,
metric,
index,
queries,
coarse_query_results,
n_queries,
queries_offset,
n_probes,
k,
sample_filter,
neighbors,
distances,
grid_dim_x,
stream);
}

} // namespace raft::neighbors::ivf_flat::detail
Loading

0 comments on commit 9bf7b4b

Please sign in to comment.