Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Task ringbuffer #112

Merged
merged 8 commits into from
Jul 18, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 86 additions & 84 deletions ci/perf/gpuc2_bench.csv

Large diffs are not rendered by default.

176 changes: 89 additions & 87 deletions ci/perf/gpuc2_bench.md

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions include/print_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

#include <memory>
#include <string>
#include <unordered_map>

#include "task.h"
#include "task_ring_buffer.h"

namespace celerity {
namespace detail {

class command_graph;
class task_manager;

std::string print_task_graph(const std::unordered_map<task_id, std::unique_ptr<task>>& tdag);
std::string print_task_graph(const task_ring_buffer& tdag);
std::string print_command_graph(const command_graph& cdag, const task_manager& tm);

} // namespace detail
Expand Down
4 changes: 2 additions & 2 deletions include/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ namespace detail {
*/
void startup();

void shutdown() noexcept;
void shutdown();

void sync() noexcept;
void sync();

bool is_master_node() const { return local_nid == 0; }

Expand Down
48 changes: 26 additions & 22 deletions include/task_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "host_queue.h"
#include "region_map.h"
#include "task.h"
#include "task_ring_buffer.h"
#include "types.h"

namespace celerity {
Expand All @@ -30,9 +31,10 @@ namespace detail {
return this_epoch;
}

void await(const task_id epoch) const {
task_id await(const task_id min_tid_reached) const {
std::unique_lock lock{mutex};
epoch_changed.wait(lock, [=] { return this_epoch >= epoch; });
epoch_changed.wait(lock, [=] { return this_epoch >= min_tid_reached; });
return this_epoch;
}

void set(const task_id epoch) {
Expand Down Expand Up @@ -66,16 +68,20 @@ namespace detail {
task_id tid;
{
std::lock_guard lock(task_mutex);
tid = get_new_tid();
auto reservation = task_buffer.reserve_task_entry(await_free_task_slot_callback());
tid = reservation.get_tid();

prepass_handler cgh(tid, std::make_unique<command_group_storage<CGF>>(cgf), num_collective_nodes);
cgf(cgh);

task& task_ref = register_task_internal(std::move(cgh).into_task());
task& task_ref = register_task_internal(std::move(reservation), std::move(cgh).into_task());

compute_dependencies(tid);
if(queue) queue->require_collective_group(task_ref.get_collective_group_id());
prune_tasks_before_latest_epoch_reached();

// the following deletion is intentionally redundant with the one happening when waiting for free task slots
// we want to free tasks earlier than just when running out of slots,
// so that we can potentially reclaim additional resources such as buffers earlier
task_buffer.delete_up_to(latest_epoch_reached.get());
}
invoke_callbacks(tid);
if(need_new_horizon()) { generate_horizon_task(); }
Expand Down Expand Up @@ -127,7 +133,7 @@ namespace detail {
/**
* @brief Shuts down the task_manager, freeing all stored tasks.
*/
void shutdown() { task_map.clear(); }
void shutdown() { task_buffer.clear(); }

void set_horizon_step(const int step) {
assert(step >= 0);
Expand All @@ -154,21 +160,20 @@ namespace detail {
* Returns the number of tasks created during the lifetime of the task_manager,
* including tasks that have already been deleted.
*/
task_id get_total_task_count() const { return next_task_id; }
size_t get_total_task_count() const { return task_buffer.get_total_task_count(); }

/**
* Returns the number of tasks currently being managed by the task_manager.
*/
task_id get_current_task_count() const { return task_map.size(); }
size_t get_current_task_count() const { return task_buffer.get_current_task_count(); }

private:
const size_t num_collective_nodes;
host_queue* queue;

reduction_manager* reduction_mngr;

task_id next_task_id = 1;
std::unordered_map<task_id, std::unique_ptr<task>> task_map;
task_ring_buffer task_buffer;

// The active epoch is used as the last writer for host-initialized buffers.
// This is useful so we can correctly generate anti-dependencies onto tasks that read host-initialized buffers.
Expand All @@ -184,7 +189,7 @@ namespace detail {
// Stores which host object was last affected by which task.
std::unordered_map<host_object_id, task_id> host_object_last_effects;

// For simplicity we use a single mutex to control access to all task-related (i.e. the task graph, task_map, ...) data structures.
// For simplicity we use a single mutex to control access to all task-related (i.e. the task graph, ...) data structures.
mutable std::mutex task_mutex;

std::vector<task_callback> task_callbacks;
Expand All @@ -207,15 +212,10 @@ namespace detail {
// The last epoch task that has been processed by the executor. Behind a monitor to allow awaiting this change from the main thread.
epoch_monitor latest_epoch_reached{initial_epoch_task};

// The last epoch that was used in task pruning after being reached. This allows skipping the pruning step if no new epoch was completed since.
task_id last_pruned_before{initial_epoch_task};

// Set of tasks with no dependents
std::unordered_set<task*> execution_front;

inline task_id get_new_tid() { return next_task_id++; }

task& register_task_internal(std::unique_ptr<task> task);
task& register_task_internal(task_ring_buffer::reservation&& reserve, std::unique_ptr<task> task);
PeterTh marked this conversation as resolved.
Show resolved Hide resolved

void invoke_callbacks(task_id tid);

Expand All @@ -225,18 +225,22 @@ namespace detail {

int get_max_pseudo_critical_path_length() const { return max_pseudo_critical_path_length; }

task_id reduce_execution_front(std::unique_ptr<task> new_front);
task_id reduce_execution_front(task_ring_buffer::reservation&& reserve, std::unique_ptr<task> new_front);

void set_epoch_for_new_tasks(task_id epoch);

const std::unordered_set<task*>& get_execution_front() { return execution_front; }

task_id generate_horizon_task();

// Needs to be called while task map accesses are safe (ie. mutex is locked)
void prune_tasks_before_latest_epoch_reached();

void compute_dependencies(task_id tid);

// Finds the first in-flight epoch, or returns the currently reached one if there are none in-flight
// Used in await_free_task_slot_callback to check for hangs
task_id get_first_in_flight_epoch() const;

// Returns a callback which blocks until any epoch task has executed, freeing new task slots
task_ring_buffer::wait_callback await_free_task_slot_callback();
};

} // namespace detail
Expand Down
140 changes: 140 additions & 0 deletions include/task_ring_buffer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#pragma once

#include <array>
#include <atomic>
#include <memory>

#include "log.h"
#include "task.h"
#include "types.h"

namespace celerity::detail {

constexpr unsigned long task_ringbuffer_size = 1024;

class task_ring_buffer {
friend struct task_ring_buffer_testspy;

public:
// This is an RAII type for ensuring correct handling of task id reservations
// in the presence of exceptions (i.e. revoking the reservation on stack unwinding)
class reservation {
friend class task_ring_buffer;

public:
reservation(task_id tid, task_ring_buffer& buffer) : tid(tid), buffer(buffer) {}
~reservation() {
if(!consumed) {
CELERITY_WARN("Consumed reservation for tid {} in destructor", tid);
buffer.revoke_reservation(std::move(*this));
}
}
reservation(const reservation&) = delete; // non copyable
reservation& operator=(const reservation&) = delete; // non assignable
reservation(reservation&&) = default; // movable

task_id get_tid() const { return tid; }

private:
void consume() {
assert(consumed == false);
consumed = true;
}

bool consumed = false;
task_id tid;
task_ring_buffer& buffer;
};

bool has_task(task_id tid) const {
return tid >= number_of_deleted_tasks.load(std::memory_order_relaxed) // best effort, only reliable from application thread
&& tid < next_active_tid.load(std::memory_order_acquire); // synchronizes access to data with put(...)
}

size_t get_total_task_count() const { return next_active_tid.load(std::memory_order_relaxed); }

task* find_task(task_id tid) const { return has_task(tid) ? data[tid % task_ringbuffer_size].get() : nullptr; }

task* get_task(task_id tid) const {
assert(has_task(tid));
return data[tid % task_ringbuffer_size].get();
}

// all member functions beyond this point may *only* be called by the main application thread

size_t get_current_task_count() const { //
return next_active_tid.load(std::memory_order_relaxed) - number_of_deleted_tasks.load(std::memory_order_relaxed);
}

// the task id passed to the wait callback identifies the lowest in-use TID that the ring buffer is aware of
using wait_callback = std::function<void(task_id)>;

reservation reserve_task_entry(const wait_callback& wc) {
wait_for_available_slot(wc);
reservation ret(next_task_id, *this);
next_task_id++;
return ret;
}

void revoke_reservation(reservation&& reserve) {
reserve.consume();
assert(reserve.tid == next_task_id - 1); // this is the only allowed (and extant) pattern
next_task_id--;
}

void put(reservation&& reserve, std::unique_ptr<task> task) {
reserve.consume();
assert(next_active_tid.load(std::memory_order_relaxed) == reserve.tid);
data[reserve.tid % task_ringbuffer_size] = std::move(task);
next_active_tid.store(reserve.tid + 1, std::memory_order_release);
}

void delete_up_to(task_id target_tid) {
assert(target_tid >= number_of_deleted_tasks.load(std::memory_order_relaxed));
for(task_id tid = number_of_deleted_tasks.load(std::memory_order_relaxed); tid < target_tid; ++tid) {
data[tid % task_ringbuffer_size].reset();
}
number_of_deleted_tasks.store(target_tid, std::memory_order_relaxed);
}

void clear() {
for(auto&& d : data) {
d.reset();
}
number_of_deleted_tasks.store(next_task_id, std::memory_order_relaxed);
}

class task_buffer_iterator {
unsigned long id;
const task_ring_buffer& buffer;

public:
task_buffer_iterator(unsigned long id, const task_ring_buffer& buffer) : id(id), buffer(buffer) {}
task* operator*() { return buffer.get_task(id); }
void operator++() { id++; }
bool operator<(task_buffer_iterator other) { return id < other.id; }
bool operator!=(task_buffer_iterator other) { return &buffer != &other.buffer || id != other.id; }
};

task_buffer_iterator begin() const { //
return task_buffer_iterator(number_of_deleted_tasks.load(std::memory_order_relaxed), *this);
}
task_buffer_iterator end() const { return task_buffer_iterator(next_task_id, *this); }

private:
// the id of the next task that will be reserved
task_id next_task_id = 0;
// the next task id that will actually be emplaced
std::atomic<task_id> next_active_tid = task_id(0);
// the number of deleted tasks (which is implicitly the start of the active range of the ringbuffer)
std::atomic<size_t> number_of_deleted_tasks = 0;
std::array<std::unique_ptr<task>, task_ringbuffer_size> data;

void wait_for_available_slot(const wait_callback& wc) const {
if(next_task_id - number_of_deleted_tasks.load(std::memory_order_relaxed) >= task_ringbuffer_size) {
wc(static_cast<task_id>(number_of_deleted_tasks.load(std::memory_order_relaxed)));
}
}
};

} // namespace celerity::detail
6 changes: 2 additions & 4 deletions src/print_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,11 @@ namespace detail {
}
}

std::string print_task_graph(const std::unordered_map<task_id, std::unique_ptr<task>>& tdag) {
std::string print_task_graph(const task_ring_buffer& tdag) {
std::ostringstream ss;
ss << "digraph G { label=\"Task Graph\" ";

for(auto& it : tdag) {
const auto tsk = it.second.get();

for(auto tsk : tdag) {
std::unordered_map<std::string, std::string> props;
props["label"] = "\"" + get_task_label(tsk) + "\"";

Expand Down
4 changes: 2 additions & 2 deletions src/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ namespace detail {
set_thread_name(get_current_thread_handle(), "cy-main");
}

void runtime::shutdown() noexcept {
void runtime::shutdown() {
assert(is_active);
is_shutting_down = true;

Expand Down Expand Up @@ -229,7 +229,7 @@ namespace detail {
maybe_destroy_runtime();
}

void runtime::sync() noexcept {
void runtime::sync() {
const auto epoch = task_mngr->generate_epoch_task(epoch_action::barrier);
task_mngr->await_epoch(epoch);
}
Expand Down
Loading