Skip to content

Commit

Permalink
Buffer Debug Names
Browse files Browse the repository at this point in the history
Test that print graph correctly shows the buffer debug names.
  • Loading branch information
facuMH committed Sep 7, 2022
1 parent 299ebbf commit 5f5a3a3
Show file tree
Hide file tree
Showing 15 changed files with 142 additions and 40 deletions.
3 changes: 3 additions & 0 deletions examples/matmul/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ int main() {
celerity::buffer<float, 2> mat_b_buf(range);
celerity::buffer<float, 2> mat_c_buf(range);

celerity::debug::set_buffer_name(mat_a_buf, "mat_a");
celerity::debug::set_buffer_name(mat_b_buf, "mat_b");

set_identity(queue, mat_a_buf);
set_identity(queue, mat_b_buf);

Expand Down
11 changes: 11 additions & 0 deletions include/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ namespace detail {
cl::sycl::range<3> range = {1, 1, 1};
size_t element_size = 0;
bool is_host_initialized;
std::string debug_name = {};
};

/**
Expand Down Expand Up @@ -312,6 +313,16 @@ namespace detail {

bool is_locked(buffer_id bid) const;

void set_debug_name(const buffer_id bid, const std::string& debug_name) {
std::lock_guard lock(m_mutex);
m_buffer_infos.at(bid).debug_name = debug_name;
}

std::string get_debug_name(const buffer_id bid) const {
std::lock_guard lock(m_mutex);
return m_buffer_infos.at(bid).debug_name;
}

private:
struct backing_buffer {
std::unique_ptr<buffer_storage> storage = nullptr;
Expand Down
1 change: 1 addition & 0 deletions include/celerity.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "accessor.h"
#include "buffer.h"
#include "debug.h"
#include "distr_queue.h"
#include "side_effect.h"
#include "user_bench.h"
Expand Down
3 changes: 2 additions & 1 deletion include/command_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
namespace celerity {
namespace detail {

class buffer_manager;
class reduction_manager;
class task_manager;

Expand Down Expand Up @@ -128,7 +129,7 @@ namespace detail {

auto& task_commands(task_id tid) { return m_by_task.at(tid); }

std::optional<std::string> print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm) const;
std::optional<std::string> print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) const;

// TODO unify dependency terminology to this
void add_dependency(abstract_command* depender, abstract_command* dependee, dependency_kind kind, dependency_origin origin) {
Expand Down
17 changes: 17 additions & 0 deletions include/debug.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#include <string>

#include "buffer.h"

namespace celerity {
namespace debug {
template <typename DataT, int Dims>
void set_buffer_name(const celerity::buffer<DataT, Dims>& buff, const std::string& debug_name) {
detail::runtime::get_instance().set_buffer_debug_name(detail::get_buffer_id(buff), debug_name);
}
template <typename DataT, int Dims>
std::string get_buffer_name(const celerity::buffer<DataT, Dims>& buff) {
return detail::runtime::get_instance().get_buffer_manager().get_debug_name(detail::get_buffer_id(buff));
}

} // namespace debug
} // namespace celerity
5 changes: 3 additions & 2 deletions include/print_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
namespace celerity {
namespace detail {

class buffer_manager;
class command_graph;
class reduction_manager;
class task_manager;

std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm);
std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm);
std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm, const buffer_manager* bm);
std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm);

} // namespace detail
} // namespace celerity
2 changes: 2 additions & 0 deletions include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ namespace detail {
m_test_active = false;
}

void set_buffer_debug_name(const buffer_id bid, const std::string& debug_name);

private:
inline static bool m_test_mode = false;
inline static bool m_test_active = false;
Expand Down
5 changes: 3 additions & 2 deletions src/command_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ namespace detail {
}
}

std::optional<std::string> command_graph::print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm) const {
if(command_count() <= max_nodes) { return detail::print_command_graph(*this, tm, rm); }
std::optional<std::string> command_graph::print_graph(
size_t max_nodes, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) const {
if(command_count() <= max_nodes) { return detail::print_command_graph(*this, tm, rm, bm); }
return std::nullopt;
}

Expand Down
48 changes: 32 additions & 16 deletions src/print_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,19 @@ namespace detail {
}
}

void format_requirements(std::string& label, const task& tsk, subrange<3> execution_range, access_mode reduction_init_mode, const reduction_manager& rm) {
std::string get_buffer_debug_name(const buffer_manager* bm, const buffer_id bid) {
// if there is no buffer manager or no name defined, the name will be the buffer id.
// if there is a name we want "id name"
std::string name{};
const std::string bid_str = fmt::format("B{}", bid);
if(bm != nullptr) { name = bm->get_debug_name(bid); }
if(!name.empty()) { name = fmt::format(" \"{}\"", name); }
name = fmt::format("{}{}", bid_str, name);
return name;
}

void format_requirements(std::string& label, const task& tsk, subrange<3> execution_range, access_mode reduction_init_mode, const reduction_manager& rm,
const buffer_manager* bm) {
for(auto rid : tsk.get_reductions()) {
auto reduction = rm.get_reduction(rid);

Expand All @@ -43,15 +55,17 @@ namespace detail {

const auto bid = reduction.output_buffer_id;
const auto req = GridRegion<3>{{1, 1, 1}};
fmt::format_to(std::back_inserter(label), "<br/>(R{}) <i>{}</i> B{} {}", rid, detail::access::mode_traits::name(rmode), bid, req);
const std::string bdn = get_buffer_debug_name(bm, bid);
fmt::format_to(std::back_inserter(label), "<br/>(R{}) <i>{}</i> {} {}", rid, detail::access::mode_traits::name(rmode), bdn, req);
}

const auto& bam = tsk.get_buffer_access_map();
for(const auto bid : bam.get_accessed_buffers()) {
for(const auto mode : bam.get_access_modes(bid)) {
const auto req = bam.get_requirements_for_access(bid, mode, tsk.get_dimensions(), execution_range, tsk.get_global_size());
const std::string bdn = get_buffer_debug_name(bm, bid);
// While uncommon, we do support chunks that don't require access to a particular buffer at all.
if(!req.empty()) { fmt::format_to(std::back_inserter(label), "<br/><i>{}</i> B{} {}", detail::access::mode_traits::name(mode), bid, req); }
if(!req.empty()) { fmt::format_to(std::back_inserter(label), "<br/><i>{}</i> {} {}", detail::access::mode_traits::name(mode), bdn, req); }
}
}

Expand All @@ -60,7 +74,7 @@ namespace detail {
}
}

std::string get_task_label(const task& tsk, const reduction_manager& rm) {
std::string get_task_label(const task& tsk, const reduction_manager& rm, const buffer_manager* bm) {
std::string label;
fmt::format_to(std::back_inserter(label), "T{}", tsk.get_id());
if(!tsk.get_debug_name().empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.get_debug_name()); }
Expand All @@ -74,17 +88,17 @@ namespace detail {
fmt::format_to(std::back_inserter(label), " in CG{}", tsk.get_collective_group_id());
}

format_requirements(label, tsk, execution_range, access_mode::read_write, rm);
format_requirements(label, tsk, execution_range, access_mode::read_write, rm, bm);

return label;
}

std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm) {
std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm, const buffer_manager* bm) {
std::string dot = "digraph G {label=\"Task Graph\" ";

for(auto tsk : tdag) {
const auto shape = tsk->get_type() == task_type::epoch || tsk->get_type() == task_type::horizon ? "ellipse" : "box style=rounded";
fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->get_id(), shape, get_task_label(*tsk, rm));
fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->get_id(), shape, get_task_label(*tsk, rm, bm));
for(auto d : tsk->get_dependencies()) {
fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", d.node->get_id(), tsk->get_id(), dependency_style(d));
}
Expand All @@ -94,7 +108,7 @@ namespace detail {
return dot;
}

std::string get_command_label(const abstract_command& cmd, const task_manager& tm, const reduction_manager& rm) {
std::string get_command_label(const abstract_command& cmd, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) {
const command_id cid = cmd.get_cid();
const node_id nid = cmd.get_nid();

Expand All @@ -108,21 +122,23 @@ namespace detail {
fmt::format_to(std::back_inserter(label), "<b>execution</b> {}", subrange_to_grid_box(xcmd->get_execution_range()));
} else if(const auto pcmd = dynamic_cast<const push_command*>(&cmd)) {
if(pcmd->get_rid()) { fmt::format_to(std::back_inserter(label), "(R{}) ", pcmd->get_rid()); }
fmt::format_to(
std::back_inserter(label), "<b>push</b> to N{}<br/>B{} {}", pcmd->get_target(), pcmd->get_bid(), subrange_to_grid_box(pcmd->get_range()));
const std::string bdn = get_buffer_debug_name(bm, pcmd->get_bid());
fmt::format_to(std::back_inserter(label), "<b>push</b> to N{}<br/> {} {}", pcmd->get_target(), bdn, subrange_to_grid_box(pcmd->get_range()));
} else if(const auto apcmd = dynamic_cast<const await_push_command*>(&cmd)) {
if(apcmd->get_source()->get_rid()) { label += fmt::format("(R{}) ", apcmd->get_source()->get_rid()); }
fmt::format_to(std::back_inserter(label), "<b>await push</b> from N{}<br/>B{} {}", apcmd->get_source()->get_nid(), apcmd->get_source()->get_bid(),
const std::string bdn = get_buffer_debug_name(bm, apcmd->get_source()->get_bid());
fmt::format_to(std::back_inserter(label), "<b>await push</b> from N{}<br/> {} {}", apcmd->get_source()->get_nid(), bdn,
subrange_to_grid_box(apcmd->get_source()->get_range()));
} else if(const auto rrcmd = dynamic_cast<const reduction_command*>(&cmd)) {
const auto reduction = rm.get_reduction(rrcmd->get_rid());
const auto req = GridRegion<3>{{1, 1, 1}};
fmt::format_to(std::back_inserter(label), "<b>reduction</b> R{}<br/>B{} {}", rrcmd->get_rid(), reduction.output_buffer_id, req);
const std::string bdn = get_buffer_debug_name(bm, reduction.output_buffer_id);
fmt::format_to(std::back_inserter(label), "<b>reduction</b> R{}<br/> {} {}", rrcmd->get_rid(), bdn, req);
} else if(const auto hcmd = dynamic_cast<const horizon_command*>(&cmd)) {
label += "<b>horizon</b>";
} else {
assert(!"Unkown command");
label += "<b>unknown</b>";
label += "<b>unknown<b>";
}

if(const auto tcmd = dynamic_cast<const task_command*>(&cmd)) {
Expand All @@ -135,21 +151,21 @@ namespace detail {
execution_range = ecmd->get_execution_range();
}

format_requirements(label, tsk, execution_range, reduction_init_mode, rm);
format_requirements(label, tsk, execution_range, reduction_init_mode, rm, bm);
}

return label;
}

std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm) {
std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) {
std::string main_dot;
std::unordered_map<task_id, std::string> task_subgraph_dot;

const auto print_vertex = [&](const abstract_command& cmd) {
static const char* const colors[] = {"black", "crimson", "dodgerblue4", "goldenrod", "maroon4", "springgreen2", "tan1", "chartreuse2"};

const auto name = cmd.get_cid();
const auto label = get_command_label(cmd, tm, rm);
const auto label = get_command_label(cmd, tm, rm, bm);
const auto fontcolor = colors[cmd.get_nid() % (sizeof(colors) / sizeof(char*))];
const auto shape = isa<task_command>(&cmd) ? "box" : "ellipse";
return fmt::format("{}[label=<{}> fontcolor={} shape={}];", name, label, fontcolor, shape);
Expand Down
4 changes: 3 additions & 1 deletion src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ namespace detail {
}
}
{
const auto graph_str = m_cdag->print_graph(print_max_nodes, *m_task_mngr, *m_reduction_mngr);
const auto graph_str = m_cdag->print_graph(print_max_nodes, *m_task_mngr, *m_reduction_mngr, m_buffer_mngr.get());
if(graph_str.has_value()) {
CELERITY_TRACE("Command graph:\n\n{}\n", *graph_str);
} else {
Expand Down Expand Up @@ -281,5 +281,7 @@ namespace detail {
if(done) { m_active_flushes.pop_front(); }
}

void runtime::set_buffer_debug_name(const buffer_id bid, const std::string& debug_name) { m_buffer_mngr->set_debug_name(bid, debug_name); }

} // namespace detail
} // namespace celerity
2 changes: 1 addition & 1 deletion src/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace detail {
const task* task_manager::get_task(task_id tid) const { return m_task_buffer.get_task(tid); }

std::optional<std::string> task_manager::print_graph(size_t max_nodes) const {
if(m_task_buffer.get_current_task_count() <= max_nodes) { return detail::print_task_graph(m_task_buffer, *m_reduction_mngr); }
if(m_task_buffer.get_current_task_count() <= max_nodes) { return detail::print_task_graph(m_task_buffer, *m_reduction_mngr, nullptr); }
return std::nullopt;
}

Expand Down
8 changes: 8 additions & 0 deletions test/buffer_manager_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,14 @@ namespace detail {
}
}

TEST_CASE_METHOD(test_utils::runtime_fixture, "set buffer debug name", "[buffer_debug_name]") {
distr_queue q;
celerity::buffer<int, 1> buff_a(16);
std::string buff_name{"my_buffer"};
celerity::debug::set_buffer_name(buff_a, buff_name);
CHECK(celerity::debug::get_buffer_name(buff_a) == buff_name);
}

#endif // CELERITY_DETAIL_IS_OLD_COMPUTECPP_COMPILER

} // namespace detail
Expand Down
50 changes: 42 additions & 8 deletions test/print_graph_tests.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#include <catch2/catch_test_macros.hpp>

#include "test_utils.h"


namespace celerity::detail {

using celerity::access::fixed;
Expand Down Expand Up @@ -72,7 +75,8 @@ TEST_CASE("command graph printing is unchanged", "[print_graph][command-graph]")
// replace the `expected` value with the new dot graph.
const auto expected =
"digraph G{label=\"Command Graph\" subgraph cluster_2{label=<<font color=\"#606060\">T2 (master-node host)</font>>;color=darkgray;9[label=<C9 on "
"N0<br/><b>execution</b> [[0,0,0] - [0,0,0]]<br/><i>read</i> B0 {[[0,0,0] - [1,1,1]]}<br/><i>read_write</i> B0 {[[0,0,0] - [1,1,1]]}<br/><i>write</i> "
"N0<br/><b>execution</b> [[0,0,0] - [0,0,0]]<br/><i>read</i> B0 {[[0,0,0] - [1,1,1]]}<br/><i>read_write</i> B0 {[[0,0,0] - "
"[1,1,1]]}<br/><i>write</i> "
"B0 {[[0,0,0] - [1,1,1]]}> fontcolor=black shape=box];}subgraph cluster_1{label=<<font color=\"#606060\">T1 \"task_reduction_8\" "
"(device-compute)</font>>;color=darkgray;5[label=<C5 on N0<br/><b>execution</b> [[0,0,0] - [1,1,1]]<br/>(R1) <i>discard_write</i> B0 {[[0,0,0] - "
"[1,1,1]]}> fontcolor=black shape=box];6[label=<C6 on N1<br/><b>execution</b> [[1,0,0] - [2,1,1]]<br/>(R1) <i>discard_write</i> B0 {[[0,0,0] - "
Expand All @@ -81,16 +85,46 @@ TEST_CASE("command graph printing is unchanged", "[print_graph][command-graph]")
"[1,1,1]]}> fontcolor=goldenrod shape=box];}subgraph cluster_0{label=<<font color=\"#606060\">T0 (epoch)</font>>;color=darkgray;0[label=<C0 on "
"N0<br/><b>epoch</b>> fontcolor=black shape=box];1[label=<C1 on N1<br/><b>epoch</b>> fontcolor=crimson shape=box];2[label=<C2 on N2<br/><b>epoch</b>> "
"fontcolor=dodgerblue4 shape=box];3[label=<C3 on N3<br/><b>epoch</b>> fontcolor=goldenrod shape=box];}16[label=<C16 on N0<br/>(R1) <b>await push</b> "
"from N3<br/>B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->16[color=orchid];15->16[style=dashed color=gray40];15[label=<C15 on N3<br/>(R1) "
"<b>push</b> to N0<br/>B0 [[0,0,0] - [1,1,1]]> fontcolor=goldenrod shape=ellipse];8->15[];14[label=<C14 on N0<br/>(R1) <b>await push</b> from "
"N2<br/>B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->14[color=orchid];13->14[style=dashed color=gray40];13[label=<C13 on N2<br/>(R1) "
"<b>push</b> to N0<br/>B0 [[0,0,0] - [1,1,1]]> fontcolor=dodgerblue4 "
"from N3<br/> B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->16[color=orchid];15->16[style=dashed color=gray40];15[label=<C15 on "
"N3<br/>(R1) "
"<b>push</b> to N0<br/> B0 [[0,0,0] - [1,1,1]]> fontcolor=goldenrod shape=ellipse];8->15[];14[label=<C14 on N0<br/>(R1) <b>await push</b> from "
"N2<br/> B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->14[color=orchid];13->14[style=dashed color=gray40];13[label=<C13 on N2<br/>(R1) "
"<b>push</b> to N0<br/> B0 [[0,0,0] - [1,1,1]]> fontcolor=dodgerblue4 "
"shape=ellipse];7->13[];0->5[color=orchid];1->6[color=orchid];2->7[color=orchid];3->8[color=orchid];10->9[];10[label=<C10 on N0<br/><b>reduction</b> "
"R1<br/>B0 {[[0,0,0] - [1,1,1]]}> fontcolor=black shape=ellipse];5->10[];12->10[];14->10[];16->10[];11[label=<C11 on N1<br/>(R1) <b>push</b> to "
"N0<br/>B0 [[0,0,0] - [1,1,1]]> fontcolor=crimson shape=ellipse];6->11[];12[label=<C12 on N0<br/>(R1) <b>await push</b> from N1<br/>B0 [[0,0,0] - "
"R1<br/> B0 {[[0,0,0] - [1,1,1]]}> fontcolor=black shape=ellipse];5->10[];12->10[];14->10[];16->10[];11[label=<C11 on N1<br/>(R1) <b>push</b> to "
"N0<br/> B0 [[0,0,0] - [1,1,1]]> fontcolor=crimson shape=ellipse];6->11[];12[label=<C12 on N0<br/>(R1) <b>await push</b> from N1<br/> B0 [[0,0,0] - "
"[1,1,1]]> fontcolor=black shape=ellipse];0->12[color=orchid];11->12[style=dashed color=gray40];}";

const auto dot = ctx.get_command_graph().print_graph(std::numeric_limits<size_t>::max(), tm, rm).value();
const auto dot = ctx.get_command_graph().print_graph(std::numeric_limits<size_t>::max(), tm, rm, {}).value();
CHECK(dot == expected);
}

TEST_CASE_METHOD(test_utils::runtime_fixture, "Buffer debug names show up in the generated graph", "[print_graph][buffer_names]") {
distr_queue q;
celerity::range<1> range(16);
celerity::buffer<int, 1> buff_a(range);
std::string buff_name{"my_buffer"};
celerity::debug::set_buffer_name(buff_a, buff_name);
CHECK(celerity::debug::get_buffer_name(buff_a) == buff_name);

q.submit([=](handler& cgh) {
celerity::accessor acc{buff_a, cgh, celerity::access::all{}, celerity::write_only};
cgh.parallel_for<class UKN(print_graph_buffer_name)>(range, [=](item<1> item) {});
});

// slow full sync needed to avoid race condition on the check.
q.slow_full_sync();

// Smoke test: It is valid for the dot output to change with updates to graph generation. If this test fails, verify that the printed graph is sane and
// replace the `expected` value with the new dot graph.
const auto expected =
"digraph G{label=\"Command Graph\" subgraph cluster_0{label=<<font color=\"#606060\">T0 (epoch)</font>>;color=darkgray;0[label=<C0 on "
"N0<br/><b>epoch</b>> fontcolor=black shape=box];}subgraph cluster_1{label=<<font color=\"#606060\">T1 \"print_graph_buffer_name_11\" "
"(device-compute)</font>>;color=darkgray;2[label=<C2 on N0<br/><b>execution</b> [[0,0,0] - [16,1,1]]<br/><i>write</i> B0 \"my_buffer\" {[[0,0,0] - "
"[16,1,1]]}> fontcolor=black shape=box];}subgraph cluster_2{label=<<font color=\"#606060\">T2 (epoch)</font>>;color=darkgray;3[label=<C3 on "
"N0<br/><b>epoch</b> (barrier)> fontcolor=black shape=box];}2->3[color=orange];0->2[];}";

const auto dot = runtime_testspy::print_graph(runtime::get_instance());
CHECK(dot == expected);
}
} // namespace celerity::detail
8 changes: 0 additions & 8 deletions test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,6 @@ namespace detail {
}
};

struct runtime_testspy {
static scheduler& get_schdlr(runtime& rt) { return *rt.m_schdlr; }

static executor& get_exec(runtime& rt) { return *rt.m_exec; }

static size_t get_command_count(runtime& rt) { return rt.m_cdag->command_count(); }
};

struct scheduler_testspy {
static std::thread& get_worker_thread(scheduler& schdlr) { return schdlr.m_worker_thread; }
};
Expand Down
Loading

0 comments on commit 5f5a3a3

Please sign in to comment.