Skip to content

Commit

Permalink
Merge pull request #1107 from brian-kelley/BatchedSVD
Browse files Browse the repository at this point in the history
Batched serial SVD
  • Loading branch information
brian-kelley authored Sep 20, 2021
2 parents c8c0f21 + f270e32 commit 4faf97b
Show file tree
Hide file tree
Showing 6 changed files with 907 additions and 0 deletions.
66 changes: 66 additions & 0 deletions src/batched/dense/KokkosBatched_SVD_Decl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#ifndef __KOKKOSBATCHED_SVD_DECL_HPP__
#define __KOKKOSBATCHED_SVD_DECL_HPP__

/// \author Brian Kelley (bmkelle@sandia.gov)

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Vector.hpp"

namespace KokkosBatched {

/// Given a general matrix A (m x n), compute the full singular value decomposition (SVD):
/// U * diag(s) * V^T = A. U/V are orthogonal and s contains nonnegative values in descending order.
///
/// Currently only supports real-valued matrices.
///
/// Parameters:
/// [in] A
/// General matrix (rank 2 view), m x n.
/// The contents of A are overwritten and undefined after calling this function.
/// [out] U
/// m left singular vectors (in columns). Dimensions m*m.
/// [out] Vt
/// n right singular vectors (in rows). Dimensions n*n.
/// [out] s
/// min(m, n) singular values.
/// [in] W
/// 1D contiguous workspace. The required size is max(m, n).
///
/// Preconditions:
/// m == A.extent(0) == U.extent(0) == U.extent(1)
/// n == A.extent(1) == V.extent(0) == V.extent(1)
/// min(m, n) == s.extent(0)
/// W.extent(0) >= max(m, n)
/// W.stride(0) == 1 (contiguous)

struct SVD_USV_Tag {};
struct SVD_S_Tag {};
// Note: Could easily add SV or US tags later if needed

struct SerialSVD {
//Version to compute full factorization: A == U * diag(s) * Vt
template<typename AViewType,
typename UViewType,
typename VtViewType,
typename SViewType,
typename WViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(SVD_USV_Tag, const AViewType &A,
const UViewType &U, const SViewType &s,
const VtViewType &Vt, const WViewType &W);

//Version which computes only singular values
template<typename AViewType,
typename SViewType,
typename WViewType>
KOKKOS_INLINE_FUNCTION
static int
invoke(SVD_S_Tag, const AViewType &A, const SViewType &s, const WViewType &W);
};

} /// end namespace KokkosBatched

#include "KokkosBatched_SVD_Serial_Impl.hpp"

#endif
51 changes: 51 additions & 0 deletions src/batched/dense/impl/KokkosBatched_SVD_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
#ifndef __KOKKOSBATCHED_SVD_SERIAL_IMPL_HPP__
#define __KOKKOSBATCHED_SVD_SERIAL_IMPL_HPP__

/// \author Brian Kelley (bmkelle@sandia.gov)

#include "KokkosBatched_SVD_Serial_Internal.hpp"

namespace KokkosBatched {
//Version which computes the full factorization
template<typename AViewType,
typename UViewType,
typename VViewType,
typename SViewType,
typename WViewType>
KOKKOS_INLINE_FUNCTION
int SerialSVD::
invoke(SVD_USV_Tag, const AViewType &A,
const UViewType &U, const SViewType &sigma,
const VViewType &Vt, const WViewType &work)
{
using value_type = typename AViewType::non_const_value_type;
return KokkosBatched::SerialSVDInternal::invoke<value_type>
(A.extent(0), A.extent(1),
A.data(), A.stride(0), A.stride(1),
U.data(), U.stride(0), U.stride(1),
Vt.data(), Vt.stride(0), Vt.stride(1),
sigma.data(), sigma.stride(0),
work.data());
}

//Version which computes only singular values
template<typename AViewType,
typename SViewType,
typename WViewType>
KOKKOS_INLINE_FUNCTION
int SerialSVD::
invoke(SVD_S_Tag, const AViewType &A, const SViewType &sigma, const WViewType &work)
{
using value_type = typename AViewType::non_const_value_type;
return KokkosBatched::SerialSVDInternal::invoke<value_type>
(A.extent(0), A.extent(1),
A.data(), A.stride(0), A.stride(1),
nullptr, 0, 0,
nullptr, 0, 0,
sigma.data(), sigma.stride(0),
work.data());
}

} /// end namespace KokkosBatched

#endif
Loading

0 comments on commit 4faf97b

Please sign in to comment.