Skip to content

Commit

Permalink
Disallow BsrMatrix tensor core SpMV for non-scalar types
Browse files Browse the repository at this point in the history
  • Loading branch information
cwpearson committed Aug 10, 2023
1 parent 95be1a4 commit 6ca39f1
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 35 deletions.
56 changes: 36 additions & 20 deletions sparse/impl/KokkosSparse_spmv_bsrmatrix_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,40 @@ struct BsrMatrixSpMVTensorCoreFunctorParams {
int leagueDim_y;
};

/*! \brief Can the tensor core impl be used in ExecutionSpace to operate on
AMatrix, XMatrix, and YMatrix?
*/
template <typename ExecutionSpace, typename AMatrix, typename XMatrix,
typename YMatrix>
class TensorCoresAvailable {
#if defined(KOKKOS_ENABLE_CUDA)
using AScalar = typename AMatrix::non_const_value_type;
using YScalar = typename YMatrix::non_const_value_type;
using XScalar = typename XMatrix::non_const_value_type;

using a_mem_space = typename AMatrix::memory_space;
using x_mem_space = typename XMatrix::memory_space;
using y_mem_space = typename YMatrix::memory_space;

template <typename T>
constexpr static bool is_scalar() {
return std::is_scalar_v<T> ||
std::is_same_v<T, Kokkos::Experimental::half_t>;
}

public:
constexpr static inline bool value =
Kokkos::SpaceAccessibility<ExecutionSpace, a_mem_space>::accessible &&
Kokkos::SpaceAccessibility<ExecutionSpace, x_mem_space>::accessible &&
Kokkos::SpaceAccessibility<ExecutionSpace, y_mem_space>::accessible &&
is_scalar<AScalar>() && is_scalar<XScalar>() && is_scalar<YScalar>() &&
std::is_same_v<ExecutionSpace, Kokkos::Cuda>;
#else
public:
constexpr static inline bool value = false;
#endif
};

/// \brief Functor for the BsrMatrix SpMV multivector implementation utilizing
/// tensor cores.
///
Expand Down Expand Up @@ -471,30 +505,12 @@ struct BsrMatrixSpMVTensorCoreDispatcher {
"execution spaces");
}

/*true if none of T1, T2, or T3 are complex*/
template <typename T1, typename T2, typename T3>
struct none_complex {
const static bool value = !Kokkos::ArithTraits<T1>::is_complex &&
!Kokkos::ArithTraits<T2>::is_complex &&
!Kokkos::ArithTraits<T3>::is_complex;
};

/*true if T1::execution_space, T2, or T3 are all GPU exec space*/
template <typename T1, typename T2, typename T3>
struct all_gpu {
const static bool value = KokkosKernels::Impl::kk_is_gpu_exec_space<T1>() &&
KokkosKernels::Impl::kk_is_gpu_exec_space<T2>() &&
KokkosKernels::Impl::kk_is_gpu_exec_space<T3>();
};

static void dispatch(YScalar alpha, AMatrix a, XMatrix x, YScalar beta,
YMatrix y) {
// tag will be false unless all conditions are met
using tag = std::integral_constant<
bool, none_complex<AScalar, XScalar, YScalar>::value &&
all_gpu<typename AMatrix::execution_space,
typename XMatrix::execution_space,
typename YMatrix::execution_space>::value>;
bool, TensorCoresAvailable<typename AMatrix::execution_space, AMatrix,
XMatrix, YMatrix>::value>;
tag_dispatch(tag{}, alpha, a, x, beta, y);
}
};
Expand Down
20 changes: 5 additions & 15 deletions sparse/impl/KokkosSparse_spmv_bsrmatrix_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,25 +228,15 @@ struct SPMV_MV_BSRMATRIX<AT, AO, AD, AM, AS, XT, XL, XD, XM, YT, YL, YD, YM,
#if defined(KOKKOS_ARCH_AMPERE) || defined(KOKKOS_ARCH_VOLTA)
Method method = Method::Fallback;
{
typedef typename AMatrix::non_const_value_type AScalar;
typedef typename XVector::non_const_value_type XScalar;
// try to use tensor cores if requested
if (controls.getParameter("algorithm") == ALG_TC)
method = Method::TensorCores;
// can't use tensor cores for complex
if (Kokkos::ArithTraits<YScalar>::is_complex) method = Method::Fallback;
if (Kokkos::ArithTraits<XScalar>::is_complex) method = Method::Fallback;
if (Kokkos::ArithTraits<AScalar>::is_complex) method = Method::Fallback;
// can't use tensor cores outside GPU
if (!KokkosKernels::Impl::kk_is_gpu_exec_space<
typename AMatrix::execution_space>())
method = Method::Fallback;
if (!KokkosKernels::Impl::kk_is_gpu_exec_space<
typename XVector::execution_space>())
method = Method::Fallback;
if (!KokkosKernels::Impl::kk_is_gpu_exec_space<
typename YVector::execution_space>())

if (!KokkosSparse::Experimental::Impl::TensorCoresAvailable<
typename AMatrix::execution_space, AMatrix, XMatrix,
YMatrix>::value) {
method = Method::Fallback;
}
// can't use tensor cores unless mode is no-transpose
if (mode[0] != KokkosSparse::NoTranspose[0]) method = Method::Fallback;
#if KOKKOS_HALF_T_IS_FLOAT
Expand Down

0 comments on commit 6ca39f1

Please sign in to comment.