Skip to content

Commit

Permalink
Added host multithreading support for FFTW (#652)
Browse files Browse the repository at this point in the history
Added multi-threading support for FFTW and element-wise ops for the host
  • Loading branch information
aayushg55 authored Jun 25, 2024
1 parent ed09e1c commit 7680716
Show file tree
Hide file tree
Showing 13 changed files with 192 additions and 87 deletions.
30 changes: 23 additions & 7 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ option(MATX_EN_VISUALIZATION "Enable visualization support" OFF)
#option(MATX_EN_CUTLASS OFF)
option(MATX_EN_CUTENSOR OFF)
option(MATX_EN_FILEIO OFF)
option(MATX_EN_X86_FFTW OFF "Enable x86 FFTW support")
option(MATX_EN_NVPL OFF, "Enable NVIDIA Performance Libraries for optimized ARM CPU support")
option(MATX_DISABLE_CUB_CACHE "Disable caching for CUB allocations" ON)
option(MATX_EN_COVERAGE OFF "Enable code coverage reporting")
Expand Down Expand Up @@ -171,13 +172,28 @@ set(WARN_FLAGS ${WARN_FLAGS} $<$<COMPILE_LANGUAGE:CXX>:-Werror>)
set (CUTLASS_INC "")
target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0)

if (MATX_EN_NVPL)
message(STATUS "Enabling NVPL library support")
# find_package is currently broken in NVPL. Use proper targets once working
#find_package(nvpl REQUIRED COMPONENTS fft)
#target_link_libraries(matx INTERFACE nvpl::fftw)
target_link_libraries(matx INTERFACE nvpl_fftw)
target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1)
# Host support
if (MATX_EN_NVPL OR MATX_EN_X86_FFTW)
message(STATUS "Enabling OpenMP support")
find_package(OpenMP REQUIRED)
target_link_libraries(matx INTERFACE OpenMP::OpenMP_CXX)
target_compile_definitions(matx INTERFACE MATX_EN_OMP=1)
if (MATX_EN_NVPL)
message(STATUS "Enabling NVPL library support for ARM CPUs")
find_package(nvpl REQUIRED COMPONENTS fft)
target_link_libraries(matx INTERFACE nvpl::fftw)
target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1)
else()
if (MATX_EN_X86_FFTW)
message(STATUS "Enabling x86 FFTW")
find_library(FFTW_LIB fftw3 REQUIRED)
find_library(FFTWF_LIB fftw3f REQUIRED)
find_library(FFTW_OMP_LIB fftw3_omp REQUIRED)
find_library(FFTWF_OMP_LIB fftw3f_omp REQUIRED)
target_link_libraries(matx INTERFACE ${FFTW_LIB} ${FFTWF_LIB} ${FFTW_OMP_LIB} ${FFTWF_OMP_LIB})
target_compile_definitions(matx INTERFACE MATX_EN_X86_FFTW=1)
endif()
endif()
endif()

if (MATX_DISABLE_CUB_CACHE)
Expand Down
2 changes: 1 addition & 1 deletion docs_input/api/logic/truth/allclose.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Reduce the closeness of two operators to a single scalar (0D) output. The output
from allclose is an ``int`` value since boolean reductions are not available in hardware


.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, HostExecutor exec)
.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, HostExecutor<MODE> &exec)
.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, cudaExecutor exec = 0)

Examples
Expand Down
16 changes: 13 additions & 3 deletions include/matx/core/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ inline constexpr bool is_settable_xform_v = std::conjunction_v<detail::is_matx_s
namespace detail {
template <typename T> struct is_executor : std::false_type {};
template <> struct is_executor<cudaExecutor> : std::true_type {};
template <> struct is_executor<HostExecutor> : std::true_type {};
template <ThreadsMode MODE> struct is_executor<HostExecutor<MODE>> : std::true_type {};
}

/**
Expand Down Expand Up @@ -286,17 +286,27 @@ inline constexpr bool is_cuda_executor_v = detail::is_cuda_executor<typename rem

namespace detail {
template<typename T> struct is_host_executor : std::false_type {};
template<> struct is_host_executor<matx::HostExecutor> : std::true_type {};
template<ThreadsMode MODE> struct is_host_executor<matx::HostExecutor<MODE>> : std::true_type {};

template<typename T> struct is_select_threads_host_executor : std::false_type {};
template<> struct is_select_threads_host_executor<matx::SelectThreadsHostExecutor> : std::true_type {};
}

/**
* @brief Determine if a type is a single-threaded host executor executor
* @brief Determine if a type is a host executor
*
* @tparam T Type to test
*/
template <typename T>
inline constexpr bool is_host_executor_v = detail::is_host_executor<remove_cvref_t<T>>::value;

/**
* @brief Determine if a type is a select threads host executor
*
* @tparam T Type to test
*/
template <typename T>
inline constexpr bool is_select_threads_host_executor_v = detail::is_select_threads_host_executor<remove_cvref_t<T>>::value;

namespace detail {
template <typename T, typename = void>
Expand Down
69 changes: 58 additions & 11 deletions include/matx/executors/host.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@

#include "matx/core/error.h"
#include "matx/core/get_grid_dims.h"

#ifdef MATX_EN_OMP
#include <omp.h>
#endif
namespace matx
{

Expand All @@ -48,9 +50,15 @@ struct cpu_set_t {
cuda::std::array<set_type, MAX_CPUS / (8 * sizeof(set_type))> bits_;
};

enum class ThreadsMode {
SINGLE,
SELECT,
ALL,
};

struct HostExecParams {
HostExecParams(int threads = 1) : threads_(threads) {}
HostExecParams(cpu_set_t cpu_set) : cpu_set_(cpu_set) {
HostExecParams(cpu_set_t cpu_set) : cpu_set_(cpu_set), threads_(1) {
MATX_ASSERT_STR(false, matxNotSupported, "CPU affinity not supported yet");
}

Expand All @@ -62,18 +70,42 @@ struct HostExecParams {
};

/**
* @brief Executor for running an operator on a single host thread
* @brief Executor for running an operator on a single or multi-threaded host
*
* @tparam MODE Threading policy
*
*/
template <ThreadsMode MODE = ThreadsMode::SINGLE>
class HostExecutor {
public:
using matx_cpu = bool; ///< Type trait indicating this is a CPU executor
using matx_executor = bool; ///< Type trait indicating this is an executor

HostExecutor(const HostExecParams &params = HostExecParams{}) : params_(params) {}
HostExecutor() {
int n_threads = 1;
if constexpr (MODE == ThreadsMode::SINGLE) {
n_threads = 1;
}
else if constexpr (MODE == ThreadsMode::ALL) {
#if MATX_EN_OMP
n_threads = omp_get_num_procs();
#endif
}
params_ = HostExecParams(n_threads);

#if MATX_EN_OMP
omp_set_num_threads(params_.GetNumThreads());
#endif
}

HostExecutor(const HostExecParams &params) : params_(params) {
#if MATX_EN_OMP
omp_set_num_threads(params_.GetNumThreads());
#endif
}

/**
* @brief Synchronize the host executor's threads. Currently a noop as the executor is single-threaded.
* @brief Synchronize the host executor's threads.
*
*/
void sync() {}
Expand All @@ -86,12 +118,23 @@ class HostExecutor {
*/
template <typename Op>
void Exec(Op &op) const noexcept {
if (params_.GetNumThreads() == 1) {
if constexpr (Op::Rank() == 0) {
op();
}
else {
index_t size = TotalSize(op);
if constexpr (Op::Rank() == 0) {
op();
}
else {
index_t size = TotalSize(op);
#ifdef MATX_EN_OMP
if (params_.GetNumThreads() > 1) {
#pragma omp parallel for num_threads(params_.GetNumThreads())
for (index_t i = 0; i < size; i++) {
auto idx = GetIdxFromAbs(op, i);
cuda::std::apply([&](auto... args) {
return op(args...);
}, idx);
}
} else
#endif
{
for (index_t i = 0; i < size; i++) {
auto idx = GetIdxFromAbs(op, i);
cuda::std::apply([&](auto... args) {
Expand All @@ -108,4 +151,8 @@ class HostExecutor {
HostExecParams params_;
};

using SingleThreadedHostExecutor = HostExecutor<ThreadsMode::SINGLE>;
using SelectThreadsHostExecutor = HostExecutor<ThreadsMode::SELECT>;
using AllThreadsHostExecutor = HostExecutor<ThreadsMode::ALL>;

}
4 changes: 2 additions & 2 deletions include/matx/executors/support.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,11 @@ namespace matx {
namespace detail {

// FFT
#if defined(MATX_EN_NVPL)
#if defined(MATX_EN_NVPL) || defined(MATX_EN_X86_FFTW)
#define MATX_EN_CPU_FFT 1
#else
#define MATX_EN_CPU_FFT 0
#endif
#endif

template <typename Exec>
constexpr bool CheckFFTSupport() {
Expand Down
2 changes: 1 addition & 1 deletion include/matx/operators/legendre.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ namespace matx
__MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ scalar_type operator()(Is... indices) const
{
cuda::std::array<index_t, Rank()> inds{indices...};
cuda::std::array<index_t, T3::Rank()> xinds;
cuda::std::array<index_t, T3::Rank()> xinds{};

int axis1 = axis_[0];
int axis2 = axis_[1];
Expand Down
20 changes: 10 additions & 10 deletions include/matx/transforms/cub.h
Original file line number Diff line number Diff line change
Expand Up @@ -1409,10 +1409,10 @@ void sort_impl(OutputTensor &a_out, const InputOperator &a,
#endif
}

template <typename OutputTensor, typename InputOperator>
template <typename OutputTensor, typename InputOperator, ThreadsMode MODE>
void sort_impl(OutputTensor &a_out, const InputOperator &a,
const SortDirection_t dir,
[[maybe_unused]] HostExecutor exec)
[[maybe_unused]] HostExecutor<MODE> &exec)
{
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)

Expand Down Expand Up @@ -1507,9 +1507,9 @@ void cumsum_impl(OutputTensor &a_out, const InputOperator &a,
#endif
}

template <typename OutputTensor, typename InputOperator>
template <typename OutputTensor, typename InputOperator, ThreadsMode MODE>
void cumsum_impl(OutputTensor &a_out, const InputOperator &a,
[[maybe_unused]] HostExecutor exec)
[[maybe_unused]] HostExecutor<MODE> &exec)
{
#ifdef __CUDACC__
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
Expand Down Expand Up @@ -1782,8 +1782,8 @@ void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator
* @param exec
* Single-threaded host executor
*/
template <typename SelectType, typename CountTensor, typename OutputTensor, typename InputOperator>
void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor exec)
template <typename SelectType, typename CountTensor, typename OutputTensor, typename InputOperator, ThreadsMode MODE>
void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor<MODE> &exec)
{
static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0");
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
Expand Down Expand Up @@ -1903,8 +1903,8 @@ void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOpera
* @param exec
* Single host executor
*/
template <typename SelectType, typename CountTensor, typename OutputTensor, typename InputOperator>
void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor exec)
template <typename SelectType, typename CountTensor, typename OutputTensor, typename InputOperator, ThreadsMode MODE>
void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor<MODE> &exec)
{
static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0");
MATX_NVTX_START("", matx::MATX_NVTX_LOG_API)
Expand Down Expand Up @@ -2018,8 +2018,8 @@ void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperato
* @param exec
* Single thread executor
*/
template <typename CountTensor, typename OutputTensor, typename InputOperator>
void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, [[maybe_unused]] HostExecutor exec)
template <typename CountTensor, typename OutputTensor, typename InputOperator, ThreadsMode MODE>
void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, [[maybe_unused]] HostExecutor<MODE> &exec)
{
#ifdef __CUDACC__
static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0");
Expand Down
Loading

0 comments on commit 7680716

Please sign in to comment.