Skip to content

Commit

Permalink
Fix spmv regressions (#2204)
Browse files Browse the repository at this point in the history
* Restore cusparse spmv ALG2 path for imbalanced

With correct version cutoffs

* spmv: use separate rank-1 and rank-2 tpl subhandles

* Remove redundant single-column path in native spmv_mv

* Fix unused param warning
  • Loading branch information
brian-kelley authored May 22, 2024
1 parent 6204151 commit feb1f55
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 151 deletions.
56 changes: 13 additions & 43 deletions sparse/impl/KokkosSparse_spmv_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,54 +203,24 @@ struct SPMV_MV<ExecutionSpace, Handle, AMatrix, XVector, YVector, false, false,
KOKKOSKERNELS_IMPL_COMPILE_LIBRARY> {
typedef typename YVector::non_const_value_type coefficient_type;

static void spmv_mv(const ExecutionSpace& space, Handle* handle,
// TODO: pass handle through to implementation and use tuning parameters
static void spmv_mv(const ExecutionSpace& space, Handle* /* handle */,
const char mode[], const coefficient_type& alpha,
const AMatrix& A, const XVector& x,
const coefficient_type& beta, const YVector& y) {
typedef Kokkos::ArithTraits<coefficient_type> KAT;
// Intercept special case: if x/y have only 1 column and both are
// contiguous, use the more efficient single-vector impl.
//
// We cannot do this if x or y is noncontiguous, because the column subview
// must be LayoutStride which is not ETI'd.
//
// Do not use a TPL even if one is available for the types:
// we don't want the same handle being used in both TPL and non-TPL versions
if (x.extent(1) == size_t(1) && x.span_is_contiguous() &&
y.span_is_contiguous()) {
Kokkos::View<typename XVector::const_value_type*, default_layout,
typename XVector::device_type>
x0(x.data(), x.extent(0));
Kokkos::View<typename YVector::non_const_value_type*, default_layout,
typename YVector::device_type>
y0(y.data(), y.extent(0));
if (beta == KAT::zero()) {
spmv_beta<ExecutionSpace, Handle, AMatrix, decltype(x0), decltype(y0),
0>(space, handle, mode, alpha, A, x0, beta, y0);
} else if (beta == KAT::one()) {
spmv_beta<ExecutionSpace, Handle, AMatrix, decltype(x0), decltype(y0),
1>(space, handle, mode, alpha, A, x0, beta, y0);
} else if (beta == -KAT::one()) {
spmv_beta<ExecutionSpace, Handle, AMatrix, decltype(x0), decltype(y0),
-1>(space, handle, mode, alpha, A, x0, beta, y0);
} else {
spmv_beta<ExecutionSpace, Handle, AMatrix, decltype(x0), decltype(y0),
2>(space, handle, mode, alpha, A, x0, beta, y0);
}
if (alpha == KAT::zero()) {
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, 0>(
space, mode, alpha, A, x, beta, y);
} else if (alpha == KAT::one()) {
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, 1>(
space, mode, alpha, A, x, beta, y);
} else if (alpha == -KAT::one()) {
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, -1>(
space, mode, alpha, A, x, beta, y);
} else {
if (alpha == KAT::zero()) {
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, 0>(
space, mode, alpha, A, x, beta, y);
} else if (alpha == KAT::one()) {
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, 1>(
space, mode, alpha, A, x, beta, y);
} else if (alpha == -KAT::one()) {
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, -1>(
space, mode, alpha, A, x, beta, y);
} else {
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, 2>(
space, mode, alpha, A, x, beta, y);
}
spmv_alpha_mv<ExecutionSpace, AMatrix, XVector, YVector, 2>(
space, mode, alpha, A, x, beta, y);
}
}
};
Expand Down
49 changes: 8 additions & 41 deletions sparse/src/KokkosSparse_spmv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,32 +40,6 @@ struct RANK_ONE {};
struct RANK_TWO {};
} // namespace

namespace Impl {
template <typename ExecutionSpace, typename Handle, typename AMatrix,
typename XVector, class YVector>
inline constexpr bool spmv_general_tpl_avail() {
constexpr bool isBSR = ::KokkosSparse::Experimental::is_bsr_matrix_v<AMatrix>;
if constexpr (!isBSR) {
// CRS
if constexpr (XVector::rank() == 1)
return spmv_tpl_spec_avail<ExecutionSpace, Handle, AMatrix, XVector,
YVector>::value;
else
return spmv_mv_tpl_spec_avail<ExecutionSpace, Handle, AMatrix, XVector,
YVector>::value;
} else {
// BSR
if constexpr (XVector::rank() == 1)
return spmv_bsrmatrix_tpl_spec_avail<ExecutionSpace, Handle, AMatrix,
XVector, YVector>::value;
else
return spmv_mv_bsrmatrix_tpl_spec_avail<ExecutionSpace, Handle, AMatrix,
XVector, YVector>::value;
}
return false;
}
} // namespace Impl

// clang-format off
/// \brief Kokkos sparse matrix-vector multiply.
/// Computes y := alpha*Op(A)*x + beta*y, where Op(A) is
Expand Down Expand Up @@ -248,8 +222,8 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
typename YVector::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;

// Special case: XVector/YVector are rank-2 but x,y both have one column and
// are contiguous. If a TPL is available for rank-1 vectors but not rank-2,
// take rank-1 subviews of x,y and call the rank-1 version.
// are contiguous. In this case take rank-1 subviews of x,y and call the
// rank-1 version.
if constexpr (XVector::rank() == 2) {
using XVector_SubInternal = Kokkos::View<
typename XVector::const_value_type*,
Expand All @@ -260,19 +234,12 @@ void spmv(const ExecutionSpace& space, Handle* handle, const char mode[],
typename YVector::non_const_value_type*,
typename KokkosKernels::Impl::GetUnifiedLayout<YVector>::array_layout,
typename YVector::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
if constexpr (!Impl::spmv_general_tpl_avail<
ExecutionSpace, HandleImpl, AMatrix_Internal,
XVector_Internal, YVector_Internal>() &&
Impl::spmv_general_tpl_avail<
ExecutionSpace, HandleImpl, AMatrix_Internal,
XVector_SubInternal, YVector_SubInternal>()) {
if (x.extent(1) == size_t(1) && x.span_is_contiguous() &&
y.span_is_contiguous()) {
XVector_SubInternal xsub(x.data(), x.extent(0));
YVector_SubInternal ysub(y.data(), y.extent(0));
spmv(space, handle->get_impl(), mode, alpha, A, xsub, beta, ysub);
return;
}
if (x.extent(1) == size_t(1) && x.span_is_contiguous() &&
y.span_is_contiguous()) {
XVector_SubInternal xsub(x.data(), x.extent(0));
YVector_SubInternal ysub(y.data(), y.extent(0));
spmv(space, handle->get_impl(), mode, alpha, A, xsub, beta, ysub);
return;
}
}

Expand Down
9 changes: 5 additions & 4 deletions sparse/src/KokkosSparse_spmv_handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,17 +234,18 @@ struct SPMVHandleImpl {
"SPMVHandleImpl: Ordinal must not be a const type");
SPMVHandleImpl(SPMVAlgorithm algo_) : algo(algo_) {}
~SPMVHandleImpl() {
if (tpl) delete tpl;
if (tpl_rank1) delete tpl_rank1;
if (tpl_rank2) delete tpl_rank2;
}

ImplType* get_impl() { return this; }

/// Get the SPMVAlgorithm used by this handle
SPMVAlgorithm get_algorithm() const { return this->algo; }

bool is_set_up = false;
const SPMVAlgorithm algo = SPMV_DEFAULT;
TPL_SpMV_Data<ExecutionSpace>* tpl = nullptr;
const SPMVAlgorithm algo = SPMV_DEFAULT;
TPL_SpMV_Data<ExecutionSpace>* tpl_rank1 = nullptr;
TPL_SpMV_Data<ExecutionSpace>* tpl_rank2 = nullptr;
// Expert tuning parameters for native SpMV
// TODO: expose a proper Experimental interface to set these. Currently they
// can be assigned directly in the SPMVHandle as they are public members.
Expand Down
47 changes: 21 additions & 26 deletions sparse/tpls/KokkosSparse_spmv_bsrmatrix_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ inline void spmv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
Subhandle* subhandle;
const MKLScalar* x_mkl = reinterpret_cast<const MKLScalar*>(x);
MKLScalar* y_mkl = reinterpret_cast<MKLScalar*>(y);
if (handle->is_set_up) {
subhandle = dynamic_cast<Subhandle*>(handle->tpl);
if (handle->tpl_rank1) {
subhandle = dynamic_cast<Subhandle*>(handle->tpl_rank1);
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for MKL BSR");
Expand All @@ -54,7 +54,7 @@ inline void spmv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
// Use the default execution space instance, as classic MKL does not use
// a specific instance.
subhandle = new Subhandle(ExecSpace());
handle->tpl = subhandle;
handle->tpl_rank1 = subhandle;
subhandle->descr.type = SPARSE_MATRIX_TYPE_GENERAL;
subhandle->descr.mode = SPARSE_FILL_MODE_FULL;
subhandle->descr.diag = SPARSE_DIAG_NON_UNIT;
Expand Down Expand Up @@ -87,7 +87,6 @@ inline void spmv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
const_cast<MKL_INT*>(Arowptrs + 1), const_cast<MKL_INT*>(Aentries),
Avalues_mkl));
}
handle->is_set_up = true;
}
MKLScalar alpha_mkl = KokkosSparse::Impl::KokkosToMKLScalar<Scalar>(alpha);
MKLScalar beta_mkl = KokkosSparse::Impl::KokkosToMKLScalar<Scalar>(beta);
Expand Down Expand Up @@ -124,8 +123,8 @@ inline void spmv_mv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
Subhandle* subhandle;
const MKLScalar* x_mkl = reinterpret_cast<const MKLScalar*>(x);
MKLScalar* y_mkl = reinterpret_cast<MKLScalar*>(y);
if (handle->is_set_up) {
subhandle = dynamic_cast<Subhandle*>(handle->tpl);
if (handle->tpl_rank2) {
subhandle = dynamic_cast<Subhandle*>(handle->tpl_rank2);
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for MKL BSR");
Expand All @@ -135,7 +134,7 @@ inline void spmv_mv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
// Use the default execution space instance, as classic MKL does not use
// a specific instance.
subhandle = new Subhandle(ExecSpace());
handle->tpl = subhandle;
handle->tpl_rank2 = subhandle;
subhandle->descr.type = SPARSE_MATRIX_TYPE_GENERAL;
subhandle->descr.mode = SPARSE_FILL_MODE_FULL;
subhandle->descr.diag = SPARSE_DIAG_NON_UNIT;
Expand Down Expand Up @@ -168,7 +167,6 @@ inline void spmv_mv_bsr_mkl(Handle* handle, sparse_operation_t op, Scalar alpha,
const_cast<MKL_INT*>(Arowptrs + 1), const_cast<MKL_INT*>(Aentries),
Avalues_mkl));
}
handle->is_set_up = true;
}
MKLScalar alpha_mkl = KokkosSparse::Impl::KokkosToMKLScalar<Scalar>(alpha);
MKLScalar beta_mkl = KokkosSparse::Impl::KokkosToMKLScalar<Scalar>(beta);
Expand Down Expand Up @@ -376,23 +374,22 @@ void spmv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle,

KokkosSparse::Impl::CuSparse9_SpMV_Data* subhandle;

if (handle->is_set_up) {
subhandle =
dynamic_cast<KokkosSparse::Impl::CuSparse9_SpMV_Data*>(handle->tpl);
if (handle->tpl_rank1) {
subhandle = dynamic_cast<KokkosSparse::Impl::CuSparse9_SpMV_Data*>(
handle->tpl_rank1);
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
/* create and set the subhandle and matrix descriptor */
subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec);
handle->tpl = subhandle;
subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec);
handle->tpl_rank1 = subhandle;
KOKKOS_CUSPARSE_SAFE_CALL(cusparseCreateMatDescr(&subhandle->mat));
KOKKOS_CUSPARSE_SAFE_CALL(
cusparseSetMatType(subhandle->mat, CUSPARSE_MATRIX_TYPE_GENERAL));
KOKKOS_CUSPARSE_SAFE_CALL(
cusparseSetMatIndexBase(subhandle->mat, CUSPARSE_INDEX_BASE_ZERO));
handle->is_set_up = true;
}

cusparseDirection_t dirA = CUSPARSE_DIRECTION_ROW;
Expand Down Expand Up @@ -504,23 +501,22 @@ void spmv_mv_bsr_cusparse(const Kokkos::Cuda& exec, Handle* handle,

KokkosSparse::Impl::CuSparse9_SpMV_Data* subhandle;

if (handle->is_set_up) {
subhandle =
dynamic_cast<KokkosSparse::Impl::CuSparse9_SpMV_Data*>(handle->tpl);
if (handle->tpl_rank2) {
subhandle = dynamic_cast<KokkosSparse::Impl::CuSparse9_SpMV_Data*>(
handle->tpl_rank2);
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
/* create and set the subhandle and matrix descriptor */
subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec);
handle->tpl = subhandle;
subhandle = new KokkosSparse::Impl::CuSparse9_SpMV_Data(exec);
handle->tpl_rank2 = subhandle;
KOKKOS_CUSPARSE_SAFE_CALL(cusparseCreateMatDescr(&subhandle->mat));
KOKKOS_CUSPARSE_SAFE_CALL(
cusparseSetMatType(subhandle->mat, CUSPARSE_MATRIX_TYPE_GENERAL));
KOKKOS_CUSPARSE_SAFE_CALL(
cusparseSetMatIndexBase(subhandle->mat, CUSPARSE_INDEX_BASE_ZERO));
handle->is_set_up = true;
}
cusparseDirection_t dirA = CUSPARSE_DIRECTION_ROW;

Expand Down Expand Up @@ -855,16 +851,16 @@ void spmv_bsr_rocsparse(const Kokkos::HIP& exec, Handle* handle,
rocsparse_value_type* y_ = reinterpret_cast<rocsparse_value_type*>(y.data());

KokkosSparse::Impl::RocSparse_BSR_SpMV_Data* subhandle;
if (handle->is_set_up) {
subhandle =
dynamic_cast<KokkosSparse::Impl::RocSparse_BSR_SpMV_Data*>(handle->tpl);
if (handle->tpl_rank1) {
subhandle = dynamic_cast<KokkosSparse::Impl::RocSparse_BSR_SpMV_Data*>(
handle->tpl_rank1);
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for rocsparse BSR");
subhandle->set_exec_space(exec);
} else {
subhandle = new KokkosSparse::Impl::RocSparse_BSR_SpMV_Data(exec);
handle->tpl = subhandle;
subhandle = new KokkosSparse::Impl::RocSparse_BSR_SpMV_Data(exec);
handle->tpl_rank1 = subhandle;
KOKKOS_ROCSPARSE_SAFE_CALL_IMPL(
rocsparse_create_mat_descr(&subhandle->mat));
// *_ex* functions deprecated in introduced in 6+
Expand Down Expand Up @@ -918,7 +914,6 @@ void spmv_bsr_rocsparse(const Kokkos::HIP& exec, Handle* handle,
"unsupported value type for rocsparse_*bsrmv");
}
#endif
handle->is_set_up = true;
}

// *_ex* functions deprecated in introduced in 6+
Expand Down
12 changes: 5 additions & 7 deletions sparse/tpls/KokkosSparse_spmv_mv_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,16 @@ void spmv_mv_cusparse(const Kokkos::Cuda &exec, Handle *handle,
}

KokkosSparse::Impl::CuSparse10_SpMV_Data *subhandle;
if (handle->is_set_up) {
subhandle =
dynamic_cast<KokkosSparse::Impl::CuSparse10_SpMV_Data *>(handle->tpl);
if (handle->tpl_rank2) {
subhandle = dynamic_cast<KokkosSparse::Impl::CuSparse10_SpMV_Data *>(
handle->tpl_rank2);
if (!subhandle)
throw std::runtime_error(
"KokkosSparse::spmv: subhandle is not set up for cusparse");
subhandle->set_exec_space(exec);
} else {
subhandle = new KokkosSparse::Impl::CuSparse10_SpMV_Data(exec);
handle->tpl = subhandle;
subhandle = new KokkosSparse::Impl::CuSparse10_SpMV_Data(exec);
handle->tpl_rank2 = subhandle;
/* create matrix */
KOKKOS_CUSPARSE_SAFE_CALL(cusparseCreateCsr(
&subhandle->mat, A.numRows(), A.numCols(), A.nnz(),
Expand All @@ -209,8 +209,6 @@ void spmv_mv_cusparse(const Kokkos::Cuda &exec, Handle *handle,

KOKKOS_IMPL_CUDA_SAFE_CALL(
cudaMalloc(&subhandle->buffer, subhandle->bufferSize));

handle->is_set_up = true;
}

KOKKOS_CUSPARSE_SAFE_CALL(cusparseSpMM(cusparseHandle, opA, opB, &alpha,
Expand Down
Loading

0 comments on commit feb1f55

Please sign in to comment.