Skip to content

Commit

Permalink
[BASE] celerity#198 Improve command graph testing infrastructure
Browse files Browse the repository at this point in the history
  • Loading branch information
psalz authored and fknorr committed Aug 23, 2023
1 parent 0716c23 commit 0683c67
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 126 deletions.
146 changes: 119 additions & 27 deletions test/distributed_graph_generator_test_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,16 @@ class command_query {
};

public:
// -------------------------------------------------------------------------------------------------------------------------------------------------------
// ------------------------------------------------------------------- Query functions -------------------------------------------------------------------
// -------------------------------------------------------------------------------------------------------------------------------------------------------

/**
* Finds all commands within the current set that match a given list of filters.
* Currently supported filters are node_id, task_id and command_type.
* Filters are applied conjunctively (AND), hence each type can be specified at most once.
*/
template <typename... Filters>
command_query find_all(Filters... filters) const {
assert_not_empty(__FUNCTION__);
Expand Down Expand Up @@ -170,25 +180,72 @@ class command_query {
return command_query{std::move(filtered)};
}

/**
* Returns a new command_query that contains all commands that precede the current set of commands.
*/
template <typename... Filters>
command_query find_predecessors(Filters... filters) const {
assert_not_empty(__FUNCTION__);
return find_adjacent(true, filters...);
}

/**
* Returns a new command_query that contains all commands that succeed the current set of commands.
*/
template <typename... Filters>
command_query find_successors(Filters... filters) const {
assert_not_empty(__FUNCTION__);
return find_adjacent(false, filters...);
}

/**
* Returns the total number of commands across all nodes.
*/
size_t count() const {
return std::accumulate(
m_commands_by_node.begin(), m_commands_by_node.end(), size_t(0), [](size_t current, auto& cmds) { return current + cmds.size(); });
}

/**
* Returns the number of commands per node, if it is the same, throws otherwise.
*/
size_t count_per_node() const {
if(m_commands_by_node.empty()) return 0;
const size_t count = m_commands_by_node[0].size();
for(size_t i = 1; i < m_commands_by_node.size(); ++i) {
if(m_commands_by_node[i].size() != count) {
throw query_exception(
fmt::format("Different number of commands across nodes (node 0: {}, node {}: {})", count, i, m_commands_by_node[i].size()));
}
}
return count;
}

/**
* Chainable variant of count(), for use as part of larger query expressions.
*/
command_query assert_count(const size_t expected) const {
if(count() != expected) { throw query_exception(fmt::format("Expected {} total command(s), found {}", expected, count())); }
return *this;
}

/**
* Chainable variant of count_per_node(), for use as part of larger query expressions.
*/
command_query assert_count_per_node(const size_t expected) const {
if(count_per_node() != expected) { throw query_exception(fmt::format("Expected {} command(s) per node, found {}", expected, count_per_node())); }
return *this;
}

bool empty() const { return count() == 0; }

// --------------------------------------------------------------------------------------------------------------------------------------------------------
// -------------------------------------------------------------------- Set operations --------------------------------------------------------------------
// --------------------------------------------------------------------------------------------------------------------------------------------------------

friend command_query operator-(const command_query& lhs, const command_query& rhs) { return lhs.subtract(rhs); }
friend command_query operator+(const command_query& lhs, const command_query& rhs) { return lhs.merge(rhs); }

command_query subtract(const command_query& other) const {
assert_not_empty(__FUNCTION__);
assert(m_commands_by_node.size() == other.m_commands_by_node.size());
Expand All @@ -211,16 +268,13 @@ class command_query {
return command_query{std::move(result)};
}

// Call the provided function once for each node, with a subquery containing commands only for that node.
template <typename PerNodeCallback>
void for_each_node(PerNodeCallback&& cb) const {
assert_not_empty(__FUNCTION__);
for(node_id nid = 0; nid < m_commands_by_node.size(); ++nid) {
UNSCOPED_INFO(fmt::format("On node {}", nid));
cb(find_all(nid));
}
}
// --------------------------------------------------------------------------------------------------------------------------------------------------------
// ---------------------------------------------------------------------- Predicates ----------------------------------------------------------------------
// --------------------------------------------------------------------------------------------------------------------------------------------------------

/**
* Returns whether all commands on all nodes are of the given type.
*/
bool have_type(const command_type expected) const {
assert_not_empty(__FUNCTION__);
return for_all_commands([expected](const node_id nid, const abstract_command* cmd) {
Expand All @@ -234,10 +288,29 @@ class command_query {
});
}

/**
* Returns whether the current set of commands is succeeded by ALL commands in successors on each node.
*
* Throws if successors is empty or contains commands for nodes not present in the current query.
*
* NOTE: Care has to be taken when using this function in negative assertions. For example the check
* `CHECK_FALSE(q.find_all(task_a).have_successors(q.find_all(task_b))` can NOT be used to check
* whether there are no true dependencies between tasks a and b: If there are multiple nodes and
* for only one of them there is no dependency, the assertion will pass.
*/
bool have_successors(const command_query& successors, const std::optional<dependency_kind>& kind = std::nullopt,
const std::optional<dependency_origin>& origin = std::nullopt) const {
assert_not_empty(__FUNCTION__);

if(successors.count() == 0) { throw query_exception("Successor set is empty"); }

assert(m_commands_by_node.size() == successors.m_commands_by_node.size());
for(node_id nid = 0; nid < m_commands_by_node.size(); ++nid) {
if(m_commands_by_node[nid].empty() && !successors.m_commands_by_node[nid].empty()) {
throw query_exception(fmt::format("A.have_successors(B): B contains commands for node {}, whereas A does not", nid));
}
}

return for_all_commands([&successors, &kind, &origin](const node_id nid, const abstract_command* cmd) {
for(const auto* expected : successors.m_commands_by_node[nid]) {
bool found = false;
Expand Down Expand Up @@ -265,9 +338,31 @@ class command_query {
});
}

// --------------------------------------------------------------------------------------------------------------------------------------------------------
// ------------------------------------------------------------------------ Other -------------------------------------------------------------------------
// --------------------------------------------------------------------------------------------------------------------------------------------------------

/**
* Call the provided function once for each node, with a subquery containing commands only for that node.
*
* Using this function is usually not necessary, as all predicates (have_successors, have_types, ...) apply
* simultaneously on all nodes.
*/
template <typename PerNodeCallback>
void for_each_node(PerNodeCallback&& cb) const {
assert_not_empty(__FUNCTION__);
for(node_id nid = 0; nid < m_commands_by_node.size(); ++nid) {
UNSCOPED_INFO(fmt::format("On node {}", nid));
cb(find_all(nid));
}
}

/**
* Returns the raw command pointers contained within the query, optionally limited to a given node.
*/
std::vector<const abstract_command*> get_raw(const std::optional<node_id>& nid = std::nullopt) const {
std::vector<const abstract_command*> result;
for_all_commands([&result, &nid](const node_id n, const abstract_command* cmd) {
for_all_commands([&result, &nid](const node_id n, const abstract_command* const cmd) {
if(nid.has_value() && n != nid) return;
result.push_back(cmd);
});
Expand Down Expand Up @@ -373,14 +468,11 @@ class dist_cdag_test_context {
friend class task_builder;

public:
dist_cdag_test_context(size_t num_nodes) : m_num_nodes(num_nodes) {
m_rm = std::make_unique<reduction_manager>();
m_task_recorder = std::make_unique<task_recorder>();
m_tm = std::make_unique<task_manager>(num_nodes, nullptr /* host_queue */, m_task_recorder.get());
dist_cdag_test_context(size_t num_nodes) : m_num_nodes(num_nodes), m_rm(), m_tm(num_nodes, nullptr /* host_queue */, &m_task_recorder) {
for(node_id nid = 0; nid < num_nodes; ++nid) {
m_cdags.emplace_back(std::make_unique<command_graph>());
m_cmd_recorders.emplace_back(std::make_unique<command_recorder>(m_tm.get(), nullptr));
m_dggens.emplace_back(std::make_unique<distributed_graph_generator>(num_nodes, nid, *m_cdags[nid], *m_tm, m_cmd_recorders[nid].get()));
m_cmd_recorders.emplace_back(std::make_unique<command_recorder>(&m_tm, nullptr));
m_dggens.emplace_back(std::make_unique<distributed_graph_generator>(num_nodes, nid, *m_cdags[nid], m_tm, m_cmd_recorders[nid].get()));
}
}

Expand All @@ -395,7 +487,7 @@ class dist_cdag_test_context {
test_utils::mock_buffer<Dims> create_buffer(range<Dims> size, bool mark_as_host_initialized = false) {
const buffer_id bid = m_next_buffer_id++;
const auto buf = test_utils::mock_buffer<Dims>(bid, size);
m_tm->add_buffer(bid, Dims, range_cast<3>(size), mark_as_host_initialized);
m_tm.add_buffer(bid, Dims, range_cast<3>(size), mark_as_host_initialized);
for(auto& dggen : m_dggens) {
dggen->add_buffer(bid, Dims, range_cast<3>(size));
}
Expand Down Expand Up @@ -447,7 +539,7 @@ class dist_cdag_test_context {
}

task_id epoch(epoch_action action) {
const auto tid = m_tm->generate_epoch_task(action);
const auto tid = m_tm.generate_epoch_task(action);
build_task(tid);
return tid;
}
Expand All @@ -457,13 +549,13 @@ class dist_cdag_test_context {
return command_query(m_cdags).find_all(filters...);
}

void set_horizon_step(const int step) { m_tm->set_horizon_step(step); }
void set_horizon_step(const int step) { m_tm.set_horizon_step(step); }

task_manager& get_task_manager() { return *m_tm; }
task_manager& get_task_manager() { return m_tm; }

distributed_graph_generator& get_graph_generator(node_id nid) { return *m_dggens.at(nid); }

std::string print_task_graph() { return detail::print_task_graph(*m_task_recorder); }
std::string print_task_graph() { return detail::print_task_graph(m_task_recorder); }
std::string print_command_graph(node_id nid) { return detail::print_command_graph(nid, *m_cmd_recorders[nid]); }

private:
Expand All @@ -472,9 +564,9 @@ class dist_cdag_test_context {
host_object_id m_next_host_object_id = 0;
reduction_id m_next_reduction_id = 1; // Start from 1 as rid 0 designates "no reduction" in push commands
std::optional<task_id> m_most_recently_built_horizon;
std::unique_ptr<reduction_manager> m_rm;
std::unique_ptr<task_manager> m_tm;
std::unique_ptr<task_recorder> m_task_recorder;
reduction_manager m_rm;
task_recorder m_task_recorder;
task_manager m_tm;
std::vector<std::unique_ptr<command_graph>> m_cdags;
std::vector<std::unique_ptr<distributed_graph_generator>> m_dggens;
std::vector<std::unique_ptr<command_recorder>> m_cmd_recorders;
Expand All @@ -485,12 +577,12 @@ class dist_cdag_test_context {

void build_task(const task_id tid) {
for(auto& dggen : m_dggens) {
dggen->build_task(*m_tm->get_task(tid));
dggen->build_task(*m_tm.get_task(tid));
}
}

void maybe_build_horizon() {
const auto current_horizon = task_manager_testspy::get_current_horizon(*m_tm);
const auto current_horizon = task_manager_testspy::get_current_horizon(m_tm);
if(m_most_recently_built_horizon != current_horizon) {
assert(current_horizon.has_value());
build_task(*current_horizon);
Expand All @@ -510,7 +602,7 @@ class dist_cdag_test_context {
}

task_id fence(buffer_access_map access_map, side_effect_map side_effects) {
const auto tid = m_tm->generate_fence_task(std::move(access_map), std::move(side_effects), nullptr);
const auto tid = m_tm.generate_fence_task(std::move(access_map), std::move(side_effects), nullptr);
build_task(tid);
maybe_build_horizon();
return tid;
Expand Down
4 changes: 1 addition & 3 deletions test/graph_compaction_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -214,9 +214,7 @@ TEST_CASE("side-effect dependencies are correctly subsumed by horizons", "[distr

// This must depend on the first horizon, not first_task
const auto second_task = dctx.master_node_host_task().affect(ho, experimental::side_effect_order::sequential).submit();
const auto predecessors = dctx.query(second_task).find_predecessors();
CHECK(predecessors.count() == 1);
CHECK(predecessors.have_type(command_type::horizon));
CHECK(dctx.query(second_task).find_predecessors().assert_count(1).have_type(command_type::horizon));
}

TEST_CASE("reaching an epoch will prune all nodes of the preceding task graph", "[task_manager][task-graph][epoch]") {
Expand Down
6 changes: 2 additions & 4 deletions test/graph_gen_reduction_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ TEST_CASE("multiple chained reductions produce appropriate data transfers", "[di
const auto reduction1 = dctx.query(command_type::reduction);
CHECK(reduction1.count() == 1);
dctx.master_node_host_task().read(buf0, acc::all{}).submit();
const auto reduction2 = dctx.query(command_type::reduction).subtract(reduction1);
const auto reduction2 = dctx.query(command_type::reduction) - reduction1;
CHECK(reduction2.count() == 1);

// Both reductions are preceeded by await_pushes
Expand Down Expand Up @@ -199,9 +199,7 @@ TEST_CASE("nodes that do not own pending reduction don't include it in final red
SECTION("local reductions don't have a dependency on the last writer") {
// The last writer in this case is the initial epoch
dctx.device_compute<class UKN(consume)>(range<1>(96)).read(buf0, acc::all{}).submit();
const auto predecessors = dctx.query(node_id(2), command_type::reduction).find_predecessors(dependency_kind::true_dep);
CHECK(predecessors.count() == 1);
CHECK(predecessors.have_type(command_type::await_push));
CHECK(dctx.query(node_id(2), command_type::reduction).find_predecessors(dependency_kind::true_dep).assert_count(1).have_type(command_type::await_push));
}
}

Expand Down
22 changes: 2 additions & 20 deletions test/graph_gen_transfer_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ TEST_CASE("distributed_graph_generator generates required data transfer commands
return {};
};
const auto tid_a = dctx.device_compute<class UKN(task_a)>(test_range).discard_write(buf, rm).submit();
dctx.query(tid_a).for_each_node([](const auto& q) { CHECK(q.find_all(command_type::execution).count() == 1); });
CHECK(dctx.query(tid_a, command_type::execution).count_per_node() == 1);

dctx.device_compute<class UKN(task_b)>(test_range).read(buf, acc::one_to_one{}).submit();
CHECK(dctx.query(command_type::push).count() == 2);
Expand Down Expand Up @@ -147,24 +147,6 @@ TEST_CASE("distributed_graph_generator consolidates push commands for adjacent s
CHECK(dctx.query(tid_b).have_successors(dctx.query(command_type::push)));
}

TEST_CASE("distributed_graph_generator builds dependencies to all local commands if a given range is produced by multiple",
"[distributed_graph_generator][command-graph]") {
dist_cdag_test_context dctx(1);

const range<1> test_range = {96};
const range<1> one_third = {test_range / 3};
auto buf = dctx.create_buffer(test_range);

const auto tid_a = dctx.device_compute<class UKN(task_a)>(one_third, id<1>{0 * one_third}).discard_write(buf, acc::one_to_one{}).submit();
const auto tid_b = dctx.device_compute<class UKN(task_b)>(one_third, id<1>{1 * one_third}).discard_write(buf, acc::one_to_one{}).submit();
const auto tid_c = dctx.device_compute<class UKN(task_c)>(one_third, id<1>{2 * one_third}).discard_write(buf, acc::one_to_one{}).submit();

const auto tid_d = dctx.device_compute<class UKN(task_d)>(test_range).read(buf, acc::one_to_one{}).submit();
CHECK(dctx.query(tid_a).have_successors(dctx.query(tid_d)));
CHECK(dctx.query(tid_b).have_successors(dctx.query(tid_d)));
CHECK(dctx.query(tid_c).have_successors(dctx.query(tid_d)));
}

TEST_CASE("distributed_graph_generator generates dependencies for push commands", "[distributed_graph_generator][command-graph]") {
dist_cdag_test_context dctx(2);

Expand Down Expand Up @@ -217,7 +199,7 @@ TEST_CASE("distributed_graph_generator generates anti-dependencies for await_pus
dctx.device_compute<class UKN(task_c)>(test_range).discard_write(buf, acc::one_to_one{}).submit();
// Node 0 reads it again
dctx.master_node_host_task().read(buf, acc::all{}).submit();
const auto second_await_push = dctx.query(command_type::await_push).subtract(first_await_push);
const auto second_await_push = dctx.query(command_type::await_push) - first_await_push;
// The first await push last wrote the data, but the anti-dependency is delegated to the reading successor task
CHECK(dctx.query(tid_b).have_successors(second_await_push));
}
Expand Down
Loading

0 comments on commit 0683c67

Please sign in to comment.