Skip to content

Commit

Permalink
Move Set (Serial, Team and TeamVector) from KokkosBatched to KokkosBlas
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Jun 28, 2022
1 parent 673466c commit e0c69dc
Show file tree
Hide file tree
Showing 26 changed files with 242 additions and 125 deletions.
26 changes: 19 additions & 7 deletions src/batched/dense/KokkosBatched_Set_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@

/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Vector.hpp"
#include "impl/Kokkos_Error.hpp"

namespace KokkosBatched {
///
Expand All @@ -14,7 +13,12 @@ namespace KokkosBatched {
struct SerialSet {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha,
const AViewType &A);
const AViewType &A) {
Kokkos::abort(
"KokkosBatched::SerialSet is deprecated: use KokkosBlas::SerialSet "
"instead");
return 0;
}
};

///
Expand All @@ -26,7 +30,12 @@ struct TeamSet {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A);
const AViewType &A) {
Kokkos::abort(
"KokkosBatched::TeamSet is deprecated: use KokkosBlas::TeamSet "
"instead");
return 0;
}
};

///
Expand All @@ -38,11 +47,14 @@ struct TeamVectorSet {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A);
const AViewType &A) {
Kokkos::abort(
"KokkosBatched::TeamVectorSet is deprecated: use "
"KokkosBlas::TeamVectorSet instead");
return 0;
}
};

} // namespace KokkosBatched

#include "KokkosBatched_Set_Impl.hpp"

#endif
6 changes: 3 additions & 3 deletions src/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"
Expand Down Expand Up @@ -41,7 +41,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
const ScalarType one(1.0), zero(0.0);

if (beta == zero)
SerialSetInternal ::invoke(m, n, zero, C, cs0, cs1);
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1);
else if (beta != one)
SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);

Expand Down Expand Up @@ -81,7 +81,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
const ScalarType one(1.0), zero(0.0);

if (beta == zero)
SerialSetInternal ::invoke(m, n, zero, C, cs0, cs1);
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1);
else if (beta != one)
SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

namespace KokkosBatched {
Expand Down Expand Up @@ -39,7 +39,8 @@ TeamVectorGemmInternal<Algo::Gemm::Unblocked, false>::invoke(
const ScalarType one(1.0), zero(0.0);

if (beta == zero)
TeamVectorSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, n, zero, C, cs0,
cs1);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);

Expand Down Expand Up @@ -79,7 +80,8 @@ TeamVectorGemmInternal<Algo::Gemm::Unblocked, true>::invoke(
const ScalarType one(1.0), zero(0.0);

if (beta == zero)
TeamVectorSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, n, zero, C, cs0,
cs1);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);

Expand Down
6 changes: 3 additions & 3 deletions src/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "KokkosBatched_Util.hpp"
#include "KokkosKernels_ExecSpaceUtils.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"
Expand Down Expand Up @@ -41,7 +41,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Unblocked>::invoke(
const ScalarType one(1.0), zero(0.0);

if (beta == zero)
TeamSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);

Expand Down Expand Up @@ -82,7 +82,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Blocked>::invoke(
const ScalarType one(1.0), zero(0.0);

if (beta == zero)
TeamSetInternal ::invoke(member, m, n, zero, C, cs0, cs1);
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);

Expand Down
6 changes: 3 additions & 3 deletions src/batched/dense/impl/KokkosBatched_Gemv_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBatched_InnerMultipleDotProduct_Serial_Impl.hpp"
Expand Down Expand Up @@ -39,7 +39,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemvInternal<Algo::Gemv::Unblocked>::invoke(
// y (m), A(m x n), B(n)

if (beta == zero)
SerialSetInternal ::invoke(m, zero, y, ys0);
KokkosBlas::Impl::SerialSetInternal::invoke(m, zero, y, ys0);
else if (beta != one)
SerialScaleInternal::invoke(m, beta, y, ys0);

Expand Down Expand Up @@ -78,7 +78,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemvInternal<Algo::Gemv::Blocked>::invoke(
constexpr int mbAlgo = Algo::Gemv::Blocked::mb();

if (beta == zero)
SerialSetInternal ::invoke(m, zero, y, ys0);
KokkosBlas::Impl::SerialSetInternal::invoke(m, zero, y, ys0);
else if (beta != one)
SerialScaleInternal::invoke(m, beta, y, ys0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBatched_InnerMultipleDotProduct_Serial_Impl.hpp"
Expand Down Expand Up @@ -58,7 +58,7 @@ TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
// y (m), A(m x n), B(n)

if (beta == zero)
TeamVectorSetInternal ::invoke(member, m, zero, y, ys0);
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, beta, y, ys0);

Expand Down
6 changes: 3 additions & 3 deletions src/batched/dense/impl/KokkosBatched_Gemv_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBatched_InnerMultipleDotProduct_Serial_Impl.hpp"
Expand Down Expand Up @@ -48,7 +48,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
// y (m), A(m x n), B(n)

if (beta == zero)
TeamSetInternal ::invoke(member, m, zero, y, ys0);
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamScaleInternal::invoke(member, m, beta, y, ys0);

Expand Down Expand Up @@ -87,7 +87,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
constexpr int mbAlgo = Algo::Gemv::Blocked::mb();

if (beta == zero)
TeamSetInternal ::invoke(member, m, zero, y, ys0);
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamScaleInternal::invoke(member, m, beta, y, ys0);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_SetIdentity_Internal.hpp"
#include "KokkosBatched_ApplyQ_Serial_Internal.hpp"

Expand Down Expand Up @@ -37,7 +37,8 @@ struct SerialHessenbergFormQInternal {
/// B is m x m
// set identity
if (is_Q_zero)
SerialSetInternal::invoke(m, value_type(1), Q, qs0 + qs1);
KokkosBlas::Impl::SerialSetInternal::invoke(m, value_type(1), Q,
qs0 + qs1);
else
SerialSetIdentityInternal::invoke(m, Q, qs0, qs1);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_SetIdentity_Internal.hpp"
#include "KokkosBatched_ApplyQ_Serial_Internal.hpp"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_SetIdentity_Internal.hpp"
#include "KokkosBatched_ApplyQ_TeamVector_Internal.hpp"

Expand Down Expand Up @@ -36,7 +36,8 @@ struct TeamVectorQR_FormQ_Internal {

// set identity
if (is_Q_zero)
TeamVectorSetInternal::invoke(member, m, value_type(1), Q, qs0 + qs1);
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, value_type(1),
Q, qs0 + qs1);
else
TeamVectorSetIdentityInternal::invoke(member, m, n, Q, qs0, qs1);
member.team_barrier();
Expand Down
48 changes: 0 additions & 48 deletions src/batched/dense/impl/KokkosBatched_Set_Impl.hpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBatched_InnerTrsm_Serial_Impl.hpp"
Expand Down
10 changes: 5 additions & 5 deletions src/batched/dense/impl/KokkosBatched_Trmm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@

#include "KokkosBatched_Util.hpp"

#include "KokkosBatched_Set_Internal.hpp"
#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

namespace KokkosBatched {
Expand Down Expand Up @@ -152,7 +152,7 @@ SerialTrmmInternalLeftLower<Algo::Trmm::Unblocked>::invoke(
if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) return 0;

if (alpha == zero)
SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

Expand Down Expand Up @@ -240,7 +240,7 @@ SerialTrmmInternalRightLower<Algo::Trmm::Unblocked>::invoke(
if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) return 0;

if (alpha == zero)
SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

Expand Down Expand Up @@ -321,7 +321,7 @@ SerialTrmmInternalLeftUpper<Algo::Trmm::Unblocked>::invoke(
if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) return 0;

if (alpha == zero)
SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

Expand Down Expand Up @@ -401,7 +401,7 @@ SerialTrmmInternalRightUpper<Algo::Trmm::Unblocked>::invoke(
if (bm <= 0 || bn <= 0 || am <= 0 || an <= 0) return 0;

if (alpha == zero)
SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

Expand Down
Loading

0 comments on commit e0c69dc

Please sign in to comment.