Skip to content

Commit

Permalink
Merge pull request #3 from e10harvey/batched_gemm_scaling_beyond_32
Browse files Browse the repository at this point in the history
Merge from Evan's batched gemm scaling beyond 32
  • Loading branch information
vqd8a authored Dec 5, 2021
2 parents 133b7fc + d9252db commit 3ce0ee5
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 24 deletions.
4 changes: 4 additions & 0 deletions src/batched/KokkosBatched_Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -826,6 +826,10 @@ KOKKOS_INLINE_FUNCTION ViewValueType
access_view_bounds_check(ViewType v, int m, int n, const BoundsCheck::Yes &) {
if (m < v.extent_int(0) && n < v.extent_int(1)) return v(m, n);
return (ViewValueType)0.0F;
// TODO: use compile-time extents
// if (m > scr.extent(0) || n > scr.extent(1))
// return 0;
// return v(m, n);
}

template <class ViewValueType, class ViewType>
Expand Down
74 changes: 50 additions & 24 deletions src/batched/dense/impl/KokkosBatched_Gemm_DblBuf_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,15 @@ namespace Impl {
/// CT/NT, NT/CT, CT/CT
///

// TODO - scaling between (32x32, 64x64)
// Option 0: Increase number of tiles and figure out how to map kokkos teams
// into cuda grid. Keep team size and vector lanes constant.
// TODO: write up small example and ask Christian. [DONE,
// MdRangePolicy not applicable here]
// Option 1: Increase register sizes to handle rows/cols past tile size
// Option 2: Fix league_size and have single team solve full tile followed
// by same team solving extra rows/cols (without multiplying by the
// zero rows/cols)
template <class ArgTransA, class ArgTransB, class ArgBatchSzDim,
class HandleType, class ScalarType, class AViewType, class BViewType,
class CViewType, class ArgBoundsCheck, int TILE_M, int TILE_N,
Expand All @@ -83,6 +92,7 @@ class BatchedDblBufGemm {
typename execution_space_type::scratch_memory_space;
using view_type_2d_scratch =
Kokkos::View<view_value_type **, Kokkos::LayoutRight, scratch_space_type>;
// TODO: add compile-time extents

public:
BatchedDblBufGemm(HandleType *const handle, ScalarType alpha, AViewType A,
Expand Down Expand Up @@ -252,6 +262,9 @@ class BatchedDblBufGemm {
view_value_type reg_c[REG_M][REG_N],
view_type_2d_scratch &svA_scr,
view_type_2d_scratch &svB_scr) const {
// view_type_2d_scratch svA_scr(member.team_scratch(0), __tile_m,
// __tile_k); view_type_2d_scratch svB_scr(member.team_scratch(0),
// __tile_k, __tile_n);
Kokkos::parallel_for(
Kokkos::TeamThreadRange(member, 0, tile_m / REG_M),
[&](const int &thread_id) {
Expand All @@ -261,13 +274,22 @@ class BatchedDblBufGemm {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (unsigned k = 0; k < nk; ++k) {
// TODO: would have to invert this for
// threadVectorRange copy TODOs below
for (unsigned k = 0; k < nk;
++k) { // TODO: svA_scr coalesced access. All vlanes are
// readying the same data from svA scr.
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int m = 0; m < REG_M; ++m)
reg_a[m] = svA_scr(k, thread_id + m * STRIDE_M);
// TODO: this could be a threadVectorRange copy
reg_a[m] = svA_scr(thread_id + m * STRIDE_M, k);
// TODO: reg_a could be a thread shared buffer

// view_type_2d_scratch svA_scr(member.team_scratch(0),
// __tile_m, __tile_k); view_type_2d_scratch
// svB_scr(member.team_scratch(0), __tile_k, __tile_n);
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
Expand All @@ -291,12 +313,17 @@ class BatchedDblBufGemm {

KOKKOS_INLINE_FUNCTION
void operator()(const MemberType &member) const {
// TODO: use Kokkos view with compile-time size to allocating register??
// Then we can use local deep copy for prefetch_reg population.
// Allocate registers used for prefetching
view_value_type prefetch_reg_a[REG_M] = {0}, prefetch_reg_b[REG_N] = {0};

// Allocate registers used for FMAs
view_value_type reg_a[REG_M] = {0}, reg_b[REG_N] = {0},
reg_c[REG_M][REG_N] = {{0}};
// TODO: look at local loads and stores via nvprof
// TODO: look at GPU trace in nvprof to find out how many registers are
// used.

unsigned batch_idx = member.league_rank() / __n_sub_tiles;

Expand All @@ -314,25 +341,25 @@ class BatchedDblBufGemm {
__ei.__batch_layout_tag);

// Allocate scratch memory buffers used for prefetching
view_type_2d_scratch svA_scr(member.team_scratch(0), __tile_k, __tile_m);
view_type_2d_scratch svA_scr(member.team_scratch(0), __tile_m, __tile_k);
view_type_2d_scratch svB_scr(member.team_scratch(0), __tile_k, __tile_n);

// Here we populate scratch memory with one or more "k" tiles for every
// thread of the team!
Kokkos::parallel_for(
Kokkos::TeamThreadRange(member, 0, __tile_n / REG_N),
Kokkos::TeamThreadRange(member, 0, __tile_k),
[&](const int &thread_id) {
auto thread_offset = thread_id + start_n;
Kokkos::parallel_for(
Kokkos::ThreadVectorRange(member, 0, __tile_k),
Kokkos::ThreadVectorRange(member, 0, __tile_n / REG_N),
[&](const int &vlane_id) {
auto vlane_offset = vlane_id + start_n;
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_N * STRIDE_N; i += STRIDE_N)
svB_scr(vlane_id, thread_id + i) =
svB_scr(thread_id, vlane_id + i) =
access_view_bounds_check<view_value_type>(
svB, vlane_id, thread_offset + i,
svB, thread_id, vlane_offset + i,
__ei.__bounds_check_tag);
});
});
Expand All @@ -347,10 +374,11 @@ class BatchedDblBufGemm {
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_M * STRIDE_M; i += STRIDE_M)
svA_scr(vlane_id, thread_id + i) =
svA_scr(thread_id + i, vlane_id) =
access_view_bounds_check<view_value_type>(
svA, thread_offset + i, vlane_id,
__ei.__bounds_check_tag);
// TODO: might be able to use local deep copy here.
});
});

Expand All @@ -373,20 +401,20 @@ class BatchedDblBufGemm {
// Each thread has its own copy of prefetch_reg_b. TeamThreadRange runs
// over all threads in the team.
Kokkos::parallel_for(
Kokkos::TeamThreadRange(member, 0, __tile_n / REG_N),
[&](const int &thread_id) {
auto thread_offset = thread_id + start_n;
Kokkos::TeamThreadRange(member, k_tile_offset,
k_tile_offset + __tile_k),
[&](const int &thread_offset) {
Kokkos::parallel_for(
Kokkos::ThreadVectorRange(member, 0, __tile_k),
Kokkos::ThreadVectorRange(member, 0, __tile_n / REG_N),
[&](const int &vlane_id) {
auto vlane_offset = vlane_id + start_n;
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_N; ++i)
prefetch_reg_b[i] =
access_view_bounds_check<view_value_type>(
svB, vlane_id + k_tile_offset,
thread_offset + i * STRIDE_N,
svB, thread_offset, vlane_offset + i * STRIDE_N,
__ei.__bounds_check_tag);
});
});
Expand All @@ -399,16 +427,16 @@ class BatchedDblBufGemm {
[&](const int &thread_id) {
auto thread_offset = thread_id + start_m;
Kokkos::parallel_for(
Kokkos::ThreadVectorRange(member, 0, __tile_k),
Kokkos::ThreadVectorRange(member, k_tile_offset,
k_tile_offset + __tile_k),
[&](const int &vlane_id) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_M; ++i)
prefetch_reg_a[i] =
access_view_bounds_check<view_value_type>(
svA, thread_offset + i * STRIDE_M,
vlane_id + k_tile_offset,
svA, thread_offset + i * STRIDE_M, vlane_id,
__ei.__bounds_check_tag);
});
});
Expand All @@ -424,17 +452,16 @@ class BatchedDblBufGemm {
// populate shmem from prefetch registers. Each thread has its own copy
// of prefetch_reg_a.
Kokkos::parallel_for(
Kokkos::TeamThreadRange(member, 0, __tile_n / REG_N),
Kokkos::TeamThreadRange(member, 0, __tile_k),
[&](const int &thread_id) {
auto thread_offset = thread_id;
Kokkos::parallel_for(
Kokkos::ThreadVectorRange(member, 0, __tile_k),
Kokkos::ThreadVectorRange(member, 0, __tile_n / REG_N),
[&](const int &vlane_id) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_N; ++i)
svB_scr(vlane_id, thread_offset + i * STRIDE_N) =
svB_scr(thread_id, vlane_id + i * STRIDE_N) =
prefetch_reg_b[i];
});
});
Expand All @@ -444,15 +471,14 @@ class BatchedDblBufGemm {
Kokkos::parallel_for(
Kokkos::TeamThreadRange(member, 0, __tile_m / REG_M),
[&](const int &thread_id) {
auto thread_offset = thread_id;
Kokkos::parallel_for(
Kokkos::ThreadVectorRange(member, 0, __tile_k),
[&](const int &vlane_id) {
#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
#endif // KOKKOS_ENABLE_PRAGMA_UNROLL
for (int i = 0; i < REG_M; ++i)
svA_scr(vlane_id, thread_offset + i * STRIDE_M) =
svA_scr(thread_id + i * STRIDE_M, vlane_id) =
prefetch_reg_a[i];
});
});
Expand Down

0 comments on commit 3ce0ee5

Please sign in to comment.