diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index a6a3a25c45..0455022740 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -58,6 +58,9 @@ #include "Test_Batched_SerialPbtrf.hpp" #include "Test_Batched_SerialPbtrf_Real.hpp" #include "Test_Batched_SerialPbtrf_Complex.hpp" +#include "Test_Batched_SerialPbtrs.hpp" +#include "Test_Batched_SerialPbtrs_Real.hpp" +#include "Test_Batched_SerialPbtrs_Complex.hpp" // Team Kernels #include "Test_Batched_TeamAxpy.hpp" diff --git a/batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp b/batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp index 054bb57d06..ddd58229ce 100644 --- a/batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp +++ b/batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp @@ -17,7 +17,7 @@ #include #include #include - +#include #include "KokkosBatched_Util.hpp" #include "KokkosBatched_Pbtrf.hpp" #include "KokkosBatched_Pbtrs.hpp" @@ -54,7 +54,7 @@ struct Functor_BatchedSerialPbtrf { const std::string name_value_type = Test::value_type_name(); std::string name = name_region + name_value_type; Kokkos::RangePolicy policy(0, _ab.extent(0)); - Kokkos::parallel_for(name.c_str(), policy, *this, info_sum); + Kokkos::parallel_for(name.c_str(), policy, *this); } }; @@ -89,35 +89,34 @@ struct Functor_BatchedSerialPbtrs { } }; -template -struct Functor_BatchedSerialGemm { +template +struct Functor_BatchedSerialGemv { using execution_space = typename DeviceType::execution_space; AViewType _a; - BViewType _b; - CViewType _c; + xViewType _x; + yViewType _y; ScalarType _alpha, _beta; KOKKOS_INLINE_FUNCTION - Functor_BatchedSerialGemm(const ScalarType alpha, const AViewType &a, const BViewType &b, const ScalarType beta, - const CViewType &c) - : _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {} + Functor_BatchedSerialGemv(const ScalarType alpha, const AViewType &a, const xViewType &x, const ScalarType beta, + const yViewType &y) + : _a(a), _x(x), _y(y), _alpha(alpha), _beta(beta) {} KOKKOS_INLINE_FUNCTION void operator()(const int k) const { auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL()); - auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL()); - auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL()); + auto xx = Kokkos::subview(_x, k, Kokkos::ALL()); + auto yy = Kokkos::subview(_y, k, Kokkos::ALL()); - KokkosBatched::SerialGemm::invoke(_alpha, aa, bb, _beta, cc); + KokkosBlas::SerialGemv::invoke(_alpha, aa, xx, _beta, yy); } inline void run() { using value_type = typename AViewType::non_const_value_type; - std::string name_region("KokkosBatched::Test::SerialPbtrf"); + std::string name_region("KokkosBatched::Test::SerialPbtrs"); const std::string name_value_type = Test::value_type_name(); std::string name = name_region + name_value_type; - Kokkos::RangePolicy policy(0, _a.extent(0)); + Kokkos::RangePolicy policy(0, _x.extent(0)); Kokkos::parallel_for(name.c_str(), policy, *this); } }; @@ -190,9 +189,7 @@ void impl_test_batched_pbtrs_analytical(const int N) { // Check x0 = x1 for (int ib = 0; ib < N; ib++) { for (int i = 0; i < BlkSize; i++) { - for (int j = 0; j < BlkSize; j++) { - EXPECT_NEAR_KK(h_x0(ib, i, j), h_x_ref(ib, i, j), eps); - } + EXPECT_NEAR_KK(h_x0(ib, i), h_x_ref(ib, i), eps); } } } @@ -259,9 +256,7 @@ void impl_test_batched_pbtrs(const int N, const int k, const int BlkSize) { // Check A * x0 = x_ref for (int ib = 0; ib < N; ib++) { for (int i = 0; i < BlkSize; i++) { - for (int j = 0; j < BlkSize; j++) { - EXPECT_NEAR_KK(h_y0(ib, i, j), h_x_ref(ib, i, j), eps); - } + EXPECT_NEAR_KK(h_y0(ib, i), h_x_ref(ib, i), eps); } } } @@ -274,12 +269,12 @@ int test_batched_pbtrs() { #if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) { using LayoutType = Kokkos::LayoutLeft; - Test::pbtrs::impl_test_batched_pbtrs_analytical(1); + Test::Pbtrs::impl_test_batched_pbtrs_analytical(1); Test::Pbtrs::impl_test_batched_pbtrs_analytical(2); for (int i = 0; i < 10; i++) { int k = 1; - Test::pbtrs::impl_test_batched_pbtrs(1, k, i); - Test::pbtrs::impl_test_batched_pbtrs(2, k, i); + Test::Pbtrs::impl_test_batched_pbtrs(1, k, i); + Test::Pbtrs::impl_test_batched_pbtrs(2, k, i); } } #endif