Skip to content

Commit

Permalink
Fix race around reference-capture in matmul example introduced with #60
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Jan 5, 2022
1 parent 3014146 commit 76f49c9
Showing 1 changed file with 25 additions and 19 deletions.
44 changes: 25 additions & 19 deletions examples/matmul/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,18 @@ void multiply(celerity::distr_queue queue, celerity::buffer<T, 2> mat_a, celerit
});
}

int main(int argc, char* argv[]) {
bool verification_passed = true;

celerity::distr_queue queue;

auto range = celerity::range<2>(MAT_SIZE, MAT_SIZE);
celerity::buffer<float, 2> mat_a_buf(range);
celerity::buffer<float, 2> mat_b_buf(range);
celerity::buffer<float, 2> mat_c_buf(range);

set_identity(queue, mat_a_buf);
set_identity(queue, mat_b_buf);

multiply(queue, mat_a_buf, mat_b_buf, mat_c_buf);
multiply(queue, mat_b_buf, mat_c_buf, mat_a_buf);

queue.submit(celerity::allow_by_ref, [&](celerity::handler& cgh) {
template <typename T>
void verify(celerity::distr_queue queue, celerity::buffer<T, 2> mat_a_buf, bool& verification_passed) {
// allow_by_ref is safe here as long as the caller of verify() ensures that verification_passed lives until the next synchronization point
queue.submit(celerity::allow_by_ref, [=, &verification_passed](celerity::handler& cgh) {
celerity::accessor result{mat_a_buf, cgh, celerity::access::one_to_one{}, celerity::read_only_host_task};

cgh.host_task(range, [=, &verification_passed](celerity::partition<2> part) {
cgh.host_task(mat_a_buf.get_range(), [=, &verification_passed](celerity::partition<2> part) {
auto sr = part.get_subrange();
for(size_t i = sr.offset[0]; i < sr.offset[0] + sr.range[0]; ++i) {
for(size_t j = sr.offset[0]; j < sr.offset[0] + sr.range[0]; ++j) {
const float received = result[{i, j}];
const float expected = float(i == j);
const float expected = i == j;
if(expected != received) {
fprintf(stderr, "VERIFICATION FAILED for element %zu,%zu: %f (received) != %f (expected)\n", i, j, received, expected);
verification_passed = false;
Expand All @@ -96,6 +83,25 @@ int main(int argc, char* argv[]) {
if(verification_passed) { printf("VERIFICATION PASSED!\n"); }
});
});
}

int main() {
celerity::distr_queue queue;

const auto range = celerity::range<2>(MAT_SIZE, MAT_SIZE);
celerity::buffer<float, 2> mat_a_buf(range);
celerity::buffer<float, 2> mat_b_buf(range);
celerity::buffer<float, 2> mat_c_buf(range);

set_identity(queue, mat_a_buf);
set_identity(queue, mat_b_buf);

multiply(queue, mat_a_buf, mat_b_buf, mat_c_buf);
multiply(queue, mat_b_buf, mat_c_buf, mat_a_buf);

bool verification_passed = true;
verify(queue, mat_a_buf, verification_passed);
queue.slow_full_sync(); // Wait for verification_passed to become available

return verification_passed ? EXIT_SUCCESS : EXIT_FAILURE;
}

0 comments on commit 76f49c9

Please sign in to comment.