From 474c9343ad41d607d7f41314769d0a2081dbe751 Mon Sep 17 00:00:00 2001 From: Mariya Podchishchaeva Date: Fri, 29 Mar 2019 18:25:41 +0300 Subject: [PATCH] [SYCL] Implement buffer::set_write_back Signed-off-by: Mariya Podchishchaeva --- sycl/include/CL/sycl/buffer.hpp | 3 +- sycl/include/CL/sycl/detail/buffer_impl.hpp | 5 +- sycl/test/basic_tests/buffer/buffer.cpp | 51 +++++++++++++++++++++ 3 files changed, 56 insertions(+), 3 deletions(-) diff --git a/sycl/include/CL/sycl/buffer.hpp b/sycl/include/CL/sycl/buffer.hpp index 91782ca2e5c51..9a43504d5de72 100644 --- a/sycl/include/CL/sycl/buffer.hpp +++ b/sycl/include/CL/sycl/buffer.hpp @@ -175,8 +175,7 @@ class buffer { impl->set_final_data(finalData); } - // void set_write_back(bool flag = true) { return impl->set_write_back(flag); - // } + void set_write_back(bool flag = true) { return impl->set_write_back(flag); } // bool is_sub_buffer() const { return impl->is_sub_buffer(); } diff --git a/sycl/include/CL/sycl/detail/buffer_impl.hpp b/sycl/include/CL/sycl/detail/buffer_impl.hpp index ac2f844686684..321bacb2f5aae 100644 --- a/sycl/include/CL/sycl/detail/buffer_impl.hpp +++ b/sycl/include/CL/sycl/detail/buffer_impl.hpp @@ -145,7 +145,7 @@ template class buffer_impl { .copyBack( *this); - if (uploadData != nullptr) { + if (uploadData != nullptr && NeedWriteBack) { uploadData(); } @@ -187,6 +187,8 @@ template class buffer_impl { }; } + void set_write_back(bool flag) { NeedWriteBack = flag; } + AllocatorT get_allocator() const { return MAllocator; } template class buffer_impl { AllocatorT MAllocator; OpenCLMemState OCLState; bool OpenCLInterop = false; + bool NeedWriteBack = true; event AvailableEvent; cl_context OpenCLContext = nullptr; void *BufPtr = nullptr; diff --git a/sycl/test/basic_tests/buffer/buffer.cpp b/sycl/test/basic_tests/buffer/buffer.cpp index 02450b91abbfc..574d938fee1d5 100644 --- a/sycl/test/basic_tests/buffer/buffer.cpp +++ b/sycl/test/basic_tests/buffer/buffer.cpp @@ -19,6 +19,7 @@ using namespace cl::sycl; int main() { int data = 5; + bool failed = false; buffer buf(&data, range<1>(1)); { int data1[10] = {-1, -1, -1, -1, -1, -1, -1, -1, -1, -1}; @@ -453,5 +454,55 @@ int main() { for (int i = 5; i < 10; i++) assert(data1[i] == -1); } + + // Check that data is copied back after forcing write-back using + // set_write_back + { + std::vector data1(10, -1); + { + buffer b(range<1>(10)); + b.set_final_data(data1.data()); + b.set_write_back(true); + queue myQueue; + myQueue.submit([&](handler &cgh) { + auto B = b.get_access(cgh); + cgh.parallel_for(range<1>{10}, + [=](id<1> index) { B[index] = 0; }); + }); + + } + // Data is copied back because there is a user side ptr and write-back is + // enabled + for (int i = 0; i < 10; i++) + if (data1[i] != 0) { + assert(false); + failed = true; + } + } + + // Check that data is not copied back after canceling write-back using + // set_write_back + { + std::vector data1(10, -1); + { + buffer b(range<1>(10)); + b.set_final_data(data1.data()); + b.set_write_back(false); + queue myQueue; + myQueue.submit([&](handler &cgh) { + auto B = b.get_access(cgh); + cgh.parallel_for(range<1>{10}, + [=](id<1> index) { B[index] = 0; }); + }); + + } + // Data is not copied back because write-back is canceled + for (int i = 0; i < 10; i++) + if (data1[i] != -1) { + assert(false); + failed = true; + } + } // TODO tests with mutex property + return failed; }