Skip to content

Commit

Permalink
Merge pull request #7 from devreal/spmm-with-ctlflow-const
Browse files Browse the repository at this point in the history
Sprinkle some consts across the SPMM code
  • Loading branch information
therault authored Oct 8, 2021
2 parents ac184c8 + 10340ee commit b5322ec
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions examples/spmm/spmm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ inline int tile2rank(int i, int j, int P, int Q) {

// flow (move?) data into an existing SpMatrix on rank 0
template <typename Blk = blk_t>
class Write_SpMatrix : public Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, Blk> {
class Write_SpMatrix : public Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, const Blk> {
public:
using baseT = Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, Blk>;
using baseT = Op<Key<2>, std::tuple<>, Write_SpMatrix<Blk>, const Blk>;

template <typename Keymap>
Write_SpMatrix(SpMatrix<Blk> &matrix, Edge<Key<2>, Blk> &in, Keymap &&keymap)
Expand Down Expand Up @@ -308,7 +308,7 @@ class SpMM {
};

using gemmset_t = std::set<std::tuple<long, long, long>>;
using step_vector_t = std::vector<std::tuple<gemmset_t, long>>;
using step_vector_t = std::vector<std::tuple<gemmset_t, long, gemmset_t>>;
using step_per_tile_t = std::unordered_map<std::tuple<long, long, long>, std::set<long>, long_tuple_hash>;
using bcastset_t = std::set<std::tuple<long, long>>;
using comm_plan_t = std::vector<std::vector<bcastset_t>>;
Expand Down Expand Up @@ -491,6 +491,7 @@ class SpMM {
for (long nn = 0; nn < nns; nn++) {
for (long kk = 0; kk < kns; kk++) {
gemmset_t gemms;
gemmset_t local_gemms;
long nb_local_gemms = 0;
for (long m = mm * cube_dim; m < (mm + 1) * cube_dim && m < mt; m++) {
if (m >= a_rowidx_to_colidx_.size() || a_rowidx_to_colidx_[m].empty()) continue;
Expand All @@ -505,7 +506,10 @@ class SpMM {
b_colidx_to_rowidx_[n].end())
continue;
auto r = keymap_(Key<2>({m, n}));
if (r == rank) nb_local_gemms++;
if (r == rank) {
local_gemms.insert({m, n, k});
nb_local_gemms++;
}
gemms.insert({m, n, k});
auto it = steps_per_tile_A.find(std::make_tuple(r, m, k));
if (it == steps_per_tile_A.end()) {
Expand Down Expand Up @@ -545,7 +549,7 @@ class SpMM {
}
}
}
steps.emplace_back(std::make_tuple(gemms, nb_local_gemms));
steps.emplace_back(std::make_tuple(gemms, nb_local_gemms, local_gemms));
}
}
}
Expand Down Expand Up @@ -762,6 +766,7 @@ class SpMM {
};

const gemmset_t &gemms(long s) const { return std::get<0>(std::get<0>(steps_)[s]); }
const gemmset_t &local_gemms(long s) const { return std::get<2>(std::get<0>(steps_)[s]); }

long nb_local_gemms(long s) const { return std::get<1>(std::get<0>(steps_)[s]); }

Expand Down Expand Up @@ -882,18 +887,18 @@ class SpMM {

/// Central coordinator: ensures that all progress according to the plan
class Coordinator : public Op<Key<2>, std::tuple<Out<Key<4>, Control>, Out<Key<4>, Control>, Out<Key<2>, Control>>,
Coordinator, Control> {
Coordinator, const Control> {
public:
using baseT =
Op<Key<2>, std::tuple<Out<Key<4>, Control>, Out<Key<4>, Control>, Out<Key<2>, Control>>, Coordinator, Control>;
Op<Key<2>, std::tuple<Out<Key<4>, Control>, Out<Key<4>, Control>, Out<Key<2>, Control>>, Coordinator, const Control>;

Coordinator(Edge<Key<2>, Control> progress_ctl, Edge<Key<4>, Control> &a_ctl, Edge<Key<4>, Control> &b_ctl,
Edge<Key<2>, Control> &c2c_ctl, std::shared_ptr<const Plan> plan, const Keymap &keymap)
: baseT(edges(fuse(progress_ctl, c2c_ctl)), edges(a_ctl, b_ctl, c2c_ctl), std::string("SpMM::Coordinator"),
{"ctl_rs"}, {"a_ctl_riks", "b_ctl_rkjs", "ctl_rs"}, [](const Key<2> &key) { return (int)key[0]; })
, plan_(plan)
, keymap_(keymap) {
baseT::template set_input_reducer<0>([](Control &&a, const Control &&b) { return a; });
baseT::template set_input_reducer<0>([](Control &&a, Control &&b) { return a; });
auto r = ttg_default_execution_context().rank();
for (long l = 0; l < plan_->lookahead_ && l < plan_->nb_steps(); l++) {
if (tracing())
Expand All @@ -918,7 +923,7 @@ class SpMM {
ttg::print("Coordinator(", r, ",", s, "): there are 0 local GEMMS in step", s,
"; triggering next coordinator step", s + plan_->lookahead_);
baseT::template set_argstream_size<0>(Key<2>({r, s + plan_->lookahead_}), 1);
::send<2>(Key<2>({r, s + plan_->lookahead_}), Control{}, out);
::send<2>(Key<2>({r, s + plan_->lookahead_}), std::get<0>(input), out);
}
}

Expand All @@ -930,22 +935,21 @@ class SpMM {

std::unordered_set<std::tuple<int, int>, tuple_hash> seen_a;
std::unordered_set<std::tuple<int, int>, tuple_hash> seen_b;
for (auto x : plan_->gemms(s)) {
for (auto x : plan_->local_gemms(s)) {
long gi, gj, gk;
std::tie(gi, gj, gk) = x;
if (keymap_(Key<2>{gi, gj}) != r) continue;
if (seen_a.find(std::make_tuple(gi, gk)) == seen_a.end()) {
if (tracing())
ttg::print("On rank", r, "Coordinator(", r, ", ", s, "): Sending control to LBCastA(", r, ",", gi, ",", gk,
",", s, ")");
::send<0>(Key<4>({r, gi, gk, s}), Control{}, out);
::send<0>(Key<4>({r, gi, gk, s}), std::get<0>(input), out);
seen_a.insert(std::make_tuple(gi, gk));
}
if (seen_b.find(std::make_tuple(gk, gj)) == seen_b.end()) {
if (tracing())
ttg::print("On rank", r, "Coordinator(", r, ", ", s, "): Sending control to LBCastB(", r, ",", gk, ",", gj,
",", s, ")");
::send<1>(Key<4>({r, gk, gj, s}), Control{}, out);
::send<1>(Key<4>({r, gk, gj, s}), std::get<0>(input), out);
seen_b.insert(std::make_tuple(gk, gj));
}
}
Expand All @@ -969,7 +973,7 @@ class SpMM {
, plan_(plan)
, is_a_(is_label_a(label))
, keymap_(keymap) {
baseT::template set_input_reducer<0>([](Control &&a, const Control &&b) { return a; });
baseT::template set_input_reducer<0>([](Control &&a, Control &&b) { return a; });
auto r = ttg_default_execution_context().rank();
if (tracing())
ttg::print("On rank ", r, ": at bootstrap, setting the number of local bcast on ", is_a_ ? "A" : "B",
Expand Down Expand Up @@ -1064,9 +1068,9 @@ class SpMM {
};

/// broadcast A[i][k] to all procs where B[j][k]
class BcastA : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, Blk> {
class BcastA : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, const Blk> {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, Blk>;
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastA, const Blk>;

BcastA(Edge<Key<2>, Blk> &a_mn, Edge<Key<3>, Blk> &a_rik, std::shared_ptr<const Plan> plan, const Keymap &keymap)
: baseT(edges(a_mn), edges(a_rik), "SpMM::BcastA", {"a_mn"}, {"a_rik"}, keymap), plan_(plan), keymap_(keymap) {}
Expand Down Expand Up @@ -1142,9 +1146,9 @@ class SpMM {
}; // class LStoreA

/// broadcast A[i][k] to all local tasks that belong to this step
class LBcastA : public Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastA, Blk, Control> {
class LBcastA : public Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastA, const Blk, const Control> {
public:
using baseT = Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastA, Blk, Control>;
using baseT = Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastA, const Blk, const Control>;

LBcastA(Edge<Key<4>, Blk> &a_riks, Edge<Key<4>, Control> &ctl_riks, Edge<Key<3>, Blk> &a_ijk,
std::shared_ptr<const Plan> plan, const Keymap &keymap)
Expand All @@ -1164,11 +1168,10 @@ class SpMM {
if (tracing()) ttg::print("On rank", rank, "LBcastA(", r, ",", i, ",", k, ",", s, ")");
// broadcast A[i][k] to all local GEMMs in step s, then pass the data to the next step
std::vector<Key<3>> ijk_keys;
for (auto x : plan_->gemms(s)) {
for (const auto& x : plan_->local_gemms(s)) {
long gi, gj, gk;
std::tie(gi, gj, gk) = x;
if (gi != i || gk != k) continue;
if (keymap_(Key<2>{gi, gj}) != r) continue;
ijk_keys.emplace_back(Key<3>{gi, gj, gk});
if (tracing())
ttg::print("On rank", rank, "Giving A[", gi, ",", gk, "]", "to GEMM(", gi, ",", gj, ",", gk, ") during step",
Expand All @@ -1194,9 +1197,9 @@ class SpMM {
}; // class LBcastA

/// broadcast B[k][j] to all procs where A[i][k]
class BcastB : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, Blk> {
class BcastB : public Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, const Blk> {
public:
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, Blk>;
using baseT = Op<Key<2>, std::tuple<Out<Key<3>, Blk>>, BcastB, const Blk>;

BcastB(Edge<Key<2>, Blk> &b_mn, Edge<Key<3>, Blk> &b_rkj, std::shared_ptr<const Plan> plan, const Keymap &keymap)
: baseT(edges(b_mn), edges(b_rkj), "SpMM::BcastB", {"b_mn"}, {"b_rkj"}, keymap), plan_(plan), keymap_(keymap) {}
Expand Down Expand Up @@ -1272,9 +1275,9 @@ class SpMM {
}; // class LStoreB

/// broadcast B[k][j] to all local tasks that belong to this step
class LBcastB : public Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastB, Blk, Control> {
class LBcastB : public Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastB, const Blk, const Control> {
public:
using baseT = Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastB, Blk, Control>;
using baseT = Op<Key<4>, std::tuple<Out<Key<3>, Blk>, Out<Key<4>, Blk>>, LBcastB, const Blk, const Control>;

LBcastB(Edge<Key<4>, Blk> &b_rkjs, Edge<Key<4>, Control> &ctl_rkjs, Edge<Key<3>, Blk> &b_ijk,
std::shared_ptr<const Plan> plan, const Keymap &keymap)
Expand All @@ -1294,11 +1297,10 @@ class SpMM {
if (tracing()) ttg::print("On rank", r, "LBcastB(", r, ",", k, ",", j, ",", s, ")");
// broadcast B[k][j] to all local GEMMs in step s, then pass the data to the next step
std::vector<Key<3>> ijk_keys;
for (auto x : plan_->gemms(s)) {
for (const auto& x : plan_->local_gemms(s)) {
long gi, gj, gk;
std::tie(gi, gj, gk) = x;
if (gj != j || gk != k) continue;
if (keymap_(Key<2>{gi, gj}) != r) continue;
ijk_keys.emplace_back(Key<3>{gi, gj, gk});
if (tracing())
ttg::print("On rank", rank, "Giving B[", gk, ",", gj, "]", "to GEMM(", gi, ",", gj, ",", gk, ") during step",
Expand Down Expand Up @@ -2261,6 +2263,7 @@ static void timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::function<
assert(connected);
TTGUNUSED(connected);

MPI_Barrier(MPI_COMM_WORLD);
struct timeval start {
0
}, end{0}, diff{0};
Expand Down

0 comments on commit b5322ec

Please sign in to comment.