diff --git a/perf_test/sparse/KokkosSparse_spmv.cpp b/perf_test/sparse/KokkosSparse_spmv.cpp index 1b5e98d74a..6b67905adc 100644 --- a/perf_test/sparse/KokkosSparse_spmv.cpp +++ b/perf_test/sparse/KokkosSparse_spmv.cpp @@ -75,10 +75,6 @@ #include #endif -#ifdef KOKKOSKERNELS_ENABLE_TPL_ARMPL -#include -#endif - int test_crs_matrix_singlevec(Ordinal numRows, Ordinal numCols, int test, const char* filename, Ordinal rows_per_thread, int team_size, int vector_length, int schedule, diff --git a/perf_test/sparse/KokkosSparse_spmv_test.cpp b/perf_test/sparse/KokkosSparse_spmv_test.cpp index ffdd5e352f..098c7e923b 100644 --- a/perf_test/sparse/KokkosSparse_spmv_test.cpp +++ b/perf_test/sparse/KokkosSparse_spmv_test.cpp @@ -66,6 +66,10 @@ #include #endif +#ifdef KOKKOSKERNELS_ENABLE_TPL_ARMPL +#include +#endif + // return std::make_tuple(newnumRows, newnumCols, A, x1, y1, // rows_per_thread, team_size, vector_length, // test, schedule, ave_time, max_time, min_time); diff --git a/perf_test/sparse/KokkosSparse_spmv_test.hpp b/perf_test/sparse/KokkosSparse_spmv_test.hpp index 0603acf1d6..b6ff552faf 100644 --- a/perf_test/sparse/KokkosSparse_spmv_test.hpp +++ b/perf_test/sparse/KokkosSparse_spmv_test.hpp @@ -33,6 +33,9 @@ #include #endif +template +void armpl_matvec(AType /*A*/, XType x, YType y, spmv_additional_data* data); + enum { KOKKOS, MKL, diff --git a/perf_test/sparse/spmv/ArmPL_SPMV.hpp b/perf_test/sparse/spmv/ArmPL_SPMV.hpp index 0776f0d938..d36fd959f4 100644 --- a/perf_test/sparse/spmv/ArmPL_SPMV.hpp +++ b/perf_test/sparse/spmv/ArmPL_SPMV.hpp @@ -48,21 +48,22 @@ #ifdef KOKKOSKERNELS_ENABLE_TPL_ARMPL #include -template -void armpl_matvec_wrapper(armpl_spmat_t A, Scalar* x, Scalar* y) { - throw std::runtime_error( - "Can't use ArmPL mat-vec for scalar types other than double and float."); -} +// template +// void armpl_matvec_wrapper(armpl_spmat_t A, Scalar* x, Scalar* y) { +// throw std::runtime_error( +// "Can't use ArmPL mat-vec for scalar types other than double and +// float."); +//} -template <> -void armpl_matvec_wrapper(armpl_spmat_t A, float* x, float* y) { +// template <> +void armpl_matvec_wrapper(armpl_spmat_t A, float* x, float* y) { const float alpha = 1.0; const float beta = 0.0; armpl_spmv_exec_s(ARMPL_SPARSE_OPERATION_NOTRANS, alpha, A, x, beta, y); } -template <> -void armpl_matvec_wrapper(armpl_spmat_t A, double* x, double* y) { +// template <> +void armpl_matvec_wrapper(armpl_spmat_t A, double* x, double* y) { const double alpha = 1.0; const double beta = 0.0; armpl_spmv_exec_d(ARMPL_SPARSE_OPERATION_NOTRANS, alpha, A, x, beta, y); @@ -72,7 +73,7 @@ template void armpl_matvec(AType /*A*/, XType x, YType y, spmv_additional_data* data) { using Scalar = typename AType::non_const_value_type; // Run armpl spmv corresponding to scalar type - armpl_matvec_wrapper(data->A, x.data(), y.data()); + armpl_matvec_wrapper(data->A, x.data(), y.data()); } #endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL diff --git a/src/batched/dense/KokkosBatched_Gemm_Decl.hpp b/src/batched/dense/KokkosBatched_Gemm_Decl.hpp index 9b1eb18e1f..77e56abf89 100644 --- a/src/batched/dense/KokkosBatched_Gemm_Decl.hpp +++ b/src/batched/dense/KokkosBatched_Gemm_Decl.hpp @@ -211,6 +211,7 @@ class BatchedSerialGemm; /// ResultsPerThread::Rank0 Each thread computes a scalar of C /// ResultsPerThread::Rank1 Each thread computes a 1-rank chunk of C /// ResultsPerThread::Rank2 Each thread computes a 2-rank chunk of C +/// \tparam HandleType Specifies the handle type of the kernel handle /// \tparam ScalarType Specifies the scalar type of alpha and beta /// \tparam AViewType Input matrix, as either a 3-rank Kokkos::View or a /// 4-rank Kokkos::View for SIMD operations. @@ -256,6 +257,52 @@ template class BatchedDblBufGemm; + +// clang-format off +/// \brief Blocking solve of general matrix multiply on a batch of uniform matrices. +/// +/// +/// C = alpha * op(A) * op(B) + beta * C +/// +/// \tparam ArgTransA Specifies what op does to A: +/// Trans::NoTranspose for non-transpose +/// Trans::Transpose for transpose +/// Trans::ConjTranspose for conjugate transpose (unsupported) +/// \tparam ArgTransB Specifies what op does to B: +/// Trans::NoTranspose for non-transpose +/// Trans::Transpose for transpose +/// Trans::ConjTranspose for conjugate transpose (unsupported) +/// \tparam HandleType Specifies the handle type of the kernel handle +/// \tparam ScalarType Specifies the scalar type of alpha and beta +/// \tparam AViewType Input matrix, as a 3-rank Kokkos::View +/// \tparam BViewType Input matrix, as a 3-rank Kokkos::View +/// \tparam CViewType Input(RHS)/Output(LHS) matrix, as a 3-rank +/// Kokkos::View +/// +/// See struct BatchedGemmHandle for details +/// \param ninter [in] The number of matrices to interleave +/// \param alpha [in] Input coefficient used for multiplication with A +/// \param A [in] Input matrix, as a 3-rank Kokkos::View +/// If ArgBatchSzDim == "BatchSzDim::Right", matrix A is MxKxB +/// If ArgBatchSzDim == "BatchSzDim::Left", matrix A is BxMxK +/// \param B [in] Input matrix, as a 3-rank Kokkos::View +/// If ArgBatchSzDim == "BatchSzDim::Right", matrix B is KxNxB +/// If ArgBatchSzDim == "BatchSzDim::Left", matrix B is BxKxN +/// \param beta [in] Input coefficient used for multiplication with C +/// \param C [in/out] Input/Output matrix, as a 3-rank Kokkos::View +/// If ArgBatchSzDim == "BatchSzDim::Right", matrix C is MxNxB +/// If ArgBatchSzDim == "BatchSzDim::Left", matrix C is BxMxN +/// \return 0 upon success, non-zero otherwise +/// +/// Usage Example: +/// BatchedArmplGemm +/// (ninter, alpha, A, B, beta, C).invoke(); +// clang-format on +template +class BatchedArmplGemm; /********************* END forward declarations *********************/ } // namespace Impl @@ -328,6 +375,14 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha, "BViewType must be a Kokkos::View."); static_assert(Kokkos::is_view::value, "CViewType must be a Kokkos::View."); + static_assert( + std::is_same::value || + std::is_same::value, + "ArgTransA must be either Trans::Transpose or Trans::NoTranspose."); + static_assert( + std::is_same::value || + std::is_same::value, + "ArgTransB must be either Trans::Transpose or Trans::NoTranspose."); if (is_vector::value) { // Check ranks of view with underlying SIMD value types // For SIMD views, we can have either 3-rank or 4-ranks inputs. @@ -419,15 +474,6 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha, } switch (handle->get_kernel_algo_type()) { - case BaseKokkosBatchedAlgos::KK_SERIAL: - ret = - Impl::BatchedSerialGemm( - alpha, A, B, beta, C) - .invoke(); - break; - ////////////// HEURISTIC ALGOS ////////////// case BaseHeuristicAlgos::SQUARE: if (c_m != c_n) { @@ -524,54 +570,64 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha, } break; - case GemmKokkosBatchedAlgos::KK_DBLBUF: - // Note: The tile sizes of 1x1x1 here will not perform well but must be - // selected in order to function on all devices since the serial execution - // space has a max team size of 1. KokkosKernels API users will need to - // follow an approach similar to KK_SQUARE above for best performance. + // case BaseHeuristicAlgos::TALL: + // + // case BaseHeuristicAlgos::WIDE: + ////////////// TPL ALGOS ////////////// - // TODO: Add auto-selection of tile size based on inputs and device type - ret = Impl::BatchedDblBufGemm( - handle, alpha, A, B, beta, C) + case BaseTplAlgos::ARMPL: + ret = Impl::BatchedArmplGemm(handle, alpha, A, B, + beta, C) .invoke(); break; + // case BaseTplAlgos::MKL: + // + // case GemmTplAlgos::CUBLAS: + // + // case GemmTplAlgos::MAGMA: - case GemmKokkosBatchedAlgos::KK_SERIAL_RANK0: + ////////////// KokkosBatched ALGOS ////////////// + case BaseKokkosBatchedAlgos::KK_SERIAL: ret = Impl::BatchedSerialGemm( alpha, A, B, beta, C) .invoke(); break; - case BaseHeuristicAlgos::TALL: - - case BaseHeuristicAlgos::WIDE: - - ////////////// TPL ALGOS ////////////// - case BaseTplAlgos::ARMPL: - - case BaseTplAlgos::MKL: - - case GemmTplAlgos::CUBLAS: - - case GemmTplAlgos::MAGMA: - - ////////////// KokkosBatched ALGOS ////////////// - - case GemmKokkosBatchedAlgos::KK_TEAM: + // case GemmKokkosBatchedAlgos::KK_SERIALSIMD: - case GemmKokkosBatchedAlgos::KK_TEAMVECTOR: + case GemmKokkosBatchedAlgos::KK_SERIAL_RANK0: + ret = + Impl::BatchedSerialGemm( + alpha, A, B, beta, C) + .invoke(); + break; - case GemmKokkosBatchedAlgos::KK_SERIALSIMD: + // case GemmKokkosBatchedAlgos::KK_SERIAL_SHMEM: + // case GemmKokkosBatchedAlgos::KK_TEAM: + // case GemmKokkosBatchedAlgos::KK_TEAMVECTOR: + // case GemmKokkosBatchedAlgos::KK_TEAMSIMD: - case GemmKokkosBatchedAlgos::KK_TEAMSIMD: + case GemmKokkosBatchedAlgos::KK_DBLBUF: + // Note: The tile sizes of 1x1x1 here will not perform well but must be + // selected in order to function on all devices since the serial execution + // space has a max team size of 1. KokkosKernels API users will need to + // follow an approach similar to KK_SQUARE above for best performance. - case GemmKokkosBatchedAlgos::KK_SERIAL_SHMEM: + // TODO: Add auto-selection of tile size based on inputs and device type + ret = Impl::BatchedDblBufGemm( + handle, alpha, A, B, beta, C) + .invoke(); + break; default: std::ostringstream os; @@ -589,5 +645,6 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha, #include "KokkosBatched_Gemm_Team_Impl.hpp" #include "KokkosBatched_Gemm_TeamVector_Impl.hpp" #include "KokkosBatched_Gemm_DblBuf_Impl.hpp" +#include "KokkosBatched_Gemm_Armpl_Impl.hpp" #endif diff --git a/src/batched/dense/KokkosBatched_Gemm_Handle.hpp b/src/batched/dense/KokkosBatched_Gemm_Handle.hpp index 47a2071851..27b8278e58 100644 --- a/src/batched/dense/KokkosBatched_Gemm_Handle.hpp +++ b/src/batched/dense/KokkosBatched_Gemm_Handle.hpp @@ -70,6 +70,15 @@ enum GEMM_KOKKOS_BATCHED_ALGOS : int { }; } +#define GEMM_ALGO_STRS \ + "GemmTplAlgos::CUBLAS", "GemmTplAlgos::MAGMA", \ + "GemmKokkosBatchedAlgos::KK_TEAM", \ + "GemmKokkosBatchedAlgos::KK_TEAMVECTOR", \ + "GemmKokkosBatchedAlgos::KK_SERIALSIMD", \ + "GemmKokkosBatchedAlgos::KK_TEAMSIMD", \ + "GemmKokkosBatchedAlgos::KK_SERIAL_RANK0", \ + "GemmKokkosBatchedAlgos::KK_SERIAL_SHMEM", \ + "GemmKokkosBatchedAlgos::KK_DBLBUF" // clang-format off /// \brief Handle for selecting runtime behavior of the BatchedGemm interface. /// @@ -163,7 +172,9 @@ class BatchedGemmHandle : public BatchedKernelHandle { #endif // MAGMA decltype(auto) get_tpl_params() { -#if _kernelAlgoType == CUBLAS && defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS) +#if _kernelAlgoType == ARMPL && defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) + return &_tplParamsSingleton.ninter; +#elif _kernelAlgoType == CUBLAS && defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS) return _tplParamsSingleton.cublas_handle; #elif _kernelAlgoType == MAGMA && defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA) return _tplParamsSingleton.magma_queue; @@ -171,6 +182,14 @@ class BatchedGemmHandle : public BatchedKernelHandle { return this->BatchedKernelHandle::get_tpl_params(); #endif } + + std::string get_kernel_algo_type_str() const { + return gemm_algo_type_strs[_kernelAlgoType]; + } + + private: + const char *gemm_algo_type_strs[GemmKokkosBatchedAlgos::N] = {BASE_ALGO_STRS, + GEMM_ALGO_STRS}; }; } // namespace KokkosBatched diff --git a/src/batched/dense/KokkosBatched_Kernel_Handle.hpp b/src/batched/dense/KokkosBatched_Kernel_Handle.hpp index 22833b78a5..b9b75e2cc5 100644 --- a/src/batched/dense/KokkosBatched_Kernel_Handle.hpp +++ b/src/batched/dense/KokkosBatched_Kernel_Handle.hpp @@ -52,7 +52,7 @@ #endif // KOKKOSKERNELS_ENABLE_TPL_MKL #if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) -// TODO: Add armpl handle type to expose nintern & nbatch? +#include "armpl.h" #endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL #if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS) @@ -83,6 +83,10 @@ enum BASE_KOKKOS_BATCHED_ALGOS : int { KK_SERIAL = BaseTplAlgos::N, N }; } #define N_BASE_ALGOS BaseKokkosBatchedAlgos::N +#define BASE_ALGO_STRS \ + "BaseHeuristicAlgos::SQUARE", "BaseHeuristicAlgos::TALL", \ + "BaseHeuristicAlgos::WIDE", "BaseTplAlgos::ARMPL", "BaseTplAlgosMKL", \ + "BaseKokkosBatchedAlgos::KK_SERIAL" /// \brief TplParams abstracts underlying handle or execution queue type. struct TplParams { @@ -94,7 +98,7 @@ struct TplParams { #endif // KOKKOSKERNELS_ENABLE_TPL_MKL #if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) - // TODO: Add armpl handle type in KokkosKernels to expose nintern & nbatch? + armpl_int_t ninter = 1; #endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL #if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS) @@ -172,19 +176,22 @@ class BatchedKernelHandle { int teamSize = 0, int vecLength = 0) : teamSz(teamSize), vecLen(vecLength), _kernelAlgoType(kernelAlgoType){}; - decltype(auto) get_tpl_params() { + int get_kernel_algo_type() const { return _kernelAlgoType; } + + std::string get_kernel_algo_type_str() const { + return algo_type_strs[_kernelAlgoType]; + } + + decltype(auto) get_tpl_params() const { #if _kernelAlgoType == ARMPL && defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) return "BaseTplAlgos::ARMPL does not support any tpl parameters"; #elif _kernelAlgoType == MKL && defined(KOKKOSKERNELS_ENABLE_TPL_MKL) return "BaseTplAlgos::MKL does not support any tpl parameters"; #else - return "Unsupported kernelAlgoType = " + std::to_string(_kernelAlgoType) + - "."; + return "Unsupported kernelAlgoType:" + get_kernel_algo_type_str() + "."; #endif } - int get_kernel_algo_type() const { return _kernelAlgoType; } - // clang-format off /// \var _kernelAlgoType Specifies which algorithm to use for invocation (default, SQUARE). /// \var _tplParams a handle or queue specific to the TPL API. @@ -202,6 +209,9 @@ class BatchedKernelHandle { int _kernelAlgoType = BaseHeuristicAlgos::SQUARE; TplParams &_tplParamsSingleton = _get_tpl_params_singleton(); bool _tplParamsSet = false; + + private: + const char *algo_type_strs[N_BASE_ALGOS] = {BASE_ALGO_STRS}; }; } // namespace KokkosBatched diff --git a/src/batched/dense/impl/KokkosBatched_Gemm_Armpl_Impl.hpp b/src/batched/dense/impl/KokkosBatched_Gemm_Armpl_Impl.hpp new file mode 100644 index 0000000000..f66a236ffc --- /dev/null +++ b/src/batched/dense/impl/KokkosBatched_Gemm_Armpl_Impl.hpp @@ -0,0 +1,299 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.4 +// Copyright (2021) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +#ifndef __KOKKOSBATCHED_GEMM_ARMPL_IMPL_HPP__ +#define __KOKKOSBATCHED_GEMM_ARMPL_IMPL_HPP__ +#if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) +#include "KokkosBatched_Util.hpp" + +namespace KokkosBatched { +/********************* BEGIN functor-level routines *********************/ +/// +/// Serial Impl +/// =========== +/********************* END functor-level routines *********************/ + +namespace Impl { +/********************* BEGIN non-functor-level routines *********************/ +// TODO: wrap this class in a macro for permutations of supported scalars. +template +class BatchedArmplGemm { + private: + HandleType *const __handle; + using avt = typename AViewType::value_type; + using bvt = typename BViewType::value_type; + using cvt = typename CViewType::value_type; + + AViewType __A; + avt *__Adp = nullptr; + armpl_int_t __Ajstrd, __Aistrd, __Abstrd; + + BViewType __B; + bvt *__Bdp = nullptr; + armpl_int_t __Bjstrd, __Bistrd, __Bbstrd; + + CViewType __C; + cvt *__Cdp = nullptr; + armpl_int_t __Cjstrd, __Cistrd, __Cbstrd; + + ScalarType __alpha, __beta; + armpl_int_t __ninter, __nbatch; + + char __transa, __transb; + + ArgTransA __transa_tag; + ArgTransB __transb_tag; + Trans::NoTranspose __no_trans_tag; + ArgBatchSzDim __batch_layout_tag; + + armpl_int_t __Am, __An, __Bm, __Bn, __Cm, __Cn; + + void __unpack_views() { + for (int ib = 0; ib < __nbatch; ++ib) { + for (int i = 0; i < __ninter; ++i) { + auto svA = + subview_wrapper(__A, ib * __ninter + i, Kokkos::ALL(), + Kokkos::ALL(), __batch_layout_tag, __no_trans_tag); + auto svB = + subview_wrapper(__B, ib * __ninter + i, Kokkos::ALL(), + Kokkos::ALL(), __batch_layout_tag, __no_trans_tag); + auto svC = + subview_wrapper(__C, ib * __ninter + i, Kokkos::ALL(), + Kokkos::ALL(), __batch_layout_tag, __no_trans_tag); + + auto info = armpl_dge_interleave( + __ninter, i, __Am, __An, reinterpret_cast(svA.data()), + svA.stride(0), svA.stride(1), + reinterpret_cast(&__Adp[__Abstrd * ib]), __Aistrd, + __Ajstrd); + if (info != ARMPL_STATUS_SUCCESS) { + std::ostringstream os; + os << "armpl_dge_interleave(A) returned:" << info << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + } + + info = armpl_dge_interleave( + __ninter, i, __Bm, __Bn, reinterpret_cast(svB.data()), + svB.stride(0), svB.stride(1), + reinterpret_cast(&__Bdp[__Bbstrd * ib]), __Bistrd, + __Bjstrd); + if (info != ARMPL_STATUS_SUCCESS) { + std::ostringstream os; + os << "armpl_dge_interleave(B) returned:" << info << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + } + + info = armpl_dge_interleave( + __ninter, i, __Cm, __Cn, reinterpret_cast(svC.data()), + svC.stride(0), svC.stride(1), + reinterpret_cast(&__Cdp[__Cbstrd * ib]), __Cistrd, + __Cjstrd); + if (info != ARMPL_STATUS_SUCCESS) { + std::ostringstream os; + os << "armpl_dge_interleave(C) returned:" << info << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + } + } + } + } + + void __repack_view() { + for (int ib = 0; ib < __nbatch; ++ib) { + for (int i = 0; i < __ninter; ++i) { + auto svC = + subview_wrapper(__C, ib * __ninter + i, Kokkos::ALL(), + Kokkos::ALL(), __batch_layout_tag, __no_trans_tag); + + auto info = armpl_dge_deinterleave( + __ninter, i, __Cm, __Cn, reinterpret_cast(svC.data()), + svC.stride(0), svC.stride(1), + reinterpret_cast(&__Cdp[__Cbstrd * ib]), __Cistrd, + __Cjstrd); + if (info != ARMPL_STATUS_SUCCESS) { + std::ostringstream os; + os << "armpl_dge_deinterleave returned:" << info << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + } + } + } + delete __Cdp; + } + + void __run() { + auto info = armpl_dgemm_interleave_batch( + __ninter, __nbatch, __transa, __transb, __Cm, __Cn, + std::is_same::value ? __An : __Am, + static_cast(__alpha), reinterpret_cast(__Adp), + __Abstrd, __Aistrd, __Ajstrd, reinterpret_cast(__Bdp), + __Bbstrd, __Bistrd, __Bjstrd, static_cast(__beta), + reinterpret_cast(__Cdp), __Cbstrd, __Cistrd, __Cjstrd); + if (info != ARMPL_STATUS_SUCCESS) { + std::ostringstream os; + os << "armpl_dgemm_interleave_batch returned :" << info << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + } + delete __Adp; + delete __Bdp; + } + + public: + BatchedArmplGemm(HandleType *const handle, ScalarType alpha, AViewType A, + BViewType B, ScalarType beta, CViewType C) + : __handle(handle), __A(A), __B(B), __C(C), __alpha(alpha), __beta(beta) { + __ninter = __handle->get_tpl_params()[0]; + + if (std::is_same::value) { + __Am = __A.extent(1); + __An = __A.extent(2); + __Bm = __B.extent(1); + __Bn = __B.extent(2); + __Cm = __C.extent(1); + __Cn = __C.extent(2); + __nbatch = __C.extent(0); + } else { + __Am = __A.extent(0); + __An = __A.extent(1); + __Bm = __B.extent(0); + __Bn = __B.extent(1); + __Cm = __C.extent(0); + __Cn = __C.extent(1); + __nbatch = __C.extent(2); + } + + __Ajstrd = __ninter; + __Aistrd = __Ajstrd * __An; + __Abstrd = __Aistrd * __Am; + + __Bjstrd = __ninter; + __Bistrd = __Bjstrd * __Bn; + __Bbstrd = __Bistrd * __Bm; + + __Cjstrd = __ninter; + __Cistrd = __Cjstrd * __Cn; + __Cbstrd = __Cistrd * __Cm; + + __transa = std::is_same::value ? 'N' : 'T'; + __transb = std::is_same::value ? 'N' : 'T'; + } + + int invoke() { + if (__handle->enableDebug) { + std::cerr << "__nbatch:" << std::to_string(__nbatch) + << ", __ninter:" << std::to_string(__ninter) + << ", __Am:" << std::to_string(__Am) + << ", __An:" << std::to_string(__An) + << ", __alpha:" << std::to_string(__alpha) + << ", __beta:" << std::to_string(__beta) << std::endl; + } + + if (!std::is_same::value || + !std::is_same::value || + !std::is_same::value || + !std::is_same::value) { + std::ostringstream os; + os << "KokkosBatched::Impl::BatchedArmplGemm only supports 'double' " + "scalar types." + << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + } + + if (__nbatch != 0) { + if (__ninter == 0 || __nbatch % __ninter) { + std::string msg = + "batch size must be evenly divisible by ninter. __nbatch: "; + msg += (std::to_string(__nbatch) + + ", __ninter: " + std::to_string(__ninter) + "\n"); + Kokkos::abort(msg.c_str()); + } + + // Calculate internal batch size for interleaving + __nbatch /= __ninter; + + // Allocate space for interleaving + // __Adp and __Bdp are deleted in __run() + // __Cdp is deleted in __repack_view() + __Adp = new avt[__Abstrd * __nbatch]; + __Bdp = new bvt[__Bbstrd * __nbatch]; + __Cdp = new cvt[__Cbstrd * __nbatch]; + + __unpack_views(); + __run(); + __repack_view(); + } + return 0; + } +}; +/********************* END non-functor-level routines *********************/ +} // namespace Impl +} // namespace KokkosBatched +#else // KOKKOSKERNELS_ENABLE_TPL_ARMPL +namespace KokkosBatched { +namespace Impl { +/********************* BEGIN non-functor-level routines *********************/ +// TODO: wrap this class in a macro for permutations of supported scalars. +template +class BatchedArmplGemm { + public: + BatchedArmplGemm(HandleType *const handle, ScalarType alpha, AViewType A, + BViewType B, ScalarType beta, CViewType C) { + (void)handle; + (void)alpha; + (void)A; + (void)B; + (void)beta; + (void)C; + } + + int invoke() { + std::ostringstream os; + os << "KokkosBatched::Impl::BatchedArmplGemm requires the ARMPL TPL" + << std::endl; + Kokkos::Impl::throw_runtime_exception(os.str()); + } +}; +} // namespace Impl +} // namespace KokkosBatched +#endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL +#endif diff --git a/test_common/KokkosKernels_TestUtils.hpp b/test_common/KokkosKernels_TestUtils.hpp index 89ce643d82..34d0a65357 100644 --- a/test_common/KokkosKernels_TestUtils.hpp +++ b/test_common/KokkosKernels_TestUtils.hpp @@ -93,10 +93,11 @@ struct multivector_layout_adapter { }; template -void EXPECT_NEAR_KK(Scalar1 val1, Scalar2 val2, Scalar3 tol) { +void EXPECT_NEAR_KK(Scalar1 val1, Scalar2 val2, Scalar3 tol, + std::string msg = "") { typedef Kokkos::Details::ArithTraits AT1; typedef Kokkos::Details::ArithTraits AT3; - EXPECT_LE((double)AT1::abs(val1 - val2), (double)AT3::abs(tol)); + EXPECT_LE((double)AT1::abs(val1 - val2), (double)AT3::abs(tol)) << msg; } template @@ -116,6 +117,19 @@ void EXPECT_NEAR_KK_1DVIEW(ViewType1 v1, ViewType2 v2, Scalar tol) { } } +/// This function returns a descriptive user defined failure string for +/// insertion into gtest macros such as FAIL() and EXPECT_LE(). \param file The +/// filename where the failure originated \param func The function where the +/// failure originated \param line The line number where the failure originated +/// \return a new string containing: " > from file:func:line\n > " +static inline const std::string kk_failure_str(std::string file, + std::string func, + const int line) { + std::string failure_msg = " > from "; + failure_msg += (file + ":" + func + ":" + std::to_string(line) + "\n > "); + return std::string(failure_msg); +} + #if defined(KOKKOS_HALF_T_IS_FLOAT) using halfScalarType = Kokkos::Experimental::half_t; #endif // KOKKOS_HALF_T_IS_FLOAT diff --git a/tpls/gtest/gtest/gtest-all.cc b/tpls/gtest/gtest/gtest-all.cc index 735f581c95..15d0cb7bca 100644 --- a/tpls/gtest/gtest/gtest-all.cc +++ b/tpls/gtest/gtest/gtest-all.cc @@ -7482,8 +7482,8 @@ void StackLowerThanAddress(const void* ptr, bool* result) { } bool StackGrowsDown() { - int dummy; - bool result; + int dummy = 1; + bool result = 0; StackLowerThanAddress(&dummy, &result); return result; } diff --git a/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp b/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp index d230b39d14..4864728e07 100644 --- a/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp +++ b/unit_test/batched/dense/Test_Batched_BatchedGemm.hpp @@ -26,8 +26,20 @@ void impl_test_batched_gemm_with_handle(BatchedGemmHandle* batchedGemmHandle, int ret = 0; auto algo_type = batchedGemmHandle->get_kernel_algo_type(); - ViewType a_expected, a_actual, b_expected, b_actual, c_expected, c_actual; + std::string fmsg; + std::string fmsg_rhs = + "algo_type:" + batchedGemmHandle->get_kernel_algo_type_str() + ", "; + fmsg_rhs += ("N:" + std::to_string(N) + ", "); + fmsg_rhs += + ("A:" + std::to_string(matAdim1) + "x" + std::to_string(matAdim2) + ", "); + fmsg_rhs += + ("B:" + std::to_string(matBdim1) + "x" + std::to_string(matBdim2) + ", "); + fmsg_rhs += + ("C:" + std::to_string(matCdim1) + "x" + std::to_string(matCdim2) + ", "); + fmsg_rhs += ("alpha:" + std::to_string(alpha) + ", "); + fmsg_rhs += ("beta:" + std::to_string(beta) + "\n"); + if (std::is_same::value) { a_expected = ViewType("a_expected", N, matAdim1, matAdim2); a_actual = ViewType("a_actual", N, matAdim1, matAdim2); @@ -60,26 +72,28 @@ void impl_test_batched_gemm_with_handle(BatchedGemmHandle* batchedGemmHandle, if (algo_type == GemmKokkosBatchedAlgos::KK_DBLBUF) { // Check for DblBuf runtime errors related to team_size try { + fmsg = kk_failure_str(__FILE__, __FUNCTION__, __LINE__); Impl::BatchedDblBufGemm( batchedGemmHandle, alpha, a_actual, b_actual, beta, c_actual) .invoke(); - FAIL(); + FAIL() << (fmsg + fmsg_rhs); } catch (const std::runtime_error& error) { ; } // Check for DblBuf runtime errors related to vector_len try { + fmsg = kk_failure_str(__FILE__, __FUNCTION__, __LINE__); Impl::BatchedDblBufGemm< transA, transB, batchLayout, BatchedGemmHandle, ScalarType, decltype(a_actual), decltype(b_actual), decltype(c_actual), BoundsCheck::No, AlphaTag::No, 65536, 65536 * 2, 65536>( batchedGemmHandle, alpha, a_actual, b_actual, beta, c_actual) .invoke(); - FAIL(); + FAIL() << (fmsg + fmsg_rhs); } catch (const std::runtime_error& error) { ; } @@ -87,22 +101,41 @@ void impl_test_batched_gemm_with_handle(BatchedGemmHandle* batchedGemmHandle, // Check for expected BatchedGemm runtime errors try { - ret = BatchedGemm( + if (algo_type == BaseTplAlgos::ARMPL && N % 2 == 0) { + batchedGemmHandle->get_tpl_params()[0] = N / 2; + } + + fmsg = kk_failure_str(__FILE__, __FUNCTION__, __LINE__); + ret = BatchedGemm( batchedGemmHandle, alpha, a_actual, b_actual, beta, c_actual); // Compute c_actual } catch (const std::runtime_error& error) { - // std::cout << "Caught expected runtime error" << std::endl; - if (algo_type == BaseHeuristicAlgos::SQUARE && matCdim1 != matCdim2) + bool is_invalid_layout = + !((std::is_same::value && + !std::is_same::value) || + (std::is_same::value && + !std::is_same::value)); + std::string error_msg = error.what(); + if (algo_type == BaseHeuristicAlgos::SQUARE && matCdim1 != matCdim2) { ; - else if (!((std::is_same::value && - !std::is_same::value) || - (std::is_same::value && - !std::is_same::value))) { - FAIL(); + } else if (algo_type == BaseTplAlgos::ARMPL) { +#if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL) + // No runtime errors expected since double is a supported type. + if (is_invalid_layout && + std::is_same::value) { + FAIL() << (error_msg + fmsg + fmsg_rhs); + } +#else + ; // We expect a runtime error if the ARMPL TPL is not enabled +#endif + } else if (is_invalid_layout) { + // No runtime errors expected since we only support certain BatchLayouts + // for LayoutLeft and LayoutRight. + FAIL() << (error_msg + fmsg + fmsg_rhs); } return; } - ASSERT_EQ(ret, 0); + ASSERT_EQ(ret, 0) << (fmsg + fmsg_rhs); Functor_BatchedVanillaGEMM vgemm; @@ -122,11 +155,12 @@ void impl_test_batched_gemm_with_handle(BatchedGemmHandle* batchedGemmHandle, typename ViewType::HostMirror c_expected_host = Kokkos::create_mirror_view(c_expected); - typename ViewType::HostMirror c1_host = Kokkos::create_mirror_view(c_actual); + typename ViewType::HostMirror c_actual_host = + Kokkos::create_mirror_view(c_actual); // Copy to host Kokkos::deep_copy(c_expected_host, c_expected); - Kokkos::deep_copy(c1_host, c_actual); + Kokkos::deep_copy(c_actual_host, c_actual); Kokkos::fence(); @@ -141,16 +175,25 @@ void impl_test_batched_gemm_with_handle(BatchedGemmHandle* batchedGemmHandle, for (int i = 0; i < matCdim1; ++i) { for (int j = 0; j < matCdim2; ++j) { if (std::is_same::value) { + // printf("c_expected_host(%d, %d, %d): %g\n", i, j, k, + // ats::abs(c_expected_host(i, j, k))); + // printf("c_actual_host(%d, %d, %d): %g\n", i, j, k, + // ats::abs(c_actual_host(i, j, k))); sum += ats::abs(c_expected_host(i, j, k)); - diff += ats::abs(c_expected_host(i, j, k) - c1_host(i, j, k)); + diff += ats::abs(c_expected_host(i, j, k) - c_actual_host(i, j, k)); } else { + // printf("c_expected_host(%d, %d, %d): %g\n", k, i, j, + // ats::abs(c_expected_host(k, i, j))); + // printf("c_actual_host(%d, %d, %d): %g\n", k, i, j, + // ats::abs(c_actual_host(k, i, j))); sum += ats::abs(c_expected_host(k, i, j)); - diff += ats::abs(c_expected_host(k, i, j) - c1_host(k, i, j)); + diff += ats::abs(c_expected_host(k, i, j) - c_actual_host(k, i, j)); } } } } - EXPECT_NEAR_KK(diff / sum, 0, eps); + + EXPECT_NEAR_KK(diff / sum, 0, eps, fmsg + fmsg_rhs); } template (