diff --git a/batched/dense/impl/KokkosBatched_Getrf_Serial_Impl.hpp b/batched/dense/impl/KokkosBatched_Getrf_Serial_Impl.hpp new file mode 100644 index 0000000000..1995f2bab6 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Getrf_Serial_Impl.hpp @@ -0,0 +1,67 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_ +#define KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_ + +#include +#include "KokkosBatched_Getrf_Serial_Internal.hpp" + +namespace KokkosBatched { +namespace Impl { +template +KOKKOS_INLINE_FUNCTION static int checkGetrfInput([[maybe_unused]] const AViewType &A, + [[maybe_unused]] const PivViewType &ipiv) { + static_assert(Kokkos::is_view_v, "KokkosBatched::getrf: AViewType is not a Kokkos::View."); + static_assert(Kokkos::is_view_v, "KokkosBatched::getrf: PivViewType is not a Kokkos::View."); + static_assert(AViewType::rank == 2, "KokkosBatched::getrf: AViewType must have rank 2."); + static_assert(PivViewType::rank == 1, "KokkosBatched::getrf: PivViewType must have rank 1."); +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + const int m = A.extent(0), n = A.extent(1); + const int npiv = ipiv.extent(0); + if (npiv != Kokkos::min(m, n)) { + Kokkos::printf( + "KokkosBatched::getrf: the dimension of the ipiv array must " + "satisfy ipiv.extent(0) == max(m, n): ipiv: %d, A: " + "%d " + "x %d \n", + npiv, m, n); + return 1; + } + +#endif + return 0; +} +} // namespace Impl + +template <> +struct SerialGetrf { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &ipiv) { + // Quick return if possible + if (A.extent(0) == 0 || A.extent(1) == 0) return 0; + + auto info = KokkosBatched::Impl::checkGetrfInput(A, ipiv); + if (info) return info; + KOKKOS_IF_ON_HOST((return KokkosBatched::Impl::SerialGetrfInternalHost::invoke(A, ipiv);)) + KOKKOS_IF_ON_DEVICE( + (return KokkosBatched::Impl::SerialGetrfInternalDevice::invoke(A, ipiv);)) + } +}; + +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_GETRF_SERIAL_IMPL_HPP_ diff --git a/batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp b/batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp new file mode 100644 index 0000000000..bfbe65f9c8 --- /dev/null +++ b/batched/dense/impl/KokkosBatched_Getrf_Serial_Internal.hpp @@ -0,0 +1,315 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#ifndef KOKKOSBATCHED_GETRF_SERIAL_INTERNAL_HPP_ +#define KOKKOSBATCHED_GETRF_SERIAL_INTERNAL_HPP_ + +#include +#include +#include +#include +#include +#include + +namespace KokkosBatched { +namespace Impl { + +struct Stack { + private: + constexpr static int STACK_SIZE = 48; + + // (state, m_start, n_start, piv_start, m_size, n_size, piv_size) + int m_stack[7][STACK_SIZE]; + int m_top; + + public: + KOKKOS_FUNCTION + Stack() : m_top(-1) {} // Initialize top to -1, indicating the stack is empty + + KOKKOS_INLINE_FUNCTION + void push(int values[]) { +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + if (m_top >= STACK_SIZE - 1) { + Kokkos::printf("Stack overflow: Cannot push, the stack is full.\n"); + return; + } +#endif + ++m_top; + for (int i = 0; i < 7; i++) { + // Increment top and add value + m_stack[i][m_top] = values[i]; + } + } + + KOKKOS_INLINE_FUNCTION + void pop(int values[]) { +#if (KOKKOSKERNELS_DEBUG_LEVEL > 0) + if (m_top < 0) { + // Check if the stack is empty + Kokkos::printf("Stack underflow: Cannot pop, the stack is empty."); + return; + } +#endif + for (int i = 0; i < 7; i++) { + // Return the top value and decrement top + values[i] = m_stack[i][m_top]; + } + m_top--; + } + + KOKKOS_INLINE_FUNCTION + bool isEmpty() const { return m_top == -1; } +}; + +// Host only implementation with recursive algorithm +template +struct SerialGetrfInternalHost { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &ipiv); +}; + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGetrfInternalHost::invoke(const AViewType &A, + const PivViewType &ipiv) { + using ScalarType = typename AViewType::non_const_value_type; + + const int m = A.extent(0), n = A.extent(1); + + // Quick return if possible + if (m <= 0 || n <= 0) return 0; + + int info = 0; + + // Use unblocked code for one row case + // Just need to handle ipiv and info + if (m == 1) { + ipiv(0) = 0; + if (A(0, 0) == 0) return 1; + + return 0; + } else if (n == 1) { + // Use unblocked code for one column case + // Compute machine safe minimum + auto col_A = Kokkos::subview(A, Kokkos::ALL, 0); + + int i = SerialIamax::invoke(col_A); + ipiv(0) = i; + + if (A(i, 0) == 0) return 1; + + // Apply the interchange + if (i != 0) { + Kokkos::kokkos_swap(A(i, 0), A(0, 0)); + } + + // Compute elements + const ScalarType alpha = 1.0 / A(0, 0); + auto sub_col_A = Kokkos::subview(A, Kokkos::pair(1, m), 0); + [[maybe_unused]] auto info_scal = KokkosBlas::SerialScale::invoke(alpha, sub_col_A); + + return 0; + } else { + // Use recursive code + auto n1 = Kokkos::min(m, n) / 2; + + // Factor A0 = [[A00], + // [A10]] + + // split A into two submatrices A = [A0, A1] + auto A0 = Kokkos::subview(A, Kokkos::ALL, Kokkos::pair(0, n1)); + auto A1 = Kokkos::subview(A, Kokkos::ALL, Kokkos::pair(n1, n)); + auto ipiv0 = Kokkos::subview(ipiv, Kokkos::pair(0, n1)); + auto iinfo = invoke(A0, ipiv0); + + if (info == 0 && iinfo > 0) info = iinfo; + + // Apply interchanges to A1 = [[A01], + // [A11]] + + [[maybe_unused]] auto info_laswp = KokkosBatched::SerialLaswp::invoke(ipiv0, A1); + + // split A into four submatrices + // A = [[A00, A01], + // [A10, A11]] + auto A00 = Kokkos::subview(A, Kokkos::pair(0, n1), Kokkos::pair(0, n1)); + auto A01 = Kokkos::subview(A, Kokkos::pair(0, n1), Kokkos::pair(n1, n)); + auto A10 = Kokkos::subview(A, Kokkos::pair(n1, m), Kokkos::pair(0, n1)); + auto A11 = Kokkos::subview(A, Kokkos::pair(n1, m), Kokkos::pair(n1, n)); + + // Solve A00 * X = A01 + [[maybe_unused]] auto info_trsm = KokkosBatched::SerialTrsm::invoke(1.0, A00, A01); + + // Update A11 = A11 - A10 * A01 + [[maybe_unused]] auto info_gemm = + KokkosBatched::SerialGemm::invoke(-1.0, A10, A01, + 1.0, A11); + + // Factor A11 + auto ipiv1 = Kokkos::subview(ipiv, Kokkos::pair(n1, Kokkos::min(m, n))); + iinfo = invoke(A11, ipiv1); + + if (info == 0 && iinfo > 0) info = iinfo + n1; + + // Apply interchanges to A10 + info_laswp = KokkosBatched::SerialLaswp::invoke(ipiv1, A10); + + // Pivot indices + for (int i = n1; i < Kokkos::min(m, n); i++) { + ipiv(i) += n1; + } + + return info; + } +} + +// Device only implementation with recursive algorithm +template +struct SerialGetrfInternalDevice { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &ipiv); +}; + +template <> +template +KOKKOS_INLINE_FUNCTION int SerialGetrfInternalDevice::invoke(const AViewType &A, + const PivViewType &ipiv) { + using ScalarType = typename AViewType::non_const_value_type; + + const int m = A.extent(0), n = A.extent(1), init_piv_size = ipiv.extent(0); + + Stack stack; + int initial[7] = {0, 0, 0, 0, m, n, init_piv_size}; + stack.push(initial); + + // Quick return if possible + if (m <= 0 || n <= 0) return 0; + + while (!stack.isEmpty()) { + // Firstly, make a subview based on the current state + int current[7]; + stack.pop(current); + + int state = current[0], m_start = current[1], n_start = current[2], piv_start = current[3], m_size = current[4], + n_size = current[5], piv_size = current[6]; + + // Quick return if possible + if (m_size <= 0 || n_size <= 0) continue; + + auto A_current = Kokkos::subview(A, Kokkos::pair(m_start, m_start + m_size), + Kokkos::pair(n_start, n_start + n_size)); + + auto ipiv_current = Kokkos::subview(ipiv, Kokkos::pair(piv_start, piv_start + piv_size)); + auto n1 = Kokkos::min(m_size, n_size) / 2; + + // split A into two submatrices A = [A0, A1] + auto A0 = Kokkos::subview(A_current, Kokkos::ALL, Kokkos::pair(0, n1)); + auto A1 = Kokkos::subview(A_current, Kokkos::ALL, Kokkos::pair(n1, n_size)); + auto ipiv0 = Kokkos::subview(ipiv_current, Kokkos::pair(0, n1)); + auto ipiv1 = Kokkos::subview(ipiv_current, Kokkos::pair(n1, Kokkos::min(m_size, n_size))); + + // split A into four submatrices + // A = [[A00, A01], + // [A10, A11]] + auto A00 = Kokkos::subview(A_current, Kokkos::pair(0, n1), Kokkos::pair(0, n1)); + auto A01 = Kokkos::subview(A_current, Kokkos::pair(0, n1), Kokkos::pair(n1, n_size)); + auto A10 = Kokkos::subview(A_current, Kokkos::pair(n1, m_size), Kokkos::pair(0, n1)); + auto A11 = Kokkos::subview(A_current, Kokkos::pair(n1, m_size), Kokkos::pair(n1, n_size)); + + if (state == 0) { + // start state + if (m_size == 1) { + ipiv_current(0) = 0; + if (A_current(0, 0) == 0) return 1; + continue; + } else if (n_size == 1) { + // Use unblocked code for one column case + // Compute machine safe minimum + auto col_A = Kokkos::subview(A_current, Kokkos::ALL, 0); + + int i = SerialIamax::invoke(col_A); + ipiv_current(0) = i; + + if (A_current(i, 0) == 0) return 1; + + // Apply the interchange + if (i != 0) { + Kokkos::kokkos_swap(A_current(i, 0), A_current(0, 0)); + } + + // Compute elements + const ScalarType alpha = 1.0 / A_current(0, 0); + auto sub_col_A = Kokkos::subview(A_current, Kokkos::pair(1, m_size), 0); + [[maybe_unused]] auto info_scal = KokkosBlas::SerialScale::invoke(alpha, sub_col_A); + continue; + } + + // Push states onto the stack in reverse order of how they are executed + // in the recursive version + int after_second[7] = {2, m_start, n_start, piv_start, m_size, n_size, piv_size}; + int second[7] = {0, + m_start + n1, + n_start + n1, + piv_start + n1, + m_size - n1, + n_size - n1, + static_cast(Kokkos::min(m_size, n_size)) - n1}; + int after_first[7] = {1, m_start, n_start, piv_start, m_size, n_size, piv_size}; + int first[7] = {0, m_start, n_start, piv_start, m_size, n1, n1}; + + stack.push(after_second); + stack.push(second); + stack.push(after_first); + stack.push(first); + + } else if (state == 1) { + // after first recursive call + // Factor A0 = [[A00], + // [A10]] + + // Apply interchanges to A1 = [[A01], + // [A11]] + KokkosBatched::SerialLaswp::invoke(ipiv0, A1); + + // Solve A00 * X = A01 + [[maybe_unused]] auto info_trsm = + KokkosBatched::SerialTrsm::invoke(1.0, A00, A01); + + // Update A11 = A11 - A10 * A01 + [[maybe_unused]] auto info_gemm = + KokkosBatched::SerialGemm::invoke( + -1.0, A10, A01, 1.0, A11); + + } else if (state == 2) { + // after second recursive call + // Apply interchanges to A10 + KokkosBatched::SerialLaswp::invoke(ipiv1, A10); + + // Pivot indices + for (int i = n1; i < Kokkos::min(m_size, n_size); i++) { + ipiv_current(i) += n1; + } + } + } + return 0; +} + +} // namespace Impl +} // namespace KokkosBatched + +#endif // KOKKOSBATCHED_GETRF_SERIAL_INTERNAL_HPP_ diff --git a/batched/dense/src/KokkosBatched_Getrf.hpp b/batched/dense/src/KokkosBatched_Getrf.hpp new file mode 100644 index 0000000000..1b1bcac903 --- /dev/null +++ b/batched/dense/src/KokkosBatched_Getrf.hpp @@ -0,0 +1,65 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#ifndef KOKKOSBATCHED_GETRF_HPP_ +#define KOKKOSBATCHED_GETRF_HPP_ + +#include + +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) + +namespace KokkosBatched { + +/// \brief Serial Batched Getrf: +/// Compute a LU factorization of a general m-by-n matrix A using partial +/// pivoting with row interchanges. +/// The factorization has the form +/// A = P * L * U +/// where P is a permutation matrix, L is lower triangular with unit +/// diagonal elements (lower trapezoidal if m > n), and U is upper +/// triangular (upper trapezoidal if m < n). +/// +/// This is the recusive version of the algorithm. It divides the matrix +/// into four submatrices: +/// A = [[A00, A01], +/// [A10, A11]] +/// where A00 is a square matrix of size n0, A11 is a matrix of size n1 by n1 +/// with n0 = min(m, n) / 2 and n1 = n - n0. +/// +/// This function calls itself to factorize A0 = [[A00], +// [A10]] +/// do the swaps on A1 = [[A01], +/// [A11]] +/// solve A01, update A11, then calls itself to factorize A11 +/// and do the swaps on A10. +/// \tparam AViewType: Input type for the matrix, needs to be a 2D view +/// \tparam PivViewType: Input type for the pivot indices, needs to be a 1D view +/// +/// \param A [inout]: A is a m by n general matrix, a rank 2 view +/// \param piv [out]: On exit, the pivot indices, a rank 1 view +/// +/// No nested parallel_for is used inside of the function. +/// + +template +struct SerialGetrf { + template + KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const PivViewType &piv); +}; +} // namespace KokkosBatched + +#include "KokkosBatched_Getrf_Serial_Impl.hpp" + +#endif // KOKKOSBATCHED_GETRF_HPP_ diff --git a/batched/dense/unit_test/Test_Batched_Dense.hpp b/batched/dense/unit_test/Test_Batched_Dense.hpp index 2378e5ff01..37673e1a5e 100644 --- a/batched/dense/unit_test/Test_Batched_Dense.hpp +++ b/batched/dense/unit_test/Test_Batched_Dense.hpp @@ -63,6 +63,7 @@ #include "Test_Batched_SerialPbtrs_Complex.hpp" #include "Test_Batched_SerialLaswp.hpp" #include "Test_Batched_SerialIamax.hpp" +#include "Test_Batched_SerialGetrf.hpp" // Team Kernels #include "Test_Batched_TeamAxpy.hpp" diff --git a/batched/dense/unit_test/Test_Batched_SerialGetrf.hpp b/batched/dense/unit_test/Test_Batched_SerialGetrf.hpp new file mode 100644 index 0000000000..24d99594ca --- /dev/null +++ b/batched/dense/unit_test/Test_Batched_SerialGetrf.hpp @@ -0,0 +1,458 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr) +#include +#include +#include +#include +#include +#include +#include "Test_Batched_DenseUtils.hpp" + +namespace Test { +namespace Getrf { + +template +struct Functor_BatchedSerialGetrf { + using execution_space = typename DeviceType::execution_space; + AViewType m_a; + PivViewType m_ipiv; + + KOKKOS_INLINE_FUNCTION + Functor_BatchedSerialGetrf(const AViewType &a, const PivViewType &ipiv) : m_a(a), m_ipiv(ipiv) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const int k, int &info) const { + auto sub_a = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto sub_ipiv = Kokkos::subview(m_ipiv, k, Kokkos::ALL()); + + info += KokkosBatched::SerialGetrf::invoke(sub_a, sub_ipiv); + } + + inline int run() { + using value_type = typename AViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialGetrf"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + int info_sum = 0; + Kokkos::Profiling::pushRegion(name.c_str()); + Kokkos::RangePolicy policy(0, m_a.extent(0)); + Kokkos::parallel_reduce(name.c_str(), policy, *this, info_sum); + Kokkos::Profiling::popRegion(); + return info_sum; + } +}; + +template +struct Functor_BatchedSerialGemm { + using execution_space = typename DeviceType::execution_space; + AViewType m_a; + BViewType m_b; + CViewType m_c; + ScalarType m_alpha, m_beta; + + KOKKOS_INLINE_FUNCTION + Functor_BatchedSerialGemm(const ScalarType alpha, const AViewType &a, const BViewType &b, const ScalarType beta, + const CViewType &c) + : m_a(a), m_b(b), m_c(c), m_alpha(alpha), m_beta(beta) {} + + KOKKOS_INLINE_FUNCTION + void operator()(const int k) const { + auto sub_a = Kokkos::subview(m_a, k, Kokkos::ALL(), Kokkos::ALL()); + auto sub_b = Kokkos::subview(m_b, k, Kokkos::ALL(), Kokkos::ALL()); + auto sub_c = Kokkos::subview(m_c, k, Kokkos::ALL(), Kokkos::ALL()); + + KokkosBatched::SerialGemm::invoke( + m_alpha, sub_a, sub_b, m_beta, sub_c); + } + + inline void run() { + using value_type = typename AViewType::non_const_value_type; + std::string name_region("KokkosBatched::Test::SerialGetrf"); + const std::string name_value_type = Test::value_type_name(); + std::string name = name_region + name_value_type; + Kokkos::RangePolicy policy(0, m_a.extent(0)); + Kokkos::parallel_for(name.c_str(), policy, *this); + } +}; + +/// \brief Implementation details of batched getrf test +/// LU factorization with partial pivoting +/// 4x4 matrix +/// A = [[1. 0. 0. 0.] +/// [0. 1. 0. 0.] +/// [0. 0. 1. 0.] +/// [0. 0. 0. 1.]] +/// LU = [[1. 0. 0. 0.] +/// [0. 1. 0. 0.] +/// [0. 0. 1. 0.] +/// [0. 0. 0. 1.]] +/// piv = [0 1 2 3] +/// +/// 3x4 matrix +/// A1 = [[1. 0. 0. 0.] +/// [0. 1. 0. 0.] +/// [0. 0. 1. 0.]] +/// LU1 = [[1. 0. 0. 0.] +/// [0. 1. 0. 0.] +/// [0. 0. 1. 0.]] +/// piv1 = [0 1 2] +/// +/// 4x3 matrix +/// A2 = [[1. 0. 0.] +/// [0. 1. 0.] +/// [0. 0. 1.] +/// [0. 0. 0.]] +/// LU2 = [[1. 0. 0.] +/// [0. 1. 0.] +/// [0. 0. 1.] +/// [0. 0. 0.]] +/// piv2 = [0 1 2] +/// 3x3 more general matrix +/// which satisfies PA = LU +/// P = [[0 0 1] +/// [1 0 0] +/// [0 1 0]] +/// A = [[1 2 3] +/// [2 -4 6] +/// [3 -9 -3]] +/// L = [[1 0 0] +/// [1/3 1 0] +/// [2/3 2/5 1]] +/// U = [[-3 -9 -3] +/// [ 0 5 4] +/// [ 0 0 32/5]] +/// Note P is obtained by piv = [2 2 2] +/// We compare the non-diagnoal elements of L only, which is +/// NL = [[0 0 0] +/// [1/3 0 0] +/// [2/3 2/5 0]] +/// \param Nb [in] Batch size of matrices +template +void impl_test_batched_getrf_analytical(const int Nb) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + using View3DType = Kokkos::View; + using PivView2DType = Kokkos::View; + + constexpr int M = 4, N = 3; + View3DType A0("A0", Nb, M, M), LU0("LU0", Nb, M, M); + PivView2DType ipiv0("ipiv0", Nb, M), ipiv0_ref("ipiv0_ref", Nb, M); + + // Non-square matrix + View3DType A1("A1", Nb, N, M), LU1("LU1", Nb, N, M); + PivView2DType ipiv1("ipiv1", Nb, N), ipiv1_ref("ipiv1_ref", Nb, N); + + View3DType A2("A2", Nb, M, N), LU2("LU2", Nb, M, N); + PivView2DType ipiv2("ipiv2", Nb, N), ipiv2_ref("ipiv1_ref", Nb, N); + + // Complicated matrix + View3DType A3("A3", Nb, N, N), LU3("LU3", Nb, N, N), L3("L3", Nb, N, N), U3("U3", Nb, N, N), + L3_ref("L3_ref", Nb, N, N), U3_ref("U3_ref", Nb, N, N); + PivView2DType ipiv3("ipiv3", Nb, N), ipiv3_ref("ipiv3_ref", Nb, N); + + auto h_A0 = Kokkos::create_mirror_view(A0); + auto h_A1 = Kokkos::create_mirror_view(A1); + auto h_A2 = Kokkos::create_mirror_view(A2); + auto h_A3 = Kokkos::create_mirror_view(A3); + auto h_L3_ref = Kokkos::create_mirror_view(L3_ref); + auto h_U3_ref = Kokkos::create_mirror_view(U3_ref); + auto h_ipiv0_ref = Kokkos::create_mirror_view(ipiv0_ref); + auto h_ipiv1_ref = Kokkos::create_mirror_view(ipiv1_ref); + auto h_ipiv2_ref = Kokkos::create_mirror_view(ipiv2_ref); + auto h_ipiv3_ref = Kokkos::create_mirror_view(ipiv3_ref); + for (int ib = 0; ib < Nb; ib++) { + for (int i = 0; i < M; i++) { + h_ipiv0_ref(ib, i) = i; + for (int j = 0; j < M; j++) { + h_A0(ib, i, j) = i == j ? 1.0 : 0.0; + } + } + + for (int i = 0; i < N; i++) { + h_ipiv1_ref(ib, i) = i; + h_ipiv2_ref(ib, i) = i; + for (int j = 0; j < M; j++) { + h_A1(ib, i, j) = i == j ? 1.0 : 0.0; + h_A2(ib, j, i) = i == j ? 1.0 : 0.0; + } + } + + h_A3(ib, 0, 0) = 1.0; + h_A3(ib, 0, 1) = 2.0; + h_A3(ib, 0, 2) = 3.0; + h_A3(ib, 1, 0) = 2.0; + h_A3(ib, 1, 1) = -4.0; + h_A3(ib, 1, 2) = 6.0; + h_A3(ib, 2, 0) = 3.0; + h_A3(ib, 2, 1) = -9.0; + h_A3(ib, 2, 2) = -3.0; + + h_L3_ref(ib, 0, 0) = 0.0; + h_L3_ref(ib, 0, 1) = 0.0; + h_L3_ref(ib, 0, 2) = 0.0; + h_L3_ref(ib, 1, 0) = 1.0 / 3.0; + h_L3_ref(ib, 1, 1) = 0.0; + h_L3_ref(ib, 1, 2) = 0.0; + h_L3_ref(ib, 2, 0) = 2.0 / 3.0; + h_L3_ref(ib, 2, 1) = 2.0 / 5.0; + h_L3_ref(ib, 2, 2) = 0.0; + + h_U3_ref(ib, 0, 0) = 3.0; + h_U3_ref(ib, 0, 1) = -9.0; + h_U3_ref(ib, 0, 2) = -3.0; + h_U3_ref(ib, 1, 0) = 0.0; + h_U3_ref(ib, 1, 1) = 5.0; + h_U3_ref(ib, 1, 2) = 4.0; + h_U3_ref(ib, 2, 0) = 0.0; + h_U3_ref(ib, 2, 1) = 0.0; + h_U3_ref(ib, 2, 2) = 32.0 / 5.0; + + h_ipiv3_ref(ib, 0) = 2; + h_ipiv3_ref(ib, 1) = 2; + h_ipiv3_ref(ib, 2) = 2; + } + + Kokkos::deep_copy(A0, h_A0); + Kokkos::deep_copy(A1, h_A1); + Kokkos::deep_copy(A2, h_A2); + Kokkos::deep_copy(A3, h_A3); + Kokkos::deep_copy(LU0, A0); + Kokkos::deep_copy(LU1, A1); + Kokkos::deep_copy(LU2, A2); + Kokkos::deep_copy(LU3, A3); + + // getrf to factorize matrix A = P * L * U + auto info0 = Functor_BatchedSerialGetrf(LU0, ipiv0).run(); + auto info1 = Functor_BatchedSerialGetrf(LU1, ipiv1).run(); + auto info2 = Functor_BatchedSerialGetrf(LU2, ipiv2).run(); + auto info3 = Functor_BatchedSerialGetrf(LU3, ipiv3).run(); + + Kokkos::fence(); + EXPECT_EQ(info0, 0); + EXPECT_EQ(info1, 0); + EXPECT_EQ(info2, 0); + EXPECT_EQ(info3, 0); + + auto h_ipiv0 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv0); + auto h_ipiv1 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv1); + auto h_ipiv2 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv2); + auto h_ipiv3 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv3); + + for (int ib = 0; ib < Nb; ib++) { + // Check if piv0 = [0 1 2 3] + for (int i = 0; i < M; i++) { + EXPECT_EQ(h_ipiv0(ib, i), h_ipiv0_ref(ib, i)); + } + // Check if piv1 = [0 1 2] and piv2 = [0 1 2] + for (int i = 0; i < N; i++) { + EXPECT_EQ(h_ipiv1(ib, i), h_ipiv1_ref(ib, i)); + EXPECT_EQ(h_ipiv2(ib, i), h_ipiv2_ref(ib, i)); + } + // Check if piv3 = [2 2 2] + for (int i = 0; i < N; i++) { + EXPECT_EQ(h_ipiv3(ib, i), h_ipiv3_ref(ib, i)); + } + } + + // Reconstruct L and U from Factorized matrix A + // Copy non-diagonal lower triangular components to NL + create_triangular_matrix(LU3, L3, + -1); + + // Copy upper triangular components to U + create_triangular_matrix(LU3, U3); + + RealType eps = 1.0e1 * ats::epsilon(); + auto h_LU0 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), LU0); + auto h_LU1 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), LU1); + auto h_LU2 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), LU2); + + // Check if LU = A (permuted) + for (int ib = 0; ib < Nb; ib++) { + for (int i = 0; i < M; i++) { + for (int j = 0; j < M; j++) { + EXPECT_NEAR_KK(h_LU0(ib, i, j), h_A0(ib, i, j), eps); + } + } + for (int i = 0; i < N; i++) { + for (int j = 0; j < M; j++) { + EXPECT_NEAR_KK(h_LU1(ib, i, j), h_A1(ib, i, j), eps); + EXPECT_NEAR_KK(h_LU2(ib, j, i), h_A2(ib, j, i), eps); + } + } + } + + // For complicated matrix, we compare L and U with reference L and U + auto h_L3 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), L3); + auto h_U3 = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), U3); + for (int ib = 0; ib < Nb; ib++) { + for (int i = 0; i < N; i++) { + for (int j = 0; j < N; j++) { + EXPECT_NEAR_KK(h_L3(ib, i, j), h_L3_ref(ib, i, j), eps); + EXPECT_NEAR_KK(h_U3(ib, i, j), h_U3_ref(ib, i, j), eps); + } + } + } +} + +/// \brief Implementation details of batched getrf test +/// LU factorization with partial pivoting +/// +/// \param N [in] Batch size of matrices +/// \param BlkSize [in] Block size of matrix A +template +void impl_test_batched_getrf(const int N, const int BlkSize) { + using ats = typename Kokkos::ArithTraits; + using RealType = typename ats::mag_type; + using RealView2DType = Kokkos::View; + using View3DType = Kokkos::View; + using PivView2DType = Kokkos::View; + + View3DType A("A", N, BlkSize, BlkSize), A_reconst("A_reconst", N, BlkSize, BlkSize), NL("NL", N, BlkSize, BlkSize), + L("L", N, BlkSize, BlkSize), U("U", N, BlkSize, BlkSize), LU("LU", N, BlkSize, BlkSize), + I("I", N, BlkSize, BlkSize); + RealView2DType ones(Kokkos::view_alloc("ones", Kokkos::WithoutInitializing), N, BlkSize); + PivView2DType ipiv("ipiv", N, BlkSize); + + using execution_space = typename DeviceType::execution_space; + Kokkos::Random_XorShift64_Pool rand_pool(13718); + ScalarType randStart, randEnd; + + // Initialize A_reconst with random matrix + KokkosKernels::Impl::getRandomBounds(1.0, randStart, randEnd); + Kokkos::fill_random(A, rand_pool, randStart, randEnd); + Kokkos::deep_copy(LU, A); + + // Unit matrix I + Kokkos::deep_copy(ones, RealType(1.0)); + create_diagonal_matrix(ones, I); + + Kokkos::fence(); + + // getrf to factorize matrix A = P * L * U + auto info = Functor_BatchedSerialGetrf(LU, ipiv).run(); + + Kokkos::fence(); + EXPECT_EQ(info, 0); + + // Reconstruct L and U from Factorized matrix A + // Copy non-diagonal lower triangular components to NL + create_triangular_matrix(LU, NL, + -1); + + // Copy upper triangular components to U + create_triangular_matrix(LU, U); + + // Copy I to L + Kokkos::deep_copy(L, I); + + // Matrix matrix addition by Gemm + // NL + I by NL * I + L (==I) (result stored in L) + Functor_BatchedSerialGemm(1.0, NL, I, 1.0, L).run(); + + // LU = L * U + Functor_BatchedSerialGemm(1.0, L, U, 0.0, LU).run(); + + Kokkos::fence(); + + // permute A by ipiv + auto h_ipiv = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), ipiv); + auto h_A = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), A); + for (int ib = 0; ib < N; ib++) { + // Permute A by pivot vector + for (int i = 0; i < BlkSize; i++) { + for (int j = 0; j < BlkSize; j++) { + Kokkos::kokkos_swap(h_A(ib, h_ipiv(ib, i), j), h_A(ib, i, j)); + } + } + } + + RealType eps = 1.0e1 * ats::epsilon(); + + auto h_LU = Kokkos::create_mirror_view_and_copy(Kokkos::HostSpace(), LU); + // Check if LU = A (permuted) + for (int ib = 0; ib < N; ib++) { + for (int i = 0; i < BlkSize; i++) { + for (int j = 0; j < BlkSize; j++) { + EXPECT_NEAR_KK(h_LU(ib, i, j), h_A(ib, i, j), eps); + } + } + } +} + +} // namespace Getrf +} // namespace Test + +template +int test_batched_getrf() { +#if defined(KOKKOSKERNELS_INST_LAYOUTLEFT) + { + using LayoutType = Kokkos::LayoutLeft; + Test::Getrf::impl_test_batched_getrf_analytical(1); + Test::Getrf::impl_test_batched_getrf_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Getrf::impl_test_batched_getrf(1, i); + Test::Getrf::impl_test_batched_getrf(2, i); + } + } +#endif +#if defined(KOKKOSKERNELS_INST_LAYOUTRIGHT) + { + using LayoutType = Kokkos::LayoutRight; + Test::Getrf::impl_test_batched_getrf_analytical(1); + Test::Getrf::impl_test_batched_getrf_analytical(2); + for (int i = 0; i < 10; i++) { + Test::Getrf::impl_test_batched_getrf(1, i); + Test::Getrf::impl_test_batched_getrf(2, i); + } + } +#endif + + return 0; +} + +#if defined(KOKKOSKERNELS_INST_FLOAT) +TEST_F(TestCategory, test_batched_getrf_float) { + using algo_tag_type = typename KokkosBatched::Algo::Getrf::Unblocked; + + test_batched_getrf(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_DOUBLE) +TEST_F(TestCategory, test_batched_getrf_double) { + using algo_tag_type = typename KokkosBatched::Algo::Getrf::Unblocked; + + test_batched_getrf(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_FLOAT) +TEST_F(TestCategory, test_batched_getrf_fcomplex) { + using algo_tag_type = typename KokkosBatched::Algo::Getrf::Unblocked; + + test_batched_getrf, algo_tag_type>(); +} +#endif + +#if defined(KOKKOSKERNELS_INST_COMPLEX_DOUBLE) +TEST_F(TestCategory, test_batched_getrf_dcomplex) { + using algo_tag_type = typename KokkosBatched::Algo::Getrf::Unblocked; + + test_batched_getrf, algo_tag_type>(); +} +#endif diff --git a/blas/impl/KokkosBlas_util.hpp b/blas/impl/KokkosBlas_util.hpp index c0777ac9ea..53916bd23e 100644 --- a/blas/impl/KokkosBlas_util.hpp +++ b/blas/impl/KokkosBlas_util.hpp @@ -87,6 +87,7 @@ struct Algo { using UTV = Level3; using Pttrf = Level3; using Pttrs = Level3; + using Getrf = Level3; struct Level2 { struct Unblocked {};