Skip to content

Commit

Permalink
Spend more effort to make type-derived task labels readable
Browse files Browse the repository at this point in the history
  • Loading branch information
PeterTh committed Sep 7, 2023
1 parent 926cf24 commit 727de29
Show file tree
Hide file tree
Showing 10 changed files with 93 additions and 39 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ set(SOURCES
src/task.cc
src/task_manager.cc
src/user_bench.cc
src/utils.cc
src/worker_job.cc
"${CMAKE_CURRENT_BINARY_DIR}/src/version.cc"
)
Expand Down
6 changes: 2 additions & 4 deletions include/handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,13 @@ namespace detail {

template <typename Name>
std::string kernel_debug_name() {
// we need to typeid a pointer, since the name is often undefined
std::string name = typeid(Name*).name();
#if !defined(_MSC_VER)
const std::unique_ptr<char, void (*)(void*)> demangled(abi::__cxa_demangle(name.c_str(), nullptr, nullptr, nullptr), std::free);
name = demangled.get();
#elif defined(_MSC_VER)
if(size_t lastc, id_end; (lastc = name.rfind(":")) != std::string::npos && (id_end = name.find(" ", lastc)) != std::string::npos) {
name = name.substr(lastc + 1, id_end - lastc);
}
#endif
// get rid of the pointer "*"
return name.substr(0, name.length() - 1);
}

Expand Down
6 changes: 6 additions & 0 deletions include/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include <cstdint>
#include <functional>
#include <string>
#include <type_traits>
#include <variant>

Expand Down Expand Up @@ -83,5 +84,10 @@ static auto tuple_without(const std::tuple<Ts...>& tuple) {
}
}

// fiddles out the base name of a task from a full, demangled input type name
std::string simplify_task_name(const std::string& demangled_type_name);

// escapes "<", ">", and "&" with their corresponding HTML escape sequences
std::string escape_for_dot_label(std::string str);

} // namespace celerity::detail::utils
4 changes: 2 additions & 2 deletions src/print_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ void format_requirements(std::string& label, const reduction_list& reductions, c
std::string get_task_label(const task_record& tsk) {
std::string label;
fmt::format_to(std::back_inserter(label), "T{}", tsk.tid);
if(!tsk.debug_name.empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.debug_name); }
if(!tsk.debug_name.empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", utils::escape_for_dot_label(tsk.debug_name)); }

fmt::format_to(std::back_inserter(label), "<br/><b>{}</b>", task_type_string(tsk.type));
if(tsk.type == task_type::host_compute || tsk.type == task_type::device_compute) {
Expand Down Expand Up @@ -180,7 +180,7 @@ std::string print_command_graph(const node_id local_nid, const command_recorder&
if(task_subgraph_dot.count(tid) == 0) {
std::string task_label;
fmt::format_to(std::back_inserter(task_label), "T{} ", tid);
if(!cmd->task_name.empty()) { fmt::format_to(std::back_inserter(task_label), "\"{}\" ", cmd->task_name); }
if(!cmd->task_name.empty()) { fmt::format_to(std::back_inserter(task_label), "\"{}\" ", utils::escape_for_dot_label(cmd->task_name)); }
task_label += "(";
task_label += task_type_string(cmd->task_type.value());
if(cmd->task_type == task_type::collective) { fmt::format_to(std::back_inserter(task_label), " on CG{}", cmd->collective_group_id.value()); }
Expand Down
17 changes: 2 additions & 15 deletions src/recorders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,10 @@ namespace celerity::detail {

// Naming

// removes initial template qualifiers to simplify, and escapes '<' and '>' in the given name,
// so that it can be successfully used in a dot graph label that uses HTML, and is hopefully readable
std::string simplify_and_escape_name(const std::string& name) {
// simplify
auto first_opening_pos = name.find('<');
auto namespace_qual_end_pos = name.rfind(':', first_opening_pos);
auto simplified = namespace_qual_end_pos != std::string::npos ? name.substr(namespace_qual_end_pos + 1) : name;
// escape
simplified = std::regex_replace(simplified, std::regex("<"), "&lt;");
return std::regex_replace(simplified, std::regex(">"), "&gt;");
}

std::string get_buffer_name(const buffer_id bid, const buffer_manager* buff_man) {
return buff_man != nullptr ? buff_man->get_buffer_info(bid).debug_name : "";
}


// Tasks

access_list build_access_list(const task& tsk, const buffer_manager* buff_man, const std::optional<subrange<3>> execution_range = {}) {
Expand Down Expand Up @@ -57,7 +44,7 @@ task_dependency_list build_task_dependency_list(const task& tsk) {
}

task_record::task_record(const task& from, const buffer_manager* buff_mngr)
: tid(from.get_id()), debug_name(simplify_and_escape_name(from.get_debug_name())), cgid(from.get_collective_group_id()), type(from.get_type()),
: tid(from.get_id()), debug_name(utils::simplify_task_name(from.get_debug_name())), cgid(from.get_collective_group_id()), type(from.get_type()),
geometry(from.get_geometry()), reductions(build_reduction_list(from, buff_mngr)), accesses(build_access_list(from, buff_mngr)),
side_effect_map(from.get_side_effect_map()), dependencies(build_task_dependency_list(from)) {}

Expand Down Expand Up @@ -181,7 +168,7 @@ command_dependency_list build_command_dependency_list(const abstract_command& cm
}

std::string get_task_name(const abstract_command& cmd, const task_manager* task_mngr) {
if(const auto* tsk = get_task_for(cmd, task_mngr)) return simplify_and_escape_name(tsk->get_debug_name());
if(const auto* tsk = get_task_for(cmd, task_mngr)) return utils::simplify_task_name(tsk->get_debug_name());
return {};
}

Expand Down
45 changes: 45 additions & 0 deletions src/utils.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#include <utils.h>

#include <regex>

namespace celerity::detail::utils {

std::string simplify_task_name(const std::string& demangled_type_name) {
if(demangled_type_name.length() < 2) return demangled_type_name;
bool templated = false;
// there are two options:
// 1. the type is templated; in this case, the last character is ">" and we go back to the matching "<"
std::string::size_type last_idx = demangled_type_name.length() - 1;
if(demangled_type_name[last_idx] == '>') {
templated = true;
int open = 0;
while(last_idx > 1) {
last_idx--;
if(demangled_type_name[last_idx] == '>') { open++; }
if(demangled_type_name[last_idx] == '<') {
if(open > 0) {
open--;
} else {
last_idx--;
break;
}
}
}
}
// 2. the type isn't templated (or we just removed the template); in this case, we are interested in the part from the end to the last ":" (or the start)
std::string::size_type start_idx = last_idx - 1;
while(start_idx > 0 && demangled_type_name[start_idx - 1] != ':') {
start_idx--;
}
// if the type was templated, we add a "<...>" to indicate that
return demangled_type_name.substr(start_idx, last_idx - start_idx + 1) + (templated ? "<...>" : "");
}

std::string escape_for_dot_label(std::string str) {
str = std::regex_replace(str, std::regex("&"), "&amp;");
str = std::regex_replace(str, std::regex("<"), "&lt;");
str = std::regex_replace(str, std::regex(">"), "&gt;");
return str;
}

} // namespace celerity::detail::utils
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ set(TEST_TARGETS
task_graph_tests
task_ring_buffer_tests
test_utils_tests
utils_tests
device_selection_tests
)

Expand Down
18 changes: 4 additions & 14 deletions test/print_graph_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,28 +190,18 @@ TEST_CASE_METHOD(test_utils::runtime_fixture, "full graph is printed if CELERITY
}
}

namespace test_ns {
namespace x {
enum class ec { a, b };
}
template <test_ns::x::ec X>
template <int X>
class name_class {};

template <test_ns::x::ec X>
void compute(task_manager& tm, mock_buffer<1> buf, const celerity::range<1> range) {
test_utils::add_compute_task<name_class<X>>(
tm, [&](handler& cgh) { buf.get_access<access_mode::discard_write>(cgh, acc::one_to_one{}); }, range);
}
} // namespace test_ns

TEST_CASE("task-graph names are escaped", "[print_graph][task-graph][task-name]") {
auto tt = test_utils::task_test_context{};

auto range = celerity::range<1>(64);
auto buf = tt.mbf.create_buffer(range);

test_ns::compute<test_ns::x::ec::a>(tt.tm, buf, range);
test_utils::add_compute_task<name_class<5>>(
tt.tm, [&](handler& cgh) { buf.get_access<access_mode::discard_write>(cgh, acc::one_to_one{}); }, range);

const auto* escaped_name = "\"name_class&lt;(test_ns::x::ec)0&gt;\"";
const auto* escaped_name = "\"name_class&lt;...&gt;\"";
REQUIRE_THAT(print_task_graph(tt.trec), Catch::Matchers::ContainsSubstring(escaped_name));
}
8 changes: 4 additions & 4 deletions test/runtime_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -295,14 +295,14 @@ namespace detail {
template <typename T>
class MyThirdKernel;

TEST_CASE("device_compute tasks derive debug name from kernel name", "[task][!mayfail]") {
TEST_CASE("device_compute tasks derive debug name from kernel name", "[task]") {
auto tm = detail::task_manager(1, nullptr, {});
const auto t1 = tm.get_task(tm.submit_command_group([](handler& cgh) { cgh.parallel_for<class MyFirstKernel>(range<1>{1}, [](id<1>) {}); }));
const auto t2 = tm.get_task(tm.submit_command_group([](handler& cgh) { cgh.parallel_for<foo::MySecondKernel>(range<1>{1}, [](id<1>) {}); }));
const auto t3 = tm.get_task(tm.submit_command_group([](handler& cgh) { cgh.parallel_for<MyThirdKernel<int>>(range<1>{1}, [](id<1>) {}); }));
REQUIRE(t1->get_debug_name() == "MyFirstKernel");
REQUIRE(t2->get_debug_name() == "MySecondKernel");
REQUIRE(t3->get_debug_name() == "MyThirdKernel<int>");
CHECK(utils::simplify_task_name(t1->get_debug_name()) == "MyFirstKernel");
CHECK(utils::simplify_task_name(t2->get_debug_name()) == "MySecondKernel");
CHECK(utils::simplify_task_name(t3->get_debug_name()) == "MyThirdKernel<...>");
}

TEST_CASE_METHOD(test_utils::runtime_fixture, "basic SYNC command functionality", "[distr_queue][sync][control-flow]") {
Expand Down
26 changes: 26 additions & 0 deletions test/utils_tests.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#include <utils.h>

#include <catch2/catch_test_macros.hpp>

using namespace celerity::detail::utils;
using std::string;

TEST_CASE("name strings are correctly extracted from types", "[utils][simplify_task_name]") {
const string simple = "name";
CHECK(simplify_task_name(simple) == "name");

const string namespaced = "ns::another::name2";
CHECK(simplify_task_name(namespaced) == "name2");

const string templated = "name3<class, int>";
CHECK(simplify_task_name(templated) == "name3<...>");

const string real = "set_identity<float>(celerity::distr_queue, celerity::buffer<float, 2>, "
"bool)::{lambda(celerity::handler&)#1}::operator()(celerity::handler&) const::set_identity_kernel<const celerity::buffer&>";
CHECK(simplify_task_name(real) == "set_identity_kernel<...>");
}

TEST_CASE("escaping of invalid characters for dot labels", "[utils][escape_for_dot_label]") {
const string test = "hello<bla&>";
CHECK(escape_for_dot_label(test) == "hello&lt;bla&amp;&gt;");
}

0 comments on commit 727de29

Please sign in to comment.