From c402d960b992db3bd7e68c79bf556a746e6097a1 Mon Sep 17 00:00:00 2001 From: Brian Kelley Date: Mon, 1 Apr 2024 15:41:34 -0600 Subject: [PATCH] Fix #2156 spmv: add special path for rank-2 x/y, but where both have 1 column and a TPL is available for rank-1 but not rank-2. Also call "subhandle->set_exec_space" correctly in the TPLs to ensure proper synchronization between setup, spmv and cleanup (in the case that different exec instances are used in different calls) --- sparse/src/KokkosSparse_spmv.hpp | 54 +++++++++++++++++++ sparse/src/KokkosSparse_spmv_handle.hpp | 5 +- ...kosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp | 5 ++ .../KokkosSparse_spmv_mv_tpl_spec_decl.hpp | 1 + .../tpls/KokkosSparse_spmv_tpl_spec_decl.hpp | 6 +++ 5 files changed, 68 insertions(+), 3 deletions(-) diff --git a/sparse/src/KokkosSparse_spmv.hpp b/sparse/src/KokkosSparse_spmv.hpp index 2391291695..f11b61f675 100644 --- a/sparse/src/KokkosSparse_spmv.hpp +++ b/sparse/src/KokkosSparse_spmv.hpp @@ -40,6 +40,31 @@ struct RANK_ONE {}; struct RANK_TWO {}; } // namespace +namespace Impl { +template +inline constexpr bool spmv_general_tpl_avail() { + constexpr bool isBSR = ::KokkosSparse::Experimental::is_bsr_matrix_v; + if constexpr (!isBSR) { + // CRS + if constexpr (XVector::rank() == 1) + return spmv_tpl_spec_avail::value; + else + return spmv_mv_tpl_spec_avail::value; + } else { + // BSR + if constexpr (XVector::rank() == 1) + return spmv_bsrmatrix_tpl_spec_avail::value; + else + return spmv_mv_bsrmatrix_tpl_spec_avail::value; + } +} +} // namespace Impl + // clang-format off /// \brief Kokkos sparse matrix-vector multiply. /// Computes y := alpha*Op(A)*x + beta*y, where Op(A) is @@ -221,6 +246,35 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[], typename KokkosKernels::Impl::GetUnifiedLayout::array_layout, typename YVector::device_type, Kokkos::MemoryTraits>; + // Special case: XVector/YVector are rank-2 but x,y both have one column and + // are contiguous. If a TPL is available for rank-1 vectors but not rank-2, + // take rank-1 subviews of x,y and call the rank-1 version. + if constexpr (XVector::rank() == 2) { + using XVector_SubInternal = Kokkos::View< + typename XVector::const_value_type*, + typename KokkosKernels::Impl::GetUnifiedLayout::array_layout, + typename XVector::device_type, + Kokkos::MemoryTraits>; + using YVector_SubInternal = Kokkos::View< + typename YVector::non_const_value_type*, + typename KokkosKernels::Impl::GetUnifiedLayout::array_layout, + typename YVector::device_type, Kokkos::MemoryTraits>; + if constexpr (!Impl::spmv_general_tpl_avail< + ExecutionSpace, HandleImpl, AMatrix_Internal, + XVector_Internal, YVector_Internal>() && + Impl::spmv_general_tpl_avail< + ExecutionSpace, HandleImpl, AMatrix_Internal, + XVector_SubInternal, YVector_SubInternal>()) { + if (x.extent(1) == size_t(1) && x.span_is_contiguous() && + y.span_is_contiguous()) { + XVector_SubInternal xsub(x.data(), x.extent(0)); + YVector_SubInternal ysub(y.data(), y.extent(0)); + spmv(space, handle->get_impl(), mode, alpha, A, xsub, beta, ysub); + return; + } + } + } + XVector_Internal x_i(x); YVector_Internal y_i(y); diff --git a/sparse/src/KokkosSparse_spmv_handle.hpp b/sparse/src/KokkosSparse_spmv_handle.hpp index 9e7295c72c..a2eecfd1ce 100644 --- a/sparse/src/KokkosSparse_spmv_handle.hpp +++ b/sparse/src/KokkosSparse_spmv_handle.hpp @@ -237,9 +237,8 @@ struct SPMVHandleImpl { ~SPMVHandleImpl() { if (tpl) delete tpl; } - void set_exec_space(const ExecutionSpace& exec) { - if (tpl) tpl->set_exec_space(exec); - } + + ImplType* get_impl() { return this; } /// Get the SPMVAlgorithm used by this handle SPMVAlgorithm get_algorithm() const { return this->algo; } diff --git a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp index 72062c26fb..188bc5580d 100644 --- a/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp @@ -48,6 +48,7 @@ inline void spmv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha, if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for MKL BSR"); + subhandle->set_exec_space(exec); } else { // Use the default execution space instance, as classic MKL does not use // a specific instance. @@ -127,6 +128,7 @@ inline void spmv_mv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha, if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for MKL BSR"); + subhandle->set_exec_space(exec); } else { // Use the default execution space instance, as classic MKL does not use // a specific instance. @@ -378,6 +380,7 @@ void spmv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle, if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for cusparse"); + subhandle->set_exec_space(exec); } else { /* create and set the subhandle and matrix descriptor */ subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec); @@ -505,6 +508,7 @@ void spmv_mv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle, if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for cusparse"); + subhandle->set_exec_space(exec); } else { /* create and set the subhandle and matrix descriptor */ subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec); @@ -855,6 +859,7 @@ void spmv_bsr_rocsparse(const Kokkos::HIP& exec, Handle* handle, if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for rocsparse BSR"); + subhandle->set_exec_space(exec); } else { subhandle = new KokkosSparse::Impl::RocSparse_BSR_SpMV_Data(exec); handle->tpl = subhandle; diff --git a/sparse/tpls/KokkosSparse_spmv_mv_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_mv_tpl_spec_decl.hpp index 2ccfd89d73..500fbddbe7 100644 --- a/sparse/tpls/KokkosSparse_spmv_mv_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_mv_tpl_spec_decl.hpp @@ -192,6 +192,7 @@ void spmv_mv_cusparse(const Kokkos::Cuda &exec, Handle *handle, if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for cusparse"); + subhandle->set_exec_space(exec); } else { subhandle = new KokkosSparse::Impl::CuSparse10_SpMV_Data(exec); handle->tpl = subhandle; diff --git a/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp b/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp index a11fdf68b2..cd3e99ef81 100644 --- a/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp +++ b/sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp @@ -111,6 +111,7 @@ void spmv_cusparse(const Kokkos::Cuda& exec, Handle* handle, const char mode[], if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for cusparse"); + subhandle->set_exec_space(exec); } else { subhandle = new KokkosSparse::Impl::CuSparse10_SpMV_Data(exec); handle->tpl = subhandle; @@ -155,6 +156,7 @@ void spmv_cusparse(const Kokkos::Cuda& exec, Handle* handle, const char mode[], if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for cusparse"); + subhandle->set_exec_space(exec); } else { /* create and set the subhandle and matrix descriptor */ subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec); @@ -390,6 +392,7 @@ void spmv_rocsparse(const Kokkos::HIP& exec, Handle* handle, const char mode[], if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for rocsparse CRS"); + subhandle->set_exec_space(exec); } else { subhandle = new KokkosSparse::Impl::RocSparse_CRS_SpMV_Data(exec); handle->tpl = subhandle; @@ -550,6 +553,8 @@ inline void spmv_mkl(Handle* handle, sparse_operation_t op, Scalar alpha, MKLScalar* y_mkl = reinterpret_cast(y); if (handle->is_set_up) { subhandle = dynamic_cast(handle->tpl); + // note: classic mkl only runs on synchronous host exec spaces, so no need + // to call set_exec_space on the subhandle here if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for MKL CRS"); @@ -710,6 +715,7 @@ inline void spmv_onemkl(const execution_space& exec, Handle* handle, if (!subhandle) throw std::runtime_error( "KokkosSparse::spmv: subhandle is not set up for OneMKL CRS"); + subhandle->set_exec_space(exec); } else { subhandle = new OneMKL_SpMV_Data(exec); handle->tpl = subhandle;