Skip to content

Commit

Permalink
Lapack - SVD: adding more unit-test and checks for TPL
Browse files Browse the repository at this point in the history
With the improvements made the unit-tests should now pass all CI
configurations. Next step will be to add the singular vectors
computation mode flags and some more testing.
  • Loading branch information
lucbv committed Jan 23, 2024
1 parent 2627498 commit ddd4cf3
Showing 1 changed file with 146 additions and 26 deletions.
172 changes: 146 additions & 26 deletions lapack/unit_test/Test_Lapack_svd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,81 @@
#include <gtest/gtest.h>
#include <Kokkos_Core.hpp>
#include <KokkosKernels_TestUtils.hpp>
#include <KokkosBlas3_gemm.hpp>

#include <KokkosLapack_svd.hpp>

#if defined(KOKKOSKERNELS_ENABLE_TPL_CUSOLVER)
#include "KokkosLapack_cusolver.hpp"
#endif

namespace Test {

template <class AMatrix, class SVector, class UMatrix, class VMatrix>
void check_triple_product(const AMatrix& A, const SVector& S, const UMatrix& U, const VMatrix& Vt) {
// After a successful SVD decomposition we have A=U*S*V
// So using gemm we should be able to compare the above
// triple product to the original matrix A.

using scalar_type = typename AMatrix::non_const_value_type;
using mag_type = typename Kokkos::ArithTraits<scalar_type>::mag_type;

AMatrix temp("intermediate U*S product", A.extent(0), A.extent(1));
AMatrix M("U*S*V product", A.extent(0), A.extent(1));

// First compute the left side of the product: temp = U*S
KokkosBlas::gemm("N", "N", 1, U, S, 0, temp);

// Second compute the right side of the product: M = temp*V = U*S*V
KokkosBlas::gemm("N", "C", 1, temp, Vt, 0, M);

typename AMatrix::HostMirror A_h = Kokkos::create_mirror_view(A);
typename AMatrix::HostMirror M_h = Kokkos::create_mirror_view(M);
Kokkos::deep_copy(A_h, A);
Kokkos::deep_copy(M_h, M);
constexpr mag_type tol = 100*Kokkos::ArithTraits<scalar_type>::eps();
for(int rowIdx = 0; rowIdx < A.extent_int(0); ++rowIdx) {
for(int colIdx = 0; colIdx < A.extent_int(1); ++colIdx) {
EXPECT_NEAR_KK_REL(A_h(rowIdx, colIdx), M_h(rowIdx, colIdx), tol);
}
}
}

template <class Matrix>
void check_unitary_orthogonal_matrix(const Matrix& M) {
// After a successful SVD decomposition the matrices
// U and V are unitary matrices. Thus we can check
// the property UUt=UtU=I and VVt=VtV=I using gemm.

using scalar_type = typename Matrix::non_const_value_type;
using mag_type = typename Kokkos::ArithTraits<scalar_type>::mag_type;
constexpr mag_type tol = 100*Kokkos::ArithTraits<scalar_type>::eps();

Matrix I0("M*Mt", M.extent(0), M.extent(0));
KokkosBlas::gemm("N", "C", 1, M, M, 0, I0);
typename Matrix::HostMirror I0_h = Kokkos::create_mirror_view(I0);
Kokkos::deep_copy(I0_h, I0);
for(int rowIdx = 0; rowIdx < M.extent_int(0); ++rowIdx) {
for(int colIdx = 0; colIdx < M.extent_int(0); ++colIdx) {
if(rowIdx == colIdx) {
EXPECT_NEAR_KK_REL(I0_h(rowIdx, colIdx), Kokkos::ArithTraits<scalar_type>::one(), tol);
} else {
EXPECT_NEAR_KK_REL(I0_h(rowIdx, colIdx), Kokkos::ArithTraits<scalar_type>::zero(), tol);
}
}
}

Matrix I1("Mt*M", M.extent(1), M.extent(1));
KokkosBlas::gemm("C", "N", 1, M, M, 0, I1);
typename Matrix::HostMirror I1_h = Kokkos::create_mirror_view(I1);
Kokkos::deep_copy(I1_h, I1);
for(int rowIdx = 0; rowIdx < M.extent_int(1); ++rowIdx) {
for(int colIdx = 0; colIdx < M.extent_int(1); ++colIdx) {
if(rowIdx == colIdx) {
EXPECT_NEAR_KK_REL(I1_h(rowIdx, colIdx), Kokkos::ArithTraits<scalar_type>::one(), tol);
} else {
EXPECT_NEAR_KK_REL(I1_h(rowIdx, colIdx), Kokkos::ArithTraits<scalar_type>::zero(), tol);
}
}
}
}

template <class AMatrix, class Device>
int impl_analytic_svd() {
using scalar_type = typename AMatrix::value_type;
Expand All @@ -34,20 +100,27 @@ int impl_analytic_svd() {
using KAT_S = Kokkos::ArithTraits<scalar_type>;
using KAT_M = Kokkos::ArithTraits<mag_type>;

const mag_type mag_zero = KAT_M::zero();
const mag_type mag_one = KAT_M::one();
const mag_type eps = KAT_S::eps();

AMatrix A("A", 2, 2), U("U", 2, 2), Vt("Vt", 2, 2);
vector_type S("S", 2);

typename AMatrix::HostMirror A_h = Kokkos::create_mirror_view(A);

// A = [3 0]
// [4 5]
// USV = 1/sqrt(10) [1 -3] * sqrt(5) [3 0] * 1/sqrt(2) [ 1 1]
// [3 1] [0 1] [-1 1]
A_h(0, 0) = 3;
A_h(1, 0) = 4;
A_h(1, 1) = 5;

Kokkos::deep_copy(A, A_h);

#if defined(KOKKOSKERNELS_ENABLE_TPL_LAPACK) || defined(KOKKOSKERNELS_ENABLE_TPL_CUSOLVER) || defined(KOKKOSKERNELS_ENABLE_TPL_ROCSOLVER)

KokkosLapack::svd(A, S, U, Vt);
// Don't really need to fence here as we deep_copy right after...

typename vector_type::HostMirror S_h = Kokkos::create_mirror_view(S);
Kokkos::deep_copy(S_h, S);
Expand All @@ -58,8 +131,8 @@ int impl_analytic_svd() {

// The singular values for this problem
// are known: sqrt(45) and sqrt(5)
EXPECT_NEAR_KK_REL(S_h(0), static_cast<mag_type>(Kokkos::sqrt(45)), 100*KAT_S::eps());
EXPECT_NEAR_KK_REL(S_h(1), static_cast<mag_type>(Kokkos::sqrt( 5)), 100*KAT_S::eps());
EXPECT_NEAR_KK_REL(S_h(0), static_cast<mag_type>(Kokkos::sqrt(45)), 100*eps);
EXPECT_NEAR_KK_REL(S_h(1), static_cast<mag_type>(Kokkos::sqrt( 5)), 100*eps);

// The singular vectors should be identical
// or of oposite sign we check the first
Expand All @@ -68,20 +141,38 @@ int impl_analytic_svd() {
std::vector<scalar_type> Uref = {static_cast<scalar_type>(1 / Kokkos::sqrt(10)), static_cast<scalar_type>(3 / Kokkos::sqrt(10)), static_cast<scalar_type>(-3 / Kokkos::sqrt(10)), static_cast<scalar_type>(1 / Kokkos::sqrt(10))};
std::vector<scalar_type> Vtref = {static_cast<scalar_type>(1 / Kokkos::sqrt(2)), static_cast<scalar_type>(1 / Kokkos::sqrt(2)), static_cast<scalar_type>(-1 / Kokkos::sqrt(2)), static_cast<scalar_type>(1 / Kokkos::sqrt(2))};

mag_type sign_vec = (U_h(0, 0) > KAT_S::zero()) ? KAT_M::one() : -KAT_M::one();
EXPECT_NEAR_KK_REL(sign_vec*U_h(0, 0), Uref[0], 100*KAT_S::eps());
EXPECT_NEAR_KK_REL(sign_vec*U_h(1, 0), Uref[1], 100*KAT_S::eps());
sign_vec = (U_h(0, 1) < KAT_S::zero()) ? KAT_M::one() : -KAT_M::one();
EXPECT_NEAR_KK_REL(sign_vec*U_h(0, 1), Uref[2], 100*KAT_S::eps());
EXPECT_NEAR_KK_REL(sign_vec*U_h(1, 1), Uref[3], 100*KAT_S::eps());

sign_vec = (Vt_h(0, 0) > KAT_S::zero()) ? KAT_M::one() : -KAT_M::one();
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(0, 0), Vtref[0], 100*KAT_S::eps());
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(1, 0), Vtref[1], 100*KAT_S::eps());
sign_vec = (Vt_h(0, 1) < KAT_S::zero()) ? KAT_M::one() : -KAT_M::one();
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(0, 1), Vtref[2], 100*KAT_S::eps());
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(1, 1), Vtref[3], 100*KAT_S::eps());
#endif
// Both rotations and reflections are valid
// vector basis so we need to check both signs
// to confirm proper SVD was achieved.
if constexpr(KAT_S::is_complex) {
mag_type sign_vec = (U_h(0, 0).real() > KAT_S::zero().real()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*U_h(0, 0), Uref[0], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*U_h(1, 0), Uref[1], 100*eps);
sign_vec = (U_h(0, 1).real() < KAT_S::zero().real()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*U_h(0, 1), Uref[2], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*U_h(1, 1), Uref[3], 100*eps);

sign_vec = (Vt_h(0, 0).real() > KAT_S::zero().real()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(0, 0), Vtref[0], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(1, 0), Vtref[1], 100*eps);
sign_vec = (Vt_h(0, 1).real() < KAT_S::zero().real()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(0, 1), Vtref[2], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(1, 1), Vtref[3], 100*eps);
} else {
mag_type sign_vec = (U_h(0, 0) > KAT_S::zero()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*U_h(0, 0), Uref[0], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*U_h(1, 0), Uref[1], 100*eps);
sign_vec = (U_h(0, 1) < KAT_S::zero()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*U_h(0, 1), Uref[2], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*U_h(1, 1), Uref[3], 100*eps);

sign_vec = (Vt_h(0, 0) > KAT_S::zero()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(0, 0), Vtref[0], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(1, 0), Vtref[1], 100*eps);
sign_vec = (Vt_h(0, 1) < KAT_S::zero()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(0, 1), Vtref[2], 100*eps);
EXPECT_NEAR_KK_REL(sign_vec*Vt_h(1, 1), Vtref[3], 100*eps);
}

return 0;
}
Expand Down Expand Up @@ -144,12 +235,41 @@ int test_svd() {
return 1;
}

template <class Scalar, class Device>
int test_wrapper() {

#if defined(KOKKOSKERNELS_ENABLE_TPL_LAPACK) || defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
if constexpr(std::is_same_v<typename Device::memory_space, Kokkos::HostSpace>) {
// Using a device side space with LAPACK/MKL
return test_svd<Scalar, Device>();
}
#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_CUSOLVER)
if constexpr(std::is_same_v<typename Device::execution_space, Kokkos::Cuda>) {
// Using a Cuda device with CUSOLVER
return test_svd<Scalar, Device>();
}
#endif

#if defined(KOKKOSKERNELS_ENABLE_TPL_ROCSOLVER)
if constexpr(std::is_same_v<typename Device::execution_space, Kokkos::HIP>) {
// Using a HIP device with ROCSOLVER
return test_svd<Scalar, Device>();
}
#endif

std::cout << "No TPL support enabled, svd is not tested" << std::endl;
return 0;

}

#if defined(KOKKOSKERNELS_INST_FLOAT) || \
(!defined(KOKKOSKERNELS_ETI_ONLY) && \
!defined(KOKKOSKERNELS_IMPL_CHECK_ETI_CALLS))
TEST_F(TestCategory, svd_float) {
Kokkos::Profiling::pushRegion("KokkosLapack::Test::svd_float");
test_svd<float, TestDevice>();
test_wrapper<float, TestDevice>();
Kokkos::Profiling::popRegion();
}
#endif
Expand All @@ -159,7 +279,7 @@ TEST_F(TestCategory, svd_float) {
!defined(KOKKOSKERNELS_IMPL_CHECK_ETI_CALLS))
TEST_F(TestCategory, svd_double) {
Kokkos::Profiling::pushRegion("KokkosLapack::Test::svd_double");
test_svd<double, TestDevice>();
test_wrapper<double, TestDevice>();
Kokkos::Profiling::popRegion();
}
#endif
Expand All @@ -169,7 +289,7 @@ TEST_F(TestCategory, svd_double) {
!defined(KOKKOSKERNELS_IMPL_CHECK_ETI_CALLS))
TEST_F(TestCategory, svd_complex_float) {
Kokkos::Profiling::pushRegion("KokkosLapack::Test::svd_complex_float");
test_svd<Kokkos::complex<float>, TestDevice>();
test_wrapper<Kokkos::complex<float>, TestDevice>();
Kokkos::Profiling::popRegion();
}
#endif
Expand All @@ -179,7 +299,7 @@ TEST_F(TestCategory, svd_complex_float) {
!defined(KOKKOSKERNELS_IMPL_CHECK_ETI_CALLS))
TEST_F(TestCategory, svd_complex_double) {
Kokkos::Profiling::pushRegion("KokkosLapack::Test::svd_complex_double");
test_svd<Kokkos::complex<double>, TestDevice>();
test_wrapper<Kokkos::complex<double>, TestDevice>();
Kokkos::Profiling::popRegion();
}
#endif

0 comments on commit ddd4cf3

Please sign in to comment.