From 69d97b0abca4781e6a17653e1f87517153f3b816 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Miko=C5=82aj=20Zuzek?= Date: Fri, 9 Sep 2022 17:07:41 +0200 Subject: [PATCH] GEMM: implicit MemberType --- .../blas/blas3/KokkosBlas3_gemm_perf_test.hpp | 50 +++------ src/blas/KokkosBlas3_gemm.hpp | 55 +++++---- src/blas/impl/KokkosBlas3_team_gemm_impl.hpp | 104 ++++++++---------- .../dense/Test_Batched_TeamInverseLU.hpp | 2 +- .../dense/Test_Batched_TeamSolveLU.hpp | 2 +- .../Test_Batched_TeamVectorSolveUTV2.hpp | 3 +- unit_test/blas/Test_Blas3_team_gemm.hpp | 2 +- unit_test/blas/Test_Blas3_teamvector_gemm.hpp | 2 +- 8 files changed, 96 insertions(+), 124 deletions(-) diff --git a/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp b/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp index d24431ec95..5be636277d 100644 --- a/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp +++ b/perf_test/blas/blas3/KokkosBlas3_gemm_perf_test.hpp @@ -719,9 +719,8 @@ struct parallel_batched_gemm { auto svB = Kokkos::subview(gemm_args_.B, i, Kokkos::ALL(), Kokkos::ALL()); auto svC = Kokkos::subview(gemm_args_.C, i, Kokkos::ALL(), Kokkos::ALL()); - KokkosBlas::TeamGemm::invoke(member, gemm_args_.alpha, svA, - svB, gemm_args_.beta, svC); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -731,9 +730,8 @@ struct parallel_batched_gemm { auto svB = Kokkos::subview(gemm_args_.B, Kokkos::ALL(), Kokkos::ALL(), i); auto svC = Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), i); - KokkosBlas::TeamGemm::invoke(member, gemm_args_.alpha, svA, - svB, gemm_args_.beta, svC); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -746,10 +744,8 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.C, team_idx, Kokkos::ALL(), Kokkos::ALL()); - KokkosBlas::TeamVectorGemm::invoke(member, gemm_args_.alpha, - svA, svB, gemm_args_.beta, - svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -763,10 +759,8 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.C, Kokkos::ALL(), Kokkos::ALL(), team_idx); - KokkosBlas::TeamVectorGemm::invoke(member, gemm_args_.alpha, - svA, svB, gemm_args_.beta, - svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } KOKKOS_INLINE_FUNCTION @@ -782,7 +776,7 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.Cv.ivec_4d, i, Kokkos::ALL(), Kokkos::ALL(), vector_lane); - KokkosBlas::Gemm::invoke(member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); }); @@ -802,7 +796,7 @@ struct parallel_batched_gemm { auto svC = Kokkos::subview(gemm_args_.Cv.ivec_4d, vector_lane, Kokkos::ALL(), Kokkos::ALL(), i); - KokkosBlas::Gemm::invoke(member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); }); @@ -1066,10 +1060,8 @@ struct parallel_batched_gemm_experiment2_3_4 { // Uses TeamThreadRange over C-rows // ThreadVectorRange over C-cols - KokkosBlas::TeamVectorGemm::invoke(member, gemm_args_.alpha, - svA, svB, gemm_args_.beta, - svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args_.alpha, svA, svB, gemm_args_.beta, svC); } // Experiment 3 @@ -1096,10 +1088,8 @@ struct parallel_batched_gemm_experiment2_3_4 { auto svC_col = Kokkos::subview(svC, Kokkos::ALL(), lane_idx); // TeamGemm Calls TeamThreadRange over M*N meaning the flat M*N array // is split over all threads of the team - KokkosBlas::TeamGemm::invoke(member, gemm_args_.alpha, - svA, svB_col, - gemm_args_.beta, svC_col); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA, svB_col, gemm_args_.beta, svC_col); }); } @@ -1128,10 +1118,8 @@ struct parallel_batched_gemm_experiment2_3_4 { auto svC_row = Kokkos::subview(svC, lane_idx, Kokkos::ALL()); // TeamGemm Calls TeamThreadRange over M*N meaning the flat M*N array // is split over all threads of the team - KokkosBlas::TeamGemm::invoke(member, gemm_args_.alpha, - svA_row, svB, - gemm_args_.beta, svC_row); + KokkosBlas::TeamGemm::invoke( + member, gemm_args_.alpha, svA_row, svB, gemm_args_.beta, svC_row); }); } }; @@ -1412,10 +1400,8 @@ class parallel_batched_gemm_experiment6 { auto svC = Kokkos::subview(C, i, Kokkos::ALL(), Kokkos::ALL()); // Uses two serial for-loops internally - KokkosBlas::TeamVectorGemm::invoke(member, gemm_args.alpha, - svA, svB, gemm_args.beta, - svC); + KokkosBlas::TeamVectorGemm::invoke( + member, gemm_args.alpha, svA, svB, gemm_args.beta, svC); } }; diff --git a/src/blas/KokkosBlas3_gemm.hpp b/src/blas/KokkosBlas3_gemm.hpp index 659f695290..1719075661 100644 --- a/src/blas/KokkosBlas3_gemm.hpp +++ b/src/blas/KokkosBlas3_gemm.hpp @@ -277,11 +277,10 @@ struct SerialGemm { /// Team Impl /// ========= -template +template struct TeamGemm { - template + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType& member, const ScalarType alpha, const AViewType& A, const BViewType& B, const ScalarType beta, const CViewType& C); @@ -291,11 +290,10 @@ struct TeamGemm { /// TeamVector Impl /// ========= -template +template struct TeamVectorGemm { - template + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType& member, const ScalarType alpha, const AViewType& A, const BViewType& B, const ScalarType beta, const CViewType& C); @@ -304,20 +302,19 @@ struct TeamVectorGemm { /// /// Selective Interface /// -template +template struct Gemm { - template + template KOKKOS_FORCEINLINE_FUNCTION static int invoke( const MemberType& member, const ScalarType alpha, const AViewType& A, const BViewType& B, const ScalarType beta, const CViewType& C); }; -template -struct Gemm { - template +struct Gemm { + template KOKKOS_FORCEINLINE_FUNCTION static int invoke(const MemberType& /* member */, const ScalarType alpha, @@ -330,29 +327,27 @@ struct Gemm { } }; -template -struct Gemm { - template +template +struct Gemm { + template KOKKOS_FORCEINLINE_FUNCTION static int invoke( const MemberType& member, const ScalarType alpha, const AViewType& A, const BViewType& B, const ScalarType beta, const CViewType& C) { - return TeamGemm::invoke( - member, alpha, A, B, beta, C); + return TeamGemm::invoke(member, alpha, A, B, + beta, C); } }; -template -struct Gemm { - template +template +struct Gemm { + template KOKKOS_FORCEINLINE_FUNCTION static int invoke( const MemberType& member, const ScalarType alpha, const AViewType& A, const BViewType& B, const ScalarType beta, const CViewType& C) { - return TeamVectorGemm::invoke( - member, alpha, A, B, beta, C); + return TeamVectorGemm::invoke(member, alpha, + A, B, beta, C); } }; diff --git a/src/blas/impl/KokkosBlas3_team_gemm_impl.hpp b/src/blas/impl/KokkosBlas3_team_gemm_impl.hpp index 0b1ed597af..f4cc752d1e 100644 --- a/src/blas/impl/KokkosBlas3_team_gemm_impl.hpp +++ b/src/blas/impl/KokkosBlas3_team_gemm_impl.hpp @@ -64,11 +64,10 @@ namespace KokkosBlas { /// NT/NT /// -template -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -81,11 +80,10 @@ struct TeamGemm -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -102,11 +100,10 @@ struct TeamGemm -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -119,11 +116,10 @@ struct TeamGemm -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -140,11 +136,10 @@ struct TeamGemm -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -157,11 +152,10 @@ struct TeamGemm -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -178,11 +172,10 @@ struct TeamGemm -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -195,11 +188,10 @@ struct TeamGemm -struct TeamGemm { - template +template <> +struct TeamGemm { + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -224,11 +216,11 @@ struct TeamGemm -struct TeamVectorGemm +struct TeamVectorGemm { - template + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -245,11 +237,11 @@ struct TeamVectorGemm -struct TeamVectorGemm +struct TeamVectorGemm { - template + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -266,11 +258,11 @@ struct TeamVectorGemm -struct TeamVectorGemm +struct TeamVectorGemm { - template + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { @@ -287,11 +279,11 @@ struct TeamVectorGemm -struct TeamVectorGemm +struct TeamVectorGemm { - template + template KOKKOS_INLINE_FUNCTION static int invoke( const MemberType &member, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C) { diff --git a/unit_test/batched/dense/Test_Batched_TeamInverseLU.hpp b/unit_test/batched/dense/Test_Batched_TeamInverseLU.hpp index 712f7b235b..f5f991effc 100644 --- a/unit_test/batched/dense/Test_Batched_TeamInverseLU.hpp +++ b/unit_test/batched/dense/Test_Batched_TeamInverseLU.hpp @@ -52,7 +52,7 @@ struct Functor_BatchedTeamGemm { } member.team_barrier(); - KokkosBlas::TeamGemm::invoke(member, _alpha, aa, bb, _beta, cc); diff --git a/unit_test/batched/dense/Test_Batched_TeamSolveLU.hpp b/unit_test/batched/dense/Test_Batched_TeamSolveLU.hpp index 1f1245db49..ee2a0b703d 100644 --- a/unit_test/batched/dense/Test_Batched_TeamSolveLU.hpp +++ b/unit_test/batched/dense/Test_Batched_TeamSolveLU.hpp @@ -52,7 +52,7 @@ struct Functor_BatchedTeamGemm { } member.team_barrier(); - KokkosBlas::TeamGemm::invoke(member, _alpha, aa, bb, _beta, cc); diff --git a/unit_test/batched/dense/Test_Batched_TeamVectorSolveUTV2.hpp b/unit_test/batched/dense/Test_Batched_TeamVectorSolveUTV2.hpp index 7ee69fb1de..5760a80c89 100644 --- a/unit_test/batched/dense/Test_Batched_TeamVectorSolveUTV2.hpp +++ b/unit_test/batched/dense/Test_Batched_TeamVectorSolveUTV2.hpp @@ -81,8 +81,7 @@ struct Functor_TestBatchedTeamVectorSolveUTV2 { TeamVectorCopy::invoke(member, aa, ac); /// bb = AA*xx - KokkosBlas::TeamVectorGemm::invoke(member, one, aa, xx, zero, bb); member.team_barrier(); diff --git a/unit_test/blas/Test_Blas3_team_gemm.hpp b/unit_test/blas/Test_Blas3_team_gemm.hpp index def6548bb4..7e8267efac 100644 --- a/unit_test/blas/Test_Blas3_team_gemm.hpp +++ b/unit_test/blas/Test_Blas3_team_gemm.hpp @@ -42,7 +42,7 @@ struct Functor_TestBatchedTeamGemm { auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); - KokkosBlas::TeamGemm::invoke(member, _alpha, aa, bb, _beta, cc); diff --git a/unit_test/blas/Test_Blas3_teamvector_gemm.hpp b/unit_test/blas/Test_Blas3_teamvector_gemm.hpp index 095cd86195..add323d425 100644 --- a/unit_test/blas/Test_Blas3_teamvector_gemm.hpp +++ b/unit_test/blas/Test_Blas3_teamvector_gemm.hpp @@ -39,7 +39,7 @@ struct Functor_TestBatchedTeamVector { auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); - KokkosBlas::TeamVectorGemm::invoke(member, _alpha, aa, bb, _beta, cc);