From ba98498cc127284ba944ea3572503a03cc891a02 Mon Sep 17 00:00:00 2001 From: Fabian Knorr Date: Wed, 16 Mar 2022 19:58:17 +0100 Subject: [PATCH] Improve captures buffer_data API, verify matmul result in main thread --- examples/matmul/matmul.cc | 47 ++++++++++++++------------------------ include/capture.h | 28 ++++++++++++++--------- test/accessor_tests.cc | 2 +- test/system/distr_tests.cc | 4 ++-- 4 files changed, 37 insertions(+), 44 deletions(-) diff --git a/examples/matmul/matmul.cc b/examples/matmul/matmul.cc index 0f8b73343..f3b0c0dc6 100644 --- a/examples/matmul/matmul.cc +++ b/examples/matmul/matmul.cc @@ -1,5 +1,3 @@ -#include - #include const size_t MAT_SIZE = 1024; @@ -46,29 +44,21 @@ void multiply(celerity::distr_queue& queue, celerity::buffer mat_a, celeri } template -void verify(celerity::distr_queue& queue, celerity::buffer mat_a_buf, celerity::experimental::host_object verification) { - queue.submit([=](celerity::handler& cgh) { - celerity::accessor result{mat_a_buf, cgh, celerity::access::one_to_one{}, celerity::read_only_host_task}; - celerity::experimental::side_effect passed{verification, cgh}; - - cgh.host_task(mat_a_buf.get_range(), [=](celerity::partition<2> part) { - *passed = [&] { - auto sr = part.get_subrange(); - for(size_t i = sr.offset[0]; i < sr.offset[0] + sr.range[0]; ++i) { - for(size_t j = sr.offset[0]; j < sr.offset[0] + sr.range[0]; ++j) { - const float received = result[{i, j}]; - const float expected = i == j; - if(expected != received) { - fprintf(stderr, "VERIFICATION FAILED for element %zu,%zu: %f (received) != %f (expected)\n", i, j, received, expected); - return false; - } - } - } - printf("VERIFICATION PASSED!\n"); - return true; - }(); - }); - }); +bool verify(celerity::experimental::buffer_data mat) { + const auto range = mat.get_range(); + bool verification_passed = true; + for(size_t i = 0; i < range[0]; ++i) { + for(size_t j = 0; j < range[1]; ++j) { + const float received = mat[i][j]; + const float expected = i == j; + if(expected != received) { + CELERITY_ERROR("Verification failed for element {},{}: {} (received) != {} (expected)", i, j, received, expected); + verification_passed = false; + } + } + } + if(verification_passed) { CELERITY_INFO("Verification passed"); } + return verification_passed; } int main() { @@ -85,9 +75,6 @@ int main() { multiply(queue, mat_a_buf, mat_b_buf, mat_c_buf); multiply(queue, mat_b_buf, mat_c_buf, mat_a_buf); - celerity::experimental::host_object verification; - verify(queue, mat_a_buf, verification); - - const auto passed = queue.drain(celerity::experimental::capture{verification}); - return passed ? EXIT_SUCCESS : EXIT_FAILURE; + auto mat_a_dump = queue.drain(celerity::experimental::capture{mat_a_buf}); + return verify(mat_a_dump) ? EXIT_SUCCESS : EXIT_FAILURE; } diff --git a/include/capture.h b/include/capture.h index 3d4197cd7..a7f76d41e 100644 --- a/include/capture.h +++ b/include/capture.h @@ -54,25 +54,31 @@ class capture; template class buffer_data { public: - buffer_data() : range{detail::zero_range} {} + buffer_data() : sr{{}, detail::zero_range} {} explicit operator bool() const { return !data.empty(); } - celerity::range get_range() const { return range; } - const T* get_pointer() const { return data.data(); } - T* get_pointer() { return data.data(); } + range get_offset() const { return sr.offset; } - // TODO accessor semantics with operator[]; into_vector() + range get_range() const { return sr.range; } + + subrange get_subrange() const { return sr; } + + const std::vector& get_data() const { return data; } + + std::vector into_data() && { return std::move(data); } + + inline const T& operator[](id index) const { return data[detail::get_linear_index(sr.range, index)]; } + + inline detail::subscript_result_t operator[](size_t index) const { return detail::subscript(*this, index); } private: friend class capture>; - celerity::range range; + subrange sr; std::vector data; - explicit buffer_data(celerity::range range, std::vector data) : range{range}, data{std::move(data)} { - assert(this->data.size() == this->range.size()); - } + explicit buffer_data(subrange sr, std::vector data) : sr{sr}, data{std::move(data)} { assert(this->data.size() == this->sr.range.size()); } }; template @@ -90,7 +96,7 @@ class capture> { subrange sr; void record_requirements(detail::buffer_capture_map& accesses, detail::side_effect_map&) const { - accesses.add_read_access(detail::get_buffer_id(buffer), sr); + accesses.add_read_access(detail::get_buffer_id(buffer), detail::subrange_cast<3>(sr)); } value_type exfiltrate_by_copy() const { @@ -120,7 +126,7 @@ class capture> { } } - return value_type{allocation_window.get_window_range(), std::move(data)}; + return value_type{sr, std::move(data)}; } value_type exfiltrate_by_move() const { return exfiltrate_by_copy(); } diff --git a/test/accessor_tests.cc b/test/accessor_tests.cc index c64cb4aef..f01b64e95 100644 --- a/test/accessor_tests.cc +++ b/test/accessor_tests.cc @@ -56,7 +56,7 @@ namespace detail { buffer buf_a(mem_a.data(), cl::sycl::range<1>{1}); q.submit([=](handler& cgh) { auto a = buf_a.get_access(cgh, fixed<1>({0, 1})); - cgh.host_task(on_master_node, [=] { ++a[{0}]; }); + cgh.host_task(on_master_node, [=] { ++a[0]; }); }); int out = 0; q.submit(celerity::allow_by_ref, [=, &out](handler& cgh) { diff --git a/test/system/distr_tests.cc b/test/system/distr_tests.cc index 347f9a38d..8eff8b449 100644 --- a/test/system/distr_tests.cc +++ b/test/system/distr_tests.cc @@ -330,7 +330,7 @@ namespace detail { q.slow_full_sync(std::tuple{experimental::capture{buf, subrange<3>{{1, 2, 3}, {1, 1, 1}}}, experimental::capture{obj}}); REQUIRE(gathered_from_master.get_range() == range<3>{1, 1, 1}); - CHECK(gathered_from_master.get_pointer()[0] == 42); + CHECK(gathered_from_master[0][0][0] == 42); int global_rank; MPI_Comm_rank(MPI_COMM_WORLD, &global_rank); @@ -344,7 +344,7 @@ namespace detail { const auto drained_from_master = q.drain(experimental::capture{buf, subrange<3>{{1, 2, 3}, {1, 1, 1}}}); REQUIRE(drained_from_master.get_range() == range<3>{1, 1, 1}); - CHECK(drained_from_master.get_pointer()[0] == 84); + CHECK(drained_from_master[0][0][0] == 84); } } // namespace detail