Skip to content

Commit

Permalink
[SYCL] Language restrictions for SYCL kernel functions from 6.3 section
Browse files Browse the repository at this point in the history
 - disallow allocation in kernel functions (Overloaded 'new' operations are allowed if no storage is allocated)
 - disallow recursion in kernel functions

Signed-off-by: Blower, Melanie <melanie.blower@intel.com>
Signed-off-by: Vladimir Lazarev <vladimir.lazarev@intel.com>
  • Loading branch information
Blower, Melanie authored and vladimirlaz committed Feb 12, 2019
1 parent 6a70b70 commit 4efe9fc
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 24 deletions.
2 changes: 2 additions & 0 deletions clang/include/clang/Basic/DiagnosticSemaKinds.td
Original file line number Diff line number Diff line change
Expand Up @@ -9532,13 +9532,15 @@ def err_sycl_restrict : Error<
"|use rtti"
"|use a non-const static data variable"
"|call a virtual function"
"|call a recursive function"
"|call through a function pointer"
"|allocate storage"
"|use exceptions"
"|use inline assembly}0">;
def err_sycl_virtual_types : Error<
"No class with a vtable can be used in a SYCL kernel or any code included in the kernel">;
def note_sycl_used_here : Note<"used here">;
def note_sycl_recursive_function_declared_here: Note<"function implemented using recursion declared here">;
def err_sycl_non_std_layout_type : Error<
"kernel parameter has non-standard layout class/struct type">;
} // end of sema component.
90 changes: 70 additions & 20 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/raw_ostream.h"
#include "clang/Analysis/CallGraph.h"

#include <array>

Expand All @@ -45,6 +46,7 @@ enum RestrictKind {
KernelRTTI,
KernelNonConstStaticDataVariable,
KernelCallVirtualFunction,
KernelCallRecursiveFunction,
KernelCallFunctionPointer,
KernelAllocateStorage,
KernelUseExceptions,
Expand Down Expand Up @@ -85,20 +87,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {

bool VisitCallExpr(CallExpr *e) {
for (const auto &Arg : e->arguments())
CheckTypeForVirtual(Arg->getType(), Arg->getSourceRange());
CheckSYCLType(Arg->getType(), Arg->getSourceRange());

if (FunctionDecl *Callee = e->getDirectCallee()) {
Callee = Callee->getCanonicalDecl();
// Remember that all SYCL kernel functions have deferred
// instantiation as template functions. It means that
// all functions used by kernel have already been parsed and have
// definitions.
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet;
if (IsRecursive(Callee, Callee, VisitedSet))
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict) <<
KernelCallRecursiveFunction;

if (const CXXMethodDecl *Method = dyn_cast<CXXMethodDecl>(Callee))
if (Method->isVirtual())
SemaRef.Diag(e->getExprLoc(), diag::err_sycl_restrict) <<
KernelCallVirtualFunction;

CheckTypeForVirtual(Callee->getReturnType(), Callee->getSourceRange());
CheckSYCLType(Callee->getReturnType(), Callee->getSourceRange());

if (FunctionDecl *Def = Callee->getDefinition()) {
if (!Def->hasAttr<SYCLDeviceAttr>()) {
Expand All @@ -116,7 +123,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {

bool VisitCXXConstructExpr(CXXConstructExpr *E) {
for (const auto &Arg : E->arguments())
CheckTypeForVirtual(Arg->getType(), Arg->getSourceRange());
CheckSYCLType(Arg->getType(), Arg->getSourceRange());

CXXConstructorDecl *Ctor = E->getConstructor();

Expand Down Expand Up @@ -150,22 +157,22 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
}

bool VisitTypedefNameDecl(TypedefNameDecl *TD) {
CheckTypeForVirtual(TD->getUnderlyingType(), TD->getLocation());
CheckSYCLType(TD->getUnderlyingType(), TD->getLocation());
return true;
}

bool VisitRecordDecl(RecordDecl *RD) {
CheckTypeForVirtual(QualType{RD->getTypeForDecl(), 0}, RD->getLocation());
CheckSYCLType(QualType{RD->getTypeForDecl(), 0}, RD->getLocation());
return true;
}

bool VisitParmVarDecl(VarDecl *VD) {
CheckTypeForVirtual(VD->getType(), VD->getLocation());
CheckSYCLType(VD->getType(), VD->getLocation());
return true;
}

bool VisitVarDecl(VarDecl *VD) {
CheckTypeForVirtual(VD->getType(), VD->getLocation());
CheckSYCLType(VD->getType(), VD->getLocation());
return true;
}

Expand All @@ -180,7 +187,7 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
}

bool VisitDeclRefExpr(DeclRefExpr *E) {
CheckTypeForVirtual(E->getType(), E->getSourceRange());
CheckSYCLType(E->getType(), E->getSourceRange());
if (VarDecl *VD = dyn_cast<VarDecl>(E->getDecl())) {
bool IsConst = VD->getType().getNonReferenceType().isConstQualified();
if (!IsConst && VD->hasGlobalStorage() && !VD->isStaticLocal() &&
Expand All @@ -199,12 +206,17 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
// storage are disallowed in a SYCL kernel. The placement
// new operator and any user-defined overloads that
// do not allocate storage are permitted.
const FunctionDecl *FD = E->getOperatorNew();
if (FD && !FD->isReservedGlobalPlacementOperator()) {
OverloadedOperatorKind Kind = FD->getOverloadedOperator();
if (Kind == OO_New || Kind == OO_Array_New)
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) <<
KernelAllocateStorage;
if (FunctionDecl *FD = E->getOperatorNew()) {
if (FD->isReplaceableGlobalAllocationFunction()) {
SemaRef.Diag(E->getExprLoc(), diag::err_sycl_restrict) <<
KernelAllocateStorage;
} else if (FunctionDecl *Def = FD->getDefinition()) {
if (!Def->hasAttr<SYCLDeviceAttr>()) {
Def->addAttr(SYCLDeviceAttr::CreateImplicit(SemaRef.Context));
this->TraverseStmt(Def->getBody());
SemaRef.AddSyclKernel(Def);
}
}
}
return true;
}
Expand Down Expand Up @@ -245,8 +257,42 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
return true;
}

// The call graph for this translation unit.
CallGraph SYCLCG;
private:
bool CheckTypeForVirtual(QualType Ty, SourceRange Loc) {
// Determines whether the function FD is recursive.
// CalleeNode is a function which is called either directly
// or indirectly from FD. If recursion is detected then create
// diagnostic notes on each function as the callstack is unwound.
bool IsRecursive(FunctionDecl *CalleeNode, FunctionDecl *FD,
llvm::SmallPtrSet<FunctionDecl *, 10> VisitedSet) {
// We're currently checking CalleeNode on a different
// trace through the CallGraph, we avoid infinite recursion
// by using VisitedSet to keep track of this.
if (!VisitedSet.insert(CalleeNode).second)
return false;
if (CallGraphNode *N = SYCLCG.getNode(CalleeNode)) {
for (const CallGraphNode *CI : *N) {
if (FunctionDecl *Callee = dyn_cast<FunctionDecl>(CI->getDecl())) {
Callee = Callee->getCanonicalDecl();
if (Callee == FD)
return SemaRef.Diag(FD->getSourceRange().getBegin(),
diag::note_sycl_recursive_function_declared_here)
<< KernelCallRecursiveFunction;
else if (IsRecursive(Callee, FD, VisitedSet))
return true;
}
}
}
return false;
}

bool CheckSYCLType(QualType Ty, SourceRange Loc) {
if (Ty->isVariableArrayType()) {
SemaRef.Diag(Loc.getBegin(), diag::err_vla_unsupported);
return false;
}

while (Ty->isAnyPointerType() || Ty->isArrayType())
Ty = QualType{Ty->getPointeeOrArrayElementType(), 0};

Expand All @@ -264,25 +310,25 @@ class MarkDeviceFunction : public RecursiveASTVisitor<MarkDeviceFunction> {
}

for (const auto &Field : CRD->fields()) {
if (!CheckTypeForVirtual(Field->getType(), Field->getSourceRange())) {
if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) {
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
return false;
}
}
} else if (const auto *RD = Ty->getAsRecordDecl()) {
for (const auto &Field : RD->fields()) {
if (!CheckTypeForVirtual(Field->getType(), Field->getSourceRange())) {
if (!CheckSYCLType(Field->getType(), Field->getSourceRange())) {
SemaRef.Diag(Loc.getBegin(), diag::note_sycl_used_here);
return false;
}
}
} else if (const auto *FPTy = dyn_cast<FunctionProtoType>(Ty)) {
for (const auto &ParamTy : FPTy->param_types())
if (!CheckTypeForVirtual(ParamTy, Loc))
if (!CheckSYCLType(ParamTy, Loc))
return false;
return CheckTypeForVirtual(FPTy->getReturnType(), Loc);
return CheckSYCLType(FPTy->getReturnType(), Loc);
} else if (const auto *FTy = dyn_cast<FunctionType>(Ty)) {
return CheckTypeForVirtual(FTy->getReturnType(), Loc);
return CheckSYCLType(FTy->getReturnType(), Loc);
}
return true;
}
Expand Down Expand Up @@ -726,6 +772,10 @@ void Sema::ConstructSYCLKernel(FunctionDecl *KernelCallerFunc) {
AddSyclKernel(SYCLKernel);
// Let's mark all called functions with SYCL Device attribute.
MarkDeviceFunction Marker(*this);
// Create the call graph so we can detect recursion and check the validity
// of new operator overrides. Add the kernel function itself in case
// it is recursive.
Marker.SYCLCG.addToCallGraph(getASTContext().getTranslationUnitDecl());
Marker.TraverseStmt(SYCLKernelBody);
}

Expand Down
107 changes: 107 additions & 0 deletions clang/test/SemaSYCL/restrict-recursion.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// RUN: %clang_cc1 -fcxx-exceptions -fsycl-is-device -Wno-return-type -verify -fsyntax-only -x c++ -emit-llvm-only -std=c++17 %s

// This recursive function is not called from sycl kernel,
// so it should not be diagnosed.
int fib(int n)
{
if (n <= 1)
return n;
return fib(n-1) + fib(n-2);
}

typedef struct S {
template <typename T>
// expected-note@+1 2{{function implemented using recursion declared here}}
T factT(T i, T j)
{
// expected-error@+1 1{{SYCL kernel cannot call a recursive function}}
return factT(j,i);
}

int fact(unsigned i)
{
if (i==0) return 1;
// expected-error@+1 1{{SYCL kernel cannot call a recursive function}}
else return factT<unsigned>(i-1, i);
}
} S_type;


// expected-note@+1 2{{function implemented using recursion declared here}}
int fact(unsigned i);
// expected-note@+1 2{{function implemented using recursion declared here}}
int fact1(unsigned i)
{
if (i==0) return 1;
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
else return fact(i-1) * i;
}
int fact(unsigned i)
{
if (i==0) return 1;
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
else return fact1(i-1) * i;
}

bool isa_B(void) {
S_type s;

unsigned f = s.fact(3);
// expected-error@+1 1{{SYCL kernel cannot call a recursive function}}
unsigned f1 = s.factT<unsigned>(3,4);
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
unsigned g = fact(3);
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
unsigned g1 = fact1(3);
return 0;
}

__attribute__((sycl_kernel)) void kernel1(void) {
isa_B();
}
// expected-note@+1 2{{function implemented using recursion declared here}}
__attribute__((sycl_kernel)) void kernel2(void) {
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
kernel2();
}
__attribute__((sycl_kernel)) void kernel3(void) {
;
}

using myFuncDef = int(int,int);

void usage( myFuncDef functionPtr ) {
kernel1();
}
void usage2( myFuncDef functionPtr ) {
// expected-error@+1 {{SYCL kernel cannot call a recursive function}}
kernel2();
}
void usage3( myFuncDef functionPtr ) {
kernel3();
}

int addInt(int n, int m) {
return n+m;
}

template <typename name, typename Func>
__attribute__((sycl_kernel)) void kernel_single_task(Func kernelFunc) {
kernelFunc();
}

template <typename name, typename Func>
// expected-note@+1 2{{function implemented using recursion declared here}}
__attribute__((sycl_kernel)) void kernel_single_task2(Func kernelFunc) {
kernelFunc();
// expected-error@+1 2{{SYCL kernel cannot call a recursive function}}
kernel_single_task2<name, Func>(kernelFunc);
}

int main() {
kernel_single_task<class fake_kernel>([]() { usage( &addInt ); });
kernel_single_task<class fake_kernel>([]() { usage2( &addInt ); });
kernel_single_task2<class fake_kernel>([]() { usage3( &addInt ); });
return fib(5);
}

Loading

0 comments on commit 4efe9fc

Please sign in to comment.