diff --git a/src/blas/impl/KokkosBlas3_gemm_spec.hpp b/src/blas/impl/KokkosBlas3_gemm_spec.hpp index 739b4e1c6f..8064219a93 100644 --- a/src/blas/impl/KokkosBlas3_gemm_spec.hpp +++ b/src/blas/impl/KokkosBlas3_gemm_spec.hpp @@ -49,8 +49,9 @@ #include "Kokkos_InnerProductSpaceTraits.hpp" #if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY -#include -#include +#include "KokkosBlas3_gemm_impl.hpp" +#include "KokkosBlas3_gemm_dotbased_impl.hpp" +#include "KokkosKernels_ExecSpaceUtils.hpp" #endif namespace KokkosBlas { @@ -135,15 +136,13 @@ struct GEMM { typedef typename AViewType::non_const_value_type ScalarA; typedef typename BViewType::non_const_value_type ScalarB; typedef typename CViewType::non_const_value_type ScalarC; - - typedef typename CViewType::execution_space ExecSpace; // Figure out whether to use DotBased implementation const int M = static_cast (C.extent(0)); const int N = static_cast (C.extent(1)); - const bool host_space = std::is_same::value; + const bool is_device_space = KokkosKernels::Impl::kk_is_gpu_exec_space(); const bool A_is_lr = std::is_same::value; const bool A_is_tr = ((transA[0]=='T') || (transA[0]=='t') || (transA[0]=='C') || (transA[0]=='c')); const bool B_is_tr = ((transB[0]=='T') || (transB[0]=='t') || (transB[0]=='C') || (transB[0]=='c')); @@ -153,7 +152,7 @@ struct GEMM { constexpr int numDotsLayoutRightThreshold = 100; if(( (!A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutLeftThreshold) || ( A_is_lr && A_is_tr && !B_is_tr && M*N < numDotsLayoutRightThreshold)) - && !host_space) { + && is_device_space) { // call dot-based GEMM, only for C := beta * C + alpha * A^T * B, on device bool A_is_conj = ((transA[0]=='C') || (transA[0]=='c'));