Skip to content

Commit

Permalink
Add buffer debug names
Browse files Browse the repository at this point in the history
This adds `[set|get]_buffer_name`, which should be more significant names than just the ids.
If set, these names are then included in the GraphViz output.
  • Loading branch information
facuMH authored Sep 20, 2022
1 parent 53d7ef7 commit 1076522
Show file tree
Hide file tree
Showing 15 changed files with 163 additions and 62 deletions.
3 changes: 3 additions & 0 deletions examples/matmul/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ int main() {
celerity::buffer<float, 2> mat_b_buf(range);
celerity::buffer<float, 2> mat_c_buf(range);

celerity::debug::set_buffer_name(mat_a_buf, "mat_a");
celerity::debug::set_buffer_name(mat_b_buf, "mat_b");

set_identity(queue, mat_a_buf);
set_identity(queue, mat_b_buf);

Expand Down
52 changes: 30 additions & 22 deletions include/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,18 @@ class buffer;

namespace detail {

struct buffer_lifetime_tracker {
buffer_lifetime_tracker() = default;
template <typename DataT, int Dims>
buffer_id initialize(celerity::range<3> range, const DataT* host_init_ptr) {
id = runtime::get_instance().get_buffer_manager().register_buffer<DataT, Dims>(range, host_init_ptr);
return id;
}
buffer_lifetime_tracker(const buffer_lifetime_tracker&) = delete;
buffer_lifetime_tracker(buffer_lifetime_tracker&&) = delete;
~buffer_lifetime_tracker() noexcept { runtime::get_instance().get_buffer_manager().unregister_buffer(id); }
buffer_id id;
};

template <typename T, int D>
buffer_id get_buffer_id(const buffer<T, D>& buff);

template <typename DataT, int Dims>
void set_buffer_name(const celerity::buffer<DataT, Dims>& buff, const std::string& debug_name) {
buff.m_impl->debug_name = debug_name;
};
template <typename DataT, int Dims>
std::string get_buffer_name(const celerity::buffer<DataT, Dims>& buff) {
return buff.m_impl->debug_name;
};

} // namespace detail

template <typename DataT, int Dims, access_mode Mode, target Target>
Expand All @@ -44,11 +40,9 @@ class buffer {
public:
static_assert(Dims > 0, "0-dimensional buffers NYI");

buffer(const DataT* host_ptr, celerity::range<Dims> range) : m_range(range) {
buffer(const DataT* host_ptr, celerity::range<Dims> range) {
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr); }

m_lifetime_tracker = std::make_shared<detail::buffer_lifetime_tracker>();
m_id = m_lifetime_tracker->initialize<DataT, Dims>(detail::range_cast<3>(range), host_ptr);
m_impl = std::make_shared<impl>(range, host_ptr);
}

buffer(celerity::range<Dims> range) : buffer(nullptr, range) {}
Expand All @@ -72,22 +66,36 @@ class buffer {
return accessor<DataT, Dims, Mode, Target>(*this, cgh, rmfn);
}

celerity::range<Dims> get_range() const { return m_range; }
celerity::range<Dims> get_range() const { return m_impl->range; }

private:
std::shared_ptr<detail::buffer_lifetime_tracker> m_lifetime_tracker = nullptr;
celerity::range<Dims> m_range;
detail::buffer_id m_id;
struct impl {
impl(celerity::range<Dims> rng, const DataT* host_init_ptr) : range(rng) {
id = detail::runtime::get_instance().get_buffer_manager().register_buffer<DataT, Dims>(detail::range_cast<3>(range), host_init_ptr);
}
impl(const impl&) = delete;
impl(impl&&) = delete;
~impl() noexcept { detail::runtime::get_instance().get_buffer_manager().unregister_buffer(id); }
detail::buffer_id id;
celerity::range<Dims> range;
std::string debug_name;
};

std::shared_ptr<impl> m_impl = nullptr;

template <typename T, int D>
friend detail::buffer_id detail::get_buffer_id(const buffer<T, D>& buff);
template <typename T, int D>
friend void detail::set_buffer_name(const celerity::buffer<T, D>& buff, const std::string& debug_name);
template <typename T, int D>
friend std::string detail::get_buffer_name(const celerity::buffer<T, D>& buff);
};

namespace detail {

template <typename T, int D>
buffer_id get_buffer_id(const buffer<T, D>& buff) {
return buff.m_id;
return buff.m_impl->id;
}

} // namespace detail
Expand Down
14 changes: 13 additions & 1 deletion include/buffer_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ namespace detail {
cl::sycl::range<3> range = {1, 1, 1};
size_t element_size = 0;
bool is_host_initialized;
std::string debug_name = {};
};

/**
Expand Down Expand Up @@ -163,7 +164,8 @@ namespace detail {
return !m_buffer_infos.empty();
}

const buffer_info& get_buffer_info(buffer_id bid) const {
// returning copy of struct because FOR NOW it is not called in any performance critical section.
buffer_info get_buffer_info(buffer_id bid) const {
std::shared_lock lock(m_mutex);
assert(m_buffer_infos.find(bid) != m_buffer_infos.end());
return m_buffer_infos.at(bid);
Expand Down Expand Up @@ -312,6 +314,16 @@ namespace detail {

bool is_locked(buffer_id bid) const;

void set_debug_name(const buffer_id bid, const std::string& debug_name) {
std::lock_guard lock(m_mutex);
m_buffer_infos.at(bid).debug_name = debug_name;
}

std::string get_debug_name(const buffer_id bid) const {
std::lock_guard lock(m_mutex);
return m_buffer_infos.at(bid).debug_name;
}

private:
struct backing_buffer {
std::unique_ptr<buffer_storage> storage = nullptr;
Expand Down
1 change: 1 addition & 0 deletions include/celerity.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

#include "accessor.h"
#include "buffer.h"
#include "debug.h"
#include "distr_queue.h"
#include "side_effect.h"
#include "user_bench.h"
Expand Down
3 changes: 2 additions & 1 deletion include/command_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
namespace celerity {
namespace detail {

class buffer_manager;
class reduction_manager;
class task_manager;

Expand Down Expand Up @@ -128,7 +129,7 @@ namespace detail {

auto& task_commands(task_id tid) { return m_by_task.at(tid); }

std::optional<std::string> print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm) const;
std::optional<std::string> print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) const;

// TODO unify dependency terminology to this
void add_dependency(abstract_command* depender, abstract_command* dependee, dependency_kind kind, dependency_origin origin) {
Expand Down
18 changes: 18 additions & 0 deletions include/debug.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
#include <string>

#include "buffer.h"

namespace celerity {
namespace debug {
template <typename DataT, int Dims>
void set_buffer_name(const celerity::buffer<DataT, Dims>& buff, const std::string& debug_name) {
detail::set_buffer_name(buff, debug_name);
detail::runtime::get_instance().get_buffer_manager().set_debug_name(detail::get_buffer_id(buff), debug_name);
}
template <typename DataT, int Dims>
std::string get_buffer_name(const celerity::buffer<DataT, Dims>& buff) {
return detail::get_buffer_name(buff);
}

} // namespace debug
} // namespace celerity
5 changes: 3 additions & 2 deletions include/print_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
namespace celerity {
namespace detail {

class buffer_manager;
class command_graph;
class reduction_manager;
class task_manager;

std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm);
std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm);
std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm, const buffer_manager* bm);
std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm);

} // namespace detail
} // namespace celerity
5 changes: 3 additions & 2 deletions src/command_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ namespace detail {
}
}

std::optional<std::string> command_graph::print_graph(size_t max_nodes, const task_manager& tm, const reduction_manager& rm) const {
if(command_count() <= max_nodes) { return detail::print_command_graph(*this, tm, rm); }
std::optional<std::string> command_graph::print_graph(
size_t max_nodes, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) const {
if(command_count() <= max_nodes) { return detail::print_command_graph(*this, tm, rm, bm); }
return std::nullopt;
}

Expand Down
43 changes: 28 additions & 15 deletions src/print_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@ namespace detail {
}
}

void format_requirements(std::string& label, const task& tsk, subrange<3> execution_range, access_mode reduction_init_mode, const reduction_manager& rm) {
std::string get_buffer_label(const buffer_manager* bm, const buffer_id bid) {
// if there is no buffer manager or no name defined, the name will be the buffer id.
// if there is a name we want "id name"
std::string name;
if(bm != nullptr) { name = bm->get_debug_name(bid); }
return !name.empty() ? fmt::format("B{} \"{}\"", bid, name) : fmt::format("B{}", bid);
}

void format_requirements(std::string& label, const task& tsk, subrange<3> execution_range, access_mode reduction_init_mode, const reduction_manager& rm,
const buffer_manager* bm) {
for(auto rid : tsk.get_reductions()) {
auto reduction = rm.get_reduction(rid);

Expand All @@ -43,15 +52,17 @@ namespace detail {

const auto bid = reduction.output_buffer_id;
const auto req = GridRegion<3>{{1, 1, 1}};
fmt::format_to(std::back_inserter(label), "<br/>(R{}) <i>{}</i> B{} {}", rid, detail::access::mode_traits::name(rmode), bid, req);
const std::string bl = get_buffer_label(bm, bid);
fmt::format_to(std::back_inserter(label), "<br/>(R{}) <i>{}</i> {} {}", rid, detail::access::mode_traits::name(rmode), bl, req);
}

const auto& bam = tsk.get_buffer_access_map();
for(const auto bid : bam.get_accessed_buffers()) {
for(const auto mode : bam.get_access_modes(bid)) {
const auto req = bam.get_requirements_for_access(bid, mode, tsk.get_dimensions(), execution_range, tsk.get_global_size());
const std::string bl = get_buffer_label(bm, bid);
// While uncommon, we do support chunks that don't require access to a particular buffer at all.
if(!req.empty()) { fmt::format_to(std::back_inserter(label), "<br/><i>{}</i> B{} {}", detail::access::mode_traits::name(mode), bid, req); }
if(!req.empty()) { fmt::format_to(std::back_inserter(label), "<br/><i>{}</i> {} {}", detail::access::mode_traits::name(mode), bl, req); }
}
}

Expand All @@ -60,7 +71,7 @@ namespace detail {
}
}

std::string get_task_label(const task& tsk, const reduction_manager& rm) {
std::string get_task_label(const task& tsk, const reduction_manager& rm, const buffer_manager* bm) {
std::string label;
fmt::format_to(std::back_inserter(label), "T{}", tsk.get_id());
if(!tsk.get_debug_name().empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.get_debug_name()); }
Expand All @@ -74,17 +85,17 @@ namespace detail {
fmt::format_to(std::back_inserter(label), " in CG{}", tsk.get_collective_group_id());
}

format_requirements(label, tsk, execution_range, access_mode::read_write, rm);
format_requirements(label, tsk, execution_range, access_mode::read_write, rm, bm);

return label;
}

std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm) {
std::string print_task_graph(const task_ring_buffer& tdag, const reduction_manager& rm, const buffer_manager* bm) {
std::string dot = "digraph G {label=\"Task Graph\" ";

for(auto tsk : tdag) {
const auto shape = tsk->get_type() == task_type::epoch || tsk->get_type() == task_type::horizon ? "ellipse" : "box style=rounded";
fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->get_id(), shape, get_task_label(*tsk, rm));
fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk->get_id(), shape, get_task_label(*tsk, rm, bm));
for(auto d : tsk->get_dependencies()) {
fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", d.node->get_id(), tsk->get_id(), dependency_style(d));
}
Expand All @@ -94,7 +105,7 @@ namespace detail {
return dot;
}

std::string get_command_label(const abstract_command& cmd, const task_manager& tm, const reduction_manager& rm) {
std::string get_command_label(const abstract_command& cmd, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) {
const command_id cid = cmd.get_cid();
const node_id nid = cmd.get_nid();

Expand All @@ -108,16 +119,18 @@ namespace detail {
fmt::format_to(std::back_inserter(label), "<b>execution</b> {}", subrange_to_grid_box(xcmd->get_execution_range()));
} else if(const auto pcmd = dynamic_cast<const push_command*>(&cmd)) {
if(pcmd->get_rid()) { fmt::format_to(std::back_inserter(label), "(R{}) ", pcmd->get_rid()); }
fmt::format_to(
std::back_inserter(label), "<b>push</b> to N{}<br/>B{} {}", pcmd->get_target(), pcmd->get_bid(), subrange_to_grid_box(pcmd->get_range()));
const std::string bl = get_buffer_label(bm, pcmd->get_bid());
fmt::format_to(std::back_inserter(label), "<b>push</b> to N{}<br/> {} {}", pcmd->get_target(), bl, subrange_to_grid_box(pcmd->get_range()));
} else if(const auto apcmd = dynamic_cast<const await_push_command*>(&cmd)) {
if(apcmd->get_source()->get_rid()) { label += fmt::format("(R{}) ", apcmd->get_source()->get_rid()); }
fmt::format_to(std::back_inserter(label), "<b>await push</b> from N{}<br/>B{} {}", apcmd->get_source()->get_nid(), apcmd->get_source()->get_bid(),
const std::string bl = get_buffer_label(bm, apcmd->get_source()->get_bid());
fmt::format_to(std::back_inserter(label), "<b>await push</b> from N{}<br/> {} {}", apcmd->get_source()->get_nid(), bl,
subrange_to_grid_box(apcmd->get_source()->get_range()));
} else if(const auto rrcmd = dynamic_cast<const reduction_command*>(&cmd)) {
const auto reduction = rm.get_reduction(rrcmd->get_rid());
const auto req = GridRegion<3>{{1, 1, 1}};
fmt::format_to(std::back_inserter(label), "<b>reduction</b> R{}<br/>B{} {}", rrcmd->get_rid(), reduction.output_buffer_id, req);
const std::string bl = get_buffer_label(bm, reduction.output_buffer_id);
fmt::format_to(std::back_inserter(label), "<b>reduction</b> R{}<br/> {} {}", rrcmd->get_rid(), bl, req);
} else if(const auto hcmd = dynamic_cast<const horizon_command*>(&cmd)) {
label += "<b>horizon</b>";
} else {
Expand All @@ -135,21 +148,21 @@ namespace detail {
execution_range = ecmd->get_execution_range();
}

format_requirements(label, tsk, execution_range, reduction_init_mode, rm);
format_requirements(label, tsk, execution_range, reduction_init_mode, rm, bm);
}

return label;
}

std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm) {
std::string print_command_graph(const command_graph& cdag, const task_manager& tm, const reduction_manager& rm, const buffer_manager* bm) {
std::string main_dot;
std::unordered_map<task_id, std::string> task_subgraph_dot;

const auto print_vertex = [&](const abstract_command& cmd) {
static const char* const colors[] = {"black", "crimson", "dodgerblue4", "goldenrod", "maroon4", "springgreen2", "tan1", "chartreuse2"};

const auto name = cmd.get_cid();
const auto label = get_command_label(cmd, tm, rm);
const auto label = get_command_label(cmd, tm, rm, bm);
const auto fontcolor = colors[cmd.get_nid() % (sizeof(colors) / sizeof(char*))];
const auto shape = isa<task_command>(&cmd) ? "box" : "ellipse";
return fmt::format("{}[label=<{}> fontcolor={} shape={}];", name, label, fontcolor, shape);
Expand Down
2 changes: 1 addition & 1 deletion src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ namespace detail {
}
}
{
const auto graph_str = m_cdag->print_graph(print_max_nodes, *m_task_mngr, *m_reduction_mngr);
const auto graph_str = m_cdag->print_graph(print_max_nodes, *m_task_mngr, *m_reduction_mngr, m_buffer_mngr.get());
if(graph_str.has_value()) {
CELERITY_TRACE("Command graph:\n\n{}\n", *graph_str);
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/task_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace detail {
const task* task_manager::get_task(task_id tid) const { return m_task_buffer.get_task(tid); }

std::optional<std::string> task_manager::print_graph(size_t max_nodes) const {
if(m_task_buffer.get_current_task_count() <= max_nodes) { return detail::print_task_graph(m_task_buffer, *m_reduction_mngr); }
if(m_task_buffer.get_current_task_count() <= max_nodes) { return detail::print_task_graph(m_task_buffer, *m_reduction_mngr, nullptr); }
return std::nullopt;
}

Expand Down
7 changes: 7 additions & 0 deletions test/buffer_manager_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1098,6 +1098,13 @@ namespace detail {
}
}

TEST_CASE_METHOD(test_utils::runtime_fixture, "buffer_manager allows to set buffer debug names on buffers", "[buffer_manager]") {
celerity::buffer<int, 1> buff_a(16);
std::string buff_name{"my_buffer"};
detail::runtime::get_instance().get_buffer_manager().set_debug_name(detail::get_buffer_id(buff_a), buff_name);
CHECK(detail::runtime::get_instance().get_buffer_manager().get_debug_name(detail::get_buffer_id(buff_a)) == buff_name);
}

#endif // CELERITY_DETAIL_IS_OLD_COMPUTECPP_COMPILER

} // namespace detail
Expand Down
Loading

0 comments on commit 1076522

Please sign in to comment.