Skip to content

Commit

Permalink
[SYCL] Support the case when local accessor is a temporary object
Browse files Browse the repository at this point in the history
Signed-off-by: Artur Gainullin <artur.gainullin@intel.com>
  • Loading branch information
againull committed Jan 14, 2020
1 parent 8803f62 commit 1eed329
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 12 deletions.
2 changes: 2 additions & 0 deletions sycl/include/CL/sycl/detail/accessor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ class LocalAccessorImplHost {
}
};

using LocalAccessorImplPtr = std::shared_ptr<LocalAccessorImplHost>;

class LocalAccessorBaseHost {
public:
LocalAccessorBaseHost(sycl::range<3> Size, int Dims, int ElemSize) {
Expand Down
28 changes: 17 additions & 11 deletions sycl/include/CL/sycl/handler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ class handler {
// we exit the method they are passed in.
std::vector<std::vector<char>> MArgsStorage;
std::vector<detail::AccessorImplPtr> MAccStorage;
std::vector<detail::LocalAccessorImplPtr> MLocalAccStorage;
std::vector<std::shared_ptr<detail::stream_impl>> MStreamStorage;
std::vector<std::shared_ptr<const void>> MSharedPtrStorage;
// The list of arguments for the kernel.
Expand Down Expand Up @@ -221,6 +222,10 @@ class handler {
detail::AccessorBaseHost *AccBase =
static_cast<detail::AccessorBaseHost *>(Ptr);
Ptr = detail::getSyclObjImpl(*AccBase).get();
} else if (AccTarget == access::target::local) {
detail::LocalAccessorBaseHost *LocalAccBase =
static_cast<detail::LocalAccessorBaseHost *>(Ptr);
Ptr = detail::getSyclObjImpl(*LocalAccBase).get();
}
}
processArg(Ptr, Kind, Size, I, IndexShift, IsKernelCreatedFromSource);
Expand Down Expand Up @@ -292,20 +297,17 @@ class handler {
break;
}
case access::target::local: {
detail::LocalAccessorBaseHost *LAcc =
static_cast<detail::LocalAccessorBaseHost *>(Ptr);
detail::LocalAccessorImplHost *LAcc =
static_cast<detail::LocalAccessorImplHost *>(Ptr);
// Stream implementation creates local accessor with size per work item
// in work group. Number of work items is not available during stream
// construction, that is why size of the accessor is updated here using
// information about number of work items in the work group.
if (detail::getSyclObjImpl(*LAcc)->PerWI) {
auto LocalAccImpl = detail::getSyclObjImpl(*LAcc);
LocalAccImpl->resize(MNDRDesc.LocalSize.size(),
MNDRDesc.GlobalSize.size());
}
range<3> &Size = LAcc->getSize();
const int Dims = LAcc->getNumOfDims();
int SizeInBytes = LAcc->getElementSize();
if (LAcc->PerWI)
LAcc->resize(MNDRDesc.LocalSize.size(), MNDRDesc.GlobalSize.size());
range<3> &Size = LAcc->MSize;
const int Dims = LAcc->MDims;
int SizeInBytes = LAcc->MElemSize;
for (int I = 0; I < Dims; ++I)
SizeInBytes *= Size[I];
MArgs.emplace_back(kind_std_layout, nullptr, SizeInBytes,
Expand Down Expand Up @@ -480,7 +482,11 @@ class handler {
IsPlaceholder> &&Arg) {
detail::LocalAccessorBaseHost *LocalAccBase =
(detail::LocalAccessorBaseHost *)&Arg;
MArgs.emplace_back(detail::kernel_param_kind_t::kind_accessor, LocalAccBase,
detail::LocalAccessorImplPtr LocalAccImpl =
detail::getSyclObjImpl(*LocalAccBase);
detail::LocalAccessorImplHost *Req = LocalAccImpl.get();
MLocalAccStorage.push_back(std::move(LocalAccImpl));
MArgs.emplace_back(detail::kernel_param_kind_t::kind_accessor, Req,
static_cast<int>(access::target::local), ArgIndex);
}

Expand Down
23 changes: 22 additions & 1 deletion sycl/test/basic_tests/set_arg_interop.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ int main() {

cl_context ClContext = Context.get();

const size_t CountSources = 2;
const size_t CountSources = 3;
const char *Sources[CountSources] = {
"kernel void foo1(global float* Array, global int* Value) { *Array = "
"42; *Value = 1; }\n",
"kernel void foo2(global float* Array) { int id = get_global_id(0); "
"Array[id] = id; }\n",
"kernel void foo3(global float* Array, local float* LocalArray) { "
"(void)LocalArray; (void)Array; }\n",
};

cl_int Err;
Expand All @@ -38,11 +40,15 @@ int main() {
cl_kernel SecondCLKernel = clCreateKernel(ClProgram, "foo2", &Err);
assert(Err == CL_SUCCESS);

cl_kernel ThirdCLKernel = clCreateKernel(ClProgram, "foo3", &Err);
assert(Err == CL_SUCCESS);

const size_t Count = 100;
float Array[Count];

kernel FirstKernel(FirstCLKernel, Context);
kernel SecondKernel(SecondCLKernel, Context);
kernel ThirdKernel(ThirdCLKernel, Context);
int Value;
{
buffer<float, 1> FirstBuffer(Array, range<1>(1));
Expand Down Expand Up @@ -92,9 +98,24 @@ int main() {
}
}

{
buffer<float, 1> FirstBuffer(Array, range<1>(Count));
Queue.submit([&](handler &CGH) {
auto Acc = FirstBuffer.get_access<access::mode::read_write>(CGH);
CGH.set_arg(0, FirstBuffer.get_access<access::mode::read_write>(CGH));
CGH.set_arg(
1, cl::sycl::accessor<float, 1, cl::sycl::access::mode::read_write,
cl::sycl::access::target::local>(
cl::sycl::range<1>(Count), CGH));
CGH.parallel_for(range<1>{Count}, ThirdKernel);
});
}
Queue.wait_and_throw();

clReleaseContext(ClContext);
clReleaseKernel(FirstCLKernel);
clReleaseKernel(SecondCLKernel);
clReleaseKernel(ThirdCLKernel);
clReleaseProgram(ClProgram);
}
return 0;
Expand Down

0 comments on commit 1eed329

Please sign in to comment.