From 40410ecb0df832e48e33b354d847515517cce93b Mon Sep 17 00:00:00 2001 From: iyamaza Date: Tue, 13 Jul 2021 15:21:38 -0600 Subject: [PATCH 1/8] move dot-based GEMM out of TPL CUBLAS --- .../impl/KokkosBlas3_gemm_dotbased_impl.hpp | 206 ++++++++++++++++ src/blas/impl/KokkosBlas3_gemm_spec.hpp | 160 +++++++----- .../tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp | 229 ++---------------- 3 files changed, 319 insertions(+), 276 deletions(-) create mode 100644 src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp diff --git a/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp b/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp new file mode 100644 index 0000000000..775443d991 --- /dev/null +++ b/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp @@ -0,0 +1,206 @@ +/* +//@HEADER +// ************************************************************************ +// +// Kokkos v. 3.0 +// Copyright (2020) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// 1. Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// +// 2. Redistributions in binary form must reproduce the above copyright +// notice, this list of conditions and the following disclaimer in the +// documentation and/or other materials provided with the distribution. +// +// 3. Neither the name of the Corporation nor the names of the +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY NTESS "AS IS" AND ANY +// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NTESS OR THE +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING +// NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Questions? Contact Siva Rajamanickam (srajama@sandia.gov) +// +// ************************************************************************ +//@HEADER +*/ + +#ifndef KOKKOS_BLAS3_GEMM_DOTBASED_IMPL_HPP_ +#define KOKKOS_BLAS3_GEMM_DOTBASED_IMPL_HPP_ + +namespace KokkosBlas { +namespace Impl { + + +// DotBasedGEMM implements the optimization for C = beta*C + alpha*A^TB +// with A and B matrices both being tall and skinny. C matrix is assumably +// small, so, each entry of C is computed by performing the dot product of +// respective columns of A and B matrices. Note that the dot products are +// performed on very long vectors, so, each dot product is distributed among +// numDivPerDot teams. + +struct TagZero{}; // The init tag for beta=0 +struct TagInit{}; // The init tag for beta!=0 and beta !=1 +struct TagMult{}; // The multiplication tag for transposed A +struct TagMultCT{}; // The multiplication tag for conjugate-transposed A +template +struct DotBasedGEMM{ + + const AV A; + const BV B; + CV C; + + using scalar_A = typename AV::non_const_value_type; + using size_A = typename AV::size_type; + using scalar_C = typename CV::non_const_value_type; + using size_C = typename CV::size_type; + using AVT = Kokkos::Details::ArithTraits; + using CVT = Kokkos::Details::ArithTraits; + + const scalar_A alpha; + const scalar_C beta; + + // The following types (especially dotSize) could have simply been int, + const size_C numCrows; + const size_C numCcols; + + size_C numDivPerDot; // number of teams collectively performing a dot product + size_C numTeams; // total number of teams + + const size_A dotSize; // the length of the vectors in the dot products + size_A chunkSize; // the local length of each team's share on the dot product + + + DotBasedGEMM(const scalar_A& alpha_, const AV& A_, const BV& B_, const scalar_C& beta_, const CV& C_):A(A_),B(B_),C(C_),alpha(alpha_),beta(beta_),numCrows(C.extent(0)),numCcols(C.extent(1)),dotSize(A.extent(0)) + { } + + void run(bool conjugateTranspose) { + + constexpr size_C workPerTeam = 4096; // Amount of work per team + const size_C ndots = numCrows * numCcols; // Number of dot products + size_C appxNumTeams = (dotSize * ndots) / workPerTeam; // Estimation for appxNumTeams + + // Adjust appxNumTeams in case it is too small or too large + if(appxNumTeams < 1) + appxNumTeams = 1; + if(appxNumTeams > 1024) + appxNumTeams = 1024; + + // If there are more dot products than the number of teams, + // then set the number of teams to be number of dot products + // and each team will perform only one dot product. + // We don't want a team to perform more than one dot product. + if(ndots >= appxNumTeams) { + numTeams = ndots; + numDivPerDot = 1; + } + // If there are more teams than dot products, each dot product can + // potentially be performed by multiple teams. First, compute + // numDivPerDot as an integer (take the floor, not ceiling), then, + // compute actual number of teams by using this factor. + else{ + numDivPerDot = appxNumTeams / ndots; + numTeams = ndots * numDivPerDot; + } + + // Determine the local length for the dot product + chunkSize = dotSize / numDivPerDot; + if(numDivPerDot > 1) + chunkSize++; + + // Initialize C matrix if beta != 1 + if(beta == CVT::zero()) { + Kokkos::MDRangePolicy> policyInit({0,0}, {numCrows, numCcols}); + Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this); + } + else if(beta != CVT::one()) { + Kokkos::MDRangePolicy> policyInit({0,0}, {numCrows, numCcols}); + Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this); + } + + // Multiply alpha*A^TB and add it to beta*C + if(conjugateTranspose) { + Kokkos::TeamPolicy policyMult(numTeams, Kokkos::AUTO); + Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this); + } + else{ + Kokkos::TeamPolicy policyMult(numTeams, Kokkos::AUTO); + Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this); + } + } + + KOKKOS_INLINE_FUNCTION + void operator() (const TagZero&, const size_C &rowId, const size_C &colId ) const { + C(rowId, colId) = CVT::zero(); + } + + KOKKOS_INLINE_FUNCTION + void operator() (const TagInit&, const size_C &rowId, const size_C &colId ) const { + C(rowId, colId) = beta * C(rowId, colId); + } + + KOKKOS_INLINE_FUNCTION + void operator() (const TagMult&, const typename Kokkos::TeamPolicy::member_type& teamMember) const { + + const size_C globalRank = teamMember.league_rank(); + const size_C localRank = globalRank % numDivPerDot; + const size_C i = globalRank / numDivPerDot; + const size_C rowId = i / numCcols; + const size_C colId = i % numCcols; + + scalar_C result = CVT::zero(); + const size_A baseInd = chunkSize*localRank; + Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { + if(baseInd + k < dotSize) + update += alpha * A(baseInd+k, rowId) * B(baseInd+k, colId); + }, result ); + + Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { + Kokkos::atomic_add(&C(rowId, colId), result); + }); + } + + KOKKOS_INLINE_FUNCTION + void operator() (const TagMultCT&, const typename Kokkos::TeamPolicy::member_type& teamMember) const { + + const size_C globalRank = teamMember.league_rank(); + const size_C localRank = globalRank % numDivPerDot; + const size_C i = globalRank / numDivPerDot; + const size_C rowId = i / numCcols; + const size_C colId = i % numCcols; + + scalar_C result = CVT::zero(); + const size_A baseInd = chunkSize*localRank; + Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { + if(baseInd + k < dotSize) + update += alpha * AVT::conj(A(baseInd+k, rowId)) * B(baseInd+k, colId); + }, result ); + + Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { + Kokkos::atomic_add(&C(rowId, colId), result); + }); + } + +}; + +} +} + +#endif diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index 2a63c3736f..ca05297bda 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -50,6 +50,7 @@ #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY #include +#include #endif namespace KokkosBlas { @@ -135,73 +136,98 @@ struct GEMM { typedef typename BViewType::non_const_value_type ScalarB; typedef typename CViewType::non_const_value_type ScalarC; - // Define Blocking sizes (this will be used for scratch spaces) - static constexpr int blockA0 = 24; - static constexpr int blockB1 = 64; - static constexpr int blockA1 = (sizeof(ScalarA)*blockA0*16 + sizeof(ScalarB)*16*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 16 : - (sizeof(ScalarA)*blockA0*8 + sizeof(ScalarB)*8*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 8 : - (sizeof(ScalarA)*blockA0*4 + sizeof(ScalarB)*4*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 4 : 16 ; - static constexpr int vector_length = blockB1/4; - - // Compute scratch space size - typedef KokkosBlas::Impl::GEMMImpl gemm_dummy_type; - const int scratch_memory_size = - gemm_dummy_type::ViewTypeAScratch::required_allocation_size() + - gemm_dummy_type::ViewTypeBScratch::required_allocation_size() + - gemm_dummy_type::ViewTypeCScratch::required_allocation_size(); - const int scratch_level = scratch_memory_size < 24000 ? 0 : 1; - - // Figure out Team Sizes - int team_size = 1; - #if defined(KOKKOS_ENABLE_CUDA) - if(std::is_same::value) - team_size = blockA0; - #endif - #if defined(KOKKOS_ENABLE_HIP) - if(std::is_same::value) - team_size = blockA0; - #endif - #if defined(KOKKOS_ENABLE_ROCM) - if(std::is_same::value) - team_size = blockA0; - #endif - - // Call the correct kernel - if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='N' || transB[0]=='n')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='N' || transB[0]=='n')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='N' || transB[0]=='n')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='T' || transB[0]=='t')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='T' || transB[0]=='t')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='T' || transB[0]=='t')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='C' || transB[0]=='c')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='C' || transB[0]=='c')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); - } - if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='C' || transB[0]=='c')) { - KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); - gemm.run(team_size,vector_length,scratch_level); + +#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY + typedef typename CViewType::execution_space ExecSpace; + + // Figure out if we used use DotBased implementation + const int M = static_cast (C.extent(0)); + const int N = static_cast (C.extent(1)); + + const bool A_is_lr = std::is_same::value; + const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c')); + const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c')); + + constexpr int numDotsLayoutLeftThreshold = 1600; + constexpr int numDotsLayoutRightThreshold = 100; + if( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) + || ( A_is_lr && !A_is_tr && B_is_tr && M*N < numDotsLayoutRightThreshold)) { + // call dot-based GEMM + bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c')); + DotBasedGEMM dotBasedGemm(alpha, A, B, beta, C); + dotBasedGemm.run(A_is_conj ? true : false); + } else +#endif + { + + // Define Blocking sizes (this will be used for scratch spaces) + static constexpr int blockA0 = 24; + static constexpr int blockB1 = 64; + static constexpr int blockA1 = (sizeof(ScalarA)*blockA0*16 + sizeof(ScalarB)*16*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 16 : + (sizeof(ScalarA)*blockA0*8 + sizeof(ScalarB)*8*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 8 : + (sizeof(ScalarA)*blockA0*4 + sizeof(ScalarB)*4*blockB1 + sizeof(ScalarC)*blockA0*blockB1 < 24000) ? 4 : 16 ; + static constexpr int vector_length = blockB1/4; + + // Compute scratch space size + typedef KokkosBlas::Impl::GEMMImpl gemm_dummy_type; + const int scratch_memory_size = + gemm_dummy_type::ViewTypeAScratch::required_allocation_size() + + gemm_dummy_type::ViewTypeBScratch::required_allocation_size() + + gemm_dummy_type::ViewTypeCScratch::required_allocation_size(); + const int scratch_level = scratch_memory_size < 24000 ? 0 : 1; + + // Figure out Team Sizes + int team_size = 1; + #if defined(KOKKOS_ENABLE_CUDA) + if(std::is_same::value) + team_size = blockA0; + #endif + #if defined(KOKKOS_ENABLE_HIP) + if(std::is_same::value) + team_size = blockA0; + #endif + #if defined(KOKKOS_ENABLE_ROCM) + if(std::is_same::value) + team_size = blockA0; + #endif + + // Call the correct kernel + if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='N' || transB[0]=='n')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='N' || transB[0]=='n')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='N' || transB[0]=='n')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='T' || transB[0]=='t')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='T' || transB[0]=='t')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='T' || transB[0]=='t')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='N' || transA[0]=='n') && (transB[0]=='C' || transB[0]=='c')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='T' || transA[0]=='t') && (transB[0]=='C' || transB[0]=='c')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } + if((transA[0]=='C' || transA[0]=='c') && (transB[0]=='C' || transB[0]=='c')) { + KokkosBlas::Impl::GEMMImpl gemm(alpha,A,B,beta,C); + gemm.run(team_size,vector_length,scratch_level); + } } Kokkos::Profiling::popRegion(); } diff --git a/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp b/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp index 3e4bc5ed53..c6cf04c4a6 100644 --- a/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp +++ b/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp @@ -336,159 +336,6 @@ KOKKOSBLAS3_CGEMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::Layout namespace KokkosBlas { namespace Impl { - -// DotBasedGEMM implements the optimization for C = beta*C + alpha*A^TB -// with A and B matrices both being tall and skinny. C matrix is assumably -// small, so, each entry of C is computed by performing the dot product of -// respective columns of A and B matrices. Note that the dot products are -// performed on very long vectors, so, each dot product is distributed among -// numDivPerDot teams. - -struct TagZero{}; // The init tag for beta=0 -struct TagInit{}; // The init tag for beta!=0 and beta !=1 -struct TagMult{}; // The multiplication tag for transposed A -struct TagMultCT{}; // The multiplication tag for conjugate-transposed A -template -struct DotBasedGEMM{ - - const AV A; - const BV B; - CV C; - - using scalar_A = typename AV::non_const_value_type; - using size_A = typename AV::size_type; - using scalar_C = typename CV::non_const_value_type; - using size_C = typename CV::size_type; - using AVT = Kokkos::Details::ArithTraits; - using CVT = Kokkos::Details::ArithTraits; - - const scalar_A alpha; - const scalar_C beta; - - // The following types (especially dotSize) could have simply been int, - const size_C numCrows; - const size_C numCcols; - - size_C numDivPerDot; // number of teams collectively performing a dot product - size_C numTeams; // total number of teams - - const size_A dotSize; // the length of the vectors in the dot products - size_A chunkSize; // the local length of each team's share on the dot product - - - DotBasedGEMM(const scalar_A& alpha_, const AV& A_, const BV& B_, const scalar_C& beta_, const CV& C_):A(A_),B(B_),C(C_),alpha(alpha_),beta(beta_),numCrows(C.extent(0)),numCcols(C.extent(1)),dotSize(A.extent(0)) - { } - - void run(bool conjugateTranspose) { - - constexpr size_C workPerTeam = 4096; // Amount of work per team - const size_C ndots = numCrows * numCcols; // Number of dot products - size_C appxNumTeams = (dotSize * ndots) / workPerTeam; // Estimation for appxNumTeams - - // Adjust appxNumTeams in case it is too small or too large - if(appxNumTeams < 1) - appxNumTeams = 1; - if(appxNumTeams > 1024) - appxNumTeams = 1024; - - // If there are more dot products than the number of teams, - // then set the number of teams to be number of dot products - // and each team will perform only one dot product. - // We don't want a team to perform more than one dot product. - if(ndots >= appxNumTeams) { - numTeams = ndots; - numDivPerDot = 1; - } - // If there are more teams than dot products, each dot product can - // potentially be performed by multiple teams. First, compute - // numDivPerDot as an integer (take the floor, not ceiling), then, - // compute actual number of teams by using this factor. - else{ - numDivPerDot = appxNumTeams / ndots; - numTeams = ndots * numDivPerDot; - } - - // Determine the local length for the dot product - chunkSize = dotSize / numDivPerDot; - if(numDivPerDot > 1) - chunkSize++; - - // Initialize C matrix if beta != 1 - if(beta == CVT::zero()) { - Kokkos::MDRangePolicy> policyInit({0,0}, {numCrows, numCcols}); - Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this); - } - else if(beta != CVT::one()) { - Kokkos::MDRangePolicy> policyInit({0,0}, {numCrows, numCcols}); - Kokkos::parallel_for("Initialize C for Dot Product Based GEMM", policyInit, *this); - } - - // Multiply alpha*A^TB and add it to beta*C - if(conjugateTranspose) { - Kokkos::TeamPolicy policyMult(numTeams, Kokkos::AUTO); - Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this); - } - else{ - Kokkos::TeamPolicy policyMult(numTeams, Kokkos::AUTO); - Kokkos::parallel_for("Perform Dot Product Based GEMM", policyMult, *this); - } - } - - KOKKOS_INLINE_FUNCTION - void operator() (const TagZero&, const size_C &rowId, const size_C &colId ) const { - C(rowId, colId) = CVT::zero(); - } - - KOKKOS_INLINE_FUNCTION - void operator() (const TagInit&, const size_C &rowId, const size_C &colId ) const { - C(rowId, colId) = beta * C(rowId, colId); - } - - KOKKOS_INLINE_FUNCTION - void operator() (const TagMult&, const typename Kokkos::TeamPolicy<>::member_type& teamMember) const { - - const size_C globalRank = teamMember.league_rank(); - const size_C localRank = globalRank % numDivPerDot; - const size_C i = globalRank / numDivPerDot; - const size_C rowId = i / numCcols; - const size_C colId = i % numCcols; - - scalar_C result = CVT::zero(); - const size_A baseInd = chunkSize*localRank; - Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { - if(baseInd + k < dotSize) - update += alpha * A(baseInd+k, rowId) * B(baseInd+k, colId); - }, result ); - - Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { - Kokkos::atomic_add(&C(rowId, colId), result); - }); - } - - KOKKOS_INLINE_FUNCTION - void operator() (const TagMultCT&, const typename Kokkos::TeamPolicy<>::member_type& teamMember) const { - - const size_C globalRank = teamMember.league_rank(); - const size_C localRank = globalRank % numDivPerDot; - const size_C i = globalRank / numDivPerDot; - const size_C rowId = i / numCcols; - const size_C colId = i % numCcols; - - scalar_C result = CVT::zero(); - const size_A baseInd = chunkSize*localRank; - Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { - if(baseInd + k < dotSize) - update += alpha * AVT::conj(A(baseInd+k, rowId)) * B(baseInd+k, colId); - }, result ); - - Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { - Kokkos::atomic_add(&C(rowId, colId), result); - }); - } - -}; - - #define KOKKOSBLAS3_DGEMM_CUBLAS( LAYOUTA, LAYOUTB, LAYOUTC, MEM_SPACE, ETI_SPEC_AVAIL ) \ template \ struct GEMM< \ @@ -544,20 +391,11 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - constexpr int numDotsLayoutLeftThreshold = 1600; \ - constexpr int numDotsLayoutRightThreshold = 100; \ - if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ - || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ - DotBasedGEMM gemm(alpha,A,B,beta,C); \ - gemm.run(false); \ - } \ - else { \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasDgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasDgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ - } \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasDgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasDgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -617,20 +455,11 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - constexpr int numDotsLayoutLeftThreshold = 1600; \ - constexpr int numDotsLayoutRightThreshold = 100; \ - if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ - || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ - DotBasedGEMM gemm(alpha,A,B,beta,C); \ - gemm.run(false); \ - } \ - else { \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasSgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasSgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ - } \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasSgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasSgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -690,20 +519,11 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - constexpr int numDotsLayoutLeftThreshold = 1600; \ - constexpr int numDotsLayoutRightThreshold = 100; \ - if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ - || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ - DotBasedGEMM gemm(alpha,A,B,beta,C); \ - gemm.run(transa == CUBLAS_OP_C ? true : false); \ - } \ - else { \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasZgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasZgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ - } \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasZgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasZgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ Kokkos::Profiling::popRegion(); \ } \ }; \ @@ -763,20 +583,11 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - constexpr int numDotsLayoutLeftThreshold = 1600; \ - constexpr int numDotsLayoutRightThreshold = 100; \ - if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ - || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ - DotBasedGEMM gemm(alpha,A,B,beta,C); \ - gemm.run(transa == CUBLAS_OP_C ? true : false); \ - } \ - else { \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasCgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasCgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ - } \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasCgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasCgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ Kokkos::Profiling::popRegion(); \ } \ }; From 4be3c652759c829d1e1db7189b8930f57c6f0a78 Mon Sep 17 00:00:00 2001 From: iyamazaki Date: Tue, 13 Jul 2021 22:40:00 -0600 Subject: [PATCH 2/8] fix a bug (dot-based Gemm should be called only for C := beta * C + alpha * A^T*B case). --- src/blas/impl/KokkosBlas3_gemm_spec.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index ca05297bda..81bdbd950d 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -150,9 +150,9 @@ struct GEMM { constexpr int numDotsLayoutLeftThreshold = 1600; constexpr int numDotsLayoutRightThreshold = 100; - if( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) - || ( A_is_lr && !A_is_tr && B_is_tr && M*N < numDotsLayoutRightThreshold)) { - // call dot-based GEMM + if( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) + || ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold)) { + // call dot-based GEMM, only for C := beta * C + alpha * A^T * B bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c')); DotBasedGEMM dotBasedGemm(alpha, A, B, beta, C); dotBasedGemm.run(A_is_conj ? true : false); From a9f944c99026373c996a9b365095c01b627ccde4 Mon Sep 17 00:00:00 2001 From: iyamazaki Date: Wed, 14 Jul 2021 08:52:49 -0600 Subject: [PATCH 3/8] cleaning up the code --- src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp | 7 ++++--- src/blas/impl/KokkosBlas3_gemm_spec.hpp | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp b/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp index 775443d991..becd6a6c7d 100644 --- a/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp @@ -77,7 +77,6 @@ struct DotBasedGEMM{ const scalar_A alpha; const scalar_C beta; - // The following types (especially dotSize) could have simply been int, const size_C numCrows; const size_C numCcols; @@ -88,7 +87,9 @@ struct DotBasedGEMM{ size_A chunkSize; // the local length of each team's share on the dot product - DotBasedGEMM(const scalar_A& alpha_, const AV& A_, const BV& B_, const scalar_C& beta_, const CV& C_):A(A_),B(B_),C(C_),alpha(alpha_),beta(beta_),numCrows(C.extent(0)),numCcols(C.extent(1)),dotSize(A.extent(0)) + DotBasedGEMM(const scalar_A& alpha_, const AV& A_, const BV& B_, const scalar_C& beta_, const CV& C_) : + A(A_), B(B_), C(C_), alpha(alpha_), beta(beta_), + numCrows(C.extent(0)), numCcols(C.extent(1)), dotSize(A.extent(0)) { } void run(bool conjugateTranspose) { @@ -115,7 +116,7 @@ struct DotBasedGEMM{ // potentially be performed by multiple teams. First, compute // numDivPerDot as an integer (take the floor, not ceiling), then, // compute actual number of teams by using this factor. - else{ + else { numDivPerDot = appxNumTeams / ndots; numTeams = ndots * numDivPerDot; } diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index 81bdbd950d..aab3d8196e 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -140,7 +140,7 @@ struct GEMM { #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY typedef typename CViewType::execution_space ExecSpace; - // Figure out if we used use DotBased implementation + // Figure out whether to use DotBased implementation const int M = static_cast (C.extent(0)); const int N = static_cast (C.extent(1)); From 639d483d84f57c94c5cfaf4766417c60489c3fb7 Mon Sep 17 00:00:00 2001 From: iyamazaki Date: Wed, 14 Jul 2021 08:54:47 -0600 Subject: [PATCH 4/8] remove the redundant ifdef's --- src/blas/impl/KokkosBlas3_gemm_spec.hpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index aab3d8196e..1d068c6572 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -137,7 +137,6 @@ struct GEMM { typedef typename CViewType::non_const_value_type ScalarC; -#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY typedef typename CViewType::execution_space ExecSpace; // Figure out whether to use DotBased implementation @@ -152,13 +151,13 @@ struct GEMM { constexpr int numDotsLayoutRightThreshold = 100; if( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) || ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold)) { + // call dot-based GEMM, only for C := beta * C + alpha * A^T * B bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c')); DotBasedGEMM dotBasedGemm(alpha, A, B, beta, C); dotBasedGemm.run(A_is_conj ? true : false); - } else -#endif - { + + } else { // Define Blocking sizes (this will be used for scratch spaces) static constexpr int blockA0 = 24; From 81b359da92e82ab56dbe8f5c241da948d6f73177 Mon Sep 17 00:00:00 2001 From: iyamazaki Date: Wed, 14 Jul 2021 16:35:27 -0600 Subject: [PATCH 5/8] bring back calls to dot-based GEMM in the GEMMs in TPL CUBLAS --- .../impl/KokkosBlas3_gemm_dotbased_impl.hpp | 10 ++- src/blas/impl/KokkosBlas3_gemm_spec.hpp | 1 + .../tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp | 76 ++++++++++++++----- 3 files changed, 63 insertions(+), 24 deletions(-) diff --git a/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp b/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp index becd6a6c7d..8c4e404e9f 100644 --- a/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_dotbased_impl.hpp @@ -94,6 +94,8 @@ struct DotBasedGEMM{ void run(bool conjugateTranspose) { + // NOTE: these workPerTeam and approxNumTeams were used for TPL CUBLAS, + // and may need to be retuned for other architectures constexpr size_C workPerTeam = 4096; // Amount of work per team const size_C ndots = numCrows * numCcols; // Number of dot products size_C appxNumTeams = (dotSize * ndots) / workPerTeam; // Estimation for appxNumTeams @@ -169,8 +171,8 @@ struct DotBasedGEMM{ scalar_C result = CVT::zero(); const size_A baseInd = chunkSize*localRank; Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { - if(baseInd + k < dotSize) - update += alpha * A(baseInd+k, rowId) * B(baseInd+k, colId); + if(baseInd + k < dotSize) + update += alpha * A(baseInd+k, rowId) * B(baseInd+k, colId); }, result ); Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { @@ -190,8 +192,8 @@ struct DotBasedGEMM{ scalar_C result = CVT::zero(); const size_A baseInd = chunkSize*localRank; Kokkos::parallel_reduce( Kokkos::TeamThreadRange(teamMember, chunkSize), [&]( const size_A k, scalar_C &update ) { - if(baseInd + k < dotSize) - update += alpha * AVT::conj(A(baseInd+k, rowId)) * B(baseInd+k, colId); + if(baseInd + k < dotSize) + update += alpha * AVT::conj(A(baseInd+k, rowId)) * B(baseInd+k, colId); }, result ); Kokkos::single(Kokkos::PerTeam(teamMember), [&] () { diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index 1d068c6572..6780cfbb60 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -147,6 +147,7 @@ struct GEMM { const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c')); const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c')); + // NOTE: these thresholds were copied from TPL CUBLAS, and may need to be retuned constexpr int numDotsLayoutLeftThreshold = 1600; constexpr int numDotsLayoutRightThreshold = 100; if( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) diff --git a/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp b/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp index c6cf04c4a6..6ab0181881 100644 --- a/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp +++ b/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp @@ -391,11 +391,20 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasDgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasDgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ + constexpr int numDotsLayoutLeftThreshold = 1600; \ + constexpr int numDotsLayoutRightThreshold = 100; \ + if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ + || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ + DotBasedGEMM gemm(alpha,A,B,beta,C); \ + gemm.run(false); \ + } \ + else { \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasDgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasDgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ + } \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -455,11 +464,20 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasSgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasSgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ + constexpr int numDotsLayoutLeftThreshold = 1600; \ + constexpr int numDotsLayoutRightThreshold = 100; \ + if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ + || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ + DotBasedGEMM gemm(alpha,A,B,beta,C); \ + gemm.run(false); \ + } \ + else { \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasSgemm(s.handle, transa, transb, M, N, K, &alpha, A.data(), LDA, B.data(), LDB, &beta, C.data(), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasSgemm(s.handle, transb, transa, N, M, K, &alpha, B.data(), LDB, A.data(), LDA, &beta, C.data(), LDC); \ + } \ Kokkos::Profiling::popRegion(); \ } \ }; @@ -519,11 +537,20 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasZgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasZgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + constexpr int numDotsLayoutLeftThreshold = 1600; \ + constexpr int numDotsLayoutRightThreshold = 100; \ + if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ + || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ + DotBasedGEMM gemm(alpha,A,B,beta,C); \ + gemm.run(transa == CUBLAS_OP_C ? true : false); \ + } \ + else { \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasZgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasZgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + } \ Kokkos::Profiling::popRegion(); \ } \ }; \ @@ -583,11 +610,20 @@ struct GEMM< \ else \ transb = CUBLAS_OP_C; \ \ - KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ - if(!A_is_lr && !B_is_lr && !C_is_lr ) \ - cublasCgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ - if(A_is_lr && B_is_lr && C_is_lr ) \ - cublasCgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + constexpr int numDotsLayoutLeftThreshold = 1600; \ + constexpr int numDotsLayoutRightThreshold = 100; \ + if( (!A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutLeftThreshold) \ + || ( A_is_lr && transa != CUBLAS_OP_N && transb == CUBLAS_OP_N && M*N < numDotsLayoutRightThreshold)) { \ + DotBasedGEMM gemm(alpha,A,B,beta,C); \ + gemm.run(transa == CUBLAS_OP_C ? true : false); \ + } \ + else { \ + KokkosBlas::Impl::CudaBlasSingleton & s = KokkosBlas::Impl::CudaBlasSingleton::singleton(); \ + if(!A_is_lr && !B_is_lr && !C_is_lr ) \ + cublasCgemm(s.handle, transa, transb, M, N, K, reinterpret_cast(&alpha), reinterpret_cast(A.data()), LDA, reinterpret_cast(B.data()), LDB, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + if(A_is_lr && B_is_lr && C_is_lr ) \ + cublasCgemm(s.handle, transb, transa, N, M, K, reinterpret_cast(&alpha), reinterpret_cast(B.data()), LDB, reinterpret_cast(A.data()), LDA, reinterpret_cast(&beta), reinterpret_cast(C.data()), LDC); \ + } \ Kokkos::Profiling::popRegion(); \ } \ }; From dcd8f367df71302f75f253f27094a4f507757c3b Mon Sep 17 00:00:00 2001 From: iyamazaki Date: Thu, 15 Jul 2021 10:04:54 -0600 Subject: [PATCH 6/8] include dot-based GEMM for TPL CUBLAS --- src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp b/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp index 6ab0181881..4eba52efac 100644 --- a/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp +++ b/src/impl/tpls/KokkosBlas3_gemm_tpl_spec_decl.hpp @@ -332,6 +332,7 @@ KOKKOSBLAS3_CGEMM_BLAS( Kokkos::LayoutRight, Kokkos::LayoutRight, Kokkos::Layout // cuBLAS #ifdef KOKKOSKERNELS_ENABLE_TPL_CUBLAS #include +#include namespace KokkosBlas { namespace Impl { From fdad674d9a45706baeb14b808d35cabddef97f94 Mon Sep 17 00:00:00 2001 From: iyamaza Date: Tue, 20 Jul 2021 21:13:22 -0600 Subject: [PATCH 7/8] perform dot-based GEMM, only on device --- src/blas/impl/KokkosBlas3_gemm_spec.hpp | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index 6780cfbb60..739b4e1c6f 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -143,6 +143,7 @@ struct GEMM { const int M = static_cast (C.extent(0)); const int N = static_cast (C.extent(1)); + const bool host_space = std::is_same::value; const bool A_is_lr = std::is_same::value; const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c')); const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c')); @@ -150,13 +151,14 @@ struct GEMM { // NOTE: these thresholds were copied from TPL CUBLAS, and may need to be retuned constexpr int numDotsLayoutLeftThreshold = 1600; constexpr int numDotsLayoutRightThreshold = 100; - if( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) - || ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold)) { + if(( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) + || ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold)) + && !host_space) { - // call dot-based GEMM, only for C := beta * C + alpha * A^T * B + // call dot-based GEMM, only for C := beta * C + alpha * A^T * B, on device bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c')); DotBasedGEMM dotBasedGemm(alpha, A, B, beta, C); - dotBasedGemm.run(A_is_conj ? true : false); + dotBasedGemm.run(A_is_conj); } else { From 6d6ee66946775e654cdf9f2ac34f41cb486e01bd Mon Sep 17 00:00:00 2001 From: iyamaza Date: Wed, 21 Jul 2021 11:53:09 -0600 Subject: [PATCH 8/8] use kk_is_gpu_exec_space() to check if it is on device --- src/blas/impl/KokkosBlas3_gemm_spec.hpp | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index 739b4e1c6f..8064219a93 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -49,8 +49,9 @@ #include "Kokkos_InnerProductSpaceTraits.hpp" #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY -#include -#include +#include "KokkosBlas3_gemm_impl.hpp" +#include "KokkosBlas3_gemm_dotbased_impl.hpp" +#include "KokkosKernels_ExecSpaceUtils.hpp" #endif namespace KokkosBlas { @@ -135,15 +136,13 @@ struct GEMM { typedef typename AViewType::non_const_value_type ScalarA; typedef typename BViewType::non_const_value_type ScalarB; typedef typename CViewType::non_const_value_type ScalarC; - - typedef typename CViewType::execution_space ExecSpace; // Figure out whether to use DotBased implementation const int M = static_cast (C.extent(0)); const int N = static_cast (C.extent(1)); - const bool host_space = std::is_same::value; + const bool is_device_space = KokkosKernels::Impl::kk_is_gpu_exec_space(); const bool A_is_lr = std::is_same::value; const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c')); const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c')); @@ -153,7 +152,7 @@ struct GEMM { constexpr int numDotsLayoutRightThreshold = 100; if(( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) || ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold)) - && !host_space) { + && is_device_space) { // call dot-based GEMM, only for C := beta * C + alpha * A^T * B, on device bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c'));