Skip to content

Commit

Permalink
Lapack: change according to Brian's review
Browse files Browse the repository at this point in the history
The SpaceAccessibility of IPIVV needs to be modified for MAGMA.
The value_type of IPIVV needs to be rocblas_int when running with
rocSOLVER.

The types used for gesv_tpl_spec_avail and the actual TPL
instantiation where mismatched leading to linker error.
  • Loading branch information
lucbv committed Nov 15, 2023
1 parent 24c73c8 commit c06b8db
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 46 deletions.
8 changes: 8 additions & 0 deletions lapack/src/KokkosLapack_gesv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,17 @@ void gesv(const ExecutionSpace& space, const AMatrix& A, const BXMV& B,
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename BXMV::memory_space>::accessible);
#if defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA)
if constexpr (!std::is_same_v<ExecutionSpace, Kokkos::Cuda>) {
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename IPIVV::memory_space>::accessible);
}
#else
static_assert(
Kokkos::SpaceAccessibility<ExecutionSpace,
typename IPIVV::memory_space>::accessible);
#endif
static_assert(Kokkos::is_view<AMatrix>::value,
"KokkosLapack::gesv: A must be a Kokkos::View.");
static_assert(Kokkos::is_view<BXMV>::value,
Expand Down
13 changes: 9 additions & 4 deletions lapack/tpls/KokkosLapack_gesv_tpl_spec_avail.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,15 @@ KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<double>,
Kokkos::LayoutLeft, Kokkos::CudaSpace)
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<float>,
Kokkos::LayoutLeft, Kokkos::CudaSpace)

#endif
} // namespace Impl
} // namespace KokkosLapack

#ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSOLVER
#include <rocsolver/rocsolver.h>

namespace KokkosLapack {
namespace Impl {

#define KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_ROCSOLVER(SCALAR, LAYOUT, MEMSPACE) \
template <> \
Expand All @@ -88,7 +93,8 @@ KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex<float>,
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> >, \
Kokkos::View<int*, LAYOUT, Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::View<rocblas_int*, LAYOUT, \
Kokkos::Device<Kokkos::HIP, MEMSPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged> > > { \
enum : bool { value = true }; \
};
Expand All @@ -102,9 +108,8 @@ KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_ROCSOLVER(Kokkos::complex<double>,
KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_ROCSOLVER(Kokkos::complex<float>,
Kokkos::LayoutLeft, Kokkos::HIPSpace)

#endif

} // namespace Impl
} // namespace KokkosLapack
#endif // KOKKOSKERNELS_ENABLE_TPL_ROCSOLVER

#endif
81 changes: 39 additions & 42 deletions lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,36 +76,35 @@ void lapackGesvWrapper(const AViewType& A, const BViewType& B,
}
}

#define KOKKOSLAPACK_GESV_LAPACK(SCALAR, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
template <class ExecSpace> \
struct GESV< \
ExecSpace, \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<int*, LAYOUT, \
Kokkos::Device<Kokkos::DefaultHostExecutionSpace, \
Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
true, ETI_SPEC_AVAIL> { \
using AViewType = \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
using BViewType = \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
using PViewType = Kokkos::View<int*, LAYOUT, Kokkos::HostSpace, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
\
static void gesv(const ExecSpace& /* space */, const AViewType& A, \
const BViewType& B, const PViewType& IPIV) { \
Kokkos::Profiling::pushRegion("KokkosLapack::gesv[TPL_LAPACK," #SCALAR \
"]"); \
gesv_print_specialization<AViewType, BViewType, PViewType>(); \
lapackGesvWrapper(A, B, IPIV); \
Kokkos::Profiling::popRegion(); \
} \
#define KOKKOSLAPACK_GESV_LAPACK(SCALAR, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \
template <class ExecSpace> \
struct GESV< \
ExecSpace, \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
Kokkos::View<int*, LAYOUT, Kokkos::Device<ExecSpace, Kokkos::HostSpace>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>, \
true, ETI_SPEC_AVAIL> { \
using AViewType = \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
using BViewType = \
Kokkos::View<SCALAR**, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
using PViewType = \
Kokkos::View<int*, LAYOUT, Kokkos::Device<ExecSpace, MEM_SPACE>, \
Kokkos::MemoryTraits<Kokkos::Unmanaged>>; \
\
static void gesv(const ExecSpace& /* space */, const AViewType& A, \
const BViewType& B, const PViewType& IPIV) { \
Kokkos::Profiling::pushRegion("KokkosLapack::gesv[TPL_LAPACK," #SCALAR \
"]"); \
gesv_print_specialization<AViewType, BViewType, PViewType>(); \
lapackGesvWrapper(A, B, IPIV); \
Kokkos::Profiling::popRegion(); \
} \
};

KOKKOSLAPACK_GESV_LAPACK(float, Kokkos::LayoutLeft, Kokkos::HostSpace, true)
Expand Down Expand Up @@ -422,28 +421,26 @@ void rocsolverGesvWrapper(const ExecutionSpace& space, const IPIVViewType& IPIV,
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(
rocblas_set_stream(s.handle, space.hip_stream()));
if constexpr (std::is_same_v<Scalar, float>) {
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(
rocsolver_sgesv(s.handle, N, nrhs, A.data(), lda,
reinterpret_cast<rocblas_int*>(IPIV.data()), B.data(),
ldb, info.data()));
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_sgesv(s.handle, N, nrhs, A.data(),
lda, IPIV.data(), B.data(),
ldb, info.data()));
}
if constexpr (std::is_same_v<Scalar, double>) {
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(
rocsolver_dgesv(s.handle, N, nrhs, A.data(), lda,
reinterpret_cast<rocblas_int*>(IPIV.data()), B.data(),
ldb, info.data()));
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_dgesv(s.handle, N, nrhs, A.data(),
lda, IPIV.data(), B.data(),
ldb, info.data()));
}
if constexpr (std::is_same_v<Scalar, Kokkos::complex<float>>) {
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_cgesv(
s.handle, N, nrhs, reinterpret_cast<rocblas_float_complex*>(A.data()),
lda, reinterpret_cast<rocblas_int*>(IPIV.data()),
reinterpret_cast<rocblas_float_complex*>(B.data()), ldb, info.data()));
lda, IPIV.data(), reinterpret_cast<rocblas_float_complex*>(B.data()),
ldb, info.data()));
}
if constexpr (std::is_same_v<Scalar, Kokkos::complex<double>>) {
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_zgesv(
s.handle, N, nrhs, reinterpret_cast<rocblas_double_complex*>(A.data()),
lda, reinterpret_cast<rocblas_int*>(IPIV.data()),
reinterpret_cast<rocblas_double_complex*>(B.data()), ldb, info.data()));
lda, IPIV.data(), reinterpret_cast<rocblas_double_complex*>(B.data()),
ldb, info.data()));
}
KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_set_stream(s.handle, NULL));
}
Expand Down

0 comments on commit c06b8db

Please sign in to comment.