Skip to content

Commit

Permalink
Set Kokkos Kernels to 4.5.01
Browse files Browse the repository at this point in the history
  • Loading branch information
tpadioleau committed Dec 24, 2024
1 parent 0bc2306 commit 651307f
Show file tree
Hide file tree
Showing 8 changed files with 406 additions and 8 deletions.
2 changes: 1 addition & 1 deletion .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
url = https://github.com/kokkos/kokkos.git
[submodule "vendor/kokkos-kernels"]
path = vendor/kokkos-kernels
url = https://github.com/yasahi-hpc/kokkos-kernels.git
url = https://github.com/kokkos/kokkos-kernels.git
[submodule "vendor/doxygen-awesome-css"]
path = vendor/doxygen-awesome-css
url = https://github.com/jothepro/doxygen-awesome-css.git
Expand Down
64 changes: 64 additions & 0 deletions include/ddc/kernels/splines/KokkosBatched_Gbtrs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_GBTRS_HPP_
#define KOKKOSBATCHED_GBTRS_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {

/// \brief Serial Batched Gbtrs:
///
/// Solve A_l x_l = b_l for all l = 0, ..., N
/// with a general band matrix A using the LU factorization computed
/// by gbtrf.
///
/// \tparam AViewType: Input type for the matrix, needs to be a 2D view
/// \tparam BViewType: Input type for the right-hand side and the solution,
/// needs to be a 1D view
/// \tparam PivViewType: Integer type for pivot indices, needs to be a 1D view
///
/// \param A [in]: A is a ldab by n banded matrix.
/// Details of the LU factorization of the band matrix A, as computed by
/// gbtrf. U is stored as an upper triangular band matrix with KL+KU
/// superdiagonals in rows 1 to KL+KU+1, and the multipliers used during
/// the factorization are stored in rows KL+KU+2 to 2*KL+KU+1.
/// \param b [inout]: right-hand side and the solution
/// \param piv [in]: The pivot indices; for 1 <= i <= N, row i of the matrix
/// was interchanged with row piv(i).
/// \param kl [in]: kl specifies the number of subdiagonals within the band
/// of A. kl >= 0
/// \param ku [in]: ku specifies the number of superdiagonals within the band
/// of A. ku >= 0
///
/// No nested parallel_for is used inside of the function.
///

template <typename ArgTrans, typename ArgAlgo>
struct SerialGbtrs {
template <typename AViewType, typename BViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A,
const BViewType &b,
const PivViewType &piv, const int kl,
const int ku);
};
} // namespace KokkosBatched

#include "KokkosBatched_Gbtrs_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_GBTRS_HPP_
176 changes: 176 additions & 0 deletions include/ddc/kernels/splines/KokkosBatched_Gbtrs_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,176 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_GBTRS_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_GBTRS_SERIAL_IMPL_HPP_

#include <Kokkos_Swap.hpp>
#include <KokkosBatched_Util.hpp>
#include <KokkosBlas2_gemv.hpp>
#include <KokkosBatched_Tbsv.hpp>

namespace KokkosBatched {

template <typename AViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int checkGbtrsInput(
[[maybe_unused]] const AViewType &A, [[maybe_unused]] const BViewType &b,
[[maybe_unused]] const int kl, [[maybe_unused]] const int ku) {
static_assert(Kokkos::is_view_v<AViewType>,
"KokkosBatched::gbtrs: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<BViewType>,
"KokkosBatched::gbtrs: BViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2,
"KokkosBatched::gbtrs: AViewType must have rank 2.");
static_assert(BViewType::rank == 1,
"KokkosBatched::gbtrs: BViewType must have rank 1.");
#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
if (kl < 0) {
Kokkos::printf(
"KokkosBatched::gbtrs: input parameter kl must not be less than 0: kl "
"= "
"%d\n",
kl);
return 1;
}

if (ku < 0) {
Kokkos::printf(
"KokkosBatched::gbtrs: input parameter ku must not be less than 0: ku "
"= "
"%d\n",
ku);
return 1;
}

const int lda = A.extent(0), n = A.extent(1);
if (lda < (2 * kl + ku + 1)) {
Kokkos::printf(
"KokkosBatched::gbtrs: leading dimension of A must be smaller than 2 * "
"kl + ku + 1: "
"lda = %d, kl = %d, ku = %d\n",
lda, kl, ku);
return 1;
}

const int ldb = b.extent(0);
if (ldb < Kokkos::max(1, n)) {
Kokkos::printf(
"KokkosBatched::gbtrs: leading dimension of b must be smaller than "
"max(1, n): "
"ldb = %d, n = %d\n",
ldb, n);
return 1;
}

#endif
return 0;
}

//// Non-transpose ////
template <>
struct SerialGbtrs<Trans::NoTranspose, Algo::Level3::Unblocked> {
template <typename AViewType, typename BViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A,
const BViewType &b,
const PivViewType &piv, const int kl,
const int ku) {
// Quick return if possible
const int n = A.extent(1);
if (n == 0) return 0;

auto info = checkGbtrsInput(A, b, kl, ku);
if (info) return info;

bool lonti = kl > 0;
const int kd = ku + kl + 1;
if (lonti) {
for (int j = 0; j < n - 1; ++j) {
const int lm = Kokkos::min(kl, n - j - 1);
auto l = piv(j);
// If pivot index is not j, swap rows l and j in b
if (l != j) {
Kokkos::kokkos_swap(b(l), b(j));
}

// Perform a rank-1 update of the remaining part of the current column
// (ger)
for (int i = 0; i < lm; ++i) {
b(j + 1 + i) = b(j + 1 + i) - A(kd + i, j) * b(j);
}
}
}

// Solve U*X = b for each right hand side, overwriting B with X.
[[maybe_unused]] auto info_tbsv =
KokkosBatched::SerialTbsv<Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit,
Algo::Trsv::Unblocked>::invoke(A, b, kl + ku);

return 0;
}
};

//// Transpose ////
template <>
struct SerialGbtrs<Trans::Transpose, Algo::Level3::Unblocked> {
template <typename AViewType, typename BViewType, typename PivViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A,
const BViewType &b,
const PivViewType &piv, const int kl,
const int ku) {
// Quick return if possible
const int n = A.extent(1);
if (n == 0) return 0;

auto info = checkGbtrsInput(A, b, kl, ku);
if (info) return info;

bool lonti = kl > 0;
const int kd = ku + kl + 1;

// Solve U*X = b for each right hand side, overwriting B with X.
[[maybe_unused]] auto info_tbsv =
KokkosBatched::SerialTbsv<Uplo::Upper, Trans::Transpose, Diag::NonUnit,
Algo::Tbsv::Unblocked>::invoke(A, b, kl + ku);

if (lonti) {
for (int j = n - 2; j >= 0; --j) {
const int lm = Kokkos::min(kl, n - j - 1);

// Gemv transposed
auto a = Kokkos::subview(b, Kokkos::pair(j + 1, j + 1 + lm));
auto x = Kokkos::subview(A, Kokkos::pair(kd, kd + lm), j);
auto y = Kokkos::subview(b, Kokkos::pair(j, j + lm));

[[maybe_unused]] auto info_gemv =
KokkosBlas::Impl::SerialGemvInternal<Algo::Gemv::Unblocked>::invoke(
1, a.extent(0), -1.0, a.data(), a.stride_0(), a.stride_0(),
x.data(), x.stride_0(), 1.0, y.data(), y.stride_0());

// If pivot index is not j, swap rows l and j in b
auto l = piv(j);
if (l != j) {
Kokkos::kokkos_swap(b(l), b(j));
}
}
}

return 0;
}
};
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_GBTRS_SERIAL_IMPL_HPP_
53 changes: 53 additions & 0 deletions include/ddc/kernels/splines/KokkosBatched_Getrs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_GETRS_HPP_
#define KOKKOSBATCHED_GETRS_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {

/// \brief Serial Batched Getrs:
/// Solve a system of linear equations
/// A * x = b or A**T * x = b
/// with a general N-by-N matrix A using LU factorization computed
/// by Getrf.
/// \tparam AViewType: Input type for the matrix, needs to be a 2D view
/// \tparam PivViewType: Input type for the pivot indices, needs to be a 1D view
/// \tparam BViewType: Input type for the right-hand side and the solution,
/// needs to be a 1D view
///
/// \param A [inout]: A is a m by n general matrix, a rank 2 view
/// \param piv [out]: On exit, the pivot indices, a rank 1 view
/// \param B [inout]: right-hand side and the solution, a rank 1 view
///
/// No nested parallel_for is used inside of the function.
///

template <typename ArgTrans, typename ArgAlgo>
struct SerialGetrs {
template <typename AViewType, typename PivViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A,
const PivViewType &piv,
const BViewType &b);
};
} // namespace KokkosBatched

#include "KokkosBatched_Getrs_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_GETRS_HPP_
Loading

0 comments on commit 651307f

Please sign in to comment.