diff --git a/sparse/src/KokkosSparse_Utils_rocsparse.hpp b/sparse/src/KokkosSparse_Utils_rocsparse.hpp index cc34e55093..baf2d3a822 100644 --- a/sparse/src/KokkosSparse_Utils_rocsparse.hpp +++ b/sparse/src/KokkosSparse_Utils_rocsparse.hpp @@ -21,8 +21,12 @@ #include #ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE +#if __has_include() +#include +#else #include -#include "rocsparse/rocsparse.h" +#endif +#include namespace KokkosSparse { namespace Impl { diff --git a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp index 97019e4682..75752190e7 100644 --- a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp @@ -869,8 +869,46 @@ void spmv_block_impl_rocsparse( rocsparse_mat_info info; KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_create_mat_info(&info)); + // *_ex* functions deprecated in introduced in 6+ +#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 60000 + if constexpr (std::is_same_v) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv_analysis( + handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info, x_, beta_, y_)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info)); + } else if constexpr (std::is_same_v) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_dbsrmv_analysis( + handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_dbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info, x_, beta_, y_)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info)); + } else if constexpr (std::is_same_v>) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_cbsrmv_analysis( + handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_cbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info, x_, beta_, y_)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info)); + } else if constexpr (std::is_same_v>) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_zbsrmv_analysis( + handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_zbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, info, x_, beta_, y_)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_bsrsv_clear(handle, info)); + } else { + static_assert(KokkosKernels::Impl::always_false_v, + "unsupported value type for rocsparse_*bsrmv"); + } // *_ex* functions introduced in 5.4.0 -#if KOKKOSSPARSE_IMPL_ROCM_VERSION < 50400 +#elif KOKKOSSPARSE_IMPL_ROCM_VERSION < 50400 if constexpr (std::is_same_v) { KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv( handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, diff --git a/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp index e12ee23937..99799ced5d 100644 --- a/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp @@ -359,8 +359,6 @@ KOKKOSSPARSE_SPMV_CUSPARSE(Kokkos::complex, int64_t, size_t, // rocSPARSE #if defined(KOKKOSKERNELS_ENABLE_TPL_ROCSPARSE) -#include -#include #include "KokkosSparse_Utils_rocsparse.hpp" namespace KokkosSparse { @@ -443,7 +441,17 @@ void spmv_rocsparse(const Kokkos::HIP& exec, alg = rocsparse_spmv_alg_csr_stream; } -#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400 +#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 60000 + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL( + rocsparse_spmv(handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, + vecY, compute_type, alg, rocsparse_spmv_stage_buffer_size, + &buffer_size, tmp_buffer)); + KOKKOS_IMPL_HIP_SAFE_CALL(hipMalloc(&tmp_buffer, buffer_size)); + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL( + rocsparse_spmv(handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, + vecY, compute_type, alg, rocsparse_spmv_stage_compute, + &buffer_size, tmp_buffer)); +#elif KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400 KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_spmv_ex( handle, myRocsparseOperation, &alpha, Aspmat, vecX, &beta, vecY, compute_type, alg, rocsparse_spmv_stage_auto, &buffer_size, tmp_buffer));