Skip to content

Commit

Permalink
Address review comments, add task test context
Browse files Browse the repository at this point in the history
 * Recorders no longer owned by task manager / graph gen
 * As a consequence, task graph tests would have been even more
   cumbersome to write -> add context utility
   - this also obviates the need for explicit maybe_print!
 * Also fix member var naming scheme and namespaces according to review
   suggestions
  • Loading branch information
PeterTh committed Aug 9, 2023
1 parent c60e073 commit f846555
Show file tree
Hide file tree
Showing 20 changed files with 764 additions and 824 deletions.
7 changes: 3 additions & 4 deletions include/distributed_graph_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,16 +75,14 @@ class distributed_graph_generator {

public:
distributed_graph_generator(
const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, const std::optional<detail::command_recorder>& recorder);
const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, detail::command_recorder* recorder);

void add_buffer(const buffer_id bid, const int dims, const range<3>& range);

std::unordered_set<abstract_command*> build_task(const task& tsk);

command_graph& get_command_graph() { return m_cdag; }

std::string print_command_graph() const;

private:
// Wrapper around command_graph::create that adds commands to current batch set.
template <typename T, typename... Args>
Expand Down Expand Up @@ -145,7 +143,8 @@ class distributed_graph_generator {
// Side effects on the same host object create true dependencies between task commands, so we track the last effect per host object.
side_effect_map m_host_object_last_effects;

std::optional<detail::command_recorder> m_recorder;
// Generated commands will be recorded to this recorder if it is set
detail::command_recorder* m_recorder = nullptr;
};

} // namespace celerity::detail
Expand Down
66 changes: 31 additions & 35 deletions include/print_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,15 @@ namespace detail {
struct task_printing_information {
task_printing_information(const task& from, const buffer_manager* buff_man);

const task_id m_tid;
const std::string m_debug_name;
const collective_group_id m_cgid;
const task_type m_type;
const task_geometry m_geometry;
const reduction_list m_reductions;
const access_list m_accesses;
const side_effect_map m_side_effect_map;
const task_dependency_list m_dependencies;
const task_id tid;
const std::string debug_name;
const collective_group_id cgid;
const task_type type;
const task_geometry geometry;
const reduction_list reductions;
const access_list accesses;
const side_effect_map side_effect_map;
const task_dependency_list dependencies;
};

class task_recorder {
Expand All @@ -73,35 +73,33 @@ namespace detail {
const buffer_manager* m_buff_man;
};

const std::optional<task_recorder> no_task_recorder = {};

// Command recording

using command_dependency_list = std::vector<dependency_record<command_id>>;

struct command_printing_information {
const command_id m_cid;
const command_type m_type;

const std::optional<epoch_action> m_epoch_action;
const std::optional<subrange<3>> m_execution_range;
const std::optional<reduction_id> m_reduction_id;
const std::optional<buffer_id> m_buffer_id;
const std::string m_buffer_name;
const std::optional<node_id> m_target;
const std::optional<GridRegion<3>> m_await_region;
const std::optional<subrange<3>> m_push_range;
const std::optional<transfer_id> m_transfer_id;
const std::optional<task_id> m_task_id;
const std::optional<task_geometry> m_task_geometry;
const bool m_is_reduction_initializer;
const std::optional<access_list> m_accesses;
const std::optional<reduction_list> m_reductions;
const std::optional<side_effect_map> m_side_effects;
const command_dependency_list m_dependencies;
const std::string m_task_name;
const std::optional<task_type> m_task_type;
const std::optional<collective_group_id> m_collective_group_id;
const command_id cid;
const command_type type;

const std::optional<epoch_action> epoch_action;
const std::optional<subrange<3>> execution_range;
const std::optional<reduction_id> reduction_id;
const std::optional<buffer_id> buffer_id;
const std::string buffer_name;
const std::optional<node_id> target;
const std::optional<GridRegion<3>> await_region;
const std::optional<subrange<3>> push_range;
const std::optional<transfer_id> transfer_id;
const std::optional<task_id> task_id;
const std::optional<task_geometry> task_geometry;
const bool is_reduction_initializer;
const std::optional<access_list> accesses;
const std::optional<reduction_list> reductions;
const std::optional<side_effect_map> side_effects;
const command_dependency_list dependencies;
const std::string task_name;
const std::optional<task_type> task_type;
const std::optional<collective_group_id> collective_group_id;

command_printing_information(const abstract_command& cmd, const task_manager* task_man, const buffer_manager* buff_man);
};
Expand All @@ -122,8 +120,6 @@ namespace detail {
const buffer_manager* m_buff_man;
};

const std::optional<command_recorder> no_command_recorder = {};

// Printing interface

std::string print_task_graph(const task_recorder& recorder);
Expand Down
4 changes: 4 additions & 0 deletions include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include "device_queue.h"
#include "frame.h"
#include "host_queue.h"
#include "print_graph.h"
#include "types.h"

namespace celerity {
Expand Down Expand Up @@ -110,6 +111,9 @@ namespace detail {
std::unique_ptr<task_manager> m_task_mngr;
std::unique_ptr<executor> m_exec;

std::unique_ptr<detail::task_recorder> m_task_recorder;
std::unique_ptr<detail::command_recorder> m_command_recorder;

runtime(int* argc, char** argv[], device_or_selector user_device_or_selector);
runtime(const runtime&) = delete;
runtime(runtime&&) = delete;
Expand Down
2 changes: 0 additions & 2 deletions include/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ namespace detail {

void notify_buffer_registered(const buffer_id bid, const int dims, const range<3>& range) { notify(event_buffer_registered{bid, dims, range}); }

std::string print_command_graph() const;

protected:
/**
* This is called by the worker thread.
Expand Down
6 changes: 2 additions & 4 deletions include/task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ namespace detail {
public:
constexpr inline static task_id initial_epoch_task = 0;

task_manager(size_t num_collective_nodes, host_queue* queue, std::optional<detail::task_recorder> recorder);
task_manager(size_t num_collective_nodes, host_queue* queue, detail::task_recorder* recorder);

virtual ~task_manager() = default;

Expand Down Expand Up @@ -130,8 +130,6 @@ namespace detail {
*/
const task* get_task(task_id tid) const;

std::string print_task_graph() const;

/**
* Blocks until an epoch task has executed on this node (or all nodes, if the epoch_for_new_tasks was created with `epoch_action::barrier`).
*/
Expand Down Expand Up @@ -222,7 +220,7 @@ namespace detail {
std::unordered_set<task*> m_execution_front;

// An optional task_recorder which records information about tasks for e.g. printing graphs.
mutable std::optional<detail::task_recorder> m_task_recorder;
mutable detail::task_recorder* m_task_recorder;

task& register_task_internal(task_ring_buffer::reservation&& reserve, std::unique_ptr<task> task);

Expand Down
12 changes: 3 additions & 9 deletions src/distributed_graph_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
namespace celerity::detail {

distributed_graph_generator::distributed_graph_generator(
const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, const std::optional<detail::command_recorder>& recorder)
const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm, detail::command_recorder* recorder)
: m_num_nodes(num_nodes), m_local_nid(local_nid), m_cdag(cdag), m_task_mngr(tm), m_recorder(recorder) {
if(m_num_nodes > max_num_nodes) {
throw std::runtime_error(fmt::format("Number of nodes requested ({}) exceeds compile-time maximum of {}", m_num_nodes, max_num_nodes));
Expand All @@ -20,7 +20,7 @@ distributed_graph_generator::distributed_graph_generator(
// set_epoch_for_new_commands).
auto* const epoch_cmd = cdag.create<epoch_command>(task_manager::initial_epoch_task, epoch_action::none);
epoch_cmd->mark_as_flushed(); // there is no point in flushing the initial epoch command
if(m_recorder.has_value()) m_recorder->record_command(*epoch_cmd);
if(m_recorder != nullptr) m_recorder->record_command(*epoch_cmd);
m_epoch_for_new_commands = epoch_cmd->get_cid();
}

Expand Down Expand Up @@ -135,7 +135,7 @@ std::unordered_set<abstract_command*> distributed_graph_generator::build_task(co
prune_commands_before(epoch_to_prune_before);

// If we have a command_recorder, record the current batch of commands
if(m_recorder) {
if(m_recorder != nullptr) {
for(const auto& cmd : m_current_cmd_batch) {
m_recorder->record_command(*cmd);
}
Expand All @@ -144,12 +144,6 @@ std::unordered_set<abstract_command*> distributed_graph_generator::build_task(co
return std::move(m_current_cmd_batch);
}

std::string distributed_graph_generator::print_command_graph() const {
if(m_recorder.has_value()) { return detail::print_command_graph(m_local_nid, m_recorder.value()); }
CELERITY_ERROR("Trying to print command graph, but no recorder available");
return "";
}

void distributed_graph_generator::generate_distributed_commands(const task& tsk) {
chunk<3> full_chunk{tsk.get_global_offset(), tsk.get_global_size(), tsk.get_global_size()};
const size_t num_chunks = m_num_nodes * 1; // TODO Make configurable
Expand Down
Loading

0 comments on commit f846555

Please sign in to comment.