diff --git a/src/batched/dense/KokkosBatched_Gesv.hpp b/src/batched/dense/KokkosBatched_Gesv.hpp index 08ad9644a0..cda2225c43 100644 --- a/src/batched/dense/KokkosBatched_Gesv.hpp +++ b/src/batched/dense/KokkosBatched_Gesv.hpp @@ -62,16 +62,15 @@ struct Gesv { /// using a batched LU decomposition, 2 batched triangular solves, and a batched /// static pivoting. /// -/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view +/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view /// \tparam VectorType: Input type for the right-hand side and the solution, -/// needs to be a 2D view +/// needs to be a 1D view /// -/// \param A [in]: batched matrix, a rank 3 view -/// \param X [out]: solution, a rank 2 view -/// \param B [in]: right-hand side, a rank 2 view -/// \param tmp [in]: a rank 3 view used to store temporary variable; dimension -/// must be N x n x (n+4) where N is the batched size and n is the number of -/// rows. +/// \param A [in]: matrix, a rank 2 view +/// \param X [out]: solution, a rank 1 view +/// \param B [in]: right-hand side, a rank 1 view +/// \param tmp [in]: a rank 2 view used to store temporary variable; dimension +/// must be n x (n+4) where n is the number of rows. /// /// /// Two versions are available (those are chosen based on ArgAlgo): @@ -103,14 +102,14 @@ struct SerialGesv { /// using a batched LU decomposition, 2 batched triangular solves, and a batched /// static pivoting. /// -/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view +/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view /// \tparam VectorType: Input type for the right-hand side and the solution, -/// needs to be a 2D view +/// needs to be a 1D view /// /// \param member [in]: TeamPolicy member -/// \param A [in]: batched matrix, a rank 3 view -/// \param X [out]: solution, a rank 2 view -/// \param B [in]: right-hand side, a rank 2 view +/// \param A [in]: matrix, a rank 2 view +/// \param X [out]: solution, a rank 1 view +/// \param B [in]: right-hand side, a rank 1 view /// /// Two versions are available (those are chosen based on ArgAlgo): /// @@ -141,14 +140,14 @@ struct TeamGesv { /// using a batched LU decomposition, 2 batched triangular solves, and a batched /// static pivoting. /// -/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view +/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view /// \tparam VectorType: Input type for the right-hand side and the solution, -/// needs to be a 2D view +/// needs to be a 1D view /// /// \param member [in]: TeamPolicy member -/// \param A [in]: batched matrix, a rank 3 view -/// \param X [out]: solution, a rank 2 view -/// \param B [in]: right-hand side, a rank 2 view +/// \param A [in]: matrix, a rank 2 view +/// \param X [out]: solution, a rank 1 view +/// \param B [in]: right-hand side, a rank 1 view /// /// Two versions are available (those are chosen based on ArgAlgo): /// diff --git a/src/batched/dense/KokkosBatched_LU_Decl.hpp b/src/batched/dense/KokkosBatched_LU_Decl.hpp index 8cffbdc766..9fa2e2b6e3 100644 --- a/src/batched/dense/KokkosBatched_LU_Decl.hpp +++ b/src/batched/dense/KokkosBatched_LU_Decl.hpp @@ -51,4 +51,7 @@ struct LU { } // namespace KokkosBatched +#include "KokkosBatched_LU_Serial_Impl.hpp" +#include "KokkosBatched_LU_Team_Impl.hpp" + #endif diff --git a/src/batched/dense/impl/KokkosBatched_Gesv_Impl.hpp b/src/batched/dense/impl/KokkosBatched_Gesv_Impl.hpp index 5a07a58990..616df45df9 100644 --- a/src/batched/dense/impl/KokkosBatched_Gesv_Impl.hpp +++ b/src/batched/dense/impl/KokkosBatched_Gesv_Impl.hpp @@ -446,16 +446,20 @@ struct SerialGesv { return 1; } - SerialLU::invoke(PDAD); + int r_val = SerialLU::invoke(PDAD); - SerialTrsm::invoke(1.0, PDAD, PDY); + if (r_val == 0) + r_val = + SerialTrsm::invoke(1.0, PDAD, PDY); - SerialTrsm::invoke(1.0, PDAD, PDY); + if (r_val == 0) + r_val = + SerialTrsm::invoke(1.0, PDAD, PDY); - SerialHadamard1D(PDY, D2, X); - return 0; + if (r_val == 0) SerialHadamard1D(PDY, D2, X); + return r_val; } }; @@ -489,16 +493,21 @@ struct SerialGesv { } #endif - SerialLU::invoke(A); + int r_val = SerialLU::invoke(A); - SerialCopy::invoke(Y, X); - SerialTrsm::invoke(1.0, A, X); + if (r_val == 0) r_val = SerialCopy::invoke(Y, X); - SerialTrsm::invoke(1.0, A, X); + if (r_val == 0) + r_val = + SerialTrsm::invoke(1.0, A, X); - return 0; + if (r_val == 0) + r_val = + SerialTrsm::invoke(1.0, A, X); + + return r_val; } }; @@ -557,22 +566,31 @@ struct TeamGesv { } member.team_barrier(); - TeamLU::invoke(member, PDAD); + int r_val = + TeamLU::invoke(member, PDAD); member.team_barrier(); - TeamTrsm::invoke(member, 1.0, PDAD, - PDY); - member.team_barrier(); + if (r_val == 0) { + r_val = TeamTrsm::invoke(member, 1.0, + PDAD, PDY); + member.team_barrier(); + } - TeamTrsm::invoke(member, 1.0, PDAD, - PDY); - member.team_barrier(); + if (r_val == 0) { + r_val = + TeamTrsm::invoke(member, 1.0, + PDAD, PDY); + member.team_barrier(); + } - TeamHadamard1D(member, PDY, D2, X); - member.team_barrier(); - return 0; + if (r_val == 0) { + TeamHadamard1D(member, PDY, D2, X); + member.team_barrier(); + } + + return r_val; } }; @@ -605,21 +623,28 @@ struct TeamGesv { } #endif - TeamLU::invoke(member, A); + int r_val = TeamLU::invoke(member, A); member.team_barrier(); - TeamCopy::invoke(member, Y, X); - member.team_barrier(); + if (r_val == 0) { + TeamCopy::invoke(member, Y, X); + member.team_barrier(); + } - TeamTrsm::invoke(member, 1.0, A, X); - member.team_barrier(); + if (r_val == 0) { + TeamTrsm::invoke(member, 1.0, A, X); + member.team_barrier(); + } - TeamTrsm::invoke(member, 1.0, A, X); - member.team_barrier(); + if (r_val == 0) { + TeamTrsm::invoke(member, 1.0, A, + X); + member.team_barrier(); + } - return 0; + return r_val; } }; @@ -679,22 +704,31 @@ struct TeamVectorGesv { member.team_barrier(); - TeamLU::invoke(member, PDAD); + int r_val = + TeamLU::invoke(member, PDAD); member.team_barrier(); - TeamVectorTrsm::invoke(member, 1.0, - PDAD, PDY); - member.team_barrier(); + if (r_val == 0) { + TeamVectorTrsm::invoke(member, 1.0, + PDAD, PDY); + member.team_barrier(); + } - TeamVectorTrsm::invoke(member, 1.0, - PDAD, PDY); - member.team_barrier(); + if (r_val == 0) { + TeamVectorTrsm::invoke(member, + 1.0, PDAD, + PDY); + member.team_barrier(); + } - TeamVectorHadamard1D(member, PDY, D2, X); - member.team_barrier(); - return 0; + if (r_val == 0) { + TeamVectorHadamard1D(member, PDY, D2, X); + member.team_barrier(); + } + + return r_val; } }; @@ -727,23 +761,29 @@ struct TeamVectorGesv { } #endif - TeamLU::invoke(member, A); + int r_val = TeamLU::invoke(member, A); member.team_barrier(); - TeamVectorCopy::invoke(member, Y, X); - member.team_barrier(); + if (r_val == 0) { + TeamVectorCopy::invoke(member, Y, X); + member.team_barrier(); + } - TeamVectorTrsm::invoke(member, 1.0, A, - X); - member.team_barrier(); + if (r_val == 0) { + TeamVectorTrsm::invoke(member, 1.0, + A, X); + member.team_barrier(); + } - TeamVectorTrsm::invoke(member, 1.0, - A, X); - member.team_barrier(); + if (r_val == 0) { + TeamVectorTrsm::invoke(member, + 1.0, A, X); + member.team_barrier(); + } - return 0; + return r_val; } };