Skip to content

Commit

Permalink
Update the batched sparse kernels to add SIMD data types and options …
Browse files Browse the repository at this point in the history
…for the temporary variable allocations.
  • Loading branch information
kliegeois committed Sep 26, 2022
1 parent 7d25cf5 commit 0b16cb0
Show file tree
Hide file tree
Showing 10 changed files with 665 additions and 192 deletions.
103 changes: 78 additions & 25 deletions batched/sparse/impl/KokkosBatched_CG_TeamVector_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ namespace KokkosBatched {

template <typename MemberType>
template <typename OperatorType, typename VectorViewType,
typename KrylovHandleType>
typename KrylovHandleType, typename TMPViewType,
typename TMPNormViewType>
KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandleType& handle) {
const VectorViewType& _X, const KrylovHandleType& handle,
const TMPViewType& _TMPView, const TMPNormViewType& _TMPNormView) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;
Expand All @@ -85,29 +87,25 @@ KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
const OrdinalType numMatrices = _X.extent(0);
const OrdinalType numRows = _X.extent(1);

ScratchPadVectorViewType P(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType Q(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType R(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType X(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);

ScratchPadNormViewType sqr_norm_0(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType sqr_norm_j(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType alpha(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType mask(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType tmp(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
int offset_P = 0;
int offset_Q = offset_P + numRows;
int offset_R = offset_Q + numRows;
int offset_X = offset_R + numRows;

auto P = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_P, offset_P + numRows));
auto Q = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_Q, offset_Q + numRows));
auto R = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_R, offset_R + numRows));
auto X = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_X, offset_X + numRows));

auto sqr_norm_0 = Kokkos::subview(_TMPNormView, Kokkos::ALL, 0);
auto sqr_norm_j = Kokkos::subview(_TMPNormView, Kokkos::ALL, 1);
auto alpha = Kokkos::subview(_TMPNormView, Kokkos::ALL, 2);
auto mask = Kokkos::subview(_TMPNormView, Kokkos::ALL, 3);
auto tmp = Kokkos::subview(_TMPNormView, Kokkos::ALL, 4);

TeamVectorCopy<MemberType>::invoke(member, _X, X);
// Deep copy of b into r_0:
Expand Down Expand Up @@ -200,6 +198,61 @@ KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
TeamVectorCopy<MemberType>::invoke(member, X, _X);
return status;
}

template <typename MemberType>
template <typename OperatorType, typename VectorViewType,
typename KrylovHandleType>
KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandleType& handle) {
const int strategy = handle.get_memory_strategy();
if (strategy == 0) {
using ScratchPadVectorViewType = Kokkos::View<
typename VectorViewType::non_const_value_type**,
typename VectorViewType::array_layout,
typename VectorViewType::execution_space::scratch_memory_space>;
using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);
const int numRows = _X.extent(1);

ScratchPadVectorViewType _TMPView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
4 * numRows);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
if (strategy == 1) {
const int first_matrix = handle.first_index(member.league_rank());
const int last_matrix = handle.last_index(member.league_rank());

using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);

auto _TMPView = Kokkos::subview(
handle.tmp_view, Kokkos::make_pair(first_matrix, last_matrix),
Kokkos::ALL);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
return 0;
}

} // namespace KokkosBatched

#endif
102 changes: 77 additions & 25 deletions batched/sparse/impl/KokkosBatched_CG_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ namespace KokkosBatched {
///

template <typename MemberType>
template <typename OperatorType, typename VectorViewType, typename KrylovHandle>
template <typename OperatorType, typename VectorViewType, typename KrylovHandle,
typename TMPViewType, typename TMPNormViewType>
KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandle& handle) {
const VectorViewType& _X, const KrylovHandle& handle,
const TMPViewType& _TMPView, const TMPNormViewType& _TMPNormView) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;
Expand All @@ -83,29 +85,25 @@ KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
const OrdinalType numMatrices = _X.extent(0);
const OrdinalType numRows = _X.extent(1);

ScratchPadVectorViewType P(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType Q(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType R(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType X(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);

ScratchPadNormViewType sqr_norm_0(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType sqr_norm_j(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType alpha(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType mask(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType tmp(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
int offset_P = 0;
int offset_Q = offset_P + numRows;
int offset_R = offset_Q + numRows;
int offset_X = offset_R + numRows;

auto P = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_P, offset_P + numRows));
auto Q = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_Q, offset_Q + numRows));
auto R = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_R, offset_R + numRows));
auto X = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_X, offset_X + numRows));

auto sqr_norm_0 = Kokkos::subview(_TMPNormView, Kokkos::ALL, 0);
auto sqr_norm_j = Kokkos::subview(_TMPNormView, Kokkos::ALL, 1);
auto alpha = Kokkos::subview(_TMPNormView, Kokkos::ALL, 2);
auto mask = Kokkos::subview(_TMPNormView, Kokkos::ALL, 3);
auto tmp = Kokkos::subview(_TMPNormView, Kokkos::ALL, 4);

TeamCopy<MemberType>::invoke(member, _X, X);
// Deep copy of b into r_0:
Expand Down Expand Up @@ -199,6 +197,60 @@ KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
return status;
}

template <typename MemberType>
template <typename OperatorType, typename VectorViewType,
typename KrylovHandleType>
KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandleType& handle) {
const int strategy = handle.get_memory_strategy();
if (strategy == 0) {
using ScratchPadVectorViewType = Kokkos::View<
typename VectorViewType::non_const_value_type**,
typename VectorViewType::array_layout,
typename VectorViewType::execution_space::scratch_memory_space>;
using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);
const int numRows = _X.extent(1);

ScratchPadVectorViewType _TMPView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
4 * numRows);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
if (strategy == 1) {
const int first_matrix = handle.first_index(member.league_rank());
const int last_matrix = handle.last_index(member.league_rank());

using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);

auto _TMPView = Kokkos::subview(
handle.tmp_view, Kokkos::make_pair(first_matrix, last_matrix),
Kokkos::ALL);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
return 0;
}

} // namespace KokkosBatched

#endif
Loading

0 comments on commit 0b16cb0

Please sign in to comment.