Skip to content

Commit

Permalink
use kk_is_gpu_exec_space() to check if it is on device
Browse files Browse the repository at this point in the history
  • Loading branch information
iyamazaki committed Jul 21, 2021
1 parent fdad674 commit 6d6ee66
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions src/blas/impl/KokkosBlas3_gemm_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@
#include "Kokkos_InnerProductSpaceTraits.hpp"

#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
#include<KokkosBlas3_gemm_impl.hpp>
#include<KokkosBlas3_gemm_dotbased_impl.hpp>
#include "KokkosBlas3_gemm_impl.hpp"
#include "KokkosBlas3_gemm_dotbased_impl.hpp"
#include "KokkosKernels_ExecSpaceUtils.hpp"
#endif

namespace KokkosBlas {
Expand Down Expand Up @@ -135,15 +136,13 @@ struct GEMM {
typedef typename AViewType::non_const_value_type ScalarA;
typedef typename BViewType::non_const_value_type ScalarB;
typedef typename CViewType::non_const_value_type ScalarC;


typedef typename CViewType::execution_space ExecSpace;

// Figure out whether to use DotBased implementation
const int M = static_cast<int> (C.extent(0));
const int N = static_cast<int> (C.extent(1));

const bool host_space = std::is_same<Kokkos::HostSpace, ExecSpace>::value;
const bool is_device_space = KokkosKernels::Impl::kk_is_gpu_exec_space<ExecSpace>();
const bool A_is_lr = std::is_same<Kokkos::LayoutRight, typename AViewType::array_layout>::value;
const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c'));
const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c'));
Expand All @@ -153,7 +152,7 @@ struct GEMM {
constexpr int numDotsLayoutRightThreshold = 100;
if(( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold)
|| ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold))
&& !host_space) {
&& is_device_space) {

// call dot-based GEMM, only for C := beta * C + alpha * A^T * B, on device
bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c'));
Expand Down

0 comments on commit 6d6ee66

Please sign in to comment.