From 76807166059b34d2267c62c6685672ff2ff6a140 Mon Sep 17 00:00:00 2001 From: Aayush Gupta <19579293+aayushg55@users.noreply.github.com> Date: Mon, 24 Jun 2024 20:45:07 -0700 Subject: [PATCH] Added host multithreading support for FFTW (#652) Added multi-threading support for FFTW and element-wise ops for the host --- CMakeLists.txt | 30 +++++++--- docs_input/api/logic/truth/allclose.rst | 2 +- include/matx/core/type_utils.h | 16 +++++- include/matx/executors/host.h | 69 ++++++++++++++++++---- include/matx/executors/support.h | 4 +- include/matx/operators/legendre.h | 2 +- include/matx/transforms/cub.h | 20 +++---- include/matx/transforms/fft/fft_fftw.h | 76 +++++++++++++++++-------- include/matx/transforms/reduce.h | 44 +++++++------- include/matx/transforms/transpose.h | 4 +- test/00_operators/OperatorTests.cu | 2 +- test/00_transform/FFT.cu | 6 ++ test/include/test_types.h | 4 +- 13 files changed, 192 insertions(+), 87 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 02678641..55eb72ec 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -58,6 +58,7 @@ option(MATX_EN_VISUALIZATION "Enable visualization support" OFF) #option(MATX_EN_CUTLASS OFF) option(MATX_EN_CUTENSOR OFF) option(MATX_EN_FILEIO OFF) +option(MATX_EN_X86_FFTW OFF "Enable x86 FFTW support") option(MATX_EN_NVPL OFF, "Enable NVIDIA Performance Libraries for optimized ARM CPU support") option(MATX_DISABLE_CUB_CACHE "Disable caching for CUB allocations" ON) option(MATX_EN_COVERAGE OFF "Enable code coverage reporting") @@ -171,13 +172,28 @@ set(WARN_FLAGS ${WARN_FLAGS} $<$:-Werror>) set (CUTLASS_INC "") target_compile_definitions(matx INTERFACE MATX_ENABLE_CUTLASS=0) -if (MATX_EN_NVPL) - message(STATUS "Enabling NVPL library support") - # find_package is currently broken in NVPL. Use proper targets once working - #find_package(nvpl REQUIRED COMPONENTS fft) - #target_link_libraries(matx INTERFACE nvpl::fftw) - target_link_libraries(matx INTERFACE nvpl_fftw) - target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1) +# Host support +if (MATX_EN_NVPL OR MATX_EN_X86_FFTW) + message(STATUS "Enabling OpenMP support") + find_package(OpenMP REQUIRED) + target_link_libraries(matx INTERFACE OpenMP::OpenMP_CXX) + target_compile_definitions(matx INTERFACE MATX_EN_OMP=1) + if (MATX_EN_NVPL) + message(STATUS "Enabling NVPL library support for ARM CPUs") + find_package(nvpl REQUIRED COMPONENTS fft) + target_link_libraries(matx INTERFACE nvpl::fftw) + target_compile_definitions(matx INTERFACE MATX_EN_NVPL=1) + else() + if (MATX_EN_X86_FFTW) + message(STATUS "Enabling x86 FFTW") + find_library(FFTW_LIB fftw3 REQUIRED) + find_library(FFTWF_LIB fftw3f REQUIRED) + find_library(FFTW_OMP_LIB fftw3_omp REQUIRED) + find_library(FFTWF_OMP_LIB fftw3f_omp REQUIRED) + target_link_libraries(matx INTERFACE ${FFTW_LIB} ${FFTWF_LIB} ${FFTW_OMP_LIB} ${FFTWF_OMP_LIB}) + target_compile_definitions(matx INTERFACE MATX_EN_X86_FFTW=1) + endif() + endif() endif() if (MATX_DISABLE_CUB_CACHE) diff --git a/docs_input/api/logic/truth/allclose.rst b/docs_input/api/logic/truth/allclose.rst index e6fd82e4..bca253e8 100644 --- a/docs_input/api/logic/truth/allclose.rst +++ b/docs_input/api/logic/truth/allclose.rst @@ -7,7 +7,7 @@ Reduce the closeness of two operators to a single scalar (0D) output. The output from allclose is an ``int`` value since boolean reductions are not available in hardware -.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, HostExecutor exec) +.. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, HostExecutor &exec) .. doxygenfunction:: allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, cudaExecutor exec = 0) Examples diff --git a/include/matx/core/type_utils.h b/include/matx/core/type_utils.h index 78e7e397..9c59cfbf 100644 --- a/include/matx/core/type_utils.h +++ b/include/matx/core/type_utils.h @@ -256,7 +256,7 @@ inline constexpr bool is_settable_xform_v = std::conjunction_v struct is_executor : std::false_type {}; template <> struct is_executor : std::true_type {}; -template <> struct is_executor : std::true_type {}; +template struct is_executor> : std::true_type {}; } /** @@ -286,17 +286,27 @@ inline constexpr bool is_cuda_executor_v = detail::is_cuda_executor struct is_host_executor : std::false_type {}; -template<> struct is_host_executor : std::true_type {}; +template struct is_host_executor> : std::true_type {}; + +template struct is_select_threads_host_executor : std::false_type {}; +template<> struct is_select_threads_host_executor : std::true_type {}; } /** - * @brief Determine if a type is a single-threaded host executor executor + * @brief Determine if a type is a host executor * * @tparam T Type to test */ template inline constexpr bool is_host_executor_v = detail::is_host_executor>::value; +/** + * @brief Determine if a type is a select threads host executor + * + * @tparam T Type to test + */ +template +inline constexpr bool is_select_threads_host_executor_v = detail::is_select_threads_host_executor>::value; namespace detail { template diff --git a/include/matx/executors/host.h b/include/matx/executors/host.h index 728c6ec2..de0d50e9 100644 --- a/include/matx/executors/host.h +++ b/include/matx/executors/host.h @@ -36,7 +36,9 @@ #include "matx/core/error.h" #include "matx/core/get_grid_dims.h" - +#ifdef MATX_EN_OMP +#include +#endif namespace matx { @@ -48,9 +50,15 @@ struct cpu_set_t { cuda::std::array bits_; }; +enum class ThreadsMode { + SINGLE, + SELECT, + ALL, +}; + struct HostExecParams { HostExecParams(int threads = 1) : threads_(threads) {} - HostExecParams(cpu_set_t cpu_set) : cpu_set_(cpu_set) { + HostExecParams(cpu_set_t cpu_set) : cpu_set_(cpu_set), threads_(1) { MATX_ASSERT_STR(false, matxNotSupported, "CPU affinity not supported yet"); } @@ -62,18 +70,42 @@ struct HostExecParams { }; /** - * @brief Executor for running an operator on a single host thread + * @brief Executor for running an operator on a single or multi-threaded host + * + * @tparam MODE Threading policy * */ +template class HostExecutor { public: using matx_cpu = bool; ///< Type trait indicating this is a CPU executor using matx_executor = bool; ///< Type trait indicating this is an executor - HostExecutor(const HostExecParams ¶ms = HostExecParams{}) : params_(params) {} + HostExecutor() { + int n_threads = 1; + if constexpr (MODE == ThreadsMode::SINGLE) { + n_threads = 1; + } + else if constexpr (MODE == ThreadsMode::ALL) { +#if MATX_EN_OMP + n_threads = omp_get_num_procs(); +#endif + } + params_ = HostExecParams(n_threads); + +#if MATX_EN_OMP + omp_set_num_threads(params_.GetNumThreads()); +#endif + } + + HostExecutor(const HostExecParams ¶ms) : params_(params) { +#if MATX_EN_OMP + omp_set_num_threads(params_.GetNumThreads()); +#endif + } /** - * @brief Synchronize the host executor's threads. Currently a noop as the executor is single-threaded. + * @brief Synchronize the host executor's threads. * */ void sync() {} @@ -86,12 +118,23 @@ class HostExecutor { */ template void Exec(Op &op) const noexcept { - if (params_.GetNumThreads() == 1) { - if constexpr (Op::Rank() == 0) { - op(); - } - else { - index_t size = TotalSize(op); + if constexpr (Op::Rank() == 0) { + op(); + } + else { + index_t size = TotalSize(op); + #ifdef MATX_EN_OMP + if (params_.GetNumThreads() > 1) { + #pragma omp parallel for num_threads(params_.GetNumThreads()) + for (index_t i = 0; i < size; i++) { + auto idx = GetIdxFromAbs(op, i); + cuda::std::apply([&](auto... args) { + return op(args...); + }, idx); + } + } else + #endif + { for (index_t i = 0; i < size; i++) { auto idx = GetIdxFromAbs(op, i); cuda::std::apply([&](auto... args) { @@ -108,4 +151,8 @@ class HostExecutor { HostExecParams params_; }; +using SingleThreadedHostExecutor = HostExecutor; +using SelectThreadsHostExecutor = HostExecutor; +using AllThreadsHostExecutor = HostExecutor; + } diff --git a/include/matx/executors/support.h b/include/matx/executors/support.h index d6b8d3d1..bdcac763 100644 --- a/include/matx/executors/support.h +++ b/include/matx/executors/support.h @@ -40,11 +40,11 @@ namespace matx { namespace detail { // FFT -#if defined(MATX_EN_NVPL) +#if defined(MATX_EN_NVPL) || defined(MATX_EN_X86_FFTW) #define MATX_EN_CPU_FFT 1 #else #define MATX_EN_CPU_FFT 0 -#endif +#endif template constexpr bool CheckFFTSupport() { diff --git a/include/matx/operators/legendre.h b/include/matx/operators/legendre.h index 3b2cb3d0..c1606053 100644 --- a/include/matx/operators/legendre.h +++ b/include/matx/operators/legendre.h @@ -103,7 +103,7 @@ namespace matx __MATX_INLINE__ __MATX_DEVICE__ __MATX_HOST__ scalar_type operator()(Is... indices) const { cuda::std::array inds{indices...}; - cuda::std::array xinds; + cuda::std::array xinds{}; int axis1 = axis_[0]; int axis2 = axis_[1]; diff --git a/include/matx/transforms/cub.h b/include/matx/transforms/cub.h index bb8f5412..b8fae988 100644 --- a/include/matx/transforms/cub.h +++ b/include/matx/transforms/cub.h @@ -1409,10 +1409,10 @@ void sort_impl(OutputTensor &a_out, const InputOperator &a, #endif } -template +template void sort_impl(OutputTensor &a_out, const InputOperator &a, const SortDirection_t dir, - [[maybe_unused]] HostExecutor exec) + [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -1507,9 +1507,9 @@ void cumsum_impl(OutputTensor &a_out, const InputOperator &a, #endif } -template +template void cumsum_impl(OutputTensor &a_out, const InputOperator &a, - [[maybe_unused]] HostExecutor exec) + [[maybe_unused]] HostExecutor &exec) { #ifdef __CUDACC__ MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -1782,8 +1782,8 @@ void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator * @param exec * Single-threaded host executor */ -template -void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor exec) +template +void find_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor &exec) { static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0"); MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -1903,8 +1903,8 @@ void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOpera * @param exec * Single host executor */ -template -void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor exec) +template +void find_idx_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, SelectType sel, [[maybe_unused]] HostExecutor &exec) { static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0"); MATX_NVTX_START("", matx::MATX_NVTX_LOG_API) @@ -2018,8 +2018,8 @@ void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperato * @param exec * Single thread executor */ -template -void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, [[maybe_unused]] HostExecutor exec) +template +void unique_impl(OutputTensor &a_out, CountTensor &num_found, const InputOperator &a, [[maybe_unused]] HostExecutor &exec) { #ifdef __CUDACC__ static_assert(CountTensor::Rank() == 0, "Num found output tensor rank must be 0"); diff --git a/include/matx/transforms/fft/fft_fftw.h b/include/matx/transforms/fft/fft_fftw.h index 27de9f36..f263c573 100644 --- a/include/matx/transforms/fft/fft_fftw.h +++ b/include/matx/transforms/fft/fft_fftw.h @@ -44,7 +44,12 @@ #ifdef MATX_EN_NVPL #include #endif - +#ifdef MATX_EN_X86_FFTW +#include +#endif +#ifdef MATX_EN_OMP +#include +#endif #include #include #include @@ -130,6 +135,8 @@ struct FftFFTWparams_t { } params.batch = (RANK == 2) ? 1 : static_cast(TotalSize(i) / ((i.Size(RANK - 1) * i.Size(RANK - 2)))); + params.inembed[0] = static_cast(i.Size(RANK-2)); + params.onembed[0] = static_cast(o.Size(RANK-2)); params.inembed[1] = static_cast(i.Size(RANK-1)); params.onembed[1] = static_cast(o.Size(RANK-1)); params.istride = static_cast(i.Stride(RANK-1)); @@ -212,6 +219,17 @@ struct FftFFTWparams_t { } else { bool supported = true; + // only a subset of strides are supported per fftw indexing scheme. + if constexpr (RANK >= 2) { + if (in.Stride(RANK - 2) != in.Stride(RANK - 1) * in.Size(RANK - 1)) { + supported = false; + } + } + if constexpr (RANK > 2) { + if (in.Stride(RANK - 3) != in.Size(RANK - 2) * in.Stride(RANK - 2)) { + supported = false; + } + } // If there are any unsupported layouts for fftw add them here if (supported) { @@ -249,11 +267,12 @@ struct FftFFTWparams_t { } } - template + template __MATX_INLINE__ void fft_exec([[maybe_unused]] OutputTensor &o, [[maybe_unused]] const InputTensor &i, [[maybe_unused]] const FftFFTWparams_t ¶ms, - [[maybe_unused]] detail::FFTDirection dir) { + [[maybe_unused]] detail::FFTDirection dir, + [[maybe_unused]] const HostExecutor &exec) { [[maybe_unused]] static constexpr bool fp32 = std::is_same_v::type, float>; #if MATX_EN_CPU_FFT @@ -262,6 +281,10 @@ struct FftFFTWparams_t { auto exec_plans = [&](typename OutputTensor::value_type *out_ptr, typename InputTensor::value_type *in_ptr) { if constexpr (fp32) { + int ret = fftwf_init_threads(); + MATX_ASSERT_STR(ret != 0, matxAssertError, "fftwf_init_threads() failed"); + + fftwf_plan_with_nthreads(exec.GetNumThreads()); fftwf_plan plan{}; if constexpr (DeduceFFTTransformType() == FFTType::C2C) { plan = fftwf_plan_many_dft( params.fft_rank, @@ -306,12 +329,18 @@ struct FftFFTWparams_t { params.odist, FFTW_ESTIMATE); } + MATX_ASSERT_STR(plan != nullptr, matxAssertError, "fftwf plan creation failed"); fftwf_execute(plan); fftwf_destroy_plan(plan); - fftwf_cleanup(); + fftwf_cleanup_threads(); + fftwf_cleanup(); } else { + int ret = fftw_init_threads(); + MATX_ASSERT_STR(ret != 0, matxAssertError, "fftw_init_threads() failed"); + + fftw_plan_with_nthreads(exec.GetNumThreads()); fftw_plan plan{}; if constexpr (DeduceFFTTransformType() == FFTType::Z2Z) { plan = fftw_plan_many_dft( params.fft_rank, @@ -356,10 +385,12 @@ struct FftFFTWparams_t { params.odist, FFTW_ESTIMATE); } - + MATX_ASSERT_STR(plan != nullptr, matxAssertError, "fftw plan creation failed"); + fftw_execute(plan); fftw_destroy_plan(plan); - fftw_cleanup(); + fftw_cleanup_threads(); + fftw_cleanup(); } }; @@ -367,22 +398,21 @@ struct FftFFTWparams_t { #endif } - template + template __MATX_INLINE__ void fft1d_dispatch(OutputTensor o, const InputTensor i, - uint64_t fft_size, detail::FFTDirection dir, FFTNorm norm, const HostExecutor &exec) + uint64_t fft_size, detail::FFTDirection dir, FFTNorm norm, const HostExecutor &exec) { MATX_STATIC_ASSERT_STR(OutputTensor::Rank() == InputTensor::Rank(), matxInvalidDim, "Input and output tensor ranks must match"); MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - MATX_ASSERT_STR(exec.GetNumThreads() == 1, matxInvalidParameter, "Only single-threaded host FFT supported"); MATX_ASSERT_STR(TotalSize(i) < std::numeric_limits::max(), matxInvalidSize, "Dimensions too large for host FFT currently"); // converts operators to tensors auto out = getFFTW1DSupportedTensor(o); auto in_t = getFFTW1DSupportedTensor(i); - + if(!in_t.isSameView(i)) { (in_t = i).run(exec); } @@ -392,7 +422,7 @@ struct FftFFTWparams_t { // Get parameters required by these tensors auto params = GetFFTParams(out, in, 1); - fft_exec(o, in, params, dir); + fft_exec(out, in, params, dir, exec); if(!out.isSameView(o)) { (o = out).run(exec); @@ -423,9 +453,9 @@ struct FftFFTWparams_t { } - template + template __MATX_INLINE__ void fft2d_dispatch(OutputTensor o, const InputTensor i, - detail::FFTDirection dir, FFTNorm norm, const HostExecutor &exec) + detail::FFTDirection dir, FFTNorm norm, const HostExecutor &exec) { MATX_STATIC_ASSERT_STR(OutputTensor::Rank() == InputTensor::Rank(), matxInvalidDim, "Input and output tensor ranks must match"); @@ -433,8 +463,6 @@ struct FftFFTWparams_t { MATX_NVTX_START("", matx::MATX_NVTX_LOG_INTERNAL) - MATX_ASSERT_STR(exec.GetNumThreads() == 1, matxInvalidParameter, "Only single-threaded host FFT supported"); - // converts operators to tensors auto out = getFFTW2DSupportedTensor(o); auto in = getFFTW2DSupportedTensor(i); @@ -446,7 +474,7 @@ struct FftFFTWparams_t { // Get parameters required by these tensors auto params = GetFFTParams(out, in, 2); - fft_exec(o, in, params, dir); + fft_exec(out, in, params, dir, exec); if(!out.isSameView(o)) { (o = out).run(exec); @@ -477,9 +505,9 @@ struct FftFFTWparams_t { } - template + template __MATX_INLINE__ void fft_impl(OutputTensor o, const InputTensor i, - uint64_t fft_size, FFTNorm norm, const HostExecutor &exec) + uint64_t fft_size, FFTNorm norm, const HostExecutor &exec) { MATX_STATIC_ASSERT_STR(OutputTensor::Rank() == InputTensor::Rank(), matxInvalidDim, "Input and output tensor ranks must match"); @@ -495,9 +523,9 @@ struct FftFFTWparams_t { } - template + template __MATX_INLINE__ void ifft_impl(OutputTensor o, const InputTensor i, - uint64_t fft_size, FFTNorm norm, const HostExecutor &exec) + uint64_t fft_size, FFTNorm norm, const HostExecutor &exec) { MATX_STATIC_ASSERT_STR(OutputTensor::Rank() == InputTensor::Rank(), matxInvalidDim, "Input and output tensor ranks must match"); @@ -514,9 +542,9 @@ struct FftFFTWparams_t { - template + template __MATX_INLINE__ void fft2_impl(OutputTensor o, const InputTensor i, FFTNorm norm, - const HostExecutor &exec) + const HostExecutor &exec) { MATX_STATIC_ASSERT_STR(OutputTensor::Rank() == InputTensor::Rank(), matxInvalidDim, "Input and output tensor ranks must match"); @@ -532,9 +560,9 @@ struct FftFFTWparams_t { fft2d_dispatch(o, i, FFTDirection::FORWARD, norm, exec); } - template + template __MATX_INLINE__ void ifft2_impl(OutputTensor o, const InputTensor i, FFTNorm norm, - const HostExecutor &exec) + const HostExecutor &exec) { MATX_STATIC_ASSERT_STR(OutputTensor::Rank() == InputTensor::Rank(), matxInvalidDim, "Input and output tensor ranks must match"); diff --git a/include/matx/transforms/reduce.h b/include/matx/transforms/reduce.h index f7d7df7a..708d6ac0 100644 --- a/include/matx/transforms/reduce.h +++ b/include/matx/transforms/reduce.h @@ -1529,8 +1529,8 @@ void __MATX_INLINE__ mean_impl(OutType dest, const InType &in, * @param exec * Single thread host executor */ -template -void __MATX_INLINE__ mean_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ mean_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("mean_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -1795,8 +1795,8 @@ void __MATX_INLINE__ median_impl(OutType dest, * @param exec * Single thread host executor */ -template -void __MATX_INLINE__ median_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ median_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("median_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -1887,8 +1887,8 @@ void __MATX_INLINE__ sum_impl(OutType dest, const InType &in, cudaExecutor exec * @param exec * Single thread host executor */ -template -void __MATX_INLINE__ sum_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ sum_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("sum_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -1956,8 +1956,8 @@ void __MATX_INLINE__ prod_impl(OutType dest, const InType &in, cudaExecutor exec * @param exec * Single thread host executor */ -template -void __MATX_INLINE__ prod_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ prod_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("prod_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -2033,8 +2033,8 @@ void __MATX_INLINE__ max_impl(OutType dest, const InType &in, cudaExecutor exec * @param exec * Single threaded host executor */ -template -void __MATX_INLINE__ max_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ max_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("max_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2111,8 +2111,8 @@ void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InT * @param exec * Single threaded host executor */ -template -void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ argmax_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("argmax_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2184,8 +2184,8 @@ void __MATX_INLINE__ min_impl(OutType dest, const InType &in, cudaExecutor exec * @param exec * Single threaded host executor */ -template -void __MATX_INLINE__ min_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ min_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("min_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) auto ft = [&](auto &&lin, auto &&lout, [[maybe_unused]] auto &&lbegin, [[maybe_unused]] auto &&lend) { @@ -2258,8 +2258,8 @@ void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InT * @param exec * SIngle host executor */ -template -void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ argmin_impl(OutType dest, TensorIndexType &idest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("argmin_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2331,8 +2331,8 @@ void __MATX_INLINE__ any_impl(OutType dest, const InType &in, cudaExecutor exec * @param exec * Single threaded host executor */ -template -void __MATX_INLINE__ any_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ any_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("any_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2404,8 +2404,8 @@ void __MATX_INLINE__ all_impl(OutType dest, const InType &in, cudaExecutor exec * @param exec * Single threaded host executor */ -template -void __MATX_INLINE__ all_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ all_impl(OutType dest, const InType &in, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("all_impl(" + get_type_str(in) + ")", matx::MATX_NVTX_LOG_API) @@ -2490,8 +2490,8 @@ void __MATX_INLINE__ allclose(OutType dest, const InType1 &in1, const InType2 &i * @param exec * Single threaded host executor */ -template -void __MATX_INLINE__ allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, [[maybe_unused]] HostExecutor exec) +template +void __MATX_INLINE__ allclose(OutType dest, const InType1 &in1, const InType2 &in2, double rtol, double atol, [[maybe_unused]] HostExecutor &exec) { MATX_NVTX_START("allclose(" + get_type_str(in1) + ", " + get_type_str(in2) + ")", matx::MATX_NVTX_LOG_API) static_assert(OutType::Rank() == 0, "allclose output must be rank 0"); diff --git a/include/matx/transforms/transpose.h b/include/matx/transforms/transpose.h index dc1fa300..a75ac89d 100644 --- a/include/matx/transforms/transpose.h +++ b/include/matx/transforms/transpose.h @@ -102,9 +102,9 @@ namespace matx #endif }; - template + template __MATX_INLINE__ void transpose_matrix_impl([[maybe_unused]] OutputTensor &out, - const InputTensor &in, HostExecutor exec) + const InputTensor &in, HostExecutor &exec) { static_assert(InputTensor::Rank() >= 2, "transpose_matrix operator must be on rank 2 or greater"); diff --git a/test/00_operators/OperatorTests.cu b/test/00_operators/OperatorTests.cu index 1d5d6b5f..b232f1ef 100644 --- a/test/00_operators/OperatorTests.cu +++ b/test/00_operators/OperatorTests.cu @@ -3780,7 +3780,7 @@ TYPED_TEST(OperatorTestsFloatNonComplexNonHalfAllExecs, R2COp) MATX_EXIT_HANDLER(); } -TYPED_TEST(OperatorTestsFloatNonHalf, FftShiftWithTransform) +TYPED_TEST(OperatorTestsFloatNonHalf, FFTShiftWithTransform) { MATX_ENTER_HANDLER(); using TestType = cuda::std::tuple_element_t<0, TypeParam>; diff --git a/test/00_transform/FFT.cu b/test/00_transform/FFT.cu index 2e165c4b..d77a3786 100644 --- a/test/00_transform/FFT.cu +++ b/test/00_transform/FFT.cu @@ -52,6 +52,12 @@ protected: GTEST_SKIP(); } + // Use an arbitrary number of threads for the select threads host exec. + if constexpr (is_select_threads_host_executor_v) { + HostExecParams params{4}; + exec = SelectThreadsHostExecutor{params}; + } + pb = std::make_unique(); // Half precision needs a bit more tolerance when compared to fp32 diff --git a/test/include/test_types.h b/test/include/test_types.h index 9fb1d6eb..ae114943 100644 --- a/test/include/test_types.h +++ b/test/include/test_types.h @@ -72,10 +72,9 @@ template <> auto inline GenerateData>() return cuda::std::complex(1.5, -2.5); } -using ExecutorTypesAll = cuda::std::tuple; +using ExecutorTypesAll = cuda::std::tuple; using ExecutorTypesCUDAOnly = cuda::std::tuple; - /* Taken from https://stackoverflow.com/questions/70404549/cartesian-product-of-stdtuple */ template class TypedCartesianProduct { @@ -151,7 +150,6 @@ using MatXNumericNonComplexTypesCUDAExec = TupleToTypes::type>::type; using MatXDoubleOnlyTypeCUDAExec = TupleToTypes::type>::type; - // All executor types using MatXNumericNonComplexTypesAllExecs = TupleToTypes::type>::type; using MatXFloatNonHalfTypesAllExecs = TupleToTypes::type>::type;