Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes while documenting #2466

Merged
merged 11 commits into from
Jan 6, 2025
10 changes: 10 additions & 0 deletions blas/src/KokkosBlas1_rot.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ void rot(execution_space const& space, VectorView const& X, VectorView const& Y,
static_assert(Kokkos::is_execution_space<execution_space>::value,
"rot: execution_space template parameter is not a Kokkos "
"execution space.");
static_assert(Kokkos::is_view_v<VectorView>, "KokkosBlas::rot: VectorView is not a Kokkos::View.");
static_assert(Kokkos::is_view_v<ScalarView>, "KokkosBlas::rot: ScalarView is not a Kokkos::View.");
static_assert(VectorView::rank == 1, "rot: VectorView template parameter needs to be a rank 1 view");
static_assert(ScalarView::rank == 0, "rot: ScalarView template parameter needs to be a rank 0 view");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename VectorView::memory_space>::accessible,
Expand All @@ -40,6 +42,14 @@ void rot(execution_space const& space, VectorView const& X, VectorView const& Y,
static_assert(std::is_same<typename VectorView::non_const_value_type, typename VectorView::value_type>::value,
"rot: VectorView template parameter needs to store non-const values");

// Check compatibility of dimensions at run time.
if (X.extent(0) != Y.extent(0)) {
std::ostringstream os;
os << "KokkosBlas::rot: Dimensions of X and Y do not match: "
<< "X: " << X.extent(0) << ", Y: " << Y.extent(0);
KokkosKernels::Impl::throw_runtime_exception(os.str());
}

using VectorView_Internal = Kokkos::View<typename VectorView::non_const_value_type*,
typename KokkosKernels::Impl::GetUnifiedLayout<VectorView>::array_layout,
Kokkos::Device<execution_space, typename VectorView::memory_space>,
Expand Down
2 changes: 2 additions & 0 deletions blas/src/KokkosBlas1_rotg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ void rotg(execution_space const& space, SViewType const& a, SViewType const& b,
"rotg: execution_space cannot access data in SViewType");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename MViewType::memory_space>::accessible,
"rotg: execution_space cannot access data in MViewType");
static_assert(!Kokkos::ArithTraits<typename MViewType::value_type>::is_complex,
"rotg: MViewType cannot hold complex values.");

using SView_Internal = Kokkos::View<
typename SViewType::value_type, typename KokkosKernels::Impl::GetUnifiedLayout<SViewType>::array_layout,
Expand Down
2 changes: 0 additions & 2 deletions blas/src/KokkosBlas1_scal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ void scal(const execution_space& space, const RMV& R, const AV& a, const XMV& X)
"X is not a Kokkos::View.");
static_assert(Kokkos::SpaceAccessibility<execution_space, typename XMV::memory_space>::accessible,
"KokkosBlas::scal: XMV must be accessible from execution_space");
static_assert(Kokkos::SpaceAccessibility<typename RMV::memory_space, typename XMV::memory_space>::assignable,
"KokkosBlas::scal: XMV must be assignable to RMV");
static_assert(std::is_same<typename RMV::value_type, typename RMV::non_const_value_type>::value,
"KokkosBlas::scal: R is const. "
"It must be nonconst, because it is an output argument "
Expand Down
21 changes: 12 additions & 9 deletions blas/src/KokkosBlas2_ger.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,22 @@ template <class ExecutionSpace, class XViewType, class YViewType, class AViewTyp
void ger(const ExecutionSpace& space, const char trans[], const typename AViewType::const_value_type& alpha,
const XViewType& x, const YViewType& y, const AViewType& A) {
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename AViewType::memory_space>::accessible,
"AViewType memory space must be accessible from ExecutionSpace");
"ger: AViewType memory space must be accessible from ExecutionSpace");
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename XViewType::memory_space>::accessible,
"XViewType memory space must be accessible from ExecutionSpace");
"ger: XViewType memory space must be accessible from ExecutionSpace");
static_assert(Kokkos::SpaceAccessibility<ExecutionSpace, typename YViewType::memory_space>::accessible,
"YViewType memory space must be accessible from ExecutionSpace");
"ger: YViewType memory space must be accessible from ExecutionSpace");

static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "XViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<YViewType>::value, "YViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<AViewType>::value, "ger: AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<XViewType>::value, "ger: XViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<YViewType>::value, "ger: YViewType must be a Kokkos::View.");

static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
static_assert(static_cast<int>(XViewType::rank) == 1, "XViewType must have rank 1.");
static_assert(static_cast<int>(YViewType::rank) == 1, "YViewType must have rank 1.");
static_assert(static_cast<int>(AViewType::rank) == 2, "ger: AViewType must have rank 2.");
static_assert(static_cast<int>(XViewType::rank) == 1, "ger: XViewType must have rank 1.");
static_assert(static_cast<int>(YViewType::rank) == 1, "ger: YViewType must have rank 1.");

static_assert(std::is_same_v<typename AViewType::value_type, typename AViewType::non_const_value_type>,
"ger: AViewType must store non const values.");

// Check compatibility of dimensions at run time.
if ((A.extent(0) != x.extent(0)) || (A.extent(1) != y.extent(0))) {
Expand Down
12 changes: 8 additions & 4 deletions blas/src/KokkosBlas3_trmm.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,14 @@ namespace KokkosBlas {
template <class execution_space, class AViewType, class BViewType>
void trmm(const execution_space& space, const char side[], const char uplo[], const char trans[], const char diag[],
typename BViewType::const_value_type& alpha, const AViewType& A, const BViewType& B) {
static_assert(Kokkos::is_view<AViewType>::value, "AViewType must be a Kokkos::View.");
static_assert(Kokkos::is_view<BViewType>::value, "BViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2, "AViewType must have rank 2.");
static_assert(static_cast<int>(BViewType::rank) == 2, "BViewType must have rank 2.");
static_assert(Kokkos::is_execution_space_v<execution_space>,
"trmm: execution_space must be a Kokkos::execution_space.");
static_assert(Kokkos::is_view_v<AViewType>,
"trmm: AViewType must be a "
"Kokkos::View.");
static_assert(Kokkos::is_view_v<BViewType>, "trmm: BViewType must be a Kokkos::View.");
static_assert(static_cast<int>(AViewType::rank) == 2, "trmm: AViewType must have rank 2.");
static_assert(static_cast<int>(BViewType::rank) == 2, "trmm: BViewType must have rank 2.");

// Check validity of indicator argument
bool valid_side = (side[0] == 'L') || (side[0] == 'l') || (side[0] == 'R') || (side[0] == 'r');
Expand Down
Loading