Skip to content

Commit

Permalink
Enable using Thrust from CCCL
Browse files Browse the repository at this point in the history
  • Loading branch information
bernhardmgruber committed Jun 26, 2024
1 parent f787167 commit be3132b
Showing 1 changed file with 34 additions and 27 deletions.
61 changes: 34 additions & 27 deletions src/thrust/model.cmake
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@

register_flag_optional(THRUST_IMPL
"Which Thrust implementation to use, supported options include:
- CUDA (via https://github.com/NVIDIA/thrust)
- CUDA (via https://github.com/NVIDIA/thrust or https://github.com/NVIDIA/CCCL)
- ROCM (via https://github.com/ROCmSoftwarePlatform/rocThrust)
"
"CUDA")

register_flag_optional(SDK_DIR
"Path to the selected Thrust implementation (e.g `/opt/nvidia/hpc_sdk/Linux_x86_64/21.9/cuda/include` for NVHPC, `/opt/rocm` for ROCm)"
"Path to the installation prefix for CCCL or Thrust (e.g `/opt/nvidia/hpc_sdk/Linux_x86_64/24.5/cuda/12.4/lib64/cmake` for NVHPC, or `/usr/local/cuda-12.5/lib64/cmake` for nvcc, or `/usr/local/cuda-11.4/include` for older nvcc, or `/opt/rocm` for ROCm)"
"")

register_flag_optional(BACKEND
Expand All @@ -18,7 +18,7 @@ register_flag_optional(BACKEND
"
"CUDA")

register_flag_optional(MANAGED "Enabled managed memory mode."
register_flag_optional(MANAGED "Enabled managed memory mode."
"OFF")

register_flag_optional(CMAKE_CUDA_COMPILER
Expand All @@ -34,6 +34,9 @@ register_flag_optional(CUDA_EXTRA_FLAGS
"[THRUST_IMPL==CUDA] Additional CUDA flags passed to nvcc, this is appended after `CUDA_ARCH`"
"")

option(FETCH_CCCL "Fetch (download) the CCCL library. This uses CMake's FetchContent feature.
Specify version by setting FETCH_CCCL_VERSION" OFF)
set(FETCH_CCCL_VERSION "v2.4.0" CACHE STRING "Specify version of CCCL to use if FETCH_CCCL is ON")

macro(setup)
set(CMAKE_CXX_STANDARD 14)
Expand All @@ -42,44 +45,48 @@ macro(setup)
endif ()

if (${THRUST_IMPL} STREQUAL "CUDA")

# see CUDA.cmake, we're only adding a few Thrust related libraries here

if (POLICY CMP0104)
cmake_policy(SET CMP0104 NEW)
endif ()

set(CMAKE_CUDA_ARCHITECTURES ${CUDA_ARCH})
# add -forward-unknown-to-host-compiler for compatibility reasons
set(CMAKE_CUDA_FLAGS ${CMAKE_CUDA_FLAGS} "--expt-extended-lambda " ${CUDA_EXTRA_FLAGS})
enable_language(CUDA)
# CMake defaults to -O2 for CUDA at Release, let's wipe that and use the global RELEASE_FLAG
# appended later
# CMake defaults to -O2 for CUDA at Release, let's wipe that and use the global RELEASE_FLAG appended later
wipe_gcc_style_optimisation_flags(CMAKE_CUDA_FLAGS_${BUILD_TYPE})

message(STATUS "NVCC flags: ${CMAKE_CUDA_FLAGS} ${CMAKE_CUDA_FLAGS_${BUILD_TYPE}}")


# XXX NVHPC <= 21.9 has cub-config in `Linux_x86_64/21.9/cuda/11.4/include/cub/cmake`
# XXX NVHPC >= 22.3 has cub-config in `Linux_x86_64/22.3/cuda/11.6/lib64/cmake/cub/`
# same thing for thrust
if (SDK_DIR)
# CMake tries several subdirectories below SDK_DIR, see documentation:
# https://cmake.org/cmake/help/latest/command/find_package.html#config-mode-search-procedure
list(APPEND CMAKE_PREFIX_PATH ${SDK_DIR})
find_package(CUB REQUIRED CONFIG PATHS ${SDK_DIR}/cub)
find_package(Thrust REQUIRED CONFIG PATHS ${SDK_DIR}/thrust)
else ()
find_package(CUB REQUIRED CONFIG)
find_package(Thrust REQUIRED CONFIG)
endif ()

message(STATUS "Using Thrust backend: ${BACKEND}")

# this creates the interface that we can link to
thrust_create_target(Thrust${BACKEND}
HOST CPP
DEVICE ${BACKEND})

register_link_library(Thrust${BACKEND})
set(CCCL_THRUST_DEVICE_SYSTEM ${BACKEND} CACHE STRING "" FORCE)

# fetch CCCL if user wants to
if (FETCH_CCCL)
FetchContent_Declare(
CCCL
GIT_REPOSITORY https://github.com/nvidia/cccl.git
GIT_TAG "${FETCH_CCCL_VERSION}"
)
FetchContent_MakeAvailable(CCCL)
register_link_library(CCCL::CCCL)
else()
# try to find CCCL locally
find_package(CCCL CONFIG)
if (CCCL_FOUND)
register_link_library(CCCL::CCCL)
else()
# backup: find legacy projects separately
message(WARNING "No CCCL found on your system. Trying Thrust and CUB legacy targets.")
find_package(CUB REQUIRED CONFIG)
find_package(Thrust REQUIRED CONFIG)
thrust_create_target(Thrust${BACKEND} HOST CPP DEVICE ${BACKEND})
register_link_library(Thrust${BACKEND})
endif()
endif()
elseif (${THRUST_IMPL} STREQUAL "ROCM")
if (SDK_DIR)
find_package(rocprim REQUIRED CONFIG PATHS ${SDK_DIR}/rocprim)
Expand Down

0 comments on commit be3132b

Please sign in to comment.