Skip to content

Commit

Permalink
TridiagSolver (local): reduce GEMM step computational cost (#951)
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro authored Nov 30, 2023
1 parent 0655f5e commit ab2bb6f
Showing 1 changed file with 154 additions and 52 deletions.
206 changes: 154 additions & 52 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
#pragma once

#include <algorithm>
#include <array>
#include <functional>
#include <numeric>
#include <tuple>

#include <pika/barrier.hpp>
#include <pika/execution.hpp>
Expand All @@ -35,8 +37,10 @@
#include <dlaf/matrix/distribution.h>
#include <dlaf/matrix/index.h>
#include <dlaf/matrix/matrix.h>
#include <dlaf/matrix/matrix_ref.h>
#include <dlaf/memory/memory_view.h>
#include <dlaf/multiplication/general.h>
#include <dlaf/multiplication/general/api.h>
#include <dlaf/permutations/general.h>
#include <dlaf/permutations/general/impl.h>
#include <dlaf/schedulers.h>
Expand Down Expand Up @@ -252,6 +256,23 @@ auto calcTolerance(const SizeType i_begin, const SizeType i_end, Matrix<const T,
ex::ensure_started();
}

// Note:
// This is the order how we want the eigenvectors to be sorted, since it leads to a nicer matrix
// shape that allows to reduce the number of following operations (i.e. gemm)
inline std::size_t ev_sort_order(const ColType coltype) {
switch (coltype) {
case ColType::UpperHalf:
return 0;
case ColType::Dense:
return 1;
case ColType::LowerHalf:
return 2;
case ColType::Deflated:
return 3;
}
return DLAF_UNREACHABLE(std::size_t);
}

// This function returns number of non-deflated eigenvectors, together with a permutation @p out_ptr
// that represent mapping (sorted non-deflated | sorted deflated) -> initial.
//
Expand Down Expand Up @@ -312,8 +333,9 @@ SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType*
return k;
}

// This function returns number of non-deflated eigenvectors, together with two permutations
// - @p index_sorted (sort(non-deflated)|sort(deflated)) -> initial.
// This function returns number of non-deflated eigenvectors and a tuple with number of upper, dense
// and lower non-deflated eigenvectors, together with two permutations:
// - @p index_sorted (sort(non-deflated)|sorted(deflated) -> initial.
// - @p index_sorted_coltype (upper|dense|lower|sort(deflated)) -> initial
//
// The permutations will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors,
Expand All @@ -328,10 +350,11 @@ SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType*
// @param index_sorted_coltype array[n] (upper|dense|lower|sort(deflated)) -> initial
//
// @return k number of non-deflated eigenvectors
// @return n_udl tuple with number of upper, dense and lower eigenvectors
template <class T>
SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* types, const T* evals,
const SizeType* perm_sorted, SizeType* index_sorted,
SizeType* index_sorted_coltype) {
auto stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* types, const T* evals,
const SizeType* perm_sorted, SizeType* index_sorted,
SizeType* index_sorted_coltype) {
// Note:
// (in) types
// column type of the initial indexing
Expand All @@ -342,31 +365,19 @@ SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType*
// (out) index_sorted_coltype
// initial <-- (upper | dense | lower | sort(deflated))

// Note:
// This is the order how we want the eigenvectors to be sorted, since it leads to a nicer matrix
// shape that allows to reduce the number of following operations (i.e. gemm)
auto coltype_index = [](const ColType coltype) -> std::size_t {
switch (coltype) {
case ColType::UpperHalf:
return 0;
case ColType::Dense:
return 1;
case ColType::LowerHalf:
return 2;
case ColType::Deflated:
return 3;
}
return DLAF_UNREACHABLE(std::size_t);
};

std::array<std::size_t, 4> offsets{0, 0, 0, 0};
std::for_each(types, types + n, [&offsets, &coltype_index](const auto& coltype) {
std::for_each(types, types + n, [&offsets](const auto& coltype) {
if (coltype != ColType::Deflated)
offsets[1 + coltype_index(coltype)]++;
offsets[1 + ev_sort_order(coltype)]++;
});

std::array<std::size_t, 3> n_udl{offsets[1 + ev_sort_order(ColType::UpperHalf)],
offsets[1 + ev_sort_order(ColType::Dense)],
offsets[1 + ev_sort_order(ColType::LowerHalf)]};

std::partial_sum(offsets.cbegin(), offsets.cend(), offsets.begin());

const SizeType k = to_SizeType(offsets[coltype_index(ColType::Deflated)]);
const SizeType k = to_SizeType(offsets[ev_sort_order(ColType::Deflated)]);

// Create the permutation (sorted non-deflated | sorted deflated) -> initial
// Note:
Expand Down Expand Up @@ -411,14 +422,14 @@ SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType*
for (SizeType j = 0; j < n; ++j) {
const ColType& coltype = types[to_sizet(j)];
if (coltype != ColType::Deflated) {
auto& index_for_coltype = offsets[coltype_index(coltype)];
auto& index_for_coltype = offsets[ev_sort_order(coltype)];
index_sorted_coltype[index_for_coltype] = j;
++index_for_coltype;
}
}
std::copy(index_sorted + k, index_sorted + n, index_sorted_coltype + k);

return k;
return std::tuple(k, std::move(n_udl));
}

template <class T>
Expand Down Expand Up @@ -648,21 +659,7 @@ void solveRank1Problem(const SizeType i_begin, const SizeType i_end, KSender&& k
const std::size_t begin = thread_idx * batch_size;
const std::size_t end = std::min(thread_idx * batch_size + batch_size, to_sizet(k));

// STEP 0a: Fill ones for deflated Eigenvectors. (single-thread)
// Note: this step is completely independent from the rest, but it is small and it is going
// to be dropped soon.
// Note: use last thread that in principle should have less work to do
if (thread_idx == nthreads - 1) {
for (SizeType j = k; j < n; ++j) {
const GlobalElementIndex jj(j, j);
const auto linear_jj = distr.globalTileLinearIndex(jj);
const auto jj_el = distr.tileElementIndex(jj);

evec_tiles[to_sizet(linear_jj)](jj_el) = 1;
}
}

// STEP 0b: Initialize workspaces (single-thread)
// STEP 0: Initialize workspaces (single-thread)
if (thread_idx == 0) {
ws_vecs.reserve(nthreads);
for (std::size_t i = 0; i < nthreads; ++i)
Expand Down Expand Up @@ -797,6 +794,107 @@ void solveRank1Problem(const SizeType i_begin, const SizeType i_end, KSender&& k
}));
}

template <Backend B, class T, Device D, class KSender, class UDLSenders>
void multiplyEigenvectors(const SizeType sub_offset, const SizeType n, const SizeType n_upper,
const SizeType n_lower, Matrix<T, D>& e0, Matrix<T, D>& e1, Matrix<T, D>& e2,
KSender&& k, UDLSenders&& n_udl) {
// Note:
// This function computes E0 = E1 . E2
//
// where E1 is the matrix with eigenvectors and it looks like this
//
// ┌──────────┐ k
// │ a │ │
//
// ┌── ┌───┬──────┬─┬────┐
// │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// n_upper │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// │ │UUU│DDDDDD│ │XXXX│
// ├── ├───┼──────┼─┤XXXX│
// │ │ │DDDDDD│L│XXXX│
// n_lower │ │ │DDDDDD│L│XXXX│
// │ │ │DDDDDD│L│XXXX│
// └── └───┴──────┴─┴────┘
// │ b │
// └────────┘
//
// The multiplication in two different steps in order to skip zero blocks of the matrix, created by
// the grouping of eigenvectors of different lengths (UPPER, DENSE and LOWER).
//
// 1. GEMM1 = TL . TOP
// 2. GEMM2 = BR . BOTTOM
// 3. copy DEFLATED
//
// ┌────────────┬────┐
// │ │ │
// │ │ │
// │ T O P │ │
// │ │ │
// │ │ │
// ├────────────┤ │
// │ │ │
// │ │ │
// │B O T T O M │ │
// │ │ │
// └────────────┴────┘
//
// ┌──────────┬─┬────┐ ┌────────────┬────┐
// │ │0│ │ │ │ │
// │ │0│ D │ │ │ │
// │ TL │0│ E │ │ GEMM 1 │ C │
// │ │0│ F │ │ │ │
// │ │0│ L │ │ │ O │
// ├───┬──────┴─┤ A │ ├────────────┤ │
// │000│ │ T │ │ │ P │
// │000│ │ E │ │ │ │
// │000│ BR │ D │ │ GEMM 2 │ Y │
// │000│ │ │ │ │ │
// └───┴────────┴────┘ └────────────┴────┘

namespace ex = pika::execution::experimental;

ex::start_detached(
ex::when_all(std::forward<KSender>(k), std::forward<UDLSenders>(n_udl)) |
ex::then([sub_offset, n, n_upper, n_lower, e0 = e0.subPipeline(), e1 = e1.subPipelineConst(),
e2 = e2.subPipelineConst()](const SizeType k, std::array<std::size_t, 3> n_udl) mutable {
using dlaf::matrix::internal::MatrixRef;

const SizeType n_uh = to_SizeType(n_udl[ev_sort_order(ColType::UpperHalf)]);
const SizeType n_de = to_SizeType(n_udl[ev_sort_order(ColType::Dense)]);
const SizeType n_lh = to_SizeType(n_udl[ev_sort_order(ColType::LowerHalf)]);

const SizeType a = n_uh + n_de;
const SizeType b = n_de + n_lh;

using GEMM = dlaf::multiplication::internal::General<B, D, T>;
{
MatrixRef<const T, D> e1_sub(e1, {{sub_offset, sub_offset}, {n_upper, a}});
MatrixRef<const T, D> e2_sub(e2, {{sub_offset, sub_offset}, {a, k}});
MatrixRef<T, D> e0_sub(e0, {{sub_offset, sub_offset}, {n_upper, k}});
GEMM::callNN(T(1), e1_sub, e2_sub, T(0), e0_sub);
}

{
MatrixRef<const T, D> e1_sub(e1, {{sub_offset + n_upper, sub_offset + n_uh}, {n_lower, b}});
MatrixRef<const T, D> e2_sub(e2, {{sub_offset + n_uh, sub_offset}, {b, k}});
MatrixRef<T, D> e0_sub(e0, {{sub_offset + n_upper, sub_offset}, {n_lower, k}});

GEMM::callNN(T(1), e1_sub, e2_sub, T(0), e0_sub);
}

{
const matrix::internal::SubMatrixSpec deflated_submat{{sub_offset, sub_offset + k},
{n, n - k}};
MatrixRef<T, D> sub_e0(e0, deflated_submat);
MatrixRef<const T, D> sub_e1(e1, deflated_submat);

copy(sub_e1, sub_e0);
}
}));
}

template <Backend B, Device D, class T, class RhoSender>
void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const SizeType i_end,
RhoSender&& rho, WorkSpace<T, D>& ws, WorkSpaceHost<T>& ws_h,
Expand All @@ -813,8 +911,12 @@ void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const Size
const LocalTileIndex idx_begin_tiles_vec(i_begin, 0);
const LocalTileSize sz_tiles_vec(nrtiles, 1);

// Calculate the size of the upper subproblem
const SizeType n1 = problemSize(i_begin, i_split, ws.e0.distribution());
const auto& dist = ws.e0.distribution();

// Calculate the size of the upper, lower and full problem
const SizeType n = problemSize(i_begin, i_end, dist);
const SizeType n_upper = problemSize(i_begin, i_split, dist);
const SizeType n_lower = problemSize(i_split, i_end, dist);

// Assemble the rank-1 update vector `z` from the last row of Q1 and the first row of Q2
assembleZVec(i_begin, i_split, i_end, rho, ws.e0, ws.z0);
Expand All @@ -838,7 +940,7 @@ void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const Size
}

// Update indices of second sub-problem
addIndex(i_split, i_end, n1, ws_h.i1);
addIndex(i_split, i_end, n_upper, ws_h.i1);

// Step #1
//
Expand All @@ -848,7 +950,7 @@ void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const Size
// - deflate `d`, `z` and `c`
// - apply Givens rotations to `Q` - `evecs`
//
sortIndex(i_begin, i_end, ex::just(n1), ws_h.d0, ws_h.i1, ws_hm.i2);
sortIndex(i_begin, i_end, ex::just(n_upper), ws_h.d0, ws_h.i1, ws_hm.i2);

auto rots =
applyDeflation(i_begin, i_end, scaled_rho, std::move(tol), ws_hm.i2, ws_h.d0, ws_hm.z0, ws_h.c);
Expand Down Expand Up @@ -877,9 +979,10 @@ void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const Size
// | | | D | D | L | L | DF | DF | DF: Deflated
// | | | D | D | L | L | DF | DF |
//
auto k =
auto [k_unique, n_udl] =
stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, ws_h.i3, ws_hm.i5) |
ex::split();
ex::split_tuple();
auto k = ex::split(std::move(k_unique));

copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.i5, ws.i5);
dlaf::permutations::permute<B, D, T, Coord::Col>(i_begin, i_end, ws.i5, ws.e0, ws.e1);
Expand All @@ -906,17 +1009,16 @@ void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const Size
// Note:
// This is needed to set to zero elements of e2 outside of the k by k top-left part.
// The input is not required to be zero for solveRank1Problem.
matrix::util::set0<Backend::MC>(thread_priority::normal, idx_loc_begin, sz_loc_tiles, ws_hm.e2);
solveRank1Problem(i_begin, i_end, k, scaled_rho, ws_hm.d1, ws_hm.z1, ws_h.d0, ws_h.i4, ws_hm.e2);
copy(idx_loc_begin, sz_loc_tiles, ws_hm.e2, ws.e2);

// Step #3: Eigenvectors of the tridiagonal system: Q * U
//
// The eigenvectors resulting from the multiplication are already in the order of the eigenvalues as
// prepared for the deflated system.
dlaf::multiplication::internal::generalSubMatrix<B, D, T>(i_begin, i_end, blas::Op::NoTrans,
blas::Op::NoTrans, T(1), ws.e1, ws.e2, T(0),
ws.e0);
const SizeType sub_offset = dist.template globalTileElementDistance<Coord::Row>(0, i_begin);

multiplyEigenvectors<B>(sub_offset, n, n_upper, n_lower, ws.e0, ws.e1, ws.e2, k, std::move(n_udl));

// Step #4: Final permutation to sort eigenvalues and eigenvectors
//
Expand Down

0 comments on commit ab2bb6f

Please sign in to comment.