From 4f574d59b047a133eb4f56e6202b24c2ee26fb31 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 3 May 2023 13:13:45 -0700 Subject: [PATCH] Integrate RAFT FlatIndex / Distances (#2707) Summary: I've broken out the FlatIndex / Distances changes from https://github.com/facebookresearch/faiss/issues/2521 into a separate PR to make things a litle easier to review. This does also include a couple other minor changes to the IVF Flat index which I used to make it easier to dispatch to the RAFT version. I can revert that change too if desired. Pull Request resolved: https://github.com/facebookresearch/faiss/pull/2707 Reviewed By: wickedfoo Differential Revision: D44758912 Pulled By: algoriddle fbshipit-source-id: b2544990b4e941a2558f5004bceec4af4fa1ad09 --- CMakeLists.txt | 44 +++++- INSTALL.md | 3 + cmake/thirdparty/fetch_rapids.cmake | 24 ++++ cmake/thirdparty/get_cutlass.cmake | 107 +++++++++++++++ cmake/thirdparty/get_raft.cmake | 51 +++++++ faiss/gpu/CMakeLists.txt | 31 ++++- faiss/gpu/GpuDistance.cu | 178 ++++++++++++++++++++++++- faiss/gpu/GpuDistance.h | 3 + faiss/gpu/GpuIndex.h | 18 +++ faiss/gpu/GpuIndexFlat.cu | 44 +++--- faiss/gpu/GpuIndexFlat.h | 2 + faiss/gpu/GpuResources.cpp | 21 +++ faiss/gpu/GpuResources.h | 27 ++++ faiss/gpu/StandardGpuResources.cpp | 85 +++++++++++- faiss/gpu/StandardGpuResources.h | 63 ++++++++- faiss/gpu/impl/FlatIndex.cu | 15 +++ faiss/gpu/impl/FlatIndex.cuh | 21 ++- faiss/gpu/impl/RaftFlatIndex.cu | 157 ++++++++++++++++++++++ faiss/gpu/impl/RaftFlatIndex.cuh | 69 ++++++++++ faiss/gpu/impl/RaftUtils.h | 55 ++++++++ faiss/gpu/test/CMakeLists.txt | 16 ++- faiss/gpu/test/TestGpuDistance.cu | 168 ++++++++++++++++++++--- faiss/gpu/test/TestGpuIndexFlat.cpp | 138 +++++++++++++++++-- faiss/gpu/test/TestGpuIndexIVFFlat.cpp | 15 +++ 24 files changed, 1293 insertions(+), 62 deletions(-) create mode 100644 cmake/thirdparty/fetch_rapids.cmake create mode 100644 cmake/thirdparty/get_cutlass.cmake create mode 100644 cmake/thirdparty/get_raft.cmake create mode 100644 faiss/gpu/impl/RaftFlatIndex.cu create mode 100644 faiss/gpu/impl/RaftFlatIndex.cuh create mode 100644 faiss/gpu/impl/RaftUtils.h diff --git a/CMakeLists.txt b/CMakeLists.txt index fa1e312456..17c8d7ee3d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,22 +4,58 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) +set(FAISS_LANGUAGES CXX) + +if(FAISS_ENABLE_GPU) + list(APPEND FAISS_LANGUAGES CUDA) +endif() + +if(FAISS_ENABLE_RAFT) +include(cmake/thirdparty/fetch_rapids.cmake) +include(rapids-cmake) +include(rapids-cpm) +include(rapids-cuda) +include(rapids-export) +include(rapids-find) + +rapids_cuda_init_architectures(faiss) +endif() + project(faiss VERSION 1.7.4 DESCRIPTION "A library for efficient similarity search and clustering of dense vectors." HOMEPAGE_URL "https://github.com/facebookresearch/faiss" - LANGUAGES CXX) + LANGUAGES ${FAISS_LANGUAGES}) include(GNUInstallDirs) +if(FAISS_ENABLE_RAFT) +set(CMAKE_CXX_STANDARD 17) +else() set(CMAKE_CXX_STANDARD 11) +endif() list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake") # Valid values are "generic", "avx2". option(FAISS_OPT_LEVEL "" "generic") option(FAISS_ENABLE_GPU "Enable support for GPU indexes." ON) +option(FAISS_ENABLE_RAFT "Enable RAFT for GPU indexes." OFF) option(FAISS_ENABLE_PYTHON "Build Python extension." ON) option(FAISS_ENABLE_C_API "Build C API." OFF) @@ -28,6 +64,12 @@ if(FAISS_ENABLE_GPU) enable_language(CUDA) endif() +if(FAISS_ENABLE_RAFT) + rapids_cpm_init() + include(cmake/thirdparty/get_raft.cmake) + include(cmake/thirdparty/get_cutlass.cmake) +endif() + add_subdirectory(faiss) if(FAISS_ENABLE_GPU) diff --git a/INSTALL.md b/INSTALL.md index 1ff021d82e..c82f0dd507 100644 --- a/INSTALL.md +++ b/INSTALL.md @@ -105,6 +105,9 @@ Several options can be passed to CMake, among which: values are `ON` and `OFF`), - `-DFAISS_ENABLE_PYTHON=OFF` in order to disable building python bindings (possible values are `ON` and `OFF`), + - `-DFAISS_ENABLE_RAFT=ON` in order to enable building the RAFT implementations + of the IVF-Flat and IVF-PQ GPU-accelerated indices (default is `OFF`, possible + values are `ON` and `OFF`) - `-DBUILD_TESTING=OFF` in order to disable building C++ tests, - `-DBUILD_SHARED_LIBS=ON` in order to build a shared library (possible values are `ON` and `OFF`), diff --git a/cmake/thirdparty/fetch_rapids.cmake b/cmake/thirdparty/fetch_rapids.cmake new file mode 100644 index 0000000000..013bdc90b3 --- /dev/null +++ b/cmake/thirdparty/fetch_rapids.cmake @@ -0,0 +1,24 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= +set(RAPIDS_VERSION "23.02") + +if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake) + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake + ${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake) +endif() +include(${CMAKE_CURRENT_BINARY_DIR}/FAISS_RAPIDS.cmake) diff --git a/cmake/thirdparty/get_cutlass.cmake b/cmake/thirdparty/get_cutlass.cmake new file mode 100644 index 0000000000..df605d887d --- /dev/null +++ b/cmake/thirdparty/get_cutlass.cmake @@ -0,0 +1,107 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# ============================================================================= +# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +function(find_and_configure_cutlass) + set(oneValueArgs VERSION REPOSITORY PINNED_TAG) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + # if(RAFT_ENABLE_DIST_DEPENDENCIES OR RAFT_COMPILE_LIBRARIES) + set(CUTLASS_ENABLE_HEADERS_ONLY + ON + CACHE BOOL "Enable only the header library" + ) + set(CUTLASS_NAMESPACE + "raft_cutlass" + CACHE STRING "Top level namespace of CUTLASS" + ) + set(CUTLASS_ENABLE_CUBLAS + OFF + CACHE BOOL "Disable CUTLASS to build with cuBLAS library." + ) + + if (CUDA_STATIC_RUNTIME) + set(CUDART_LIBRARY "${CUDA_cudart_static_LIBRARY}" CACHE FILEPATH "fixing cutlass cmake code" FORCE) + endif() + + rapids_cpm_find( + NvidiaCutlass ${PKG_VERSION} + GLOBAL_TARGETS nvidia::cutlass::cutlass + CPM_ARGS + GIT_REPOSITORY ${PKG_REPOSITORY} + GIT_TAG ${PKG_PINNED_TAG} + GIT_SHALLOW TRUE + OPTIONS "CUDAToolkit_ROOT ${CUDAToolkit_LIBRARY_DIR}" + ) + + if(TARGET CUTLASS AND NOT TARGET nvidia::cutlass::cutlass) + add_library(nvidia::cutlass::cutlass ALIAS CUTLASS) + endif() + + if(NvidiaCutlass_ADDED) + rapids_export( + BUILD NvidiaCutlass + EXPORT_SET NvidiaCutlass + GLOBAL_TARGETS nvidia::cutlass::cutlass + NAMESPACE nvidia::cutlass:: + ) + endif() + # endif() + + # We generate the cutlass-config files when we built cutlass locally, so always do + # `find_dependency` + rapids_export_package( + BUILD NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + INSTALL NvidiaCutlass raft-distance-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + BUILD NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + rapids_export_package( + INSTALL NvidiaCutlass raft-nn-exports GLOBAL_TARGETS nvidia::cutlass::cutlass + ) + + # Tell cmake where it can find the generated NvidiaCutlass-config.cmake we wrote. + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-distance-exports + ) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-distance-exports + ) + include("${rapids-cmake-dir}/export/find_package_root.cmake") + rapids_export_find_package_root( + INSTALL NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}/../]=] raft-nn-exports + ) + rapids_export_find_package_root( + BUILD NvidiaCutlass [=[${CMAKE_CURRENT_LIST_DIR}]=] raft-nn-exports + ) +endfunction() + +if(NOT RAFT_CUTLASS_GIT_TAG) + set(RAFT_CUTLASS_GIT_TAG v2.9.1) +endif() + +if(NOT RAFT_CUTLASS_GIT_REPOSITORY) + set(RAFT_CUTLASS_GIT_REPOSITORY https://github.com/NVIDIA/cutlass.git) +endif() + +find_and_configure_cutlass( + VERSION 2.9.1 REPOSITORY ${RAFT_CUTLASS_GIT_REPOSITORY} PINNED_TAG ${RAFT_CUTLASS_GIT_TAG} +) diff --git a/cmake/thirdparty/get_raft.cmake b/cmake/thirdparty/get_raft.cmake new file mode 100644 index 0000000000..df5aa448e4 --- /dev/null +++ b/cmake/thirdparty/get_raft.cmake @@ -0,0 +1,51 @@ +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +# ============================================================================= +# Copyright (c) 2021-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= + +set(RAFT_VERSION "${RAPIDS_VERSION}") +set(RAFT_FORK "rapidsai") +set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + +function(find_and_configure_raft) + set(oneValueArgs VERSION FORK PINNED_TAG) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + #----------------------------------------------------- + # Invoke CPM find_package() + #----------------------------------------------------- + rapids_cpm_find(raft ${PKG_VERSION} + GLOBAL_TARGETS raft::raft + BUILD_EXPORT_SET faiss-exports + INSTALL_EXPORT_SET faiss-exports + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + OPTIONS + "BUILD_TESTS OFF" + "RAFT_COMPILE_LIBRARY OFF" + ) +endfunction() + +# Change pinned tag here to test a commit in CI +# To use a different RAFT locally, set the CMake variable +# CPM_raft_SOURCE=/path/to/local/raft +find_and_configure_raft(VERSION ${RAFT_VERSION}.00 + FORK ${RAFT_FORK} + PINNED_TAG ${RAFT_PINNED_TAG} + ) diff --git a/faiss/gpu/CMakeLists.txt b/faiss/gpu/CMakeLists.txt index 6fb694764b..cfe4e0b195 100644 --- a/faiss/gpu/CMakeLists.txt +++ b/faiss/gpu/CMakeLists.txt @@ -3,6 +3,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= set(FAISS_GPU_SRC GpuAutoTune.cpp @@ -163,6 +176,16 @@ set(FAISS_GPU_HEADERS utils/warpselect/WarpSelectImpl.cuh ) +if(FAISS_ENABLE_RAFT) + list(APPEND FAISS_GPU_HEADERS + impl/RaftFlatIndex.cuh) + list(APPEND FAISS_GPU_SRC + impl/RaftFlatIndex.cu) + + target_compile_definitions(faiss PUBLIC USE_NVIDIA_RAFT=1) + target_compile_definitions(faiss_avx2 PUBLIC USE_NVIDIA_RAFT=1) +endif() + # Export FAISS_GPU_HEADERS variable to parent scope. set(FAISS_GPU_HEADERS ${FAISS_GPU_HEADERS} PARENT_SCOPE) @@ -177,7 +200,7 @@ foreach(header ${FAISS_GPU_HEADERS}) endforeach() find_package(CUDAToolkit REQUIRED) -target_link_libraries(faiss PRIVATE CUDA::cudart CUDA::cublas) -target_link_libraries(faiss_avx2 PRIVATE CUDA::cudart CUDA::cublas) -target_compile_options(faiss PRIVATE $<$:-Xfatbin=-compress-all>) -target_compile_options(faiss_avx2 PRIVATE $<$:-Xfatbin=-compress-all>) +target_link_libraries(faiss PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) +target_link_libraries(faiss_avx2 PRIVATE CUDA::cudart CUDA::cublas $<$:raft::raft> $<$:nvidia::cutlass::cutlass>) +target_compile_options(faiss PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr>) +target_compile_options(faiss_avx2 PRIVATE $<$:-Xfatbin=-compress-all --expt-extended-lambda --expt-relaxed-constexpr>) diff --git a/faiss/gpu/GpuDistance.cu b/faiss/gpu/GpuDistance.cu index 41596fa220..daf4710aec 100644 --- a/faiss/gpu/GpuDistance.cu +++ b/faiss/gpu/GpuDistance.cu @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -14,9 +29,27 @@ #include #include +#if defined USE_NVIDIA_RAFT +#include +#include +#include +#include +#include +#include +#include +#include +#include +#define RAFT_NAME "raft" +#endif + namespace faiss { namespace gpu { +#if defined USE_NVIDIA_RAFT +using namespace raft::distance; +using namespace raft::neighbors; +#endif + template void bfKnnConvert(GpuResourcesProvider* prov, const GpuDistanceParams& args) { // Validate the input data @@ -165,7 +198,6 @@ void bfKnnConvert(GpuResourcesProvider* prov, const GpuDistanceParams& args) { tOutDistances, tIntIndices, args.ignoreOutDistances); - // Convert and copy int indices out auto tOutIntIndices = toDeviceTemporary( res, @@ -186,17 +218,153 @@ void bfKnnConvert(GpuResourcesProvider* prov, const GpuDistanceParams& args) { fromDevice(tOutDistances, args.outDistances, stream); } -void bfKnn(GpuResourcesProvider* res, const GpuDistanceParams& args) { +void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) { // For now, both vectors and queries must be of the same data type FAISS_THROW_IF_NOT_MSG( args.vectorType == args.queryType, "limitation: both vectorType and queryType must currently " "be the same (F32 or F16"); - if (args.vectorType == DistanceDataType::F32) { - bfKnnConvert(res, args); +#if defined USE_NVIDIA_RAFT + // Note: For now, RAFT bfknn requires queries and vectors to be same layout + if (args.use_raft && args.queriesRowMajor == args.vectorsRowMajor) { + DistanceType distance = faiss_to_raft(args.metric, false); + + auto resImpl = prov->getResources(); + auto res = resImpl.get(); + raft::device_resources& handle = res->getRaftHandleCurrentDevice(); + auto stream = res->getDefaultStreamCurrentDevice(); + + idx_t dims = args.dims; + idx_t num_vectors = args.numVectors; + idx_t num_queries = args.numQueries; + int k = args.k; + float metric_arg = args.metricArg; + + auto inds = raft::make_writeback_temporary_device_buffer( + handle, + reinterpret_cast(args.outIndices), + raft::matrix_extent(num_queries, (idx_t)k)); + auto dists = raft::make_writeback_temporary_device_buffer( + handle, + reinterpret_cast(args.outDistances), + raft::matrix_extent(num_queries, (idx_t)k)); + + if (args.queriesRowMajor) { + auto index = raft::make_readonly_temporary_device_buffer< + const float, + idx_t, + raft::row_major>( + handle, + const_cast( + reinterpret_cast(args.vectors)), + raft::matrix_extent(num_vectors, dims)); + + auto search = raft::make_readonly_temporary_device_buffer< + const float, + idx_t, + raft::row_major>( + handle, + const_cast( + reinterpret_cast(args.queries)), + raft::matrix_extent(num_queries, dims)); + + // For now, use RAFT's fused KNN when k <= 64 and L2 metric is used + if (args.k <= 64 && args.metric == MetricType::METRIC_L2 && + args.numVectors > 0) { + RAFT_LOG_INFO("Invoking flat fused_l2_knn"); + brute_force::fused_l2_knn( + handle, + index.view(), + search.view(), + inds.view(), + dists.view(), + distance); + } else { + std::vector> + index_vec = {index.view()}; + RAFT_LOG_INFO("Invoking flat bfknn"); + brute_force::knn( + handle, + index_vec, + search.view(), + inds.view(), + dists.view(), + k, + distance, + metric_arg); + } + } else { + auto index = raft::make_readonly_temporary_device_buffer< + const float, + idx_t, + raft::col_major>( + handle, + const_cast( + reinterpret_cast(args.vectors)), + raft::matrix_extent(num_vectors, dims)); + + auto search = raft::make_readonly_temporary_device_buffer< + const float, + idx_t, + raft::col_major>( + handle, + const_cast( + reinterpret_cast(args.queries)), + raft::matrix_extent(num_queries, dims)); + + std::vector> + index_vec = {index.view()}; + RAFT_LOG_INFO("Invoking flat bfknn"); + brute_force::knn( + handle, + index_vec, + search.view(), + inds.view(), + dists.view(), + k, + distance, + metric_arg); + } + + if (args.metric == MetricType::METRIC_Lp) { + raft::linalg::unary_op( + handle, + raft::make_const_mdspan(dists.view()), + dists.view(), + [metric_arg] __device__(const float& a) { + return powf(a, metric_arg); + }); + } else if (args.metric == MetricType::METRIC_JensenShannon) { + raft::linalg::unary_op( + handle, + raft::make_const_mdspan(dists.view()), + dists.view(), + [] __device__(const float& a) { return powf(a, 2); }); + } + + RAFT_LOG_INFO("Done."); + + handle.sync_stream(); + RAFT_LOG_INFO("All synced."); + } else +#else + if (args.use_raft) { + FAISS_THROW_IF_NOT_MSG( + !args.use_raft, + "RAFT has not been compiled into the current version so it cannot be used."); + } else +#endif + if (args.vectorType == DistanceDataType::F32) { + bfKnnConvert(prov, args); } else if (args.vectorType == DistanceDataType::F16) { - bfKnnConvert(res, args); + bfKnnConvert(prov, args); } else { FAISS_THROW_MSG("unknown vectorType"); } diff --git a/faiss/gpu/GpuDistance.h b/faiss/gpu/GpuDistance.h index d69edcf9f7..3d9a318990 100644 --- a/faiss/gpu/GpuDistance.h +++ b/faiss/gpu/GpuDistance.h @@ -124,6 +124,9 @@ struct GpuDistanceParams { /// Otherwise, an integer 0 <= device < numDevices indicates the device for /// execution int device; + + /// Should the index dispatch down to RAFT? + bool use_raft = false; }; /// A wrapper for gpu/impl/Distance.cuh to expose direct brute-force k-nearest diff --git a/faiss/gpu/GpuIndex.h b/faiss/gpu/GpuIndex.h index 0ab2e7983e..8f981ccd74 100644 --- a/faiss/gpu/GpuIndex.h +++ b/faiss/gpu/GpuIndex.h @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once @@ -23,6 +38,9 @@ struct GpuIndexConfig { /// On Pascal and above (CC 6+) architectures, allows GPUs to use /// more memory than is available on the GPU. MemorySpace memorySpace; + + /// Should the index dispatch down to RAFT? + bool use_raft = false; }; class GpuIndex : public faiss::Index { diff --git a/faiss/gpu/GpuIndexFlat.cu b/faiss/gpu/GpuIndexFlat.cu index 833f9ff3eb..ef5757fbbd 100644 --- a/faiss/gpu/GpuIndexFlat.cu +++ b/faiss/gpu/GpuIndexFlat.cu @@ -12,6 +12,7 @@ #include #include #include +#include #include #include #include @@ -67,11 +68,7 @@ GpuIndexFlat::GpuIndexFlat( this->is_trained = true; // Construct index - data_.reset(new FlatIndex( - resources_.get(), - dims, - flatConfig_.useFloat16, - config_.memorySpace)); + resetIndex_(dims); } GpuIndexFlat::GpuIndexFlat( @@ -86,26 +83,43 @@ GpuIndexFlat::GpuIndexFlat( this->is_trained = true; // Construct index - data_.reset(new FlatIndex( - resources_.get(), - dims, - flatConfig_.useFloat16, - config_.memorySpace)); + resetIndex_(dims); } GpuIndexFlat::~GpuIndexFlat() {} +void GpuIndexFlat::resetIndex_(int dims) { +#if defined USE_NVIDIA_RAFT + + if (flatConfig_.use_raft) { + data_.reset(new RaftFlatIndex( + resources_.get(), + dims, + flatConfig_.useFloat16, + config_.memorySpace)); + } else +#else + if (flatConfig_.use_raft) { + FAISS_THROW_MSG( + "RAFT has not been compiled into the current version so it cannot be used."); + } else +#endif + { + data_.reset(new FlatIndex( + resources_.get(), + dims, + flatConfig_.useFloat16, + config_.memorySpace)); + } +} + void GpuIndexFlat::copyFrom(const faiss::IndexFlat* index) { DeviceScope scope(config_.device); GpuIndex::copyFrom(index); data_.reset(); - data_.reset(new FlatIndex( - resources_.get(), - this->d, - flatConfig_.useFloat16, - config_.memorySpace)); + resetIndex_(this->d); // The index could be empty if (index->ntotal > 0) { diff --git a/faiss/gpu/GpuIndexFlat.h b/faiss/gpu/GpuIndexFlat.h index bf933bc772..514220039d 100644 --- a/faiss/gpu/GpuIndexFlat.h +++ b/faiss/gpu/GpuIndexFlat.h @@ -115,6 +115,8 @@ class GpuIndexFlat : public GpuIndex { } protected: + void resetIndex_(int dims); + /// Flat index does not require IDs as there is no storage available for /// them bool addImplRequiresIDs_() const override; diff --git a/faiss/gpu/GpuResources.cpp b/faiss/gpu/GpuResources.cpp index b3dca0895d..e745263dec 100644 --- a/faiss/gpu/GpuResources.cpp +++ b/faiss/gpu/GpuResources.cpp @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -153,6 +168,12 @@ cudaStream_t GpuResources::getDefaultStreamCurrentDevice() { return getDefaultStream(getCurrentDevice()); } +#if defined USE_NVIDIA_RAFT +raft::device_resources& GpuResources::getRaftHandleCurrentDevice() { + return getRaftHandle(getCurrentDevice()); +} +#endif + std::vector GpuResources::getAlternateStreamsCurrentDevice() { return getAlternateStreams(getCurrentDevice()); } diff --git a/faiss/gpu/GpuResources.h b/faiss/gpu/GpuResources.h index 3ae2dfbe19..5177065374 100644 --- a/faiss/gpu/GpuResources.h +++ b/faiss/gpu/GpuResources.h @@ -4,16 +4,36 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once #include #include #include + #include #include #include +#if defined USE_NVIDIA_RAFT +#include +#endif + namespace faiss { namespace gpu { @@ -190,6 +210,13 @@ class GpuResources { /// given device virtual cudaStream_t getDefaultStream(int device) = 0; +#if defined USE_NVIDIA_RAFT + /// Returns the raft handle for the given device which can be used to + /// make calls to other raft primitives. + virtual raft::device_resources& getRaftHandle(int device) = 0; + raft::device_resources& getRaftHandleCurrentDevice(); +#endif + /// Overrides the default stream for a device to the user-supplied stream. /// The resources object does not own this stream (i.e., it will not destroy /// it). diff --git a/faiss/gpu/StandardGpuResources.cpp b/faiss/gpu/StandardGpuResources.cpp index 80146a2e59..af0f24c51e 100644 --- a/faiss/gpu/StandardGpuResources.cpp +++ b/faiss/gpu/StandardGpuResources.cpp @@ -4,6 +4,29 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#if defined USE_NVIDIA_RAFT +#include +#include +#include +#include + +#endif #include #include @@ -74,7 +97,15 @@ StandardGpuResourcesImpl::StandardGpuResourcesImpl() -1, std::numeric_limits::max())), pinnedMemSize_(kDefaultPinnedMemoryAllocation), - allocLogging_(false) {} + allocLogging_(false) +#if defined USE_NVIDIA_RAFT + , + cmr(new rmm::mr::cuda_memory_resource), + mmr(new rmm::mr::managed_memory_resource), + pmr(new rmm::mr::pinned_memory_resource) +#endif +{ +} StandardGpuResourcesImpl::~StandardGpuResourcesImpl() { // The temporary memory allocator has allocated memory through us, so clean @@ -129,6 +160,9 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() { } if (pinnedMemAlloc_) { +#if defined USE_NVIDIA_RAFT + pmr->deallocate(pinnedMemAlloc_, pinnedMemAllocSize_); +#else auto err = cudaFreeHost(pinnedMemAlloc_); FAISS_ASSERT_FMT( err == cudaSuccess, @@ -136,6 +170,7 @@ StandardGpuResourcesImpl::~StandardGpuResourcesImpl() { pinnedMemAlloc_, (int)err, cudaGetErrorString(err)); +#endif } } @@ -274,6 +309,14 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) { // If this is the first device that we're initializing, create our // pinned memory allocation if (defaultStreams_.empty() && pinnedMemSize_ > 0) { +#if defined USE_NVIDIA_RAFT + // If this is the first device that we're initializing, create our + // pinned memory allocation + if (defaultStreams_.empty() && pinnedMemSize_ > 0) { + pinnedMemAlloc_ = pmr->allocate(pinnedMemSize_); + pinnedMemAllocSize_ = pinnedMemSize_; + } +#else auto err = cudaHostAlloc( &pinnedMemAlloc_, pinnedMemSize_, cudaHostAllocDefault); @@ -286,6 +329,7 @@ void StandardGpuResourcesImpl::initializeForDevice(int device) { cudaGetErrorString(err)); pinnedMemAllocSize_ = pinnedMemSize_; +#endif } // Make sure that device properties for all devices are cached @@ -375,6 +419,22 @@ cudaStream_t StandardGpuResourcesImpl::getDefaultStream(int device) { return defaultStreams_[device]; } +#if defined USE_NVIDIA_RAFT +raft::device_resources& StandardGpuResourcesImpl::getRaftHandle(int device) { + initializeForDevice(device); + + auto it = raftHandles_.find(device); + if (it == raftHandles_.end()) { + // Make sure we are using the stream the user may have already assigned + // to the current GpuResources + raftHandles_.emplace(std::make_pair(device, getDefaultStream(device))); + } + + // Otherwise, our base default handle + return raftHandles_[device]; +} +#endif + std::vector StandardGpuResourcesImpl::getAlternateStreams( int device) { initializeForDevice(device); @@ -430,6 +490,9 @@ void* StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) { p = tempMemory_[adjReq.device]->allocMemory(adjReq.stream, adjReq.size); } else if (adjReq.space == MemorySpace::Device) { +#if defined USE_NVIDIA_RAFT + p = cmr->allocate(adjReq.size, adjReq.stream); +#else auto err = cudaMalloc(&p, adjReq.size); // Throw if we fail to allocate @@ -451,7 +514,11 @@ void* StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) { FAISS_THROW_IF_NOT_FMT(err == cudaSuccess, "%s", str.c_str()); } +#endif } else if (adjReq.space == MemorySpace::Unified) { +#if defined USE_NVIDIA_RAFT + p = mmr->allocate(adjReq.size, adjReq.stream); +#else auto err = cudaMallocManaged(&p, adjReq.size); if (err != cudaSuccess) { @@ -472,6 +539,7 @@ void* StandardGpuResourcesImpl::allocMemory(const AllocRequest& req) { FAISS_THROW_IF_NOT_FMT(err == cudaSuccess, "%s", str.c_str()); } +#endif } else { FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int)adjReq.space); } @@ -509,6 +577,13 @@ void StandardGpuResourcesImpl::deallocMemory(int device, void* p) { } else if ( req.space == MemorySpace::Device || req.space == MemorySpace::Unified) { +#if defined USE_NVIDIA_RAFT + if (req.space == MemorySpace::Device) { + cmr->deallocate(p, req.size, req.stream); + } else if (req.space == MemorySpace::Unified) { + mmr->deallocate(p, req.size, req.stream); + } +#else auto err = cudaFree(p); FAISS_ASSERT_FMT( err == cudaSuccess, @@ -516,7 +591,7 @@ void StandardGpuResourcesImpl::deallocMemory(int device, void* p) { p, (int)err, cudaGetErrorString(err)); - +#endif } else { FAISS_ASSERT_FMT(false, "unknown MemorySpace %d", (int)req.space); } @@ -600,6 +675,12 @@ cudaStream_t StandardGpuResources::getDefaultStream(int device) { return res_->getDefaultStream(device); } +#if defined USE_NVIDIA_RAFT +raft::device_resources& StandardGpuResources::getRaftHandle(int device) { + return res_->getRaftHandle(device); +} +#endif + size_t StandardGpuResources::getTempMemoryAvailable(int device) const { return res_->getTempMemoryAvailable(device); } diff --git a/faiss/gpu/StandardGpuResources.h b/faiss/gpu/StandardGpuResources.h index d1edfb6673..9113de573c 100644 --- a/faiss/gpu/StandardGpuResources.h +++ b/faiss/gpu/StandardGpuResources.h @@ -4,9 +4,31 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once +#if defined USE_NVIDIA_RAFT +#include +#include +#include +#include +#endif + #include #include #include @@ -58,6 +80,12 @@ class StandardGpuResourcesImpl : public GpuResources { /// this stream upon exit from an index or other Faiss GPU call. cudaStream_t getDefaultStream(int device) override; +#if defined USE_NVIDIA_RAFT + /// Returns the raft handle for the given device which can be used to + /// make calls to other raft primitives. + raft::device_resources& getRaftHandle(int device) override; +#endif + /// Called to change the work ordering streams to the null stream /// for all devices void setDefaultNullStreamAllDevices(); @@ -92,7 +120,7 @@ class StandardGpuResourcesImpl : public GpuResources { cudaStream_t getAsyncCopyStream(int device) override; - private: + protected: /// Have GPU resources been initialized for this device yet? bool isInitialized(int device) const; @@ -100,7 +128,7 @@ class StandardGpuResourcesImpl : public GpuResources { /// memory size static size_t getDefaultTempMemForGPU(int device, size_t requested); - private: + protected: /// Set of currently outstanding memory allocations per device /// device -> (alloc request, allocated ptr) std::unordered_map> allocs_; @@ -124,6 +152,30 @@ class StandardGpuResourcesImpl : public GpuResources { /// cuBLAS handle for each device std::unordered_map blasHandles_; +#if defined USE_NVIDIA_RAFT + /// raft handle for each device + std::unordered_map raftHandles_; + + /** + * FIXME: Integrating these in a separate code path for now. Ultimately, + * it would be nice if we use a simple memory resource abstraction + * in FAISS so we could plug in whether to use RMM's memory resources + * or the default. + * + * There's enough duplicated logic that it doesn't *seem* to make sense + * to create a subclass only for the RMM memory resources. + */ + + // cuda_memory_resource + std::unique_ptr cmr; + + // managed_memory_resource + std::unique_ptr mmr; + + // pinned_memory_resource + std::unique_ptr pmr; +#endif + /// Pinned memory allocation for use with this GPU void* pinnedMemAlloc_; size_t pinnedMemAllocSize_; @@ -183,10 +235,15 @@ class StandardGpuResources : public GpuResourcesProvider { /// Export a description of memory used for Python std::map>> getMemoryInfo() const; - /// Returns the current default stream cudaStream_t getDefaultStream(int device); +#if defined USE_NVIDIA_RAFT + /// Returns the raft handle for the given device which can be used to + /// make calls to other raft primitives. + raft::device_resources& getRaftHandle(int device); +#endif + /// Returns the current amount of temp memory available size_t getTempMemoryAvailable(int device) const; diff --git a/faiss/gpu/impl/FlatIndex.cu b/faiss/gpu/impl/FlatIndex.cu index 95ae320a0a..64c4a3d7c0 100644 --- a/faiss/gpu/impl/FlatIndex.cu +++ b/faiss/gpu/impl/FlatIndex.cu @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include diff --git a/faiss/gpu/impl/FlatIndex.cuh b/faiss/gpu/impl/FlatIndex.cuh index 66f611dcbf..d1610f7244 100644 --- a/faiss/gpu/impl/FlatIndex.cuh +++ b/faiss/gpu/impl/FlatIndex.cuh @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #pragma once @@ -44,7 +59,7 @@ class FlatIndex { /// Returns a reference to our vectors currently in use (if useFloat16 mode) Tensor& getVectorsFloat16Ref(); - void query( + virtual void query( Tensor& vecs, int k, faiss::MetricType metric, @@ -53,7 +68,7 @@ class FlatIndex { Tensor& outIndices, bool exactDistance); - void query( + virtual void query( Tensor& vecs, int k, faiss::MetricType metric, @@ -81,7 +96,7 @@ class FlatIndex { /// Free all storage void reset(); - private: + protected: /// Collection of GPU resources that we use GpuResources* resources_; diff --git a/faiss/gpu/impl/RaftFlatIndex.cu b/faiss/gpu/impl/RaftFlatIndex.cu new file mode 100644 index 0000000000..d407c68680 --- /dev/null +++ b/faiss/gpu/impl/RaftFlatIndex.cu @@ -0,0 +1,157 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include + +#include +#include +#include +#include + +#define RAFT_NAME "raft" + +namespace faiss { +namespace gpu { + +using namespace raft::distance; +using namespace raft::neighbors; + +RaftFlatIndex::RaftFlatIndex( + GpuResources* res, + int dim, + bool useFloat16, + MemorySpace space) + : FlatIndex(res, dim, useFloat16, space) {} + +void RaftFlatIndex::query( + Tensor& input, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance) { + /** + * RAFT doesn't yet support half-precision in bfknn. + * Use FlatIndex for float16 for now + */ + if (useFloat16_) { + auto stream = resources_->getDefaultStreamCurrentDevice(); + + // We need to convert the input to float16 for comparison to ourselves + auto inputHalf = convertTensorTemporary( + resources_, stream, input); + + FlatIndex::query( + inputHalf, + k, + metric, + metricArg, + outDistances, + outIndices, + exactDistance); + } else { + raft::device_resources& handle = + resources_->getRaftHandleCurrentDevice(); + + auto index = raft::make_device_matrix_view( + vectors_.data(), vectors_.getSize(0), vectors_.getSize(1)); + auto search = raft::make_device_matrix_view( + input.data(), input.getSize(0), input.getSize(1)); + auto inds = raft::make_device_matrix_view( + outIndices.data(), + outIndices.getSize(0), + outIndices.getSize(1)); + auto dists = raft::make_device_matrix_view( + outDistances.data(), + outDistances.getSize(0), + outDistances.getSize(1)); + + DistanceType distance = faiss_to_raft(metric, exactDistance); + + std::vector> index_vec = { + index}; + + // For now, use RAFT's fused KNN when k <= 64 and L2 metric is used + if (k <= 64 && metric == MetricType::METRIC_L2 && + vectors_.getSize(0) > 0) { + RAFT_LOG_INFO("Invoking flat fused_l2_knn"); + brute_force::fused_l2_knn( + handle, index, search, inds, dists, distance); + } else { + RAFT_LOG_INFO("Invoking flat bfknn"); + brute_force::knn( + handle, + index_vec, + search, + inds, + dists, + k, + distance, + metricArg); + } + + if (metric == MetricType::METRIC_Lp) { + raft::linalg::unary_op( + handle, + raft::make_const_mdspan(dists), + dists, + [metricArg] __device__(const float& a) { + return powf(a, metricArg); + }); + } else if (metric == MetricType::METRIC_JensenShannon) { + raft::linalg::unary_op( + handle, + raft::make_const_mdspan(dists), + dists, + [] __device__(const float& a) { return powf(a, 2); }); + } + } +} + +void RaftFlatIndex::query( + Tensor& vecs, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance) { + FAISS_ASSERT(useFloat16_); + + // FIXME: ref https://github.com/rapidsai/raft/issues/1280 + FlatIndex::query( + vecs, + k, + metric, + metricArg, + outDistances, + outIndices, + exactDistance); +} + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/impl/RaftFlatIndex.cuh b/faiss/gpu/impl/RaftFlatIndex.cuh new file mode 100644 index 0000000000..010c5aebce --- /dev/null +++ b/faiss/gpu/impl/RaftFlatIndex.cuh @@ -0,0 +1,69 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include + +namespace faiss { +namespace gpu { + +class GpuResources; + +/// Holder of GPU resources for a particular flat index +/// Can be in either float16 or float32 mode. If float32, we only store +/// the vectors in float32. +/// If float16, we store the vectors in both float16 and float32, where float32 +/// data is possibly needed for certain residual operations +class RaftFlatIndex : public FlatIndex { + public: + RaftFlatIndex( + GpuResources* res, + int dim, + bool useFloat16, + MemorySpace space); + + void query( + Tensor& vecs, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance) override; + + void query( + Tensor& vecs, + int k, + faiss::MetricType metric, + float metricArg, + Tensor& outDistances, + Tensor& outIndices, + bool exactDistance) override; +}; + +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/impl/RaftUtils.h b/faiss/gpu/impl/RaftUtils.h new file mode 100644 index 0000000000..77c47999a5 --- /dev/null +++ b/faiss/gpu/impl/RaftUtils.h @@ -0,0 +1,55 @@ +/** + * Copyright (c) Facebook, Inc. and its affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace faiss { +namespace gpu { + +raft::distance::DistanceType faiss_to_raft( + MetricType metric, + bool exactDistance) { + switch (metric) { + case MetricType::METRIC_INNER_PRODUCT: + return raft::distance::DistanceType::InnerProduct; + case MetricType::METRIC_L2: + return exactDistance ? raft::distance::DistanceType::L2Unexpanded + : raft::distance::DistanceType::L2Expanded; + case MetricType::METRIC_L1: + return raft::distance::DistanceType::L1; + case MetricType::METRIC_Linf: + return raft::distance::DistanceType::Linf; + case MetricType::METRIC_Lp: + return raft::distance::DistanceType::LpUnexpanded; + case MetricType::METRIC_Canberra: + return raft::distance::DistanceType::Canberra; + case MetricType::METRIC_BrayCurtis: + return raft::distance::DistanceType::BrayCurtis; + case MetricType::METRIC_JensenShannon: + return raft::distance::DistanceType::JensenShannon; + default: + RAFT_FAIL("Distance type not supported"); + } +} +} // namespace gpu +} // namespace faiss diff --git a/faiss/gpu/test/CMakeLists.txt b/faiss/gpu/test/CMakeLists.txt index def3ef3151..ed68a142e8 100644 --- a/faiss/gpu/test/CMakeLists.txt +++ b/faiss/gpu/test/CMakeLists.txt @@ -3,6 +3,19 @@ # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. +# ============================================================================= find_package(CUDAToolkit REQUIRED) @@ -10,7 +23,7 @@ find_package(CUDAToolkit REQUIRED) include(GoogleTest) add_library(faiss_gpu_test_helper TestUtils.cpp) -target_link_libraries(faiss_gpu_test_helper PUBLIC faiss gtest CUDA::cudart) +target_link_libraries(faiss_gpu_test_helper PUBLIC faiss gtest CUDA::cudart $<$:raft::raft>) macro(faiss_gpu_test file) get_filename_component(test_name ${file} NAME_WE) @@ -19,6 +32,7 @@ macro(faiss_gpu_test file) gtest_discover_tests(${test_name}) endmacro() + faiss_gpu_test(TestCodePacking.cpp) faiss_gpu_test(TestGpuIndexFlat.cpp) faiss_gpu_test(TestGpuIndexIVFFlat.cpp) diff --git a/faiss/gpu/test/TestGpuDistance.cu b/faiss/gpu/test/TestGpuDistance.cu index eedec6b702..3c59cc1a5f 100644 --- a/faiss/gpu/test/TestGpuDistance.cu +++ b/faiss/gpu/test/TestGpuDistance.cu @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include @@ -16,10 +31,48 @@ #include #include +void evaluate_bfknn( + faiss::gpu::GpuDistanceParams& args, + faiss::gpu::GpuResourcesProvider* res, + std::vector& cpuDistance, + std::vector& cpuIndices, + std::vector& gpuDistance, + std::vector& gpuIndices, + int numQuery, + int k, + bool colMajorVecs, + bool colMajorQueries, + faiss::MetricType metric) { + using namespace faiss::gpu; + + bfKnn(res, args); + + std::stringstream str; + str << "using raft " << args.use_raft << "metric " << metric + << " colMajorVecs " << colMajorVecs << " colMajorQueries " + << colMajorQueries; + + compareLists( + cpuDistance.data(), + cpuIndices.data(), + gpuDistance.data(), + gpuIndices.data(), + numQuery, + k, + str.str(), + false, + false, + true, + 6e-3f, + 0.1f, + 0.015f); +} + void testTransposition( bool colMajorVecs, bool colMajorQueries, faiss::MetricType metric, + bool use_raft = false, float metricArg = 0) { using namespace faiss::gpu; @@ -115,26 +168,26 @@ void testTransposition( args.outIndices = gpuIndices.data(); args.device = device; - bfKnn(&res, args); - - std::stringstream str; - str << "metric " << metric << " colMajorVecs " << colMajorVecs - << " colMajorQueries " << colMajorQueries; - - compareLists( - cpuDistance.data(), - cpuIndices.data(), - gpuDistance.data(), - gpuIndices.data(), +#if defined USE_NVIDIA_RAFT + args.use_raft = use_raft; +#else + FAISS_THROW_IF_NOT_MSG( + !use_raft, + "RAFT has not been compiled into the current version so it cannot be used."); +#endif + + evaluate_bfknn( + args, + &res, + cpuDistance, + cpuIndices, + gpuDistance, + gpuIndices, numQuery, k, - str.str(), - false, - false, - true, - 6e-3f, - 0.1f, - 0.015f); + colMajorVecs, + colMajorQueries, + metric); } // Test different memory layouts for brute-force k-NN @@ -143,48 +196,118 @@ TEST(TestGpuDistance, Transposition_RR) { testTransposition(false, false, faiss::MetricType::METRIC_INNER_PRODUCT); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, Transposition_RR) { + testTransposition(false, false, faiss::MetricType::METRIC_L2, true); + testTransposition( + false, false, faiss::MetricType::METRIC_INNER_PRODUCT, true); +} +#endif + TEST(TestGpuDistance, Transposition_RC) { testTransposition(false, true, faiss::MetricType::METRIC_L2); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, Transposition_RC) { + testTransposition(false, true, faiss::MetricType::METRIC_L2, true); +} +#endif + TEST(TestGpuDistance, Transposition_CR) { testTransposition(true, false, faiss::MetricType::METRIC_L2); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, Transposition_CR) { + testTransposition(true, false, faiss::MetricType::METRIC_L2, true); +} +#endif + TEST(TestGpuDistance, Transposition_CC) { testTransposition(true, true, faiss::MetricType::METRIC_L2); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, Transposition_CC) { + testTransposition(true, true, faiss::MetricType::METRIC_L2, true); +} +#endif + TEST(TestGpuDistance, L1) { testTransposition(false, false, faiss::MetricType::METRIC_L1); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, L1) { + testTransposition(false, false, faiss::MetricType::METRIC_L1, true); +} +#endif + // Test other transpositions with the general distance kernel TEST(TestGpuDistance, L1_RC) { testTransposition(false, true, faiss::MetricType::METRIC_L1); } +#if defined USE_NVIDIA_RAFT +// Test other transpositions with the general distance kernel +TEST(TestRaftGpuDistance, L1_RC) { + testTransposition(false, true, faiss::MetricType::METRIC_L1, true); +} +#endif + TEST(TestGpuDistance, L1_CR) { testTransposition(true, false, faiss::MetricType::METRIC_L1); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, L1_CR) { + testTransposition(true, false, faiss::MetricType::METRIC_L1, true); +} +#endif + TEST(TestGpuDistance, L1_CC) { testTransposition(true, true, faiss::MetricType::METRIC_L1); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, L1_CC) { + testTransposition(true, true, faiss::MetricType::METRIC_L1, true); +} +#endif + // Test remainder of metric types TEST(TestGpuDistance, Linf) { testTransposition(false, false, faiss::MetricType::METRIC_Linf); } +#if defined USE_NVIDIA_RAFT +// Test remainder of metric types +TEST(TestRaftGpuDistance, Linf) { + testTransposition(false, false, faiss::MetricType::METRIC_Linf, true); +} +#endif + TEST(TestGpuDistance, Lp) { - testTransposition(false, false, faiss::MetricType::METRIC_Lp, 3); + testTransposition(false, false, faiss::MetricType::METRIC_Lp, false, 3); +} + +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, Lp) { + testTransposition(false, false, faiss::MetricType::METRIC_Lp, true, 3); } +#endif TEST(TestGpuDistance, Canberra) { testTransposition(false, false, faiss::MetricType::METRIC_Canberra); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, Canberra) { + testTransposition(false, false, faiss::MetricType::METRIC_Canberra, true); +} +#endif + TEST(TestGpuDistance, BrayCurtis) { testTransposition(false, false, faiss::MetricType::METRIC_BrayCurtis); } @@ -193,6 +316,13 @@ TEST(TestGpuDistance, JensenShannon) { testTransposition(false, false, faiss::MetricType::METRIC_JensenShannon); } +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuDistance, JensenShannon) { + testTransposition( + false, false, faiss::MetricType::METRIC_JensenShannon, true); +} +#endif + TEST(TestGpuDistance, Jaccard) { testTransposition(false, false, faiss::MetricType::METRIC_Jaccard); } diff --git a/faiss/gpu/test/TestGpuIndexFlat.cpp b/faiss/gpu/test/TestGpuIndexFlat.cpp index bcd49c68de..4f7c95deab 100644 --- a/faiss/gpu/test/TestGpuIndexFlat.cpp +++ b/faiss/gpu/test/TestGpuIndexFlat.cpp @@ -28,7 +28,8 @@ struct TestFlatOptions { numVecsOverride(-1), numQueriesOverride(-1), kOverride(-1), - dimOverride(-1) {} + dimOverride(-1), + use_raft(false) {} faiss::MetricType metric; float metricArg; @@ -38,6 +39,7 @@ struct TestFlatOptions { int numQueriesOverride; int kOverride; int dimOverride; + bool use_raft; }; void testFlat(const TestFlatOptions& opt) { @@ -73,6 +75,7 @@ void testFlat(const TestFlatOptions& opt) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; config.useFloat16 = opt.useFloat16; + config.use_raft = opt.use_raft; faiss::gpu::GpuIndexFlat gpuIndex(&res, dim, opt.metric, config); gpuIndex.metric_arg = opt.metricArg; @@ -110,6 +113,11 @@ TEST(TestGpuIndexFlat, IP_Float32) { opt.useFloat16 = false; testFlat(opt); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -119,6 +127,11 @@ TEST(TestGpuIndexFlat, L1_Float32) { opt.useFloat16 = false; testFlat(opt); + +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } TEST(TestGpuIndexFlat, Lp_Float32) { @@ -128,6 +141,10 @@ TEST(TestGpuIndexFlat, Lp_Float32) { opt.useFloat16 = false; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } TEST(TestGpuIndexFlat, L2_Float32) { @@ -138,6 +155,10 @@ TEST(TestGpuIndexFlat, L2_Float32) { opt.useFloat16 = false; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -152,6 +173,10 @@ TEST(TestGpuIndexFlat, L2_k_2048) { opt.numVecsOverride = 10000; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -164,6 +189,10 @@ TEST(TestGpuIndexFlat, L2_Float32_K1) { opt.kOverride = 1; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -174,6 +203,10 @@ TEST(TestGpuIndexFlat, IP_Float16) { opt.useFloat16 = true; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -184,6 +217,10 @@ TEST(TestGpuIndexFlat, L2_Float16) { opt.useFloat16 = true; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -196,6 +233,10 @@ TEST(TestGpuIndexFlat, L2_Float16_K1) { opt.kOverride = 1; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -213,6 +254,10 @@ TEST(TestGpuIndexFlat, L2_Tiling) { opt.kOverride = 64; testFlat(opt); +#if defined USE_NVIDIA_RAFT + opt.use_raft = true; + testFlat(opt); +#endif } } @@ -223,7 +268,6 @@ TEST(TestGpuIndexFlat, QueryEmpty) { faiss::gpu::GpuIndexFlatConfig config; config.device = 0; config.useFloat16 = false; - int dim = 128; faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config); @@ -247,7 +291,7 @@ TEST(TestGpuIndexFlat, QueryEmpty) { } } -TEST(TestGpuIndexFlat, CopyFrom) { +void testCopyFrom(bool use_raft) { int numVecs = faiss::gpu::randVal(100, 200); int dim = faiss::gpu::randVal(1, 1000); @@ -265,6 +309,7 @@ TEST(TestGpuIndexFlat, CopyFrom) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; config.useFloat16 = useFloat16; + config.use_raft = use_raft; // Fill with garbage values faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, 2000, config); @@ -293,7 +338,17 @@ TEST(TestGpuIndexFlat, CopyFrom) { } } -TEST(TestGpuIndexFlat, CopyTo) { +TEST(TestGpuIndexFlat, CopyFrom) { + testCopyFrom(false); +} + +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuIndexFlat, CopyFrom) { + testCopyFrom(true); +} +#endif + +void testCopyTo(bool use_raft) { faiss::gpu::StandardGpuResources res; res.noTempMemory(); @@ -307,6 +362,7 @@ TEST(TestGpuIndexFlat, CopyTo) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; config.useFloat16 = useFloat16; + config.use_raft = use_raft; faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config); gpuIndex.add(numVecs, vecs.data()); @@ -333,7 +389,17 @@ TEST(TestGpuIndexFlat, CopyTo) { } } -TEST(TestGpuIndexFlat, UnifiedMemory) { +TEST(TestGpuIndexFlat, CopyTo) { + testCopyTo(false); +} + +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuIndexFlat, CopyTo) { + testCopyTo(true); +} +#endif + +void testUnifiedMemory(bool use_raft) { // Construct on a random device to test multi-device, if we have // multiple devices int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); @@ -359,6 +425,7 @@ TEST(TestGpuIndexFlat, UnifiedMemory) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; config.memorySpace = faiss::gpu::MemorySpace::Unified; + config.use_raft = use_raft; faiss::gpu::GpuIndexFlatL2 gpuIndexL2(&res, dim, config); @@ -380,7 +447,17 @@ TEST(TestGpuIndexFlat, UnifiedMemory) { 0.015f); } -TEST(TestGpuIndexFlat, LargeIndex) { +TEST(TestGpuIndexFlat, UnifiedMemory) { + testUnifiedMemory(false); +} + +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuIndexFlat, UnifiedMemory) { + testUnifiedMemory(true); +} +#endif + +void testLargeIndex(bool use_raft) { // Construct on a random device to test multi-device, if we have // multiple devices int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); @@ -411,6 +488,7 @@ TEST(TestGpuIndexFlat, LargeIndex) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; + config.use_raft = use_raft; faiss::gpu::GpuIndexFlatL2 gpuIndexL2(&res, dim, config); cpuIndexL2.add(nb, xb.data()); @@ -430,7 +508,17 @@ TEST(TestGpuIndexFlat, LargeIndex) { 0.015f); } -TEST(TestGpuIndexFlat, Residual) { +TEST(TestGpuIndexFlat, LargeIndex) { + testLargeIndex(false); +} + +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuIndexFlat, LargeIndex) { + testLargeIndex(true); +} +#endif + +void testResidual(bool use_raft) { // Construct on a random device to test multi-device, if we have // multiple devices int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); @@ -440,6 +528,7 @@ TEST(TestGpuIndexFlat, Residual) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; + config.use_raft = use_raft; int dim = 32; faiss::IndexFlat cpuIndex(dim, faiss::MetricType::METRIC_L2); @@ -472,7 +561,17 @@ TEST(TestGpuIndexFlat, Residual) { EXPECT_EQ(residualsCpu, residualsGpu); } -TEST(TestGpuIndexFlat, Reconstruct) { +TEST(TestGpuIndexFlat, Residual) { + testResidual(false); +} + +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuIndexFlat, Residual) { + testResidual(true); +} +#endif + +void testReconstruct(bool use_raft) { // Construct on a random device to test multi-device, if we have // multiple devices int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); @@ -489,6 +588,7 @@ TEST(TestGpuIndexFlat, Reconstruct) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; config.useFloat16 = useFloat16; + config.use_raft = use_raft; faiss::gpu::GpuIndexFlat gpuIndex( &res, dim, faiss::MetricType::METRIC_L2, config); @@ -553,7 +653,16 @@ TEST(TestGpuIndexFlat, Reconstruct) { } } -TEST(TestGpuIndexFlat, SearchAndReconstruct) { +TEST(TestGpuIndexFlat, Reconstruct) { + testReconstruct(false); +} +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuIndexFlat, Reconstruct) { + testReconstruct(true); +} +#endif + +void testSearchAndReconstruct(bool use_raft) { // Construct on a random device to test multi-device, if we have // multiple devices int device = faiss::gpu::randVal(0, faiss::gpu::getNumDevices() - 1); @@ -573,6 +682,7 @@ TEST(TestGpuIndexFlat, SearchAndReconstruct) { faiss::gpu::GpuIndexFlatConfig config; config.device = device; + config.use_raft = use_raft; faiss::gpu::GpuIndexFlatL2 gpuIndex(&res, dim, config); cpuIndex.add(nb, xb.data()); @@ -640,6 +750,16 @@ TEST(TestGpuIndexFlat, SearchAndReconstruct) { } } +TEST(TestGpuIndexFlat, SearchAndReconstruct) { + testSearchAndReconstruct(false); +} + +#if defined USE_NVIDIA_RAFT +TEST(TestRaftGpuIndexFlat, SearchAndReconstruct) { + testSearchAndReconstruct(true); +} +#endif + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); diff --git a/faiss/gpu/test/TestGpuIndexIVFFlat.cpp b/faiss/gpu/test/TestGpuIndexIVFFlat.cpp index 0a004bd9b0..c4fc95ef29 100644 --- a/faiss/gpu/test/TestGpuIndexIVFFlat.cpp +++ b/faiss/gpu/test/TestGpuIndexIVFFlat.cpp @@ -4,6 +4,21 @@ * This source code is licensed under the MIT license found in the * LICENSE file in the root directory of this source tree. */ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ #include #include