From c06b8db52636128517e169e1952b8b06f0a37335 Mon Sep 17 00:00:00 2001 From: Luc Berger-Vergiat Date: Tue, 14 Nov 2023 20:50:02 -0700 Subject: [PATCH] Lapack: change according to Brian's review The SpaceAccessibility of IPIVV needs to be modified for MAGMA. The value_type of IPIVV needs to be rocblas_int when running with rocSOLVER. The types used for gesv_tpl_spec_avail and the actual TPL instantiation where mismatched leading to linker error. --- lapack/src/KokkosLapack_gesv.hpp | 8 ++ .../tpls/KokkosLapack_gesv_tpl_spec_avail.hpp | 13 ++- .../tpls/KokkosLapack_gesv_tpl_spec_decl.hpp | 81 +++++++++---------- 3 files changed, 56 insertions(+), 46 deletions(-) diff --git a/lapack/src/KokkosLapack_gesv.hpp b/lapack/src/KokkosLapack_gesv.hpp index 74d2e01cf9..a37cfd95fe 100644 --- a/lapack/src/KokkosLapack_gesv.hpp +++ b/lapack/src/KokkosLapack_gesv.hpp @@ -65,9 +65,17 @@ void gesv(const ExecutionSpace& space, const AMatrix& A, const BXMV& B, static_assert( Kokkos::SpaceAccessibility::accessible); +#if defined(KOKKOSKERNELS_ENABLE_TPL_MAGMA) + if constexpr (!std::is_same_v) { + static_assert( + Kokkos::SpaceAccessibility::accessible); + } +#else static_assert( Kokkos::SpaceAccessibility::accessible); +#endif static_assert(Kokkos::is_view::value, "KokkosLapack::gesv: A must be a Kokkos::View."); static_assert(Kokkos::is_view::value, diff --git a/lapack/tpls/KokkosLapack_gesv_tpl_spec_avail.hpp b/lapack/tpls/KokkosLapack_gesv_tpl_spec_avail.hpp index fc8f634078..e7bc5425f7 100644 --- a/lapack/tpls/KokkosLapack_gesv_tpl_spec_avail.hpp +++ b/lapack/tpls/KokkosLapack_gesv_tpl_spec_avail.hpp @@ -75,10 +75,15 @@ KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::CudaSpace) KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::CudaSpace) - #endif +} // namespace Impl +} // namespace KokkosLapack #ifdef KOKKOSKERNELS_ENABLE_TPL_ROCSOLVER +#include + +namespace KokkosLapack { +namespace Impl { #define KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_ROCSOLVER(SCALAR, LAYOUT, MEMSPACE) \ template <> \ @@ -88,7 +93,8 @@ KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_MAGMA(Kokkos::complex, Kokkos::MemoryTraits >, \ Kokkos::View, \ Kokkos::MemoryTraits >, \ - Kokkos::View, \ + Kokkos::View, \ Kokkos::MemoryTraits > > { \ enum : bool { value = true }; \ }; @@ -102,9 +108,8 @@ KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_ROCSOLVER(Kokkos::complex, KOKKOSLAPACK_GESV_TPL_SPEC_AVAIL_ROCSOLVER(Kokkos::complex, Kokkos::LayoutLeft, Kokkos::HIPSpace) -#endif - } // namespace Impl } // namespace KokkosLapack +#endif // KOKKOSKERNELS_ENABLE_TPL_ROCSOLVER #endif diff --git a/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp b/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp index 957ac7c138..d3a71a0cfa 100644 --- a/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp +++ b/lapack/tpls/KokkosLapack_gesv_tpl_spec_decl.hpp @@ -76,36 +76,35 @@ void lapackGesvWrapper(const AViewType& A, const BViewType& B, } } -#define KOKKOSLAPACK_GESV_LAPACK(SCALAR, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ - template \ - struct GESV< \ - ExecSpace, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - Kokkos::View, \ - Kokkos::MemoryTraits>, \ - true, ETI_SPEC_AVAIL> { \ - using AViewType = \ - Kokkos::View, \ - Kokkos::MemoryTraits>; \ - using BViewType = \ - Kokkos::View, \ - Kokkos::MemoryTraits>; \ - using PViewType = Kokkos::View>; \ - \ - static void gesv(const ExecSpace& /* space */, const AViewType& A, \ - const BViewType& B, const PViewType& IPIV) { \ - Kokkos::Profiling::pushRegion("KokkosLapack::gesv[TPL_LAPACK," #SCALAR \ - "]"); \ - gesv_print_specialization(); \ - lapackGesvWrapper(A, B, IPIV); \ - Kokkos::Profiling::popRegion(); \ - } \ +#define KOKKOSLAPACK_GESV_LAPACK(SCALAR, LAYOUT, MEM_SPACE, ETI_SPEC_AVAIL) \ + template \ + struct GESV< \ + ExecSpace, \ + Kokkos::View, \ + Kokkos::MemoryTraits>, \ + Kokkos::View, \ + Kokkos::MemoryTraits>, \ + Kokkos::View, \ + Kokkos::MemoryTraits>, \ + true, ETI_SPEC_AVAIL> { \ + using AViewType = \ + Kokkos::View, \ + Kokkos::MemoryTraits>; \ + using BViewType = \ + Kokkos::View, \ + Kokkos::MemoryTraits>; \ + using PViewType = \ + Kokkos::View, \ + Kokkos::MemoryTraits>; \ + \ + static void gesv(const ExecSpace& /* space */, const AViewType& A, \ + const BViewType& B, const PViewType& IPIV) { \ + Kokkos::Profiling::pushRegion("KokkosLapack::gesv[TPL_LAPACK," #SCALAR \ + "]"); \ + gesv_print_specialization(); \ + lapackGesvWrapper(A, B, IPIV); \ + Kokkos::Profiling::popRegion(); \ + } \ }; KOKKOSLAPACK_GESV_LAPACK(float, Kokkos::LayoutLeft, Kokkos::HostSpace, true) @@ -422,28 +421,26 @@ void rocsolverGesvWrapper(const ExecutionSpace& space, const IPIVViewType& IPIV, KOKKOS_ROCBLAS_SAFE_CALL_IMPL( rocblas_set_stream(s.handle, space.hip_stream())); if constexpr (std::is_same_v) { - KOKKOS_ROCBLAS_SAFE_CALL_IMPL( - rocsolver_sgesv(s.handle, N, nrhs, A.data(), lda, - reinterpret_cast(IPIV.data()), B.data(), - ldb, info.data())); + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_sgesv(s.handle, N, nrhs, A.data(), + lda, IPIV.data(), B.data(), + ldb, info.data())); } if constexpr (std::is_same_v) { - KOKKOS_ROCBLAS_SAFE_CALL_IMPL( - rocsolver_dgesv(s.handle, N, nrhs, A.data(), lda, - reinterpret_cast(IPIV.data()), B.data(), - ldb, info.data())); + KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_dgesv(s.handle, N, nrhs, A.data(), + lda, IPIV.data(), B.data(), + ldb, info.data())); } if constexpr (std::is_same_v>) { KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_cgesv( s.handle, N, nrhs, reinterpret_cast(A.data()), - lda, reinterpret_cast(IPIV.data()), - reinterpret_cast(B.data()), ldb, info.data())); + lda, IPIV.data(), reinterpret_cast(B.data()), + ldb, info.data())); } if constexpr (std::is_same_v>) { KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocsolver_zgesv( s.handle, N, nrhs, reinterpret_cast(A.data()), - lda, reinterpret_cast(IPIV.data()), - reinterpret_cast(B.data()), ldb, info.data())); + lda, IPIV.data(), reinterpret_cast(B.data()), + ldb, info.data())); } KOKKOS_ROCBLAS_SAFE_CALL_IMPL(rocblas_set_stream(s.handle, NULL)); }