diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index fe425fe8f..f05bebf3f 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -28,8 +28,8 @@ #include "./knn_utils.cuh" #include -#include #include +#include #include #include #include @@ -46,6 +46,7 @@ #include #include #include +#include #include #include @@ -591,10 +592,10 @@ void brute_force_search_filtered( auto nnz_view = raft::make_device_scalar_view(nnz.data()); auto filter_view = raft::make_device_vector_view(filter.data(), filter.n_elements()); + IdxT size_h = n_queries * n_dataset; + auto size_view = raft::make_host_scalar_view(&size_h); - // TODO(rhdong): Need to switch to the public API, - // with the issue: https://github.com/rapidsai/cuvs/issues/158 - raft::detail::popc(res, filter_view, n_queries * n_dataset, nnz_view); + raft::popc(res, filter_view, size_view, nnz_view); raft::copy(&nnz_h, nnz.data(), 1, stream); raft::resource::sync_stream(res, stream); diff --git a/cpp/test/neighbors/brute_force_prefiltered.cu b/cpp/test/neighbors/brute_force_prefiltered.cu index 17166fd7a..2b8ae9d9a 100644 --- a/cpp/test/neighbors/brute_force_prefiltered.cu +++ b/cpp/test/neighbors/brute_force_prefiltered.cu @@ -20,12 +20,13 @@ #include #include -#include +#include #include #include #include #include #include +#include #include @@ -192,12 +193,12 @@ class PrefilteredBruteForceTest auto nnz_view = raft::make_device_scalar_view(nnz.data()); auto filter_view = raft::make_device_vector_view(filter_d.data(), filter_d.size()); + index_t size_h = m * n; + auto size_view = raft::make_host_scalar_view(&size_h); set_bitmap(src, dst, bitmap, n_edges, n, stream); - // TODO(rhdong): Need to switch to the public API, - // with the issue: https://github.com/rapidsai/cuvs/issues/158 - raft::detail::popc(handle, filter_view, m * n, nnz_view); + raft::popc(handle, filter_view, size_view, nnz_view); raft::copy(&nnz_h, nnz.data(), 1, stream); raft::resource::sync_stream(handle, stream);