Skip to content

Commit

Permalink
Adds the merge-based spmv behind an opt-in `controls.setParameter("al…
Browse files Browse the repository at this point in the history
…gorithm", "merge")`.

This SpMV is up to 200x faster than the existing native implementation for matrices with highly skewed row lengths.

* Adds `KokkosSparse::Impl::diagonal_search` to support the merge-based SpMV
* Moves `KokkosKernels_Iota.hpp` into the `impl` directory
* Removes the static assert for `Kokkos::View` and `Kokkos::Iota` from `lower_bound` and friends, since they can also operate on things that appear like a view but are not present in the common component
* Changes `KokkosSparse::Impl::MergeMatrixDiagonal::MatrixPosition` -> `KokkosSparse::Impl::MergeMatrixPosition`. Various `MergeMatrixDiagonals` have the same `MatrixPosition` type, which were technically different because they were scoped.
  • Loading branch information
cwpearson committed Jul 20, 2023
1 parent ce0653d commit ba7a062
Show file tree
Hide file tree
Showing 8 changed files with 551 additions and 63 deletions.
File renamed without changes.
22 changes: 6 additions & 16 deletions common/src/KokkosKernels_LowerBound.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ namespace Impl {

/*! \brief Single-thread sequential lower-bound search
\tparam ViewLike A Kokkos::View or KokkosKernels::Impl::Iota
\tparam Pred a binary predicate function
\tparam ViewLike A Kokkos::View, KokkosKernels::Impl::Iota, or
KokkosSparse::MergeMatrixDiagonal \tparam Pred a binary predicate function
\param view the view to search
\param value the value to search for
\param pred a binary predicate function
Expand All @@ -96,9 +96,6 @@ lower_bound_sequential_thread(
using size_type = typename ViewLike::size_type;
static_assert(1 == ViewLike::rank,
"lower_bound_sequential_thread requires rank-1 views");
static_assert(is_iota_v<ViewLike> || Kokkos::is_view<ViewLike>::value,
"lower_bound_sequential_thread requires a "
"KokkosKernels::Impl::Iota or a Kokkos::View");

size_type i = 0;
while (i < view.size() && pred(view(i), value)) {
Expand All @@ -109,8 +106,8 @@ lower_bound_sequential_thread(

/*! \brief Single-thread binary lower-bound search
\tparam ViewLike A Kokkos::View or KokkosKernels::Impl::Iota
\tparam Pred a binary predicate function
\tparam ViewLike A Kokkos::View, KokkosKernels::Impl::Iota, or
KokkosSparse::MergeMatrixDiagonal \tparam Pred a binary predicate function
\param view the view to search
\param value the value to search for
\param pred a binary predicate function
Expand All @@ -127,9 +124,6 @@ KOKKOS_INLINE_FUNCTION typename ViewLike::size_type lower_bound_binary_thread(
using size_type = typename ViewLike::size_type;
static_assert(1 == ViewLike::rank,
"lower_bound_binary_thread requires rank-1 views");
static_assert(is_iota_v<ViewLike> || Kokkos::is_view<ViewLike>::value,
"lower_bound_binary_thread requires a "
"KokkosKernels::Impl::Iota or a Kokkos::View");

size_type lo = 0;
size_type hi = view.size();
Expand All @@ -150,8 +144,8 @@ KOKKOS_INLINE_FUNCTION typename ViewLike::size_type lower_bound_binary_thread(

/*! \brief single-thread lower-bound search
\tparam ViewLike A Kokkos::View or KokkosKernels::Impl::Iota
\tparam Pred a binary predicate function
\tparam ViewLike A Kokkos::View, KokkosKernels::Impl::Iota, or
KokkosSparse::MergeMatrixDiagonal \tparam Pred a binary predicate function
\param view the view to search
\param value the value to search for
\param pred a binary predicate function
Expand All @@ -168,10 +162,6 @@ KOKKOS_INLINE_FUNCTION typename ViewLike::size_type lower_bound_thread(
Pred pred = Pred()) {
static_assert(1 == ViewLike::rank,
"lower_bound_thread requires rank-1 views");
static_assert(KokkosKernels::Impl::is_iota_v<ViewLike> ||
Kokkos::is_view<ViewLike>::value,
"lower_bound_thread requires a "
"KokkosKernels::Impl::Iota or a Kokkos::View");
/*
sequential search makes on average 0.5 * view.size memory accesses
binary search makes log2(view.size)+1 accesses
Expand Down
1 change: 0 additions & 1 deletion perf_test/sparse/KokkosSparse_spmv_bsr_benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,6 @@ void run(benchmark::State &state, const Bsr &bsr, const size_t k) {
template <typename Bsr, typename Spmv>
void read_expand_run(benchmark::State &state, const fs::path &path,
const size_t blockSize, const size_t k) {
using device_type = typename Bsr::device_type;
using scalar_type = typename Bsr::non_const_value_type;
using ordinal_type = typename Bsr::non_const_ordinal_type;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@
#include <type_traits>

#include "KokkosKernels_Iota.hpp"
#include "KokkosKernels_LowerBound.hpp"
#include "KokkosKernels_Predicates.hpp"
#include "KokkosKernels_SafeCompare.hpp"

/// \file KokkosSparse_MergeMatrix.hpp
/// \file KokkosSparse_merge_matrix.hpp

namespace KokkosSparse {
namespace Experimental {
namespace Impl {
namespace KokkosSparse::Impl {

// a joint index into a and b
template <typename AIndex, typename BIndex>
struct MergeMatrixPosition {
AIndex ai;
BIndex bi;
};

/*! \class MergeMatrixDiagonal
\brief a view into the entries of the Merge Matrix along a diagonal
Expand Down Expand Up @@ -88,14 +95,7 @@ class MergeMatrixDiagonal {
using a_value_type = typename AView::non_const_value_type;
using b_value_type = typename BViewLike::non_const_value_type;

/*! \struct MatrixPosition
* \brief indices into the a_ and b_ views.
*/
struct MatrixPosition {
a_index_type ai;
b_index_type bi;
};
using position_type = MatrixPosition;
using position_type = MergeMatrixPosition<a_index_type, b_index_type>;

// implement bare minimum parts of the view interface
enum { rank = 1 };
Expand Down Expand Up @@ -145,9 +145,9 @@ class MergeMatrixDiagonal {
KOKKOS_INLINE_FUNCTION
bool operator()(const size_type di) const {
position_type pos = diag_to_a_b(di);
if (pos.ai >= a_.size()) {
if (size_t(pos.ai) >= a_.size()) {
return true; // on the +a side out of matrix bounds is 1
} else if (pos.bi >= b_.size()) {
} else if (size_t(pos.bi) >= b_.size()) {
return false; // on the +b side out of matrix bounds is 0
} else {
return KokkosKernels::Impl::safe_gt(a_(pos.ai), b_(pos.bi));
Expand Down Expand Up @@ -192,8 +192,106 @@ class MergeMatrixDiagonal {
size_type d_; ///< diagonal
};

} // namespace Impl
} // namespace Experimental
} // namespace KokkosSparse
/*! \brief Return the first index on diagonal \code diag
in the merge matrix of \code a and \code b that is not 1
This is effectively a lower-bound search on the merge matrix diagonal
where the predicate is "equals 1"
*/
template <typename AView, typename BViewLike>
KOKKOS_INLINE_FUNCTION
typename MergeMatrixDiagonal<AView, BViewLike>::position_type
diagonal_search(
const AView &a, const BViewLike &b,
typename MergeMatrixDiagonal<AView, BViewLike>::size_type diag) {
// unmanaged view types for a and b
using um_a_view =
Kokkos::View<typename AView::value_type *, typename AView::device_type,
Kokkos::MemoryUnmanaged>;
using um_b_view =
Kokkos::View<typename BViewLike::value_type *,
typename BViewLike::device_type, Kokkos::MemoryUnmanaged>;

um_a_view ua(a.data(), a.size());

// if BViewLike is an Iota, pass it on directly to MMD,
// otherwise, create an unmanaged view of B
using b_type =
typename std::conditional<KokkosKernels::Impl::is_iota<BViewLike>::value,
BViewLike, um_b_view>::type;

using MMD = MergeMatrixDiagonal<um_a_view, b_type>;
MMD mmd;
if constexpr (KokkosKernels::Impl::is_iota<BViewLike>::value) {
mmd = MMD(ua, b, diag);
} else {
b_type ub(b.data(), b.size());
mmd = MMD(ua, ub, diag);
}

// returns index of the first element that does not satisfy pred(element,
// value) our input view is the merge matrix entry along the diagonal, and we
// want the first one that is not true. so our predicate just tells us if the
// merge matrix diagonal entry is equal to true or not
const typename MMD::size_type idx = KokkosKernels::lower_bound_thread(
mmd, true, KokkosKernels::Equal<bool>());
return mmd.position(idx);
}

template <typename TeamMember, typename AView, typename BViewLike>
KOKKOS_INLINE_FUNCTION
typename MergeMatrixDiagonal<AView, BViewLike>::position_type
diagonal_search(
const TeamMember &handle, const AView &a, const BViewLike &b,
typename MergeMatrixDiagonal<AView, BViewLike>::size_type diag) {
// unmanaged view types for a and b
using um_a_view =
Kokkos::View<typename AView::value_type *, typename AView::device_type,
Kokkos::MemoryUnmanaged>;
using um_b_view =
Kokkos::View<typename BViewLike::value_type *,
typename BViewLike::device_type, Kokkos::MemoryUnmanaged>;

um_a_view ua(a.data(), a.size());

// if BViewLike is an Iota, pass it on directly to MMD,
// otherwise, create an unmanaged view of B
using b_type =
typename std::conditional<KokkosKernels::Impl::is_iota<BViewLike>::value,
BViewLike, um_b_view>::type;

using MMD = MergeMatrixDiagonal<um_a_view, b_type>;
MMD mmd;
if constexpr (KokkosKernels::Impl::is_iota<BViewLike>::value) {
mmd = MMD(ua, b, diag);
} else {
b_type ub(b.data(), b.size());
mmd = MMD(ua, ub, diag);
}

// returns index of the first element that does not satisfy pred(element,
// value) our input view is the merge matrix entry along the diagonal, and we
// want the first one that is not true. so our predicate just tells us if the
// merge matrix diagonal entry is equal to true or not
const typename MMD::size_type idx = KokkosKernels::lower_bound_team(
handle, mmd, true, KokkosKernels::Equal<bool>());
return mmd.position(idx);
}

/*! \brief
\return A MergeMatrixDiagonal::position_type
*/
template <typename View>
KOKKOS_INLINE_FUNCTION auto diagonal_search(
const View &a, typename View::non_const_value_type totalWork,
typename View::size_type diag) {
using value_type = typename View::non_const_value_type;
using size_type = typename View::size_type;

KokkosKernels::Impl::Iota<value_type, size_type> iota(totalWork);
return diagonal_search(a, iota, diag);
}

} // namespace KokkosSparse::Impl

#endif // KOKKOSSPARSE_MERGEMATRIX_HPP
21 changes: 17 additions & 4 deletions sparse/impl/KokkosSparse_spmv_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,14 @@
#include "KokkosKernels_ExecSpaceUtils.hpp"
#include "KokkosSparse_CrsMatrix.hpp"
#include "KokkosSparse_spmv_impl_omp.hpp"
#include "KokkosSparse_spmv_impl_merge.hpp"
#include "KokkosKernels_Error.hpp"

namespace KokkosSparse {
namespace Impl {

constexpr const char* KOKKOSSPARSE_ALG_MERGE = "merge";

template <class InputType, class DeviceType>
struct GetCoeffView {
typedef Kokkos::View<InputType*, Kokkos::LayoutLeft, DeviceType> view_type;
Expand Down Expand Up @@ -645,11 +648,21 @@ static void spmv_beta(const KokkosKernels::Experimental::Controls& controls,
typename YVector::const_value_type& beta,
const YVector& y) {
if (mode[0] == NoTranspose[0]) {
spmv_beta_no_transpose<AMatrix, XVector, YVector, dobeta, false>(
controls, alpha, A, x, beta, y);
if (controls.getParameter("algorithm") == KOKKOSSPARSE_ALG_MERGE) {
SpmvMergeHierarchical<AMatrix, XVector, YVector>::spmv(mode, alpha, A, x,
beta, y);
} else {
spmv_beta_no_transpose<AMatrix, XVector, YVector, dobeta, false>(
controls, alpha, A, x, beta, y);
}
} else if (mode[0] == Conjugate[0]) {
spmv_beta_no_transpose<AMatrix, XVector, YVector, dobeta, true>(
controls, alpha, A, x, beta, y);
if (controls.getParameter("algorithm") == KOKKOSSPARSE_ALG_MERGE) {
SpmvMergeHierarchical<AMatrix, XVector, YVector>::spmv(mode, alpha, A, x,
beta, y);
} else {
spmv_beta_no_transpose<AMatrix, XVector, YVector, dobeta, true>(
controls, alpha, A, x, beta, y);
}
} else if (mode[0] == Transpose[0]) {
spmv_beta_transpose<AMatrix, XVector, YVector, dobeta, false>(alpha, A, x,
beta, y);
Expand Down
Loading

0 comments on commit ba7a062

Please sign in to comment.