Skip to content

Commit

Permalink
Add naming capability to tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
GagaLP committed Sep 14, 2023
1 parent f190da3 commit 0f7145c
Show file tree
Hide file tree
Showing 6 changed files with 84 additions and 3 deletions.
4 changes: 4 additions & 0 deletions examples/matmul/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ void set_identity(celerity::distr_queue queue, celerity::buffer<T, 2> mat, bool
queue.submit([&](celerity::handler& cgh) {
celerity::accessor dw{mat, cgh, celerity::access::one_to_one{}, celerity::write_only, celerity::no_init};
const auto range = mat.get_range();

celerity::debug::set_task_name(cgh, "set identity");
cgh.parallel_for<class set_identity_kernel>(range, [=](celerity::item<2> item) {
if(!reverse) {
dw[item] = item[0] == item[1];
Expand All @@ -36,6 +38,7 @@ void multiply(celerity::distr_queue queue, celerity::buffer<T, 2> mat_a, celerit
celerity::local_accessor<T, 2> scratch_a{{group_size, group_size}, cgh};
celerity::local_accessor<T, 2> scratch_b{{group_size, group_size}, cgh};

celerity::debug::set_task_name(cgh, "matrix multiplication");
cgh.parallel_for<class mat_mul>(celerity::nd_range<2>{{MAT_SIZE, MAT_SIZE}, {group_size, group_size}}, [=](celerity::nd_item<2> item) {
T sum{};
const auto lid = item.get_local_id();
Expand Down Expand Up @@ -63,6 +66,7 @@ void verify(celerity::distr_queue& queue, celerity::buffer<T, 2> mat_c, celerity
celerity::accessor c{mat_c, cgh, celerity::access::one_to_one{}, celerity::read_only_host_task};
celerity::experimental::side_effect passed{passed_obj, cgh};

celerity::debug::set_task_name(cgh, "verification");
cgh.host_task(mat_c.get_range(), [=](celerity::partition<2> part) {
*passed = true;
const auto& sr = part.get_subrange();
Expand Down
4 changes: 4 additions & 0 deletions include/debug.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <string>

#include "buffer.h"
#include "handler.h"

namespace celerity {
namespace debug {
Expand All @@ -14,5 +15,8 @@ namespace debug {
return detail::get_buffer_name(buff);
}

inline void set_task_name(celerity::handler& cgh, const std::string& debug_name) {
detail::set_task_name(cgh, debug_name);
}
} // namespace debug
} // namespace celerity
20 changes: 19 additions & 1 deletion include/handler.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@ namespace detail {
void add_reduction(handler& cgh, const reduction_info& rinfo);
void extend_lifetime(handler& cgh, std::shared_ptr<detail::lifetime_extending_state> state);

void set_task_name(handler& cgh, const std::string& debug_name);

template <typename Name>
std::string kernel_debug_name() {
// we need to typeid a pointer, since the name is often undefined
Expand Down Expand Up @@ -374,6 +376,8 @@ class handler {
friend void detail::add_reduction(handler& cgh, const detail::reduction_info& rinfo);
friend void detail::extend_lifetime(handler& cgh, std::shared_ptr<detail::lifetime_extending_state> state);

friend void detail::set_task_name(handler &cgh, const std::string& debug_name);

detail::task_id m_tid;
detail::buffer_access_map m_access_map;
detail::side_effect_map m_side_effects;
Expand All @@ -383,6 +387,7 @@ class handler {
size_t m_num_collective_nodes;
detail::hydration_id m_next_accessor_hydration_id = 1;
std::vector<std::shared_ptr<detail::lifetime_extending_state>> m_attached_state;
std::optional<std::string> m_usr_def_task_name;

handler(detail::task_id tid, size_t num_collective_nodes) : m_tid(tid), m_num_collective_nodes(num_collective_nodes) {}

Expand Down Expand Up @@ -442,6 +447,8 @@ class handler {
}
m_task =
detail::task::make_host_compute(m_tid, geometry, std::move(launcher), std::move(m_access_map), std::move(m_side_effects), std::move(m_reductions));

m_task->set_debug_name(m_usr_def_task_name.value_or(""));
}

void create_device_compute_task(detail::task_geometry geometry, std::string debug_name, std::unique_ptr<detail::command_launcher_storage_base> launcher) {
Expand All @@ -456,17 +463,24 @@ class handler {
// Note that cgf_diagnostics has a similar check, but we don't catch void side effects there.
if(!m_side_effects.empty()) { throw std::runtime_error{"Side effects cannot be used in device kernels"}; }
m_task =
detail::task::make_device_compute(m_tid, geometry, std::move(launcher), std::move(m_access_map), std::move(m_reductions), std::move(debug_name));
detail::task::make_device_compute(m_tid, geometry, std::move(launcher), std::move(m_access_map), std::move(m_reductions));

m_task->set_debug_name(m_usr_def_task_name.value_or(debug_name));

}

void create_collective_task(detail::collective_group_id cgid, std::unique_ptr<detail::command_launcher_storage_base> launcher) {
assert(m_task == nullptr);
m_task = detail::task::make_collective(m_tid, cgid, m_num_collective_nodes, std::move(launcher), std::move(m_access_map), std::move(m_side_effects));

m_task->set_debug_name(m_usr_def_task_name.value_or(""));
}

void create_master_node_task(std::unique_ptr<detail::command_launcher_storage_base> launcher) {
assert(m_task == nullptr);
m_task = detail::task::make_master_node(m_tid, std::move(launcher), std::move(m_access_map), std::move(m_side_effects));

m_task->set_debug_name(m_usr_def_task_name.value_or(""));
}

template <typename KernelFlavor, typename KernelName, int Dims, typename Kernel, size_t... ReductionIndices, typename... Reductions>
Expand Down Expand Up @@ -574,6 +588,10 @@ namespace detail {

inline void extend_lifetime(handler& cgh, std::shared_ptr<detail::lifetime_extending_state> state) { cgh.extend_lifetime(std::move(state)); }

inline void set_task_name(handler& cgh, const std::string& debug_name) {
cgh.m_usr_def_task_name = {debug_name};
}

// TODO: The _impl functions in detail only exist during the grace period for deprecated reductions on const buffers; move outside again afterwards.
template <typename DataT, int Dims, typename BinaryOperation>
auto reduction_impl(const buffer<DataT, Dims>& vars, handler& cgh, BinaryOperation combiner, const cl::sycl::property_list& prop_list = {}) {
Expand Down
5 changes: 3 additions & 2 deletions include/task.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ namespace detail {

range<3> get_granularity() const { return m_geometry.granularity; }

void set_debug_name(const std::string& debug_name) { m_debug_name = debug_name; }
const std::string& get_debug_name() const { return m_debug_name; }

bool has_variable_split() const { return m_type == task_type::host_compute || m_type == task_type::device_compute; }
Expand Down Expand Up @@ -215,9 +216,9 @@ namespace detail {
}

static std::unique_ptr<task> make_device_compute(task_id tid, task_geometry geometry, std::unique_ptr<command_launcher_storage_base> launcher,
buffer_access_map access_map, reduction_set reductions, std::string debug_name) {
buffer_access_map access_map, reduction_set reductions) {
return std::unique_ptr<task>(new task(tid, task_type::device_compute, collective_group_id{}, geometry, std::move(launcher), std::move(access_map),
{}, std::move(reductions), std::move(debug_name), {}, nullptr));
{}, std::move(reductions), {}, {}, nullptr));
}

static std::unique_ptr<task> make_collective(task_id tid, collective_group_id cgid, size_t num_collective_nodes,
Expand Down
1 change: 1 addition & 0 deletions test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ set(TEST_TARGETS
accessor_tests
backend_tests
buffer_manager_tests
debug_naming_tests
graph_generation_tests
graph_gen_granularity_tests
graph_gen_reduction_tests
Expand Down
53 changes: 53 additions & 0 deletions test/debug_naming_tests.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#include "task.h"
#include "task_manager.h"
#include "types.h"

#include <catch2/catch_template_test_macros.hpp>
#include <catch2/catch_test_macros.hpp>
#include <catch2/matchers/catch_matchers_string.hpp>

#include <celerity.h>

#include "test_utils.h"

using namespace celerity;
using namespace celerity::detail;

TEST_CASE("debug names can be set and retrieved from tasks", "[debug]") {
const std::string task_name = "sample task";

auto tt = test_utils::task_test_context{};

SECTION("Host Task") {
const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {
celerity::debug::set_task_name(cgh, task_name);
});

const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {});

CHECK(tt.tm.get_task(tid_a)->get_debug_name() == task_name);
CHECK(tt.tm.get_task(tid_b)->get_debug_name().empty());
}

SECTION("Compute Task") {
const auto tid_a = test_utils::add_compute_task<class compute_task>(tt.tm, [&](handler& cgh) {
celerity::debug::set_task_name(cgh, task_name);
});

const auto tid_b = test_utils::add_compute_task<class compute_task_unnamed>(tt.tm, [&](handler& cgh) {});

CHECK(tt.tm.get_task(tid_a)->get_debug_name() == task_name);
CHECK_THAT(tt.tm.get_task(tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("compute_task_unnamed"));
}

SECTION("ND Range Task") {
const auto tid_a = test_utils::add_nd_range_compute_task<class nd_range_task>(tt.tm, [&](handler& cgh) {
celerity::debug::set_task_name(cgh, task_name);
});

const auto tid_b = test_utils::add_compute_task<class nd_range_task_unnamed>(tt.tm, [&](handler& cgh) {});

CHECK(tt.tm.get_task(tid_a)->get_debug_name() == task_name);
CHECK_THAT(tt.tm.get_task(tid_b)->get_debug_name(), Catch::Matchers::ContainsSubstring("nd_range_task_unnamed"));
}
}

0 comments on commit 0f7145c

Please sign in to comment.