Skip to content

Commit

Permalink
Merge Pull Request #13536 from ndellingwood/Trilinos/kernels-blsptrsv…
Browse files Browse the repository at this point in the history
…-2376

Automatically Merged using Trilinos Pull Request AutoTester
PR Title: b'kokkos-kernels: Block Sptrsv fixes (kokkos/kokkos-kernels#2376)'
PR Author: ndellingwood
  • Loading branch information
trilinos-autotester authored Oct 17, 2024
2 parents ffcdd92 + 699b67a commit 1beef8c
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 96 deletions.
275 changes: 179 additions & 96 deletions packages/kokkos-kernels/sparse/impl/KokkosSparse_sptrsv_solve_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct SptrsvWrap {
using range_type = Kokkos::pair<int, int>;

// Tag structs
struct UnsortedTag {}; // This doesn't appear to be supported
struct UnsortedTag {};
struct LargerCutoffTag {};
struct UnsortedLargerCutoffTag {};

Expand Down Expand Up @@ -115,7 +115,9 @@ struct SptrsvWrap {
RHSType rhs;
entries_t nodes_grouped_by_level;

using reftype = scalar_t &;
using reftype = scalar_t &;
using ArrayType = reftype;
using SumArray = reftype;

struct SBlock {
template <typename T>
Expand All @@ -141,6 +143,16 @@ struct SptrsvWrap {
KOKKOS_INLINE_FUNCTION
size_type get_block_size() const { return 0; }

// multiply_subtract. C -= A * B
KOKKOS_INLINE_FUNCTION
static void multiply_subtract(const scalar_t &a, const scalar_t &b, scalar_t &c) { c -= a * b; }

KOKKOS_INLINE_FUNCTION
static void copy(const member_type &, scalar_t &, const scalar_t &) {}

KOKKOS_INLINE_FUNCTION
static void copy(scalar_t &, const scalar_t &) {}

// lget
KOKKOS_INLINE_FUNCTION
scalar_t &lget(const size_type row) const { return lhs(row); }
Expand Down Expand Up @@ -195,6 +207,60 @@ struct SptrsvWrap {

using reftype = Vector;

struct ArrayType {
scalar_t m_data[MAX_VEC_SIZE];

KOKKOS_INLINE_FUNCTION
ArrayType() { init(); }

KOKKOS_INLINE_FUNCTION
ArrayType(const ArrayType &rhs_) {
for (size_type i = 0; i < MAX_VEC_SIZE; ++i) m_data[i] = rhs_.m_data[i];
}

KOKKOS_INLINE_FUNCTION
ArrayType(const Vector &) { init(); }

KOKKOS_INLINE_FUNCTION
void init() {
for (size_type i = 0; i < MAX_VEC_SIZE; ++i) m_data[i] = 0;
}

KOKKOS_INLINE_FUNCTION
ArrayType &operator+=(const ArrayType &rhs_) {
for (size_type i = 0; i < MAX_VEC_SIZE; ++i) m_data[i] += rhs_.m_data[i];
return *this;
}
};

struct SumArray {
using reducer = SumArray;
using value_type = ArrayType;
using result_view_type = Kokkos::View<value_type *, execution_space, Kokkos::MemoryUnmanaged>;

private:
value_type &m_value;

public:
KOKKOS_INLINE_FUNCTION
SumArray(value_type &value) : m_value(value) {}

KOKKOS_INLINE_FUNCTION
void join(value_type &dest, const value_type &src) const { dest += src; }

KOKKOS_INLINE_FUNCTION
void init(value_type &val) const { val.init(); }

KOKKOS_INLINE_FUNCTION
value_type &reference() const { return m_value; }

KOKKOS_INLINE_FUNCTION
result_view_type view() const { return result_view_type(&m_value, 1); }

KOKKOS_INLINE_FUNCTION
bool reference_scalar() const { return true; }
};

RowMapType row_map;
EntriesType entries;
ValuesType values;
Expand All @@ -215,6 +281,7 @@ struct SptrsvWrap {
block_size(block_size_),
block_items(block_size * block_size) {
KK_REQUIRE_MSG(block_size > 0, "Tried to use block_size=0 with the blocked Common?");
KK_REQUIRE_MSG(block_size <= MAX_VEC_SIZE, "Max supported block size is " << MAX_VEC_SIZE);
}

KOKKOS_INLINE_FUNCTION
Expand Down Expand Up @@ -257,17 +324,17 @@ struct SptrsvWrap {
team.team_barrier();
KokkosBatched::TeamLU<member_type, KokkosBatched::Algo::LU::Blocked>::invoke(team, LU);

// A = LU
// A^-1 = U^-1 * L^-1
// b = (b * U^-1) * L^-1, so do U trsv first
// Ax = LUx = Lz = b, we use the change of variable z = U*x
// z = L^-1 * b, first we solve for z, storing the result back into b
// x = U^-1 * z, second we solve for x, again storing the result back into b
team.team_barrier();
KokkosBatched::TeamTrsv<member_type, KokkosBatched::Uplo::Lower, KokkosBatched::Trans::NoTranspose,
KokkosBatched::Diag::Unit, KokkosBatched::Algo::Trsv::Blocked>::invoke(team, 1.0, LU, b);

team.team_barrier();
KokkosBatched::TeamTrsv<member_type, KokkosBatched::Uplo::Upper, KokkosBatched::Trans::NoTranspose,
KokkosBatched::Diag::NonUnit, KokkosBatched::Algo::Trsv::Blocked>::invoke(team, 1.0, LU,
b);

team.team_barrier();
KokkosBatched::TeamTrsv<member_type, KokkosBatched::Uplo::Lower, KokkosBatched::Trans::NoTranspose,
KokkosBatched::Diag::Unit, KokkosBatched::Algo::Trsv::Blocked>::invoke(team, 1.0, LU, b);
}

// serial divide. b /= A (b = b * A^-1)
Expand All @@ -278,21 +345,44 @@ struct SptrsvWrap {

// Need a temp block to do LU of A
const auto block_size_ = b.size();
KK_KERNEL_REQUIRE_MSG(block_size_ <= MAX_VEC_SIZE,
"Max supported block size for range-policy is 16. Use team-policy alg if you need more.");

Block LU(&buff[0], block_size_, block_size_);
assign(LU, A);
KokkosBatched::SerialLU<KokkosBatched::Algo::LU::Blocked>::invoke(LU);

// A = LU
// A^-1 = U^-1 * L^-1
// b = (b * U^-1) * L^-1, so do U trsv first
// Ax = LUx = Lz = b, we use the change of variable z = U*x
// z = L^-1 * b, first we solve for z, storing the result back into b
// x = U^-1 * z, second we solve for x, again storing the result back into b
KokkosBatched::SerialTrsv<KokkosBatched::Uplo::Lower, KokkosBatched::Trans::NoTranspose,
KokkosBatched::Diag::Unit, KokkosBatched::Algo::Trsv::Blocked>::invoke(1.0, LU, b);

KokkosBatched::SerialTrsv<KokkosBatched::Uplo::Upper, KokkosBatched::Trans::NoTranspose,
KokkosBatched::Diag::NonUnit, KokkosBatched::Algo::Trsv::Blocked>::invoke(1.0, LU, b);
}

KokkosBatched::SerialTrsv<KokkosBatched::Uplo::Lower, KokkosBatched::Trans::NoTranspose,
KokkosBatched::Diag::Unit, KokkosBatched::Algo::Trsv::Blocked>::invoke(1.0, LU, b);
// multiply_subtract. C -= A * B
KOKKOS_INLINE_FUNCTION
static void multiply_subtract(const CBlock &A, const CVector &b, ArrayType &ca) {
Vector c(&ca.m_data[0], b.size());
multiply_subtract(A, b, c);
}

KOKKOS_INLINE_FUNCTION
static void multiply_subtract(const CBlock &A, const CVector &b, Vector &c) {
// Use gemv. alpha is hardcoded to -1, beta hardcoded to 1
KokkosBlas::SerialGemv<KokkosBlas::Trans::NoTranspose, KokkosBlas::Algo::Gemv::Blocked>::invoke(-1.0, A, b, 1.0,
c);
}

KOKKOS_INLINE_FUNCTION
static void copy(const member_type &team, const Vector &lhs_, ArrayType &rhsa) {
CVector rhs_(&rhsa.m_data[0], lhs_.size());
assign(team, lhs_, rhs_);
}

KOKKOS_INLINE_FUNCTION
static void copy(const Vector &lhs_, ArrayType &rhsa) {
CVector rhs_(&rhsa.m_data[0], lhs_.size());
assign(lhs_, rhs_);
}

// lget
Expand Down Expand Up @@ -331,58 +421,69 @@ struct SptrsvWrap {
*/
template <class RowMapType, class EntriesType, class ValuesType, class LHSType, class RHSType, bool BlockEnabled>
struct Intermediate : public Common<RowMapType, EntriesType, ValuesType, LHSType, RHSType, BlockEnabled> {
using Base = Common<RowMapType, EntriesType, ValuesType, LHSType, RHSType, BlockEnabled>;
using Base = Common<RowMapType, EntriesType, ValuesType, LHSType, RHSType, BlockEnabled>;
using accum_t = std::conditional_t<BlockEnabled, typename Base::ArrayType, scalar_t>;

Intermediate(const RowMapType &row_map_, const EntriesType &entries_, const ValuesType &values_, LHSType &lhs_,
const RHSType &rhs_, const entries_t &nodes_grouped_by_level_, const size_type block_size_ = 0)
: Base(row_map_, entries_, values_, lhs_, rhs_, nodes_grouped_by_level_, block_size_) {}

struct ReduceFunctorBasic {
struct ReduceSumFunctor {
const Base *m_obj;
const lno_t rowid;
lno_t diag;

KOKKOS_INLINE_FUNCTION
ReduceFunctorBasic(const Base *obj, const lno_t = 0) : m_obj(obj) {}

KOKKOS_INLINE_FUNCTION
static void multiply_subtract(const scalar_t &val, const scalar_t &lhs_col_val, scalar_t &accum) {
accum -= val * lhs_col_val;
}

KOKKOS_INLINE_FUNCTION
void operator()(size_type i, scalar_t &accum) const {
void operator()(size_type i, accum_t &accum) const {
const auto colid = m_obj->entries(i);
multiply_subtract(m_obj->vget(i), m_obj->lget(colid), accum);
auto val = m_obj->vget(i);
auto lhs_colid = m_obj->lget(colid);
// accum -= val * lhs_colid;
if constexpr (BlockEnabled) {
accum_t temp;
Base::multiply_subtract(val, lhs_colid, temp);
accum += temp;
} else {
Base::multiply_subtract(val, lhs_colid, accum);
}
KK_KERNEL_ASSERT_MSG(colid != rowid, "Should not have hit diag");
}
};

struct ReduceFunctorBlock : public ReduceFunctorBasic {
using P = ReduceFunctorBasic;

const size_type block_size;
const size_type b;

KOKKOS_INLINE_FUNCTION
ReduceFunctorBlock(const Base *obj, const size_type block_size_, const size_type b_, const lno_t = 0)
: P(obj), block_size(block_size_), b(b_) {}
struct ReduceSumDiagFunctor {
const Base *m_obj;
const lno_t rowid;
mutable lno_t diag;

KOKKOS_INLINE_FUNCTION
void operator()(size_type i, scalar_t &accum) const {
const auto idx = i / block_size;
const auto colid = P::m_obj->entries(idx);
P::multiply_subtract(P::m_obj->vget(idx)(b, i % block_size), P::m_obj->lget(colid)(b), accum);
void operator()(size_type i, accum_t &accum) const {
const auto colid = m_obj->entries(i);
if (colid != rowid) {
auto val = m_obj->vget(i);
auto lhs_colid = m_obj->lget(colid);
// accum -= val * lhs_colid;
if constexpr (BlockEnabled) {
accum_t temp;
Base::multiply_subtract(val, lhs_colid, temp);
accum += temp;
} else {
Base::multiply_subtract(val, lhs_colid, accum);
}
} else {
diag = i;
}
}
};

/**
* If we want to support Unsorted, we'll need a Functor that returns the ptr
* of the diag item (colid == rowid). Possibly via multi-reduce? The UnsortedTag
* is defined above but no policies actually use it.
*/

template <bool IsSerial, bool IsSorted, bool IsLower, bool UseThreadVec = false>
KOKKOS_INLINE_FUNCTION void solve_impl(const member_type *team, const int my_rank, const long node_count) const {
using reduce_item_t = typename Base::ArrayType;
using reducer_t = typename Base::SumArray;
using functor_t = std::conditional_t<IsSorted, ReduceSumFunctor, ReduceSumDiagFunctor>;

static_assert(!((!IsSerial && BlockEnabled) && UseThreadVec),
"ThreadVectorRanges are not yet supported for block-enabled");
static_assert(!(IsSerial && UseThreadVec), "Requested thread vector range in serial?");
static_assert(IsSorted, "Unsorted is not yet supported.");

const auto rowid = Base::nodes_grouped_by_level(my_rank + node_count);
const auto soffset = Base::row_map(rowid);
Expand All @@ -394,76 +495,58 @@ struct SptrsvWrap {
const auto itr_e = eoffset - (IsSorted ? (IsLower ? 1 : 0) : 0);

// We don't need the reducer to find the diag item if sorted
functor_t rf{this, rowid, -1};
typename Base::reftype lhs_val = Base::lget(rowid);

const auto block_size_ = BlockEnabled ? Base::get_block_size() : 1;
(void)block_size_; // Some settings do not use this var
reduce_item_t reduce = lhs_val;

if constexpr (IsSerial) {
KK_KERNEL_ASSERT_MSG(my_rank == 0, "Non zero rank in serial");
KK_KERNEL_ASSERT_MSG(team == nullptr, "Team provided in serial?");
if constexpr (BlockEnabled) {
for (size_type b = 0; b < block_size_; ++b) {
ReduceFunctorBlock rf(this, block_size_, b, rowid);
for (size_type i = itr_b * block_size_; i < itr_e * block_size_; ++i) {
rf(i, lhs_val(b));
}
}
} else {
ReduceFunctorBasic rf(this, rowid);
for (size_type i = itr_b; i < itr_e; ++i) {
rf(i, lhs_val);
}
for (auto ptr = itr_b; ptr < itr_e; ++ptr) {
rf(ptr, reduce);
}
Base::copy(lhs_val, reduce);
} else {
KK_KERNEL_ASSERT_MSG(team != nullptr, "Cannot do team operations without team");
if constexpr (!UseThreadVec) {
if constexpr (BlockEnabled) {
Kokkos::parallel_for(Kokkos::TeamThreadRange(*team, block_size_), [&](size_type b) {
ReduceFunctorBlock rf(this, block_size_, b, rowid);
Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(*team, itr_b * block_size_, itr_e * block_size_), rf,
lhs_val(b));
});
} else {
ReduceFunctorBasic rf(this, rowid);
Kokkos::parallel_reduce(Kokkos::TeamThreadRange(*team, itr_b, itr_e), rf, lhs_val);
}
Kokkos::parallel_reduce(Kokkos::TeamThreadRange(*team, itr_b, itr_e), rf, reducer_t(reduce));
team->team_barrier();
Base::copy(*team, lhs_val, reduce);
team->team_barrier();
} else {
if constexpr (BlockEnabled) {
Kokkos::parallel_for(Kokkos::ThreadVectorRange(*team, block_size_), [&](size_type b) {
ReduceFunctorBlock rf(this, block_size_, b, rowid);
for (size_type i = itr_b * block_size_; i < itr_e * block_size_; ++i) {
rf(i, lhs_val(b));
}
});
} else {
ReduceFunctorBasic rf(this, rowid);
Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(*team, itr_b, itr_e), rf, lhs_val);
}
Kokkos::parallel_reduce(Kokkos::ThreadVectorRange(*team, itr_b, itr_e), rf, reducer_t(reduce));
Base::copy(lhs_val, reduce);
}
}

// If sorted, we already know the diag. Otherwise, get it from the reducer
const lno_t diag = IsLower ? eoffset - 1 : soffset;
rf.diag = IsSorted ? (IsLower ? eoffset - 1 : soffset) : rf.diag;

// At end, handle the diag element. We need to be careful to avoid race
// conditions here.
if constexpr (IsSerial) {
// Serial case is easy, there's only 1 thread so just do the
// add_and_divide
KK_KERNEL_ASSERT_MSG(diag != -1, "Serial should always know diag");
Base::add_and_divide(lhs_val, rhs_val, Base::vget(diag));
KK_KERNEL_ASSERT_MSG(rf.diag != -1, "Serial should always know diag");
Base::add_and_divide(lhs_val, rhs_val, Base::vget(rf.diag));
} else {
// Parallel sorted case is complex. All threads know what the diag is.
// If we have a team sharing the work, we need to ensure only one
// thread performs the add_and_divide (except in BlockEnabled, then
// we can use team operations).
KK_KERNEL_ASSERT_MSG(diag != -1, "Sorted should always know diag");
if constexpr (!UseThreadVec) {
Base::add_and_divide(*team, lhs_val, rhs_val, Base::vget(diag));
if constexpr (IsSorted) {
// Parallel sorted case is complex. All threads know what the diag is.
// If we have a team sharing the work, we need to ensure only one
// thread performs the add_and_divide (except in BlockEnabled, then
// we can use team operations).
KK_KERNEL_ASSERT_MSG(rf.diag != -1, "Sorted should always know diag");
if constexpr (!UseThreadVec) {
Base::add_and_divide(*team, lhs_val, rhs_val, Base::vget(rf.diag));
} else {
Base::add_and_divide(lhs_val, rhs_val, Base::vget(rf.diag));
}
} else {
Base::add_and_divide(lhs_val, rhs_val, Base::vget(diag));
// Parallel unsorted case. Only one thread should know what the diag
// item is. We have that one do the add_and_divide.
if (rf.diag != -1) {
Base::add_and_divide(lhs_val, rhs_val, Base::vget(rf.diag));
}
}
}
}
Expand Down
Loading

0 comments on commit 1beef8c

Please sign in to comment.