Skip to content

Commit

Permalink
fix: tests for pbtrs
Browse files Browse the repository at this point in the history
  • Loading branch information
Yuuichi Asahi committed Sep 25, 2024
1 parent 1ce6fd2 commit 99f223d
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 24 deletions.
3 changes: 3 additions & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
#include "Test_Batched_SerialPbtrf.hpp"
#include "Test_Batched_SerialPbtrf_Real.hpp"
#include "Test_Batched_SerialPbtrf_Complex.hpp"
#include "Test_Batched_SerialPbtrs.hpp"
#include "Test_Batched_SerialPbtrs_Real.hpp"
#include "Test_Batched_SerialPbtrs_Complex.hpp"

// Team Kernels
#include "Test_Batched_TeamAxpy.hpp"
Expand Down
43 changes: 19 additions & 24 deletions batched/dense/unit_test/Test_Batched_SerialPbtrs.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <gtest/gtest.h>
#include <Kokkos_Core.hpp>
#include <Kokkos_Random.hpp>

#include <KokkosBlas2_gemv.hpp>
#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Pbtrf.hpp"
#include "KokkosBatched_Pbtrs.hpp"
Expand Down Expand Up @@ -54,7 +54,7 @@ struct Functor_BatchedSerialPbtrf {
const std::string name_value_type = Test::value_type_name<value_type>();
std::string name = name_region + name_value_type;
Kokkos::RangePolicy<execution_space, ParamTagType> policy(0, _ab.extent(0));
Kokkos::parallel_for(name.c_str(), policy, *this, info_sum);
Kokkos::parallel_for(name.c_str(), policy, *this);
}
};

Expand Down Expand Up @@ -89,35 +89,34 @@ struct Functor_BatchedSerialPbtrs {
}
};

template <typename DeviceType, typename ScalarType, typename AViewType, typename BViewType, typename CViewType,
typename ArgTransA, typename ArgTransB>
struct Functor_BatchedSerialGemm {
template <typename DeviceType, typename ScalarType, typename AViewType, typename xViewType, typename yViewType>
struct Functor_BatchedSerialGemv {
using execution_space = typename DeviceType::execution_space;
AViewType _a;
BViewType _b;
CViewType _c;
xViewType _x;
yViewType _y;
ScalarType _alpha, _beta;

KOKKOS_INLINE_FUNCTION
Functor_BatchedSerialGemm(const ScalarType alpha, const AViewType &a, const BViewType &b, const ScalarType beta,
const CViewType &c)
: _a(a), _b(b), _c(c), _alpha(alpha), _beta(beta) {}
Functor_BatchedSerialGemv(const ScalarType alpha, const AViewType &a, const xViewType &x, const ScalarType beta,
const yViewType &y)
: _a(a), _x(x), _y(y), _alpha(alpha), _beta(beta) {}

KOKKOS_INLINE_FUNCTION
void operator()(const int k) const {
auto aa = Kokkos::subview(_a, k, Kokkos::ALL(), Kokkos::ALL());
auto bb = Kokkos::subview(_b, k, Kokkos::ALL(), Kokkos::ALL());
auto cc = Kokkos::subview(_c, k, Kokkos::ALL(), Kokkos::ALL());
auto xx = Kokkos::subview(_x, k, Kokkos::ALL());
auto yy = Kokkos::subview(_y, k, Kokkos::ALL());

KokkosBatched::SerialGemm<ArgTransA, ArgTransB, Algo::Gemm::Unblocked>::invoke(_alpha, aa, bb, _beta, cc);
KokkosBlas::SerialGemv<Trans::NoTranspose, Algo::Gemv::Unblocked>::invoke(_alpha, aa, xx, _beta, yy);
}

inline void run() {
using value_type = typename AViewType::non_const_value_type;
std::string name_region("KokkosBatched::Test::SerialPbtrf");
std::string name_region("KokkosBatched::Test::SerialPbtrs");
const std::string name_value_type = Test::value_type_name<value_type>();
std::string name = name_region + name_value_type;
Kokkos::RangePolicy<execution_space> policy(0, _a.extent(0));
Kokkos::RangePolicy<execution_space> policy(0, _x.extent(0));
Kokkos::parallel_for(name.c_str(), policy, *this);
}
};
Expand Down Expand Up @@ -190,9 +189,7 @@ void impl_test_batched_pbtrs_analytical(const int N) {
// Check x0 = x1
for (int ib = 0; ib < N; ib++) {
for (int i = 0; i < BlkSize; i++) {
for (int j = 0; j < BlkSize; j++) {
EXPECT_NEAR_KK(h_x0(ib, i, j), h_x_ref(ib, i, j), eps);
}
EXPECT_NEAR_KK(h_x0(ib, i), h_x_ref(ib, i), eps);
}
}
}
Expand Down Expand Up @@ -259,9 +256,7 @@ void impl_test_batched_pbtrs(const int N, const int k, const int BlkSize) {
// Check A * x0 = x_ref
for (int ib = 0; ib < N; ib++) {
for (int i = 0; i < BlkSize; i++) {
for (int j = 0; j < BlkSize; j++) {
EXPECT_NEAR_KK(h_y0(ib, i, j), h_x_ref(ib, i, j), eps);
}
EXPECT_NEAR_KK(h_y0(ib, i), h_x_ref(ib, i), eps);
}
}
}
Expand All @@ -274,12 +269,12 @@ int test_batched_pbtrs() {
#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT)
{
using LayoutType = Kokkos::LayoutLeft;
Test::pbtrs::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1);
Test::Pbtrs::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1);
Test::Pbtrs::impl_test_batched_pbtrs_analytical<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2);
for (int i = 0; i < 10; i++) {
int k = 1;
Test::pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1, k, i);
Test::pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2, k, i);
Test::Pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(1, k, i);
Test::Pbtrs::impl_test_batched_pbtrs<DeviceType, ScalarType, LayoutType, ParamTagType, AlgoTagType>(2, k, i);
}
}
#endif
Expand Down

0 comments on commit 99f223d

Please sign in to comment.