Skip to content

Commit

Permalink
Fix #2156
Browse files Browse the repository at this point in the history
spmv: add special path for rank-2 x/y, but where both have 1 column
and a TPL is available for rank-1 but not rank-2.

Also call "subhandle->set_exec_space" correctly in the TPLs to ensure
proper synchronization between setup, spmv and cleanup (in the case that
different exec instances are used in different calls)
  • Loading branch information
brian-kelley committed Apr 2, 2024
1 parent 02ea952 commit c402d96
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 3 deletions.
54 changes: 54 additions & 0 deletions sparse/src/KokkosSparse_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,31 @@ struct RANK_ONE {};
struct RANK_TWO {};
} // namespace

namespace Impl {
template <typename ExecutionSpace, typename Handle, typename AMatrix,
typename XVector, class YVector>
inline constexpr bool spmv_general_tpl_avail() {
constexpr bool isBSR = ::KokkosSparse::Experimental::is_bsr_matrix_v<AMatrix>;
if constexpr (!isBSR) {
// CRS
if constexpr (XVector::rank() == 1)
return spmv_tpl_spec_avail<ExecutionSpace, Handle, AMatrix, XVector,
YVector>::value;
else
return spmv_mv_tpl_spec_avail<ExecutionSpace, Handle, AMatrix, XVector,
YVector>::value;
} else {
// BSR
if constexpr (XVector::rank() == 1)
return spmv_bsrmatrix_tpl_spec_avail<ExecutionSpace, Handle, AMatrix,
XVector, YVector>::value;
else
return spmv_mv_bsrmatrix_tpl_spec_avail<ExecutionSpace, Handle, AMatrix,
XVector, YVector>::value;
}
}
} // namespace Impl

// clang-format off
/// \brief Kokkos sparse matrix-vector multiply.
/// Computes y := alpha*Op(A)*x + beta*y, where Op(A) is
Expand Down Expand Up @@ -221,6 +246,35 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
typename KokkosKernels::Impl::GetUnifiedLayout<YVector>::array_layout,
typename YVector::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

// Special case: XVector/YVector are rank-2 but x,y both have one column and
// are contiguous. If a TPL is available for rank-1 vectors but not rank-2,
// take rank-1 subviews of x,y and call the rank-1 version.
if constexpr (XVector::rank() == 2) {
using XVector_SubInternal = Kokkos::View<
typename XVector::const_value_type*,
typename KokkosKernels::Impl::GetUnifiedLayout<XVector>::array_layout,
typename XVector::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged | Kokkos::RandomAccess>>;
using YVector_SubInternal = Kokkos::View<
typename YVector::non_const_value_type*,
typename KokkosKernels::Impl::GetUnifiedLayout<YVector>::array_layout,
typename YVector::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
if constexpr (!Impl::spmv_general_tpl_avail<
ExecutionSpace, HandleImpl, AMatrix_Internal,
XVector_Internal, YVector_Internal>() &&
Impl::spmv_general_tpl_avail<
ExecutionSpace, HandleImpl, AMatrix_Internal,
XVector_SubInternal, YVector_SubInternal>()) {
if (x.extent(1) == size_t(1) && x.span_is_contiguous() &&
y.span_is_contiguous()) {
XVector_SubInternal xsub(x.data(), x.extent(0));
YVector_SubInternal ysub(y.data(), y.extent(0));
spmv(space, handle->get_impl(), mode, alpha, A, xsub, beta, ysub);
return;
}
}
}

XVector_Internal x_i(x);
YVector_Internal y_i(y);

Expand Down
5 changes: 2 additions & 3 deletions sparse/src/KokkosSparse_spmv_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,8 @@ struct SPMVHandleImpl {
~SPMVHandleImpl() {
if (tpl) delete tpl;
}
void set_exec_space(const ExecutionSpace& exec) {
if (tpl) tpl->set_exec_space(exec);
}

ImplType* get_impl() { return this; }

/// Get the SPMVAlgorithm used by this handle
SPMVAlgorithm get_algorithm() const { return this->algo; }
Expand Down
5 changes: 5 additions & 0 deletions sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ inline void spmv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for MKL BSR");
subhandle->set_exec_space(exec);
} else {
// Use the default execution space instance, as classic MKL does not use
// a specific instance.
Expand Down Expand Up @@ -127,6 +128,7 @@ inline void spmv_mv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for MKL BSR");
subhandle->set_exec_space(exec);
} else {
// Use the default execution space instance, as classic MKL does not use
// a specific instance.
Expand Down Expand Up @@ -378,6 +380,7 @@ void spmv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle,
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
/* create and set the subhandle and matrix descriptor */
subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec);
Expand Down Expand Up @@ -505,6 +508,7 @@ void spmv_mv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle,
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
/* create and set the subhandle and matrix descriptor */
subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec);
Expand Down Expand Up @@ -855,6 +859,7 @@ void spmv_bsr_rocsparse(const Kokkos::HIP& exec, Handle* handle,
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for rocsparse BSR");
subhandle->set_exec_space(exec);
} else {
subhandle = new KokkosSparse::Impl::RocSparse_BSR_SpMV_Data(exec);
handle->tpl = subhandle;
Expand Down
1 change: 1 addition & 0 deletions sparse/tpls/KokkosSparse_spmv_mv_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ void spmv_mv_cusparse(const Kokkos::Cuda &exec, Handle *handle,
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
subhandle = new KokkosSparse::Impl::CuSparse10_SpMV_Data(exec);
handle->tpl = subhandle;
Expand Down
6 changes: 6 additions & 0 deletions sparse/tpls/KokkosSparse_spmv_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ void spmv_cusparse(const Kokkos::Cuda& exec, Handle* handle, const char mode[],
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
subhandle = new KokkosSparse::Impl::CuSparse10_SpMV_Data(exec);
handle->tpl = subhandle;
Expand Down Expand Up @@ -155,6 +156,7 @@ void spmv_cusparse(const Kokkos::Cuda& exec, Handle* handle, const char mode[],
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
/* create and set the subhandle and matrix descriptor */
subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec);
Expand Down Expand Up @@ -390,6 +392,7 @@ void spmv_rocsparse(const Kokkos::HIP& exec, Handle* handle, const char mode[],
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for rocsparse CRS");
subhandle->set_exec_space(exec);
} else {
subhandle = new KokkosSparse::Impl::RocSparse_CRS_SpMV_Data(exec);
handle->tpl = subhandle;
Expand Down Expand Up @@ -550,6 +553,8 @@ inline void spmv_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
MKLScalar* y_mkl = reinterpret_cast<MKLScalar*>(y);
if (handle->is_set_up) {
subhandle = dynamic_cast<Subhandle*>(handle->tpl);
// note: classic mkl only runs on synchronous host exec spaces, so no need
// to call set_exec_space on the subhandle here
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for MKL CRS");
Expand Down Expand Up @@ -710,6 +715,7 @@ inline void spmv_onemkl(const execution_space& exec, Handle* handle,
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for OneMKL CRS");
subhandle->set_exec_space(exec);
} else {
subhandle = new OneMKL_SpMV_Data(exec);
handle->tpl = subhandle;
Expand Down

0 comments on commit c402d96

Please sign in to comment.