Skip to content

Commit

Permalink
Update gen_to_std and cholesky after rebase
Browse files Browse the repository at this point in the history
  • Loading branch information
msimberg committed Jul 12, 2021
1 parent 6c394ab commit bb1ed9c
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 106 deletions.
108 changes: 61 additions & 47 deletions include/dlaf/eigensolver/gen_to_std/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,15 @@ namespace dlaf {
namespace eigensolver {
namespace internal {

template <Backend backend, typename AKKSender, typename LKKSender>
template <Backend backend, class AKKSender, class LKKSender>
void hegstDiagTile(AKKSender&& a_kk, LKKSender&& l_kk) {
dlaf::internal::whenAllLift(1, blas::Uplo::Lower, std::forward<AKKSender>(a_kk),
std::forward<LKKSender>(l_kk)) |
dlaf::tile::hegst(dlaf::internal::Policy<backend>(hpx::threads::thread_priority::high)) |
hpx::execution::experimental::detach();
}

template <Backend backend, typename T, typename LKKSender, typename AIKSender>
template <Backend backend, class T, class LKKSender, class AIKSender>
void trsmPanelTile(LKKSender&& l_kk, AIKSender&& a_ik) {
dlaf::internal::whenAllLift(blas::Side::Right, blas::Uplo::Lower, blas::Op::ConjTrans,
blas::Diag::NonUnit, T(1.0), std::forward<LKKSender>(l_kk),
Expand All @@ -53,7 +53,7 @@ void trsmPanelTile(LKKSender&& l_kk, AIKSender&& a_ik) {
hpx::execution::experimental::detach();
}

template <Backend backend, class T, typename AKKSender, typename LIKSender, typename AIKSender>
template <Backend backend, class T, class AKKSender, class LIKSender, class AIKSender>
void hemmPanelTile(AKKSender&& a_kk, LIKSender&& l_ik, AIKSender&& a_ik) {
dlaf::internal::whenAllLift(blas::Side::Right, blas::Uplo::Lower, T(-0.5),
std::forward<AKKSender>(a_kk), std::forward<LIKSender>(l_ik), T(1.0),
Expand All @@ -62,45 +62,48 @@ void hemmPanelTile(AKKSender&& a_kk, LIKSender&& l_ik, AIKSender&& a_ik) {
hpx::execution::experimental::detach();
}

template <class Executor, Device device, class T>
void her2kTrailingDiagTile(Executor&& ex, hpx::shared_future<matrix::Tile<const T, device>> a_jk,
hpx::shared_future<matrix::Tile<const T, device>> l_jk,
hpx::future<matrix::Tile<T, device>> a_kk) {
hpx::dataflow(ex, matrix::unwrapExtendTiles(tile::internal::her2k_o), blas::Uplo::Lower,
blas::Op::NoTrans, T(-1.0), a_jk, l_jk, BaseType<T>(1.0), std::move(a_kk));
template <Backend backend, class T, class AJKSender, class LJKSender, class AKKSender>
void her2kTrailingDiagTile(hpx::threads::thread_priority priority, AJKSender&& a_jk, LJKSender&& l_jk,
AKKSender&& a_kk) {
dlaf::internal::whenAllLift(blas::Uplo::Lower, blas::Op::NoTrans, T(-1.0),
std::forward<AJKSender>(a_jk), std::forward<LJKSender>(l_jk),
BaseType<T>(1.0), std::forward<AKKSender>(a_kk)) |
dlaf::tile::her2k(dlaf::internal::Policy<backend>(priority)) |
hpx::execution::experimental::detach();
}

template <class Executor, Device device, class T>
void gemmTrailingMatrixTile(Executor&& ex, hpx::shared_future<matrix::Tile<const T, device>> mat_ik,
hpx::shared_future<matrix::Tile<const T, device>> mat_jk,
hpx::future<matrix::Tile<T, device>> a_ij) {
hpx::dataflow(ex, matrix::unwrapExtendTiles(tile::internal::gemm_o), blas::Op::NoTrans,
blas::Op::ConjTrans, T(-1.0), mat_ik, mat_jk, T(1.0), std::move(a_ij));
template <Backend backend, class T, class MatIKSender, class MatJKSender, class AIJSender>
void gemmTrailingMatrixTile(hpx::threads::thread_priority priority, MatIKSender&& mat_ik,
MatJKSender&& mat_jk, AIJSender a_ij) {
dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::ConjTrans, T(-1.0),
std::forward<MatIKSender>(mat_ik), std::forward<MatJKSender>(mat_jk),
T(1.0), std::forward<AIJSender>(a_ij)) |
dlaf::tile::gemm(dlaf::internal::Policy<backend>(priority)) |
hpx::execution::experimental::detach();
}

template <class Executor, Device device, class T>
void trsmPanelUpdateTile(Executor&& executor_hp, hpx::shared_future<matrix::Tile<const T, device>> l_jj,
hpx::future<matrix::Tile<T, device>> a_jk) {
hpx::dataflow(executor_hp, matrix::unwrapExtendTiles(tile::internal::trsm_o), blas::Side::Left,
blas::Uplo::Lower, blas::Op::NoTrans, blas::Diag::NonUnit, T(1.0), l_jj,
std::move(a_jk));
template <Backend backend, class T, class LJJSender, class AJKSender>
void trsmPanelUpdateTile(LJJSender&& l_jj, AJKSender a_jk) {
dlaf::internal::whenAllLift(blas::Side::Left, blas::Uplo::Lower, blas::Op::NoTrans,
blas::Diag::NonUnit, T(1.0), std::forward<LJJSender>(l_jj),
std::forward<AJKSender>(a_jk)) |
dlaf::tile::trsm(dlaf::internal::Policy<backend>(hpx::threads::thread_priority::high)) |
hpx::execution::experimental::detach();
}

template <class Executor, Device device, class T>
void gemmPanelUpdateTile(Executor&& ex, hpx::shared_future<matrix::Tile<const T, device>> l_ij,
hpx::shared_future<matrix::Tile<const T, device>> a_jk,
hpx::future<matrix::Tile<T, device>> a_ik) {
hpx::dataflow(ex, matrix::unwrapExtendTiles(tile::internal::gemm_o), blas::Op::NoTrans,
blas::Op::NoTrans, T(-1.0), l_ij, a_jk, T(1.0), std::move(a_ik));
template <Backend backend, class T, class LIJSender, class AJKSender, class AIKSender>
void gemmPanelUpdateTile(LIJSender&& l_ij, AJKSender&& a_jk, AIKSender&& a_ik) {
dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::NoTrans, T(-1.0),
std::forward<LIJSender>(l_ij), std::forward<AJKSender>(a_jk), T(1.0),
std::forward<AIKSender>(a_ik)) |
dlaf::tile::gemm(dlaf::internal::Policy<backend>(hpx::threads::thread_priority::normal)) |
hpx::execution::experimental::detach();
}

// Implementation based on LAPACK Algorithm for the transformation from generalized to standard
// eigenproblem (xHEGST)
template <Backend backend, Device device, class T>
void GenToStd<backend, device, T>::call_L(Matrix<T, device>& mat_a, Matrix<T, device>& mat_l) {
auto executor_hp = dlaf::getHpExecutor<backend>();
auto executor_np = dlaf::getNpExecutor<backend>();

// Number of tile (rows = cols)
SizeType nrtile = mat_a.nrTiles().cols();

Expand All @@ -124,15 +127,19 @@ void GenToStd<backend, device, T>::call_L(Matrix<T, device>& mat_a, Matrix<T, de
const LocalTileIndex jj{j, j};
const LocalTileIndex jk{j, k};
// first trailing panel gets high priority (look ahead).
auto& trailing_matrix_executor = (j == k + 1) ? executor_hp : executor_np;
const auto trailing_matrix_priority =
(j == k + 1) ? hpx::threads::thread_priority::high : hpx::threads::thread_priority::normal;

her2kTrailingDiagTile(trailing_matrix_executor, mat_a.read(jk), mat_l.read(jk), mat_a(jj));
her2kTrailingDiagTile<backend, T>(trailing_matrix_priority, mat_a.read_sender(jk),
mat_l.read_sender(jk), mat_a.readwrite_sender(jj));

for (SizeType i = j + 1; i < nrtile; ++i) {
const LocalTileIndex ik{i, k};
const LocalTileIndex ij{i, j};
gemmTrailingMatrixTile(trailing_matrix_executor, mat_a.read(ik), mat_l.read(jk), mat_a(ij));
gemmTrailingMatrixTile(trailing_matrix_executor, mat_l.read(ik), mat_a.read(jk), mat_a(ij));
gemmTrailingMatrixTile<backend, T>(trailing_matrix_priority, mat_a.read_sender(ik),
mat_l.read_sender(jk), mat_a.readwrite_sender(ij));
gemmTrailingMatrixTile<backend, T>(trailing_matrix_priority, mat_l.read_sender(ik),
mat_a.read_sender(jk), mat_a.readwrite_sender(ij));
}
}

Expand All @@ -144,12 +151,13 @@ void GenToStd<backend, device, T>::call_L(Matrix<T, device>& mat_a, Matrix<T, de
for (SizeType j = k + 1; j < nrtile; ++j) {
const LocalTileIndex jj{j, j};
const LocalTileIndex jk{j, k};
trsmPanelUpdateTile(executor_hp, mat_l.read(jj), mat_a(jk));
trsmPanelUpdateTile<backend, T>(mat_l.read_sender(jj), mat_a.readwrite_sender(jk));

for (SizeType i = j + 1; i < nrtile; ++i) {
const LocalTileIndex ij{i, j};
const LocalTileIndex ik{i, k};
gemmPanelUpdateTile(executor_np, mat_l.read(ij), mat_a.read(jk), mat_a(ik));
gemmPanelUpdateTile<backend, T>(mat_l.read_sender(ij), mat_a.read_sender(jk),
mat_a.readwrite_sender(ik));
}
}
}
Expand All @@ -158,8 +166,6 @@ void GenToStd<backend, device, T>::call_L(Matrix<T, device>& mat_a, Matrix<T, de
template <Backend backend, Device device, class T>
void GenToStd<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T, device>& mat_a,
Matrix<T, device>& mat_l) {
auto executor_hp = dlaf::getHpExecutor<backend>();
auto executor_np = dlaf::getNpExecutor<backend>();
auto executor_mpi = dlaf::getMPIExecutor<backend>();

// Set up MPI executor pipelines
Expand Down Expand Up @@ -223,7 +229,7 @@ void GenToStd<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T,
const LocalTileIndex kj_panelT{Coord::Col, j_local};
const LocalTileIndex kj(kk_offset.rows(), j_local);

trsmPanelUpdateTile(executor_hp, l_panel.read(kk_panel), mat_a(kj));
trsmPanelUpdateTile<backend, T>(l_panel.read_sender(kk_panel), mat_a(kj));

a_panelT.setTile(kj_panelT, mat_a.read(kj));
}
Expand All @@ -245,7 +251,8 @@ void GenToStd<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T,
}
for (SizeType j_local = 0; j_local < kk_offset.cols(); ++j_local) {
const LocalTileIndex kj(kk_offset.rows(), j_local);
trsmPanelUpdateTile(executor_hp, l_kk, mat_a(kj));
trsmPanelUpdateTile<backend, T>(hpx::execution::experimental::keep_future(l_kk),
mat_a.readwrite_sender(kj));
}
}
}
Expand All @@ -260,7 +267,8 @@ void GenToStd<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T,
const LocalTileIndex kj_panelT{Coord::Col, j_local};
const LocalTileIndex ij{i_local, j_local};

gemmPanelUpdateTile(executor_np, l_panel.read(ik_panel), a_panelT.read(kj_panelT), mat_a(ij));
gemmPanelUpdateTile<backend, T>(l_panel.read_sender(ik_panel), a_panelT.read_sender(kj_panelT),
mat_a.readwrite_sender(ij));
}
}
}
Expand Down Expand Up @@ -322,13 +330,15 @@ void GenToStd<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T,

const auto j_local = distr.localTileFromGlobalTile<Coord::Col>(j);
// first trailing panel gets high priority (look ahead).
auto& trailing_matrix_executor = (j == k + 1) ? executor_hp : executor_np;
const auto trailing_matrix_priority =
(j == k + 1) ? hpx::threads::thread_priority::high : hpx::threads::thread_priority::normal;
if (this_rank.row() == owner.row()) {
const auto i_local = distr.localTileFromGlobalTile<Coord::Row>(j);

her2kTrailingDiagTile(trailing_matrix_executor, a_panel.read({Coord::Row, i_local}),
l_panel.read({Coord::Row, i_local}),
mat_a(LocalTileIndex{i_local, j_local}));
her2kTrailingDiagTile<backend, T>(trailing_matrix_priority,
a_panel.read_sender({Coord::Row, i_local}),
l_panel.read_sender({Coord::Row, i_local}),
mat_a.readwrite_sender(LocalTileIndex{i_local, j_local}));
}

for (SizeType i = j + 1; i < nrtile; ++i) {
Expand All @@ -342,8 +352,12 @@ void GenToStd<backend, device, T>::call_L(comm::CommunicatorGrid grid, Matrix<T,
const LocalTileIndex kj_panelT{Coord::Col, j_local};
const LocalTileIndex ij{i_local, j_local};

gemmTrailingMatrixTile(executor_np, a_panel.read(ik_panel), l_panelT.read(kj_panelT), mat_a(ij));
gemmTrailingMatrixTile(executor_np, l_panel.read(ik_panel), a_panelT.read(kj_panelT), mat_a(ij));
gemmTrailingMatrixTile<backend, T>(hpx::threads::thread_priority::normal,
a_panel.read_sender(ik_panel),
l_panelT.read_sender(kj_panelT), mat_a.readwrite_sender(ij));
gemmTrailingMatrixTile<backend, T>(hpx::threads::thread_priority::normal,
l_panel.read_sender(ik_panel),
a_panelT.read_sender(kj_panelT), mat_a.readwrite_sender(ij));
}
}

Expand Down
Loading

0 comments on commit bb1ed9c

Please sign in to comment.