Skip to content

Commit

Permalink
Merge pull request #1217 from vqd8a/batched_gemm_fix
Browse files Browse the repository at this point in the history
Improve double buffering batched gemm performance
  • Loading branch information
e10harvey authored Dec 9, 2021
2 parents b609e0b + 2602b97 commit 446e9c2
Show file tree
Hide file tree
Showing 3 changed files with 292 additions and 43 deletions.
9 changes: 7 additions & 2 deletions src/batched/KokkosBatched_Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "Kokkos_Timer.hpp"

#include "KokkosKernels_config.h"
#include "KokkosKernels_Utils.hpp"

// TPL macros
#if defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
Expand Down Expand Up @@ -824,8 +825,12 @@ KOKKOS_INLINE_FUNCTION auto subview_wrapper(
template <class ViewValueType, class ViewType>
KOKKOS_INLINE_FUNCTION ViewValueType
access_view_bounds_check(ViewType v, int m, int n, const BoundsCheck::Yes &) {
if (m < v.extent_int(0) && n < v.extent_int(1)) return v(m, n);
return (ViewValueType)0.0F;
return v(KOKKOSKERNELS_MACRO_MIN(m, v.extent_int(0)),
KOKKOSKERNELS_MACRO_MIN(n, v.extent_int(1)));
//// TODO: use compile-time extents
//// if (m > scr.extent(0) || n > scr.extent(1))
//// return 0;
//// return v(m, n);
}

template <class ViewValueType, class ViewType>
Expand Down
13 changes: 7 additions & 6 deletions src/batched/dense/KokkosBatched_Gemm_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -463,28 +463,29 @@ int BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha,
// (alpha == 1.0F && beta == 0.0F) ? c_m <= 24 : c_m <= 21) {
// // TODO: invoke TeamShmem
// } else
if (on_gpu &&
((std::is_same<layout_type, Kokkos::LayoutLeft>::value)
? (c_m >= 16)
: (c_m >= 24 && c_m <= 32) || (c_m >= 45 && c_m <= 64))) {
if (on_gpu && ((std::is_same<layout_type, Kokkos::LayoutLeft>::value)
? (c_m >= 16)
: (c_m >= 24))) { // Vinh's note: use this condition
// for now, might need to revisit
handle->teamSz = handle->vecLen = 8;
constexpr int tile_m = 32, tile_n = 32, tile_k = 8;
if (c_m % 32 == 0) // No bounds checking
if (c_m % 32 == 0) { // No bounds checking
ret =
Impl::BatchedDblBufGemm<ArgTransA, ArgTransB, ArgBatchSzDim,
BatchedGemmHandleType, ScalarType,
AViewType, BViewType, CViewType,
BoundsCheck::No, tile_m, tile_n, tile_k>(
handle, alpha, A, B, beta, C)
.invoke();
else
} else {
ret =
Impl::BatchedDblBufGemm<ArgTransA, ArgTransB, ArgBatchSzDim,
BatchedGemmHandleType, ScalarType,
AViewType, BViewType, CViewType,
BoundsCheck::Yes, tile_m, tile_n, tile_k>(
handle, alpha, A, B, beta, C)
.invoke();
}
} else {
ret = Impl::BatchedSerialGemm<ArgTransA, ArgTransB, bsgModeType,
ArgBatchSzDim, bsgResultsPerThread,
Expand Down
Loading

0 comments on commit 446e9c2

Please sign in to comment.