Skip to content

Commit

Permalink
Merge branch 'blas-serial-scale' into move_set-scale-test_to_blas
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikołaj Zuzek committed Jun 30, 2022
2 parents d273957 + 5bec42c commit a72302a
Show file tree
Hide file tree
Showing 25 changed files with 317 additions and 176 deletions.
53 changes: 32 additions & 21 deletions src/batched/dense/KokkosBatched_Scale_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,47 +3,58 @@

/// \author Kyungjoo Kim (kyukim@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Vector.hpp"
#include "impl/Kokkos_Error.hpp"

namespace KokkosBatched {

///
/// Serial Scale
///

struct SerialScale {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha,
const AViewType &A);
};
struct [[deprecated]] SerialScale{
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ScalarType alpha,
const AViewType &A){Kokkos::abort(
"KokkosBatched::SerialScale is deprecated: use KokkosBlas::SerialScale "
"instead");
return 0;
} // namespace KokkosBatched
}
;

///
/// Team Scale
///

template <typename MemberType>
struct TeamScale {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A);
};
struct [[deprecated]] TeamScale{
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A){Kokkos::abort(
"KokkosBatched::TeamScale is deprecated: use KokkosBlas::TeamScale "
"instead");
return 0;
}
}
;

///
/// TeamVector Scale
///

template <typename MemberType>
struct TeamVectorScale {
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const MemberType &member,
const ScalarType alpha,
const AViewType &A);
};
struct [[deprecated]] TeamVectorScale{
template <typename ScalarType, typename AViewType>
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A){
Kokkos::abort("KokkosBatched::TeamVectorScale is deprecated: use "
"KokkosBlas::TeamVectorScale instead");
return 0;
}
}
;

} // namespace KokkosBatched

#include "KokkosBatched_Scale_Impl.hpp"

#endif
6 changes: 3 additions & 3 deletions src/batched/dense/impl/KokkosBatched_Gemm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"
#include "KokkosBlas1_serial_scal_impl.hpp"

#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"

Expand Down Expand Up @@ -43,7 +43,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Unblocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1);
else if (beta != one)
SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);

if (alpha != zero) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down Expand Up @@ -83,7 +83,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemmInternal<Algo::Gemm::Blocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, C, cs0, cs1);
else if (beta != one)
SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, beta, C, cs0, cs1);

if (alpha != zero) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"
#include "KokkosBlas1_team_scal_impl.hpp"

namespace KokkosBatched {

Expand Down Expand Up @@ -42,7 +42,8 @@ TeamVectorGemmInternal<Algo::Gemm::Unblocked, false>::invoke(
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, n, zero, C, cs0,
cs1);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C,
cs0, cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down Expand Up @@ -83,7 +84,8 @@ TeamVectorGemmInternal<Algo::Gemm::Unblocked, true>::invoke(
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, n, zero, C, cs0,
cs1);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, n, beta, C,
cs0, cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down
8 changes: 5 additions & 3 deletions src/batched/dense/impl/KokkosBatched_Gemm_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include "KokkosKernels_ExecSpaceUtils.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"
#include "KokkosBlas1_team_scal_impl.hpp"

#include "KokkosBatched_InnerGemmFixC_Serial_Impl.hpp"

Expand Down Expand Up @@ -43,7 +43,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Unblocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0,
cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down Expand Up @@ -84,7 +85,8 @@ KOKKOS_INLINE_FUNCTION int TeamGemmInternal<Algo::Gemm::Blocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, n, zero, C, cs0, cs1);
else if (beta != one)
TeamScaleInternal::invoke(member, m, n, beta, C, cs0, cs1);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, n, beta, C, cs0,
cs1);

if (alpha != ScalarType(0.0)) {
if (m <= 0 || n <= 0 || k <= 0) return 0;
Expand Down
7 changes: 3 additions & 4 deletions src/batched/dense/impl/KokkosBatched_Gemv_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBlas1_serial_scal_impl.hpp"
#include "KokkosBatched_InnerMultipleDotProduct_Serial_Impl.hpp"

namespace KokkosBatched {
Expand Down Expand Up @@ -41,7 +40,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemvInternal<Algo::Gemv::Unblocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, zero, y, ys0);
else if (beta != one)
SerialScaleInternal::invoke(m, beta, y, ys0);
KokkosBlas::Impl::SerialScaleInternal::invoke(m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down Expand Up @@ -80,7 +79,7 @@ KOKKOS_INLINE_FUNCTION int SerialGemvInternal<Algo::Gemv::Blocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, zero, y, ys0);
else if (beta != one)
SerialScaleInternal::invoke(m, beta, y, ys0);
KokkosBlas::Impl::SerialScaleInternal::invoke(m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBlas1_team_scal_impl.hpp"
#include "KokkosBatched_InnerMultipleDotProduct_Serial_Impl.hpp"

namespace KokkosBatched {
Expand Down Expand Up @@ -60,7 +59,7 @@ TeamVectorGemvInternal<Algo::Gemv::Unblocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::TeamVectorSetInternal::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamVectorScaleInternal::invoke(member, m, beta, y, ys0);
KokkosBlas::Impl::TeamVectorScaleInternal::invoke(member, m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down
7 changes: 3 additions & 4 deletions src/batched/dense/impl/KokkosBatched_Gemv_Team_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBlas1_team_scal_impl.hpp"
#include "KokkosBatched_InnerMultipleDotProduct_Serial_Impl.hpp"

namespace KokkosBatched {
Expand Down Expand Up @@ -50,7 +49,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemvInternal<Algo::Gemv::Unblocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamScaleInternal::invoke(member, m, beta, y, ys0);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down Expand Up @@ -89,7 +88,7 @@ KOKKOS_INLINE_FUNCTION int TeamGemvInternal<Algo::Gemv::Blocked>::invoke(
if (beta == zero)
KokkosBlas::Impl::TeamSetInternal::invoke(member, m, zero, y, ys0);
else if (beta != one)
TeamScaleInternal::invoke(member, m, beta, y, ys0);
KokkosBlas::Impl::TeamScaleInternal::invoke(member, m, beta, y, ys0);

if (alpha != zero) {
if (m <= 0 || n <= 0) return 0;
Expand Down
48 changes: 0 additions & 48 deletions src/batched/dense/impl/KokkosBatched_Scale_Impl.hpp

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBatched_InnerTrsm_Serial_Impl.hpp"
#include "KokkosBatched_Gemv_Serial_Internal.hpp"
Expand Down
14 changes: 9 additions & 5 deletions src/batched/dense/impl/KokkosBatched_Trmm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"
#include "KokkosBlas1_serial_scal_impl.hpp"

namespace KokkosBatched {

Expand Down Expand Up @@ -154,7 +154,8 @@ SerialTrmmInternalLeftLower<Algo::Trmm::Unblocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand Down Expand Up @@ -242,7 +243,8 @@ SerialTrmmInternalRightLower<Algo::Trmm::Unblocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand Down Expand Up @@ -323,7 +325,8 @@ SerialTrmmInternalLeftUpper<Algo::Trmm::Unblocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand Down Expand Up @@ -403,7 +406,8 @@ SerialTrmmInternalRightUpper<Algo::Trmm::Unblocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(bm, bn, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(bm, bn, alpha, B, bs0, bs1);

#if defined(KOKKOS_ENABLE_PRAGMA_UNROLL)
#pragma unroll
Expand Down
15 changes: 9 additions & 6 deletions src/batched/dense/impl/KokkosBatched_Trsm_Serial_Internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
#include "KokkosBatched_Util.hpp"

#include "KokkosBlas1_set_impl.hpp"
#include "KokkosBatched_Scale_Internal.hpp"

#include "KokkosBlas1_serial_scal_impl.hpp"
#include "KokkosBatched_InnerGemmFixA_Serial_Impl.hpp"
#include "KokkosBatched_InnerTrsm_Serial_Impl.hpp"

Expand Down Expand Up @@ -41,7 +40,8 @@ SerialTrsmInternalLeftLower<Algo::Trsm::Unblocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

for (int p = 0; p < m; ++p) {
Expand Down Expand Up @@ -89,7 +89,8 @@ SerialTrsmInternalLeftLower<Algo::Trsm::Blocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

InnerTrsmLeftLowerUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, bs1);
Expand Down Expand Up @@ -156,7 +157,8 @@ SerialTrsmInternalLeftUpper<Algo::Trsm::Unblocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

ValueType *KOKKOS_RESTRICT B0 = B;
Expand Down Expand Up @@ -204,7 +206,8 @@ SerialTrsmInternalLeftUpper<Algo::Trsm::Blocked>::invoke(
if (alpha == zero)
KokkosBlas::Impl::SerialSetInternal::invoke(m, n, zero, B, bs0, bs1);
else {
if (alpha != one) SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (alpha != one)
KokkosBlas::Impl::SerialScaleInternal::invoke(m, n, alpha, B, bs0, bs1);
if (m <= 0 || n <= 0) return 0;

InnerTrsmLeftUpperUnitDiag<mbAlgo> trsm_u(as0, as1, bs0, bs1);
Expand Down
Loading

0 comments on commit a72302a

Please sign in to comment.