Skip to content

Commit

Permalink
Port examples to use fences where appropriate, add fence distr_test
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Nov 7, 2022
1 parent 69820f1 commit 1031625
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 40 deletions.
46 changes: 24 additions & 22 deletions examples/distr_io/distr_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -130,30 +130,32 @@ int main(int argc, char* argv[]) {
}

if(argc == 4 && strcmp(argv[1], "--compare") == 0) {
bool equal = true;
{
celerity::distr_queue q;

celerity::buffer<float, 2> left(celerity::range<2>{N, N});
celerity::buffer<float, 2> right(celerity::range<2>{N, N});

read_hdf5_file(q, left, argv[2]);
read_hdf5_file(q, right, argv[3]);

q.submit(celerity::allow_by_ref, [=, &equal](celerity::handler& cgh) {
celerity::accessor a{left, cgh, celerity::access::all{}, celerity::read_only_host_task};
celerity::accessor b{right, cgh, celerity::access::all{}, celerity::read_only_host_task};
cgh.host_task(celerity::on_master_node, [=, &equal] {
for(size_t i = 0; i < N; ++i) {
for(size_t j = 0; j < N; ++j) {
equal &= a[{i, j}] == b[{i, j}];
}
celerity::distr_queue q;

celerity::buffer<float, 2> left(celerity::range<2>{N, N});
celerity::buffer<float, 2> right(celerity::range<2>{N, N});
celerity::buffer<bool> equal(1);

read_hdf5_file(q, left, argv[2]);
read_hdf5_file(q, right, argv[3]);

q.submit([=](celerity::handler& cgh) {
celerity::accessor a{left, cgh, celerity::access::all{}, celerity::read_only_host_task};
celerity::accessor b{right, cgh, celerity::access::all{}, celerity::read_only_host_task};
celerity::accessor e{equal, cgh, celerity::access::all{}, celerity::write_only_host_task, celerity::no_init};
cgh.host_task(celerity::on_master_node, [=] {
e[0] = true;
for(size_t i = 0; i < N; ++i) {
for(size_t j = 0; j < N; ++j) {
e[0] &= a[{i, j}] == b[{i, j}];
}
fprintf(stderr, "=> Files are %sequal\n", equal ? "" : "NOT ");
});
}
});
}
return equal ? EXIT_SUCCESS : EXIT_FAILURE;
});

const auto files_equal = celerity::experimental::fence(q, equal).get_data()[0];
fprintf(stderr, "=> Files are %sequal\n", files_equal ? "" : "NOT ");
return files_equal ? EXIT_SUCCESS : EXIT_FAILURE;
}

fprintf(stderr,
Expand Down
35 changes: 18 additions & 17 deletions examples/matmul/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,27 +45,27 @@ void multiply(celerity::distr_queue queue, celerity::buffer<T, 2> mat_a, celerit
});
}

// TODO this should really reduce into a buffer<bool> on the device, but not all backends currently support reductions
template <typename T>
void verify(celerity::distr_queue queue, celerity::buffer<T, 2> mat_a_buf, bool& verification_passed) {
// allow_by_ref is safe here as long as the caller of verify() ensures that verification_passed lives until the next synchronization point
queue.submit(celerity::allow_by_ref, [=, &verification_passed](celerity::handler& cgh) {
celerity::accessor result{mat_a_buf, cgh, celerity::access::one_to_one{}, celerity::read_only_host_task};
void verify(celerity::distr_queue& queue, celerity::buffer<T, 2> mat_c, celerity::buffer<bool> passed_buf) {
queue.submit([=](celerity::handler& cgh) {
celerity::accessor c{mat_c, cgh, celerity::access::one_to_one{}, celerity::read_only_host_task};
celerity::accessor passed{passed_buf, cgh, celerity::access::all{}, celerity::write_only_host_task, celerity::no_init};

cgh.host_task(mat_a_buf.get_range(), [=, &verification_passed](celerity::partition<2> part) {
auto sr = part.get_subrange();
cgh.host_task(mat_c.get_range(), [=](celerity::partition<2> part) {
passed[0] = true;
const auto& sr = part.get_subrange();
for(size_t i = sr.offset[0]; i < sr.offset[0] + sr.range[0]; ++i) {
for(size_t j = sr.offset[0]; j < sr.offset[0] + sr.range[0]; ++j) {
const float received = result[{i, j}];
for(size_t j = sr.offset[1]; j < sr.offset[1] + sr.range[1]; ++j) {
const float received = c[i][j];
const float expected = i == j;
if(expected != received) {
fprintf(stderr, "VERIFICATION FAILED for element %zu,%zu: %f (received) != %f (expected)\n", i, j, received, expected);
verification_passed = false;
break;
CELERITY_ERROR("Verification failed for element {},{}: {} (received) != {} (expected)", i, j, received, expected);
passed[0] = false;
}
}
if(!verification_passed) { break; }
}
if(verification_passed) { printf("VERIFICATION PASSED!\n"); }
if(passed[0]) { CELERITY_INFO("Verification passed for {}", part.get_subrange()); }
});
});
}
Expand All @@ -87,9 +87,10 @@ int main() {
multiply(queue, mat_a_buf, mat_b_buf, mat_c_buf);
multiply(queue, mat_b_buf, mat_c_buf, mat_a_buf);

bool verification_passed = true;
verify(queue, mat_a_buf, verification_passed);
queue.slow_full_sync(); // Wait for verification_passed to become available
celerity::buffer<bool> passed_buf(1);
verify(queue, mat_c_buf, passed_buf);

return verification_passed ? EXIT_SUCCESS : EXIT_FAILURE;
// The value of `passed` can differ between hosts if only part of the verification failed.
const auto passed = celerity::experimental::fence(queue, passed_buf);
return passed ? EXIT_SUCCESS : EXIT_FAILURE;
}
2 changes: 1 addition & 1 deletion include/distr_queue.h
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ namespace detail {

template <typename T, int Dims>
experimental::buffer_snapshot<T, Dims> extract(const buffer<T, Dims>& buf) const {
return extract(buffer_subrange(buffer_subrange(buf, subrange({}, buf.get_range()))));
return extract(experimental::buffer_subrange(buf, subrange({}, buf.get_range())));
}
};

Expand Down
29 changes: 29 additions & 0 deletions test/system/distr_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -343,5 +343,34 @@ namespace detail {
});
}

TEST_CASE_METHOD(test_utils::runtime_fixture, "fences transfer data correctly between nodes", "[fence]") {
buffer<int, 3> buf{{5, 4, 7}}; // Use an oddly-sized buffer to test the buffer subrange extraction logic
experimental::host_object<int> obj;

distr_queue q;
q.submit([=](handler& cgh) {
experimental::side_effect eff{obj, cgh};
cgh.host_task(experimental::collective, [=](experimental::collective_partition p) { *eff = static_cast<int>(p.get_node_index()); });
});
q.submit([=](handler& cgh) {
accessor acc{buf, cgh, celerity::access::all{}, write_only_host_task, no_init};
cgh.host_task(on_master_node, [=] { acc[{1, 2, 3}] = 42; });
});

const auto [gathered_from_master, host_rank] =
experimental::fence(q, std::tuple{experimental::buffer_subrange(buf, subrange<3>({1, 2, 3}, {1, 1, 1})), obj});
const auto gathered_from_master_individual = experimental::fence(q, experimental::buffer_subrange(buf, subrange<3>({1, 2, 3}, {1, 1, 1})));
const auto host_rank_individual = experimental::fence(q, obj);

REQUIRE(gathered_from_master.get_range() == range<3>{1, 1, 1});
CHECK(gathered_from_master[0][0][0] == 42);
CHECK(gathered_from_master_individual == gathered_from_master);

int global_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &global_rank);
CHECK(host_rank == global_rank);
CHECK(host_rank_individual == global_rank);
}

} // namespace detail
} // namespace celerity

0 comments on commit 1031625

Please sign in to comment.