Skip to content

Commit

Permalink
src/batched: Add switch, tests, and singleton
Browse files Browse the repository at this point in the history
  • Loading branch information
e10harvey committed Jul 15, 2021
1 parent 49a2af5 commit 5bf79ef
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 47 deletions.
53 changes: 50 additions & 3 deletions src/batched/KokkosBatched_Gemm_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
#include "KokkosBatched_Vector.hpp"
#include "KokkosBatched_Gemm_Handle.hpp"

// Includes for non-functor-level routines
#include <KokkosBatched_Gemm_Handle.hpp>

namespace KokkosBatched {
/********************* BEGIN functor-level routines *********************/
///
Expand Down Expand Up @@ -273,9 +276,53 @@ int BatchedGemm(const BatchedGemmHandleType *handle, const ScalarType alpha,
"on_intel:" << on_intel << std::endl <<
"on_a64fx:" << on_a64fx << std::endl;
#endif
BatchedSerialGemm<ArgTransA, ArgTransB, mode_type, ArgBatchSzDim,
resultsPerThread, ScalarType, AViewType, BViewType,
CViewType>(alpha, A, B, beta, C);
switch (handle->get_kernel_algo_type()) {
case BaseKokkosBatchedAlgos::KK_SERIAL:
ret = BatchedSerialGemm<ArgTransA, ArgTransB, mode_type, ArgBatchSzDim,
resultsPerThread, ScalarType, AViewType,
BViewType, CViewType>(alpha, A, B, beta, C)
.invoke();
break;

////////////// HEURISTIC ALGOS //////////////
case BaseHeuristicAlgos::SQUARE:

case BaseHeuristicAlgos::TALL:

case BaseHeuristicAlgos::WIDE:

////////////// TPL ALGOS //////////////
case BaseTplAlgos::ARMPL:

case BaseTplAlgos::MKL:

case GemmTplAlgos::CUBLAS:

case GemmTplAlgos::MAGMA:

////////////// KokkosBatched ALGOS //////////////

case GemmKokkosBatchedAlgos::KK_TEAM:

case GemmKokkosBatchedAlgos::KK_TEAMVECTOR:

case GemmKokkosBatchedAlgos::KK_SERIALSIMD:

case GemmKokkosBatchedAlgos::KK_TEAMSIMD:

case GemmKokkosBatchedAlgos::KK_SERIAL_OPT2:

case GemmKokkosBatchedAlgos::KK_TEAMVECTOR_SHMEM:

case GemmKokkosBatchedAlgos::KK_TEAMVECTOR_DBLBUF:

default:
std::ostringstream os;
os << "KokkosBatched::BatchedGemm does not support kernelAlgoType = "
<< std::to_string(handle->get_kernel_algo_type()) << "." << std::endl;
Kokkos::Impl::throw_runtime_exception(os.str());
break;
}
return ret;
}
/********************* END non-functor-level routines *********************/
Expand Down
9 changes: 5 additions & 4 deletions src/batched/KokkosBatched_Gemm_Handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,14 @@ enum GEMM_KOKKOS_BATCHED_ALGOS : int {
// clang-format on
class BatchedGemmHandle : public BatchedKernelHandle {
public:
BatchedGemmHandle() = default;

BatchedGemmHandle(int kernelAlgoType = BaseHeuristicAlgos::SQUARE,
int teamSize = 0, int vecLength = 0)
: BatchedKernelHandle(kernelAlgoType, teamSize, vecLength){};
decltype(auto) get_tpl_params() {
#if _kernelAlgoType == CUBLAS && defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS)
return _tplParams.tplHandle.cublas_handle;
return _tplParamsSingleton.cublas_handle;
#elif _kernelAlgoType == MAGMA && defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA)
return _tplParams.tplQueue.magma_queue;
return _tplParamsSingleton.magma_queue;
#else
return this->BatchedKernelHandle::get_tpl_params();
#endif
Expand Down
6 changes: 2 additions & 4 deletions src/batched/KokkosBatched_Gemm_Serial_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,7 @@ class BatchedSerialGemm {
*this);
}

public:
int invoke() {
if (std::is_same<ArgResultsPerThread, ResultsPerThread::Rank0>::value) {
// Set members for ResultsPerThread::Rank0 operator; these members allow
Expand Down Expand Up @@ -476,12 +477,9 @@ class BatchedSerialGemm {
return 0;
}

public:
BatchedSerialGemm(ScalarType _alpha, AViewType _A, BViewType _B,
ScalarType _beta, CViewType _C)
: A(_A), B(_B), C(_C), alpha(_alpha), beta(_beta) {
invoke();
}
: A(_A), B(_B), C(_C), alpha(_alpha), beta(_beta) {}

KOKKOS_INLINE_FUNCTION
void operator()(const ResultsPerThread::Rank0 &, const int &i) const {
Expand Down
67 changes: 39 additions & 28 deletions src/batched/KokkosBatched_Kernel_Handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,24 @@
#ifndef KOKKOSKERNELS_KOKKOSBATCHED_KERNEL_HEADER_HPP
#define KOKKOSKERNELS_KOKKOSBATCHED_KERNEL_HEADER_HPP

#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
#include <mkl.h>
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

#if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL)
// TODO: Add armpl handle type to expose nintern & nbatch?
#endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL

#if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS)
#include "cuda_runtime.h"
#include "cublas_v2.h"
#endif // KOKKOSKERNELS_ENABLE_TPL_CUBLAS

#if defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA)
#include <magma_v2.h>
#include <magma_batched.h>
#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA

namespace KokkosBatched {

/// \brief Heuristic algorithm types. See BatchedKernelHandle for details.
Expand All @@ -66,36 +84,25 @@ enum BASE_KOKKOS_BATCHED_ALGOS : int { KK_SERIAL = BaseTplAlgos::N, N };

#define N_BASE_ALGOS BaseKokkosBatchedAlgos::N

/// \brief TplHandle abstracts underlying handle type.
union TplHandle {
/// \brief TplParams abstracts underlying handle or execution queue type.
struct TplParams {
union {
#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
queue mkl_queue;
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

#if defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL)
// TODO: Add armpl handle type to expose nintern & nbatch?
// TODO: Add armpl handle type to expose nintern & nbatch?
#endif // KOKKOSKERNELS_ENABLE_TPL_ARMPL

#if defined(KOKKOSKERNELS_ENABLE_TPL_CUBLAS)
#include <cublas.h>
cublasHandle_t &cublas_handle;
cublasHandle_t cublas_handle;
#endif // KOKKOSKERNELS_ENABLE_TPL_CUBLAS
// char no_handles;
};

/// \brief TplExecQueue abstracts underlying execution queue type.
union TplExecQueue {
#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
#include <mkl.h>
queue &mkl_queue;
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

#if defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA)
#include <magma.h>
magma_queue_t &magma_queue;
magma_queue_t magma_queue;
#endif // KOKKOSKERNELS_ENABLE_TPL_MAGMA
// char no_queues;
};

union TplParams {
TplHandle tplHandle;
TplExecQueue tplQueue;
};
};

// clang-format off
Expand Down Expand Up @@ -142,22 +149,26 @@ class BatchedKernelHandle {

decltype(auto) get_tpl_params() {
#if _kernelAlgoType == ARMPL && defined(KOKKOSKERNELS_ENABLE_TPL_ARMPL)
Kokkos::abort("BaseTplAlgos::ARMPL does not support any parameters");
return nullptr;
return "BaseTplAlgos::ARMPL does not support any tpl parameters";
#elif _kernelAlgoType == MKL && defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
return _tplParams.tplQueue.mkl_queue;
return _tplParamsSingleton.mkl_queue;
#else
return "Unsupported kernelAlgoType = " + std::to_string(_kernelAlgoType) +
".";
#endif
}

int get_kernel_algo_type() const { return _kernelAlgoType; }

/// \var _kernelAlgoType Specifies which algorithm to use for invocation
/// (default, SQUARE). \var _enabledDebug toggle debug messages. \var
/// _tplParams a handle or queue specific to the TPL API.
/// managed internally unless provided by user via
/// constructor overload
protected:
int _kernelAlgoType = BaseHeuristicAlgos::SQUARE;
bool _enableDebug = false;
TplParams _tplParams;
int _kernelAlgoType = BaseHeuristicAlgos::SQUARE;
constexpr static bool _enableDebug = false;
static TplParams &_tplParamsSingleton;
};

} // namespace KokkosBatched
Expand Down
58 changes: 50 additions & 8 deletions unit_test/batched/Test_Batched_BatchedGemm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,21 @@ using namespace KokkosBatched;
namespace Test {
template <typename DeviceType, typename ViewType, typename ScalarType,
typename ParamTagType>
void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2,
const int matBdim1, const int matBdim2,
const int matCdim1, const int matCdim2) {
void impl_test_batched_gemm_handle(BatchedGemmHandle* batchedGemmHandle,
const int N, const int matAdim1,
const int matAdim2, const int matBdim1,
const int matBdim2, const int matCdim1,
const int matCdim2) {
using execution_space = typename DeviceType::execution_space;
using transA = typename ParamTagType::transA;
using transB = typename ParamTagType::transB;
using batchLayout = typename ParamTagType::batchLayout;
using view_layout = typename ViewType::array_layout;
using ats = Kokkos::Details::ArithTraits<ScalarType>;

int ret = 0;
ScalarType alpha = ScalarType(1.5);
ScalarType beta = ScalarType(3.0);
BatchedGemmHandle batchedGemmHandle;
int ret = 0;
ScalarType alpha = ScalarType(1.5);
ScalarType beta = ScalarType(3.0);

ViewType a_expected, a_actual, b_expected, b_actual, c_expected, c_actual;
if (std::is_same<batchLayout, BatchLayout::Left>::value) {
Expand Down Expand Up @@ -58,7 +59,7 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2,
// Check for expected runtime errors due to non-optimal BatchedGemm invocation
try {
ret = BatchedGemm<transA, transB, batchLayout>(
&batchedGemmHandle, alpha, a_actual, b_actual, beta,
batchedGemmHandle, alpha, a_actual, b_actual, beta,
c_actual); // Compute c_actual
} catch (const std::runtime_error& error) {
if (!((std::is_same<view_layout, Kokkos::LayoutLeft>::value &&
Expand Down Expand Up @@ -119,6 +120,47 @@ void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2,
}
EXPECT_NEAR_KK(diff / sum, 0, eps);
}

template <typename DeviceType, typename ViewType, typename ScalarType,
typename ParamTagType>
void impl_test_batched_gemm(const int N, const int matAdim1, const int matAdim2,
const int matBdim1, const int matBdim2,
const int matCdim1, const int matCdim2) {
{
BatchedGemmHandle batchedGemmHandle;

ASSERT_EQ(batchedGemmHandle.get_kernel_algo_type(),
BaseHeuristicAlgos::SQUARE);
ASSERT_EQ(batchedGemmHandle.teamSz, 0);
ASSERT_EQ(batchedGemmHandle.vecLen, 0);
}

for (int algo_type = BaseHeuristicAlgos::SQUARE;
algo_type < GemmKokkosBatchedAlgos::N; ++algo_type) {
{
BatchedGemmHandle batchedGemmHandle(algo_type);

ASSERT_EQ(batchedGemmHandle.get_kernel_algo_type(), algo_type);

if (algo_type == BaseKokkosBatchedAlgos::KK_SERIAL) {
impl_test_batched_gemm_handle<DeviceType, ViewType, ScalarType,
ParamTagType>(
&batchedGemmHandle, N, matAdim1, matAdim2, matBdim1, matBdim2,
matCdim1, matCdim2);
} else {
try {
impl_test_batched_gemm_handle<DeviceType, ViewType, ScalarType,
ParamTagType>(
&batchedGemmHandle, N, matAdim1, matAdim2, matBdim1, matBdim2,
matCdim1, matCdim2);
FAIL();
} catch (const std::runtime_error& error) {
;
}
}
}
}
}
} // namespace Test
template <typename DeviceType, typename ValueType, typename ScalarType,
typename ParamTagType>
Expand Down

0 comments on commit 5bf79ef

Please sign in to comment.