Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creation of the 'lapack' subdirectory, parallel to 'blas' #1985

Merged
merged 22 commits into from
Oct 26, 2023
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ IF (KokkosKernels_INSTALL_TESTING)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(batched/dense/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(batched/sparse/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(blas/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(lapack/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(graph/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(sparse/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(ode/unit_test)
Expand Down Expand Up @@ -192,7 +193,7 @@ ELSE()
"ALL"
STRING
"A list of components to enable in testing and building"
VALID_ENTRIES BATCHED BLAS GRAPH SPARSE ALL
VALID_ENTRIES BATCHED BLAS LAPACK GRAPH SPARSE ALL
)

# ==================================================================
Expand Down Expand Up @@ -243,6 +244,7 @@ ELSE()
MESSAGE(" COMMON: ON")
MESSAGE(" BATCHED: ${KokkosKernels_ENABLE_COMPONENT_BATCHED}")
MESSAGE(" BLAS: ${KokkosKernels_ENABLE_COMPONENT_BLAS}")
MESSAGE(" LAPACK: ${KokkosKernels_ENABLE_COMPONENT_LAPACK}")
MESSAGE(" GRAPH: ${KokkosKernels_ENABLE_COMPONENT_GRAPH}")
MESSAGE(" SPARSE: ${KokkosKernels_ENABLE_COMPONENT_SPARSE}")
MESSAGE(" ODE: ${KokkosKernels_ENABLE_COMPONENT_ODE}")
Expand Down Expand Up @@ -287,6 +289,9 @@ ELSE()
IF (KokkosKernels_ENABLE_COMPONENT_BLAS)
INCLUDE(blas/CMakeLists.txt)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_LAPACK)
INCLUDE(lapack/CMakeLists.txt)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_GRAPH)
INCLUDE(graph/CMakeLists.txt)
ENDIF()
Expand Down Expand Up @@ -405,6 +410,9 @@ ELSE()
IF (KokkosKernels_ENABLE_COMPONENT_BLAS)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(blas/unit_test)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_LAPACK)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(lapack/unit_test)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_GRAPH)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(graph/unit_test)
ENDIF()
Expand Down
14 changes: 0 additions & 14 deletions blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,6 @@ KOKKOSKERNELS_GENERATE_ETI(Blas1_dot_mv dot
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas_gesv gesv
COMPONENTS blas
HEADER_LIST ETI_HEADERS
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas1_axpby axpby
COMPONENTS blas
HEADER_LIST ETI_HEADERS
Expand Down Expand Up @@ -331,10 +324,3 @@ KOKKOSKERNELS_GENERATE_ETI(Blas3_trmm trmm
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas_trtri trtri
COMPONENTS blas
HEADER_LIST ETI_HEADERS
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)
101 changes: 3 additions & 98 deletions blas/src/KokkosBlas_gesv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
#ifndef KOKKOSBLAS_GESV_HPP_
#define KOKKOSBLAS_GESV_HPP_

#include <type_traits>

#include "KokkosBlas_gesv_spec.hpp"
#include "KokkosKernels_Error.hpp"
#include "KokkosLapack_gesv.hpp"

namespace KokkosBlas {

Expand All @@ -49,100 +46,8 @@ namespace KokkosBlas {
/// its data pointer is NULL, pivoting is not used.
///
template <class AMatrix, class BXMV, class IPIVV>
void gesv(const AMatrix& A, const BXMV& B, const IPIVV& IPIV) {
// NOTE: Currently, KokkosBlas::gesv only supports for MAGMA TPL and BLAS TPL.
// MAGMA TPL should be enabled to call the MAGMA GPU interface for
// device views BLAS TPL should be enabled to call the BLAS interface
// for host views

static_assert(Kokkos::is_view<AMatrix>::value,
"KokkosBlas::gesv: A must be a Kokkos::View.");
static_assert(Kokkos::is_view<BXMV>::value,
"KokkosBlas::gesv: B must be a Kokkos::View.");
static_assert(Kokkos::is_view<IPIVV>::value,
"KokkosBlas::gesv: IPIV must be a Kokkos::View.");
static_assert(static_cast<int>(AMatrix::rank) == 2,
"KokkosBlas::gesv: A must have rank 2.");
static_assert(
static_cast<int>(BXMV::rank) == 1 || static_cast<int>(BXMV::rank) == 2,
"KokkosBlas::gesv: B must have either rank 1 or rank 2.");
static_assert(static_cast<int>(IPIVV::rank) == 1,
"KokkosBlas::gesv: IPIV must have rank 1.");

int64_t IPIV0 = IPIV.extent(0);
int64_t A0 = A.extent(0);
int64_t A1 = A.extent(1);
int64_t B0 = B.extent(0);

// Check validity of pivot argument
bool valid_pivot =
(IPIV0 == A1) || ((IPIV0 == 0) && (IPIV.data() == nullptr));
if (!(valid_pivot)) {
std::ostringstream os;
os << "KokkosBlas::gesv: IPIV: " << IPIV0 << ". "
<< "Valid options include zero-extent 1-D view (no pivoting), or 1-D "
"View with size of "
<< A0 << " (partial pivoting).";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

// Check for no pivoting case. Only MAGMA supports no pivoting interface
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA // have MAGMA TPL
#ifdef KOKKOSKERNELS_ENABLE_TPL_BLAS // and have BLAS TPL
if ((!std::is_same<typename AMatrix::device_type::memory_space,
Kokkos::CudaSpace>::value) &&
(IPIV0 == 0) && (IPIV.data() == nullptr)) {
std::ostringstream os;
os << "KokkosBlas::gesv: IPIV: " << IPIV0 << ". "
<< "BLAS TPL does not support no pivoting.";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}
#endif
#else // not have MAGMA TPL
#ifdef KOKKOSKERNELS_ENABLE_TPL_BLAS // but have BLAS TPL
if ((IPIV0 == 0) && (IPIV.data() == nullptr)) {
std::ostringstream os;
os << "KokkosBlas::gesv: IPIV: " << IPIV0 << ". "
<< "BLAS TPL does not support no pivoting.";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}
#endif
#endif

// Check compatibility of dimensions at run time.
if ((A0 < A1) || (A0 != B0)) {
std::ostringstream os;
os << "KokkosBlas::gesv: Dimensions of A, and B do not match: "
<< " A: " << A.extent(0) << " x " << A.extent(1) << " B: " << B.extent(0)
<< " x " << B.extent(1);
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

typedef Kokkos::View<
typename AMatrix::non_const_value_type**, typename AMatrix::array_layout,
typename AMatrix::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
AMatrix_Internal;
typedef Kokkos::View<typename BXMV::non_const_value_type**,
typename BXMV::array_layout, typename BXMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
BXMV_Internal;
typedef Kokkos::View<
typename IPIVV::non_const_value_type*, typename IPIVV::array_layout,
typename IPIVV::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
IPIVV_Internal;
AMatrix_Internal A_i = A;
// BXMV_Internal B_i = B;
IPIVV_Internal IPIV_i = IPIV;

if (BXMV::rank == 1) {
auto B_i = BXMV_Internal(B.data(), B.extent(0), 1);
KokkosBlas::Impl::GESV<AMatrix_Internal, BXMV_Internal,
IPIVV_Internal>::gesv(A_i, B_i, IPIV_i);
} else { // BXMV::rank == 2
auto B_i = BXMV_Internal(B.data(), B.extent(0), B.extent(1));
KokkosBlas::Impl::GESV<AMatrix_Internal, BXMV_Internal,
IPIVV_Internal>::gesv(A_i, B_i, IPIV_i);
}
[[deprecated]] void gesv(const AMatrix& A, const BXMV& B, const IPIVV& IPIV) {
KokkosLapack::gesv(A, B, IPIV);
}

} // namespace KokkosBlas
Expand Down
74 changes: 4 additions & 70 deletions blas/src/KokkosBlas_trtri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@

/// \file KokkosBlas_trtri.hpp

#include "KokkosKernels_Macros.hpp"
#include "KokkosBlas_trtri_spec.hpp"
#include "KokkosKernels_helpers.hpp"
#include <sstream>
#include <type_traits>
#include "KokkosKernels_Error.hpp"
#include "KokkosLapack_trtri.hpp"

namespace KokkosBlas {

Expand All @@ -48,70 +43,9 @@ namespace KokkosBlas {
// and the inversion could not be completed.
// source: https://software.intel.com/en-us/mkl-developer-reference-c-trtri
template <class AViewType>
int trtri(const char uplo[], const char diag[], const AViewType& A) {
static_assert(Kokkos::is_view<AViewType>::value,
"AViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2,
"AViewType must have rank 2.");

// Check validity of indicator argument
bool valid_uplo = (uplo[0] == 'U') || (uplo[0] == 'u') || (uplo[0] == 'L') ||
(uplo[0] == 'l');
bool valid_diag = (diag[0] == 'U') || (diag[0] == 'u') || (diag[0] == 'N') ||
(diag[0] == 'n');

if (!valid_uplo) {
std::ostringstream os;
os << "KokkosBlas::trtri: uplo = '" << uplo[0] << "'. "
<< "Valid values include 'U' or 'u' (A is upper triangular), "
"'L' or 'l' (A is lower triangular).";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}
if (!valid_diag) {
std::ostringstream os;
os << "KokkosBlas::trtri: diag = '" << diag[0] << "'. "
<< "Valid values include 'U' or 'u' (the diagonal of A is assumed to be "
"unit), "
"'N' or 'n' (the diagonal of A is assumed to be non-unit).";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

int64_t A_m = A.extent(0);
int64_t A_n = A.extent(1);

// Return if degenerated matrices are provided
if (A_m == 0 || A_n == 0)
return 0; // This is success as the inverse of a matrix with no elements is
// itself.

// Ensure that the dimensions of A match and that we can legally perform A*B
// or B*A
if (A_m != A_n) {
std::ostringstream os;
os << "KokkosBlas::trtri: Dimensions of A do not match,"
<< " A: " << A.extent(0) << " x " << A.extent(1);
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

// Create A matrix view type alias
using AViewInternalType =
Kokkos::View<typename AViewType::non_const_value_type**,
typename AViewType::array_layout,
typename AViewType::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >;

// This is the return value type and should always reside on host
using RViewInternalType =
Kokkos::View<int, Kokkos::LayoutRight, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >;

int result;
RViewInternalType R = RViewInternalType(&result);

KokkosBlas::Impl::TRTRI<RViewInternalType, AViewInternalType>::trtri(R, uplo,
diag, A);

return result;
[[deprecated]] int trtri(const char uplo[], const char diag[],
const AViewType& A) {
return KokkosLapack::trtri(uplo, diag, A);
}

} // namespace KokkosBlas
Expand Down
Loading