diff --git a/sycl/include/CL/sycl/detail/accessor_impl.hpp b/sycl/include/CL/sycl/detail/accessor_impl.hpp index b561e123baef..14baf3338e23 100644 --- a/sycl/include/CL/sycl/detail/accessor_impl.hpp +++ b/sycl/include/CL/sycl/detail/accessor_impl.hpp @@ -160,6 +160,8 @@ class LocalAccessorImplHost { } }; +using LocalAccessorImplPtr = std::shared_ptr; + class LocalAccessorBaseHost { public: LocalAccessorBaseHost(sycl::range<3> Size, int Dims, int ElemSize) { diff --git a/sycl/include/CL/sycl/handler.hpp b/sycl/include/CL/sycl/handler.hpp index 6c054a6c3f38..ddd8dd8b88ad 100644 --- a/sycl/include/CL/sycl/handler.hpp +++ b/sycl/include/CL/sycl/handler.hpp @@ -143,6 +143,7 @@ class handler { // we exit the method they are passed in. std::vector> MArgsStorage; std::vector MAccStorage; + std::vector MLocalAccStorage; std::vector> MStreamStorage; std::vector> MSharedPtrStorage; // The list of arguments for the kernel. @@ -221,6 +222,10 @@ class handler { detail::AccessorBaseHost *AccBase = static_cast(Ptr); Ptr = detail::getSyclObjImpl(*AccBase).get(); + } else if (AccTarget == access::target::local) { + detail::LocalAccessorBaseHost *LocalAccBase = + static_cast(Ptr); + Ptr = detail::getSyclObjImpl(*LocalAccBase).get(); } } processArg(Ptr, Kind, Size, I, IndexShift, IsKernelCreatedFromSource); @@ -292,20 +297,17 @@ class handler { break; } case access::target::local: { - detail::LocalAccessorBaseHost *LAcc = - static_cast(Ptr); + detail::LocalAccessorImplHost *LAcc = + static_cast(Ptr); // Stream implementation creates local accessor with size per work item // in work group. Number of work items is not available during stream // construction, that is why size of the accessor is updated here using // information about number of work items in the work group. - if (detail::getSyclObjImpl(*LAcc)->PerWI) { - auto LocalAccImpl = detail::getSyclObjImpl(*LAcc); - LocalAccImpl->resize(MNDRDesc.LocalSize.size(), - MNDRDesc.GlobalSize.size()); - } - range<3> &Size = LAcc->getSize(); - const int Dims = LAcc->getNumOfDims(); - int SizeInBytes = LAcc->getElementSize(); + if (LAcc->PerWI) + LAcc->resize(MNDRDesc.LocalSize.size(), MNDRDesc.GlobalSize.size()); + range<3> &Size = LAcc->MSize; + const int Dims = LAcc->MDims; + int SizeInBytes = LAcc->MElemSize; for (int I = 0; I < Dims; ++I) SizeInBytes *= Size[I]; MArgs.emplace_back(kind_std_layout, nullptr, SizeInBytes, @@ -480,7 +482,11 @@ class handler { IsPlaceholder> &&Arg) { detail::LocalAccessorBaseHost *LocalAccBase = (detail::LocalAccessorBaseHost *)&Arg; - MArgs.emplace_back(detail::kernel_param_kind_t::kind_accessor, LocalAccBase, + detail::LocalAccessorImplPtr LocalAccImpl = + detail::getSyclObjImpl(*LocalAccBase); + detail::LocalAccessorImplHost *Req = LocalAccImpl.get(); + MLocalAccStorage.push_back(std::move(LocalAccImpl)); + MArgs.emplace_back(detail::kernel_param_kind_t::kind_accessor, Req, static_cast(access::target::local), ArgIndex); } diff --git a/sycl/test/basic_tests/set_arg_interop.cpp b/sycl/test/basic_tests/set_arg_interop.cpp index 5e97675be342..3e8362720e54 100644 --- a/sycl/test/basic_tests/set_arg_interop.cpp +++ b/sycl/test/basic_tests/set_arg_interop.cpp @@ -16,12 +16,14 @@ int main() { cl_context ClContext = Context.get(); - const size_t CountSources = 2; + const size_t CountSources = 3; const char *Sources[CountSources] = { "kernel void foo1(global float* Array, global int* Value) { *Array = " "42; *Value = 1; }\n", "kernel void foo2(global float* Array) { int id = get_global_id(0); " "Array[id] = id; }\n", + "kernel void foo3(global float* Array, local float* LocalArray) { " + "(void)LocalArray; (void)Array; }\n", }; cl_int Err; @@ -38,11 +40,15 @@ int main() { cl_kernel SecondCLKernel = clCreateKernel(ClProgram, "foo2", &Err); assert(Err == CL_SUCCESS); + cl_kernel ThirdCLKernel = clCreateKernel(ClProgram, "foo3", &Err); + assert(Err == CL_SUCCESS); + const size_t Count = 100; float Array[Count]; kernel FirstKernel(FirstCLKernel, Context); kernel SecondKernel(SecondCLKernel, Context); + kernel ThirdKernel(ThirdCLKernel, Context); int Value; { buffer FirstBuffer(Array, range<1>(1)); @@ -92,9 +98,24 @@ int main() { } } + { + buffer FirstBuffer(Array, range<1>(Count)); + Queue.submit([&](handler &CGH) { + auto Acc = FirstBuffer.get_access(CGH); + CGH.set_arg(0, FirstBuffer.get_access(CGH)); + CGH.set_arg( + 1, cl::sycl::accessor( + cl::sycl::range<1>(Count), CGH)); + CGH.parallel_for(range<1>{Count}, ThirdKernel); + }); + } + Queue.wait_and_throw(); + clReleaseContext(ClContext); clReleaseKernel(FirstCLKernel); clReleaseKernel(SecondCLKernel); + clReleaseKernel(ThirdCLKernel); clReleaseProgram(ClProgram); } return 0;