diff --git a/src/batched/KokkosBatched_Gemm_Decl.hpp b/src/batched/KokkosBatched_Gemm_Decl.hpp index a302f882dc..1c16d4ce4e 100644 --- a/src/batched/KokkosBatched_Gemm_Decl.hpp +++ b/src/batched/KokkosBatched_Gemm_Decl.hpp @@ -46,6 +46,9 @@ #include "KokkosBatched_Vector.hpp" #include "KokkosBatched_Gemm_Handle.hpp" +// Includes for non-functor-level routines +#include + namespace KokkosBatched { /********************* BEGIN functor-level routines *********************/ /// @@ -273,9 +276,53 @@ int BatchedGemm(const BatchedGemmHandleType *handle, const ScalarType alpha, "on_intel:" << on_intel << std::endl << "on_a64fx:" << on_a64fx << std::endl; #endif - BatchedSerialGemm(alpha, A, B, beta, C); + switch (handle->get_kernel_algo_type()) { + case BaseKokkosBatchedAlgos::KK_SERIAL: + ret = BatchedSerialGemm(alpha, A, B, beta, C) + .invoke(); + break; + + ////////////// HEURISTIC ALGOS ////////////// + case BaseHeuristicAlgos::SQUARE: + + case BaseHeuristicAlgos::TALL: + + case BaseHeuristicAlgos::WIDE: + + ////////////// TPL ALGOS ////////////// + case BaseTplAlgos::ARMPL: + + case BaseTplAlgos::MKL: + + case GemmTplAlgos::CUBLAS: + + case GemmTplAlgos::MAGMA: + + ////////////// KokkosBatched ALGOS ////////////// + + case GemmKokkosBatchedAlgos::KK_TEAM: + + case GemmKokkosBatchedAlgos::KK_TEAMVECTOR: + + case GemmKokkosBatchedAlgos::KK_SERIALSIMD: + + case GemmKokkosBatchedAlgos::KK_TEAMSIMD: + + case GemmKokkosBatchedAlgos::KK_SERIAL_OPT2: + + case GemmKokkosBatchedAlgos::KK_TEAMVECTOR_SHMEM: + + case GemmKokkosBatchedAlgos::KK_TEAMVECTOR_DBLBUF: + + default: + std::ostringstream os; + os << "KokkosBatched::BatchedGemm does not support kernelAlgoType = " + << std::to_string(handle->get_kernel_algo_type()) << "." << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + break; + } return ret; } /********************* END non-functor-level routines *********************/ diff --git a/src/batched/KokkosBatched_Gemm_Handle.hpp b/src/batched/KokkosBatched_Gemm_Handle.hpp index efb2d6bec9..9df9944e97 100644 --- a/src/batched/KokkosBatched_Gemm_Handle.hpp +++ b/src/batched/KokkosBatched_Gemm_Handle.hpp @@ -119,13 +119,14 @@ enum GEMM_KOKKOS_BATCHED_ALGOS : int { // clang-format on class BatchedGemmHandle : public BatchedKernelHandle { public: - BatchedGemmHandle() = default; - + BatchedGemmHandle(int kernelAlgoType = BaseHeuristicAlgos::SQUARE, + int teamSize = 0, int vecLength = 0) + : BatchedKernelHandle(kernelAlgoType, teamSize, vecLength){}; decltype(auto) get_tpl_params() { #if _kernelAlgoType == CUBLAS && defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS) - return _tplParams.tplHandle.cublas_handle; + return _tplParamsSingleton.cublas_handle; #elif _kernelAlgoType == MAGMA && defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA) - return _tplParams.tplQueue.magma_queue; + return _tplParamsSingleton.magma_queue; #else return this->BatchedKernelHandle::get_tpl_params(); #endif diff --git a/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp b/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp index 753d5522b8..49bdde18ab 100644 --- a/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp +++ b/src/batched/KokkosBatched_Gemm_Serial_Impl.hpp @@ -443,6 +443,7 @@ class BatchedSerialGemm { *this); } + public: int invoke() { if (std::is_same::value) { // Set members for ResultsPerThread::Rank0 operator; these members allow @@ -476,12 +477,9 @@ class BatchedSerialGemm { return 0; } - public: BatchedSerialGemm(ScalarType _alpha, AViewType _A, BViewType _B, ScalarType _beta, CViewType _C) - : A(_A), B(_B), C(_C), alpha(_alpha), beta(_beta) { - invoke(); - } + : A(_A), B(_B), C(_C), alpha(_alpha), beta(_beta) {} KOKKOS_INLINE_FUNCTION void operator()(const ResultsPerThread::Rank0 &, const int &i) const { diff --git a/src/batched/KokkosBatched_Kernel_Handle.hpp b/src/batched/KokkosBatched_Kernel_Handle.hpp index e815730ad4..27ca4b9efa 100644 --- a/src/batched/KokkosBatched_Kernel_Handle.hpp +++ b/src/batched/KokkosBatched_Kernel_Handle.hpp @@ -47,6 +47,24 @@ #ifndef KOKKOSKERNELS_KOKKOSBATCHED_KERNEL_HEADER_HPP #define KOKKOSKERNELS_KOKKOSBATCHED_KERNEL_HEADER_HPP +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) +#include +#endif // KOKKOSKERNELS_ENABLE_TPL_MKL + +#if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) +// TODO: Add armpl handle type to expose nintern & nbatch? +#endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL + +#if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS) +#include "cuda_runtime.h" +#include "cublas_v2.h" +#endif // KOKKOSKERNELS_ENABLE_TPL_CUBLAS + +#if defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA) +#include +#include +#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA + namespace KokkosBatched { /// \brief Heuristic algorithm types. See BatchedKernelHandle for details. @@ -66,36 +84,25 @@ enum BASE_KOKKOS_BATCHED_ALGOS : int { KK_SERIAL = BaseTplAlgos::N, N }; #define N_BASE_ALGOS BaseKokkosBatchedAlgos::N -/// \brief TplHandle abstracts underlying handle type. -union TplHandle { +/// \brief TplParams abstracts underlying handle or execution queue type. +struct TplParams { + union { +#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) + queue mkl_queue; +#endif // KOKKOSKERNELS_ENABLE_TPL_MKL + #if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) -// TODO: Add armpl handle type to expose nintern & nbatch? + // TODO: Add armpl handle type to expose nintern & nbatch? #endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL #if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS) -#include - cublasHandle_t &cublas_handle; + cublasHandle_t cublas_handle; #endif // KOKKOSKERNELS_ENABLE_TPL_CUBLAS - // char no_handles; -}; - -/// \brief TplExecQueue abstracts underlying execution queue type. -union TplExecQueue { -#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL) -#include - queue &mkl_queue; -#endif // KOKKOSKERNELS_ENABLE_TPL_MKL #if defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA) -#include - magma_queue_t &magma_queue; + magma_queue_t magma_queue; #endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA - // char no_queues; -}; - -union TplParams { - TplHandle tplHandle; - TplExecQueue tplQueue; + }; }; // clang-format off @@ -142,22 +149,26 @@ class BatchedKernelHandle { decltype(auto) get_tpl_params() { #if _kernelAlgoType == ARMPL && defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) - Kokkos::abort("BaseTplAlgos::ARMPL does not support any parameters"); - return nullptr; + return "BaseTplAlgos::ARMPL does not support any tpl parameters"; #elif _kernelAlgoType == MKL && defined(KOKKOSKERNELS_ENABLE_TPL_MKL) - return _tplParams.tplQueue.mkl_queue; + return _tplParamsSingleton.mkl_queue; +#else + return "Unsupported kernelAlgoType = " + std::to_string(_kernelAlgoType) + + "."; #endif } + int get_kernel_algo_type() const { return _kernelAlgoType; } + /// \var _kernelAlgoType Specifies which algorithm to use for invocation /// (default, SQUARE). \var _enabledDebug toggle debug messages. \var /// _tplParams a handle or queue specific to the TPL API. /// managed internally unless provided by user via /// constructor overload protected: - int _kernelAlgoType = BaseHeuristicAlgos::SQUARE; - bool _enableDebug = false; - TplParams _tplParams; + int _kernelAlgoType = BaseHeuristicAlgos::SQUARE; + constexpr static bool _enableDebug = false; + static TplParams &_tplParamsSingleton; }; } // namespace KokkosBatched diff --git a/unit_test/batched/Test_Batched_BatchedGemm.hpp b/unit_test/batched/Test_Batched_BatchedGemm.hpp index 40695cc0e4..9eef498e7a 100644 --- a/unit_test/batched/Test_Batched_BatchedGemm.hpp +++ b/unit_test/batched/Test_Batched_BatchedGemm.hpp @@ -11,9 +11,11 @@ using namespace KokkosBatched; namespace Test { template -void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, - const int matBdim1, const int matBdim2, - const int matCdim1, const int matCdim2) { +void impl_test_batched_gemm_handle(BatchedGemmHandle* batchedGemmHandle, + const int N, const int matAdim1, + const int matAdim2, const int matBdim1, + const int matBdim2, const int matCdim1, + const int matCdim2) { using execution_space = typename DeviceType::execution_space; using transA = typename ParamTagType::transA; using transB = typename ParamTagType::transB; @@ -21,10 +23,9 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, using view_layout = typename ViewType::array_layout; using ats = Kokkos::Details::ArithTraits; - int ret = 0; - ScalarType alpha = ScalarType(1.5); - ScalarType beta = ScalarType(3.0); - BatchedGemmHandle batchedGemmHandle; + int ret = 0; + ScalarType alpha = ScalarType(1.5); + ScalarType beta = ScalarType(3.0); ViewType a_expected, a_actual, b_expected, b_actual, c_expected, c_actual; if (std::is_same::value) { @@ -58,7 +59,7 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, // Check for expected runtime errors due to non-optimal BatchedGemm invocation try { ret = BatchedGemm( - &batchedGemmHandle, alpha, a_actual, b_actual, beta, + batchedGemmHandle, alpha, a_actual, b_actual, beta, c_actual); // Compute c_actual } catch (const std::runtime_error& error) { if (!((std::is_same::value && @@ -119,6 +120,47 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, } EXPECT_NEAR_KK(diff / sum, 0, eps); } + +template +void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2, + const int matBdim1, const int matBdim2, + const int matCdim1, const int matCdim2) { + { + BatchedGemmHandle batchedGemmHandle; + + ASSERT_EQ(batchedGemmHandle.get_kernel_algo_type(), + BaseHeuristicAlgos::SQUARE); + ASSERT_EQ(batchedGemmHandle.teamSz, 0); + ASSERT_EQ(batchedGemmHandle.vecLen, 0); + } + + for (int algo_type = BaseHeuristicAlgos::SQUARE; + algo_type < GemmKokkosBatchedAlgos::N; ++algo_type) { + { + BatchedGemmHandle batchedGemmHandle(algo_type); + + ASSERT_EQ(batchedGemmHandle.get_kernel_algo_type(), algo_type); + + if (algo_type == BaseKokkosBatchedAlgos::KK_SERIAL) { + impl_test_batched_gemm_handle( + &batchedGemmHandle, N, matAdim1, matAdim2, matBdim1, matBdim2, + matCdim1, matCdim2); + } else { + try { + impl_test_batched_gemm_handle( + &batchedGemmHandle, N, matAdim1, matAdim2, matBdim1, matBdim2, + matCdim1, matCdim2); + FAIL(); + } catch (const std::runtime_error& error) { + ; + } + } + } + } +} } // namespace Test template