Skip to content

Commit

Permalink
Explicitly manage buffer / host object lifetimes in graph generation
Browse files Browse the repository at this point in the history
task_manager and distributed_graph_generator already maintain state for each
buffer and host object, but keep it around indefinitely even after the
buffer or host object in question is destroyed. Also, neither have
access to buffer debug names and thus can't include that information in
error reports (such as uninitialized-read detection).

This commit adds explicit methods for tracking the creation and
destruction of objects to task_manager, distributed_graph_generator,
scheduler (and now by necessity, runtime, which receives these requests
directly instead of via the buffer_lifetime_callback).

This also removes the recorder -> buffer_manager dependency by
replicating the buffer name (like all other metadata) in both graph
generators. This foreshadows the eventual removal of buffer_manager with
the merge of instruction graph scheduling.
  • Loading branch information
fknorr committed Feb 9, 2024
1 parent a785316 commit 33d4e65
Show file tree
Hide file tree
Showing 25 changed files with 387 additions and 257 deletions.
5 changes: 2 additions & 3 deletions include/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

#include <CL/sycl.hpp>

#include "buffer_manager.h"
#include "lifetime_extending_state.h"
#include "range_mapper.h"
#include "ranges.h"
Expand Down Expand Up @@ -98,13 +97,13 @@ class buffer final : public detail::lifetime_extending_state_wrapper {
struct impl final : public detail::lifetime_extending_state {
impl(celerity::range<Dims> rng, const DataT* host_init_ptr) : range(rng) {
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr); }
id = detail::runtime::get_instance().get_buffer_manager().register_buffer<DataT, Dims>(detail::range_cast<3>(range), host_init_ptr);
id = detail::runtime::get_instance().create_buffer<DataT, Dims>(detail::range_cast<3>(range), host_init_ptr);
}
impl(const impl&) = delete;
impl(impl&&) = delete;
impl& operator=(const impl&) = delete;
impl& operator=(impl&&) = delete;
~impl() override { detail::runtime::get_instance().get_buffer_manager().unregister_buffer(id); }
~impl() override { detail::runtime::get_instance().destroy_buffer(id); }
detail::buffer_id id;
celerity::range<Dims> range;
std::string debug_name;
Expand Down
9 changes: 1 addition & 8 deletions include/buffer_manager.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#pragma once

#include <cstring>
#include <functional>
#include <memory>
#include <mutex>
#include <shared_mutex>
Expand Down Expand Up @@ -84,10 +83,6 @@ namespace detail {
friend struct buffer_manager_testspy;

public:
enum class buffer_lifecycle_event { registered, unregistered };

using buffer_lifecycle_callback = std::function<void(buffer_lifecycle_event, buffer_id)>;

using device_buffer_factory = std::function<std::unique_ptr<buffer_storage>(const range<3>&, device_queue&)>;
using host_buffer_factory = std::function<std::unique_ptr<buffer_storage>(const range<3>&)>;

Expand Down Expand Up @@ -126,7 +121,7 @@ namespace detail {
using buffer_lock_id = size_t;

public:
buffer_manager(device_queue& queue, buffer_lifecycle_callback lifecycle_cb);
explicit buffer_manager(device_queue& queue);

template <typename DataT, int Dims>
buffer_id register_buffer(range<3> range, const DataT* host_init_ptr = nullptr) {
Expand Down Expand Up @@ -157,7 +152,6 @@ namespace detail {
auto info = access_host_buffer(bid, access_mode::discard_write, {{}, range});
std::memcpy(info.ptr, host_init_ptr, range.size() * sizeof(DataT));
}
m_lifecycle_cb(buffer_lifecycle_event::registered, bid);
return bid;
}

Expand Down Expand Up @@ -347,7 +341,6 @@ namespace detail {
// Leave some memory for other processes.
double m_max_device_global_mem_usage = 0.95;
device_queue& m_queue;
buffer_lifecycle_callback m_lifecycle_cb;
size_t m_buffer_count = 0;
mutable std::shared_mutex m_mutex;
std::unordered_map<buffer_id, buffer_info> m_buffer_infos;
Expand Down
2 changes: 1 addition & 1 deletion include/debug.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace debug {
template <typename DataT, int Dims>
void set_buffer_name(const celerity::buffer<DataT, Dims>& buff, const std::string& debug_name) {
detail::set_buffer_name(buff, debug_name);
detail::runtime::get_instance().get_buffer_manager().set_debug_name(detail::get_buffer_id(buff), debug_name);
detail::runtime::get_instance().set_buffer_debug_name(detail::get_buffer_id(buff), debug_name);
}
template <typename DataT, int Dims>
std::string get_buffer_name(const celerity::buffer<DataT, Dims>& buff) {
Expand Down
26 changes: 21 additions & 5 deletions include/distributed_graph_generator.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,13 @@ class distributed_graph_generator {
// but mark it as having a pending reduction. The final reduction will then be generated when the buffer
// is used in a subsequent read requirement. This avoids generating unnecessary reduction commands.
std::optional<reduction_info> pending_reduction;

std::string debug_name;
};

struct host_object_state {
// Side effects on the same host object create true dependencies between task commands, so we track the last effect per host object.
std::optional<command_id> last_side_effect;
};

public:
Expand All @@ -83,7 +90,15 @@ class distributed_graph_generator {
distributed_graph_generator(const size_t num_nodes, const node_id local_nid, command_graph& cdag, const task_manager& tm,
detail::command_recorder* recorder, const policy_set& policy = default_policy_set());

void add_buffer(const buffer_id bid, const range<3>& range, bool host_initialized);
void create_buffer(buffer_id bid, const range<3>& range, bool host_initialized);

void set_buffer_debug_name(buffer_id bid, const std::string& debug_name);

void destroy_buffer(buffer_id bid);

void create_host_object(host_object_id hoid);

void destroy_host_object(host_object_id hoid);

std::unordered_set<abstract_command*> build_task(const task& tsk);

Expand Down Expand Up @@ -131,12 +146,15 @@ class distributed_graph_generator {
// initializers, within its surrounding class (Clang)
constexpr static policy_set default_policy_set() { return {}; }

std::string print_buffer_debug_label(buffer_id bid) const;

size_t m_num_nodes;
node_id m_local_nid;
policy_set m_policy;
command_graph& m_cdag;
const task_manager& m_task_mngr;
std::unordered_map<buffer_id, buffer_state> m_buffer_states;
std::unordered_map<buffer_id, buffer_state> m_buffers;
std::unordered_map<host_object_id, host_object_state> m_host_objects;
command_id m_epoch_for_new_commands = 0;
command_id m_epoch_last_pruned_before = 0;
command_id m_current_horizon = no_command;
Expand All @@ -153,9 +171,6 @@ class distributed_graph_generator {
// they are executed in the same order on every node.
std::unordered_map<collective_group_id, command_id> m_last_collective_commands;

// Side effects on the same host object create true dependencies between task commands, so we track the last effect per host object.
side_effect_map m_host_object_last_effects;

// Generated commands will be recorded to this recorder if it is set
detail::command_recorder* m_recorder = nullptr;
};
Expand All @@ -167,4 +182,5 @@ template <>
struct hash<celerity::detail::write_command_state> {
size_t operator()(const celerity::detail::write_command_state& wcs) const { return std::hash<size_t>{}(static_cast<celerity::detail::command_id>(wcs)); }
};

} // namespace std
4 changes: 2 additions & 2 deletions include/host_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,15 @@ struct host_object_tracker : public lifetime_extending_state {

host_object_tracker() {
if(!detail::runtime::is_initialized()) { detail::runtime::init(nullptr, nullptr); }
id = detail::runtime::get_instance().get_host_object_manager().create_host_object();
id = detail::runtime::get_instance().create_host_object();
}

host_object_tracker(const host_object_tracker&) = delete;
host_object_tracker(host_object_tracker&&) = delete;
host_object_tracker& operator=(host_object_tracker&&) = delete;
host_object_tracker& operator=(const host_object_tracker&) = delete;

~host_object_tracker() { detail::runtime::get_instance().get_host_object_manager().destroy_host_object(id); }
~host_object_tracker() { detail::runtime::get_instance().destroy_host_object(id); }
};

// see host_object deduction guides
Expand Down
108 changes: 60 additions & 48 deletions include/recorders.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ struct access_record {
const region<3> req;
};
using access_list = std::vector<access_record>;
using buffer_name_map = std::unordered_map<buffer_id, std::string>;

struct reduction_record {
const reduction_id rid;
Expand All @@ -38,79 +39,90 @@ struct dependency_record {
using task_dependency_list = std::vector<dependency_record<task_id>>;

struct task_record {
task_record(const task& from, const buffer_manager* buff_mngr);

const task_id tid;
const std::string debug_name;
const collective_group_id cgid;
const task_type type;
const task_geometry geometry;
const reduction_list reductions;
const access_list accesses;
const detail::side_effect_map side_effect_map;
const task_dependency_list dependencies;
task_record(const task& tsk, const buffer_name_map& accessed_buffer_names);

task_id tid;
std::string debug_name;
collective_group_id cgid;
task_type type;
task_geometry geometry;
reduction_list reductions;
access_list accesses;
detail::side_effect_map side_effect_map;
task_dependency_list dependencies;
};

class task_recorder {
public:
using task_record = std::vector<detail::task_record>;
using task_records = std::vector<detail::task_record>;

task_recorder(const buffer_manager* buff_mngr = nullptr) : m_buff_mngr(buff_mngr) {}
friend task_recorder& operator<<(task_recorder& recorder, task_record&& record) {
recorder.m_recorded_tasks.push_back(std::move(record));
return recorder;
}

void record_task(const task& tsk);
const task_records& get_tasks() const { return m_recorded_tasks; }

const task_record& get_tasks() const { return m_recorded_tasks; }
const task_record& get_task(const task_id tid) const {
const auto it = std::find_if(m_recorded_tasks.begin(), m_recorded_tasks.end(), [tid](const task_record& rec) { return rec.tid == tid; });
assert(it != m_recorded_tasks.end());
return *it;
}

private:
task_record m_recorded_tasks;
const buffer_manager* m_buff_mngr;
task_records m_recorded_tasks;
};

// Command recording

using command_dependency_list = std::vector<dependency_record<command_id>>;

struct command_record {
const command_id cid;
const command_type type;

const std::optional<detail::epoch_action> epoch_action;
const std::optional<subrange<3>> execution_range;
const std::optional<detail::reduction_id> reduction_id;
const std::optional<detail::buffer_id> buffer_id;
const std::string buffer_name;
const std::optional<node_id> target;
const std::optional<region<3>> await_region;
const std::optional<subrange<3>> push_range;
const std::optional<detail::transfer_id> transfer_id;
const std::optional<detail::task_id> task_id;
const std::optional<detail::task_geometry> task_geometry;
const bool is_reduction_initializer;
const std::optional<access_list> accesses;
const std::optional<reduction_list> reductions;
const std::optional<side_effect_map> side_effects;
const command_dependency_list dependencies;
const std::string task_name;
const std::optional<detail::task_type> task_type;
const std::optional<detail::collective_group_id> collective_group_id;

command_record(const abstract_command& cmd, const task_manager* task_mngr, const buffer_manager* buff_mngr);
command_id cid;
command_type type;

std::optional<detail::epoch_action> epoch_action;
std::optional<subrange<3>> execution_range;
std::optional<detail::reduction_id> reduction_id;
std::optional<detail::buffer_id> buffer_id;
std::string buffer_name;
std::optional<node_id> target;
std::optional<region<3>> await_region;
std::optional<subrange<3>> push_range;
std::optional<detail::transfer_id> transfer_id;
std::optional<detail::task_id> task_id;
std::optional<detail::task_geometry> task_geometry;
bool is_reduction_initializer;
std::optional<access_list> accesses;
std::optional<reduction_list> reductions;
std::optional<side_effect_map> side_effects;
command_dependency_list dependencies;
std::string task_name;
std::optional<detail::task_type> task_type;
std::optional<detail::collective_group_id> collective_group_id;

command_record(const abstract_command& cmd, const task& tsk, const buffer_name_map& accessed_buffer_names);
};

class command_recorder {
public:
using command_record = std::vector<detail::command_record>;
using command_records = std::vector<detail::command_record>;

command_recorder(const task_manager* task_mngr, const buffer_manager* buff_mngr = nullptr) : m_task_mngr(task_mngr), m_buff_mngr(buff_mngr) {}
friend command_recorder& operator<<(command_recorder& recorder, command_record&& record) {
recorder.m_recorded_commands.push_back(std::move(record));
return recorder;
}

void record_command(const abstract_command& com);
const command_records& get_commands() const { return m_recorded_commands; }

const command_record& get_commands() const { return m_recorded_commands; }
const command_record& get_command(const command_id cid) const {
const auto it = std::find_if(m_recorded_commands.begin(), m_recorded_commands.end(), [cid](const command_record& rec) { return rec.cid == cid; });
assert(it != m_recorded_commands.end());
return *it;
}

private:
command_record m_recorded_commands;
const task_manager* m_task_mngr;
const buffer_manager* m_buff_mngr;
command_records m_recorded_commands;
};

} // namespace celerity::detail
22 changes: 16 additions & 6 deletions include/runtime.h
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
#pragma once

#include <deque>
#include <limits>
#include <memory>

#include "buffer_manager.h"
#include "command.h"
#include "config.h"
#include "device_queue.h"
#include "frame.h"
#include "host_queue.h"
#include "recorders.h"
#include "types.h"
Expand Down Expand Up @@ -73,7 +71,20 @@ namespace detail {

reduction_manager& get_reduction_manager() const;

host_object_manager& get_host_object_manager() const;
template <typename DataT, int Dims>
buffer_id create_buffer(const range<3>& range, const DataT* host_init_ptr) {
const auto bid = m_buffer_mngr->register_buffer<DataT, Dims>(range, host_init_ptr);
this->register_buffer(bid, range, host_init_ptr != nullptr);
return bid;
}

void set_buffer_debug_name(buffer_id bid, const std::string& debug_name);

void destroy_buffer(buffer_id bid);

host_object_id create_host_object();

void destroy_host_object(host_object_id hoid);

// returns the combined command graph of all nodes on node 0, an empty string on other nodes
std::string gather_command_graph() const;
Expand Down Expand Up @@ -118,8 +129,7 @@ namespace detail {
runtime(const runtime&) = delete;
runtime(runtime&&) = delete;

void handle_buffer_registered(buffer_id bid);
void handle_buffer_unregistered(buffer_id bid);
void register_buffer(buffer_id bid, const range<3>& range, bool host_initialized);

/**
* @brief Destroys the runtime if it is no longer active and all buffers have been unregistered.
Expand Down
Loading

0 comments on commit 33d4e65

Please sign in to comment.