Skip to content

Commit

Permalink
Enable Gemm (distributed) to be used with MatrixRef (#1022)
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro authored Dec 1, 2023
1 parent ab2bb6f commit e311f6a
Show file tree
Hide file tree
Showing 5 changed files with 257 additions and 3 deletions.
27 changes: 27 additions & 0 deletions include/dlaf/multiplication/general.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,33 @@ void generalMatrix(const blas::Op opA, const blas::Op opB, const T alpha, Matrix
DLAF_UNIMPLEMENTED(opA, opB);
}

/// General sub-matrix distributed multiplication, computing
/// C = alpha * A * B + beta * C
///
/// @param mat_a contains the input matrix A.
/// @param mat_b contains the input matrix B.
/// @param mat_c On entry it contains the input matrix C. On exit matrix tiles in the range will be
/// overwritten with the result, while others are left untouched.
///
/// @pre @p mat_a, @p mat_b and @p mat_c are distributed on the same grid,
/// @pre multipliable_sizes(mat_a.size(), mat_b.size(), mat_c.size(), opA, opB)
/// @pre multipliable_sizes(mat_a.tile_size(), mat_b.tile_size(), mat_c.tile_size(), opA, opB)
/// @pre multipliable_sizes(mat_a.tile_size_of({0, 0}), mat_b.tile_size_of({0, 0}),
/// mat_c.tile_size_of({0, 0}), opA, opB)
template <Backend B, Device D, class T>
void generalMatrix(common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain, const T alpha,
MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b, const T beta,
MatrixRef<T, D>& mat_c) {
DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_a), mat_c, mat_b);
DLAF_ASSERT(matrix::same_process_grid(mat_c, mat_b), mat_c, mat_b);

DLAF_ASSERT_HEAVY(matrix::multipliable(mat_a, mat_b, mat_c, blas::Op::NoTrans, blas::Op::NoTrans),
mat_a, mat_b, mat_c);

internal::General<B, D, T>::callNN(row_task_chain, col_task_chain, alpha, mat_a, mat_b, beta, mat_c);
}

/// General sub-matrix multiplication implementation on local memory, computing
/// C[a:b][a:b] = alpha * opA(A[a:b][a:b]) * opB(B[a:b][a:b]) + beta * C[a:b][a:b]
/// where [a:b] is the range of tiles starting from tile index @p a to tile index @p b (excluded)
Expand Down
4 changes: 4 additions & 0 deletions include/dlaf/multiplication/general/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ template <Backend B, Device D, class T>
struct General {
static void callNN(const T alpha, MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b,
const T beta, MatrixRef<T, D>& mat_c);
static void callNN(common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain, const T alpha,
MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b, const T beta,
MatrixRef<T, D>& mat_c);
};

template <Backend B, Device D, class T>
Expand Down
88 changes: 88 additions & 0 deletions include/dlaf/multiplication/general/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,94 @@ void General<B, D, T>::callNN(const T alpha, MatrixRef<const T, D>& mat_a, Matri
}
}

template <Backend B, Device D, class T>
void General<B, D, T>::callNN(common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain, const T alpha,
MatrixRef<const T, D>& mat_a, MatrixRef<const T, D>& mat_b, const T beta,
MatrixRef<T, D>& mat_c) {
namespace ex = pika::execution::experimental;

if (mat_c.size().isEmpty())
return;

const matrix::Distribution& dist_a = mat_a.distribution();
const matrix::Distribution& dist_b = mat_b.distribution();
const matrix::Distribution& dist_c = mat_c.distribution();
const auto rank = dist_c.rank_index();

if (mat_a.nr_tiles().cols() == 0) {
// Note: if beta == 1, we optimize by not even scheduling anything
if (beta != T(1)) {
for (SizeType j = 0; j < mat_c.distribution().local_nr_tiles().cols(); ++j)
for (SizeType i = 0; i < mat_c.distribution().local_nr_tiles().rows(); ++i)
ex::start_detached(dlaf::internal::whenAllLift(beta, mat_c.readwrite(LocalTileIndex(i, j))) |
tile::scal(dlaf::internal::Policy<B>()));
}
return;
}

constexpr std::size_t n_workspaces = 2;
common::RoundRobin<matrix::Panel<Coord::Col, T, D>> panelsA(n_workspaces, dist_a);
common::RoundRobin<matrix::Panel<Coord::Row, T, D>> panelsB(n_workspaces, dist_b);

DLAF_ASSERT_HEAVY(mat_a.nr_tiles().cols() == mat_b.nr_tiles().rows(), mat_a.nr_tiles(),
mat_b.nr_tiles());

// This loops over the global indices for k, because every rank has to participate in communication
for (SizeType k = 0; k < mat_a.nr_tiles().cols(); ++k) {
auto& panelA = panelsA.nextResource();
auto& panelB = panelsB.nextResource();

if (k == 0 || k == mat_a.nr_tiles().cols() - 1) {
DLAF_ASSERT_HEAVY(dist_a.tile_size_of<Coord::Col>(k) == dist_b.tile_size_of<Coord::Row>(k),
dist_a.tile_size_of<Coord::Col>(k), dist_b.tile_size_of<Coord::Row>(k));
const SizeType kSize = dist_a.tile_size_of<Coord::Col>(k);
panelA.setWidth(kSize);
panelB.setHeight(kSize);
}

// Setup the column workspace for the root ranks, i.e. the ones in the current col
const auto rank_k_col = dist_a.rank_global_tile<Coord::Col>(k);
if (rank_k_col == rank.col()) {
const auto k_local = dist_a.local_tile_from_global_tile<Coord::Col>(k);
for (SizeType i = 0; i < dist_c.local_nr_tiles().rows(); ++i) {
const LocalTileIndex ik(i, k_local);
panelA.setTile(ik, mat_a.read(ik));
}
}
// Setup the row workspace for the root ranks, i.e. the ones in the current row
const auto rank_k_row = dist_b.rank_global_tile<Coord::Row>(k);
if (rank_k_row == rank.row()) {
const auto k_local = dist_b.local_tile_from_global_tile<Coord::Row>(k);
for (SizeType j = 0; j < dist_c.local_nr_tiles().cols(); ++j) {
const LocalTileIndex kj(k_local, j);
panelB.setTile(kj, mat_b.read(kj));
}
}

// Broadcast both column and row panel from root to others (row-wise and col-wise, respectively)
broadcast(rank_k_col, panelA, row_task_chain);
broadcast(rank_k_row, panelB, col_task_chain);

// This is the core loop where the k step performs the update over the entire local matrix using
// the col and row workspaces.
// Everything needed for the update is available locally thanks to previous broadcasts.
for (SizeType i = 0; i < dist_c.local_nr_tiles().rows(); ++i) {
for (SizeType j = 0; j < dist_c.local_nr_tiles().cols(); ++j) {
const LocalTileIndex ij(i, j);

ex::start_detached(dlaf::internal::whenAllLift(blas::Op::NoTrans, blas::Op::NoTrans, alpha,
panelA.read(ij), panelB.read(ij),
k == 0 ? beta : T(1), mat_c.readwrite(ij)) |
tile::gemm(dlaf::internal::Policy<B>()));
}
}

panelA.reset();
panelB.reset();
}
}

template <Backend B, Device D, class T>
void GeneralSub<B, D, T>::callNN(const SizeType idx_begin, const SizeType idx_end, const blas::Op opA,
const blas::Op opB, const T alpha, Matrix<const T, D>& mat_a,
Expand Down
6 changes: 3 additions & 3 deletions include/dlaf/util_matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ bool local_matrix(const MatrixLike<const T, D>& m) noexcept {
}

/// Returns true if the matrix is distributed on the communication grid.
template <class T, Device D>
bool equal_process_grid(const Matrix<const T, D>& m, const comm::CommunicatorGrid& g) noexcept {
template <template <class, Device> class MatrixLike, class T, Device D>
bool equal_process_grid(const MatrixLike<const T, D>& m, const comm::CommunicatorGrid& g) noexcept {
return m.commGridSize() == g.size() && m.rankIndex() == g.rank();
}

/// Returns true if the matrix is distributed on the communication grid.
/// Returns true if the two matrices are distributed on the same grid
template <template <class, Device> class MatrixLikeA, template <class, Device> class MatrixLikeB,
class T, Device D1, Device D2>
bool same_process_grid(const MatrixLikeA<const T, D1>& a, const MatrixLikeB<const T, D2>& b) noexcept {
Expand Down
135 changes: 135 additions & 0 deletions test/unit/multiplication/test_multiplication_general.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include <dlaf/blas/enum_output.h>
#include <dlaf/common/assert.h>
#include <dlaf/common/index2d.h>
#include <dlaf/common/pipeline.h>
#include <dlaf/communication/communicator.h>
#include <dlaf/communication/communicator_grid.h>
#include <dlaf/matrix/index.h>
#include <dlaf/matrix/matrix.h>
Expand Down Expand Up @@ -40,6 +42,10 @@ template <class T>
struct GeneralMultiplicationTestMC : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralMultiplicationTestMC, MatrixElementTypes);

template <class T>
struct GeneralMultiplicationDistTestMC : public TestWithCommGrids {};
TYPED_TEST_SUITE(GeneralMultiplicationDistTestMC, MatrixElementTypes);

template <class T>
struct GeneralSubMultiplicationTestMC : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralSubMultiplicationTestMC, MatrixElementTypes);
Expand All @@ -53,6 +59,10 @@ template <class T>
struct GeneralMultiplicationTestGPU : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralMultiplicationTestGPU, MatrixElementTypes);

template <class T>
struct GeneralMultiplicationDistTestGPU : public TestWithCommGrids {};
TYPED_TEST_SUITE(GeneralMultiplicationDistTestGPU, MatrixElementTypes);

template <class T>
struct GeneralSubMultiplicationTestGPU : public ::testing::Test {};
TYPED_TEST_SUITE(GeneralSubMultiplicationTestGPU, MatrixElementTypes);
Expand Down Expand Up @@ -136,6 +146,78 @@ void testGeneralMultiplication(const T alpha, const T beta, const GemmConfig& co
2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error);
}

template <class T, Backend B, Device D>
void testGeneralMultiplication(const T alpha, const T beta, const GemmConfig& config,
comm::CommunicatorGrid& grid) {
using dlaf::matrix::internal::MatrixRef;

common::Pipeline<comm::Communicator> mpi_row_chain(grid.rowCommunicator());
common::Pipeline<comm::Communicator> mpi_col_chain(grid.colCommunicator());

const TileElementSize blocksize_a(config.mb, config.kb);
const TileElementSize blocksize_b(config.kb, config.nb);
const TileElementSize blocksize_c(config.mb, config.nb);

const comm::Index2D src_rank_c(std::max(0, grid.size().rows() - 1),
std::min(1, grid.size().cols() - 1));
const matrix::Distribution dist_c(config.full_c(), blocksize_c, grid.size(), grid.rank(), src_rank_c);

const comm::IndexT_MPI rank_aligned_row =
align_sub_rank_index<Coord::Row>(dist_c, config.sub_c().origin, blocksize_a,
config.sub_a().origin);
const comm::IndexT_MPI rank_aligned_col =
align_sub_rank_index<Coord::Col>(dist_c, config.sub_c().origin, blocksize_b,
config.sub_b().origin);

// Note:
// GEMM(NoTrans, NoTrans) requires:
// - a is rank aligned with c for what concerns rows
// - b is rank aligned with c for what concerns cols
const comm::Index2D src_rank_a{rank_aligned_row, 0};
const comm::Index2D src_rank_b{0, rank_aligned_col};

const matrix::Distribution dist_a(config.full_a(), blocksize_a, grid.size(), grid.rank(), src_rank_a);
const matrix::Distribution dist_b(config.full_b(), blocksize_b, grid.size(), grid.rank(), src_rank_b);

auto setMatrix = [&](auto&& elSetter, matrix::Distribution dist) {
Matrix<T, Device::CPU> matrix(std::move(dist));
dlaf::matrix::util::set(matrix, elSetter);
return matrix;
};

auto [subValuesA, subValuesB, subValuesC, subValuesResult] =
matrix::test::getMatrixMatrixMultiplication<GlobalElementIndex, T>(config.opA, config.opB,
config.k, alpha, beta);

const auto fullValuesA = mix_values(config.sub_a(), subValuesA, [](auto) { return T(-99); });
const auto fullValuesB = mix_values(config.sub_b(), subValuesB, [](auto) { return T(-99); });
const auto fullValuesC = mix_values(config.sub_c(), subValuesC, [](auto) { return T(-99); });

Matrix<const T, Device::CPU> mat_ah = setMatrix(fullValuesA, dist_a);
Matrix<const T, Device::CPU> mat_bh = setMatrix(fullValuesB, dist_b);
Matrix<T, Device::CPU> mat_ch = setMatrix(fullValuesC, dist_c);

{
MatrixMirror<const T, D, Device::CPU> mat_a(mat_ah);
MatrixMirror<const T, D, Device::CPU> mat_b(mat_bh);
MatrixMirror<T, D, Device::CPU> mat_c(mat_ch);

MatrixRef<const T, D> mat_sub_a(mat_a.get(), config.sub_a());
MatrixRef<const T, D> mat_sub_b(mat_b.get(), config.sub_b());
MatrixRef<T, D> mat_sub_c(mat_c.get(), config.sub_c());

// Note: currently it is implemented just the NoTrans/NoTrans case
ASSERT_EQ(config.opA, blas::Op::NoTrans);
ASSERT_EQ(config.opB, blas::Op::NoTrans);
multiplication::internal::generalMatrix<B>(mpi_row_chain, mpi_col_chain, alpha, mat_sub_a, mat_sub_b,
beta, mat_sub_c);
}

const auto fullValuesResult = mix_values(config.sub_c(), subValuesResult, fullValuesC);
CHECK_MATRIX_NEAR(fullValuesResult, mat_ch, 2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error,
2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error);
}

std::vector<GemmConfig> gemm_configs = {
// empty matrices
{blas::Op::NoTrans, blas::Op::NoTrans, 0, 0, 7, 3, 6, 2},
Expand Down Expand Up @@ -164,6 +246,7 @@ std::vector<GemmConfig> sub_gemm_configs = {
{blas::Op::NoTrans, blas::Op::NoTrans, 8, 8, 11, 10, 9, 13, {{2, 1}}, {{1, 1}}, {{0, 0}}},
// multi-tile
{blas::Op::NoTrans, blas::Op::NoTrans, 12, 20, 11, 3, 4, 5, {{7, 1}}, {{11, 10}}, {{4, 2}}},
{blas::Op::NoTrans, blas::Op::NoTrans, 12, 20, 11, 3, 4, 5, {{6, 10}}, {{5, 8}}, {{9, 12}}},
};

TYPED_TEST(GeneralMultiplicationTestMC, CorrectnessLocal) {
Expand Down Expand Up @@ -311,6 +394,32 @@ void testGeneralSubMultiplication(comm::CommunicatorGrid grid, const SizeType a,
2 * (mat_ah.size().cols() + 1) * TypeUtilities<T>::error);
}

TYPED_TEST(GeneralMultiplicationDistTestMC, CorrectnessDistributed) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::MC, Device::CPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralMultiplicationDistTestMC, CorrectnessDistributedSub) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : sub_gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::MC, Device::CPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralSubMultiplicationDistTestMC, CorrectnessDistributed) {
for (auto comm_grid : this->commGrids()) {
for (const auto& [m, mb, a, b] : sizes) {
Expand All @@ -324,6 +433,32 @@ TYPED_TEST(GeneralSubMultiplicationDistTestMC, CorrectnessDistributed) {
}

#ifdef DLAF_WITH_GPU
TYPED_TEST(GeneralMultiplicationDistTestGPU, CorrectnessDistributed) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::GPU, Device::GPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralMultiplicationDistTestGPU, CorrectnessDistributedSub) {
constexpr TypeParam alpha = TypeUtilities<TypeParam>::element(-1.3, .5);
constexpr TypeParam beta = TypeUtilities<TypeParam>::element(-2.6, .7);

for (auto comm_grid : this->commGrids()) {
for (const GemmConfig& test_config : sub_gemm_configs) {
testGeneralMultiplication<TypeParam, Backend::GPU, Device::GPU>(alpha, beta, test_config,
comm_grid);
pika::wait();
}
}
}

TYPED_TEST(GeneralSubMultiplicationDistTestGPU, CorrectnessDistributed) {
for (auto comm_grid : this->commGrids()) {
for (const auto& [m, mb, a, b] : sizes) {
Expand Down

0 comments on commit e311f6a

Please sign in to comment.