Skip to content

Commit

Permalink
Merge pull request #1985 from eeprude/lapackDir
Browse files Browse the repository at this point in the history
Creation of the 'lapack' subdirectory, parallel to 'blas'
  • Loading branch information
lucbv authored Oct 26, 2023
2 parents 69379cb + d674964 commit e7b6c12
Show file tree
Hide file tree
Showing 40 changed files with 1,451 additions and 814 deletions.
10 changes: 9 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ IF (KokkosKernels_INSTALL_TESTING)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(batched/dense/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(batched/sparse/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(blas/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(lapack/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(graph/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(sparse/unit_test)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(ode/unit_test)
Expand Down Expand Up @@ -192,7 +193,7 @@ ELSE()
"ALL"
STRING
"A list of components to enable in testing and building"
VALID_ENTRIES BATCHED BLAS GRAPH SPARSE ALL
VALID_ENTRIES BATCHED BLAS LAPACK GRAPH SPARSE ALL
)

# ==================================================================
Expand Down Expand Up @@ -243,6 +244,7 @@ ELSE()
MESSAGE(" COMMON: ON")
MESSAGE(" BATCHED: ${KokkosKernels_ENABLE_COMPONENT_BATCHED}")
MESSAGE(" BLAS: ${KokkosKernels_ENABLE_COMPONENT_BLAS}")
MESSAGE(" LAPACK: ${KokkosKernels_ENABLE_COMPONENT_LAPACK}")
MESSAGE(" GRAPH: ${KokkosKernels_ENABLE_COMPONENT_GRAPH}")
MESSAGE(" SPARSE: ${KokkosKernels_ENABLE_COMPONENT_SPARSE}")
MESSAGE(" ODE: ${KokkosKernels_ENABLE_COMPONENT_ODE}")
Expand Down Expand Up @@ -287,6 +289,9 @@ ELSE()
IF (KokkosKernels_ENABLE_COMPONENT_BLAS)
INCLUDE(blas/CMakeLists.txt)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_LAPACK)
INCLUDE(lapack/CMakeLists.txt)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_GRAPH)
INCLUDE(graph/CMakeLists.txt)
ENDIF()
Expand Down Expand Up @@ -405,6 +410,9 @@ ELSE()
IF (KokkosKernels_ENABLE_COMPONENT_BLAS)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(blas/unit_test)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_LAPACK)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(lapack/unit_test)
ENDIF()
IF (KokkosKernels_ENABLE_COMPONENT_GRAPH)
KOKKOSKERNELS_ADD_TEST_DIRECTORIES(graph/unit_test)
ENDIF()
Expand Down
14 changes: 0 additions & 14 deletions blas/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,6 @@ KOKKOSKERNELS_GENERATE_ETI(Blas1_dot_mv dot
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas_gesv gesv
COMPONENTS blas
HEADER_LIST ETI_HEADERS
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas1_axpby axpby
COMPONENTS blas
HEADER_LIST ETI_HEADERS
Expand Down Expand Up @@ -331,10 +324,3 @@ KOKKOSKERNELS_GENERATE_ETI(Blas3_trmm trmm
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)

KOKKOSKERNELS_GENERATE_ETI(Blas_trtri trtri
COMPONENTS blas
HEADER_LIST ETI_HEADERS
SOURCE_LIST SOURCES
TYPE_LISTS FLOATS LAYOUTS DEVICES
)
101 changes: 3 additions & 98 deletions blas/src/KokkosBlas_gesv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
#ifndef KOKKOSBLAS_GESV_HPP_
#define KOKKOSBLAS_GESV_HPP_

#include <type_traits>

#include "KokkosBlas_gesv_spec.hpp"
#include "KokkosKernels_Error.hpp"
#include "KokkosLapack_gesv.hpp"

namespace KokkosBlas {

Expand All @@ -49,100 +46,8 @@ namespace KokkosBlas {
/// its data pointer is NULL, pivoting is not used.
///
template <class AMatrix, class BXMV, class IPIVV>
void gesv(const AMatrix& A, const BXMV& B, const IPIVV& IPIV) {
// NOTE: Currently, KokkosBlas::gesv only supports for MAGMA TPL and BLAS TPL.
// MAGMA TPL should be enabled to call the MAGMA GPU interface for
// device views BLAS TPL should be enabled to call the BLAS interface
// for host views

static_assert(Kokkos::is_view<AMatrix>::value,
"KokkosBlas::gesv: A must be a Kokkos::View.");
static_assert(Kokkos::is_view<BXMV>::value,
"KokkosBlas::gesv: B must be a Kokkos::View.");
static_assert(Kokkos::is_view<IPIVV>::value,
"KokkosBlas::gesv: IPIV must be a Kokkos::View.");
static_assert(static_cast<int>(AMatrix::rank) == 2,
"KokkosBlas::gesv: A must have rank 2.");
static_assert(
static_cast<int>(BXMV::rank) == 1 || static_cast<int>(BXMV::rank) == 2,
"KokkosBlas::gesv: B must have either rank 1 or rank 2.");
static_assert(static_cast<int>(IPIVV::rank) == 1,
"KokkosBlas::gesv: IPIV must have rank 1.");

int64_t IPIV0 = IPIV.extent(0);
int64_t A0 = A.extent(0);
int64_t A1 = A.extent(1);
int64_t B0 = B.extent(0);

// Check validity of pivot argument
bool valid_pivot =
(IPIV0 == A1) || ((IPIV0 == 0) && (IPIV.data() == nullptr));
if (!(valid_pivot)) {
std::ostringstream os;
os << "KokkosBlas::gesv: IPIV: " << IPIV0 << ". "
<< "Valid options include zero-extent 1-D view (no pivoting), or 1-D "
"View with size of "
<< A0 << " (partial pivoting).";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

// Check for no pivoting case. Only MAGMA supports no pivoting interface
#ifdef KOKKOSKERNELS_ENABLE_TPL_MAGMA // have MAGMA TPL
#ifdef KOKKOSKERNELS_ENABLE_TPL_BLAS // and have BLAS TPL
if ((!std::is_same<typename AMatrix::device_type::memory_space,
Kokkos::CudaSpace>::value) &&
(IPIV0 == 0) && (IPIV.data() == nullptr)) {
std::ostringstream os;
os << "KokkosBlas::gesv: IPIV: " << IPIV0 << ". "
<< "BLAS TPL does not support no pivoting.";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}
#endif
#else // not have MAGMA TPL
#ifdef KOKKOSKERNELS_ENABLE_TPL_BLAS // but have BLAS TPL
if ((IPIV0 == 0) && (IPIV.data() == nullptr)) {
std::ostringstream os;
os << "KokkosBlas::gesv: IPIV: " << IPIV0 << ". "
<< "BLAS TPL does not support no pivoting.";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}
#endif
#endif

// Check compatibility of dimensions at run time.
if ((A0 < A1) || (A0 != B0)) {
std::ostringstream os;
os << "KokkosBlas::gesv: Dimensions of A, and B do not match: "
<< " A: " << A.extent(0) << " x " << A.extent(1) << " B: " << B.extent(0)
<< " x " << B.extent(1);
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

typedef Kokkos::View<
typename AMatrix::non_const_value_type**, typename AMatrix::array_layout,
typename AMatrix::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
AMatrix_Internal;
typedef Kokkos::View<typename BXMV::non_const_value_type**,
typename BXMV::array_layout, typename BXMV::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >
BXMV_Internal;
typedef Kokkos::View<
typename IPIVV::non_const_value_type*, typename IPIVV::array_layout,
typename IPIVV::device_type, Kokkos::MemoryTraits<Kokkos::Unmanaged> >
IPIVV_Internal;
AMatrix_Internal A_i = A;
// BXMV_Internal B_i = B;
IPIVV_Internal IPIV_i = IPIV;

if (BXMV::rank == 1) {
auto B_i = BXMV_Internal(B.data(), B.extent(0), 1);
KokkosBlas::Impl::GESV<AMatrix_Internal, BXMV_Internal,
IPIVV_Internal>::gesv(A_i, B_i, IPIV_i);
} else { // BXMV::rank == 2
auto B_i = BXMV_Internal(B.data(), B.extent(0), B.extent(1));
KokkosBlas::Impl::GESV<AMatrix_Internal, BXMV_Internal,
IPIVV_Internal>::gesv(A_i, B_i, IPIV_i);
}
[[deprecated]] void gesv(const AMatrix& A, const BXMV& B, const IPIVV& IPIV) {
KokkosLapack::gesv(A, B, IPIV);
}

} // namespace KokkosBlas
Expand Down
74 changes: 4 additions & 70 deletions blas/src/KokkosBlas_trtri.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,7 @@

/// \file KokkosBlas_trtri.hpp

#include "KokkosKernels_Macros.hpp"
#include "KokkosBlas_trtri_spec.hpp"
#include "KokkosKernels_helpers.hpp"
#include <sstream>
#include <type_traits>
#include "KokkosKernels_Error.hpp"
#include "KokkosLapack_trtri.hpp"

namespace KokkosBlas {

Expand All @@ -48,70 +43,9 @@ namespace KokkosBlas {
// and the inversion could not be completed.
// source: https://software.intel.com/en-us/mkl-developer-reference-c-trtri
template <class AViewType>
int trtri(const char uplo[], const char diag[], const AViewType& A) {
static_assert(Kokkos::is_view<AViewType>::value,
"AViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2,
"AViewType must have rank 2.");

// Check validity of indicator argument
bool valid_uplo = (uplo[0] == 'U') || (uplo[0] == 'u') || (uplo[0] == 'L') ||
(uplo[0] == 'l');
bool valid_diag = (diag[0] == 'U') || (diag[0] == 'u') || (diag[0] == 'N') ||
(diag[0] == 'n');

if (!valid_uplo) {
std::ostringstream os;
os << "KokkosBlas::trtri: uplo = '" << uplo[0] << "'. "
<< "Valid values include 'U' or 'u' (A is upper triangular), "
"'L' or 'l' (A is lower triangular).";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}
if (!valid_diag) {
std::ostringstream os;
os << "KokkosBlas::trtri: diag = '" << diag[0] << "'. "
<< "Valid values include 'U' or 'u' (the diagonal of A is assumed to be "
"unit), "
"'N' or 'n' (the diagonal of A is assumed to be non-unit).";
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

int64_t A_m = A.extent(0);
int64_t A_n = A.extent(1);

// Return if degenerated matrices are provided
if (A_m == 0 || A_n == 0)
return 0; // This is success as the inverse of a matrix with no elements is
// itself.

// Ensure that the dimensions of A match and that we can legally perform A*B
// or B*A
if (A_m != A_n) {
std::ostringstream os;
os << "KokkosBlas::trtri: Dimensions of A do not match,"
<< " A: " << A.extent(0) << " x " << A.extent(1);
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

// Create A matrix view type alias
using AViewInternalType =
Kokkos::View<typename AViewType::non_const_value_type**,
typename AViewType::array_layout,
typename AViewType::device_type,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >;

// This is the return value type and should always reside on host
using RViewInternalType =
Kokkos::View<int, Kokkos::LayoutRight, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >;

int result;
RViewInternalType R = RViewInternalType(&result);

KokkosBlas::Impl::TRTRI<RViewInternalType, AViewInternalType>::trtri(R, uplo,
diag, A);

return result;
[[deprecated]] int trtri(const char uplo[], const char diag[],
const AViewType& A) {
return KokkosLapack::trtri(uplo, diag, A);
}

} // namespace KokkosBlas
Expand Down
Loading

0 comments on commit e7b6c12

Please sign in to comment.