diff --git a/sparse/src/KokkosSparse_Utils_rocsparse.hpp b/sparse/src/KokkosSparse_Utils_rocsparse.hpp index dd479610ca..b146aff782 100644 --- a/sparse/src/KokkosSparse_Utils_rocsparse.hpp +++ b/sparse/src/KokkosSparse_Utils_rocsparse.hpp @@ -18,6 +18,7 @@ #define _KOKKOSKERNELS_SPARSEUTILS_ROCSPARSE_HPP #ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE +#include #include "rocsparse/rocsparse.h" namespace KokkosSparse { @@ -164,6 +165,9 @@ struct kokkos_to_rocsparse_type> { using type = rocsparse_double_complex; }; +#define KOKKOSSPARSE_IMPL_ROCM_VERSION \ + ROCM_VERSION_MAJOR * 10000 + ROCM_VERSION_MINOR * 100 + ROCM_VERSION_PATCH + } // namespace Impl } // namespace KokkosSparse diff --git a/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp index f223ed0e5a..db719b43d8 100644 --- a/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp @@ -343,6 +343,7 @@ KOKKOSSPARSE_SPMV_CUSPARSE(Kokkos::complex, int64_t, size_t, // rocSPARSE #if defined(KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE) #include +#include #include "KokkosSparse_Utils_rocsparse.hpp" namespace KokkosSparse { @@ -421,13 +422,24 @@ void spmv_rocsparse(const KokkosKernels::Experimental::Controls& controls, else if (algName == "merge") alg = rocsparse_spmv_alg_csr_stream; } - KOKKOS_ROCSPARSE_SAFE_CALL_IMPL( - rocsparse_spmv(handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, - vecY, compute_type, alg, &buffer_size, tmp_buffer)); + +#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400 + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_spmv_ex( + handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, vecY, + compute_type, alg, rocsparse_spmv_stage_auto, &buffer_size, tmp_buffer)); KOKKOS_IMPL_HIP_SAFE_CALL(hipMalloc(&tmp_buffer, buffer_size)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_spmv_ex( + handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, vecY, + compute_type, alg, rocsparse_spmv_stage_auto, &buffer_size, tmp_buffer)); +#else KOKKOS_ROCSPARSE_SAFE_CALL_IMPL( rocsparse_spmv(handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, vecY, compute_type, alg, &buffer_size, tmp_buffer)); + KOKKOS_IMPL_HIP_SAFE_CALL(hipMalloc(&tmp_buffer, buffer_size)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_spmv_ex( + handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, vecY, + compute_type, alg, &buffer_size, tmp_buffer)); +#endif KOKKOS_IMPL_HIP_SAFE_CALL(hipFree(tmp_buffer)); KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_destroy_dnvec_descr(vecY));