diff --git a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_avail.hpp b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_avail.hpp index 3ce22c630a..6846e27748 100644 --- a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_avail.hpp +++ b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_avail.hpp @@ -265,8 +265,7 @@ KOKKOSSPARSE_SPMV_MV_BSRMATRIX_TPL_SPEC_AVAIL_MKL(Kokkos::complex, enum : bool { value = true }; \ }; -// These things may also be valid before 5.4, but I haven't tested it. -#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400 +#if KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50200 KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(float, rocsparse_int, rocsparse_int, @@ -305,7 +304,7 @@ KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE(Kokkos::complex, Kokkos::LayoutRight, Kokkos::HIPSpace) -#endif // KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50400 +#endif // KOKKOSSPARSE_IMPL_ROCM_VERSION >= 50200 #undef KOKKOSSPARSE_SPMV_BSRMATRIX_TPL_SPEC_AVAIL_ROCSPARSE diff --git a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp index cc3e2a6b1e..36a64228b8 100644 --- a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp @@ -929,6 +929,30 @@ void spmv_block_impl_rocsparse( KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_create_mat_descr(&descr)); rocsparse_mat_info info; KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_create_mat_info(&info)); + + // *_ex* functions introduced in 5.4.0 +#if KOKKOSSPARSE_IMPL_ROCM_VERSION < 50400 + if constexpr (std::is_same_v) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, x_, beta_, y_)); + } else if constexpr (std::is_same_v) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_dbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, x_, beta_, y_)); + } else if constexpr (std::is_same_v>) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_cbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, x_, beta_, y_)); + } else if constexpr (std::is_same_v>) { + KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_zbsrmv( + handle, dir, trans, mb, nb, nnzb, alpha_, descr, bsr_val, bsr_row_ptr, + bsr_col_ind, block_dim, x_, beta_, y_)); + } else { + static_assert(KokkosKernels::Impl::always_false_v, + "unsupported value type for rocsparse_*bsrmv"); + } +#else if constexpr (std::is_same_v) { KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(rocsparse_sbsrmv_ex_analysis( handle, dir, trans, mb, nb, nnzb, descr, bsr_val, bsr_row_ptr, @@ -965,6 +989,7 @@ void spmv_block_impl_rocsparse( static_assert(KokkosKernels::Impl::always_false_v, "unsupported value type for rocsparse_*bsrmv"); } +#endif rocsparse_destroy_mat_descr(descr); rocsparse_destroy_mat_info(info);