Skip to content

Commit

Permalink
Merge pull request #1435 from NexGenAnalytics/blas-teamvector-gemv
Browse files Browse the repository at this point in the history
Move {Team,TeamVector}Gemv to KokkosBlas
  • Loading branch information
lucbv authored Aug 17, 2022
2 parents a730065 + da2149a commit 1d37bad
Show file tree
Hide file tree
Showing 30 changed files with 1,261 additions and 1,255 deletions.
13 changes: 1 addition & 12 deletions src/batched/KokkosBatched_Util.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,19 +283,8 @@ struct Direct {
struct Backward {};
};

struct Mode {
struct Serial {
static const char *name() { return "Serial"; }
};
struct Team {
static const char *name() { return "Team"; }
};
struct TeamVector {
static const char *name() { return "TeamVector"; }
};
};

using KokkosBlas::Algo;
using KokkosBlas::Mode;

struct Util {
template <typename ValueType>
Expand Down
76 changes: 42 additions & 34 deletions src/batched/dense/impl/KokkosBatched_Gemv_TeamVector_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,34 @@ struct TeamVectorGemv<MemberType, Trans::NoTranspose, Algo::Gemv::Unblocked> {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
if (AViewType::Rank == 2)
return TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
else
return TeamVectorGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(1), A.extent(2), alpha, A.data(),
A.stride_0(), A.stride_1(), A.stride_2(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
static_assert(AViewType::Rank == 3,
"Batched TeamVectorGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamVectorGemv for regular rank-2 matrix)");
return TeamVectorGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(1), A.extent(2), alpha, A.data(),
A.stride_0(), A.stride_1(), A.stride_2(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
}
};

template <typename MemberType>
struct TeamVectorGemv<MemberType, Trans::NoTranspose, Algo::Gemv::Blocked> {
template <typename ScalarType, typename AViewType, typename xViewType,
typename yViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
return TeamVectorGemvInternal<Algo::Gemv::Blocked>::invoke(
member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType & /*member*/,
const ScalarType /*alpha*/,
const AViewType & /*A*/,
const xViewType & /*x*/,
const ScalarType /*beta*/,
const yViewType & /*y*/) {
static_assert(AViewType::Rank == 3,
"Batched TeamVectorGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamVectorGemv for regular rank-2 matrix)");
Kokkos::abort(
"KokkosBatched::TeamVectorGemv<Algo::Gemv::Blocked> for rank-3 matrix "
"is NOT implemented");
}
};

Expand All @@ -68,30 +72,34 @@ struct TeamVectorGemv<MemberType, Trans::Transpose, Algo::Gemv::Unblocked> {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
if (AViewType::Rank == 2)
return TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
else
return TeamVectorGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(2), A.extent(1), alpha, A.data(),
A.stride_0(), A.stride_2(), A.stride_1(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
static_assert(AViewType::Rank == 3,
"Batched TeamVectorGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamVectorGemv for regular rank-2 matrix)");
return TeamVectorGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(2), A.extent(1), alpha, A.data(),
A.stride_0(), A.stride_2(), A.stride_1(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
}
};

template <typename MemberType>
struct TeamVectorGemv<MemberType, Trans::Transpose, Algo::Gemv::Blocked> {
template <typename ScalarType, typename AViewType, typename xViewType,
typename yViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
return TeamVectorGemvInternal<Algo::Gemv::Blocked>::invoke(
member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType & /*member*/,
const ScalarType /*alpha*/,
const AViewType & /*A*/,
const xViewType & /*x*/,
const ScalarType /*beta*/,
const yViewType & /*y*/) {
static_assert(AViewType::Rank == 3,
"Batched TeamVectorGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamVectorGemv for regular rank-2 matrix)");
Kokkos::abort(
"KokkosBatched::TeamVectorGemv<Algo::Gemv::Blocked> for rank-3 matrix "
"is NOT implemented");
}
};

Expand Down
61 changes: 7 additions & 54 deletions src/batched/dense/impl/KokkosBatched_Gemv_TeamVector_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@
/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBlas1_team_scal_impl.hpp"
#include "KokkosBlas2_serial_gemv_inner_multiple_dot.hpp"
// #include "KokkosBlas1_set_impl.hpp"
// #include "KokkosBlas1_team_scal_impl.hpp"
// #include "KokkosBlas2_serial_gemv_inner_multiple_dot.hpp"

namespace KokkosBatched {

Expand All @@ -16,17 +15,6 @@ namespace KokkosBatched {
/// ====================
template <typename ArgAlgo>
struct TeamVectorGemvInternal {
template <typename MemberType, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType & /*member*/, const int /*m*/, const int /*n*/,
const ScalarType /*alpha*/, const ValueType *KOKKOS_RESTRICT /*A*/,
const int /*as0*/, const int /*as1*/,
const ValueType *KOKKOS_RESTRICT /*x*/, const int /*xs0*/,
const ScalarType /*beta*/,
/**/ ValueType *KOKKOS_RESTRICT /*y*/, const int /*ys0*/) {
assert(false && "Error: encounter dummy impl");
return 0;
}
template <typename MemberType, typename ScalarType, typename layout,
typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(
Expand All @@ -43,45 +31,6 @@ struct TeamVectorGemvInternal {
}
};

template <>
template <typename MemberType, typename ScalarType, typename ValueType>
KOKKOS_INLINE_FUNCTION int
TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
const MemberType &member, const int m, const int n, const ScalarType alpha,
const ValueType *KOKKOS_RESTRICT A, const int as0, const int as1,
const ValueType *KOKKOS_RESTRICT x, const int xs0, const ScalarType beta,
/**/ ValueType *KOKKOS_RESTRICT y, const int ys0) {
const ScalarType one(1.0), zero(0.0);

// y = beta y + alpha A x
// y (m), A(m x n), B(n)

if (beta == zero)
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, zero, y, ys0);
else if (beta != one)
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;

if (beta != one) member.team_barrier();

Kokkos::parallel_for(Kokkos::TeamThreadRange(member, m), [&](const int &i) {
ValueType t(0);
const ValueType *KOKKOS_RESTRICT tA = (A + i * as0);
Kokkos::parallel_reduce(
Kokkos::ThreadVectorRange(member, n),
[&](const int &j, ValueType &update) {
update += tA[j * as1] * x[j * xs0];
},
t);
Kokkos::single(Kokkos::PerThread(member),
[&]() { y[i * ys0] += alpha * t; });
});
}
return 0;
}

template <>
template <typename MemberType, typename ScalarType, typename layout,
typename ValueType>
Expand All @@ -98,13 +47,17 @@ TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
// y_l (m), A_l(m x n), B_l(n)

if (beta == zero)
// TODO: KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, zero, y,
// ys0);
Kokkos::parallel_for(Kokkos::TeamVectorRange(member, 0, N * m),
[&](const int &iTemp) {
int iRow, iMatrix;
getIndices<int, layout>(iTemp, m, N, iRow, iMatrix);
Y[ys0 * iMatrix + ys1 * iRow] = zero;
});
else if (beta != one)
// TODO: KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, beta,
// y, ys0);
Kokkos::parallel_for(Kokkos::TeamVectorRange(member, 0, N * m),
[&](const int &iTemp) {
int iRow, iMatrix;
Expand Down
76 changes: 42 additions & 34 deletions src/batched/dense/impl/KokkosBatched_Gemv_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,30 +30,34 @@ struct TeamGemv<MemberType, Trans::NoTranspose, Algo::Gemv::Unblocked> {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
if (AViewType::Rank == 2)
return TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
else
return TeamGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(1), A.extent(2), alpha, A.data(),
A.stride_0(), A.stride_1(), A.stride_2(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
static_assert(AViewType::Rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
return TeamGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(1), A.extent(2), alpha, A.data(),
A.stride_0(), A.stride_1(), A.stride_2(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
}
};

template <typename MemberType>
struct TeamGemv<MemberType, Trans::NoTranspose, Algo::Gemv::Blocked> {
template <typename ScalarType, typename AViewType, typename xViewType,
typename yViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
return TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
member, A.extent(0), A.extent(1), alpha, A.data(), A.stride_0(),
A.stride_1(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType & /*member*/,
const ScalarType /*alpha*/,
const AViewType & /*A*/,
const xViewType & /*x*/,
const ScalarType /*beta*/,
const yViewType & /*y*/) {
static_assert(AViewType::Rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
Kokkos::abort(
"KokkosBlas::TeamGemv<Algo::Gemv::Blocked> for rank-3 matrix is NOT "
"implemented");
}
};

Expand All @@ -68,30 +72,34 @@ struct TeamGemv<MemberType, Trans::Transpose, Algo::Gemv::Unblocked> {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
if (AViewType::Rank == 2)
return TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
else
return TeamGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(2), A.extent(1), alpha, A.data(),
A.stride_0(), A.stride_2(), A.stride_1(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
static_assert(AViewType::Rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
return TeamGemvInternal<Algo::Gemv::Unblocked>::template invoke<
MemberType, ScalarType, typename AViewType::array_layout,
typename AViewType::non_const_value_type>(
member, A.extent(0), A.extent(2), A.extent(1), alpha, A.data(),
A.stride_0(), A.stride_2(), A.stride_1(), x.data(), x.stride_0(),
x.stride_1(), beta, y.data(), y.stride_0(), y.stride_1());
}
};

template <typename MemberType>
struct TeamGemv<MemberType, Trans::Transpose, Algo::Gemv::Blocked> {
template <typename ScalarType, typename AViewType, typename xViewType,
typename yViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
return TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
member, A.extent(1), A.extent(0), alpha, A.data(), A.stride_1(),
A.stride_0(), x.data(), x.stride_0(), beta, y.data(), y.stride_0());
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType & /*member*/,
const ScalarType /*alpha*/,
const AViewType & /*A*/,
const xViewType & /*x*/,
const ScalarType /*beta*/,
const yViewType & /*y*/) {
static_assert(AViewType::Rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
Kokkos::abort(
"KokkosBlas::TeamGemv<Algo::Gemv::Blocked> for rank-3 matrix is NOT "
"implemented");
}
};

Expand Down
Loading

0 comments on commit 1d37bad

Please sign in to comment.