-
Notifications
You must be signed in to change notification settings - Fork 99
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1107 from brian-kelley/BatchedSVD
Batched serial SVD
- Loading branch information
Showing
6 changed files
with
907 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#ifndef __KOKKOSBATCHED_SVD_DECL_HPP__ | ||
#define __KOKKOSBATCHED_SVD_DECL_HPP__ | ||
|
||
/// \author Brian Kelley (bmkelle@sandia.gov) | ||
|
||
#include "KokkosBatched_Util.hpp" | ||
#include "KokkosBatched_Vector.hpp" | ||
|
||
namespace KokkosBatched { | ||
|
||
/// Given a general matrix A (m x n), compute the full singular value decomposition (SVD): | ||
/// U * diag(s) * V^T = A. U/V are orthogonal and s contains nonnegative values in descending order. | ||
/// | ||
/// Currently only supports real-valued matrices. | ||
/// | ||
/// Parameters: | ||
/// [in] A | ||
/// General matrix (rank 2 view), m x n. | ||
/// The contents of A are overwritten and undefined after calling this function. | ||
/// [out] U | ||
/// m left singular vectors (in columns). Dimensions m*m. | ||
/// [out] Vt | ||
/// n right singular vectors (in rows). Dimensions n*n. | ||
/// [out] s | ||
/// min(m, n) singular values. | ||
/// [in] W | ||
/// 1D contiguous workspace. The required size is max(m, n). | ||
/// | ||
/// Preconditions: | ||
/// m == A.extent(0) == U.extent(0) == U.extent(1) | ||
/// n == A.extent(1) == V.extent(0) == V.extent(1) | ||
/// min(m, n) == s.extent(0) | ||
/// W.extent(0) >= max(m, n) | ||
/// W.stride(0) == 1 (contiguous) | ||
|
||
struct SVD_USV_Tag {}; | ||
struct SVD_S_Tag {}; | ||
// Note: Could easily add SV or US tags later if needed | ||
|
||
struct SerialSVD { | ||
//Version to compute full factorization: A == U * diag(s) * Vt | ||
template<typename AViewType, | ||
typename UViewType, | ||
typename VtViewType, | ||
typename SViewType, | ||
typename WViewType> | ||
KOKKOS_INLINE_FUNCTION | ||
static int | ||
invoke(SVD_USV_Tag, const AViewType &A, | ||
const UViewType &U, const SViewType &s, | ||
const VtViewType &Vt, const WViewType &W); | ||
|
||
//Version which computes only singular values | ||
template<typename AViewType, | ||
typename SViewType, | ||
typename WViewType> | ||
KOKKOS_INLINE_FUNCTION | ||
static int | ||
invoke(SVD_S_Tag, const AViewType &A, const SViewType &s, const WViewType &W); | ||
}; | ||
|
||
} /// end namespace KokkosBatched | ||
|
||
#include "KokkosBatched_SVD_Serial_Impl.hpp" | ||
|
||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
#ifndef __KOKKOSBATCHED_SVD_SERIAL_IMPL_HPP__ | ||
#define __KOKKOSBATCHED_SVD_SERIAL_IMPL_HPP__ | ||
|
||
/// \author Brian Kelley (bmkelle@sandia.gov) | ||
|
||
#include "KokkosBatched_SVD_Serial_Internal.hpp" | ||
|
||
namespace KokkosBatched { | ||
//Version which computes the full factorization | ||
template<typename AViewType, | ||
typename UViewType, | ||
typename VViewType, | ||
typename SViewType, | ||
typename WViewType> | ||
KOKKOS_INLINE_FUNCTION | ||
int SerialSVD:: | ||
invoke(SVD_USV_Tag, const AViewType &A, | ||
const UViewType &U, const SViewType &sigma, | ||
const VViewType &Vt, const WViewType &work) | ||
{ | ||
using value_type = typename AViewType::non_const_value_type; | ||
return KokkosBatched::SerialSVDInternal::invoke<value_type> | ||
(A.extent(0), A.extent(1), | ||
A.data(), A.stride(0), A.stride(1), | ||
U.data(), U.stride(0), U.stride(1), | ||
Vt.data(), Vt.stride(0), Vt.stride(1), | ||
sigma.data(), sigma.stride(0), | ||
work.data()); | ||
} | ||
|
||
//Version which computes only singular values | ||
template<typename AViewType, | ||
typename SViewType, | ||
typename WViewType> | ||
KOKKOS_INLINE_FUNCTION | ||
int SerialSVD:: | ||
invoke(SVD_S_Tag, const AViewType &A, const SViewType &sigma, const WViewType &work) | ||
{ | ||
using value_type = typename AViewType::non_const_value_type; | ||
return KokkosBatched::SerialSVDInternal::invoke<value_type> | ||
(A.extent(0), A.extent(1), | ||
A.data(), A.stride(0), A.stride(1), | ||
nullptr, 0, 0, | ||
nullptr, 0, 0, | ||
sigma.data(), sigma.stride(0), | ||
work.data()); | ||
} | ||
|
||
} /// end namespace KokkosBatched | ||
|
||
#endif |
Oops, something went wrong.