Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address #1409 #1410

Merged
merged 1 commit into from
May 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 17 additions & 18 deletions src/batched/dense/KokkosBatched_Gesv.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@ struct Gesv {
/// using a batched LU decomposition, 2 batched triangular solves, and a batched
/// static pivoting.
///
/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view
/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view
/// \tparam VectorType: Input type for the right-hand side and the solution,
/// needs to be a 2D view
/// needs to be a 1D view
///
/// \param A [in]: batched matrix, a rank 3 view
/// \param X [out]: solution, a rank 2 view
/// \param B [in]: right-hand side, a rank 2 view
/// \param tmp [in]: a rank 3 view used to store temporary variable; dimension
/// must be N x n x (n+4) where N is the batched size and n is the number of
/// rows.
/// \param A [in]: matrix, a rank 2 view
/// \param X [out]: solution, a rank 1 view
/// \param B [in]: right-hand side, a rank 1 view
/// \param tmp [in]: a rank 2 view used to store temporary variable; dimension
/// must be n x (n+4) where n is the number of rows.
///
///
/// Two versions are available (those are chosen based on ArgAlgo):
Expand Down Expand Up @@ -103,14 +102,14 @@ struct SerialGesv {
/// using a batched LU decomposition, 2 batched triangular solves, and a batched
/// static pivoting.
///
/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view
/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view
/// \tparam VectorType: Input type for the right-hand side and the solution,
/// needs to be a 2D view
/// needs to be a 1D view
///
/// \param member [in]: TeamPolicy member
/// \param A [in]: batched matrix, a rank 3 view
/// \param X [out]: solution, a rank 2 view
/// \param B [in]: right-hand side, a rank 2 view
/// \param A [in]: matrix, a rank 2 view
/// \param X [out]: solution, a rank 1 view
/// \param B [in]: right-hand side, a rank 1 view
///
/// Two versions are available (those are chosen based on ArgAlgo):
///
Expand Down Expand Up @@ -141,14 +140,14 @@ struct TeamGesv {
/// using a batched LU decomposition, 2 batched triangular solves, and a batched
/// static pivoting.
///
/// \tparam MatrixType: Input type for the matrix, needs to be a 3D view
/// \tparam MatrixType: Input type for the matrix, needs to be a 2D view
/// \tparam VectorType: Input type for the right-hand side and the solution,
/// needs to be a 2D view
/// needs to be a 1D view
///
/// \param member [in]: TeamPolicy member
/// \param A [in]: batched matrix, a rank 3 view
/// \param X [out]: solution, a rank 2 view
/// \param B [in]: right-hand side, a rank 2 view
/// \param A [in]: matrix, a rank 2 view
/// \param X [out]: solution, a rank 1 view
/// \param B [in]: right-hand side, a rank 1 view
///
/// Two versions are available (those are chosen based on ArgAlgo):
///
Expand Down
3 changes: 3 additions & 0 deletions src/batched/dense/KokkosBatched_LU_Decl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,4 +51,7 @@ struct LU {

} // namespace KokkosBatched

#include "KokkosBatched_LU_Serial_Impl.hpp"
#include "KokkosBatched_LU_Team_Impl.hpp"

#endif
160 changes: 100 additions & 60 deletions src/batched/dense/impl/KokkosBatched_Gesv_Impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -446,16 +446,20 @@ struct SerialGesv<Gesv::StaticPivoting> {
return 1;
}

SerialLU<Algo::Level3::Unblocked>::invoke(PDAD);
int r_val = SerialLU<Algo::Level3::Unblocked>::invoke(PDAD);

SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);

SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, PDAD, PDY);

SerialHadamard1D(PDY, D2, X);
return 0;
if (r_val == 0) SerialHadamard1D(PDY, D2, X);
return r_val;
}
};

Expand Down Expand Up @@ -489,16 +493,21 @@ struct SerialGesv<Gesv::NoPivoting> {
}
#endif

SerialLU<Algo::Level3::Unblocked>::invoke(A);
int r_val = SerialLU<Algo::Level3::Unblocked>::invoke(A);

SerialCopy<Trans::NoTranspose, 1>::invoke(Y, X);
SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);
if (r_val == 0) r_val = SerialCopy<Trans::NoTranspose, 1>::invoke(Y, X);

SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Lower, Trans::NoTranspose, Diag::Unit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);

return 0;
if (r_val == 0)
r_val =
SerialTrsm<Side::Left, Uplo::Upper, Trans::NoTranspose, Diag::NonUnit,
Algo::Level3::Unblocked>::invoke(1.0, A, X);

return r_val;
}
};

Expand Down Expand Up @@ -557,22 +566,31 @@ struct TeamGesv<MemberType, Gesv::StaticPivoting> {
}
member.team_barrier();

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
int r_val =
TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
member.team_barrier();

TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, PDAD,
PDY);
member.team_barrier();
if (r_val == 0) {
r_val = TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
}

TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0, PDAD,
PDY);
member.team_barrier();
if (r_val == 0) {
r_val =
TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
}

TeamHadamard1D(member, PDY, D2, X);
member.team_barrier();
return 0;
if (r_val == 0) {
TeamHadamard1D(member, PDY, D2, X);
member.team_barrier();
}

return r_val;
}
};

Expand Down Expand Up @@ -605,21 +623,28 @@ struct TeamGesv<MemberType, Gesv::NoPivoting> {
}
#endif

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
int r_val = TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
member.team_barrier();

TeamCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
if (r_val == 0) {
TeamCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
}

TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, A, X);
member.team_barrier();
if (r_val == 0) {
TeamTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, A, X);
member.team_barrier();
}

TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0, A, X);
member.team_barrier();
if (r_val == 0) {
TeamTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0, A,
X);
member.team_barrier();
}

return 0;
return r_val;
}
};

Expand Down Expand Up @@ -679,22 +704,31 @@ struct TeamVectorGesv<MemberType, Gesv::StaticPivoting> {

member.team_barrier();

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
int r_val =
TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, PDAD);
member.team_barrier();

TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
}

TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0,
PDAD, PDY);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member,
1.0, PDAD,
PDY);
member.team_barrier();
}

TeamVectorHadamard1D(member, PDY, D2, X);
member.team_barrier();
return 0;
if (r_val == 0) {
TeamVectorHadamard1D(member, PDY, D2, X);
member.team_barrier();
}

return r_val;
}
};

Expand Down Expand Up @@ -727,23 +761,29 @@ struct TeamVectorGesv<MemberType, Gesv::NoPivoting> {
}
#endif

TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
int r_val = TeamLU<MemberType, Algo::Level3::Unblocked>::invoke(member, A);
member.team_barrier();

TeamVectorCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
if (r_val == 0) {
TeamVectorCopy<MemberType, Trans::NoTranspose, 1>::invoke(member, Y, X);
member.team_barrier();
}

TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0, A,
X);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Lower, Trans::NoTranspose,
Diag::Unit, Algo::Level3::Unblocked>::invoke(member, 1.0,
A, X);
member.team_barrier();
}

TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member, 1.0,
A, X);
member.team_barrier();
if (r_val == 0) {
TeamVectorTrsm<MemberType, Side::Left, Uplo::Upper, Trans::NoTranspose,
Diag::NonUnit, Algo::Level3::Unblocked>::invoke(member,
1.0, A, X);
member.team_barrier();
}

return 0;
return r_val;
}
};

Expand Down