From e311f6afa75886df3699d09b32a78f7b41f79c37 Mon Sep 17 00:00:00 2001 From: Alberto Invernizzi <9337627+albestro@users.noreply.github.com> Date: Fri, 1 Dec 2023 10:36:19 +0100 Subject: [PATCH] Enable Gemm (distributed) to be used with MatrixRef (#1022) --- include/dlaf/multiplication/general.h | 27 ++++ include/dlaf/multiplication/general/api.h | 4 + include/dlaf/multiplication/general/impl.h | 88 ++++++++++++ include/dlaf/util_matrix.h | 6 +- .../test_multiplication_general.cpp | 135 ++++++++++++++++++ 5 files changed, 257 insertions(+), 3 deletions(-) diff --git a/include/dlaf/multiplication/general.h b/include/dlaf/multiplication/general.h index 6eaa3e1f22..d9e06552e8 100644 --- a/include/dlaf/multiplication/general.h +++ b/include/dlaf/multiplication/general.h @@ -64,6 +64,33 @@ void generalMatrix(const blas::Op opA, const blas::Op opB, const T alpha, Matrix DLAF_UNIMPLEMENTED(opA, opB); } +/// General sub-matrix distributed multiplication, computing +/// C = alpha * A * B + beta * C +/// +/// @param mat_a contains the input matrix A. +/// @param mat_b contains the input matrix B. +/// @param mat_c On entry it contains the input matrix C. On exit matrix tiles in the range will be +/// overwritten with the result, while others are left untouched. +/// +/// @pre @p mat_a, @p mat_b and @p mat_c are distributed on the same grid, +/// @pre multipliable_sizes(mat_a.size(), mat_b.size(), mat_c.size(), opA, opB) +/// @pre multipliable_sizes(mat_a.tile_size(), mat_b.tile_size(), mat_c.tile_size(), opA, opB) +/// @pre multipliable_sizes(mat_a.tile_size_of({0, 0}), mat_b.tile_size_of({0, 0}), +/// mat_c.tile_size_of({0, 0}), opA, opB) +template +void generalMatrix(common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const T alpha, + MatrixRef& mat_a, MatrixRef& mat_b, const T beta, + MatrixRef& mat_c) { + DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_a), mat_c, mat_b); + DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_b), mat_c, mat_b); + + DLAF_ASSERT_HEAVY(matrix::multipliable(mat_a, mat_b, mat_c, blas::Op::NoTrans, blas::Op::NoTrans), + mat_a, mat_b, mat_c); + + internal::General::callNN(row_task_chain, col_task_chain, alpha, mat_a, mat_b, beta, mat_c); +} + /// General sub-matrix multiplication implementation on local memory, computing /// C[a:b][a:b] = alpha * opA(A[a:b][a:b]) * opB(B[a:b][a:b]) + beta * C[a:b][a:b] /// where [a:b] is the range of tiles starting from tile index @p a to tile index @p b (excluded) diff --git a/include/dlaf/multiplication/general/api.h b/include/dlaf/multiplication/general/api.h index 3bb418cac4..058c1ba8de 100644 --- a/include/dlaf/multiplication/general/api.h +++ b/include/dlaf/multiplication/general/api.h @@ -25,6 +25,10 @@ template struct General { static void callNN(const T alpha, MatrixRef& mat_a, MatrixRef& mat_b, const T beta, MatrixRef& mat_c); + static void callNN(common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const T alpha, + MatrixRef& mat_a, MatrixRef& mat_b, const T beta, + MatrixRef& mat_c); }; template diff --git a/include/dlaf/multiplication/general/impl.h b/include/dlaf/multiplication/general/impl.h index 34abe4adef..fc4c5318f9 100644 --- a/include/dlaf/multiplication/general/impl.h +++ b/include/dlaf/multiplication/general/impl.h @@ -60,6 +60,94 @@ void General::callNN(const T alpha, MatrixRef& mat_a, Matri } } +template +void General::callNN(common::Pipeline& row_task_chain, + common::Pipeline& col_task_chain, const T alpha, + MatrixRef& mat_a, MatrixRef& mat_b, const T beta, + MatrixRef& mat_c) { + namespace ex = pika::execution::experimental; + + if (mat_c.size().isEmpty()) + return; + + const matrix::Distribution& dist_a = mat_a.distribution(); + const matrix::Distribution& dist_b = mat_b.distribution(); + const matrix::Distribution& dist_c = mat_c.distribution(); + const auto rank = dist_c.rank_index(); + + if (mat_a.nr_tiles().cols() == 0) { + // Note: if beta == 1, we optimize by not even scheduling anything + if (beta != T(1)) { + for (SizeType j = 0; j < mat_c.distribution().local_nr_tiles().cols(); ++j) + for (SizeType i = 0; i < mat_c.distribution().local_nr_tiles().rows(); ++i) + ex::start_detached(dlaf::internal::whenAllLift(beta, mat_c.readwrite(LocalTileIndex(i, j))) | + tile::scal(dlaf::internal::Policy())); + } + return; + } + + constexpr std::size_t n_workspaces = 2; + common::RoundRobin> panelsA(n_workspaces, dist_a); + common::RoundRobin> panelsB(n_workspaces, dist_b); + + DLAF_ASSERT_HEAVY(mat_a.nr_tiles().cols() == mat_b.nr_tiles().rows(), mat_a.nr_tiles(), + mat_b.nr_tiles()); + + // This loops over the global indices for k, because every rank has to participate in communication + for (SizeType k = 0; k < mat_a.nr_tiles().cols(); ++k) { + auto& panelA = panelsA.nextResource(); + auto& panelB = panelsB.nextResource(); + + if (k == 0 || k == mat_a.nr_tiles().cols() - 1) { + DLAF_ASSERT_HEAVY(dist_a.tile_size_of(k) == dist_b.tile_size_of(k), + dist_a.tile_size_of(k), dist_b.tile_size_of(k)); + const SizeType kSize = dist_a.tile_size_of(k); + panelA.setWidth(kSize); + panelB.setHeight(kSize); + } + + // Setup the column workspace for the root ranks, i.e. the ones in the current col + const auto rank_k_col = dist_a.rank_global_tile(k); + if (rank_k_col == rank.col()) { + const auto k_local = dist_a.local_tile_from_global_tile(k); + for (SizeType i = 0; i < dist_c.local_nr_tiles().rows(); ++i) { + const LocalTileIndex ik(i, k_local); + panelA.setTile(ik, mat_a.read(ik)); + } + } + // Setup the row workspace for the root ranks, i.e. the ones in the current row + const auto rank_k_row = dist_b.rank_global_tile(k); + if (rank_k_row == rank.row()) { + const auto k_local = dist_b.local_tile_from_global_tile(k); + for (SizeType j = 0; j < dist_c.local_nr_tiles().cols(); ++j) { + const LocalTileIndex kj(k_local, j); + panelB.setTile(kj, mat_b.read(kj)); + } + } + + // Broadcast both column and row panel from root to others (row-wise and col-wise, respectively) + broadcast(rank_k_col, panelA, row_task_chain); + broadcast(rank_k_row, panelB, col_task_chain); + + // This is the core loop where the k step performs the update over the entire local matrix using + // the col and row workspaces. + // Everything needed for the update is available locally thanks to previous broadcasts. + for (SizeType i = 0; i < dist_c.local_nr_tiles().rows(); ++i) { + for (SizeType j = 0; j < dist_c.local_nr_tiles().cols(); ++j) { + const LocalTileIndex ij(i, j); + + ex::start_detached(dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::NoTrans, alpha, + panelA.read(ij), panelB.read(ij), + k == 0 ? beta : T(1), mat_c.readwrite(ij)) | + tile::gemm(dlaf::internal::Policy())); + } + } + + panelA.reset(); + panelB.reset(); + } +} + template void GeneralSub::callNN(const SizeType idx_begin, const SizeType idx_end, const blas::Op opA, const blas::Op opB, const T alpha, Matrix& mat_a, diff --git a/include/dlaf/util_matrix.h b/include/dlaf/util_matrix.h index 2c07fef1a6..9977d174f6 100644 --- a/include/dlaf/util_matrix.h +++ b/include/dlaf/util_matrix.h @@ -74,12 +74,12 @@ bool local_matrix(const MatrixLike& m) noexcept { } /// Returns true if the matrix is distributed on the communication grid. -template -bool equal_process_grid(const Matrix& m, const comm::CommunicatorGrid& g) noexcept { +template