Skip to content

Commit

Permalink
Lapack - SVD: fixing some issue and adding MKL support
Browse files Browse the repository at this point in the history
Will probably need to retest other TPLs since a sign in the
test seemed to be wrong : (
  • Loading branch information
lucbv committed Jan 25, 2024
1 parent 712fc44 commit 7eb04dd
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 10 deletions.
5 changes: 3 additions & 2 deletions lapack/impl/KokkosLapack_svd_spec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,9 @@ template <class ExecutionSpace, class AMatrix, class SVector, class UMatrix,
bool eti_spec_avail = svd_eti_spec_avail<
ExecutionSpace, AMatrix, SVector, UMatrix, VMatrix>::value>
struct SVD {
static void svd(const ExecutionSpace &space, const AMatrix &A,
const SVector &S, const UMatrix &U, const VMatrix &Vt);
static void svd(const ExecutionSpace &space, const char jobu[],
const char jobvt[], const AMatrix &A, const SVector &S,
const UMatrix &U, const VMatrix &Vt);
};

#if !defined(KOKKOSKERNELS_ETI_ONLY) || KOKKOSKERNELS_IMPL_COMPILE_LIBRARY
Expand Down
6 changes: 3 additions & 3 deletions lapack/src/KokkosLapack_svd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,21 +121,21 @@ void svd(const ExecutionSpace& space, const char jobu[], const char jobvt[],
oss << "KokkosLapack::svd: both jobu and jobvt are invalid!\n"
<< "Possible values are A, S, O or N, submitted values are " << jobu[0]
<< " and " << jobvt[0] << "\n";
KokkosKernels::Impl::throw_runtime_exception(os.str());
KokkosKernels::Impl::throw_runtime_exception(oss.str());
}
if (is_jobu_invalid) {
std::ostringstream oss;
oss << "KokkosLapack::svd: jobu is invalid!\n"
<< "Possible values are A, S, O or N, submitted value is " << jobu[0]
<< "\n";
KokkosKernels::Impl::throw_runtime_exception(os.str());
KokkosKernels::Impl::throw_runtime_exception(oss.str());
}
if (is_jobvt_invalid) {
std::ostringstream oss;
oss << "KokkosLapack::svd: jobvt is invalid!\n"
<< "Possible values are A, S, O or N, submitted value is " << jobvt[0]
<< "\n";
KokkosKernels::Impl::throw_runtime_exception(os.str());
KokkosKernels::Impl::throw_runtime_exception(oss.str());
}

using AMatrix_Internal = Kokkos::View<
Expand Down
4 changes: 2 additions & 2 deletions lapack/tpls/KokkosLapack_svd_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct svd_tpl_spec_avail {
};

// LAPACK
#ifdef KOKKOSKERNELS_ENABLE_TPL_LAPACK
#if defined(KOKKOSKERNELS_ENABLE_TPL_LAPACK) || defined(KOKKOSKERNELS_ENABLE_TPL_MKL)
#define KOKKOSLAPACK_SVD_TPL_SPEC_AVAIL_LAPACK(SCALAR, LAYOUT, EXECSPACE) \
template <> \
struct svd_tpl_spec_avail< \
Expand Down Expand Up @@ -80,7 +80,7 @@ KOKKOSLAPACK_SVD_TPL_SPEC_AVAIL_LAPACK(Kokkos::complex<double>,
Kokkos::LayoutLeft, Kokkos::Threads)
#endif

#endif
#endif // KOKKOSKERNELS_ENABLE_TPL_LAPACK || KOKKOSKERNELS_ENABLE_TPL_MKL

// CUSOLVER
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSOLVER
Expand Down
160 changes: 160 additions & 0 deletions lapack/tpls/KokkosLapack_svd_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#ifndef KOKKOSLAPACK_SVD_TPL_SPEC_DECL_HPP_
#define KOKKOSLAPACK_SVD_TPL_SPEC_DECL_HPP_

#include "KokkosKernels_Error.hpp"
#include "Kokkos_ArithTraits.hpp"

namespace KokkosLapack {
Expand Down Expand Up @@ -197,6 +198,165 @@ KOKKOSLAPACK_SVD_LAPACK(Kokkos::complex<double>, Kokkos::LayoutLeft,
} // namespace KokkosLapack
#endif // KOKKOSKERNELS_ENABLE_TPL_LAPACK

#ifdef KOKKOSKERNELS_ENABLE_TPL_MKL
#include "mkl.h"

namespace KokkosLapack {
namespace Impl {

template <class ExecutionSpace, class AMatrix, class SVector, class UMatrix,
class VMatrix>
void mklSvdWrapper(const ExecutionSpace& /* space */, const char jobu[],
const char jobvt[], const AMatrix& A, const SVector& S,
const UMatrix& U, const VMatrix& Vt) {
using memory_space = typename AMatrix::memory_space;
using Scalar = typename AMatrix::non_const_value_type;
using Magnitude = typename SVector::non_const_value_type;
using ALayout_t = typename AMatrix::array_layout;
using ULayout_t = typename UMatrix::array_layout;
using VLayout_t = typename VMatrix::array_layout;

const lapack_int m = A.extent_int(0);
const lapack_int n = A.extent_int(1);
const lapack_int lda = std::is_same_v<ALayout_t, Kokkos::LayoutRight> ? A.stride(0)
: A.stride(1);
const lapack_int ldu = std::is_same_v<ULayout_t, Kokkos::LayoutRight> ? U.stride(0)
: U.stride(1);
const lapack_int ldvt = std::is_same_v<VLayout_t, Kokkos::LayoutRight>
? Vt.stride(0)
: Vt.stride(1);

Kokkos::View<Magnitude*, memory_space> rwork("svd rwork buffer",
Kokkos::min(m, n) - 1);
lapack_int ret = 0;
if constexpr (std::is_same_v<Scalar, float>) {
ret = LAPACKE_sgesvd(LAPACK_COL_MAJOR, jobu[0], jobvt[0], m, n,
A.data(), lda, S.data(), U.data(), ldu,
Vt.data(), ldvt, rwork.data());
}
if constexpr (std::is_same_v<Scalar, double>) {
ret = LAPACKE_dgesvd(LAPACK_COL_MAJOR, jobu[0], jobvt[0], m, n,
A.data(), lda, S.data(), U.data(), ldu,
Vt.data(), ldvt, rwork.data());
}
if constexpr (std::is_same_v<Scalar, Kokkos::complex<float>>) {
ret = LAPACKE_cgesvd(LAPACK_COL_MAJOR, jobu[0], jobvt[0], m, n,
reinterpret_cast<lapack_complex_float*>(A.data()),
lda, S.data(),
reinterpret_cast<lapack_complex_float*>(U.data()),
ldu,
reinterpret_cast<lapack_complex_float*>(Vt.data()),
ldvt, rwork.data());
}
if constexpr (std::is_same_v<Scalar, Kokkos::complex<double>>) {
ret = LAPACKE_zgesvd(LAPACK_COL_MAJOR, jobu[0], jobvt[0], m, n,
reinterpret_cast<lapack_complex_double*>(A.data()),
lda, S.data(),
reinterpret_cast<lapack_complex_double*>(U.data()),
ldu,
reinterpret_cast<lapack_complex_double*>(Vt.data()),
ldvt, rwork.data());
}

if(ret != 0) {
std::ostringstream os;
os << "KokkosLapack::svd: MKL failed with return value: " << ret << "\n";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}
}

#define KOKKOSLAPACK_SVD_MKL(SCALAR, LAYOUT, EXEC_SPACE) \
template <> \
struct SVD< \
EXEC_SPACE, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<Kokkos::ArithTraits<SCALAR>::mag_type*, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
true, \
svd_eti_spec_avail< \
EXEC_SPACE, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<Kokkos::ArithTraits<SCALAR>::mag_type*, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>>::value> { \
using AMatrix = \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
using SVector = \
Kokkos::View<Kokkos::ArithTraits<SCALAR>::mag_type*, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
using UMatrix = \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
using VMatrix = \
Kokkos::View<SCALAR**, LAYOUT, \
Kokkos::Device<EXEC_SPACE, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
\
static void svd(const EXEC_SPACE& space, const char jobu[], \
const char jobvt[], const AMatrix& A, const SVector& S, \
const UMatrix& U, const VMatrix& Vt) { \
Kokkos::Profiling::pushRegion("KokkosLapack::svd[TPL_LAPACK," #SCALAR \
"]"); \
svd_print_specialization<EXEC_SPACE, AMatrix, SVector, UMatrix, \
VMatrix>(); \
\
mklSvdWrapper(space, jobu, jobvt, A, S, U, Vt); \
Kokkos::Profiling::popRegion(); \
} \
};

#if defined(KOKKOS_ENABLE_SERIAL)
KOKKOSLAPACK_SVD_MKL(float, Kokkos::LayoutLeft, Kokkos::Serial)
KOKKOSLAPACK_SVD_MKL(double, Kokkos::LayoutLeft, Kokkos::Serial)
KOKKOSLAPACK_SVD_MKL(Kokkos::complex<float>, Kokkos::LayoutLeft,
Kokkos::Serial)
KOKKOSLAPACK_SVD_MKL(Kokkos::complex<double>, Kokkos::LayoutLeft,
Kokkos::Serial)
#endif

#if defined(KOKKOS_ENABLE_OPENMP)
KOKKOSLAPACK_SVD_MKL(float, Kokkos::LayoutLeft, Kokkos::OpenMP)
KOKKOSLAPACK_SVD_MKL(double, Kokkos::LayoutLeft, Kokkos::OpenMP)
KOKKOSLAPACK_SVD_MKL(Kokkos::complex<float>, Kokkos::LayoutLeft,
Kokkos::OpenMP)
KOKKOSLAPACK_SVD_MKL(Kokkos::complex<double>, Kokkos::LayoutLeft,
Kokkos::OpenMP)
#endif

#if defined(KOKKOS_ENABLE_THREADS)
KOKKOSLAPACK_SVD_MKL(float, Kokkos::LayoutLeft, Kokkos::Threads)
KOKKOSLAPACK_SVD_MKL(double, Kokkos::LayoutLeft, Kokkos::Threads)
KOKKOSLAPACK_SVD_MKL(Kokkos::complex<float>, Kokkos::LayoutLeft,
Kokkos::Threads)
KOKKOSLAPACK_SVD_MKL(Kokkos::complex<double>, Kokkos::LayoutLeft,
Kokkos::Threads)
#endif

} // namespace Impl
} // namespace KokkosLapack
#endif // KOKKOSKERNELS_ENABLE_TPL_MKL

// CUSOLVER
#ifdef KOKKOSKERNELS_ENABLE_TPL_CUSOLVER
#include "KokkosLapack_cusolver.hpp"
Expand Down
6 changes: 3 additions & 3 deletions lapack/unit_test/Test_Lapack_svd.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -150,9 +150,9 @@ int impl_analytic_svd() {
static_cast<scalar_type>(-3 / Kokkos::sqrt(10)),
static_cast<scalar_type>(1 / Kokkos::sqrt(10))};
std::vector<scalar_type> Vtref = {
static_cast<scalar_type>(1 / Kokkos::sqrt(2)),
static_cast<scalar_type>(1 / Kokkos::sqrt(2)),
static_cast<scalar_type>(-1 / Kokkos::sqrt(2)),
static_cast<scalar_type>(1 / Kokkos::sqrt(2)),
static_cast<scalar_type>(1 / Kokkos::sqrt(2))};

// Both rotations and reflections are valid
Expand All @@ -170,7 +170,7 @@ int impl_analytic_svd() {
sign_vec = (Vt_h(0, 0).real() > KAT_S::zero().real()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(0, 0), Vtref[0], 100 * eps);
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(1, 0), Vtref[1], 100 * eps);
sign_vec = (Vt_h(0, 1).real() < KAT_S::zero().real()) ? mag_one : -mag_one;
sign_vec = (Vt_h(0, 1).real() > KAT_S::zero().real()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(0, 1), Vtref[2], 100 * eps);
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(1, 1), Vtref[3], 100 * eps);
} else {
Expand All @@ -184,7 +184,7 @@ int impl_analytic_svd() {
sign_vec = (Vt_h(0, 0) > KAT_S::zero()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(0, 0), Vtref[0], 100 * eps);
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(1, 0), Vtref[1], 100 * eps);
sign_vec = (Vt_h(0, 1) < KAT_S::zero()) ? mag_one : -mag_one;
sign_vec = (Vt_h(0, 1) > KAT_S::zero()) ? mag_one : -mag_one;
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(0, 1), Vtref[2], 100 * eps);
EXPECT_NEAR_KK_REL(sign_vec * Vt_h(1, 1), Vtref[3], 100 * eps);
}
Expand Down

0 comments on commit 7eb04dd

Please sign in to comment.