Skip to content

Commit

Permalink
src/batched: Add armpl dgemm support
Browse files Browse the repository at this point in the history
Conflicts:
	src/batched/dense/KokkosBatched_Gemm_Decl.hpp
	unit_test/batched/dense/Test_Batched_BatchedGemm.hpp
  • Loading branch information
e10harvey committed Jan 11, 2022
1 parent 2a3e31c commit 408e194
Show file tree
Hide file tree
Showing 11 changed files with 536 additions and 89 deletions.
4 changes: 0 additions & 4 deletions perf_test/sparse/KokkosSparse_spmv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,6 @@
#include <OpenMPSmartStatic_SPMV.hpp>
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ARMPL
#include <spmv/ArmPL_SPMV.hpp>
#endif

int test_crs_matrix_singlevec(Ordinal numRows, Ordinal numCols, int test,
const char* filename, Ordinal rows_per_thread,
int team_size, int vector_length, int schedule,
Expand Down
4 changes: 4 additions & 0 deletions perf_test/sparse/KokkosSparse_spmv_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@
#include <PerfTestUtilities.hpp>
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ARMPL
#include <spmv/ArmPL_SPMV.hpp>
#endif

// return std::make_tuple(newnumRows, newnumCols, A, x1, y1,
// rows_per_thread, team_size, vector_length,
// test, schedule, ave_time, max_time, min_time);
Expand Down
3 changes: 3 additions & 0 deletions perf_test/sparse/KokkosSparse_spmv_test.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
#include <spmv/MKL_SPMV.hpp>
#endif

template <typename AType, typename XType, typename YType>
void armpl_matvec(AType /*A*/, XType x, YType y, spmv_additional_data* data);

enum {
KOKKOS,
MKL,
Expand Down
21 changes: 11 additions & 10 deletions perf_test/sparse/spmv/ArmPL_SPMV.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,21 +48,22 @@
#ifdef KOKKOSKERNELS_ENABLE_TPL_ARMPL
#include <armpl.h>

template <typename Scalar>
void armpl_matvec_wrapper(armpl_spmat_t A, Scalar* x, Scalar* y) {
throw std::runtime_error(
"Can't use ArmPL mat-vec for scalar types other than double and float.");
}
// template <typename Scalar>
// void armpl_matvec_wrapper(armpl_spmat_t A, Scalar* x, Scalar* y) {
// throw std::runtime_error(
// "Can't use ArmPL mat-vec for scalar types other than double and
// float.");
//}

template <>
void armpl_matvec_wrapper<float>(armpl_spmat_t A, float* x, float* y) {
// template <>
void armpl_matvec_wrapper(armpl_spmat_t A, float* x, float* y) {
const float alpha = 1.0;
const float beta = 0.0;
armpl_spmv_exec_s(ARMPL_SPARSE_OPERATION_NOTRANS, alpha, A, x, beta, y);
}

template <>
void armpl_matvec_wrapper<double>(armpl_spmat_t A, double* x, double* y) {
// template <>
void armpl_matvec_wrapper(armpl_spmat_t A, double* x, double* y) {
const double alpha = 1.0;
const double beta = 0.0;
armpl_spmv_exec_d(ARMPL_SPARSE_OPERATION_NOTRANS, alpha, A, x, beta, y);
Expand All @@ -72,7 +73,7 @@ template <typename AType, typename XType, typename YType>
void armpl_matvec(AType /*A*/, XType x, YType y, spmv_additional_data* data) {
using Scalar = typename AType::non_const_value_type;
// Run armpl spmv corresponding to scalar type
armpl_matvec_wrapper<Scalar>(data->A, x.data(), y.data());
armpl_matvec_wrapper(data->A, x.data(), y.data());
}

#endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL
Expand Down
141 changes: 99 additions & 42 deletions src/batched/dense/KokkosBatched_Gemm_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,7 @@ class BatchedSerialGemm;
/// ResultsPerThread::Rank0 Each thread computes a scalar of C
/// ResultsPerThread::Rank1 Each thread computes a 1-rank chunk of C
/// ResultsPerThread::Rank2 Each thread computes a 2-rank chunk of C
/// \tparam HandleType Specifies the handle type of the kernel handle
/// \tparam ScalarType Specifies the scalar type of alpha and beta
/// \tparam AViewType Input matrix, as either a 3-rank Kokkos::View or a
/// 4-rank Kokkos::View for SIMD operations.
Expand Down Expand Up @@ -256,6 +257,52 @@ template <class ArgTransA, class ArgTransB, class ArgBatchSzDim,
class CViewType, class ArgBoundsCheck, class ArgAlphaFmaTag,
int tile_m, int tile_n, int tile_k>
class BatchedDblBufGemm;

// clang-format off
/// \brief Blocking solve of general matrix multiply on a batch of uniform matrices.
///
///
/// C = alpha * op(A) * op(B) + beta * C
///
/// \tparam ArgTransA Specifies what op does to A:
/// Trans::NoTranspose for non-transpose
/// Trans::Transpose for transpose
/// Trans::ConjTranspose for conjugate transpose (unsupported)
/// \tparam ArgTransB Specifies what op does to B:
/// Trans::NoTranspose for non-transpose
/// Trans::Transpose for transpose
/// Trans::ConjTranspose for conjugate transpose (unsupported)
/// \tparam HandleType Specifies the handle type of the kernel handle
/// \tparam ScalarType Specifies the scalar type of alpha and beta
/// \tparam AViewType Input matrix, as a 3-rank Kokkos::View
/// \tparam BViewType Input matrix, as a 3-rank Kokkos::View
/// \tparam CViewType Input(RHS)/Output(LHS) matrix, as a 3-rank
/// Kokkos::View
///
/// See struct BatchedGemmHandle for details
/// \param ninter [in] The number of matrices to interleave
/// \param alpha [in] Input coefficient used for multiplication with A
/// \param A [in] Input matrix, as a 3-rank Kokkos::View
/// If ArgBatchSzDim == "BatchSzDim::Right", matrix A is MxKxB
/// If ArgBatchSzDim == "BatchSzDim::Left", matrix A is BxMxK
/// \param B [in] Input matrix, as a 3-rank Kokkos::View
/// If ArgBatchSzDim == "BatchSzDim::Right", matrix B is KxNxB
/// If ArgBatchSzDim == "BatchSzDim::Left", matrix B is BxKxN
/// \param beta [in] Input coefficient used for multiplication with C
/// \param C [in/out] Input/Output matrix, as a 3-rank Kokkos::View
/// If ArgBatchSzDim == "BatchSzDim::Right", matrix C is MxNxB
/// If ArgBatchSzDim == "BatchSzDim::Left", matrix C is BxMxN
/// \return 0 upon success, non-zero otherwise
///
/// Usage Example:
/// BatchedArmplGemm<ArgTransA, ArgTransB, ArgBatchSzDim, HandleType,
/// ScalarType, AViewType, BViewType, CViewType>
/// (ninter, alpha, A, B, beta, C).invoke();
// clang-format on
template <class ArgTransA, class ArgTransB, class ArgBatchSzDim,
class HandleType, class ScalarType, class AViewType, class BViewType,
class CViewType>
class BatchedArmplGemm;
/********************* END forward declarations *********************/
} // namespace Impl

Expand Down Expand Up @@ -328,6 +375,14 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha,
"BViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<CViewType>::value,
"CViewType must be a Kokkos::View.");
static_assert(
std::is_same<ArgTransA, Trans::NoTranspose>::value ||
std::is_same<ArgTransA, Trans::Transpose>::value,
"ArgTransA must be either Trans::Transpose or Trans::NoTranspose.");
static_assert(
std::is_same<ArgTransB, Trans::NoTranspose>::value ||
std::is_same<ArgTransB, Trans::Transpose>::value,
"ArgTransB must be either Trans::Transpose or Trans::NoTranspose.");
if (is_vector<ViewValueType>::value) {
// Check ranks of view with underlying SIMD value types
// For SIMD views, we can have either 3-rank or 4-ranks inputs.
Expand Down Expand Up @@ -419,15 +474,6 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha,
}

switch (handle->get_kernel_algo_type()) {
case BaseKokkosBatchedAlgos::KK_SERIAL:
ret =
Impl::BatchedSerialGemm<ArgTransA, ArgTransB, Algo::Gemm::Unblocked,
ArgBatchSzDim, ResultsPerThread::Rank2,
ScalarType, AViewType, BViewType, CViewType>(
alpha, A, B, beta, C)
.invoke();
break;

////////////// HEURISTIC ALGOS //////////////
case BaseHeuristicAlgos::SQUARE:
if (c_m != c_n) {
Expand Down Expand Up @@ -524,54 +570,64 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha,
}
break;

case GemmKokkosBatchedAlgos::KK_DBLBUF:
// Note: The tile sizes of 1x1x1 here will not perform well but must be
// selected in order to function on all devices since the serial execution
// space has a max team size of 1. KokkosKernels API users will need to
// follow an approach similar to KK_SQUARE above for best performance.
// case BaseHeuristicAlgos::TALL:
//
// case BaseHeuristicAlgos::WIDE:
////////////// TPL ALGOS //////////////

// TODO: Add auto-selection of tile size based on inputs and device type
ret = Impl::BatchedDblBufGemm<ArgTransA, ArgTransB, ArgBatchSzDim,
BatchedGemmHandleType, ScalarType,
AViewType, BViewType, CViewType,
BoundsCheck::Yes, AlphaTag::No, 1, 1, 1>(
handle, alpha, A, B, beta, C)
case BaseTplAlgos::ARMPL:
ret = Impl::BatchedArmplGemm<ArgTransA, ArgTransB, ArgBatchSzDim,
BatchedGemmHandleType, ScalarType, AViewType,
BViewType, CViewType>(handle, alpha, A, B,
beta, C)
.invoke();
break;
// case BaseTplAlgos::MKL:
//
// case GemmTplAlgos::CUBLAS:
//
// case GemmTplAlgos::MAGMA:

case GemmKokkosBatchedAlgos::KK_SERIAL_RANK0:
////////////// KokkosBatched ALGOS //////////////
case BaseKokkosBatchedAlgos::KK_SERIAL:
ret =
Impl::BatchedSerialGemm<ArgTransA, ArgTransB, Algo::Gemm::Unblocked,
ArgBatchSzDim, ResultsPerThread::Rank0,
ArgBatchSzDim, ResultsPerThread::Rank2,
ScalarType, AViewType, BViewType, CViewType>(
alpha, A, B, beta, C)
.invoke();
break;

case BaseHeuristicAlgos::TALL:

case BaseHeuristicAlgos::WIDE:

////////////// TPL ALGOS //////////////
case BaseTplAlgos::ARMPL:

case BaseTplAlgos::MKL:

case GemmTplAlgos::CUBLAS:

case GemmTplAlgos::MAGMA:

////////////// KokkosBatched ALGOS //////////////

case GemmKokkosBatchedAlgos::KK_TEAM:
// case GemmKokkosBatchedAlgos::KK_SERIALSIMD:

case GemmKokkosBatchedAlgos::KK_TEAMVECTOR:
case GemmKokkosBatchedAlgos::KK_SERIAL_RANK0:
ret =
Impl::BatchedSerialGemm<ArgTransA, ArgTransB, Algo::Gemm::Unblocked,
ArgBatchSzDim, ResultsPerThread::Rank0,
ScalarType, AViewType, BViewType, CViewType>(
alpha, A, B, beta, C)
.invoke();
break;

case GemmKokkosBatchedAlgos::KK_SERIALSIMD:
// case GemmKokkosBatchedAlgos::KK_SERIAL_SHMEM:
// case GemmKokkosBatchedAlgos::KK_TEAM:
// case GemmKokkosBatchedAlgos::KK_TEAMVECTOR:
// case GemmKokkosBatchedAlgos::KK_TEAMSIMD:

case GemmKokkosBatchedAlgos::KK_TEAMSIMD:
case GemmKokkosBatchedAlgos::KK_DBLBUF:
// Note: The tile sizes of 1x1x1 here will not perform well but must be
// selected in order to function on all devices since the serial execution
// space has a max team size of 1. KokkosKernels API users will need to
// follow an approach similar to KK_SQUARE above for best performance.

case GemmKokkosBatchedAlgos::KK_SERIAL_SHMEM:
// TODO: Add auto-selection of tile size based on inputs and device type
ret = Impl::BatchedDblBufGemm<ArgTransA, ArgTransB, ArgBatchSzDim,
BatchedGemmHandleType, ScalarType,
AViewType, BViewType, CViewType,
BoundsCheck::Yes, AlphaTag::No, 1, 1, 1>(
handle, alpha, A, B, beta, C)
.invoke();
break;

default:
std::ostringstream os;
Expand All @@ -589,5 +645,6 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha,
#include "KokkosBatched_Gemm_Team_Impl.hpp"
#include "KokkosBatched_Gemm_TeamVector_Impl.hpp"
#include "KokkosBatched_Gemm_DblBuf_Impl.hpp"
#include "KokkosBatched_Gemm_Armpl_Impl.hpp"

#endif
21 changes: 20 additions & 1 deletion src/batched/dense/KokkosBatched_Gemm_Handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ enum GEMM_KOKKOS_BATCHED_ALGOS : int {
};
}

#define GEMM_ALGO_STRS \
"GemmTplAlgos::CUBLAS", "GemmTplAlgos::MAGMA", \
"GemmKokkosBatchedAlgos::KK_TEAM", \
"GemmKokkosBatchedAlgos::KK_TEAMVECTOR", \
"GemmKokkosBatchedAlgos::KK_SERIALSIMD", \
"GemmKokkosBatchedAlgos::KK_TEAMSIMD", \
"GemmKokkosBatchedAlgos::KK_SERIAL_RANK0", \
"GemmKokkosBatchedAlgos::KK_SERIAL_SHMEM", \
"GemmKokkosBatchedAlgos::KK_DBLBUF"
// clang-format off
/// \brief Handle for selecting runtime behavior of the BatchedGemm interface.
///
Expand Down Expand Up @@ -163,14 +172,24 @@ class BatchedGemmHandle : public BatchedKernelHandle {
#endif // MAGMA

decltype(auto) get_tpl_params() {
#if _kernelAlgoType == CUBLAS && defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS)
#if _kernelAlgoType == ARMPL && defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL)
return &_tplParamsSingleton.ninter;
#elif _kernelAlgoType == CUBLAS && defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS)
return _tplParamsSingleton.cublas_handle;
#elif _kernelAlgoType == MAGMA && defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA)
return _tplParamsSingleton.magma_queue;
#else
return this->BatchedKernelHandle::get_tpl_params();
#endif
}

std::string get_kernel_algo_type_str() const {
return gemm_algo_type_strs[_kernelAlgoType];
}

private:
const char *gemm_algo_type_strs[GemmKokkosBatchedAlgos::N] = {BASE_ALGO_STRS,
GEMM_ALGO_STRS};
};

} // namespace KokkosBatched
Expand Down
24 changes: 17 additions & 7 deletions src/batched/dense/KokkosBatched_Kernel_Handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

#if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL)
// TODO: Add armpl handle type to expose nintern & nbatch?
#include "armpl.h"
#endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL

#if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS)
Expand Down Expand Up @@ -83,6 +83,10 @@ enum BASE_KOKKOS_BATCHED_ALGOS : int { KK_SERIAL = BaseTplAlgos::N, N };
}

#define N_BASE_ALGOS BaseKokkosBatchedAlgos::N
#define BASE_ALGO_STRS \
"BaseHeuristicAlgos::SQUARE", "BaseHeuristicAlgos::TALL", \
"BaseHeuristicAlgos::WIDE", "BaseTplAlgos::ARMPL", "BaseTplAlgosMKL", \
"BaseKokkosBatchedAlgos::KK_SERIAL"

/// \brief TplParams abstracts underlying handle or execution queue type.
struct TplParams {
Expand All @@ -94,7 +98,7 @@ struct TplParams {
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

#if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL)
// TODO: Add armpl handle type in KokkosKernels to expose nintern & nbatch?
armpl_int_t ninter = 1;
#endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL

#if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS)
Expand Down Expand Up @@ -172,19 +176,22 @@ class BatchedKernelHandle {
int teamSize = 0, int vecLength = 0)
: teamSz(teamSize), vecLen(vecLength), _kernelAlgoType(kernelAlgoType){};

decltype(auto) get_tpl_params() {
int get_kernel_algo_type() const { return _kernelAlgoType; }

std::string get_kernel_algo_type_str() const {
return algo_type_strs[_kernelAlgoType];
}

decltype(auto) get_tpl_params() const {
#if _kernelAlgoType == ARMPL && defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL)
return "BaseTplAlgos::ARMPL does not support any tpl parameters";
#elif _kernelAlgoType == MKL && defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
return "BaseTplAlgos::MKL does not support any tpl parameters";
#else
return "Unsupported kernelAlgoType = " + std::to_string(_kernelAlgoType) +
".";
return "Unsupported kernelAlgoType:" + get_kernel_algo_type_str() + ".";
#endif
}

int get_kernel_algo_type() const { return _kernelAlgoType; }

// clang-format off
/// \var _kernelAlgoType Specifies which algorithm to use for invocation (default, SQUARE).
/// \var _tplParams a handle or queue specific to the TPL API.
Expand All @@ -202,6 +209,9 @@ class BatchedKernelHandle {
int _kernelAlgoType = BaseHeuristicAlgos::SQUARE;
TplParams &_tplParamsSingleton = _get_tpl_params_singleton();
bool _tplParamsSet = false;

private:
const char *algo_type_strs[N_BASE_ALGOS] = {BASE_ALGO_STRS};
};

} // namespace KokkosBatched
Expand Down
Loading

0 comments on commit 408e194

Please sign in to comment.