Skip to content

Commit

Permalink
implement batched serial pbtrs (#2330)
Browse files Browse the repository at this point in the history
* implement batched serial pbtrs

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* format

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* fix: docstrings for pbtrs

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* move implementation details under Impl namespace

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* Add missing check for pbtrs

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* fix: conflicts

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* fix: use EXPECT_NEAR_KK_REL for check

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

* remove unused variable xm from pbtrs impl

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>

---------

Signed-off-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
Co-authored-by: Yuuichi Asahi <y.asahi@nr.titech.ac.jp>
  • Loading branch information
yasahi-hpc and Yuuichi Asahi authored Oct 25, 2024
1 parent fc39467 commit c283c44
Show file tree
Hide file tree
Showing 8 changed files with 631 additions and 0 deletions.
95 changes: 95 additions & 0 deletions batched/dense/impl/KokkosBatched_Pbtrs_Serial_Impl.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PBTRS_SERIAL_IMPL_HPP_
#define KOKKOSBATCHED_PBTRS_SERIAL_IMPL_HPP_

#include <KokkosBatched_Util.hpp>
#include "KokkosBatched_Pbtrs_Serial_Internal.hpp"

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {
namespace Impl {

template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int checkPbtrsInput([[maybe_unused]] const AViewType &A,
[[maybe_unused]] const XViewType &x) {
static_assert(Kokkos::is_view_v<AViewType>, "KokkosBatched::pbtrs: AViewType is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<XViewType>, "KokkosBatched::pbtrs: XViewType is not a Kokkos::View.");
static_assert(AViewType::rank == 2, "KokkosBatched::pbtrs: AViewType must have rank 2.");
static_assert(XViewType::rank == 1, "KokkosBatched::pbtrs: XViewType must have rank 1.");

#if (KOKKOSKERNELS_DEBUG_LEVEL > 0)
const int ldb = x.extent(0);
const int lda = A.extent(0), n = A.extent(1);
const int kd = lda - 1;
if (kd < 0) {
Kokkos::printf(
"KokkosBatched::pbtrs: leading dimension of A must not be less than 1: %d, A: "
"%d "
"x %d \n",
lda, n);
return 1;
}
if (ldb < Kokkos::max(1, n)) {
Kokkos::printf(
"KokkosBatched::pbtrs: Dimensions of x and A do not match: x: %d, A: "
"%d "
"x %d \n"
"x.extent(0) must be larger or equal to A.extent(1) \n",
ldb, lda, n);
return 1;
}
#endif
return 0;
}
} // namespace Impl

//// Lower ////
template <>
struct SerialPbtrs<Uplo::Lower, Algo::Pbtrs::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
auto info = KokkosBatched::Impl::checkPbtrsInput(A, x);
if (info) return info;

const int kd = A.extent(0) - 1;
return KokkosBatched::Impl::SerialPbtrsInternalLower<Algo::Pbtrs::Unblocked>::invoke(
A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), kd);
}
};

//// Upper ////
template <>
struct SerialPbtrs<Uplo::Upper, Algo::Pbtrs::Unblocked> {
template <typename AViewType, typename XViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const AViewType &A, const XViewType &x) {
// Quick return if possible
if (A.extent(1) == 0) return 0;
auto info = KokkosBatched::Impl::checkPbtrsInput(A, x);
if (info) return info;

const int kd = A.extent(0) - 1;
return KokkosBatched::Impl::SerialPbtrsInternalUpper<Algo::Pbtrs::Unblocked>::invoke(
A.extent(1), A.data(), A.stride_0(), A.stride_1(), x.data(), x.stride_0(), kd);
}
};

} // namespace KokkosBatched

#endif // KOKKOSBATCHED_PBTRS_SERIAL_IMPL_HPP_
91 changes: 91 additions & 0 deletions batched/dense/impl/KokkosBatched_Pbtrs_Serial_Internal.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER

#ifndef KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
#define KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_

#include "KokkosBatched_Util.hpp"
#include "KokkosBatched_Tbsv_Serial_Internal.hpp"

namespace KokkosBatched {
namespace Impl {

///
/// Serial Internal Impl
/// ====================

///
/// Lower
///

template <typename AlgoType>
struct SerialPbtrsInternalLower {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int kd);
};

template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalLower<Algo::Pbtrs::Unblocked>::invoke(const int an,
const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x,
const int xs0, const int kd) {
// Solve L*X = B, overwriting B with X.
SerialTbsvInternalLower<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);

// Solve L**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalLowerTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);

return 0;
}

///
/// Upper
///

template <typename AlgoType>
struct SerialPbtrsInternalUpper {
template <typename ValueType>
KOKKOS_INLINE_FUNCTION static int invoke(const int an, const ValueType *KOKKOS_RESTRICT A, const int as0,
const int as1,
/**/ ValueType *KOKKOS_RESTRICT x, const int xs0, const int kd);
};

template <>
template <typename ValueType>
KOKKOS_INLINE_FUNCTION int SerialPbtrsInternalUpper<Algo::Pbtrs::Unblocked>::invoke(const int an,
const ValueType *KOKKOS_RESTRICT A,
const int as0, const int as1,
/**/ ValueType *KOKKOS_RESTRICT x,
const int xs0, const int kd) {
// Solve U**T *X = B, overwriting B with X.
constexpr bool do_conj = Kokkos::ArithTraits<ValueType>::is_complex;
SerialTbsvInternalUpperTranspose<Algo::Tbsv::Unblocked>::invoke(false, do_conj, an, A, as0, as1, x, xs0, kd);

// Solve U*X = B, overwriting B with X.
SerialTbsvInternalUpper<Algo::Tbsv::Unblocked>::invoke(false, an, A, as0, as1, x, xs0, kd);

return 0;
}

} // namespace Impl
} // namespace KokkosBatched

#endif // KOKKOSBATCHED_PBTRS_SERIAL_INTERNAL_HPP_
56 changes: 56 additions & 0 deletions batched/dense/src/KokkosBatched_Pbtrs.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
//@HEADER
// ************************************************************************
//
// Kokkos v. 4.0
// Copyright (2022) National Technology & Engineering
// Solutions of Sandia, LLC (NTESS).
//
// Under the terms of Contract DE-NA0003525 with NTESS,
// the U.S. Government retains certain rights in this software.
//
// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
// See https://kokkos.org/LICENSE for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//@HEADER
#ifndef KOKKOSBATCHED_PBTRS_HPP_
#define KOKKOSBATCHED_PBTRS_HPP_

#include <KokkosBatched_Util.hpp>

/// \author Yuuichi Asahi (yuuichi.asahi@cea.fr)

namespace KokkosBatched {

/// \brief Serial Batched Pbtrs:
/// Solve Ab_l x_l = b_l for all l = 0, ..., N
/// using the Cholesky factorization A = U**H * U or A = L * L**H computed by
/// Pbtrf.
/// The matrix has the form
/// A = U**H * U , if ArgUplo = KokkosBatched::Uplo::Upper, or
/// A = L * L**H, if ArgUplo = KokkosBatched::Uplo::Lower,
/// where U is an upper triangular matrix, U**H is the transpose of U, and
/// L is lower triangular matrix, L**H is the transpose of L.
///
/// \tparam ABViewType: Input type for a banded matrix, needs to be a 2D
/// view
/// \tparam BViewType: Input type for a right-hand side and the solution,
/// needs to be a 1D view
///
/// \param ab [in]: ab is a ldab by n banded matrix, with ( kd + 1 ) diagonals
/// \param b [inout]: right-hand side and the solution, a rank 1 view
///
/// No nested parallel_for is used inside of the function.
///

template <typename ArgUplo, typename ArgAlgo>
struct SerialPbtrs {
template <typename ABViewType, typename BViewType>
KOKKOS_INLINE_FUNCTION static int invoke(const ABViewType &ab, const BViewType &b);
};

} // namespace KokkosBatched

#include "KokkosBatched_Pbtrs_Serial_Impl.hpp"

#endif // KOKKOSBATCHED_PBTRS_HPP_
3 changes: 3 additions & 0 deletions batched/dense/unit_test/Test_Batched_Dense.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@
#include "Test_Batched_SerialPbtrf.hpp"
#include "Test_Batched_SerialPbtrf_Real.hpp"
#include "Test_Batched_SerialPbtrf_Complex.hpp"
#include "Test_Batched_SerialPbtrs.hpp"
#include "Test_Batched_SerialPbtrs_Real.hpp"
#include "Test_Batched_SerialPbtrs_Complex.hpp"
#include "Test_Batched_SerialLaswp.hpp"

// Team Kernels
Expand Down
Loading

0 comments on commit c283c44

Please sign in to comment.