Skip to content

Commit

Permalink
[SYCL] Populate the body of OutlinedFunctionDecl AST nodes inserted f…
Browse files Browse the repository at this point in the history
…or sycl_kernel_entry_point attributed functions.

The function body associated with a OutlinedFunctionDecl AST node is a clone
of the original body of the sycl_kernel_entry_point attributed function
modified to substitute references to the original function parameters with
references to replacement variables that stand in for the parameters of
the SYCL kernel caller function that will be emitted during code generation.
This change implements the necessary transforms.
  • Loading branch information
tahonermann committed Jun 17, 2024
1 parent 680076d commit 59c5aa2
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 34 deletions.
77 changes: 54 additions & 23 deletions clang/lib/Sema/SemaSYCL.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
// This implements Semantic Analysis for SYCL constructs.
//===----------------------------------------------------------------------===//

#include "TreeTransform.h"
#include "clang/Sema/SemaSYCL.h"
#include "clang/AST/Mangle.h"
#include "clang/AST/StmtSYCL.h"
Expand Down Expand Up @@ -208,6 +209,51 @@ void SemaSYCL::handleKernelEntryPointAttr(Decl *D, const ParsedAttr &AL) {
AL, TSI));
}

namespace {

// The body of a function declared with the [[sycl_kernel_entry_point]]
// attribute is cloned and transformed to substitute references to the original
// function parameters with references to replacement variables that stand in
// for SYCL kernel parameters or local variables that reconstitute a decomposed
// SYCL kernel argument.
class OutlinedFunctionDeclBodyInstantiator
: public TreeTransform<OutlinedFunctionDeclBodyInstantiator> {
public:
using ParmDeclMap = llvm::DenseMap<ParmVarDecl*, VarDecl*>;

OutlinedFunctionDeclBodyInstantiator(Sema &S, ParmDeclMap &M)
: TreeTransform<OutlinedFunctionDeclBodyInstantiator>(S),
SemaRef(S), MapRef(M) {}

// A new set of AST nodes is always required.
bool AlwaysRebuild() {
return true;
}

// Transform ParmVarDecl references to the supplied replacement variables.
ExprResult TransformDeclRefExpr(DeclRefExpr *DRE) {
const ParmVarDecl *PVD = dyn_cast<ParmVarDecl>(DRE->getDecl());
if (PVD) {
ParmDeclMap::iterator I = MapRef.find(PVD);
if (I != MapRef.end()) {
VarDecl *VD = I->second;
VD->setIsUsed();
return DeclRefExpr::Create(
SemaRef.getASTContext(), DRE->getQualifierLoc(),
DRE->getTemplateKeywordLoc(), VD, false, DRE->getNameInfo(),
VD->getType(), DRE->getValueKind());
}
}
return DRE;
}

private:
Sema &SemaRef;
ParmDeclMap &MapRef;
};

} // unnamed namespace

StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, Stmt *Body) {
// FIXME: Issue proper diagnostics for all of these scenarios.
if (auto *MD = dyn_cast<CXXMethodDecl>(FD))
Expand All @@ -222,40 +268,25 @@ StmtResult SemaSYCL::BuildSYCLKernelCallStmt(FunctionDecl *FD, Stmt *Body) {
assert(!FD->isNoReturn());
assert(FD->getReturnType()->isVoidType());

using ParmDeclMap = OutlinedFunctionDeclBodyInstantiator::ParmDeclMap;
ParmDeclMap ParmMap;

assert(SemaRef.CurContext == FD);
OutlinedFunctionDecl *OFD =
OutlinedFunctionDecl::Create(getASTContext(), FD, FD->getNumParams());
unsigned i = 0;
for (const auto &p : FD->parameters()) {
for (ParmVarDecl *PVD : FD->parameters()) {
ImplicitParamDecl *IPD =
ImplicitParamDecl::Create(getASTContext(), OFD, SourceLocation(),
p->getIdentifier(), p->getType(),
PVD->getIdentifier(), PVD->getType(),
ImplicitParamKind::Other);
OFD->setParam(i, IPD);
ParmMap[PVD] = IPD;
++i;
}

// FIXME: For short-term testing purposes, a sequence of statements that
// FIXME: references each of the implicit parameter declarations is
// FIXME: generated.
SmallVector<Stmt *, 8> OFDBodyStmts;
i = 0;
for (const auto &p : FD->parameters()) {
QualType QT = p->getType().getNonReferenceType();
DeclRefExpr *DRE = new (getASTContext()) DeclRefExpr(getASTContext(),
OFD->getParam(i),
/* RefersToEnclosingVariableOrCapture */ false,
QT,
VK_LValue,
SourceLocation());
assert(DRE);
OFDBodyStmts.push_back(DRE);
++i;
}

Stmt *OFDBody =
CompoundStmt::Create(getASTContext(), OFDBodyStmts, FPOptionsOverride(),
SourceLocation(), SourceLocation());
OutlinedFunctionDeclBodyInstantiator OFDBodyInstantiator(SemaRef, ParmMap);
Stmt *OFDBody = OFDBodyInstantiator.TransformStmt(Body).get();
OFD->setBody(OFDBody);
OFD->setNothrow();
Stmt *NewBody = new (getASTContext()) SYCLKernelCallStmt(Body, OFD);
Expand Down
80 changes: 69 additions & 11 deletions clang/test/AST/ast-dump-sycl-kernel-call-stmt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ template<int> struct KN;
// A unique invocable type for use with each declared kernel entry point.
template<int> struct K {
template<typename... Ts>
void operator()(Ts...) const;
void operator()(Ts...) const {}
};


Expand Down Expand Up @@ -69,9 +69,13 @@ void skep2<KN<2>>(K<2>);
// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'const K<2>' lvalue <NoOp>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<2>' lvalue ParmVar {{.*}} 'k' 'K<2>'
// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}}
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit k 'K<2>'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<2>'
// CHECK-NEXT: | | `-CompoundStmt {{.*}}
// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'K<2>' lvalue ImplicitParam {{.*}} 'k' 'K<2>'
// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)() const' <FunctionToPointerDecay>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const'
// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'const K<2>' lvalue <NoOp>
// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'K<2>' lvalue ImplicitParam {{.*}} 'k' 'K<2>'
// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<2>

template<typename KNT, typename KT>
Expand Down Expand Up @@ -111,9 +115,13 @@ void skep3<KN<3>>(K<3> k) {
// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'const K<3>' lvalue <NoOp>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<3>' lvalue ParmVar {{.*}} 'k' 'K<3>'
// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}}
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit k 'K<3>'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<3>'
// CHECK-NEXT: | | `-CompoundStmt {{.*}}
// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'K<3>' lvalue ImplicitParam {{.*}} 'k' 'K<3>'
// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)() const' <FunctionToPointerDecay>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void () const' lvalue CXXMethod {{.*}} 'operator()' 'void () const'
// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'const K<3>' lvalue <NoOp>
// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'K<3>' lvalue ImplicitParam {{.*}} 'k' 'K<3>'
// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<3>

[[clang::sycl_kernel_entry_point(KN<4>)]]
Expand All @@ -136,14 +144,64 @@ void skep4(K<4> k, int p1, int p2) {
// CHECK-NEXT: | | | `-ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue ParmVar {{.*}} 'p2' 'int'
// CHECK-NEXT: | | `-OutlinedFunctionDecl {{.*}}
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit k 'K<4>'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit p1 'int'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit p2 'int'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<4>'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used p1 'int'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used p2 'int'
// CHECK-NEXT: | | `-CompoundStmt {{.*}}
// CHECK-NEXT: | | |-DeclRefExpr {{.*}} 'K<4>' lvalue ImplicitParam {{.*}} 'k' 'K<4>'
// CHECK-NEXT: | | |-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p1' 'int'
// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p2' 'int'
// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)(int, int) const' <FunctionToPointerDecay>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void (int, int) const' lvalue CXXMethod {{.*}} 'operator()' 'void (int, int) const'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'const K<4>' lvalue <NoOp>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<4>' lvalue ImplicitParam {{.*}} 'k' 'K<4>'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p1' 'int'
// CHECK-NEXT: | | `-ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// CHECK-NEXT: | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p2' 'int'
// CHECK-NEXT: | `-SYCLKernelEntryPointAttr {{.*}} KN<4>

[[clang::sycl_kernel_entry_point(KN<5>)]]
void skep5(int unused1, K<5> k, int unused2, int p, int unused3) {
static int slv = 0;
int lv = 4;
k(slv, 1, p, 3, lv, 5, []{ return 6; });
}
// CHECK: |-FunctionDecl {{.*}} skep5 'void (int, K<5>, int, int, int)'
// CHECK-NEXT: | |-ParmVarDecl {{.*}} unused1 'int'
// CHECK-NEXT: | |-ParmVarDecl {{.*}} used k 'K<5>'
// CHECK-NEXT: | |-ParmVarDecl {{.*}} unused2 'int'
// CHECK-NEXT: | |-ParmVarDecl {{.*}} used p 'int'
// CHECK-NEXT: | |-ParmVarDecl {{.*}} unused3 'int'
// CHECK-NEXT: | |-SYCLKernelCallStmt {{.*}}
// CHECK-NEXT: | | |-CompoundStmt {{.*}}
// CHECK: | | `-OutlinedFunctionDecl {{.*}}
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit unused1 'int'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used k 'K<5>'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit unused2 'int'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit used p 'int'
// CHECK-NEXT: | | |-ImplicitParamDecl {{.*}} implicit unused3 'int'
// CHECK-NEXT: | | `-CompoundStmt {{.*}}
// CHECK-NEXT: | | |-DeclStmt {{.*}}
// CHECK-NEXT: | | | `-VarDecl {{.*}} used slv 'int' static cinit
// CHECK-NEXT: | | | `-IntegerLiteral {{.*}} 'int' 0
// CHECK-NEXT: | | |-DeclStmt {{.*}}
// CHECK-NEXT: | | | `-VarDecl {{.*}} used lv 'int' cinit
// CHECK-NEXT: | | | `-IntegerLiteral {{.*}} 'int' 4
// CHECK-NEXT: | | `-CXXOperatorCallExpr {{.*}} 'void' '()'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'void (*)(int, int, int, int, int, int, (lambda {{.*}}) const' <FunctionToPointerDecay>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'void (int, int, int, int, int, int, (lambda {{.*}})) const' lvalue CXXMethod {{.*}} 'operator()' 'void (int, int, int, int, int, int, (lambda {{.*}})) const'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'const K<5>' lvalue <NoOp>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'K<5>' lvalue ImplicitParam {{.*}} 'k' 'K<5>'
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue Var {{.*}} 'slv' 'int'
// CHECK-NEXT: | | |-IntegerLiteral {{.*}} 'int' 1
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue ImplicitParam {{.*}} 'p' 'int'
// CHECK-NEXT: | | |-IntegerLiteral {{.*}} 'int' 3
// CHECK-NEXT: | | |-ImplicitCastExpr {{.*}} 'int' <LValueToRValue>
// CHECK-NEXT: | | | `-DeclRefExpr {{.*}} 'int' lvalue Var {{.*}} 'lv' 'int'
// CHECK-NEXT: | | |-IntegerLiteral {{.*}} 'int' 5
// CHECK-NEXT: | | `-LambdaExpr {{.*}} '(lambda {{.*}})'
// CHECK: | `-SYCLKernelEntryPointAttr {{.*}} KN<5>

void the_end() {}
// CHECK: `-FunctionDecl {{.*}} the_end 'void ()'

0 comments on commit 59c5aa2

Please sign in to comment.