diff --git a/examples/spmm/spmm.cc b/examples/spmm/spmm.cc index 47e612345..eebab4549 100644 --- a/examples/spmm/spmm.cc +++ b/examples/spmm/spmm.cc @@ -178,9 +178,9 @@ inline int tile2rank(int i, int j, int P, int Q) { // flow (move?) data into an existing SpMatrix on rank 0 template -class Write_SpMatrix : public Op, std::tuple<>, Write_SpMatrix, Blk> { +class Write_SpMatrix : public Op, std::tuple<>, Write_SpMatrix, const Blk> { public: - using baseT = Op, std::tuple<>, Write_SpMatrix, Blk>; + using baseT = Op, std::tuple<>, Write_SpMatrix, const Blk>; template Write_SpMatrix(SpMatrix &matrix, Edge, Blk> &in, Keymap &&keymap) @@ -308,7 +308,7 @@ class SpMM { }; using gemmset_t = std::set>; - using step_vector_t = std::vector>; + using step_vector_t = std::vector>; using step_per_tile_t = std::unordered_map, std::set, long_tuple_hash>; using bcastset_t = std::set>; using comm_plan_t = std::vector>; @@ -491,6 +491,7 @@ class SpMM { for (long nn = 0; nn < nns; nn++) { for (long kk = 0; kk < kns; kk++) { gemmset_t gemms; + gemmset_t local_gemms; long nb_local_gemms = 0; for (long m = mm * cube_dim; m < (mm + 1) * cube_dim && m < mt; m++) { if (m >= a_rowidx_to_colidx_.size() || a_rowidx_to_colidx_[m].empty()) continue; @@ -505,7 +506,10 @@ class SpMM { b_colidx_to_rowidx_[n].end()) continue; auto r = keymap_(Key<2>({m, n})); - if (r == rank) nb_local_gemms++; + if (r == rank) { + local_gemms.insert({m, n, k}); + nb_local_gemms++; + } gemms.insert({m, n, k}); auto it = steps_per_tile_A.find(std::make_tuple(r, m, k)); if (it == steps_per_tile_A.end()) { @@ -545,7 +549,7 @@ class SpMM { } } } - steps.emplace_back(std::make_tuple(gemms, nb_local_gemms)); + steps.emplace_back(std::make_tuple(gemms, nb_local_gemms, local_gemms)); } } } @@ -762,6 +766,7 @@ class SpMM { }; const gemmset_t &gemms(long s) const { return std::get<0>(std::get<0>(steps_)[s]); } + const gemmset_t &local_gemms(long s) const { return std::get<2>(std::get<0>(steps_)[s]); } long nb_local_gemms(long s) const { return std::get<1>(std::get<0>(steps_)[s]); } @@ -882,10 +887,10 @@ class SpMM { /// Central coordinator: ensures that all progress according to the plan class Coordinator : public Op, std::tuple, Control>, Out, Control>, Out, Control>>, - Coordinator, Control> { + Coordinator, const Control> { public: using baseT = - Op, std::tuple, Control>, Out, Control>, Out, Control>>, Coordinator, Control>; + Op, std::tuple, Control>, Out, Control>, Out, Control>>, Coordinator, const Control>; Coordinator(Edge, Control> progress_ctl, Edge, Control> &a_ctl, Edge, Control> &b_ctl, Edge, Control> &c2c_ctl, std::shared_ptr plan, const Keymap &keymap) @@ -893,7 +898,7 @@ class SpMM { {"ctl_rs"}, {"a_ctl_riks", "b_ctl_rkjs", "ctl_rs"}, [](const Key<2> &key) { return (int)key[0]; }) , plan_(plan) , keymap_(keymap) { - baseT::template set_input_reducer<0>([](Control &&a, const Control &&b) { return a; }); + baseT::template set_input_reducer<0>([](Control &&a, Control &&b) { return a; }); auto r = ttg_default_execution_context().rank(); for (long l = 0; l < plan_->lookahead_ && l < plan_->nb_steps(); l++) { if (tracing()) @@ -918,7 +923,7 @@ class SpMM { ttg::print("Coordinator(", r, ",", s, "): there are 0 local GEMMS in step", s, "; triggering next coordinator step", s + plan_->lookahead_); baseT::template set_argstream_size<0>(Key<2>({r, s + plan_->lookahead_}), 1); - ::send<2>(Key<2>({r, s + plan_->lookahead_}), Control{}, out); + ::send<2>(Key<2>({r, s + plan_->lookahead_}), std::get<0>(input), out); } } @@ -930,22 +935,21 @@ class SpMM { std::unordered_set, tuple_hash> seen_a; std::unordered_set, tuple_hash> seen_b; - for (auto x : plan_->gemms(s)) { + for (auto x : plan_->local_gemms(s)) { long gi, gj, gk; std::tie(gi, gj, gk) = x; - if (keymap_(Key<2>{gi, gj}) != r) continue; if (seen_a.find(std::make_tuple(gi, gk)) == seen_a.end()) { if (tracing()) ttg::print("On rank", r, "Coordinator(", r, ", ", s, "): Sending control to LBCastA(", r, ",", gi, ",", gk, ",", s, ")"); - ::send<0>(Key<4>({r, gi, gk, s}), Control{}, out); + ::send<0>(Key<4>({r, gi, gk, s}), std::get<0>(input), out); seen_a.insert(std::make_tuple(gi, gk)); } if (seen_b.find(std::make_tuple(gk, gj)) == seen_b.end()) { if (tracing()) ttg::print("On rank", r, "Coordinator(", r, ", ", s, "): Sending control to LBCastB(", r, ",", gk, ",", gj, ",", s, ")"); - ::send<1>(Key<4>({r, gk, gj, s}), Control{}, out); + ::send<1>(Key<4>({r, gk, gj, s}), std::get<0>(input), out); seen_b.insert(std::make_tuple(gk, gj)); } } @@ -969,7 +973,7 @@ class SpMM { , plan_(plan) , is_a_(is_label_a(label)) , keymap_(keymap) { - baseT::template set_input_reducer<0>([](Control &&a, const Control &&b) { return a; }); + baseT::template set_input_reducer<0>([](Control &&a, Control &&b) { return a; }); auto r = ttg_default_execution_context().rank(); if (tracing()) ttg::print("On rank ", r, ": at bootstrap, setting the number of local bcast on ", is_a_ ? "A" : "B", @@ -1064,9 +1068,9 @@ class SpMM { }; /// broadcast A[i][k] to all procs where B[j][k] - class BcastA : public Op, std::tuple, Blk>>, BcastA, Blk> { + class BcastA : public Op, std::tuple, Blk>>, BcastA, const Blk> { public: - using baseT = Op, std::tuple, Blk>>, BcastA, Blk>; + using baseT = Op, std::tuple, Blk>>, BcastA, const Blk>; BcastA(Edge, Blk> &a_mn, Edge, Blk> &a_rik, std::shared_ptr plan, const Keymap &keymap) : baseT(edges(a_mn), edges(a_rik), "SpMM::BcastA", {"a_mn"}, {"a_rik"}, keymap), plan_(plan), keymap_(keymap) {} @@ -1142,9 +1146,9 @@ class SpMM { }; // class LStoreA /// broadcast A[i][k] to all local tasks that belong to this step - class LBcastA : public Op, std::tuple, Blk>, Out, Blk>>, LBcastA, Blk, Control> { + class LBcastA : public Op, std::tuple, Blk>, Out, Blk>>, LBcastA, const Blk, const Control> { public: - using baseT = Op, std::tuple, Blk>, Out, Blk>>, LBcastA, Blk, Control>; + using baseT = Op, std::tuple, Blk>, Out, Blk>>, LBcastA, const Blk, const Control>; LBcastA(Edge, Blk> &a_riks, Edge, Control> &ctl_riks, Edge, Blk> &a_ijk, std::shared_ptr plan, const Keymap &keymap) @@ -1164,11 +1168,10 @@ class SpMM { if (tracing()) ttg::print("On rank", rank, "LBcastA(", r, ",", i, ",", k, ",", s, ")"); // broadcast A[i][k] to all local GEMMs in step s, then pass the data to the next step std::vector> ijk_keys; - for (auto x : plan_->gemms(s)) { + for (const auto& x : plan_->local_gemms(s)) { long gi, gj, gk; std::tie(gi, gj, gk) = x; if (gi != i || gk != k) continue; - if (keymap_(Key<2>{gi, gj}) != r) continue; ijk_keys.emplace_back(Key<3>{gi, gj, gk}); if (tracing()) ttg::print("On rank", rank, "Giving A[", gi, ",", gk, "]", "to GEMM(", gi, ",", gj, ",", gk, ") during step", @@ -1194,9 +1197,9 @@ class SpMM { }; // class LBcastA /// broadcast B[k][j] to all procs where A[i][k] - class BcastB : public Op, std::tuple, Blk>>, BcastB, Blk> { + class BcastB : public Op, std::tuple, Blk>>, BcastB, const Blk> { public: - using baseT = Op, std::tuple, Blk>>, BcastB, Blk>; + using baseT = Op, std::tuple, Blk>>, BcastB, const Blk>; BcastB(Edge, Blk> &b_mn, Edge, Blk> &b_rkj, std::shared_ptr plan, const Keymap &keymap) : baseT(edges(b_mn), edges(b_rkj), "SpMM::BcastB", {"b_mn"}, {"b_rkj"}, keymap), plan_(plan), keymap_(keymap) {} @@ -1272,9 +1275,9 @@ class SpMM { }; // class LStoreB /// broadcast B[k][j] to all local tasks that belong to this step - class LBcastB : public Op, std::tuple, Blk>, Out, Blk>>, LBcastB, Blk, Control> { + class LBcastB : public Op, std::tuple, Blk>, Out, Blk>>, LBcastB, const Blk, const Control> { public: - using baseT = Op, std::tuple, Blk>, Out, Blk>>, LBcastB, Blk, Control>; + using baseT = Op, std::tuple, Blk>, Out, Blk>>, LBcastB, const Blk, const Control>; LBcastB(Edge, Blk> &b_rkjs, Edge, Control> &ctl_rkjs, Edge, Blk> &b_ijk, std::shared_ptr plan, const Keymap &keymap) @@ -1294,11 +1297,10 @@ class SpMM { if (tracing()) ttg::print("On rank", r, "LBcastB(", r, ",", k, ",", j, ",", s, ")"); // broadcast B[k][j] to all local GEMMs in step s, then pass the data to the next step std::vector> ijk_keys; - for (auto x : plan_->gemms(s)) { + for (const auto& x : plan_->local_gemms(s)) { long gi, gj, gk; std::tie(gi, gj, gk) = x; if (gj != j || gk != k) continue; - if (keymap_(Key<2>{gi, gj}) != r) continue; ijk_keys.emplace_back(Key<3>{gi, gj, gk}); if (tracing()) ttg::print("On rank", rank, "Giving B[", gk, ",", gj, "]", "to GEMM(", gi, ",", gj, ",", gk, ") during step", @@ -2261,6 +2263,7 @@ static void timed_measurement(SpMatrix<> &A, SpMatrix<> &B, const std::function< assert(connected); TTGUNUSED(connected); + MPI_Barrier(MPI_COMM_WORLD); struct timeval start { 0 }, end{0}, diff{0};