Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batched sparse kernels update #1546

Merged
merged 4 commits into from
Sep 27, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 78 additions & 25 deletions batched/sparse/impl/KokkosBatched_CG_TeamVector_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,12 @@ namespace KokkosBatched {

template <typename MemberType>
template <typename OperatorType, typename VectorViewType,
typename KrylovHandleType>
typename KrylovHandleType, typename TMPViewType,
typename TMPNormViewType>
KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandleType& handle) {
const VectorViewType& _X, const KrylovHandleType& handle,
const TMPViewType& _TMPView, const TMPNormViewType& _TMPNormView) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;
Expand All @@ -85,29 +87,25 @@ KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
const OrdinalType numMatrices = _X.extent(0);
const OrdinalType numRows = _X.extent(1);

ScratchPadVectorViewType P(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType Q(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType R(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType X(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);

ScratchPadNormViewType sqr_norm_0(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType sqr_norm_j(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType alpha(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType mask(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType tmp(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
int offset_P = 0;
int offset_Q = offset_P + numRows;
int offset_R = offset_Q + numRows;
int offset_X = offset_R + numRows;

auto P = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_P, offset_P + numRows));
auto Q = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_Q, offset_Q + numRows));
auto R = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_R, offset_R + numRows));
auto X = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_X, offset_X + numRows));

auto sqr_norm_0 = Kokkos::subview(_TMPNormView, Kokkos::ALL, 0);
auto sqr_norm_j = Kokkos::subview(_TMPNormView, Kokkos::ALL, 1);
auto alpha = Kokkos::subview(_TMPNormView, Kokkos::ALL, 2);
auto mask = Kokkos::subview(_TMPNormView, Kokkos::ALL, 3);
auto tmp = Kokkos::subview(_TMPNormView, Kokkos::ALL, 4);

TeamVectorCopy<MemberType>::invoke(member, _X, X);
// Deep copy of b into r_0:
Expand Down Expand Up @@ -200,6 +198,61 @@ KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
TeamVectorCopy<MemberType>::invoke(member, X, _X);
return status;
}

template <typename MemberType>
template <typename OperatorType, typename VectorViewType,
typename KrylovHandleType>
KOKKOS_INLINE_FUNCTION int TeamVectorCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandleType& handle) {
const int strategy = handle.get_memory_strategy();
if (strategy == 0) {
using ScratchPadVectorViewType = Kokkos::View<
typename VectorViewType::non_const_value_type**,
typename VectorViewType::array_layout,
typename VectorViewType::execution_space::scratch_memory_space>;
using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);
const int numRows = _X.extent(1);

ScratchPadVectorViewType _TMPView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
4 * numRows);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
if (strategy == 1) {
const int first_matrix = handle.first_index(member.league_rank());
const int last_matrix = handle.last_index(member.league_rank());

using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);

auto _TMPView = Kokkos::subview(
handle.tmp_view, Kokkos::make_pair(first_matrix, last_matrix),
Kokkos::ALL);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
return 0;
}

} // namespace KokkosBatched

#endif
102 changes: 77 additions & 25 deletions batched/sparse/impl/KokkosBatched_CG_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,12 @@ namespace KokkosBatched {
///

template <typename MemberType>
template <typename OperatorType, typename VectorViewType, typename KrylovHandle>
template <typename OperatorType, typename VectorViewType, typename KrylovHandle,
typename TMPViewType, typename TMPNormViewType>
KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandle& handle) {
const VectorViewType& _X, const KrylovHandle& handle,
const TMPViewType& _TMPView, const TMPNormViewType& _TMPNormView) {
typedef int OrdinalType;
typedef typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type MagnitudeType;
Expand All @@ -83,29 +85,25 @@ KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
const OrdinalType numMatrices = _X.extent(0);
const OrdinalType numRows = _X.extent(1);

ScratchPadVectorViewType P(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType Q(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType R(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);
ScratchPadVectorViewType X(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
numRows);

ScratchPadNormViewType sqr_norm_0(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType sqr_norm_j(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType alpha(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType mask(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
ScratchPadNormViewType tmp(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices);
int offset_P = 0;
int offset_Q = offset_P + numRows;
int offset_R = offset_Q + numRows;
int offset_X = offset_R + numRows;

auto P = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_P, offset_P + numRows));
auto Q = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_Q, offset_Q + numRows));
auto R = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_R, offset_R + numRows));
auto X = Kokkos::subview(_TMPView, Kokkos::ALL,
Kokkos::make_pair(offset_X, offset_X + numRows));

auto sqr_norm_0 = Kokkos::subview(_TMPNormView, Kokkos::ALL, 0);
auto sqr_norm_j = Kokkos::subview(_TMPNormView, Kokkos::ALL, 1);
auto alpha = Kokkos::subview(_TMPNormView, Kokkos::ALL, 2);
auto mask = Kokkos::subview(_TMPNormView, Kokkos::ALL, 3);
auto tmp = Kokkos::subview(_TMPNormView, Kokkos::ALL, 4);

TeamCopy<MemberType>::invoke(member, _X, X);
// Deep copy of b into r_0:
Expand Down Expand Up @@ -199,6 +197,60 @@ KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
return status;
}

template <typename MemberType>
template <typename OperatorType, typename VectorViewType,
typename KrylovHandleType>
KOKKOS_INLINE_FUNCTION int TeamCG<MemberType>::invoke(
const MemberType& member, const OperatorType& A, const VectorViewType& _B,
const VectorViewType& _X, const KrylovHandleType& handle) {
const int strategy = handle.get_memory_strategy();
if (strategy == 0) {
using ScratchPadVectorViewType = Kokkos::View<
typename VectorViewType::non_const_value_type**,
typename VectorViewType::array_layout,
typename VectorViewType::execution_space::scratch_memory_space>;
using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);
const int numRows = _X.extent(1);

ScratchPadVectorViewType _TMPView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices,
4 * numRows);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
if (strategy == 1) {
const int first_matrix = handle.first_index(member.league_rank());
const int last_matrix = handle.last_index(member.league_rank());

using ScratchPadNormViewType = Kokkos::View<
typename Kokkos::Details::ArithTraits<
typename VectorViewType::non_const_value_type>::mag_type**,
typename VectorViewType::execution_space::scratch_memory_space>;

const int numMatrices = _X.extent(0);

auto _TMPView = Kokkos::subview(
handle.tmp_view, Kokkos::make_pair(first_matrix, last_matrix),
Kokkos::ALL);

ScratchPadNormViewType _TMPNormView(
member.team_scratch(handle.get_scratch_pad_level()), numMatrices, 5);

return invoke<OperatorType, VectorViewType, KrylovHandleType>(
member, A, _B, _X, handle, _TMPView, _TMPNormView);
}
return 0;
}

} // namespace KokkosBatched

#endif
Loading