diff --git a/clang/lib/Sema/SemaSYCL.cpp b/clang/lib/Sema/SemaSYCL.cpp index 6ea0d5f6e2827..7593311262149 100644 --- a/clang/lib/Sema/SemaSYCL.cpp +++ b/clang/lib/Sema/SemaSYCL.cpp @@ -240,10 +240,8 @@ void BuildArgTys(ASTContext &Context, for (auto V : ArgDecls) { QualType ArgTy = V->getType(); QualType ActualArgType = ArgTy; - StringRef Name = ArgTy.getBaseTypeIdentifier()->getName(); - // TODO: harden this check with additional validation that this class is - // declared in cl::sycl namespace - if (std::string(Name) == "accessor") { + std::string Name = ArgTy.getCanonicalType().getAsString(); + if (Name.find("class cl::sycl::accessor") != std::string::npos) { if (const auto *RecordDecl = ArgTy->getAsCXXRecordDecl()) { const auto *TemplateDecl = dyn_cast(RecordDecl); @@ -251,22 +249,23 @@ void BuildArgTys(ASTContext &Context, // First parameter - data type QualType PointeeType = TemplateDecl->getTemplateArgs()[0].getAsType(); // Fourth parameter - access target - auto AccessQualifier = TemplateDecl->getTemplateArgs()[3].getAsIntegral(); + auto AccessQualifier = + TemplateDecl->getTemplateArgs()[3].getAsIntegral(); int64_t AccessTarget = AccessQualifier.getExtValue(); Qualifiers Quals = PointeeType.getQualifiers(); // TODO: Support all access targets switch (AccessTarget) { - case target::global_buffer: + case target::global_buffer: Quals.setAddressSpace(LangAS::opencl_global); - break; - case target::constant_buffer: + break; + case target::constant_buffer: Quals.setAddressSpace(LangAS::opencl_constant); - break; - case target::local: + break; + case target::local: Quals.setAddressSpace(LangAS::opencl_local); - break; - default: - llvm_unreachable("Unsupported access target"); + break; + default: + llvm_unreachable("Unsupported access target"); } // TODO: get address space from accessor template parameter. PointeeType = diff --git a/clang/test/SemaSYCL/built-in-type-kernel-arg.cpp b/clang/test/SemaSYCL/built-in-type-kernel-arg.cpp new file mode 100644 index 0000000000000..39c4a0e946f77 --- /dev/null +++ b/clang/test/SemaSYCL/built-in-type-kernel-arg.cpp @@ -0,0 +1,19 @@ +// RUN: %clang -S --sycl -Xclang -ast-dump %s | FileCheck %s +// XFAIL: * +#include + +int main() { + int data = 5; + cl::sycl::queue deviceQueue; + cl::sycl::buffer bufferA(&data, cl::sycl::range<1>(1)); + + deviceQueue.submit([&](cl::sycl::handler &cgh) { + auto accessorA = bufferA.template get_access(cgh); + cgh.single_task( + [=]() { + accessorA[0] += data; + }); + }); + return 0; +} +// CHECK: kernel_function 'void (__global int *__global, int) diff --git a/clang/test/SemaSYCL/fake-accessors.cpp b/clang/test/SemaSYCL/fake-accessors.cpp new file mode 100644 index 0000000000000..bbc6a62bc4c41 --- /dev/null +++ b/clang/test/SemaSYCL/fake-accessors.cpp @@ -0,0 +1,56 @@ +// RUN: %clang -S --sycl -Xclang -ast-dump %s | FileCheck %s +// XFAIL: * +#include + +namespace foo { +namespace cl { +namespace sycl { +class accessor { +public: + int field; +}; +} // namespace sycl +} // namespace cl +} // namespace foo + +class accessor { +public: + int field; +}; + +typedef cl::sycl::accessor + MyAccessorTD; + +using MyAccessorA = cl::sycl::accessor; + +int main() { + int data = 5; + cl::sycl::queue deviceQueue; + cl::sycl::buffer bufferA(&data, cl::sycl::range<1>(1)); + foo::cl::sycl::accessor acc = {1}; + accessor acc1 = {1}; + + deviceQueue.submit([&](cl::sycl::handler &cgh) { + auto accessorA = bufferA.template get_access(cgh); + MyAccessorTD accessorB = bufferA.template get_access(cgh); + MyAccessorA accessorC = bufferA.template get_access(cgh); + cgh.single_task( + [=]() { + accessorA[0] = acc.field + acc1.field; + }); + cgh.single_task( + [=]() { + accessorB[0] = acc.field + acc1.field; + }); + cgh.single_task( + [=]() { + accessorC[0] = acc.field + acc1.field; + }); + }); + return 0; +} +// CHECK: fake_accessors 'void (__global int *__global, foo::cl::sycl::accessor, accessor) +// CHECK: accessor_typedef 'void (__global int *__global, foo::cl::sycl::accessor, accessor) +// CHECK: accessor_alias 'void (__global int *__global, foo::cl::sycl::accessor, accessor)