Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYCL] Add support for discard access modes #89

Merged
merged 3 commits into from
Apr 18, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions sycl/include/CL/sycl/detail/buffer_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ template <typename AllocatorT> class buffer_impl {

public:
void moveMemoryTo(QueueImplPtr Queue, std::vector<cl::sycl::event> DepEvents,
EventImplPtr Event);
EventImplPtr Event, cl::sycl::access::mode Mode);

void fill(QueueImplPtr Queue, std::vector<cl::sycl::event> DepEvents,
EventImplPtr Event, const void *Pattern, size_t PatternSize,
Expand All @@ -281,7 +281,7 @@ template <typename AllocatorT> class buffer_impl {
bool isValidAccessToMem(cl::sycl::access::mode AccessMode);

void allocate(QueueImplPtr Queue, std::vector<cl::sycl::event> DepEvents,
EventImplPtr Event, cl::sycl::access::mode mode);
EventImplPtr Event);

cl_mem getOpenCLMem() const;

Expand Down Expand Up @@ -409,7 +409,7 @@ void buffer_impl<AllocatorT>::copy(
template <typename AllocatorT>
void buffer_impl<AllocatorT>::moveMemoryTo(
QueueImplPtr Queue, std::vector<cl::sycl::event> DepEvents,
EventImplPtr Event) {
EventImplPtr Event, cl::sycl::access::mode Mode) {

ContextImplPtr Context = detail::getSyclObjImpl(Queue->get_context());

Expand All @@ -431,6 +431,10 @@ void buffer_impl<AllocatorT>::moveMemoryTo(

// Copy from OCL device to host device.
if (!OCLState.Queue->is_host() && Queue->is_host()) {
if (Mode == cl::sycl::access::mode::discard_write &&
Mode == cl::sycl::access::mode::discard_read_write)
return;

const size_t ByteSize = get_size();

std::vector<cl_event> CLEvents =
Expand Down Expand Up @@ -484,13 +488,18 @@ void buffer_impl<AllocatorT>::moveMemoryTo(

std::vector<cl_event> CLEvents =
detail::getOrWaitEvents(std::move(DepEvents), Context);
if (Mode == cl::sycl::access::mode::discard_write &&
Mode == cl::sycl::access::mode::discard_read_write)
return;

cl_event &WriteBufEvent = Event->getHandleRef();
// Enqueue copying from host to new OCL buffer.
Error =
clEnqueueWriteBuffer(OCLState.Queue->getHandleRef(), OCLState.Mem,
/*blocking_write=*/CL_FALSE, /*offset=*/0,
ByteSize, BufPtr, CLEvents.size(), CLEvents.data(),
&WriteBufEvent); // replace &WriteBufEvent to NULL
ByteSize, BufPtr, CLEvents.size(),
CLEvents.data(), /*replace &WriteBufEvent
to NULL*/ &WriteBufEvent);
CHECK_OCL_CODE(Error);
Event->setContextImpl(Context);

Expand All @@ -507,8 +516,10 @@ buffer_impl<AllocatorT>::convertSycl2OCLMode(cl::sycl::access::mode mode) {
case cl::sycl::access::mode::read:
return CL_MEM_READ_ONLY;
case cl::sycl::access::mode::write:
case cl::sycl::access::mode::discard_write:
return CL_MEM_WRITE_ONLY;
case cl::sycl::access::mode::read_write:
case cl::sycl::access::mode::discard_read_write:
case cl::sycl::access::mode::atomic:
return CL_MEM_READ_WRITE;
default:
Expand All @@ -534,8 +545,7 @@ bool buffer_impl<AllocatorT>::isValidAccessToMem(
template <typename AllocatorT>
void buffer_impl<AllocatorT>::allocate(QueueImplPtr Queue,
std::vector<cl::sycl::event> DepEvents,
EventImplPtr Event,
cl::sycl::access::mode mode) {
EventImplPtr Event) {

detail::waitEvents(DepEvents);

Expand All @@ -556,7 +566,7 @@ void buffer_impl<AllocatorT>::allocate(QueueImplPtr Queue,
cl_int Error;

cl_mem Mem =
clCreateBuffer(Context->getHandleRef(), convertSycl2OCLMode(mode),
clCreateBuffer(Context->getHandleRef(), CL_MEM_READ_WRITE,
ByteSize, nullptr, &Error);
CHECK_OCL_CODE(Error);

Expand Down
6 changes: 3 additions & 3 deletions sycl/include/CL/sycl/detail/scheduler/requirements.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,15 @@ class BufferStorage : public BufferRequirement {
void allocate(QueueImplPtr Queue, std::vector<cl::sycl::event> DepEvents,
EventImplPtr Event) override {
assert(m_Buffer != nullptr && "BufferStorage::m_Buffer is nullptr");
m_Buffer->allocate(std::move(Queue), std::move(DepEvents), std::move(Event),
Mode);
m_Buffer->allocate(std::move(Queue), std::move(DepEvents),
std::move(Event));
}

void moveMemoryTo(QueueImplPtr Queue, std::vector<cl::sycl::event> DepEvents,
EventImplPtr Event) override {
assert(m_Buffer != nullptr && "BufferStorage::m_Buffer is nullptr");
m_Buffer->moveMemoryTo(std::move(Queue), std::move(DepEvents),
std::move(Event));
std::move(Event), Mode);
}

void fill(QueueImplPtr Queue, std::vector<cl::sycl::event> DepEvents,
Expand Down
46 changes: 46 additions & 0 deletions sycl/test/basic_tests/accessor/accessor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -174,4 +174,50 @@ int main() {
}
}
}

// Discard write accessor.
{
try {
sycl::queue Queue;
sycl::buffer<int, 1> buf(sycl::range<1>(3));

Queue.submit([&](sycl::handler& cgh) {
auto dev_acc = buf.get_access<sycl::access::mode::discard_write>(cgh);

cgh.parallel_for<class test_discard_write>(
sycl::range<1>{3},
[=](sycl::id<1> index) { dev_acc[index] = 42; });
});

auto host_acc = buf.get_access<sycl::access::mode::read>();
for (int i = 0; i != 3; ++i)
assert(host_acc[i] == 42);

} catch (cl::sycl::exception e) {
std::cout << "SYCL exception caught: " << e.what();
return 1;
}
}

// Discard read-write accessor.
{
try {
sycl::queue Queue;
sycl::buffer<int, 1> buf(sycl::range<1>(3));

Queue.submit([&](sycl::handler& cgh) {
auto dev_acc = buf.get_access<sycl::access::mode::write>(cgh);

cgh.parallel_for<class test_discard_read_write>(
sycl::range<1>{3},
[=](sycl::id<1> index) { dev_acc[index] = 42; });
});

auto host_acc =
buf.get_access<sycl::access::mode::discard_read_write>();
} catch (cl::sycl::exception e) {
std::cout << "SYCL exception caught: " << e.what();
return 1;
}
}
}