Skip to content

Commit

Permalink
Improve captures buffer_data API, verify matmul result in main thread
Browse files Browse the repository at this point in the history
  • Loading branch information
fknorr committed Mar 16, 2022
1 parent b32ec2c commit ba98498
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 44 deletions.
47 changes: 17 additions & 30 deletions examples/matmul/matmul.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#include <cstdio>

#include <celerity.h>

const size_t MAT_SIZE = 1024;
Expand Down Expand Up @@ -46,29 +44,21 @@ void multiply(celerity::distr_queue& queue, celerity::buffer<T, 2> mat_a, celeri
}

template <typename T>
void verify(celerity::distr_queue& queue, celerity::buffer<T, 2> mat_a_buf, celerity::experimental::host_object<bool> verification) {
queue.submit([=](celerity::handler& cgh) {
celerity::accessor result{mat_a_buf, cgh, celerity::access::one_to_one{}, celerity::read_only_host_task};
celerity::experimental::side_effect passed{verification, cgh};

cgh.host_task(mat_a_buf.get_range(), [=](celerity::partition<2> part) {
*passed = [&] {
auto sr = part.get_subrange();
for(size_t i = sr.offset[0]; i < sr.offset[0] + sr.range[0]; ++i) {
for(size_t j = sr.offset[0]; j < sr.offset[0] + sr.range[0]; ++j) {
const float received = result[{i, j}];
const float expected = i == j;
if(expected != received) {
fprintf(stderr, "VERIFICATION FAILED for element %zu,%zu: %f (received) != %f (expected)\n", i, j, received, expected);
return false;
}
}
}
printf("VERIFICATION PASSED!\n");
return true;
}();
});
});
bool verify(celerity::experimental::buffer_data<T, 2> mat) {
const auto range = mat.get_range();
bool verification_passed = true;
for(size_t i = 0; i < range[0]; ++i) {
for(size_t j = 0; j < range[1]; ++j) {
const float received = mat[i][j];
const float expected = i == j;
if(expected != received) {
CELERITY_ERROR("Verification failed for element {},{}: {} (received) != {} (expected)", i, j, received, expected);
verification_passed = false;
}
}
}
if(verification_passed) { CELERITY_INFO("Verification passed"); }
return verification_passed;
}

int main() {
Expand All @@ -85,9 +75,6 @@ int main() {
multiply(queue, mat_a_buf, mat_b_buf, mat_c_buf);
multiply(queue, mat_b_buf, mat_c_buf, mat_a_buf);

celerity::experimental::host_object<bool> verification;
verify(queue, mat_a_buf, verification);

const auto passed = queue.drain(celerity::experimental::capture{verification});
return passed ? EXIT_SUCCESS : EXIT_FAILURE;
auto mat_a_dump = queue.drain(celerity::experimental::capture{mat_a_buf});
return verify(mat_a_dump) ? EXIT_SUCCESS : EXIT_FAILURE;
}
28 changes: 17 additions & 11 deletions include/capture.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,25 +54,31 @@ class capture;
template <typename T, int Dims>
class buffer_data {
public:
buffer_data() : range{detail::zero_range} {}
buffer_data() : sr{{}, detail::zero_range} {}

explicit operator bool() const { return !data.empty(); }

celerity::range<Dims> get_range() const { return range; }
const T* get_pointer() const { return data.data(); }
T* get_pointer() { return data.data(); }
range<Dims> get_offset() const { return sr.offset; }

// TODO accessor semantics with operator[]; into_vector()
range<Dims> get_range() const { return sr.range; }

subrange<Dims> get_subrange() const { return sr; }

const std::vector<T>& get_data() const { return data; }

std::vector<T> into_data() && { return std::move(data); }

inline const T& operator[](id<Dims> index) const { return data[detail::get_linear_index(sr.range, index)]; }

inline detail::subscript_result_t<Dims, const buffer_data> operator[](size_t index) const { return detail::subscript<Dims>(*this, index); }

private:
friend class capture<buffer<T, Dims>>;

celerity::range<Dims> range;
subrange<Dims> sr;
std::vector<T> data;

explicit buffer_data(celerity::range<Dims> range, std::vector<T> data) : range{range}, data{std::move(data)} {
assert(this->data.size() == this->range.size());
}
explicit buffer_data(subrange<Dims> sr, std::vector<T> data) : sr{sr}, data{std::move(data)} { assert(this->data.size() == this->sr.range.size()); }
};

template <typename T, int Dims>
Expand All @@ -90,7 +96,7 @@ class capture<buffer<T, Dims>> {
subrange<Dims> sr;

void record_requirements(detail::buffer_capture_map& accesses, detail::side_effect_map&) const {
accesses.add_read_access(detail::get_buffer_id(buffer), sr);
accesses.add_read_access(detail::get_buffer_id(buffer), detail::subrange_cast<3>(sr));
}

value_type exfiltrate_by_copy() const {
Expand Down Expand Up @@ -120,7 +126,7 @@ class capture<buffer<T, Dims>> {
}
}

return value_type{allocation_window.get_window_range(), std::move(data)};
return value_type{sr, std::move(data)};
}

value_type exfiltrate_by_move() const { return exfiltrate_by_copy(); }
Expand Down
2 changes: 1 addition & 1 deletion test/accessor_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ namespace detail {
buffer<int, 1> buf_a(mem_a.data(), cl::sycl::range<1>{1});
q.submit([=](handler& cgh) {
auto a = buf_a.get_access<cl::sycl::access::mode::read_write, target::host_task>(cgh, fixed<1>({0, 1}));
cgh.host_task(on_master_node, [=] { ++a[{0}]; });
cgh.host_task(on_master_node, [=] { ++a[0]; });
});
int out = 0;
q.submit(celerity::allow_by_ref, [=, &out](handler& cgh) {
Expand Down
4 changes: 2 additions & 2 deletions test/system/distr_tests.cc
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ namespace detail {
q.slow_full_sync(std::tuple{experimental::capture{buf, subrange<3>{{1, 2, 3}, {1, 1, 1}}}, experimental::capture{obj}});

REQUIRE(gathered_from_master.get_range() == range<3>{1, 1, 1});
CHECK(gathered_from_master.get_pointer()[0] == 42);
CHECK(gathered_from_master[0][0][0] == 42);

int global_rank;
MPI_Comm_rank(MPI_COMM_WORLD, &global_rank);
Expand All @@ -344,7 +344,7 @@ namespace detail {
const auto drained_from_master = q.drain(experimental::capture{buf, subrange<3>{{1, 2, 3}, {1, 1, 1}}});

REQUIRE(drained_from_master.get_range() == range<3>{1, 1, 1});
CHECK(drained_from_master.get_pointer()[0] == 84);
CHECK(drained_from_master[0][0][0] == 84);
}

} // namespace detail
Expand Down

0 comments on commit ba98498

Please sign in to comment.