diff --git a/include/distributed_graph_generator.h b/include/distributed_graph_generator.h index e4206cfcb..13d2f7ae4 100644 --- a/include/distributed_graph_generator.h +++ b/include/distributed_graph_generator.h @@ -75,7 +75,7 @@ class distributed_graph_generator { public: distributed_graph_generator( - const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, const std::optional& recorder); + const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, detail::command_recorder* recorder); void add_buffer(const buffer_id bid, const int dims, const range<3>& range); @@ -83,8 +83,6 @@ class distributed_graph_generator { command_graph& get_command_graph() { return m_cdag; } - std::string print_command_graph() const; - private: // Wrapper around command_graph::create that adds commands to current batch set. template @@ -145,7 +143,8 @@ class distributed_graph_generator { // Side effects on the same host object create true dependencies between task commands, so we track the last effect per host object. side_effect_map m_host_object_last_effects; - std::optional m_recorder; + // Generated commands will be recorded to this recorder if it is set + detail::command_recorder* m_recorder = nullptr; }; } // namespace celerity::detail diff --git a/include/print_graph.h b/include/print_graph.h index 3396e6059..dfc42094d 100644 --- a/include/print_graph.h +++ b/include/print_graph.h @@ -47,15 +47,15 @@ namespace detail { struct task_printing_information { task_printing_information(const task& from, const buffer_manager* buff_man); - const task_id m_tid; - const std::string m_debug_name; - const collective_group_id m_cgid; - const task_type m_type; - const task_geometry m_geometry; - const reduction_list m_reductions; - const access_list m_accesses; - const side_effect_map m_side_effect_map; - const task_dependency_list m_dependencies; + const task_id tid; + const std::string debug_name; + const collective_group_id cgid; + const task_type type; + const task_geometry geometry; + const reduction_list reductions; + const access_list accesses; + const side_effect_map side_effect_map; + const task_dependency_list dependencies; }; class task_recorder { @@ -73,35 +73,33 @@ namespace detail { const buffer_manager* m_buff_man; }; - const std::optional no_task_recorder = {}; - // Command recording using command_dependency_list = std::vector>; struct command_printing_information { - const command_id m_cid; - const command_type m_type; - - const std::optional m_epoch_action; - const std::optional> m_execution_range; - const std::optional m_reduction_id; - const std::optional m_buffer_id; - const std::string m_buffer_name; - const std::optional m_target; - const std::optional> m_await_region; - const std::optional> m_push_range; - const std::optional m_transfer_id; - const std::optional m_task_id; - const std::optional m_task_geometry; - const bool m_is_reduction_initializer; - const std::optional m_accesses; - const std::optional m_reductions; - const std::optional m_side_effects; - const command_dependency_list m_dependencies; - const std::string m_task_name; - const std::optional m_task_type; - const std::optional m_collective_group_id; + const command_id cid; + const command_type type; + + const std::optional epoch_action; + const std::optional> execution_range; + const std::optional reduction_id; + const std::optional buffer_id; + const std::string buffer_name; + const std::optional target; + const std::optional> await_region; + const std::optional> push_range; + const std::optional transfer_id; + const std::optional task_id; + const std::optional task_geometry; + const bool is_reduction_initializer; + const std::optional accesses; + const std::optional reductions; + const std::optional side_effects; + const command_dependency_list dependencies; + const std::string task_name; + const std::optional task_type; + const std::optional collective_group_id; command_printing_information(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man); }; @@ -122,8 +120,6 @@ namespace detail { const buffer_manager* m_buff_man; }; - const std::optional no_command_recorder = {}; - // Printing interface std::string print_task_graph(const task_recorder& recorder); diff --git a/include/runtime.h b/include/runtime.h index be5a552f9..18f5ef9a0 100644 --- a/include/runtime.h +++ b/include/runtime.h @@ -9,6 +9,7 @@ #include "device_queue.h" #include "frame.h" #include "host_queue.h" +#include "print_graph.h" #include "types.h" namespace celerity { @@ -110,6 +111,9 @@ namespace detail { std::unique_ptr m_task_mngr; std::unique_ptr m_exec; + std::unique_ptr m_task_recorder; + std::unique_ptr m_command_recorder; + runtime(int* argc, char** argv[], device_or_selector user_device_or_selector); runtime(const runtime&) = delete; runtime(runtime&&) = delete; diff --git a/include/scheduler.h b/include/scheduler.h index 33bd7f512..77723b3a4 100644 --- a/include/scheduler.h +++ b/include/scheduler.h @@ -34,8 +34,6 @@ namespace detail { void notify_buffer_registered(const buffer_id bid, const int dims, const range<3>& range) { notify(event_buffer_registered{bid, dims, range}); } - std::string print_command_graph() const; - protected: /** * This is called by the worker thread. diff --git a/include/task_manager.h b/include/task_manager.h index 2f9031c04..5c36b4966 100644 --- a/include/task_manager.h +++ b/include/task_manager.h @@ -59,7 +59,7 @@ namespace detail { public: constexpr inline static task_id initial_epoch_task = 0; - task_manager(size_t num_collective_nodes, host_queue* queue, std::optional recorder); + task_manager(size_t num_collective_nodes, host_queue* queue, detail::task_recorder* recorder); virtual ~task_manager() = default; @@ -130,8 +130,6 @@ namespace detail { */ const task* get_task(task_id tid) const; - std::string print_task_graph() const; - /** * Blocks until an epoch task has executed on this node (or all nodes, if the epoch_for_new_tasks was created with `epoch_action::barrier`). */ @@ -222,7 +220,7 @@ namespace detail { std::unordered_set m_execution_front; // An optional task_recorder which records information about tasks for e.g. printing graphs. - mutable std::optional m_task_recorder; + mutable detail::task_recorder* m_task_recorder; task& register_task_internal(task_ring_buffer::reservation&& reserve, std::unique_ptr task); diff --git a/src/distributed_graph_generator.cc b/src/distributed_graph_generator.cc index 8681929a9..c488ccd48 100644 --- a/src/distributed_graph_generator.cc +++ b/src/distributed_graph_generator.cc @@ -9,7 +9,7 @@ namespace celerity::detail { distributed_graph_generator::distributed_graph_generator( - const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, const std::optional& recorder) + const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, detail::command_recorder* recorder) : m_num_nodes(num_nodes), m_local_nid(local_nid), m_cdag(cdag), m_task_mngr(tm), m_recorder(recorder) { if(m_num_nodes > max_num_nodes) { throw std::runtime_error(fmt::format("Number of nodes requested ({}) exceeds compile-time maximum of {}", m_num_nodes, max_num_nodes)); @@ -20,7 +20,7 @@ distributed_graph_generator::distributed_graph_generator( // set_epoch_for_new_commands). auto* const epoch_cmd = cdag.create(task_manager::initial_epoch_task, epoch_action::none); epoch_cmd->mark_as_flushed(); // there is no point in flushing the initial epoch command - if(m_recorder.has_value()) m_recorder->record_command(*epoch_cmd); + if(m_recorder != nullptr) m_recorder->record_command(*epoch_cmd); m_epoch_for_new_commands = epoch_cmd->get_cid(); } @@ -135,7 +135,7 @@ std::unordered_set distributed_graph_generator::build_task(co prune_commands_before(epoch_to_prune_before); // If we have a command_recorder, record the current batch of commands - if(m_recorder) { + if(m_recorder != nullptr) { for(const auto& cmd : m_current_cmd_batch) { m_recorder->record_command(*cmd); } @@ -144,12 +144,6 @@ std::unordered_set distributed_graph_generator::build_task(co return std::move(m_current_cmd_batch); } -std::string distributed_graph_generator::print_command_graph() const { - if(m_recorder.has_value()) { return detail::print_command_graph(m_local_nid, m_recorder.value()); } - CELERITY_ERROR("Trying to print command graph, but no recorder available"); - return ""; -} - void distributed_graph_generator::generate_distributed_commands(const task& tsk) { chunk<3> full_chunk{tsk.get_global_offset(), tsk.get_global_size(), tsk.get_global_size()}; const size_t num_chunks = m_num_nodes * 1; // TODO Make configurable diff --git a/src/print_graph.cc b/src/print_graph.cc index aaeca0b89..b1daaeff8 100644 --- a/src/print_graph.cc +++ b/src/print_graph.cc @@ -8,420 +8,412 @@ #include "grid.h" #include "task_manager.h" -namespace celerity { -namespace detail { +namespace celerity::detail { - namespace { - std::string get_buffer_name(const buffer_id bid, const buffer_manager* buff_man) { - return buff_man != nullptr ? buff_man->get_buffer_info(bid).debug_name : ""; - } - - access_list build_access_list(const task& tsk, const buffer_manager* buff_man, const std::optional> execution_range = {}) { - access_list ret; - const auto exec_range = execution_range.value_or(subrange<3>{tsk.get_global_offset(), tsk.get_global_size()}); - const auto& bam = tsk.get_buffer_access_map(); - for(const auto bid : bam.get_accessed_buffers()) { - for(const auto mode : bam.get_access_modes(bid)) { - const auto req = bam.get_mode_requirements(bid, mode, tsk.get_dimensions(), exec_range, tsk.get_global_size()); - ret.push_back({bid, get_buffer_name(bid, buff_man), mode, req}); - } - } - return ret; - } +std::string get_buffer_name(const buffer_id bid, const buffer_manager* buff_man) { + return buff_man != nullptr ? buff_man->get_buffer_info(bid).debug_name : ""; +} - reduction_list build_reduction_list(const task& tsk, const buffer_manager* buff_man) { - reduction_list ret; - for(const auto& reduction : tsk.get_reductions()) { - ret.push_back({reduction.rid, reduction.bid, get_buffer_name(reduction.bid, buff_man), reduction.init_from_buffer}); - } - return ret; +access_list build_access_list(const task& tsk, const buffer_manager* buff_man, const std::optional> execution_range = {}) { + access_list ret; + const auto exec_range = execution_range.value_or(subrange<3>{tsk.get_global_offset(), tsk.get_global_size()}); + const auto& bam = tsk.get_buffer_access_map(); + for(const auto bid : bam.get_accessed_buffers()) { + for(const auto mode : bam.get_access_modes(bid)) { + const auto req = bam.get_mode_requirements(bid, mode, tsk.get_dimensions(), exec_range, tsk.get_global_size()); + ret.push_back({bid, get_buffer_name(bid, buff_man), mode, req}); } - - task_dependency_list build_task_dependency_list(const task& tsk) { - task_dependency_list ret; - for(const auto& dep : tsk.get_dependencies()) { - ret.push_back({dep.node->get_id(), dep.kind, dep.origin}); - } - return ret; - } - - // removes initial template qualifiers to simplify, and escapes '<' and '>' in the given name, - // so that it can be successfully used in a dot graph label that uses HTML, and is hopefully readable - std::string simplify_and_escape_name(const std::string& name) { - // simplify - auto first_opening_pos = name.find('<'); - auto namespace_qual_end_pos = name.rfind(':', first_opening_pos); - auto simplified = namespace_qual_end_pos != std::string::npos ? name.substr(namespace_qual_end_pos + 1) : name; - // escape - simplified = std::regex_replace(simplified, std::regex("<"), "<"); - return std::regex_replace(simplified, std::regex(">"), ">"); - } - } // namespace - - task_printing_information::task_printing_information(const task& from, const buffer_manager* buff_man) - : m_tid(from.get_id()), m_debug_name(simplify_and_escape_name(from.get_debug_name())), m_cgid(from.get_collective_group_id()), m_type(from.get_type()), - m_geometry(from.get_geometry()), m_reductions(build_reduction_list(from, buff_man)), m_accesses(build_access_list(from, buff_man)), - m_side_effect_map(from.get_side_effect_map()), m_dependencies(build_task_dependency_list(from)) {} - - void task_recorder::record_task(const task& tsk) { - CELERITY_TRACE("Recording task {}", tsk.get_id()); - m_recorded_tasks.emplace_back(tsk, m_buff_man); } + return ret; +} - namespace { - command_type get_command_type(const abstract_command& cmd) { - if(isa(&cmd)) return command_type::epoch; - if(isa(&cmd)) return command_type::horizon; - if(isa(&cmd)) return command_type::execution; - if(isa(&cmd)) return command_type::push; - if(isa(&cmd)) return command_type::await_push; - if(isa(&cmd)) return command_type::reduction; - if(isa(&cmd)) return command_type::fence; - CELERITY_CRITICAL("Unexpected command type"); - std::terminate(); - } - - std::optional get_epoch_action(const abstract_command& cmd) { - const auto* epoch_cmd = dynamic_cast(&cmd); - return epoch_cmd != nullptr ? epoch_cmd->get_epoch_action() : std::optional{}; - } - - std::optional> get_execution_range(const abstract_command& cmd) { - const auto* execution_cmd = dynamic_cast(&cmd); - return execution_cmd != nullptr ? execution_cmd->get_execution_range() : std::optional>{}; - } - - std::optional get_reduction_id(const abstract_command& cmd) { - if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_reduction_id(); - if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_reduction_id(); - if(const auto* reduction_cmd = dynamic_cast(&cmd)) return reduction_cmd->get_reduction_info().rid; - return {}; - } - - std::optional get_buffer_id(const abstract_command& cmd) { - if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_bid(); - if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_bid(); - if(const auto* reduction_cmd = dynamic_cast(&cmd)) return reduction_cmd->get_reduction_info().bid; - return {}; - } - - std::string get_cmd_buffer_name(const std::optional& bid, const buffer_manager* buff_man) { - if(buff_man == nullptr || !bid.has_value()) return ""; - return get_buffer_name(bid.value(), buff_man); - } - - std::optional get_target(const abstract_command& cmd) { - if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_target(); - return {}; - } - - std::optional> get_await_region(const abstract_command& cmd) { - if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_region(); - return {}; - } - - std::optional> get_push_range(const abstract_command& cmd) { - if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_range(); - return {}; - } - - std::optional get_transfer_id(const abstract_command& cmd) { - if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_transfer_id(); - if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_transfer_id(); - return {}; - } - - std::optional get_task_id(const abstract_command& cmd) { - if(const auto* task_cmd = dynamic_cast(&cmd)) return task_cmd->get_tid(); - return {}; - } - - const task* get_task_for(const abstract_command& cmd, const task_manager* task_man) { - if(const auto* task_cmd = dynamic_cast(&cmd)) { - if(task_man != nullptr) { - assert(task_man->has_task(task_cmd->get_tid())); - return task_man->get_task(task_cmd->get_tid()); - } - } - return nullptr; - } - - std::optional get_task_geometry(const abstract_command& cmd, const task_manager* task_man) { - if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_geometry(); - return {}; - } - - bool get_is_reduction_initializer(const abstract_command& cmd) { - if(const auto* execution_cmd = dynamic_cast(&cmd)) return execution_cmd->is_reduction_initializer(); - return false; - } - - access_list build_cmd_access_list(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man) { - if(const auto* tsk = get_task_for(cmd, task_man)) { - const auto execution_range = get_execution_range(cmd).value_or(subrange<3>{tsk->get_global_offset(), tsk->get_global_size()}); - return build_access_list(*tsk, buff_man, execution_range); - } - return {}; - } - - reduction_list build_cmd_reduction_list(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man) { - if(const auto* tsk = get_task_for(cmd, task_man)) return build_reduction_list(*tsk, buff_man); - return {}; - } - - side_effect_map get_side_effects(const abstract_command& cmd, const task_manager* task_man) { - if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_side_effect_map(); - return {}; - } - - command_dependency_list build_command_dependency_list(const abstract_command& cmd) { - command_dependency_list ret; - for(const auto& dep : cmd.get_dependencies()) { - ret.push_back({dep.node->get_cid(), dep.kind, dep.origin}); - } - return ret; - } - - std::string get_task_name(const abstract_command& cmd, const task_manager* task_man) { - if(const auto* tsk = get_task_for(cmd, task_man)) return simplify_and_escape_name(tsk->get_debug_name()); - return {}; - } - - std::optional get_task_type(const abstract_command& cmd, const task_manager* task_man) { - if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_type(); - return {}; - } - - std::optional get_collective_group_id(const abstract_command& cmd, const task_manager* task_man) { - if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_collective_group_id(); - return {}; - } - } // namespace - - command_printing_information::command_printing_information(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man) - : m_cid(cmd.get_cid()), m_type(get_command_type(cmd)), m_epoch_action(get_epoch_action(cmd)), m_execution_range(get_execution_range(cmd)), - m_reduction_id(get_reduction_id(cmd)), m_buffer_id(get_buffer_id(cmd)), m_buffer_name(get_cmd_buffer_name(m_buffer_id, buff_man)), - m_target(get_target(cmd)), m_await_region(get_await_region(cmd)), m_push_range(get_push_range(cmd)), m_transfer_id(get_transfer_id(cmd)), - m_task_id(get_task_id(cmd)), m_task_geometry(get_task_geometry(cmd, task_man)), m_is_reduction_initializer(get_is_reduction_initializer(cmd)), - m_accesses(build_cmd_access_list(cmd, task_man, buff_man)), m_reductions(build_cmd_reduction_list(cmd, task_man, buff_man)), - m_side_effects(get_side_effects(cmd, task_man)), m_dependencies(build_command_dependency_list(cmd)), m_task_name(get_task_name(cmd, task_man)), - m_task_type(get_task_type(cmd, task_man)), m_collective_group_id(get_collective_group_id(cmd, task_man)) {} - - void command_recorder::record_command(const abstract_command& com) { - CELERITY_TRACE("Recording command {}", com.get_cid()); - m_recorded_commands.emplace_back(com, m_task_man, m_buff_man); +reduction_list build_reduction_list(const task& tsk, const buffer_manager* buff_man) { + reduction_list ret; + for(const auto& reduction : tsk.get_reductions()) { + ret.push_back({reduction.rid, reduction.bid, get_buffer_name(reduction.bid, buff_man), reduction.init_from_buffer}); } + return ret; +} - - template - const char* dependency_style(const Dependency& dep) { - if(dep.kind == dependency_kind::anti_dep) return "color=limegreen"; - switch(dep.origin) { - case dependency_origin::collective_group_serialization: return "color=blue"; - case dependency_origin::execution_front: return "color=orange"; - case dependency_origin::last_epoch: return "color=orchid"; - default: return ""; - } +task_dependency_list build_task_dependency_list(const task& tsk) { + task_dependency_list ret; + for(const auto& dep : tsk.get_dependencies()) { + ret.push_back({dep.node->get_id(), dep.kind, dep.origin}); } - - const char* task_type_string(const task_type tt) { - switch(tt) { - case task_type::epoch: return "epoch"; - case task_type::host_compute: return "host-compute"; - case task_type::device_compute: return "device-compute"; - case task_type::collective: return "collective host"; - case task_type::master_node: return "master-node host"; - case task_type::horizon: return "horizon"; - case task_type::fence: return "fence"; - default: return "unknown"; + return ret; +} + +// removes initial template qualifiers to simplify, and escapes '<' and '>' in the given name, +// so that it can be successfully used in a dot graph label that uses HTML, and is hopefully readable +std::string simplify_and_escape_name(const std::string& name) { + // simplify + auto first_opening_pos = name.find('<'); + auto namespace_qual_end_pos = name.rfind(':', first_opening_pos); + auto simplified = namespace_qual_end_pos != std::string::npos ? name.substr(namespace_qual_end_pos + 1) : name; + // escape + simplified = std::regex_replace(simplified, std::regex("<"), "<"); + return std::regex_replace(simplified, std::regex(">"), ">"); +} + +task_printing_information::task_printing_information(const task& from, const buffer_manager* buff_man) + : tid(from.get_id()), debug_name(simplify_and_escape_name(from.get_debug_name())), cgid(from.get_collective_group_id()), type(from.get_type()), + geometry(from.get_geometry()), reductions(build_reduction_list(from, buff_man)), accesses(build_access_list(from, buff_man)), + side_effect_map(from.get_side_effect_map()), dependencies(build_task_dependency_list(from)) {} + +void task_recorder::record_task(const task& tsk) { + CELERITY_TRACE("Recording task {}", tsk.get_id()); + m_recorded_tasks.emplace_back(tsk, m_buff_man); +} + +command_type get_command_type(const abstract_command& cmd) { + if(isa(&cmd)) return command_type::epoch; + if(isa(&cmd)) return command_type::horizon; + if(isa(&cmd)) return command_type::execution; + if(isa(&cmd)) return command_type::push; + if(isa(&cmd)) return command_type::await_push; + if(isa(&cmd)) return command_type::reduction; + if(isa(&cmd)) return command_type::fence; + CELERITY_CRITICAL("Unexpected command type"); + std::terminate(); +} + +std::optional get_epoch_action(const abstract_command& cmd) { + const auto* epoch_cmd = dynamic_cast(&cmd); + return epoch_cmd != nullptr ? epoch_cmd->get_epoch_action() : std::optional{}; +} + +std::optional> get_execution_range(const abstract_command& cmd) { + const auto* execution_cmd = dynamic_cast(&cmd); + return execution_cmd != nullptr ? execution_cmd->get_execution_range() : std::optional>{}; +} + +std::optional get_reduction_id(const abstract_command& cmd) { + if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_reduction_id(); + if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_reduction_id(); + if(const auto* reduction_cmd = dynamic_cast(&cmd)) return reduction_cmd->get_reduction_info().rid; + return {}; +} + +std::optional get_buffer_id(const abstract_command& cmd) { + if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_bid(); + if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_bid(); + if(const auto* reduction_cmd = dynamic_cast(&cmd)) return reduction_cmd->get_reduction_info().bid; + return {}; +} + +std::string get_cmd_buffer_name(const std::optional& bid, const buffer_manager* buff_man) { + if(buff_man == nullptr || !bid.has_value()) return ""; + return get_buffer_name(bid.value(), buff_man); +} + +std::optional get_target(const abstract_command& cmd) { + if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_target(); + return {}; +} + +std::optional> get_await_region(const abstract_command& cmd) { + if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_region(); + return {}; +} + +std::optional> get_push_range(const abstract_command& cmd) { + if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_range(); + return {}; +} + +std::optional get_transfer_id(const abstract_command& cmd) { + if(const auto* push_cmd = dynamic_cast(&cmd)) return push_cmd->get_transfer_id(); + if(const auto* await_push_cmd = dynamic_cast(&cmd)) return await_push_cmd->get_transfer_id(); + return {}; +} + +std::optional get_task_id(const abstract_command& cmd) { + if(const auto* task_cmd = dynamic_cast(&cmd)) return task_cmd->get_tid(); + return {}; +} + +const task* get_task_for(const abstract_command& cmd, const task_manager* task_man) { + if(const auto* task_cmd = dynamic_cast(&cmd)) { + if(task_man != nullptr) { + assert(task_man->has_task(task_cmd->get_tid())); + return task_man->get_task(task_cmd->get_tid()); } } - - std::string get_buffer_label(const buffer_id bid, const std::string& name = "") { - // if there is no name defined, the name will be the buffer id. - // if there is a name we want "id name" - return !name.empty() ? fmt::format("B{} \"{}\"", bid, name) : fmt::format("B{}", bid); + return nullptr; +} + +std::optional get_task_geometry(const abstract_command& cmd, const task_manager* task_man) { + if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_geometry(); + return {}; +} + +bool get_is_reduction_initializer(const abstract_command& cmd) { + if(const auto* execution_cmd = dynamic_cast(&cmd)) return execution_cmd->is_reduction_initializer(); + return false; +} + +access_list build_cmd_access_list(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man) { + if(const auto* tsk = get_task_for(cmd, task_man)) { + const auto execution_range = get_execution_range(cmd).value_or(subrange<3>{tsk->get_global_offset(), tsk->get_global_size()}); + return build_access_list(*tsk, buff_man, execution_range); } - - void format_requirements(std::string& label, const reduction_list& reductions, const access_list& accesses, const side_effect_map& side_effects, - const access_mode reduction_init_mode) { - for(const auto& [rid, bid, buffer_name, init_from_buffer] : reductions) { - auto rmode = init_from_buffer ? reduction_init_mode : cl::sycl::access::mode::discard_write; - const auto req = GridRegion<3>{{1, 1, 1}}; - const std::string bl = get_buffer_label(bid, buffer_name); - fmt::format_to(std::back_inserter(label), "
(R{}) {} {} {}", rid, detail::access::mode_traits::name(rmode), bl, req); - } - - for(const auto& [bid, buffer_name, mode, req] : accesses) { - const std::string bl = get_buffer_label(bid, buffer_name); - // While uncommon, we do support chunks that don't require access to a particular buffer at all. - if(!req.empty()) { fmt::format_to(std::back_inserter(label), "
{} {} {}", detail::access::mode_traits::name(mode), bl, req); } - } - - for(const auto& [hoid, order] : side_effects) { - fmt::format_to(std::back_inserter(label), "
affect H{}", hoid); - } + return {}; +} + +reduction_list build_cmd_reduction_list(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man) { + if(const auto* tsk = get_task_for(cmd, task_man)) return build_reduction_list(*tsk, buff_man); + return {}; +} + +side_effect_map get_side_effects(const abstract_command& cmd, const task_manager* task_man) { + if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_side_effect_map(); + return {}; +} + +command_dependency_list build_command_dependency_list(const abstract_command& cmd) { + command_dependency_list ret; + for(const auto& dep : cmd.get_dependencies()) { + ret.push_back({dep.node->get_cid(), dep.kind, dep.origin}); + } + return ret; +} + +std::string get_task_name(const abstract_command& cmd, const task_manager* task_man) { + if(const auto* tsk = get_task_for(cmd, task_man)) return simplify_and_escape_name(tsk->get_debug_name()); + return {}; +} + +std::optional get_task_type(const abstract_command& cmd, const task_manager* task_man) { + if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_type(); + return {}; +} + +std::optional get_collective_group_id(const abstract_command& cmd, const task_manager* task_man) { + if(const auto* tsk = get_task_for(cmd, task_man)) return tsk->get_collective_group_id(); + return {}; +} + +command_printing_information::command_printing_information(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man) + : cid(cmd.get_cid()), type(get_command_type(cmd)), epoch_action(get_epoch_action(cmd)), execution_range(get_execution_range(cmd)), + reduction_id(get_reduction_id(cmd)), buffer_id(get_buffer_id(cmd)), buffer_name(get_cmd_buffer_name(buffer_id, buff_man)), target(get_target(cmd)), + await_region(get_await_region(cmd)), push_range(get_push_range(cmd)), transfer_id(get_transfer_id(cmd)), task_id(get_task_id(cmd)), + task_geometry(get_task_geometry(cmd, task_man)), is_reduction_initializer(get_is_reduction_initializer(cmd)), + accesses(build_cmd_access_list(cmd, task_man, buff_man)), reductions(build_cmd_reduction_list(cmd, task_man, buff_man)), + side_effects(get_side_effects(cmd, task_man)), dependencies(build_command_dependency_list(cmd)), task_name(get_task_name(cmd, task_man)), + task_type(get_task_type(cmd, task_man)), collective_group_id(get_collective_group_id(cmd, task_man)) {} + +void command_recorder::record_command(const abstract_command& com) { + CELERITY_TRACE("Recording command {}", com.get_cid()); + m_recorded_commands.emplace_back(com, m_task_man, m_buff_man); +} + + +template +const char* dependency_style(const Dependency& dep) { + if(dep.kind == dependency_kind::anti_dep) return "color=limegreen"; + switch(dep.origin) { + case dependency_origin::collective_group_serialization: return "color=blue"; + case dependency_origin::execution_front: return "color=orange"; + case dependency_origin::last_epoch: return "color=orchid"; + default: return ""; + } +} + +const char* task_type_string(const task_type tt) { + switch(tt) { + case task_type::epoch: return "epoch"; + case task_type::host_compute: return "host-compute"; + case task_type::device_compute: return "device-compute"; + case task_type::collective: return "collective host"; + case task_type::master_node: return "master-node host"; + case task_type::horizon: return "horizon"; + case task_type::fence: return "fence"; + default: return "unknown"; + } +} + +std::string get_buffer_label(const buffer_id bid, const std::string& name = "") { + // if there is no name defined, the name will be the buffer id. + // if there is a name we want "id name" + return !name.empty() ? fmt::format("B{} \"{}\"", bid, name) : fmt::format("B{}", bid); +} + +void format_requirements(std::string& label, const reduction_list& reductions, const access_list& accesses, const side_effect_map& side_effects, + const access_mode reduction_init_mode) { + for(const auto& [rid, bid, buffer_name, init_from_buffer] : reductions) { + auto rmode = init_from_buffer ? reduction_init_mode : cl::sycl::access::mode::discard_write; + const auto req = GridRegion<3>{{1, 1, 1}}; + const std::string bl = get_buffer_label(bid, buffer_name); + fmt::format_to(std::back_inserter(label), "
(R{}) {} {} {}", rid, detail::access::mode_traits::name(rmode), bl, req); } - std::string get_task_label(const task_printing_information& tsk) { - std::string label; - fmt::format_to(std::back_inserter(label), "T{}", tsk.m_tid); - if(!tsk.m_debug_name.empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.m_debug_name); } - - fmt::format_to(std::back_inserter(label), "
{}", task_type_string(tsk.m_type)); - if(tsk.m_type == task_type::host_compute || tsk.m_type == task_type::device_compute) { - fmt::format_to(std::back_inserter(label), " {}", subrange<3>{tsk.m_geometry.global_offset, tsk.m_geometry.global_size}); - } else if(tsk.m_type == task_type::collective) { - fmt::format_to(std::back_inserter(label), " in CG{}", tsk.m_cgid); - } - - format_requirements(label, tsk.m_reductions, tsk.m_accesses, tsk.m_side_effect_map, access_mode::read_write); + for(const auto& [bid, buffer_name, mode, req] : accesses) { + const std::string bl = get_buffer_label(bid, buffer_name); + // While uncommon, we do support chunks that don't require access to a particular buffer at all. + if(!req.empty()) { fmt::format_to(std::back_inserter(label), "
{} {} {}", detail::access::mode_traits::name(mode), bl, req); } + } - return label; + for(const auto& [hoid, order] : side_effects) { + fmt::format_to(std::back_inserter(label), "
affect H{}", hoid); + } +} + +std::string get_task_label(const task_printing_information& tsk) { + std::string label; + fmt::format_to(std::back_inserter(label), "T{}", tsk.tid); + if(!tsk.debug_name.empty()) { fmt::format_to(std::back_inserter(label), " \"{}\" ", tsk.debug_name); } + + fmt::format_to(std::back_inserter(label), "
{}", task_type_string(tsk.type)); + if(tsk.type == task_type::host_compute || tsk.type == task_type::device_compute) { + fmt::format_to(std::back_inserter(label), " {}", subrange<3>{tsk.geometry.global_offset, tsk.geometry.global_size}); + } else if(tsk.type == task_type::collective) { + fmt::format_to(std::back_inserter(label), " in CG{}", tsk.cgid); } - std::string print_task_graph(const task_recorder& recorder) { - std::string dot = "digraph G {label=\"Task Graph\" "; + format_requirements(label, tsk.reductions, tsk.accesses, tsk.side_effect_map, access_mode::read_write); - CELERITY_DEBUG("print_task_graph, {} entries", recorder.get_tasks().size()); + return label; +} - for(const auto& tsk : recorder.get_tasks()) { - const char* shape = tsk.m_type == task_type::epoch || tsk.m_type == task_type::horizon ? "ellipse" : "box style=rounded"; - fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk.m_tid, shape, get_task_label(tsk)); - for(auto d : tsk.m_dependencies) { - fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", d.node, tsk.m_tid, dependency_style(d)); - } - } +std::string print_task_graph(const task_recorder& recorder) { + std::string dot = "digraph G {label=\"Task Graph\" "; - dot += "}"; - return dot; - } + CELERITY_DEBUG("print_task_graph, {} entries", recorder.get_tasks().size()); - std::string get_command_label(const node_id local_nid, const command_printing_information& cmd) { - const command_id cid = cmd.m_cid; - - std::string label = fmt::format("C{} on N{}
", cid, local_nid); - - auto add_reduction_id_if_reduction = [&]() { - if(cmd.m_reduction_id.has_value() && cmd.m_reduction_id != 0) { fmt::format_to(std::back_inserter(label), "(R{}) ", cmd.m_reduction_id.value()); } - }; - const std::string buffer_label = cmd.m_buffer_id.has_value() ? get_buffer_label(cmd.m_buffer_id.value(), cmd.m_buffer_name) : ""; - - switch(cmd.m_type) { - case command_type::epoch: { - label += "epoch"; - if(cmd.m_epoch_action == epoch_action::barrier) { label += " (barrier)"; } - if(cmd.m_epoch_action == epoch_action::shutdown) { label += " (shutdown)"; } - } break; - case command_type::execution: { - fmt::format_to(std::back_inserter(label), "execution {}", subrange_to_grid_box(cmd.m_execution_range.value())); - } break; - case command_type::push: { - add_reduction_id_if_reduction(); - fmt::format_to(std::back_inserter(label), "push transfer {} to N{}
B{} {}", // - cmd.m_transfer_id.value(), cmd.m_target.value(), buffer_label, subrange_to_grid_box(cmd.m_push_range.value())); - } break; - case command_type::await_push: { - add_reduction_id_if_reduction(); - fmt::format_to(std::back_inserter(label), "await push transfer {}
B{} {}", // - cmd.m_transfer_id.value(), buffer_label, cmd.m_await_region.value()); - } break; - case command_type::reduction: { - fmt::format_to(std::back_inserter(label), "reduction R{}
{} {}", cmd.m_reduction_id.value(), buffer_label, GridRegion<3>{{1, 1, 1}}); - } break; - case command_type::horizon: { - label += "horizon"; - } break; - case command_type::fence: { - label += "fence"; - } break; - default: assert(!"Unkown command"); label += "unknown"; + for(const auto& tsk : recorder.get_tasks()) { + const char* shape = tsk.type == task_type::epoch || tsk.type == task_type::horizon ? "ellipse" : "box style=rounded"; + fmt::format_to(std::back_inserter(dot), "{}[shape={} label=<{}>];", tsk.tid, shape, get_task_label(tsk)); + for(auto d : tsk.dependencies) { + fmt::format_to(std::back_inserter(dot), "{}->{}[{}];", d.node, tsk.tid, dependency_style(d)); } + } - if(cmd.m_task_id.has_value() && cmd.m_task_geometry.has_value()) { - auto reduction_init_mode = cmd.m_is_reduction_initializer ? cl::sycl::access::mode::read_write : access_mode::discard_write; + dot += "}"; + return dot; +} + +std::string get_command_label(const node_id local_nid, const command_printing_information& cmd) { + const command_id cid = cmd.cid; + + std::string label = fmt::format("C{} on N{}
", cid, local_nid); + + auto add_reduction_id_if_reduction = [&]() { + if(cmd.reduction_id.has_value() && cmd.reduction_id != 0) { fmt::format_to(std::back_inserter(label), "(R{}) ", cmd.reduction_id.value()); } + }; + const std::string buffer_label = cmd.buffer_id.has_value() ? get_buffer_label(cmd.buffer_id.value(), cmd.buffer_name) : ""; + + switch(cmd.type) { + case command_type::epoch: { + label += "epoch"; + if(cmd.epoch_action == epoch_action::barrier) { label += " (barrier)"; } + if(cmd.epoch_action == epoch_action::shutdown) { label += " (shutdown)"; } + } break; + case command_type::execution: { + fmt::format_to(std::back_inserter(label), "execution {}", subrange_to_grid_box(cmd.execution_range.value())); + } break; + case command_type::push: { + add_reduction_id_if_reduction(); + fmt::format_to(std::back_inserter(label), "push transfer {} to N{}
B{} {}", // + cmd.transfer_id.value(), cmd.target.value(), buffer_label, subrange_to_grid_box(cmd.push_range.value())); + } break; + case command_type::await_push: { + add_reduction_id_if_reduction(); + fmt::format_to(std::back_inserter(label), "await push transfer {}
B{} {}", // + cmd.transfer_id.value(), buffer_label, cmd.await_region.value()); + } break; + case command_type::reduction: { + fmt::format_to(std::back_inserter(label), "reduction R{}
{} {}", cmd.reduction_id.value(), buffer_label, GridRegion<3>{{1, 1, 1}}); + } break; + case command_type::horizon: { + label += "horizon"; + } break; + case command_type::fence: { + label += "fence"; + } break; + default: assert(!"Unkown command"); label += "unknown"; + } - format_requirements(label, cmd.m_reductions.value_or(reduction_list{}), cmd.m_accesses.value_or(access_list{}), - cmd.m_side_effects.value_or(side_effect_map{}), reduction_init_mode); - } + if(cmd.task_id.has_value() && cmd.task_geometry.has_value()) { + auto reduction_init_mode = cmd.is_reduction_initializer ? cl::sycl::access::mode::read_write : access_mode::discard_write; - return label; + format_requirements(label, cmd.reductions.value_or(reduction_list{}), cmd.accesses.value_or(access_list{}), + cmd.side_effects.value_or(side_effect_map{}), reduction_init_mode); } - const std::string command_graph_preamble = "digraph G{label=\"Command Graph\" "; + return label; +} - std::string print_command_graph(const node_id local_nid, const command_recorder& recorder) { - std::string main_dot; - std::map task_subgraph_dot; // this map must be ordered! +const std::string command_graph_preamble = "digraph G{label=\"Command Graph\" "; - const auto local_to_global_id = [local_nid](uint64_t id) { - // IDs in the DOT language may not start with a digit (unless the whole thing is a numeral) - return fmt::format("id_{}_{}", local_nid, id); - }; +std::string print_command_graph(const node_id local_nid, const command_recorder& recorder) { + std::string main_dot; + std::map task_subgraph_dot; // this map must be ordered! - const auto print_vertex = [&](const command_printing_information& cmd) { - static const char* const colors[] = {"black", "crimson", "dodgerblue4", "goldenrod", "maroon4", "springgreen2", "tan1", "chartreuse2"}; + const auto local_to_global_id = [local_nid](uint64_t id) { + // IDs in the DOT language may not start with a digit (unless the whole thing is a numeral) + return fmt::format("id_{}_{}", local_nid, id); + }; - const auto id = local_to_global_id(cmd.m_cid); - const auto label = get_command_label(local_nid, cmd); - const auto* const fontcolor = colors[local_nid % (sizeof(colors) / sizeof(char*))]; - const auto* const shape = cmd.m_task_id.has_value() ? "box" : "ellipse"; - return fmt::format("{}[label=<{}> fontcolor={} shape={}];", id, label, fontcolor, shape); - }; + const auto print_vertex = [&](const command_printing_information& cmd) { + static const char* const colors[] = {"black", "crimson", "dodgerblue4", "goldenrod", "maroon4", "springgreen2", "tan1", "chartreuse2"}; - // we want to iterate over our command records in a sorted order, without moving everything around, and we aren't in C++20 (yet) - std::vector sorted_cmd_pointers; - for(const auto& cmd : recorder.get_commands()) { - sorted_cmd_pointers.push_back(&cmd); - } - std::sort(sorted_cmd_pointers.begin(), sorted_cmd_pointers.end(), [](const auto* a, const auto* b) { return a->m_cid < b->m_cid; }); - - for(const auto& cmd : sorted_cmd_pointers) { - if(cmd->m_task_id.has_value()) { - const auto tid = cmd->m_task_id.value(); - // Add to subgraph as well - if(task_subgraph_dot.count(tid) == 0) { - std::string task_label; - fmt::format_to(std::back_inserter(task_label), "T{} ", tid); - if(!cmd->m_task_name.empty()) { fmt::format_to(std::back_inserter(task_label), "\"{}\" ", cmd->m_task_name); } - task_label += "("; - task_label += task_type_string(cmd->m_task_type.value()); - if(cmd->m_task_type == task_type::collective) { - fmt::format_to(std::back_inserter(task_label), " on CG{}", cmd->m_collective_group_id.value()); - } - task_label += ")"; - - task_subgraph_dot.emplace(tid, - fmt::format("subgraph cluster_{}{{label=<{}>;color=darkgray;", local_to_global_id(tid), task_label)); - } - task_subgraph_dot[tid] += print_vertex(*cmd); - } else { - main_dot += print_vertex(*cmd); - } + const auto id = local_to_global_id(cmd.cid); + const auto label = get_command_label(local_nid, cmd); + const auto* const fontcolor = colors[local_nid % (sizeof(colors) / sizeof(char*))]; + const auto* const shape = cmd.task_id.has_value() ? "box" : "ellipse"; + return fmt::format("{}[label=<{}> fontcolor={} shape={}];", id, label, fontcolor, shape); + }; - for(const auto& d : cmd->m_dependencies) { - fmt::format_to(std::back_inserter(main_dot), "{}->{}[{}];", local_to_global_id(d.node), local_to_global_id(cmd->m_cid), dependency_style(d)); + // we want to iterate over our command records in a sorted order, without moving everything around, and we aren't in C++20 (yet) + std::vector sorted_cmd_pointers; + for(const auto& cmd : recorder.get_commands()) { + sorted_cmd_pointers.push_back(&cmd); + } + std::sort(sorted_cmd_pointers.begin(), sorted_cmd_pointers.end(), [](const auto* a, const auto* b) { return a->cid < b->cid; }); + + for(const auto& cmd : sorted_cmd_pointers) { + if(cmd->task_id.has_value()) { + const auto tid = cmd->task_id.value(); + // Add to subgraph as well + if(task_subgraph_dot.count(tid) == 0) { + std::string task_label; + fmt::format_to(std::back_inserter(task_label), "T{} ", tid); + if(!cmd->task_name.empty()) { fmt::format_to(std::back_inserter(task_label), "\"{}\" ", cmd->task_name); } + task_label += "("; + task_label += task_type_string(cmd->task_type.value()); + if(cmd->task_type == task_type::collective) { fmt::format_to(std::back_inserter(task_label), " on CG{}", cmd->collective_group_id.value()); } + task_label += ")"; + + task_subgraph_dot.emplace( + tid, fmt::format("subgraph cluster_{}{{label=<{}>;color=darkgray;", local_to_global_id(tid), task_label)); } - }; - - std::string result_dot = command_graph_preamble; - for(auto& [_, sg_dot] : task_subgraph_dot) { - result_dot += sg_dot; - result_dot += "}"; + task_subgraph_dot[tid] += print_vertex(*cmd); + } else { + main_dot += print_vertex(*cmd); } - result_dot += main_dot; - result_dot += "}"; - return result_dot; - } - std::string combine_command_graphs(const std::vector& graphs) { - std::string result_dot = command_graph_preamble; - for(const auto& g : graphs) { - result_dot += g.substr(command_graph_preamble.size(), g.size() - command_graph_preamble.size() - 1); + for(const auto& d : cmd->dependencies) { + fmt::format_to(std::back_inserter(main_dot), "{}->{}[{}];", local_to_global_id(d.node), local_to_global_id(cmd->cid), dependency_style(d)); } + }; + + std::string result_dot = command_graph_preamble; + for(auto& [_, sg_dot] : task_subgraph_dot) { + result_dot += sg_dot; result_dot += "}"; - return result_dot; } + result_dot += main_dot; + result_dot += "}"; + return result_dot; +} + +std::string combine_command_graphs(const std::vector& graphs) { + std::string result_dot = command_graph_preamble; + for(const auto& g : graphs) { + result_dot += g.substr(command_graph_preamble.size(), g.size() - command_graph_preamble.size() - 1); + } + result_dot += "}"; + return result_dot; +} -} // namespace detail -} // namespace celerity +} // namespace celerity::detail diff --git a/src/runtime.cc b/src/runtime.cc index 036d8b0a7..5ba78b3e0 100644 --- a/src/runtime.cc +++ b/src/runtime.cc @@ -150,14 +150,12 @@ namespace detail { m_reduction_mngr = std::make_unique(); m_host_object_mngr = std::make_unique(); - std::optional t_rec; - if(m_cfg->is_recording()) t_rec = task_recorder{m_buffer_mngr.get()}; - m_task_mngr = std::make_unique(m_num_nodes, m_h_queue.get(), t_rec); + if(m_cfg->is_recording()) m_task_recorder = std::make_unique(m_buffer_mngr.get()); + m_task_mngr = std::make_unique(m_num_nodes, m_h_queue.get(), m_task_recorder.get()); m_exec = std::make_unique(m_num_nodes, m_local_nid, *m_h_queue, *m_d_queue, *m_task_mngr, *m_buffer_mngr, *m_reduction_mngr); m_cdag = std::make_unique(); - std::optional c_rec; - if(m_cfg->is_recording()) c_rec = command_recorder{m_task_mngr.get(), m_buffer_mngr.get()}; - auto dggen = std::make_unique(m_num_nodes, m_local_nid, *m_cdag, *m_task_mngr, c_rec); + if(m_cfg->is_recording()) m_command_recorder = std::make_unique(m_task_mngr.get(), m_buffer_mngr.get()); + auto dggen = std::make_unique(m_num_nodes, m_local_nid, *m_cdag, *m_task_mngr, m_command_recorder.get()); m_schdlr = std::make_unique(is_dry_run(), std::move(dggen), *m_exec); m_task_mngr->register_task_callback([this](const task* tsk) { m_schdlr->notify_task_created(tsk); }); @@ -178,6 +176,8 @@ namespace detail { m_buffer_mngr.reset(); m_d_queue.reset(); m_h_queue.reset(); + m_command_recorder.reset(); + m_task_recorder.reset(); cgf_diagnostics::teardown(); @@ -209,7 +209,8 @@ namespace detail { if(spdlog::should_log(log_level::trace) && m_cfg->is_recording()) { if(m_local_nid == 0) { // It's the same across all nodes - const auto graph_str = m_task_mngr->print_task_graph(); + assert(m_task_recorder.get() != nullptr); + const auto graph_str = detail::print_task_graph(*m_task_recorder); CELERITY_TRACE("Task graph:\n\n{}\n", graph_str); } // must be called on all nodes @@ -242,25 +243,26 @@ namespace detail { host_object_manager& runtime::get_host_object_manager() const { return *m_host_object_mngr; } std::string runtime::gather_command_graph() const { - const auto graph_str = m_schdlr->print_command_graph(); + assert(m_command_recorder.get() != nullptr); + const auto graph_str = print_command_graph(m_local_nid, *m_command_recorder); // Send local graph to rank 0 if(m_local_nid != 0) { const uint64_t size = graph_str.size(); - MPI_Send(&size, 1, MPI_UINT64_T, 0, mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD); - if(size > 0) MPI_Send(graph_str.data(), static_cast(size), MPI_BYTE, 0, mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD); + assert(size < std::numeric_limits::max()); + MPI_Send(&size, 1, MPI_INT32_T, 0, mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD); + if(size > 0) MPI_Send(graph_str.data(), static_cast(size), MPI_BYTE, 0, mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD); return ""; } else { std::vector graphs; graphs.push_back(graph_str); for(size_t i = 1; i < m_num_nodes; ++i) { - uint64_t size = 0; - MPI_Recv(&size, 1, MPI_UINT64_T, static_cast(i), mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + int32_t size = 0; + MPI_Recv(&size, 1, MPI_INT32_T, static_cast(i), mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD, MPI_STATUS_IGNORE); if(size > 0) { std::string graph; graph.resize(size); - MPI_Recv( - graph.data(), static_cast(size), MPI_BYTE, static_cast(i), mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD, MPI_STATUS_IGNORE); + MPI_Recv(graph.data(), size, MPI_BYTE, static_cast(i), mpi_support::TAG_PRINT_GRAPH, MPI_COMM_WORLD, MPI_STATUS_IGNORE); graphs.push_back(std::move(graph)); } } diff --git a/src/scheduler.cc b/src/scheduler.cc index e90abc23e..2796354b2 100644 --- a/src/scheduler.cc +++ b/src/scheduler.cc @@ -17,8 +17,6 @@ namespace detail { void abstract_scheduler::shutdown() { notify(event_shutdown{}); } - std::string abstract_scheduler::print_command_graph() const { return m_dggen->print_command_graph(); } - void abstract_scheduler::schedule() { graph_serializer serializer([this](command_pkg&& pkg) { if(m_is_dry_run && pkg.get_command_type() != command_type::epoch) { return; } diff --git a/src/task_manager.cc b/src/task_manager.cc index ec0dabaac..de146ee39 100644 --- a/src/task_manager.cc +++ b/src/task_manager.cc @@ -6,12 +6,12 @@ namespace celerity { namespace detail { - task_manager::task_manager(size_t num_collective_nodes, host_queue* queue, std::optional recorder) // - : m_num_collective_nodes(num_collective_nodes), m_queue(queue), m_task_recorder(std::move(recorder)) { + task_manager::task_manager(size_t num_collective_nodes, host_queue* queue, detail::task_recorder* recorder) // + : m_num_collective_nodes(num_collective_nodes), m_queue(queue), m_task_recorder(recorder) { // We manually generate the initial epoch task, which we treat as if it has been reached immediately. auto reserve = m_task_buffer.reserve_task_entry(await_free_task_slot_callback()); auto initial_epoch = task::make_epoch(initial_epoch_task, epoch_action::none); - if(m_task_recorder) m_task_recorder->record_task(*initial_epoch); + if(m_task_recorder != nullptr) m_task_recorder->record_task(*initial_epoch); m_task_buffer.put(std::move(reserve), std::move(initial_epoch)); } @@ -28,12 +28,6 @@ namespace detail { // we don't need to worry about thread-safety after returning the task pointer. const task* task_manager::get_task(task_id tid) const { return m_task_buffer.get_task(tid); } - std::string task_manager::print_task_graph() const { - if(m_task_recorder) { return detail::print_task_graph(*m_task_recorder); } - CELERITY_ERROR("Trying to print task graph, but no recorder available"); - return ""; - } - void task_manager::notify_horizon_reached(task_id horizon_tid) { // m_latest_horizon_reached does not need synchronization (see definition), all other accesses are implicitly synchronized. @@ -189,7 +183,7 @@ namespace detail { for(const auto& cb : m_task_callbacks) { cb(tsk); } - if(m_task_recorder) m_task_recorder->record_task(*tsk); + if(m_task_recorder != nullptr) m_task_recorder->record_task(*tsk); } void task_manager::add_dependency(task& depender, task& dependee, dependency_kind kind, dependency_origin origin) { diff --git a/test/accessor_tests.cc b/test/accessor_tests.cc index 62e509b45..7cd7f4435 100644 --- a/test/accessor_tests.cc +++ b/test/accessor_tests.cc @@ -678,8 +678,9 @@ namespace detail { } const auto attempted_sr = subrange<3>{id_cast<3>(oob_idx_lo), range_cast<3>(oob_idx_hi - oob_idx_lo + id_cast(range(unit_range)))}; - const auto error_message = fmt::format("Out-of-bounds access in kernel 'acc_out_of_bounds_kernel<{}>' detected: Accessor 0 for buffer 0 attempted to " - "access indices between {} which are outside of mapped subrange {}", + const auto error_message = fmt::format("Out-of-bounds access in kernel 'celerity::detail::acc_out_of_bounds_kernel<{}>' " + "detected: Accessor 0 for buffer 0 attempted to access indices " + "between {} which are outside of mapped subrange {}", Dims, attempted_sr, subrange_cast<3>(accessible_sr)); CHECK_THAT(lc->get_log(), Catch::Matchers::ContainsSubstring(error_message)); } diff --git a/test/benchmarks.cc b/test/benchmarks.cc index 29966a833..6e6420870 100644 --- a/test/benchmarks.cc +++ b/test/benchmarks.cc @@ -72,7 +72,7 @@ TEST_CASE("benchmark task handling", "[benchmark][task]") { auto initialization_lambda = [&] { highest_tid = 0; - tm = std::make_unique(1, nullptr, no_task_recorder); + tm = std::make_unique(1, nullptr, nullptr); // we use this trick to force horizon creation without introducing dependency overhead in this microbenchmark tm->set_horizon_step(0); }; @@ -137,7 +137,8 @@ TEST_CASE("benchmark task handling", "[benchmark][task]") { struct task_manager_benchmark_context { const size_t num_nodes = 1; - task_manager tm{1, nullptr, {}}; + task_recorder trec; + task_manager tm{1, nullptr, test_utils::print_graphs ? &trec : nullptr}; test_utils::mock_buffer_factory mbf{tm}; ~task_manager_benchmark_context() { tm.generate_epoch_task(celerity::detail::epoch_action::shutdown); } @@ -155,12 +156,14 @@ struct graph_generator_benchmark_context { const size_t num_nodes; command_graph cdag; graph_serializer gser{[](command_pkg&&) {}}; - task_manager tm{num_nodes, nullptr, {}}; + task_recorder trec; + task_manager tm{num_nodes, nullptr, test_utils::print_graphs ? &trec : nullptr}; + command_recorder crec; distributed_graph_generator dggen; test_utils::mock_buffer_factory mbf; explicit graph_generator_benchmark_context(size_t num_nodes) - : num_nodes{num_nodes}, dggen{num_nodes, 0 /* local_nid */, cdag, tm, no_command_recorder}, mbf{tm, dggen} { + : num_nodes{num_nodes}, crec(&tm), dggen{num_nodes, 0 /* local_nid */, cdag, tm, test_utils::print_graphs ? &crec : nullptr}, mbf{tm, dggen} { tm.register_task_callback([this](const task* tsk) { const auto cmds = dggen.build_task(*tsk); gser.flush(cmds); @@ -260,8 +263,8 @@ struct scheduler_benchmark_context { test_utils::mock_buffer_factory mbf; explicit scheduler_benchmark_context(restartable_thread& thrd, size_t num_nodes) - : num_nodes{num_nodes}, // - schdlr{thrd, std::make_unique(num_nodes, 0 /* local_nid */, cdag, tm, no_command_recorder)}, // + : num_nodes{num_nodes}, // + schdlr{thrd, std::make_unique(num_nodes, 0 /* local_nid */, cdag, tm, nullptr)}, // mbf{tm, schdlr} { tm.register_task_callback([this](const task* tsk) { schdlr.notify_task_created(tsk); }); schdlr.startup(); @@ -466,9 +469,9 @@ void debug_graphs(BenchmarkContextFactory&& make_ctx, BenchmarkContextConsumer&& } TEST_CASE("printing benchmark task graphs", "[.][debug-graphs][task-graph]") { - debug_graphs([] { return task_manager_benchmark_context{}; }, [](auto&& ctx) { test_utils::maybe_print_graph(ctx.tm); }); + debug_graphs([] { return task_manager_benchmark_context{}; }, [](auto&& ctx) { test_utils::maybe_print_task_graph(ctx.trec); }); } TEST_CASE("printing benchmark command graphs", "[.][debug-graphs][command-graph]") { - debug_graphs([] { return graph_generator_benchmark_context{2}; }, [](auto&& ctx) { test_utils::maybe_print_command_graph(ctx.dggen); }); + debug_graphs([] { return graph_generator_benchmark_context{2}; }, [](auto&& ctx) { test_utils::maybe_print_command_graph(0, ctx.crec); }); } diff --git a/test/buffer_manager_tests.cc b/test/buffer_manager_tests.cc index 85715f138..6d585d806 100644 --- a/test/buffer_manager_tests.cc +++ b/test/buffer_manager_tests.cc @@ -48,7 +48,6 @@ namespace detail { REQUIRE_FALSE(bm.has_buffer(b_id)); // TODO: check whether error was printed or not - test_utils::maybe_print_graph(celerity::detail::runtime::get_instance().get_task_manager()); } // ComputeCPP based on Clang 8 segfaults in these tests diff --git a/test/distributed_graph_generator_test_utils.h b/test/distributed_graph_generator_test_utils.h index bc7d6a24f..83847a442 100644 --- a/test/distributed_graph_generator_test_utils.h +++ b/test/distributed_graph_generator_test_utils.h @@ -375,10 +375,12 @@ class dist_cdag_test_context { public: dist_cdag_test_context(size_t num_nodes) : m_num_nodes(num_nodes) { m_rm = std::make_unique(); - m_tm = std::make_unique(num_nodes, nullptr /* host_queue */, task_recorder{}); + m_task_recorder = std::make_unique(); + m_tm = std::make_unique(num_nodes, nullptr /* host_queue */, m_task_recorder.get()); for(node_id nid = 0; nid < num_nodes; ++nid) { m_cdags.emplace_back(std::make_unique()); - m_dggens.emplace_back(std::make_unique(num_nodes, nid, *m_cdags[nid], *m_tm, command_recorder{m_tm.get(), nullptr})); + m_cmd_recorders.emplace_back(std::make_unique(m_tm.get(), nullptr)); + m_dggens.emplace_back(std::make_unique(num_nodes, nid, *m_cdags[nid], *m_tm, m_cmd_recorders[nid].get())); } } @@ -461,6 +463,9 @@ class dist_cdag_test_context { distributed_graph_generator& get_graph_generator(node_id nid) { return *m_dggens.at(nid); } + std::string print_task_graph() { return detail::print_task_graph(*m_task_recorder); } + std::string print_command_graph(node_id nid) { return detail::print_command_graph(nid, *m_cmd_recorders[nid]); } + private: size_t m_num_nodes; buffer_id m_next_buffer_id = 0; @@ -469,8 +474,10 @@ class dist_cdag_test_context { std::optional m_most_recently_built_horizon; std::unique_ptr m_rm; std::unique_ptr m_tm; + std::unique_ptr m_task_recorder; std::vector> m_cdags; std::vector> m_dggens; + std::vector> m_cmd_recorders; reduction_info create_reduction(const buffer_id bid, const bool include_current_buffer_value) { return reduction_info{m_next_reduction_id++, bid, include_current_buffer_value}; @@ -493,11 +500,10 @@ class dist_cdag_test_context { void maybe_print_graphs() { if(test_utils::print_graphs) { - test_utils::maybe_print_graph(*m_tm); - + print_task_graph(); std::vector graphs; for(node_id nid = 0; nid < m_num_nodes; ++nid) { - graphs.push_back(m_dggens[nid]->print_command_graph()); + graphs.push_back(print_command_graph(nid)); } CELERITY_INFO("Command graph:\n\n{}\n", combine_command_graphs(graphs)); } diff --git a/test/graph_compaction_tests.cc b/test/graph_compaction_tests.cc index 40950631b..241137316 100644 --- a/test/graph_compaction_tests.cc +++ b/test/graph_compaction_tests.cc @@ -222,14 +222,13 @@ TEST_CASE("side-effect dependencies are correctly subsumed by horizons", "[distr TEST_CASE("reaching an epoch will prune all nodes of the preceding task graph", "[task_manager][task-graph][epoch]") { constexpr int num_nodes = 2; - task_manager tm{num_nodes, nullptr, {}}; - test_utils::mock_buffer_factory mbf(tm); + auto tt = test_utils::task_test_context{}; const auto check_task_has_exact_dependencies = [&](const char* info, const task_id dependent, const std::initializer_list> dependencies) { INFO(info); CAPTURE(dependent); - const auto actual = tm.get_task(dependent)->get_dependencies(); + const auto actual = tt.tm.get_task(dependent)->get_dependencies(); CHECK(static_cast(std::distance(actual.begin(), actual.end())) == dependencies.size()); for(const auto& [tid, kind, origin] : dependencies) { CAPTURE(tid); @@ -248,35 +247,33 @@ TEST_CASE("reaching an epoch will prune all nodes of the preceding task graph", const auto node_range = range<1>{num_nodes}; const auto init_tid = task_manager::initial_epoch_task; - auto early_host_initialized_buf = mbf.create_buffer(node_range, true); - auto buf_written_from_kernel = mbf.create_buffer(node_range, false); + auto early_host_initialized_buf = tt.mbf.create_buffer(node_range, true); + auto buf_written_from_kernel = tt.mbf.create_buffer(node_range, false); const auto writer_tid = test_utils::add_compute_task( - tm, [&](handler& cgh) { buf_written_from_kernel.get_access(cgh, acc::one_to_one{}); }, node_range); + tt.tm, [&](handler& cgh) { buf_written_from_kernel.get_access(cgh, acc::one_to_one{}); }, node_range); - const auto epoch_tid = tm.generate_epoch_task(epoch_action::none); + const auto epoch_tid = tt.tm.generate_epoch_task(epoch_action::none); const auto reader_writer_tid = test_utils::add_compute_task( - tm, [&](handler& cgh) { early_host_initialized_buf.get_access(cgh, acc::one_to_one{}); }, node_range); + tt.tm, [&](handler& cgh) { early_host_initialized_buf.get_access(cgh, acc::one_to_one{}); }, node_range); - auto late_host_initialized_buf = mbf.create_buffer(node_range, true); + auto late_host_initialized_buf = tt.mbf.create_buffer(node_range, true); const auto late_writer_tid = test_utils::add_compute_task( - tm, [&](handler& cgh) { late_host_initialized_buf.get_access(cgh, acc::one_to_one{}); }, node_range); + tt.tm, [&](handler& cgh) { late_host_initialized_buf.get_access(cgh, acc::one_to_one{}); }, node_range); - test_utils::maybe_print_graph(tm); - - REQUIRE(tm.has_task(init_tid)); + REQUIRE(tt.tm.has_task(init_tid)); check_task_has_exact_dependencies("initial epoch task", init_tid, {}); - REQUIRE(tm.has_task(writer_tid)); + REQUIRE(tt.tm.has_task(writer_tid)); check_task_has_exact_dependencies("writer", writer_tid, {{init_tid, dependency_kind::true_dep, dependency_origin::last_epoch}}); - REQUIRE(tm.has_task(epoch_tid)); + REQUIRE(tt.tm.has_task(epoch_tid)); check_task_has_exact_dependencies("epoch before", epoch_tid, {{writer_tid, dependency_kind::true_dep, dependency_origin::execution_front}}); - tm.notify_epoch_reached(epoch_tid); + tt.tm.notify_epoch_reached(epoch_tid); const auto reader_tid = test_utils::add_compute_task( - tm, + tt.tm, [&](handler& cgh) { early_host_initialized_buf.get_access(cgh, acc::one_to_one{}); late_host_initialized_buf.get_access(cgh, acc::one_to_one{}); @@ -284,21 +281,19 @@ TEST_CASE("reaching an epoch will prune all nodes of the preceding task graph", }, node_range); - CHECK(!tm.has_task(init_tid)); - CHECK(!tm.has_task(writer_tid)); - REQUIRE(tm.has_task(epoch_tid)); + CHECK(!tt.tm.has_task(init_tid)); + CHECK(!tt.tm.has_task(writer_tid)); + REQUIRE(tt.tm.has_task(epoch_tid)); check_task_has_exact_dependencies("epoch after", epoch_tid, {}); - REQUIRE(tm.has_task(reader_writer_tid)); + REQUIRE(tt.tm.has_task(reader_writer_tid)); check_task_has_exact_dependencies("reader-writer", reader_writer_tid, {{epoch_tid, dependency_kind::true_dep, dependency_origin::dataflow}}); - REQUIRE(tm.has_task(late_writer_tid)); + REQUIRE(tt.tm.has_task(late_writer_tid)); check_task_has_exact_dependencies("late writer", late_writer_tid, {{epoch_tid, dependency_kind::true_dep, dependency_origin::last_epoch}}); - REQUIRE(tm.has_task(reader_tid)); + REQUIRE(tt.tm.has_task(reader_tid)); check_task_has_exact_dependencies("reader", reader_tid, { {epoch_tid, dependency_kind::anti_dep, dependency_origin::dataflow}, {reader_writer_tid, dependency_kind::true_dep, dependency_origin::dataflow}, {late_writer_tid, dependency_kind::true_dep, dependency_origin::dataflow}, }); - - test_utils::maybe_print_graph(tm); } \ No newline at end of file diff --git a/test/print_graph_tests.cc b/test/print_graph_tests.cc index 99e2f00c6..1089b9229 100644 --- a/test/print_graph_tests.cc +++ b/test/print_graph_tests.cc @@ -13,30 +13,27 @@ using namespace celerity::test_utils; namespace acc = celerity::access; TEST_CASE("task-graph printing is unchanged", "[print_graph][task-graph]") { - task_recorder tr; - task_manager tm{1, nullptr, tr}; - test_utils::mock_buffer_factory mbf(tm); - test_utils::mock_reduction_factory mrf; + auto tt = test_utils::task_test_context{}; auto range = celerity::range<1>(64); - auto buf_0 = mbf.create_buffer(range); - auto buf_1 = mbf.create_buffer(celerity::range<1>(1)); + auto buf_0 = tt.mbf.create_buffer(range); + auto buf_1 = tt.mbf.create_buffer(celerity::range<1>(1)); // graph copied from graph_gen_reduction_tests "distributed_graph_generator generates reduction command trees" test_utils::add_compute_task( - tm, [&](handler& cgh) { buf_1.get_access(cgh, acc::one_to_one{}); }, range); + tt.tm, [&](handler& cgh) { buf_1.get_access(cgh, acc::one_to_one{}); }, range); test_utils::add_compute_task( - tm, [&](handler& cgh) { buf_0.get_access(cgh, acc::one_to_one{}); }, range); + tt.tm, [&](handler& cgh) { buf_0.get_access(cgh, acc::one_to_one{}); }, range); test_utils::add_compute_task( - tm, + tt.tm, [&](handler& cgh) { buf_0.get_access(cgh, acc::one_to_one{}); - test_utils::add_reduction(cgh, mrf, buf_1, true /* include_current_buffer_value */); + test_utils::add_reduction(cgh, tt.mrf, buf_1, true /* include_current_buffer_value */); }, range); test_utils::add_compute_task( - tm, + tt.tm, [&](handler& cgh) { buf_1.get_access(cgh, acc::fixed<1>({0, 1})); }, @@ -44,7 +41,7 @@ TEST_CASE("task-graph printing is unchanged", "[print_graph][task-graph]") { // Smoke test: It is valid for the dot output to change with updates to graph generation. If this test fails, verify that the printed graph is sane and // replace the `expected` value with the new dot graph. - const auto expected = + const std::string expected = "digraph G {label=\"Task Graph\" 0[shape=ellipse label=epoch>];1[shape=box style=rounded label=device-compute [0,0,0] - [64,1,1]
discard_write B1 {[[0,0,0] - [1,1,1]]}>];0->1[color=orchid];2[shape=box style=rounded " "label=device-compute [0,0,0] - [64,1,1]
discard_write B0 {[[0,0,0] - " @@ -52,7 +49,7 @@ TEST_CASE("task-graph printing is unchanged", "[print_graph][task-graph]") { "read_write B1 {[[0,0,0] - [1,1,1]]}
read B0 {[[0,0,0] - [64,1,1]]}>];1->3[];2->3[];4[shape=box style=rounded label=device-compute [0,0,0] - [64,1,1]
read B1 {[[0,0,0] - [1,1,1]]}>];3->4[];}"; - CHECK(tm.print_task_graph() == expected); + CHECK(print_task_graph(tt.trec) == expected); } namespace { @@ -94,13 +91,13 @@ TEST_CASE("command graph printing is unchanged", "[print_graph][command-graph]") "[[0,0,0] - [1,1,1]]> fontcolor=black shape=ellipse];id_0_1->id_0_7[];}"; // fully check node 0 - const auto dot0 = dctx.get_graph_generator(0).print_command_graph(); + const auto dot0 = dctx.print_command_graph(0); CHECK(dot0 == expected); // only check the rough string length and occurence count of N1/N2... for other nodes const int expected_occurences = count_occurences(expected, "N0"); for(size_t i = 1; i < num_nodes; ++i) { - const auto dot_n = dctx.get_graph_generator(i).print_command_graph(); + const auto dot_n = dctx.print_command_graph(i); REQUIRE_THAT(dot_n.size(), Catch::Matchers::WithinAbs(expected.size(), 50)); CHECK(count_occurences(dot_n, fmt::format("N{}", i)) == expected_occurences); } @@ -125,13 +122,13 @@ TEST_CASE_METHOD(test_utils::runtime_fixture, "buffer debug names show up in the q.slow_full_sync(); using Catch::Matchers::ContainsSubstring; - const auto expected_substring = "B0 \"my_buffer\""; + const std::string expected_substring = "B0 \"my_buffer\""; SECTION("in the task graph") { - const auto dot = celerity::detail::runtime::get_instance().get_task_manager().print_task_graph(); + const auto dot = runtime_testspy::print_task_graph(celerity::detail::runtime::get_instance()); REQUIRE_THAT(dot, ContainsSubstring(expected_substring)); } SECTION("in the command graph") { - const auto dot = runtime_testspy::print_graph(celerity::detail::runtime::get_instance()); + const auto dot = runtime_testspy::print_command_graph(0, celerity::detail::runtime::get_instance()); REQUIRE_THAT(dot, ContainsSubstring(expected_substring)); } } @@ -169,7 +166,7 @@ TEST_CASE_METHOD(test_utils::runtime_fixture, "full graph is printed if CELERITY "
device-compute [0,0,0] - [16,1,1]
read_write B0 {[[0,0,0] - [16,1,1]]}>];3->5[];6[shape=ellipse " "label=horizon>];5->6[color=orange];4->6[color=orange];7[shape=ellipse label=epoch>];6->7[color=orange];}"; - CHECK(tm.print_task_graph() == expected); + CHECK(runtime_testspy::print_task_graph(celerity::detail::runtime::get_instance()) == expected); } SECTION("command graph") { @@ -189,7 +186,7 @@ TEST_CASE_METHOD(test_utils::runtime_fixture, "full graph is printed if CELERITY "shape=box];}id_0_0->id_0_1[];id_0_1->id_0_2[color=orange];id_0_1->id_0_3[];id_0_3->id_0_4[color=orange];id_0_2->id_0_4[color=orange];id_0_3->id_0_" "5[];id_0_5->id_0_6[color=orange];id_0_4->id_0_6[color=orange];id_0_6->id_0_7[color=orange];}"; - CHECK(runtime_testspy::print_graph(celerity::detail::runtime::get_instance()) == expected); + CHECK(runtime_testspy::print_command_graph(0, celerity::detail::runtime::get_instance()) == expected); } } @@ -208,16 +205,13 @@ void compute(task_manager& tm, mock_buffer<1> buf, const celerity::range<1> rang } // namespace test_ns TEST_CASE("task-graph names are escaped", "[print_graph][task-graph][task-name]") { - task_recorder tr; - task_manager tm{1, nullptr, tr}; - test_utils::mock_buffer_factory mbf(tm); - test_utils::mock_reduction_factory mrf; + auto tt = test_utils::task_test_context{}; auto range = celerity::range<1>(64); - auto buf = mbf.create_buffer(range); + auto buf = tt.mbf.create_buffer(range); - test_ns::compute(tm, buf, range); + test_ns::compute(tt.tm, buf, range); const auto* escaped_name = "\"name_class<(test_ns::x::ec)0>\""; - REQUIRE_THAT(tm.print_task_graph(), Catch::Matchers::ContainsSubstring(escaped_name)); + REQUIRE_THAT(print_task_graph(tt.trec), Catch::Matchers::ContainsSubstring(escaped_name)); } diff --git a/test/runtime_tests.cc b/test/runtime_tests.cc index 61e343d88..c76b3692f 100644 --- a/test/runtime_tests.cc +++ b/test/runtime_tests.cc @@ -1220,7 +1220,6 @@ namespace detail { // intial epoch + master-node task + 1 push per node + host task + sync epoch // (dry runs currently always simulate node 0, hence the master-node task) CHECK(runtime_testspy::get_command_count(rt) == 4 + num_nodes); - test_utils::maybe_print_graph(tm); } TEST_CASE_METHOD(test_utils::runtime_fixture, "Dry run generates commands for an arbitrary number of simulated worker nodes", "[dryrun]") { diff --git a/test/system/distr_tests.cc b/test/system/distr_tests.cc index 418524c20..75525d557 100644 --- a/test/system/distr_tests.cc +++ b/test/system/distr_tests.cc @@ -310,7 +310,7 @@ namespace detail { MPI_Comm test_communicator; MPI_Comm_create(MPI_COMM_WORLD, world_group, &test_communicator); - const auto graph_str = runtime::get_instance().get_task_manager().print_task_graph(); + const auto graph_str = runtime_testspy::print_task_graph(runtime::get_instance()); const int graph_str_length = static_cast(graph_str.length()); REQUIRE(graph_str_length > 0); diff --git a/test/task_graph_tests.cc b/test/task_graph_tests.cc index fdb51f93e..9f828dcb8 100644 --- a/test/task_graph_tests.cc +++ b/test/task_graph_tests.cc @@ -20,81 +20,72 @@ namespace detail { TEST_CASE("task_manager does not create multiple dependencies between the same tasks", "[task_manager][task-graph]") { using namespace cl::sycl::access; - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf_a = mbf.create_buffer(range<1>(128)); - auto buf_b = mbf.create_buffer(range<1>(128)); + auto tt = test_utils::task_test_context{}; + auto buf_a = tt.mbf.create_buffer(range<1>(128)); + auto buf_b = tt.mbf.create_buffer(range<1>(128)); SECTION("true dependencies") { - const auto tid_a = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - const auto tid_b = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - CHECK(has_dependency(tm, tid_b, tid_a)); + CHECK(has_dependency(tt.tm, tid_b, tid_a)); - const auto its = tm.get_task(tid_a)->get_dependents(); + const auto its = tt.tm.get_task(tid_a)->get_dependents(); REQUIRE(std::distance(its.begin(), its.end()) == 1); - - test_utils::maybe_print_graph(tm); } SECTION("anti-dependencies") { - const auto tid_a = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - const auto tid_b = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - CHECK(has_dependency(tm, tid_b, tid_a, dependency_kind::anti_dep)); + CHECK(has_dependency(tt.tm, tid_b, tid_a, dependency_kind::anti_dep)); - const auto its = tm.get_task(tid_a)->get_dependents(); + const auto its = tt.tm.get_task(tid_a)->get_dependents(); REQUIRE(std::distance(its.begin(), its.end()) == 1); - - test_utils::maybe_print_graph(tm); } // Here we also check that true dependencies always take precedence SECTION("true and anti-dependencies combined") { SECTION("if true is declared first") { - const auto tid_a = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - const auto tid_b = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - CHECK(has_dependency(tm, tid_b, tid_a)); - CHECK_FALSE(has_dependency(tm, tid_b, tid_a, dependency_kind::anti_dep)); + CHECK(has_dependency(tt.tm, tid_b, tid_a)); + CHECK_FALSE(has_dependency(tt.tm, tid_b, tid_a, dependency_kind::anti_dep)); - const auto its = tm.get_task(tid_a)->get_dependents(); + const auto its = tt.tm.get_task(tid_a)->get_dependents(); REQUIRE(std::distance(its.begin(), its.end()) == 1); - - test_utils::maybe_print_graph(tm); } SECTION("if anti is declared first") { - const auto tid_a = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - const auto tid_b = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - CHECK(has_dependency(tm, tid_b, tid_a)); - CHECK_FALSE(has_dependency(tm, tid_b, tid_a, dependency_kind::anti_dep)); + CHECK(has_dependency(tt.tm, tid_b, tid_a)); + CHECK_FALSE(has_dependency(tt.tm, tid_b, tid_a, dependency_kind::anti_dep)); - const auto its = tm.get_task(tid_a)->get_dependents(); + const auto its = tt.tm.get_task(tid_a)->get_dependents(); REQUIRE(std::distance(its.begin(), its.end()) == 1); - - test_utils::maybe_print_graph(tm); } } } @@ -102,83 +93,76 @@ namespace detail { TEST_CASE("task_manager respects range mapper results for finding dependencies", "[task_manager][task-graph]") { using namespace cl::sycl::access; - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<1>(128)); + auto tt = test_utils::task_test_context{}; + auto buf = tt.mbf.create_buffer(range<1>(128)); - const auto tid_a = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_a = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 64}}); }); - const auto tid_b = test_utils::add_compute_task(tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 128}}); }); - REQUIRE(has_dependency(tm, tid_b, tid_a)); - - const auto tid_c = test_utils::add_compute_task(tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{64, 128}}); }); - REQUIRE_FALSE(has_dependency(tm, tid_c, tid_a)); + const auto tid_b = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 128}}); }); + REQUIRE(has_dependency(tt.tm, tid_b, tid_a)); - test_utils::maybe_print_graph(tm); + const auto tid_c = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{64, 128}}); }); + REQUIRE_FALSE(has_dependency(tt.tm, tid_c, tid_a)); } TEST_CASE("task_manager correctly generates anti-dependencies", "[task_manager][task-graph]") { using namespace cl::sycl::access; - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<1>(128)); + + auto tt = test_utils::task_test_context{}; + auto buf = tt.mbf.create_buffer(range<1>(128)); // Write to the full buffer - const auto tid_a = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_a = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 128}}); }); // Read the first half of the buffer - const auto tid_b = test_utils::add_compute_task(tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 64}}); }); - CHECK(has_dependency(tm, tid_b, tid_a)); + const auto tid_b = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 64}}); }); + CHECK(has_dependency(tt.tm, tid_b, tid_a)); // Overwrite the second half - no anti-dependency onto task_b should exist (but onto task_a) - const auto tid_c = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_c = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{64, 64}}); }); - REQUIRE(has_dependency(tm, tid_c, tid_a, dependency_kind::anti_dep)); - REQUIRE_FALSE(has_dependency(tm, tid_c, tid_b, dependency_kind::anti_dep)); + REQUIRE(has_dependency(tt.tm, tid_c, tid_a, dependency_kind::anti_dep)); + REQUIRE_FALSE(has_dependency(tt.tm, tid_c, tid_b, dependency_kind::anti_dep)); // Overwrite the first half - now only an anti-dependency onto task_b should exist - const auto tid_d = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_d = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 64}}); }); - REQUIRE_FALSE(has_dependency(tm, tid_d, tid_a, dependency_kind::anti_dep)); - REQUIRE(has_dependency(tm, tid_d, tid_b, dependency_kind::anti_dep)); - - test_utils::maybe_print_graph(tm); + REQUIRE_FALSE(has_dependency(tt.tm, tid_d, tid_a, dependency_kind::anti_dep)); + REQUIRE(has_dependency(tt.tm, tid_d, tid_b, dependency_kind::anti_dep)); } TEST_CASE("task_manager correctly handles host-initialized buffers", "[task_manager][task-graph]") { using namespace cl::sycl::access; - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto host_init_buf = mbf.create_buffer(range<1>(128), true); - auto non_host_init_buf = mbf.create_buffer(range<1>(128), false); - auto artificial_dependency_buf = mbf.create_buffer(range<1>(1), false); - const auto tid_a = test_utils::add_compute_task(tm, [&](handler& cgh) { + auto tt = test_utils::task_test_context{}; + auto host_init_buf = tt.mbf.create_buffer(range<1>(128), true); + auto non_host_init_buf = tt.mbf.create_buffer(range<1>(128), false); + auto artificial_dependency_buf = tt.mbf.create_buffer(range<1>(1), false); + + const auto tid_a = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { host_init_buf.get_access(cgh, fixed<1>{{0, 128}}); artificial_dependency_buf.get_access(cgh, all{}); }); - CHECK(has_dependency(tm, tid_a, task_manager::initial_epoch_task)); + CHECK(has_dependency(tt.tm, tid_a, task_manager::initial_epoch_task)); - const auto tid_b = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_b = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { non_host_init_buf.get_access(cgh, fixed<1>{{0, 128}}); // introduce an arbitrary true-dependency to avoid the fallback epoch dependency that is generated for tasks without other true-dependencies artificial_dependency_buf.get_access(cgh, all{}); }); - CHECK_FALSE(has_dependency(tm, tid_b, task_manager::initial_epoch_task)); + CHECK_FALSE(has_dependency(tt.tm, tid_b, task_manager::initial_epoch_task)); - const auto tid_c = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_c = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { host_init_buf.get_access(cgh, fixed<1>{{0, 128}}); }); - CHECK(has_dependency(tm, tid_c, tid_a, dependency_kind::anti_dep)); - const auto tid_d = test_utils::add_compute_task(tm, [&](handler& cgh) { + CHECK(has_dependency(tt.tm, tid_c, tid_a, dependency_kind::anti_dep)); + const auto tid_d = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { non_host_init_buf.get_access(cgh, fixed<1>{{0, 128}}); }); // Since task b is essentially reading uninitialized garbage, it doesn't make a difference if we write into it concurrently - CHECK_FALSE(has_dependency(tm, tid_d, tid_b, dependency_kind::anti_dep)); - - test_utils::maybe_print_graph(tm); + CHECK_FALSE(has_dependency(tt.tm, tid_d, tid_b, dependency_kind::anti_dep)); } template @@ -201,19 +185,18 @@ namespace detail { const std::vector> rw_mode_sets = {{mode::discard_read_write}, {mode::read_write}, {mode::atomic}, {mode::discard_write, mode::read}}; for(const auto& mode_set : rw_mode_sets) { - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<1>(128), true); + auto tt = test_utils::task_test_context{}; + auto buf = tt.mbf.create_buffer(range<1>(128), true); - const auto tid_a = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_a = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { for(const auto& m : mode_set) { dispatch_get_access(buf, cgh, m, fixed<1>{{0, 128}}); } }); - const auto tid_b = test_utils::add_compute_task(tm, [&](handler& cgh) { + const auto tid_b = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<1>{{0, 128}}); }); - REQUIRE(has_dependency(tm, tid_b, tid_a, dependency_kind::anti_dep)); + REQUIRE(has_dependency(tt.tm, tid_b, tid_a, dependency_kind::anti_dep)); } } @@ -223,65 +206,65 @@ namespace detail { for(const auto& producer_mode : detail::access::producer_modes) { CAPTURE(consumer_mode); CAPTURE(producer_mode); - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<1>(128), false); - const task_id tid_a = test_utils::add_compute_task(tm, [&](handler& cgh) { + auto tt = test_utils::task_test_context{}; + auto buf = tt.mbf.create_buffer(range<1>(128), false); + + const task_id tid_a = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { dispatch_get_access(buf, cgh, producer_mode, fixed<1>{{0, 128}}); }); - const task_id tid_b = test_utils::add_compute_task(tm, [&](handler& cgh) { + const task_id tid_b = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { dispatch_get_access(buf, cgh, consumer_mode, fixed<1>{{0, 128}}); }); - REQUIRE(has_dependency(tm, tid_b, tid_a)); + REQUIRE(has_dependency(tt.tm, tid_b, tid_a)); - const task_id tid_c = test_utils::add_compute_task(tm, [&](handler& cgh) { + const task_id tid_c = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { dispatch_get_access(buf, cgh, producer_mode, fixed<1>{{0, 128}}); }); const bool pure_consumer = consumer_mode == mode::read; const bool pure_producer = producer_mode == mode::discard_read_write || producer_mode == mode::discard_write; - REQUIRE(has_dependency(tm, tid_c, tid_b, pure_consumer || pure_producer ? dependency_kind::anti_dep : dependency_kind::true_dep)); + REQUIRE(has_dependency(tt.tm, tid_c, tid_b, pure_consumer || pure_producer ? dependency_kind::anti_dep : dependency_kind::true_dep)); } } } TEST_CASE("task_manager generates pseudo-dependencies for collective host tasks", "[task_manager][task-graph]") { - task_manager tm{1, nullptr, task_recorder{}}; + auto tt = test_utils::task_test_context{}; experimental::collective_group group; - auto tid_master = test_utils::add_host_task(tm, on_master_node, [](handler&) {}); - auto tid_collective_implicit_1 = test_utils::add_host_task(tm, experimental::collective, [](handler&) {}); - auto tid_collective_implicit_2 = test_utils::add_host_task(tm, experimental::collective, [](handler&) {}); - auto tid_collective_explicit_1 = test_utils::add_host_task(tm, experimental::collective(group), [](handler&) {}); - auto tid_collective_explicit_2 = test_utils::add_host_task(tm, experimental::collective(group), [](handler&) {}); - - CHECK_FALSE(has_any_dependency(tm, tid_master, tid_collective_implicit_1)); - CHECK_FALSE(has_any_dependency(tm, tid_master, tid_collective_implicit_2)); - CHECK_FALSE(has_any_dependency(tm, tid_master, tid_collective_explicit_1)); - CHECK_FALSE(has_any_dependency(tm, tid_master, tid_collective_explicit_2)); - - CHECK_FALSE(has_any_dependency(tm, tid_collective_implicit_1, tid_master)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_implicit_1, tid_collective_implicit_2)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_implicit_1, tid_collective_explicit_1)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_implicit_1, tid_collective_explicit_2)); - - CHECK_FALSE(has_any_dependency(tm, tid_collective_implicit_2, tid_master)); - CHECK(has_dependency(tm, tid_collective_implicit_2, tid_collective_implicit_1, dependency_kind::true_dep)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_implicit_2, tid_collective_explicit_1)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_implicit_2, tid_collective_explicit_2)); - - CHECK_FALSE(has_any_dependency(tm, tid_collective_explicit_1, tid_master)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_explicit_1, tid_collective_implicit_1)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_explicit_1, tid_collective_implicit_2)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_explicit_1, tid_collective_explicit_2)); - - CHECK_FALSE(has_any_dependency(tm, tid_collective_explicit_2, tid_master)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_explicit_2, tid_collective_implicit_1)); - CHECK_FALSE(has_any_dependency(tm, tid_collective_explicit_2, tid_collective_implicit_2)); - CHECK(has_dependency(tm, tid_collective_explicit_2, tid_collective_explicit_1, dependency_kind::true_dep)); + auto tid_master = test_utils::add_host_task(tt.tm, on_master_node, [](handler&) {}); + auto tid_collective_implicit_1 = test_utils::add_host_task(tt.tm, experimental::collective, [](handler&) {}); + auto tid_collective_implicit_2 = test_utils::add_host_task(tt.tm, experimental::collective, [](handler&) {}); + auto tid_collective_explicit_1 = test_utils::add_host_task(tt.tm, experimental::collective(group), [](handler&) {}); + auto tid_collective_explicit_2 = test_utils::add_host_task(tt.tm, experimental::collective(group), [](handler&) {}); + + CHECK_FALSE(has_any_dependency(tt.tm, tid_master, tid_collective_implicit_1)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_master, tid_collective_implicit_2)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_master, tid_collective_explicit_1)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_master, tid_collective_explicit_2)); + + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_implicit_1, tid_master)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_implicit_1, tid_collective_implicit_2)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_implicit_1, tid_collective_explicit_1)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_implicit_1, tid_collective_explicit_2)); + + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_implicit_2, tid_master)); + CHECK(has_dependency(tt.tm, tid_collective_implicit_2, tid_collective_implicit_1, dependency_kind::true_dep)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_implicit_2, tid_collective_explicit_1)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_implicit_2, tid_collective_explicit_2)); + + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_explicit_1, tid_master)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_explicit_1, tid_collective_implicit_1)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_explicit_1, tid_collective_implicit_2)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_explicit_1, tid_collective_explicit_2)); + + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_explicit_2, tid_master)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_explicit_2, tid_collective_implicit_1)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_collective_explicit_2, tid_collective_implicit_2)); + CHECK(has_dependency(tt.tm, tid_collective_explicit_2, tid_collective_explicit_1, dependency_kind::true_dep)); } - void check_path_length_and_front(task_manager& tm, int path_length, std::unordered_set exec_front) { + void check_path_length_and_front(task_manager& tm, int path_length, const std::unordered_set& exec_front) { { INFO("path length"); CHECK(task_manager_testspy::get_max_pseudo_critical_path_length(tm) == path_length); @@ -296,49 +279,49 @@ namespace detail { } TEST_CASE("task_manager keeps track of max pseudo critical path length and task front", "[task_manager][task-graph][task-front]") { - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf_a = mbf.create_buffer(range<1>(128)); + auto tt = test_utils::task_test_context{}; + auto buf_a = tt.mbf.create_buffer(range<1>(128)); - const auto tid_a = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); - check_path_length_and_front(tm, 1, {tid_a}); // 1: we always depend on the initial epoch task + check_path_length_and_front(tt.tm, 1, {tid_a}); // 1: we always depend on the initial epoch task - const auto tid_b = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); - check_path_length_and_front(tm, 2, {tid_b}); - - const auto tid_c = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); - check_path_length_and_front(tm, 3, {tid_c}); + check_path_length_and_front(tt.tm, 2, {tid_b}); - const auto tid_d = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) {}); - check_path_length_and_front(tm, 3, {tid_c, tid_d}); + const auto tid_c = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { + buf_a.get_access(cgh, fixed<1>({0, 128})); + }); + check_path_length_and_front(tt.tm, 3, {tid_c}); - test_utils::maybe_print_graph(tm); + const auto tid_d = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {}); + check_path_length_and_front(tt.tm, 3, {tid_c, tid_d}); } TEST_CASE("task horizons are being generated with correct dependencies", "[task_manager][task-graph][task-horizon]") { - task_manager tm{1, nullptr, task_recorder{}}; - tm.set_horizon_step(2); + auto tt = test_utils::task_test_context{}; - test_utils::mock_buffer_factory mbf(tm); - auto buf_a = mbf.create_buffer(range<1>(128)); + tt.tm.set_horizon_step(2); + auto buf_a = tt.mbf.create_buffer(range<1>(128)); - test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); - auto current_horizon = task_manager_testspy::get_current_horizon(tm); + auto current_horizon = task_manager_testspy::get_current_horizon(tt.tm); CHECK_FALSE(current_horizon.has_value()); - const auto tid_c = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); + const auto tid_c = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { + buf_a.get_access(cgh, fixed<1>({0, 128})); + }); - current_horizon = task_manager_testspy::get_current_horizon(tm); + current_horizon = task_manager_testspy::get_current_horizon(tt.tm); REQUIRE(current_horizon.has_value()); CHECK(*current_horizon == tid_c + 1); - CHECK(task_manager_testspy::get_num_horizons(tm) == 1); + CHECK(task_manager_testspy::get_num_horizons(tt.tm) == 1); - auto horizon_dependencies = tm.get_task(*current_horizon)->get_dependencies(); + auto horizon_dependencies = tt.tm.get_task(*current_horizon)->get_dependencies(); CHECK(std::distance(horizon_dependencies.begin(), horizon_dependencies.end()) == 1); CHECK(horizon_dependencies.begin()->node->get_id() == tid_c); @@ -347,23 +330,23 @@ namespace detail { // current horizon is always part of the active task front expected_dependency_ids.insert(*current_horizon); - expected_dependency_ids.insert(test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) {})); - expected_dependency_ids.insert(test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) {})); - expected_dependency_ids.insert(test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) {})); - CHECK(task_manager_testspy::get_num_horizons(tm) == 1); + expected_dependency_ids.insert(test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {})); + expected_dependency_ids.insert(test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {})); + expected_dependency_ids.insert(test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) {})); + CHECK(task_manager_testspy::get_num_horizons(tt.tm) == 1); - test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); - const auto tid_d = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); + const auto tid_d = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 128})); }); expected_dependency_ids.insert(tid_d); - current_horizon = task_manager_testspy::get_current_horizon(tm); + current_horizon = task_manager_testspy::get_current_horizon(tt.tm); REQUIRE(current_horizon.has_value()); CHECK(*current_horizon == tid_d + 1); - CHECK(task_manager_testspy::get_num_horizons(tm) == 2); + CHECK(task_manager_testspy::get_num_horizons(tt.tm) == 2); - horizon_dependencies = tm.get_task(*current_horizon)->get_dependencies(); + horizon_dependencies = tt.tm.get_task(*current_horizon)->get_dependencies(); CHECK(std::distance(horizon_dependencies.begin(), horizon_dependencies.end()) == 5); std::set actual_dependecy_ids; @@ -376,143 +359,136 @@ namespace detail { static inline GridRegion<3> make_region(int min, int max) { return GridRegion<3>(GridPoint<3>(min, 0, 0), GridPoint<3>(max, 1, 1)); } TEST_CASE("task horizons update previous writer data structure", "[task_manager][task-graph][task-horizon]") { - task_manager tm{1, nullptr, task_recorder{}}; - tm.set_horizon_step(2); + auto tt = test_utils::task_test_context{}; - test_utils::mock_buffer_factory mbf(tm); - auto buf_a = mbf.create_buffer(range<1>(128)); - auto buf_b = mbf.create_buffer(range<1>(128)); + tt.tm.set_horizon_step(2); + auto buf_a = tt.mbf.create_buffer(range<1>(128)); + auto buf_b = tt.mbf.create_buffer(range<1>(128)); - task_id tid_1 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_1 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({0, 64})); buf_b.get_access(cgh, fixed<1>({0, 128})); }); - task_id tid_2 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_2 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({64, 64})); }); - task_id tid_3 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_3 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({32, 64})); }); - task_id tid_4 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_4 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({32, 64})); }); - auto horizon = task_manager_testspy::get_current_horizon(tm); - CHECK(task_manager_testspy::get_num_horizons(tm) == 1); + auto horizon = task_manager_testspy::get_current_horizon(tt.tm); + CHECK(task_manager_testspy::get_num_horizons(tt.tm) == 1); CHECK(horizon.has_value()); - task_id tid_6 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_6 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_b.get_access(cgh, fixed<1>({0, 128})); }); - task_id tid_7 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_7 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_b.get_access(cgh, fixed<1>({0, 128})); }); { INFO("check that previous tasks are still last writers before the first horizon is applied"); - const auto& region_map_a = task_manager_testspy::get_last_writer(tm, buf_a.get_id()); + const auto& region_map_a = task_manager_testspy::get_last_writer(tt.tm, buf_a.get_id()); CHECK(region_map_a.get_region_values(make_region(0, 32)).front().second.value() == tid_1); CHECK(region_map_a.get_region_values(make_region(96, 128)).front().second.value() == tid_2); CHECK(region_map_a.get_region_values(make_region(32, 96)).front().second.value() == tid_4); } - task_id tid_8 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_8 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_b.get_access(cgh, fixed<1>({0, 128})); }); - CHECK(task_manager_testspy::get_num_horizons(tm) == 2); + CHECK(task_manager_testspy::get_num_horizons(tt.tm) == 2); { INFO("check that only the previous horizon is the last writer of buff_a"); - const auto& region_map_a = task_manager_testspy::get_last_writer(tm, buf_a.get_id()); + const auto& region_map_a = task_manager_testspy::get_last_writer(tt.tm, buf_a.get_id()); CHECK(region_map_a.get_region_values(make_region(0, 128)).front().second.value() == *horizon); } - task_id tid_9 = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + task_id tid_9 = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf_a.get_access(cgh, fixed<1>({64, 64})); }); { INFO("check that the previous horizon and task 11 are last writers of buff_a"); - const auto& region_map_a = task_manager_testspy::get_last_writer(tm, buf_a.get_id()); + const auto& region_map_a = task_manager_testspy::get_last_writer(tt.tm, buf_a.get_id()); CHECK(region_map_a.get_region_values(make_region(0, 64)).front().second.value() == *horizon); CHECK(region_map_a.get_region_values(make_region(64, 128)).front().second.value() == tid_9); } - - test_utils::maybe_print_graph(tm); } TEST_CASE("previous task horizon is used as last writer for host-initialized buffers", "[task_manager][task-graph][task-horizon]") { - task_manager tm{1, nullptr, task_recorder{}}; - tm.set_horizon_step(2); - - test_utils::mock_buffer_factory mbf(tm); + auto tt = test_utils::task_test_context{}; + tt.tm.set_horizon_step(2); task_id initial_last_writer_id = -1; { - auto buf = mbf.create_buffer(range<1>(1), true); - const auto tid = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); - const auto& deps = tm.get_task(tid)->get_dependencies(); + auto buf = tt.mbf.create_buffer(range<1>(1), true); + const auto tid = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + const auto& deps = tt.tm.get_task(tid)->get_dependencies(); CHECK(std::distance(deps.begin(), deps.end()) == 1); initial_last_writer_id = deps.begin()->node->get_id(); } - CHECK(tm.has_task(initial_last_writer_id)); + CHECK(tt.tm.has_task(initial_last_writer_id)); // Create a bunch of tasks to trigger horizon cleanup { - auto buf = mbf.create_buffer(range<1>(1)); + auto buf = tt.mbf.create_buffer(range<1>(1)); task_id last_executed_horizon = 0; // We need 7 tasks to generate a pseudo-critical path length of 6 (3x2 horizon step size), // and another one that triggers the actual deferred deletion. for(int i = 0; i < 8; ++i) { - const auto tid = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); - const auto current_horizon = task_manager_testspy::get_current_horizon(tm); + const auto tid = + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + const auto current_horizon = task_manager_testspy::get_current_horizon(tt.tm); if(current_horizon && *current_horizon > last_executed_horizon) { last_executed_horizon = *current_horizon; - tm.notify_horizon_reached(last_executed_horizon); + tt.tm.notify_horizon_reached(last_executed_horizon); } } } INFO("initial last writer with id " << initial_last_writer_id << " has been deleted"); - CHECK_FALSE(tm.has_task(initial_last_writer_id)); + CHECK_FALSE(tt.tm.has_task(initial_last_writer_id)); - auto buf = mbf.create_buffer(range<1>(1), true); - const auto tid = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); - const auto& deps = tm.get_task(tid)->get_dependencies(); + auto buf = tt.mbf.create_buffer(range<1>(1), true); + const auto tid = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + const auto& deps = tt.tm.get_task(tid)->get_dependencies(); CHECK(std::distance(deps.begin(), deps.end()) == 1); const auto* new_last_writer = deps.begin()->node; CHECK(new_last_writer->get_type() == task_type::horizon); - const auto current_horizon = task_manager_testspy::get_current_horizon(tm); + const auto current_horizon = task_manager_testspy::get_current_horizon(tt.tm); REQUIRE(current_horizon); INFO("previous horizon is being used"); CHECK(new_last_writer->get_id() < *current_horizon); - - test_utils::maybe_print_graph(tm); } TEST_CASE("collective host tasks do not order-depend on their predecessor if it is shadowed by a horizon", "[task_manager][task-graph][task-horizon]") { // Regression test: the order-dependencies between host tasks in the same collective group are built by tracking the last task in each collective group. // Once a horizon is inserted, new collective host tasks must order-depend on that horizon instead. - task_manager tm{1, nullptr, task_recorder{}}; - tm.set_horizon_step(2); + auto tt = test_utils::task_test_context{}; + tt.tm.set_horizon_step(2); + auto buf = tt.mbf.create_buffer(range<1>(1)); - const auto first_collective = test_utils::add_host_task(tm, experimental::collective, [&](handler& cgh) {}); + const auto first_collective = test_utils::add_host_task(tt.tm, experimental::collective, [&](handler& cgh) {}); // generate exactly two horizons - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<1>(1)); for(int i = 0; i < 4; ++i) { - test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); } // This must depend on the first horizon, not first_collective const auto second_collective = - test_utils::add_host_task(tm, experimental::collective, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + test_utils::add_host_task(tt.tm, experimental::collective, [&](handler& cgh) { buf.get_access(cgh, all{}); }); - const auto second_collective_deps = tm.get_task(second_collective)->get_dependencies(); + const auto second_collective_deps = tt.tm.get_task(second_collective)->get_dependencies(); const auto master_node_dep = std::find_if(second_collective_deps.begin(), second_collective_deps.end(), [](const task::dependency d) { return d.node->get_type() == task_type::master_node; }); const auto horizon_dep = std::find_if(second_collective_deps.begin(), second_collective_deps.end(), // @@ -523,14 +499,11 @@ namespace detail { CHECK(master_node_dep->kind == dependency_kind::true_dep); REQUIRE(horizon_dep != second_collective_deps.end()); CHECK(horizon_dep->kind == dependency_kind::true_dep); - - test_utils::maybe_print_graph(tm); } TEST_CASE("buffer accesses with empty ranges do not generate data-flow dependencies", "[task_manager][task-graph]") { - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<2>(32, 32)); + auto tt = test_utils::task_test_context{}; + auto buf = tt.mbf.create_buffer(range<2>(32, 32)); const auto write_sr = GENERATE(values({subrange<2>{{16, 16}, {0, 0}}, subrange<2>{{16, 16}, {8, 8}}})); const auto read_sr = GENERATE(values({subrange<2>{{1, 1}, {0, 0}}, subrange<2>{{8, 8}, {16, 16}}})); @@ -541,11 +514,11 @@ namespace detail { CAPTURE(write_empty); const auto write_tid = - test_utils::add_compute_task(tm, [&](handler& cgh) { buf.get_access(cgh, fixed<2>{write_sr}); }); + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<2>{write_sr}); }); const auto read_tid = - test_utils::add_compute_task(tm, [&](handler& cgh) { buf.get_access(cgh, fixed<2>{read_sr}); }); + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf.get_access(cgh, fixed<2>{read_sr}); }); - CHECK(has_any_dependency(tm, read_tid, write_tid) == (!write_empty && !read_empty)); + CHECK(has_any_dependency(tt.tm, read_tid, write_tid) == (!write_empty && !read_empty)); } TEST_CASE("side effects generate appropriate task-dependencies", "[task_manager][task-graph][side-effect]") { @@ -562,122 +535,114 @@ namespace detail { CAPTURE(order_a); CAPTURE(order_b); - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_host_object_factory mhof; - - auto ho_common = mhof.create_host_object(); // should generate dependencies - auto ho_a = mhof.create_host_object(); // should NOT generate dependencies - auto ho_b = mhof.create_host_object(); // -"- - const auto tid_a = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + auto tt = test_utils::task_test_context{}; + auto ho_common = tt.mhof.create_host_object(); // should generate dependencies + auto ho_a = tt.mhof.create_host_object(); // should NOT generate dependencies + auto ho_b = tt.mhof.create_host_object(); // -"- + const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { ho_common.add_side_effect(cgh, order_a); ho_a.add_side_effect(cgh, order_a); }); - const auto tid_b = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { + const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { ho_common.add_side_effect(cgh, order_b); ho_b.add_side_effect(cgh, order_b); }); - const auto deps_a = tm.get_task(tid_a)->get_dependencies(); + const auto deps_a = tt.tm.get_task(tid_a)->get_dependencies(); REQUIRE(std::distance(deps_a.begin(), deps_a.end()) == 1); CHECK(deps_a.front().node->get_id() == task_manager::initial_epoch_task); - const auto deps_b = tm.get_task(tid_b)->get_dependencies(); + const auto deps_b = tt.tm.get_task(tid_b)->get_dependencies(); const auto expected_b = expected_dependencies.at({order_a, order_b}); CHECK(std::distance(deps_b.begin(), deps_b.end()) == expected_b.has_value()); if(expected_b) { - CHECK(deps_b.front().node == tm.get_task(tid_a)); + CHECK(deps_b.front().node == tt.tm.get_task(tid_a)); CHECK(deps_b.front().kind == *expected_b); } } TEST_CASE("side-effect dependencies are correctly subsumed by horizons", "[task_manager][task-graph][task-horizon]") { - task_manager tm{1, nullptr, task_recorder{}}; - tm.set_horizon_step(2); + auto tt = test_utils::task_test_context{}; + tt.tm.set_horizon_step(2); + auto ho = tt.mhof.create_host_object(); - test_utils::mock_host_object_factory mhof; - auto ho = mhof.create_host_object(); const auto first_task = - test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); // generate exactly two horizons - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<1>(1)); + auto buf = tt.mbf.create_buffer(range<1>(1)); for(int i = 0; i < 5; ++i) { - test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); } // This must depend on the first horizon, not first_task const auto second_task = - test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); - const auto& second_deps = tm.get_task(second_task)->get_dependencies(); + const auto& second_deps = tt.tm.get_task(second_task)->get_dependencies(); CHECK(std::distance(second_deps.begin(), second_deps.end()) == 1); for(const auto& dep : second_deps) { const auto type = dep.node->get_type(); CHECK(type == task_type::horizon); CHECK(dep.kind == dependency_kind::true_dep); } - - test_utils::maybe_print_graph(tm); } TEST_CASE("epochs create appropriate dependencies to predecessors and successors", "[task_manager][task-graph][epoch]") { - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - - auto buf_a = mbf.create_buffer(range<1>(1)); - const auto tid_a = test_utils::add_compute_task(tm, [&](handler& cgh) { buf_a.get_access(cgh, all{}); }); - - auto buf_b = mbf.create_buffer(range<1>(1)); - const auto tid_b = test_utils::add_compute_task(tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); - - const auto tid_epoch = tm.generate_epoch_task(epoch_action::none); - - const auto tid_c = test_utils::add_compute_task(tm, [&](handler& cgh) { buf_a.get_access(cgh, all{}); }); - const auto tid_d = test_utils::add_compute_task(tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); - const auto tid_e = test_utils::add_compute_task(tm, [&](handler& cgh) {}); - const auto tid_f = test_utils::add_compute_task(tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); - const auto tid_g = test_utils::add_compute_task(tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); - - CHECK(has_dependency(tm, tid_epoch, tid_a)); - CHECK(has_dependency(tm, tid_epoch, tid_b)); - CHECK(has_dependency(tm, tid_c, tid_epoch)); - CHECK_FALSE(has_any_dependency(tm, tid_c, tid_a)); - CHECK(has_dependency(tm, tid_d, tid_epoch)); // needs a true_dep on barrier since it only has anti_deps otherwise - CHECK_FALSE(has_any_dependency(tm, tid_d, tid_b)); - CHECK(has_dependency(tm, tid_e, tid_epoch)); - CHECK(has_dependency(tm, tid_f, tid_d)); - CHECK_FALSE(has_any_dependency(tm, tid_f, tid_epoch)); - CHECK(has_dependency(tm, tid_g, tid_f, dependency_kind::anti_dep)); - CHECK(has_dependency(tm, tid_g, tid_epoch)); // needs a true_dep on barrier since it only has anti_deps otherwise - - test_utils::maybe_print_graph(tm); + auto tt = test_utils::task_test_context{}; + + auto buf_a = tt.mbf.create_buffer(range<1>(1)); + const auto tid_a = + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_a.get_access(cgh, all{}); }); + + auto buf_b = tt.mbf.create_buffer(range<1>(1)); + const auto tid_b = + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); + + const auto tid_epoch = tt.tm.generate_epoch_task(epoch_action::none); + + const auto tid_c = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_a.get_access(cgh, all{}); }); + const auto tid_d = + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); + const auto tid_e = test_utils::add_compute_task(tt.tm, [&](handler& cgh) {}); + const auto tid_f = test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); + const auto tid_g = + test_utils::add_compute_task(tt.tm, [&](handler& cgh) { buf_b.get_access(cgh, all{}); }); + + CHECK(has_dependency(tt.tm, tid_epoch, tid_a)); + CHECK(has_dependency(tt.tm, tid_epoch, tid_b)); + CHECK(has_dependency(tt.tm, tid_c, tid_epoch)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_c, tid_a)); + CHECK(has_dependency(tt.tm, tid_d, tid_epoch)); // needs a true_dep on barrier since it only has anti_deps otherwise + CHECK_FALSE(has_any_dependency(tt.tm, tid_d, tid_b)); + CHECK(has_dependency(tt.tm, tid_e, tid_epoch)); + CHECK(has_dependency(tt.tm, tid_f, tid_d)); + CHECK_FALSE(has_any_dependency(tt.tm, tid_f, tid_epoch)); + CHECK(has_dependency(tt.tm, tid_g, tid_f, dependency_kind::anti_dep)); + CHECK(has_dependency(tt.tm, tid_g, tid_epoch)); // needs a true_dep on barrier since it only has anti_deps otherwise } TEST_CASE("inserting epochs resets the need for horizons", "[task_manager][task-graph][task-horizon][epoch]") { - task_manager tm{1, nullptr, task_recorder{}}; - tm.set_horizon_step(2); + auto tt = test_utils::task_test_context{}; + tt.tm.set_horizon_step(2); + auto buf = tt.mbf.create_buffer(range<1>(1)); - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer(range<1>(1)); for(int i = 0; i < 3; ++i) { - test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); - tm.generate_epoch_task(epoch_action::none); + test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + tt.tm.generate_epoch_task(epoch_action::none); } - CHECK(task_manager_testspy::get_num_horizons(tm) == 0); - - test_utils::maybe_print_graph(tm); + CHECK(task_manager_testspy::get_num_horizons(tt.tm) == 0); } TEST_CASE("a sequence of epochs without intermediate tasks has defined behavior", "[task_manager][task-graph][epoch]") { - task_manager tm{1, nullptr, task_recorder{}}; + auto tt = test_utils::task_test_context{}; auto tid_before = task_manager::initial_epoch_task; for(const auto action : {epoch_action::barrier, epoch_action::shutdown}) { - const auto tid = tm.generate_epoch_task(action); + const auto tid = tt.tm.generate_epoch_task(action); CAPTURE(tid_before, tid); - const auto deps = tm.get_task(tid)->get_dependencies(); + const auto deps = tt.tm.get_task(tid)->get_dependencies(); CHECK(std::distance(deps.begin(), deps.end()) == 1); for(const auto& d : deps) { CHECK(d.kind == dependency_kind::true_dep); @@ -685,40 +650,32 @@ namespace detail { } tid_before = tid; } - - test_utils::maybe_print_graph(tm); } TEST_CASE("fences introduce dependencies on host objects", "[task_manager][task-graph][fence]") { - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_host_object_factory mhof; - auto ho = mhof.create_host_object(); + auto tt = test_utils::task_test_context{}; + auto ho = tt.mhof.create_host_object(); const auto tid_a = test_utils::add_host_task( - tm, celerity::experimental::collective, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); - const auto tid_fence = test_utils::add_fence_task(tm, ho); + tt.tm, celerity::experimental::collective, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); + const auto tid_fence = test_utils::add_fence_task(tt.tm, ho); const auto tid_b = test_utils::add_host_task( - tm, celerity::experimental::collective, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); + tt.tm, celerity::experimental::collective, [&](handler& cgh) { ho.add_side_effect(cgh, experimental::side_effect_order::sequential); }); - CHECK(has_dependency(tm, tid_fence, tid_a)); - CHECK(has_dependency(tm, tid_b, tid_fence)); - - test_utils::maybe_print_graph(tm); + CHECK(has_dependency(tt.tm, tid_fence, tid_a)); + CHECK(has_dependency(tt.tm, tid_b, tid_fence)); } TEST_CASE("fences introduce data dependencies", "[task_manager][task-graph][fence]") { - task_manager tm{1, nullptr, task_recorder{}}; - test_utils::mock_buffer_factory mbf(tm); - auto buf = mbf.create_buffer<1>({1}); - - const auto tid_a = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); - const auto tid_fence = test_utils::add_fence_task(tm, buf); - const auto tid_b = test_utils::add_host_task(tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + auto tt = test_utils::task_test_context{}; + auto buf = tt.mbf.create_buffer<1>({1}); - CHECK(has_dependency(tm, tid_fence, tid_a)); - CHECK(has_dependency(tm, tid_b, tid_fence, dependency_kind::anti_dep)); + const auto tid_a = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); + const auto tid_fence = test_utils::add_fence_task(tt.tm, buf); + const auto tid_b = test_utils::add_host_task(tt.tm, on_master_node, [&](handler& cgh) { buf.get_access(cgh, all{}); }); - test_utils::maybe_print_graph(tm); + CHECK(has_dependency(tt.tm, tid_fence, tid_a)); + CHECK(has_dependency(tt.tm, tid_b, tid_fence, dependency_kind::anti_dep)); } } // namespace detail diff --git a/test/test_utils.h b/test/test_utils.h index 6b3baf868..d3f91f9d7 100644 --- a/test/test_utils.h +++ b/test/test_utils.h @@ -56,7 +56,8 @@ namespace detail { static executor& get_exec(runtime& rt) { return *rt.m_exec; } static size_t get_command_count(runtime& rt) { return rt.m_cdag->command_count(); } static command_graph& get_cdag(runtime& rt) { return *rt.m_cdag; } - static std::string print_graph(runtime& rt) { return rt.m_schdlr->print_command_graph(); } + static std::string print_task_graph(runtime& rt) { return detail::print_task_graph(*rt.m_task_recorder); } + static std::string print_command_graph(const node_id local_nid, runtime& rt) { return detail::print_command_graph(local_nid, *rt.m_command_recorder); } }; struct task_ring_buffer_testspy { @@ -337,15 +338,25 @@ namespace test_utils { // Printing of graphs can be enabled using the "--print-graphs" command line flag inline bool print_graphs = false; - inline void maybe_print_graph(celerity::detail::task_manager& tm) { - if(print_graphs) { CELERITY_INFO("Task graph:\n\n{}\n", tm.print_task_graph()); } + inline void maybe_print_task_graph(const detail::task_recorder& trec) { + if(print_graphs) { CELERITY_INFO("Task graph:\n\n{}\n", detail::print_task_graph(trec)); } } - template - inline void maybe_print_command_graph(const CommandPrinter& cmdp) { - if(print_graphs) { CELERITY_INFO("Command graph:\n\n{}\n", cmdp.print_command_graph()); } + inline void maybe_print_command_graph(const detail::node_id local_nid, const detail::command_recorder& crec) { + if(print_graphs) { CELERITY_INFO("Command graph:\n\n{}\n", detail::print_command_graph(local_nid, crec)); } } + struct task_test_context { + detail::task_recorder trec; + detail::task_manager tm; + mock_buffer_factory mbf; + mock_host_object_factory mhof; + mock_reduction_factory mrf; + + task_test_context() : tm(1, nullptr, &trec), mbf(tm) {} + ~task_test_context() { maybe_print_task_graph(trec); } + }; + class set_test_env { public: #ifdef _WIN32