Skip to content

Commit

Permalink
batched/dense: Add gesv DynRankView runtime checks
Browse files Browse the repository at this point in the history
  • Loading branch information
e10harvey committed Jun 5, 2023
1 parent fa2bdef commit effa886
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions batched/dense/impl/KokkosBatched_Gemv_Team_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,13 @@ struct TeamGemv<MemberType, Trans::NoTranspose, Algo::Gemv::Unblocked> {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
static_assert(AViewType::rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
constexpr char *assert_msg =
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)";
if constexpr (Kokkos::is_dyn_rank_view<AViewType>)
assertm(A.rank_dynamic() == 3, assert_msg) else static_assert(
AViewType::rank == 3, assert_msg);

if (A.extent(0) == 1) {
KokkosBlas::TeamGemv<
MemberType, Trans::NoTranspose,
Expand Down Expand Up @@ -79,9 +83,12 @@ struct TeamGemv<MemberType, Trans::NoTranspose, Algo::Gemv::Blocked> {
const xViewType & /*x*/,
const ScalarType /*beta*/,
const yViewType & /*y*/) {
static_assert(AViewType::rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
constexpr char *assert_msg =
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)";
if constexpr (Kokkos::is_dyn_rank_view<AViewType>)
assertm(A.rank_dynamic() == 3, assert_msg) else static_assert(
AViewType::rank == 3, assert_msg);
Kokkos::abort(
"KokkosBlas::TeamGemv<Algo::Gemv::Blocked> for rank-3 matrix is NOT "
"implemented");
Expand All @@ -99,9 +106,12 @@ struct TeamGemv<MemberType, Trans::Transpose, Algo::Gemv::Unblocked> {
KOKKOS_INLINE_FUNCTION static int invoke(
const MemberType &member, const ScalarType alpha, const AViewType &A,
const xViewType &x, const ScalarType beta, const yViewType &y) {
static_assert(AViewType::rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
constexpr char *assert_msg =
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)";
if constexpr (Kokkos::is_dyn_rank_view<AViewType>)
assertm(A.rank_dynamic() == 3, assert_msg) else static_assert(
AViewType::rank == 3, assert_msg);
if (A.extent(0) == 1) {
KokkosBlas::
TeamGemv<MemberType, Trans::Transpose, Algo::Gemv::Unblocked>::invoke(
Expand Down Expand Up @@ -129,9 +139,12 @@ struct TeamGemv<MemberType, Trans::Transpose, Algo::Gemv::Blocked> {
const xViewType & /*x*/,
const ScalarType /*beta*/,
const yViewType & /*y*/) {
static_assert(AViewType::rank == 3,
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)");
constexpr char *assert_msg =
"Batched TeamGemv requires rank-3 A matrix (use "
"KokkosBlas::TeamGemv for regular rank-2 matrix)";
if constexpr (Kokkos::is_dyn_rank_view<AViewType>)
assertm(A.rank_dynamic() == 3, assert_msg) else static_assert(
AViewType::rank == 3, assert_msg);
Kokkos::abort(
"KokkosBlas::TeamGemv<Algo::Gemv::Blocked> for rank-3 matrix is NOT "
"implemented");
Expand Down

0 comments on commit effa886

Please sign in to comment.