diff --git a/cpp/src/neighbors/detail/faiss_distance_utils.h b/cpp/src/neighbors/detail/faiss_distance_utils.h index e8a41c1aa..63f0c88c2 100644 --- a/cpp/src/neighbors/detail/faiss_distance_utils.h +++ b/cpp/src/neighbors/detail/faiss_distance_utils.h @@ -14,10 +14,18 @@ inline void chooseTileSize(size_t numQueries, size_t numCentroids, size_t dim, size_t elementSize, - size_t totalMem, size_t& tileRows, size_t& tileCols) { + // 512 seems to be a batch size sweetspot for float32. + // If we are on float16, increase to 512. + // If the k size (vec dim) of the matrix multiplication is small (<= 32), + // increase to 1024. + size_t preferredTileRows = 512; + if (dim <= 32) { preferredTileRows = 1024; } + + tileRows = std::min(preferredTileRows, numQueries); + // The matrix multiplication should be large enough to be efficient, but if // it is too large, we seem to lose efficiency as opposed to // double-streaming. Each tile size here defines 1/2 of the memory use due @@ -25,28 +33,20 @@ inline void chooseTileSize(size_t numQueries, // adjusted independently by the user and can thus meet these requirements // (or not). For <= 4 GB GPUs, prefer 512 MB of usage. For <= 8 GB GPUs, // prefer 768 MB of usage. Otherwise, prefer 1 GB of usage. - size_t targetUsage = 0; - - if (totalMem <= ((size_t)4) * 1024 * 1024 * 1024) { - targetUsage = 512 * 1024 * 1024; - } else if (totalMem <= ((size_t)8) * 1024 * 1024 * 1024) { - targetUsage = 768 * 1024 * 1024; + size_t targetUsage = 512 * 1024 * 1024; + if (tileRows * numCentroids * elementSize * 2 <= targetUsage) { + tileCols = numCentroids; } else { - targetUsage = 1024 * 1024 * 1024; - } + // only query total memory in case it potentially impacts tilesize + size_t totalMem = rmm::available_device_memory().second; - targetUsage /= 2 * elementSize; + if (totalMem > ((size_t)8) * 1024 * 1024 * 1024) { + targetUsage = 1024 * 1024 * 1024; + } else if (totalMem > ((size_t)4) * 1024 * 1024 * 1024) { + targetUsage = 768 * 1024 * 1024; + } - // 512 seems to be a batch size sweetspot for float32. - // If we are on float16, increase to 512. - // If the k size (vec dim) of the matrix multiplication is small (<= 32), - // increase to 1024. - size_t preferredTileRows = 512; - if (dim <= 32) { preferredTileRows = 1024; } - - tileRows = std::min(preferredTileRows, numQueries); - - // tileCols is the remainder size - tileCols = std::min(targetUsage / preferredTileRows, numCentroids); + tileCols = std::min(targetUsage / (2 * elementSize * tileRows), numCentroids); + } } } // namespace cuvs::neighbors::detail::faiss_select diff --git a/cpp/src/neighbors/detail/knn_brute_force.cuh b/cpp/src/neighbors/detail/knn_brute_force.cuh index e3f7acc96..88986af7d 100644 --- a/cpp/src/neighbors/detail/knn_brute_force.cuh +++ b/cpp/src/neighbors/detail/knn_brute_force.cuh @@ -81,14 +81,12 @@ void tiled_brute_force_knn(const raft::resources& handle, const uint32_t* filter_bitmap = nullptr) { // Figure out the number of rows/cols to tile for - size_t tile_rows = 0; - size_t tile_cols = 0; - auto stream = raft::resource::get_cuda_stream(handle); - auto device_memory = raft::resource::get_workspace_resource(handle); - auto total_mem = rmm::available_device_memory().second; + size_t tile_rows = 0; + size_t tile_cols = 0; + auto stream = raft::resource::get_cuda_stream(handle); cuvs::neighbors::detail::faiss_select::chooseTileSize( - m, n, d, sizeof(DistanceT), total_mem, tile_rows, tile_cols); + m, n, d, sizeof(DistanceT), tile_rows, tile_cols); // for unittesting, its convenient to be able to put a max size on the tiles // so we can test the tiling logic without having to use huge inputs. @@ -356,27 +354,26 @@ void brute_force_knn_impl( ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); - std::vector* id_ranges; - if (translations == nullptr) { + std::vector id_ranges; + if (translations != nullptr) { + // use the given translations + id_ranges.insert(id_ranges.end(), translations->begin(), translations->end()); + } else if (input.size() > 1) { // If we don't have explicit translations // for offsets of the indices, build them // from the local partitions - id_ranges = new std::vector(); IdxType total_n = 0; for (size_t i = 0; i < input.size(); i++) { - id_ranges->push_back(total_n); + id_ranges.push_back(total_n); total_n += sizes[i]; } - } else { - // otherwise, use the given translations - id_ranges = translations; } - int device; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - - rmm::device_uvector trans(id_ranges->size(), userStream); - raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); + rmm::device_uvector trans(0, userStream); + if (id_ranges.size() > 0) { + trans.resize(id_ranges.size(), userStream); + raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), userStream); + } rmm::device_uvector all_D(0, userStream); rmm::device_uvector all_I(0, userStream); @@ -513,8 +510,6 @@ void brute_force_knn_impl( // no translations or partitions to combine, it can be skipped. knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); } - - if (translations == nullptr) delete id_ranges; }; template