Skip to content

Commit

Permalink
Merge pull request #1050 from iyamazaki/dot-based-gemm
Browse files Browse the repository at this point in the history
move dot-based GEMM out of TPL CUBLAS..
  • Loading branch information
e10harvey authored Jul 26, 2021
2 parents 51d241c + 6d6ee66 commit 835f2ca
Show file tree
Hide file tree
Showing 3 changed files with 313 additions and 229 deletions.
209 changes: 209 additions & 0 deletions src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
/*
//@HEADER
// ************************************************************************
//
// Kokkos v. 3.0
// Copyright (2020) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions are
// met:
//
// 1. Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
//
// 2. Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
//
// 3. Neither the name of the Corporation nor the names of the
// contributors may be used to endorse or promote products derived from
// this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
// Questions? Contact Siva Rajamanickam (srajama@sandia.gov)
//
// ************************************************************************
//@HEADER
*/

#ifndef KOKKOS_BLAS3_GEMM_DOTBASED_IMPL_HPP_
#define KOKKOS_BLAS3_GEMM_DOTBASED_IMPL_HPP_

namespace KokkosBlas {
namespace Impl {


// DotBasedGEMM implements the optimization for C = beta*C + alpha*A^TB
// with A and B matrices both being tall and skinny. C matrix is assumably
// small, so, each entry of C is computed by performing the dot product of
// respective columns of A and B matrices. Note that the dot products are
// performed on very long vectors, so, each dot product is distributed among
// numDivPerDot teams.

struct TagZero{}; // The init tag for beta=0
struct TagInit{}; // The init tag for beta!=0 and beta !=1
struct TagMult{}; // The multiplication tag for transposed A
struct TagMultCT{}; // The multiplication tag for conjugate-transposed A
template<class ExecSpace, class AV, class BV, class CV>
struct DotBasedGEMM{

const AV A;
const BV B;
CV C;

using scalar_A = typename AV::non_const_value_type;
using size_A = typename AV::size_type;
using scalar_C = typename CV::non_const_value_type;
using size_C = typename CV::size_type;
using AVT = Kokkos::Details::ArithTraits<scalar_A>;
using CVT = Kokkos::Details::ArithTraits<scalar_C>;

const scalar_A alpha;
const scalar_C beta;

const size_C numCrows;
const size_C numCcols;

size_C numDivPerDot; // number of teams collectively performing a dot product
size_C numTeams; // total number of teams

const size_A dotSize; // the length of the vectors in the dot products
size_A chunkSize; // the local length of each team's share on the dot product


DotBasedGEMM(const scalar_A& alpha_, const AV& A_, const BV& B_, const scalar_C& beta_, const CV& C_) :
A(A_), B(B_), C(C_), alpha(alpha_), beta(beta_),
numCrows(C.extent(0)), numCcols(C.extent(1)), dotSize(A.extent(0))
{ }

void run(bool conjugateTranspose) {

// NOTE: these workPerTeam and approxNumTeams were used for TPL CUBLAS,
// and may need to be retuned for other architectures
constexpr size_C workPerTeam = 4096; // Amount of work per team
const size_C ndots = numCrows * numCcols; // Number of dot products
size_C appxNumTeams = (dotSize * ndots) / workPerTeam; // Estimation for appxNumTeams

// Adjust appxNumTeams in case it is too small or too large
if(appxNumTeams < 1)
appxNumTeams = 1;
if(appxNumTeams > 1024)
appxNumTeams = 1024;

// If there are more dot products than the number of teams,
// then set the number of teams to be number of dot products
// and each team will perform only one dot product.
// We don't want a team to perform more than one dot product.
if(ndots >= appxNumTeams) {
numTeams = ndots;
numDivPerDot = 1;
}
// If there are more teams than dot products, each dot product can
// potentially be performed by multiple teams. First, compute
// numDivPerDot as an integer (take the floor, not ceiling), then,
// compute actual number of teams by using this factor.
else {
numDivPerDot = appxNumTeams / ndots;
numTeams = ndots * numDivPerDot;
}

// Determine the local length for the dot product
chunkSize = dotSize / numDivPerDot;
if(numDivPerDot > 1)
chunkSize++;

// Initialize C matrix if beta != 1
if(beta == CVT::zero()) {
Kokkos::MDRangePolicy<TagZero, ExecSpace, Kokkos::Rank<2>> policyInit({0,0}, {numCrows, numCcols});
Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this);
}
else if(beta != CVT::one()) {
Kokkos::MDRangePolicy<TagInit, ExecSpace, Kokkos::Rank<2>> policyInit({0,0}, {numCrows, numCcols});
Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this);
}

// Multiply alpha*A^TB and add it to beta*C
if(conjugateTranspose) {
Kokkos::TeamPolicy<TagMultCT, ExecSpace> policyMult(numTeams, Kokkos::AUTO);
Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this);
}
else{
Kokkos::TeamPolicy<TagMult, ExecSpace> policyMult(numTeams, Kokkos::AUTO);
Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this);
}
}

KOKKOS_INLINE_FUNCTION
void operator() (const TagZero&, const size_C &rowId, const size_C &colId ) const {
C(rowId, colId) = CVT::zero();
}

KOKKOS_INLINE_FUNCTION
void operator() (const TagInit&, const size_C &rowId, const size_C &colId ) const {
C(rowId, colId) = beta * C(rowId, colId);
}

KOKKOS_INLINE_FUNCTION
void operator() (const TagMult&, const typename Kokkos::TeamPolicy<ExecSpace>::member_type& teamMember) const {

const size_C globalRank = teamMember.league_rank();
const size_C localRank = globalRank % numDivPerDot;
const size_C i = globalRank / numDivPerDot;
const size_C rowId = i / numCcols;
const size_C colId = i % numCcols;

scalar_C result = CVT::zero();
const size_A baseInd = chunkSize*localRank;
Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) {
if(baseInd + k < dotSize)
update += alpha * A(baseInd+k, rowId) * B(baseInd+k, colId);
}, result );

Kokkos::single(Kokkos::PerTeam(teamMember), [&] () {
Kokkos::atomic_add(&C(rowId, colId), result);
});
}

KOKKOS_INLINE_FUNCTION
void operator() (const TagMultCT&, const typename Kokkos::TeamPolicy<ExecSpace>::member_type& teamMember) const {

const size_C globalRank = teamMember.league_rank();
const size_C localRank = globalRank % numDivPerDot;
const size_C i = globalRank / numDivPerDot;
const size_C rowId = i / numCcols;
const size_C colId = i % numCcols;

scalar_C result = CVT::zero();
const size_A baseInd = chunkSize*localRank;
Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) {
if(baseInd + k < dotSize)
update += alpha * AVT::conj(A(baseInd+k, rowId)) * B(baseInd+k, colId);
}, result );

Kokkos::single(Kokkos::PerTeam(teamMember), [&] () {
Kokkos::atomic_add(&C(rowId, colId), result);
});
}

};

}
}

#endif
163 changes: 95 additions & 68 deletions src/blas/impl/KokkosBlas3_gemm_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,9 @@
#include "Kokkos_InnerProductSpaceTraits.hpp"

#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
#include<KokkosBlas3_gemm_impl.hpp>
#include "KokkosBlas3_gemm_impl.hpp"
#include "KokkosBlas3_gemm_dotbased_impl.hpp"
#include "KokkosKernels_ExecSpaceUtils.hpp"
#endif

namespace KokkosBlas {
Expand Down Expand Up @@ -134,74 +136,99 @@ struct GEMM {
typedef typename AViewType::non_const_value_type ScalarA;
typedef typename BViewType::non_const_value_type ScalarB;
typedef typename CViewType::non_const_value_type ScalarC;
typedef typename CViewType::execution_space ExecSpace;

// Define Blocking sizes (this will be used for scratch spaces)
static constexpr int blockA0 = 24;
static constexpr int blockB1 = 64;
static constexpr int blockA1 = (sizeof(ScalarA)*blockA0*16 + sizeof(ScalarB)*16*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 16 :
(sizeof(ScalarA)*blockA0*8 + sizeof(ScalarB)*8*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 8 :
(sizeof(ScalarA)*blockA0*4 + sizeof(ScalarB)*4*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 4 : 16 ;
static constexpr int vector_length = blockB1/4;

// Compute scratch space size
typedef KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm_dummy_type;
const int scratch_memory_size =
gemm_dummy_type::ViewTypeAScratch::required_allocation_size() +
gemm_dummy_type::ViewTypeBScratch::required_allocation_size() +
gemm_dummy_type::ViewTypeCScratch::required_allocation_size();
const int scratch_level = scratch_memory_size < 24000 ? 0 : 1;

// Figure out Team Sizes
int team_size = 1;
#if defined(KOKKOS_ENABLE_CUDA)
if(std::is_same<typename CViewType::execution_space,Kokkos::Cuda>::value)
team_size = blockA0;
#endif
#if defined(KOKKOS_ENABLE_HIP)
if(std::is_same<typename CViewType::execution_space,Kokkos::Experimental::HIP>::value)
team_size = blockA0;
#endif
#if defined(KOKKOS_ENABLE_ROCM)
if(std::is_same<typename CViewType::execution_space,Kokkos::ROCm>::value)
team_size = blockA0;
#endif

// Call the correct kernel
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='N' || transB[0]=='n')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='N' || transB[0]=='n')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,0> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='N' || transB[0]=='n')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,0> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='T' || transB[0]=='t')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,1> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='T' || transB[0]=='t')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,1> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='T' || transB[0]=='t')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,1> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='C' || transB[0]=='c')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,2> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='C' || transB[0]=='c')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,2> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='C' || transB[0]=='c')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,2> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
// Figure out whether to use DotBased implementation
const int M = static_cast<int> (C.extent(0));
const int N = static_cast<int> (C.extent(1));

const bool is_device_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecSpace>();
const bool A_is_lr = std::is_same<Kokkos::LayoutRight, typename AViewType::array_layout>::value;
const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c'));
const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c'));

// NOTE: these thresholds were copied from TPL CUBLAS, and may need to be retuned
constexpr int numDotsLayoutLeftThreshold = 1600;
constexpr int numDotsLayoutRightThreshold = 100;
if(( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold)
|| ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold))
&& is_device_space) {

// call dot-based GEMM, only for C := beta * C + alpha * A^T * B, on device
bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c'));
DotBasedGEMM<ExecSpace, AViewType, BViewType, CViewType> dotBasedGemm(alpha, A, B, beta, C);
dotBasedGemm.run(A_is_conj);

} else {

// Define Blocking sizes (this will be used for scratch spaces)
static constexpr int blockA0 = 24;
static constexpr int blockB1 = 64;
static constexpr int blockA1 = (sizeof(ScalarA)*blockA0*16 + sizeof(ScalarB)*16*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 16 :
(sizeof(ScalarA)*blockA0*8 + sizeof(ScalarB)*8*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 8 :
(sizeof(ScalarA)*blockA0*4 + sizeof(ScalarB)*4*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 4 : 16 ;
static constexpr int vector_length = blockB1/4;

// Compute scratch space size
typedef KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm_dummy_type;
const int scratch_memory_size =
gemm_dummy_type::ViewTypeAScratch::required_allocation_size() +
gemm_dummy_type::ViewTypeBScratch::required_allocation_size() +
gemm_dummy_type::ViewTypeCScratch::required_allocation_size();
const int scratch_level = scratch_memory_size < 24000 ? 0 : 1;

// Figure out Team Sizes
int team_size = 1;
#if defined(KOKKOS_ENABLE_CUDA)
if(std::is_same<typename CViewType::execution_space,Kokkos::Cuda>::value)
team_size = blockA0;
#endif
#if defined(KOKKOS_ENABLE_HIP)
if(std::is_same<typename CViewType::execution_space,Kokkos::Experimental::HIP>::value)
team_size = blockA0;
#endif
#if defined(KOKKOS_ENABLE_ROCM)
if(std::is_same<typename CViewType::execution_space,Kokkos::ROCm>::value)
team_size = blockA0;
#endif

// Call the correct kernel
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='N' || transB[0]=='n')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,0> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='N' || transB[0]=='n')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,0> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='N' || transB[0]=='n')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,0> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='T' || transB[0]=='t')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,1> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='T' || transB[0]=='t')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,1> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='T' || transB[0]=='t')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,1> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='C' || transB[0]=='c')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,0,2> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='C' || transB[0]=='c')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,1,2> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='C' || transB[0]=='c')) {
KokkosBlas::Impl::GEMMImpl<typename CViewType::execution_space,AViewType,BViewType,CViewType,blockA0,blockA1,blockB1,2,2> gemm(alpha,A,B,beta,C);
gemm.run(team_size,vector_length,scratch_level);
}
}
Kokkos::Profiling::popRegion();
}
Expand Down
Loading

0 comments on commit 835f2ca

Please sign in to comment.