Skip to content

Commit

Permalink
spmv_mv wrappers for rocsparse (#2233)
Browse files Browse the repository at this point in the history
* spmv_mv wrappers for rocsparse (rocsparse_spmm())

* Use consistent types for alpha/beta in spmv wrappers
  • Loading branch information
brian-kelley authored Jun 7, 2024
1 parent a7d02ca commit a955b8b
Show file tree
Hide file tree
Showing 5 changed files with 283 additions and 45 deletions.
54 changes: 25 additions & 29 deletions sparse/src/KokkosSparse_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,10 +263,7 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
/////////////////
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSPARSE
// cuSPARSE does not support the conjugate mode (C)
if constexpr (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::CudaSpace> ||
std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::CudaUVMSpace>) {
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
useNative = useNative || (mode[0] == Conjugate[0]);
}
// cuSPARSE 12 requires that the output (y) vector is 16-byte aligned for
Expand All @@ -278,20 +275,19 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
if (std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::HIPSpace>::value) {
if (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
useNative = useNative || (mode[0] != NoTranspose[0]);
}
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_MKL
if (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::HostSpace>) {
if constexpr (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::HostSpace>) {
useNative = useNative || (mode[0] == Conjugate[0]);
}
#ifdef KOKKOS_ENABLE_SYCL
if (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::Experimental::SYCLDeviceUSMSpace>) {
if constexpr (std::is_same_v<ExecutionSpace,
Kokkos::Experimental::SYCL>) {
useNative = useNative || (mode[0] == Conjugate[0]);
}
#endif
Expand Down Expand Up @@ -324,7 +320,14 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
// CRS, rank 2 //
/////////////////
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSPARSE
useNative = useNative || (Conjugate[0] == mode[0]);
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
useNative = useNative || (Conjugate[0] == mode[0]);
}
#endif
#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
useNative = useNative || (Conjugate[0] == mode[0]);
}
#endif

if (useNative) {
Expand Down Expand Up @@ -355,25 +358,21 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
/////////////////
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSPARSE
// cuSPARSE does not support the modes (C), (T), (H)
if (std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::CudaSpace>::value ||
std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::CudaUVMSpace>::value) {
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
useNative = useNative || (mode[0] != NoTranspose[0]);
}
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_MKL
if (std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::HostSpace>::value) {
if constexpr (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::HostSpace>) {
useNative = useNative || (mode[0] == Conjugate[0]);
}
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
// rocSparse does not support the modes (C), (T), (H)
if constexpr (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::HIPSpace>) {
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::HIP>) {
useNative = useNative || (mode[0] != NoTranspose[0]);
}
#endif
Expand Down Expand Up @@ -403,17 +402,14 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
/////////////////
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSPARSE
// cuSPARSE does not support the modes (C), (T), (H)
if (std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::CudaSpace>::value ||
std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::CudaUVMSpace>::value) {
if constexpr (std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
useNative = useNative || (mode[0] != NoTranspose[0]);
}
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_MKL
if (std::is_same<typename AMatrix_Internal::memory_space,
Kokkos::HostSpace>::value) {
if constexpr (std::is_same_v<typename AMatrix_Internal::memory_space,
Kokkos::HostSpace>) {
useNative = useNative || (mode[0] == Conjugate[0]);
}
#endif
Expand Down Expand Up @@ -593,8 +589,8 @@ void spmv_struct(const ExecutionSpace& space, const char mode[],
"KokkosSparse::spmv_struct: Both Vector inputs must have rank 1 in "
"order to call this specialization of spmv.");
// Make sure that y is non-const.
static_assert(std::is_same<typename YVector::value_type,
typename YVector::non_const_value_type>::value,
static_assert(std::is_same_v<typename YVector::value_type,
typename YVector::non_const_value_type>,
"KokkosSparse::spmv_struct: Output Vector must be non-const.");

// Check compatibility of dimensions at run time.
Expand Down Expand Up @@ -886,8 +882,8 @@ void spmv_struct(const ExecutionSpace& space, const char mode[],
static_assert(XVector::rank == YVector::rank,
"KokkosSparse::spmv: Vector ranks do not match.");
// Make sure that y is non-const.
static_assert(std::is_same<typename YVector::value_type,
typename YVector::non_const_value_type>::value,
static_assert(std::is_same_v<typename YVector::value_type,
typename YVector::non_const_value_type>,
"KokkosSparse::spmv: Output Vector must be non-const.");

// Check compatibility of dimensions at run time.
Expand Down
12 changes: 6 additions & 6 deletions sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,9 +348,9 @@ namespace Impl {
template <class Handle, class AMatrix, class XVector, class YVector>
void spmv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle,
const char mode[],
typename YVector::non_const_value_type const& alpha,
typename YVector::const_value_type& alpha,
const AMatrix& A, const XVector& x,
typename YVector::non_const_value_type const& beta,
typename YVector::const_value_type& beta,
const YVector& y) {
using offset_type = typename AMatrix::non_const_size_type;
using entry_type = typename AMatrix::non_const_ordinal_type;
Expand Down Expand Up @@ -463,9 +463,9 @@ void spmv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle,
template <class Handle, class AMatrix, class XVector, class YVector>
void spmv_mv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle,
const char mode[],
typename YVector::non_const_value_type const& alpha,
typename YVector::const_value_type& alpha,
const AMatrix& A, const XVector& x,
typename YVector::non_const_value_type const& beta,
typename YVector::const_value_type& beta,
const YVector& y) {
using offset_type = typename AMatrix::non_const_size_type;
using entry_type = typename AMatrix::non_const_ordinal_type;
Expand Down Expand Up @@ -751,9 +751,9 @@ namespace Impl {
template <class Handle, class AMatrix, class XVector, class YVector>
void spmv_bsr_rocsparse(const Kokkos::HIP& exec, Handle* handle,
const char mode[],
typename YVector::non_const_value_type const& alpha,
typename YVector::const_value_type& alpha,
const AMatrix& A, const XVector& x,
typename YVector::non_const_value_type const& beta,
typename YVector::const_value_type& beta,
const YVector& y) {
/*
rocm 5.4.0 rocsparse_*bsrmv reference:
Expand Down
45 changes: 45 additions & 0 deletions sparse/tpls/KokkosSparse_spmv_mv_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,51 @@ KOKKOSSPARSE_SPMV_MV_TPL_SPEC_AVAIL_CUSPARSE(Kokkos::Experimental::half_t, int,
#endif // defined(CUSPARSE_VERSION) && (10300 <= CUSPARSE_VERSION)
#endif

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE
#define KOKKOSSPARSE_SPMV_MV_TPL_SPEC_AVAIL_ROCSPARSE(SCALAR, XL, YL, \
MEMSPACE) \
template <> \
struct spmv_mv_tpl_spec_avail< \
Kokkos::HIP, \
KokkosSparse::Impl::SPMVHandleImpl<Kokkos::HIP, MEMSPACE, SCALAR, \
rocsparse_int, rocsparse_int>, \
KokkosSparse::CrsMatrix<const SCALAR, const rocsparse_int, \
Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>, \
const rocsparse_int>, \
Kokkos::View< \
const SCALAR**, XL, Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged | Kokkos::RandomAccess>>, \
Kokkos::View<SCALAR**, YL, Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>> { \
enum : bool { value = true }; \
};

#define AVAIL_ROCSPARSE_SCALAR_MEMSPACE(SCALAR, MEMSPACE) \
KOKKOSSPARSE_SPMV_MV_TPL_SPEC_AVAIL_ROCSPARSE(SCALAR, Kokkos::LayoutLeft, \
Kokkos::LayoutLeft, MEMSPACE) \
KOKKOSSPARSE_SPMV_MV_TPL_SPEC_AVAIL_ROCSPARSE(SCALAR, Kokkos::LayoutLeft, \
Kokkos::LayoutRight, MEMSPACE) \
KOKKOSSPARSE_SPMV_MV_TPL_SPEC_AVAIL_ROCSPARSE(SCALAR, Kokkos::LayoutRight, \
Kokkos::LayoutLeft, MEMSPACE) \
KOKKOSSPARSE_SPMV_MV_TPL_SPEC_AVAIL_ROCSPARSE(SCALAR, Kokkos::LayoutRight, \
Kokkos::LayoutRight, MEMSPACE)

#define AVAIL_ROCSPARSE_SCALAR(SCALAR) \
AVAIL_ROCSPARSE_SCALAR_MEMSPACE(SCALAR, Kokkos::HIPSpace) \
AVAIL_ROCSPARSE_SCALAR_MEMSPACE(SCALAR, Kokkos::HIPManagedSpace)

AVAIL_ROCSPARSE_SCALAR(float)
AVAIL_ROCSPARSE_SCALAR(double)
AVAIL_ROCSPARSE_SCALAR(Kokkos::complex<float>)
AVAIL_ROCSPARSE_SCALAR(Kokkos::complex<double>)

#undef AVAIL_ROCSPARSE_SCALAR_MEMSPACE
#undef AVAIL_ROCSPARSE_SCALAR
#undef KOKKOSSPARSE_SPMV_MV_TPL_SPEC_AVAIL_ROCSPARSE

#endif // KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE

} // namespace Impl
} // namespace KokkosSparse

Expand Down
Loading

0 comments on commit a955b8b

Please sign in to comment.