From e20649fac6580c0930eac701f27facf9a06d4c8e Mon Sep 17 00:00:00 2001 From: FacuMH Date: Wed, 3 Aug 2022 14:25:54 +0200 Subject: [PATCH] Add buffer debug names This adds `[set|get]_buffer_name`, which should be more significant names than just the ids. If set, these names are then included in the GraphViz output. --- examples/matmul/matmul.cc | 3 +++ include/buffer.h | 52 +++++++++++++++++++++--------------- include/buffer_manager.h | 14 +++++++++- include/celerity.h | 1 + include/command_graph.h | 3 ++- include/debug.h | 18 +++++++++++++ include/print_graph.h | 5 ++-- src/command_graph.cc | 5 ++-- src/print_graph.cc | 43 ++++++++++++++++++----------- src/runtime.cc | 2 +- src/task_manager.cc | 2 +- test/buffer_manager_tests.cc | 7 +++++ test/print_graph_tests.cc | 50 ++++++++++++++++++++++++++++------ test/runtime_tests.cc | 8 ------ test/test_utils.h | 12 ++++++++- 15 files changed, 163 insertions(+), 62 deletions(-) create mode 100644 include/debug.h diff --git a/examples/matmul/matmul.cc b/examples/matmul/matmul.cc index 742d3185d..fae87a1d3 100644 --- a/examples/matmul/matmul.cc +++ b/examples/matmul/matmul.cc @@ -78,6 +78,9 @@ int main() { celerity::buffer mat_b_buf(range); celerity::buffer mat_c_buf(range); + celerity::debug::set_buffer_name(mat_a_buf, "mat_a"); + celerity::debug::set_buffer_name(mat_b_buf, "mat_b"); + set_identity(queue, mat_a_buf); set_identity(queue, mat_b_buf); diff --git a/include/buffer.h b/include/buffer.h index 616988d8d..f2f210c82 100644 --- a/include/buffer.h +++ b/include/buffer.h @@ -18,22 +18,18 @@ class buffer; namespace detail { - struct buffer_lifetime_tracker { - buffer_lifetime_tracker() = default; - template - buffer_id initialize(celerity::range<3> range, const DataT* host_init_ptr) { - id = runtime::get_instance().get_buffer_manager().register_buffer(range, host_init_ptr); - return id; - } - buffer_lifetime_tracker(const buffer_lifetime_tracker&) = delete; - buffer_lifetime_tracker(buffer_lifetime_tracker&&) = delete; - ~buffer_lifetime_tracker() noexcept { runtime::get_instance().get_buffer_manager().unregister_buffer(id); } - buffer_id id; - }; - template buffer_id get_buffer_id(const buffer& buff); + template + void set_buffer_name(const celerity::buffer& buff, const std::string& debug_name) { + buff.m_impl->debug_name = debug_name; + }; + template + std::string get_buffer_name(const celerity::buffer& buff) { + return buff.m_impl->debug_name; + }; + } // namespace detail template @@ -44,11 +40,9 @@ class buffer { public: static_assert(Dims > 0, "0-dimensional buffers NYI"); - buffer(const DataT* host_ptr, celerity::range range) : m_range(range) { + buffer(const DataT* host_ptr, celerity::range range) { if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr); } - - m_lifetime_tracker = std::make_shared(); - m_id = m_lifetime_tracker->initialize(detail::range_cast<3>(range), host_ptr); + m_impl = std::make_shared(range, host_ptr); } buffer(celerity::range range) : buffer(nullptr, range) {} @@ -72,22 +66,36 @@ class buffer { return accessor(*this, cgh, rmfn); } - celerity::range get_range() const { return m_range; } + celerity::range get_range() const { return m_impl->range; } private: - std::shared_ptr m_lifetime_tracker = nullptr; - celerity::range m_range; - detail::buffer_id m_id; + struct impl { + impl(celerity::range rng, const DataT* host_init_ptr) : range(rng) { + id = detail::runtime::get_instance().get_buffer_manager().register_buffer(detail::range_cast<3>(range), host_init_ptr); + } + impl(const impl&) = delete; + impl(impl&&) = delete; + ~impl() noexcept { detail::runtime::get_instance().get_buffer_manager().unregister_buffer(id); } + detail::buffer_id id; + celerity::range range; + std::string debug_name; + }; + + std::shared_ptr m_impl = nullptr; template friend detail::buffer_id detail::get_buffer_id(const buffer& buff); + template + friend void detail::set_buffer_name(const celerity::buffer& buff, const std::string& debug_name); + template + friend std::string detail::get_buffer_name(const celerity::buffer& buff); }; namespace detail { template buffer_id get_buffer_id(const buffer& buff) { - return buff.m_id; + return buff.m_impl->id; } } // namespace detail diff --git a/include/buffer_manager.h b/include/buffer_manager.h index ebaaec13d..908ffcf8f 100644 --- a/include/buffer_manager.h +++ b/include/buffer_manager.h @@ -93,6 +93,7 @@ namespace detail { cl::sycl::range<3> range = {1, 1, 1}; size_t element_size = 0; bool is_host_initialized; + std::string debug_name = {}; }; /** @@ -163,7 +164,8 @@ namespace detail { return !m_buffer_infos.empty(); } - const buffer_info& get_buffer_info(buffer_id bid) const { + // returning copy of struct because FOR NOW it is not called in any performance critical section. + buffer_info get_buffer_info(buffer_id bid) const { std::shared_lock lock(m_mutex); assert(m_buffer_infos.find(bid) != m_buffer_infos.end()); return m_buffer_infos.at(bid); @@ -312,6 +314,16 @@ namespace detail { bool is_locked(buffer_id bid) const; + void set_debug_name(const buffer_id bid, const std::string& debug_name) { + std::lock_guard lock(m_mutex); + m_buffer_infos.at(bid).debug_name = debug_name; + } + + std::string get_debug_name(const buffer_id bid) const { + std::lock_guard lock(m_mutex); + return m_buffer_infos.at(bid).debug_name; + } + private: struct backing_buffer { std::unique_ptr storage = nullptr; diff --git a/include/celerity.h b/include/celerity.h index 199da6051..4e19440d2 100644 --- a/include/celerity.h +++ b/include/celerity.h @@ -6,6 +6,7 @@ #include "accessor.h" #include "buffer.h" +#include "debug.h" #include "distr_queue.h" #include "side_effect.h" #include "user_bench.h" diff --git a/include/command_graph.h b/include/command_graph.h index cdadf43db..839ac56b7 100644 --- a/include/command_graph.h +++ b/include/command_graph.h @@ -14,6 +14,7 @@ namespace celerity { namespace detail { + class buffer_manager; class reduction_manager; class task_manager; @@ -128,7 +129,7 @@ namespace detail { auto& task_commands(task_id tid) { return m_by_task.at(tid); } - std::optional print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm) const; + std::optional print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) const; // TODO unify dependency terminology to this void add_dependency(abstract_command* depender, abstract_command* dependee, dependency_kind kind, dependency_origin origin) { diff --git a/include/debug.h b/include/debug.h new file mode 100644 index 000000000..45b1e6f7b --- /dev/null +++ b/include/debug.h @@ -0,0 +1,18 @@ +#include + +#include "buffer.h" + +namespace celerity { +namespace debug { + template + void set_buffer_name(const celerity::buffer& buff, const std::string& debug_name) { + detail::set_buffer_name(buff, debug_name); + detail::runtime::get_instance().get_buffer_manager().set_debug_name(detail::get_buffer_id(buff), debug_name); + } + template + std::string get_buffer_name(const celerity::buffer& buff) { + return detail::get_buffer_name(buff); + } + +} // namespace debug +} // namespace celerity \ No newline at end of file diff --git a/include/print_graph.h b/include/print_graph.h index 4397e0c2e..3e98e68ba 100644 --- a/include/print_graph.h +++ b/include/print_graph.h @@ -9,12 +9,13 @@ namespace celerity { namespace detail { + class buffer_manager; class command_graph; class reduction_manager; class task_manager; - std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm); - std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm); + std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm, const buffer_manager* bm); + std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm); } // namespace detail } // namespace celerity diff --git a/src/command_graph.cc b/src/command_graph.cc index 82a9c306e..130cebe95 100644 --- a/src/command_graph.cc +++ b/src/command_graph.cc @@ -25,8 +25,9 @@ namespace detail { } } - std::optional command_graph::print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm) const { - if(command_count() <= max_nodes) { return detail::print_command_graph(*this, tm, rm); } + std::optional command_graph::print_graph( + size_t max_nodes, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) const { + if(command_count() <= max_nodes) { return detail::print_command_graph(*this, tm, rm, bm); } return std::nullopt; } diff --git a/src/print_graph.cc b/src/print_graph.cc index 6f82b537d..f612441f0 100644 --- a/src/print_graph.cc +++ b/src/print_graph.cc @@ -34,7 +34,16 @@ namespace detail { } } - void format_requirements(std::string& label, const task& tsk, subrange<3> execution_range, access_mode reduction_init_mode, const reduction_manager& rm) { + std::string get_buffer_label(const buffer_manager* bm, const buffer_id bid) { + // if there is no buffer manager or no name defined, the name will be the buffer id. + // if there is a name we want "id name" + std::string name; + if(bm != nullptr) { name = bm->get_debug_name(bid); } + return !name.empty() ? fmt::format("B{} \"{}\"", bid, name) : fmt::format("B{}", bid); + } + + void format_requirements(std::string& label, const task& tsk, subrange<3> execution_range, access_mode reduction_init_mode, const reduction_manager& rm, + const buffer_manager* bm) { for(auto rid : tsk.get_reductions()) { auto reduction = rm.get_reduction(rid); @@ -43,15 +52,17 @@ namespace detail { const auto bid = reduction.output_buffer_id; const auto req = GridRegion<3>{{1, 1, 1}}; - fmt::format_to(std::back_inserter(label), "
(R{}) {} B{} {}", rid, detail::access::mode_traits::name(rmode), bid, req); + const std::string bl = get_buffer_label(bm, bid); + fmt::format_to(std::back_inserter(label), "
(R{}) {} {} {}", rid, detail::access::mode_traits::name(rmode), bl, req); } const auto& bam = tsk.get_buffer_access_map(); for(const auto bid : bam.get_accessed_buffers()) { for(const auto mode : bam.get_access_modes(bid)) { const auto req = bam.get_requirements_for_access(bid, mode, tsk.get_dimensions(), execution_range, tsk.get_global_size()); + const std::string bl = get_buffer_label(bm, bid); // While uncommon, we do support chunks that don't require access to a particular buffer at all. - if(!req.empty()) { fmt::format_to(std::back_inserter(label), "
{} B{} {}", detail::access::mode_traits::name(mode), bid, req); } + if(!req.empty()) { fmt::format_to(std::back_inserter(label), "
{} {} {}", detail::access::mode_traits::name(mode), bl, req); } } } @@ -60,7 +71,7 @@ namespace detail { } } - std::string get_task_label(const task& tsk, const reduction_manager& rm) { + std::string get_task_label(const task& tsk, const reduction_manager& rm, const buffer_manager* bm) { std::string label; fmt::format_to(std::back_inserter(label), "T{}", tsk.get_id()); if(!tsk.get_debug_name().empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.get_debug_name()); } @@ -74,17 +85,17 @@ namespace detail { fmt::format_to(std::back_inserter(label), " in CG{}", tsk.get_collective_group_id()); } - format_requirements(label, tsk, execution_range, access_mode::read_write, rm); + format_requirements(label, tsk, execution_range, access_mode::read_write, rm, bm); return label; } - std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm) { + std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm, const buffer_manager* bm) { std::string dot = "digraph G {label=\"Task Graph\" "; for(auto tsk : tdag) { const auto shape = tsk->get_type() == task_type::epoch || tsk->get_type() == task_type::horizon ? "ellipse" : "box style=rounded"; - fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->get_id(), shape, get_task_label(*tsk, rm)); + fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->get_id(), shape, get_task_label(*tsk, rm, bm)); for(auto d : tsk->get_dependencies()) { fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", d.node->get_id(), tsk->get_id(), dependency_style(d)); } @@ -94,7 +105,7 @@ namespace detail { return dot; } - std::string get_command_label(const abstract_command& cmd, const task_manager& tm, const reduction_manager& rm) { + std::string get_command_label(const abstract_command& cmd, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) { const command_id cid = cmd.get_cid(); const node_id nid = cmd.get_nid(); @@ -108,16 +119,18 @@ namespace detail { fmt::format_to(std::back_inserter(label), "execution {}", subrange_to_grid_box(xcmd->get_execution_range())); } else if(const auto pcmd = dynamic_cast(&cmd)) { if(pcmd->get_rid()) { fmt::format_to(std::back_inserter(label), "(R{}) ", pcmd->get_rid()); } - fmt::format_to( - std::back_inserter(label), "push to N{}
B{} {}", pcmd->get_target(), pcmd->get_bid(), subrange_to_grid_box(pcmd->get_range())); + const std::string bl = get_buffer_label(bm, pcmd->get_bid()); + fmt::format_to(std::back_inserter(label), "push to N{}
{} {}", pcmd->get_target(), bl, subrange_to_grid_box(pcmd->get_range())); } else if(const auto apcmd = dynamic_cast(&cmd)) { if(apcmd->get_source()->get_rid()) { label += fmt::format("(R{}) ", apcmd->get_source()->get_rid()); } - fmt::format_to(std::back_inserter(label), "await push from N{}
B{} {}", apcmd->get_source()->get_nid(), apcmd->get_source()->get_bid(), + const std::string bl = get_buffer_label(bm, apcmd->get_source()->get_bid()); + fmt::format_to(std::back_inserter(label), "await push from N{}
{} {}", apcmd->get_source()->get_nid(), bl, subrange_to_grid_box(apcmd->get_source()->get_range())); } else if(const auto rrcmd = dynamic_cast(&cmd)) { const auto reduction = rm.get_reduction(rrcmd->get_rid()); const auto req = GridRegion<3>{{1, 1, 1}}; - fmt::format_to(std::back_inserter(label), "reduction R{}
B{} {}", rrcmd->get_rid(), reduction.output_buffer_id, req); + const std::string bl = get_buffer_label(bm, reduction.output_buffer_id); + fmt::format_to(std::back_inserter(label), "reduction R{}
{} {}", rrcmd->get_rid(), bl, req); } else if(const auto hcmd = dynamic_cast(&cmd)) { label += "horizon"; } else { @@ -135,13 +148,13 @@ namespace detail { execution_range = ecmd->get_execution_range(); } - format_requirements(label, tsk, execution_range, reduction_init_mode, rm); + format_requirements(label, tsk, execution_range, reduction_init_mode, rm, bm); } return label; } - std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm) { + std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) { std::string main_dot; std::unordered_map task_subgraph_dot; @@ -149,7 +162,7 @@ namespace detail { static const char* const colors[] = {"black", "crimson", "dodgerblue4", "goldenrod", "maroon4", "springgreen2", "tan1", "chartreuse2"}; const auto name = cmd.get_cid(); - const auto label = get_command_label(cmd, tm, rm); + const auto label = get_command_label(cmd, tm, rm, bm); const auto fontcolor = colors[cmd.get_nid() % (sizeof(colors) / sizeof(char*))]; const auto shape = isa(&cmd) ? "box" : "ellipse"; return fmt::format("{}[label=<{}> fontcolor={} shape={}];", name, label, fontcolor, shape); diff --git a/src/runtime.cc b/src/runtime.cc index d5419d355..f0073e17b 100644 --- a/src/runtime.cc +++ b/src/runtime.cc @@ -214,7 +214,7 @@ namespace detail { } } { - const auto graph_str = m_cdag->print_graph(print_max_nodes, *m_task_mngr, *m_reduction_mngr); + const auto graph_str = m_cdag->print_graph(print_max_nodes, *m_task_mngr, *m_reduction_mngr, m_buffer_mngr.get()); if(graph_str.has_value()) { CELERITY_TRACE("Command graph:\n\n{}\n", *graph_str); } else { diff --git a/src/task_manager.cc b/src/task_manager.cc index b32e437ff..80486daee 100644 --- a/src/task_manager.cc +++ b/src/task_manager.cc @@ -27,7 +27,7 @@ namespace detail { const task* task_manager::get_task(task_id tid) const { return m_task_buffer.get_task(tid); } std::optional task_manager::print_graph(size_t max_nodes) const { - if(m_task_buffer.get_current_task_count() <= max_nodes) { return detail::print_task_graph(m_task_buffer, *m_reduction_mngr); } + if(m_task_buffer.get_current_task_count() <= max_nodes) { return detail::print_task_graph(m_task_buffer, *m_reduction_mngr, nullptr); } return std::nullopt; } diff --git a/test/buffer_manager_tests.cc b/test/buffer_manager_tests.cc index e2e0bd675..19528e449 100644 --- a/test/buffer_manager_tests.cc +++ b/test/buffer_manager_tests.cc @@ -1098,6 +1098,13 @@ namespace detail { } } + TEST_CASE_METHOD(test_utils::runtime_fixture, "buffer_manager allows to set buffer debug names on buffers", "[buffer_manager]") { + celerity::buffer buff_a(16); + std::string buff_name{"my_buffer"}; + detail::runtime::get_instance().get_buffer_manager().set_debug_name(detail::get_buffer_id(buff_a), buff_name); + CHECK(detail::runtime::get_instance().get_buffer_manager().get_debug_name(detail::get_buffer_id(buff_a)) == buff_name); + } + #endif // CELERITY_DETAIL_IS_OLD_COMPUTECPP_COMPILER } // namespace detail diff --git a/test/print_graph_tests.cc b/test/print_graph_tests.cc index 1f209c211..26adc09aa 100644 --- a/test/print_graph_tests.cc +++ b/test/print_graph_tests.cc @@ -1,5 +1,8 @@ +#include + #include "test_utils.h" + namespace celerity::detail { using celerity::access::fixed; @@ -72,7 +75,8 @@ TEST_CASE("command graph printing is unchanged", "[print_graph][command-graph]") // replace the `expected` value with the new dot graph. const auto expected = "digraph G{label=\"Command Graph\" subgraph cluster_2{label=<T2 (master-node host)>;color=darkgray;9[label=execution [[0,0,0] - [0,0,0]]
read B0 {[[0,0,0] - [1,1,1]]}
read_write B0 {[[0,0,0] - [1,1,1]]}
write " + "N0
execution [[0,0,0] - [0,0,0]]
read B0 {[[0,0,0] - [1,1,1]]}
read_write B0 {[[0,0,0] - " + "[1,1,1]]}
write " "B0 {[[0,0,0] - [1,1,1]]}> fontcolor=black shape=box];}subgraph cluster_1{label=<T1 \"task_reduction_8\" " "(device-compute)>;color=darkgray;5[label=execution [[0,0,0] - [1,1,1]]
(R1) discard_write B0 {[[0,0,0] - " "[1,1,1]]}> fontcolor=black shape=box];6[label=execution [[1,0,0] - [2,1,1]]
(R1) discard_write B0 {[[0,0,0] - " @@ -81,16 +85,46 @@ TEST_CASE("command graph printing is unchanged", "[print_graph][command-graph]") "[1,1,1]]}> fontcolor=goldenrod shape=box];}subgraph cluster_0{label=<T0 (epoch)>;color=darkgray;0[label=epoch> fontcolor=black shape=box];1[label=epoch> fontcolor=crimson shape=box];2[label=epoch> " "fontcolor=dodgerblue4 shape=box];3[label=epoch> fontcolor=goldenrod shape=box];}16[label=(R1) await push " - "from N3
B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->16[color=orchid];15->16[style=dashed color=gray40];15[label=(R1) " - "push to N0
B0 [[0,0,0] - [1,1,1]]> fontcolor=goldenrod shape=ellipse];8->15[];14[label=(R1) await push from " - "N2
B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->14[color=orchid];13->14[style=dashed color=gray40];13[label=(R1) " - "push to N0
B0 [[0,0,0] - [1,1,1]]> fontcolor=dodgerblue4 " + "from N3
B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->16[color=orchid];15->16[style=dashed color=gray40];15[label=(R1) " + "push to N0
B0 [[0,0,0] - [1,1,1]]> fontcolor=goldenrod shape=ellipse];8->15[];14[label=(R1) await push from " + "N2
B0 [[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];0->14[color=orchid];13->14[style=dashed color=gray40];13[label=(R1) " + "push to N0
B0 [[0,0,0] - [1,1,1]]> fontcolor=dodgerblue4 " "shape=ellipse];7->13[];0->5[color=orchid];1->6[color=orchid];2->7[color=orchid];3->8[color=orchid];10->9[];10[label=reduction " - "R1
B0 {[[0,0,0] - [1,1,1]]}> fontcolor=black shape=ellipse];5->10[];12->10[];14->10[];16->10[];11[label=(R1) push to " - "N0
B0 [[0,0,0] - [1,1,1]]> fontcolor=crimson shape=ellipse];6->11[];12[label=(R1) await push from N1
B0 [[0,0,0] - " + "R1
B0 {[[0,0,0] - [1,1,1]]}> fontcolor=black shape=ellipse];5->10[];12->10[];14->10[];16->10[];11[label=(R1) push to " + "N0
B0 [[0,0,0] - [1,1,1]]> fontcolor=crimson shape=ellipse];6->11[];12[label=(R1) await push from N1
B0 [[0,0,0] - " "[1,1,1]]> fontcolor=black shape=ellipse];0->12[color=orchid];11->12[style=dashed color=gray40];}"; - const auto dot = ctx.get_command_graph().print_graph(std::numeric_limits::max(), tm, rm).value(); + const auto dot = ctx.get_command_graph().print_graph(std::numeric_limits::max(), tm, rm, {}).value(); + CHECK(dot == expected); +} + +TEST_CASE_METHOD(test_utils::runtime_fixture, "Buffer debug names show up in the generated graph", "[print_graph][buffer_names]") { + distr_queue q; + celerity::range<1> range(16); + celerity::buffer buff_a(range); + std::string buff_name{"my_buffer"}; + celerity::debug::set_buffer_name(buff_a, buff_name); + CHECK(celerity::debug::get_buffer_name(buff_a) == buff_name); + + q.submit([=](handler& cgh) { + celerity::accessor acc{buff_a, cgh, celerity::access::all{}, celerity::write_only}; + cgh.parallel_for(range, [=](item<1> item) {}); + }); + + // wait for commands to be generated in the scheduler thread + q.slow_full_sync(); + + // Smoke test: It is valid for the dot output to change with updates to graph generation. If this test fails, verify that the printed graph is sane and + // replace the `expected` value with the new dot graph. + const auto expected = + "digraph G{label=\"Command Graph\" subgraph cluster_0{label=<T0 (epoch)>;color=darkgray;0[label=epoch> fontcolor=black shape=box];}subgraph cluster_1{label=<T1 \"print_graph_buffer_name_11\" " + "(device-compute)>;color=darkgray;2[label=execution [[0,0,0] - [16,1,1]]
write B0 \"my_buffer\" {[[0,0,0] - " + "[16,1,1]]}> fontcolor=black shape=box];}subgraph cluster_2{label=<T2 (epoch)>;color=darkgray;3[label=epoch (barrier)> fontcolor=black shape=box];}2->3[color=orange];0->2[];}"; + + const auto dot = runtime_testspy::print_graph(runtime::get_instance()); CHECK(dot == expected); } } // namespace celerity::detail \ No newline at end of file diff --git a/test/runtime_tests.cc b/test/runtime_tests.cc index 28c3ff0f1..7660d76f9 100644 --- a/test/runtime_tests.cc +++ b/test/runtime_tests.cc @@ -44,14 +44,6 @@ namespace detail { } }; - struct runtime_testspy { - static scheduler& get_schdlr(runtime& rt) { return *rt.m_schdlr; } - - static executor& get_exec(runtime& rt) { return *rt.m_exec; } - - static size_t get_command_count(runtime& rt) { return rt.m_cdag->command_count(); } - }; - struct scheduler_testspy { static std::thread& get_worker_thread(scheduler& schdlr) { return schdlr.m_worker_thread; } }; diff --git a/test/test_utils.h b/test/test_utils.h index 56f6dfc2f..6c2bc3a10 100644 --- a/test/test_utils.h +++ b/test/test_utils.h @@ -46,6 +46,16 @@ namespace celerity { namespace detail { + struct runtime_testspy { + static scheduler& get_schdlr(runtime& rt) { return *rt.m_schdlr; } + static executor& get_exec(runtime& rt) { return *rt.m_exec; } + static size_t get_command_count(runtime& rt) { return rt.m_cdag->command_count(); } + static command_graph& get_cdag(runtime& rt) { return *rt.m_cdag; } + static std::string print_graph(runtime& rt) { + return rt.m_cdag.get()->print_graph(std::numeric_limits::max(), *rt.m_task_mngr, *rt.m_reduction_mngr, rt.m_buffer_mngr.get()).value(); + } + }; + struct task_ring_buffer_testspy { static void create_task_slot(task_ring_buffer& trb) { trb.m_number_of_deleted_tasks += 1; } }; @@ -426,7 +436,7 @@ namespace test_utils { inline void maybe_print_graph( celerity::detail::command_graph& cdag, const celerity::detail::task_manager& tm, const celerity::detail::reduction_manager& rm) { if(print_graphs) { - const auto graph_str = cdag.print_graph(std::numeric_limits::max(), tm, rm); + const auto graph_str = cdag.print_graph(std::numeric_limits::max(), tm, rm, {}); assert(graph_str.has_value()); CELERITY_INFO("Command graph:\n\n{}\n", *graph_str); }