diff --git a/demos/ComputerGraphics/smallpt/SmallPT.cpp b/demos/ComputerGraphics/smallpt/SmallPT.cpp index 54c21c62a..468bca694 100644 --- a/demos/ComputerGraphics/smallpt/SmallPT.cpp +++ b/demos/ComputerGraphics/smallpt/SmallPT.cpp @@ -16,7 +16,8 @@ // ./SmallPT 500 && xv image.ppm // A typical invocation would be: -// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang clad \ +// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang +// clad \ // -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \ // -I../../include/ -std=c++11 SmallPT.cpp -fopenmp=libiomp5 -o SmallPT // ./SmallPT 500 && xv image.ppm diff --git a/demos/OpenCL/RosenbrockFunction.cpp b/demos/OpenCL/RosenbrockFunction.cpp index efa97c133..9b7ef154d 100644 --- a/demos/OpenCL/RosenbrockFunction.cpp +++ b/demos/OpenCL/RosenbrockFunction.cpp @@ -8,11 +8,13 @@ // To run the demo please type: // path/to/clang++ -Xclang -add-plugin -Xclang clad -Xclang -load -Xclang \ -// path/to/libclad.so -I../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp +// path/to/libclad.so -I../include/ -framework opencl -std=c++11 +// RosenbrockFunction.cpp // // A typical invocation would be: -// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang clad \ -// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \ +// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang +// clad \ +// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \ // -I../../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp // Necessary for clad to work include diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index fdc76a9b4..fc43f6554 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -75,15 +75,15 @@ static inline NamespaceDecl* NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo* Id, NamespaceDecl* PrevDecl) { - return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl); + return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl); } #else static inline NamespaceDecl* NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline, SourceLocation StartLoc, SourceLocation IdLoc, IdentifierInfo* Id, NamespaceDecl* PrevDecl) { - return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl, - /*Nested=*/false); + return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl, + /*Nested=*/false); } #endif @@ -249,17 +249,19 @@ static inline void ExprSetDeps(Expr* result, Expr* Node) { #define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsPar ,clang::CallExpr::ADLCallKind UsesADL #define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsUse ,UsesADL #define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride FPOptionsOverride - #if CLANG_VERSION_MAJOR >= 16 - #define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/ - #else - #define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts() - #endif - #define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx, - #define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS) - #define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/ - #define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved ,Node->getComputationLHSType(),Node->getComputationResultType() - #define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/ - #define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams ,Node->getLParenLoc(),Node->getRParenLoc() +#if CLANG_VERSION_MAJOR >= 16 +#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/ +#else +#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts() +#endif +#define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx, +#define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS) +#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/ +#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved \ + , Node->getComputationLHSType(), Node->getComputationResultType() +#define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/ +#define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams \ + , Node->getLParenLoc(), Node->getRParenLoc() #endif // Compatibility helper function for creation CXXOperatorCallExpr. Clang 8 and above use Create. @@ -725,7 +727,7 @@ ArraySize_GetValue(const llvm::Optional& opt) { #else static inline const Expr* ArraySize_GetValue(const std::optional& opt) { - return opt.value(); + return opt.value(); } #endif diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index 9a5d7e3d6..f75692803 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -21,11 +21,10 @@ #include namespace clad { - template - struct ValueAndAdjoint { - T value; - U adjoint; - }; +template struct ValueAndAdjoint { + T value; + U adjoint; +}; /// \returns the size of a c-style string CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { @@ -305,9 +304,9 @@ namespace clad { differentiate(F fn, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(fn && "Must pass in a non-0 argument"); - return CladFunction>(derivedFn, - code); + assert(fn && "Must pass in a non-0 argument"); + return CladFunction>(derivedFn, + code); } /// Specialization for differentiating functors. @@ -325,8 +324,8 @@ namespace clad { differentiate(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction>(derivedFn, - code, f); + return CladFunction>(derivedFn, + code, f); } /// Generates function which computes derivative of `fn` argument w.r.t @@ -348,9 +347,9 @@ namespace clad { differentiate(F fn, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(fn && "Must pass in a non-0 argument"); - return CladFunction, true>( - derivedFn, code); + assert(fn && "Must pass in a non-0 argument"); + return CladFunction, true>( + derivedFn, code); } /// Generates function which computes gradient of the given function wrt the @@ -370,9 +369,9 @@ namespace clad { gradient(F f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - assert(f && "Must pass in a non-0 argument"); - return CladFunction, true>( - derivedFn /* will be replaced by gradient*/, code); + assert(f && "Must pass in a non-0 argument"); + return CladFunction, true>( + derivedFn /* will be replaced by gradient*/, code); } /// Specialization for differentiating functors. @@ -388,8 +387,8 @@ namespace clad { gradient(F&& f, ArgSpec args = "", DerivedFnType derivedFn = static_cast(nullptr), const char* code = "") { - return CladFunction, true>( - derivedFn /* will be replaced by gradient*/, code, f); + return CladFunction, true>( + derivedFn /* will be replaced by gradient*/, code, f); } /// Generates function which computes hessian matrix of the given function wrt diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h index 32edf3e6d..7070a73f4 100644 --- a/include/clad/Differentiator/ReverseModeForwPassVisitor.h +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -24,7 +24,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor { public: ReverseModeForwPassVisitor(DerivativeBuilder& builder); DerivativeAndOverload Derive(const clang::FunctionDecl* FD, - const DiffRequest& request); + const DiffRequest& request); StmtDiff ProcessSingleStmt(const clang::Stmt* S); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 2b05c5bd5..b0855af49 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -30,7 +30,7 @@ namespace clad { class ReverseModeVisitor : public clang::ConstStmtVisitor, public VisitorBase { - + protected: // FIXME: We should remove friend-dependency of the plugin classes here. // For this we will need to separate out AST related functions in diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index a67f6ed44..1eb32f525 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -21,9 +21,8 @@ #include "clad/Differentiator/ForwardModeVisitor.h" #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" -#include "clad/Differentiator/ReverseModeVisitor.h" #include "clad/Differentiator/ReverseModeForwPassVisitor.h" -#include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/ReverseModeVisitor.h" #include "clad/Differentiator/StmtClone.h" #include "clad/Differentiator/VectorForwardModeVisitor.h" diff --git a/lib/Differentiator/DiffPlanner.cpp b/lib/Differentiator/DiffPlanner.cpp index 804671dd5..2097fb676 100644 --- a/lib/Differentiator/DiffPlanner.cpp +++ b/lib/Differentiator/DiffPlanner.cpp @@ -390,20 +390,20 @@ namespace clad { // The string is not a range just a single index size_t index; if (firstStr.getAsInteger(Radix, index)) { - utils::EmitDiag(semaRef, DiagnosticsEngine::Error, - diffArgs->getEndLoc(), - "Could not parse index '%0'", {diffSpec}); - return; + utils::EmitDiag(semaRef, DiagnosticsEngine::Error, + diffArgs->getEndLoc(), + "Could not parse index '%0'", {diffSpec}); + return; } dVarInfo.paramIndexInterval = IndexInterval(index); } else { size_t first, last; if (firstStr.getAsInteger(Radix, first) || lastStr.getAsInteger(Radix, last)) { - utils::EmitDiag(semaRef, DiagnosticsEngine::Error, - diffArgs->getEndLoc(), - "Could not parse range '%0'", {diffSpec}); - return; + utils::EmitDiag(semaRef, DiagnosticsEngine::Error, + diffArgs->getEndLoc(), + "Could not parse range '%0'", {diffSpec}); + return; } if (first >= last) { utils::EmitDiag(semaRef, DiagnosticsEngine::Error, diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index a4e5820f0..330dc4191 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -380,4 +380,4 @@ namespace clad { return DerivativeAndOverload{result.first, /*OverloadFunctionDecl=*/nullptr}; } -} // end namespace clad + } // end namespace clad diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp index 87eb965e3..8089ffc88 100644 --- a/lib/Differentiator/ReverseModeForwPassVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -34,7 +34,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, auto paramTypes = ComputeParamTypes(args); auto returnType = ComputeReturnType(); - const auto *sourceFnType = dyn_cast(m_Function->getType()); + const auto* sourceFnType = dyn_cast(m_Function->getType()); auto fnType = m_Context.getFunctionType(returnType, paramTypes, sourceFnType->getExtProtoInfo()); @@ -67,7 +67,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, for (Stmt* S : ReverseModeVisitor::m_Globals) addToCurrentBlock(S); - if (auto *CS = dyn_cast(forward)) + if (auto* CS = dyn_cast(forward)) for (Stmt* S : CS->body()) addToCurrentBlock(S); @@ -93,20 +93,20 @@ ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, QualType nonRefXValueType = xValueType.getNonReferenceType(); if (nonRefXValueType->isRealType()) return GetCladArrayRefOfType(yType); - return GetCladArrayRefOfType(nonRefXValueType); + return GetCladArrayRefOfType(nonRefXValueType); } llvm::SmallVector ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { llvm::SmallVector paramTypes; paramTypes.reserve(m_Function->getNumParams() * 2); - for (auto *PVD : m_Function->parameters()) + for (auto* PVD : m_Function->parameters()) paramTypes.push_back(PVD->getType()); QualType effectiveReturnType = m_Function->getReturnType().getNonReferenceType(); - if (const auto *MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); @@ -115,8 +115,9 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { } } - for (auto *PVD : m_Function->parameters()) { - const auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + for (auto* PVD : m_Function->parameters()) { + const auto* it = + std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { paramTypes.push_back( GetParameterDerivativeType(effectiveReturnType, PVD->getType())); @@ -126,7 +127,8 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { } clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() { - auto *valAndAdjointTempDecl = LookupTemplateDeclInCladNamespace("ValueAndAdjoint"); + auto* valAndAdjointTempDecl = + LookupTemplateDeclInCladNamespace("ValueAndAdjoint"); auto RT = m_Function->getReturnType(); auto T = InstantiateTemplate(valAndAdjointTempDecl, {RT, RT}); return T; @@ -137,14 +139,15 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { llvm::SmallVector params; llvm::SmallVector paramDerivatives; params.reserve(m_Function->getNumParams() + diffParams.size()); - const auto *derivativeFnType = cast(m_Derivative->getType()); + const auto* derivativeFnType = + cast(m_Derivative->getType()); std::size_t dParamTypesIdx = m_Function->getNumParams(); - if (const auto *MD = dyn_cast(m_Function)) { + if (const auto* MD = dyn_cast(m_Function)) { const CXXRecordDecl* RD = MD->getParent(); if (MD->isInstance() && !RD->isLambda()) { - auto *thisDerivativePVD = utils::BuildParmVarDecl( + auto* thisDerivativePVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), derivativeFnType->getParamType(dParamTypesIdx)); paramDerivatives.push_back(thisDerivativePVD); @@ -159,10 +162,10 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { ++dParamTypesIdx; } } - for (auto *PVD : m_Function->parameters()) { + for (auto* PVD : m_Function->parameters()) { // FIXME: Call expression may contain default arguments that we are now // removing. This may cause issues. - auto *newPVD = utils::BuildParmVarDecl( + auto* newPVD = utils::BuildParmVarDecl( m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); params.push_back(newPVD); @@ -171,14 +174,14 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), /*AddToContext=*/false); - auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); if (it != std::end(diffParams)) { *it = newPVD; QualType dType = derivativeFnType->getParamType(dParamTypesIdx); IdentifierInfo* dII = CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); - auto *dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, - PVD->getStorageClass()); + auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, + PVD->getStorageClass()); paramDerivatives.push_back(dPVD); ++dParamTypesIdx; @@ -230,7 +233,7 @@ StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { DeclRefExpr* clonedDRE = nullptr; // Check if referenced Decl was "replaced" with another identifier inside // the derivative - if (const auto *VD = dyn_cast(DRE->getDecl())) { + if (const auto* VD = dyn_cast(DRE->getDecl())) { auto it = m_DeclReplacements.find(VD); if (it != std::end(m_DeclReplacements)) clonedDRE = BuildDeclRef(it->second); @@ -242,13 +245,13 @@ StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { // Sema::BuildDeclRefExpr is responsible for adding captured fields // to the underlying struct of a lambda. if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { - auto *referencedDecl = cast(clonedDRE->getDecl()); + auto* referencedDecl = cast(clonedDRE->getDecl()); clonedDRE = cast(BuildDeclRef(referencedDecl)); } } else clonedDRE = cast(Clone(DRE)); - if (auto *decl = dyn_cast(clonedDRE->getDecl())) { + if (auto* decl = dyn_cast(clonedDRE->getDecl())) { // Check DeclRefExpr is a reference to an independent variable. auto it = m_Variables.find(decl); if (it == std::end(m_Variables)) { diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index a5463fb93..a2a95bdd6 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1877,18 +1877,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, "Clad failed to generate callee function forward pass function"); // FIXME: We are using the derivatives in forward pass here - // If `expr_dx()` is only meant to be used in reverse pass, + // If `expr_dx()` is only meant to be used in reverse pass, // (for example, `clad::pop(...)` expression and a corresponding // `clad::push(...)` in the forward pass), then this can result in // incorrect derivative or crash at runtime. Ideally, we should have - // a separate routine to use derivative in the forward pass. + // a separate routine to use derivative in the forward pass. // We cannot reuse the derivatives previously computed because // they might contain 'clad::pop(..)` expression. if (isa(CE)) { Expr* derivedBase = baseDiff.getExpr_dx(); - // FIXME: We may need this if-block once we support pointers, and passing pointers-by-reference - // if (isCladArrayType(derivedBase->getType())) + // FIXME: We may need this if-block once we support pointers, and + // passing pointers-by-reference if + // (isCladArrayType(derivedBase->getType())) // CallArgs.push_back(derivedBase); // else CallArgs.push_back( @@ -1899,10 +1900,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, const Expr* arg = CE->getArg(i); const ParmVarDecl* PVD = FD->getParamDecl(i); StmtDiff argDiff = Visit(arg); - if ((argDiff.getExpr_dx() != nullptr) && PVD->getType()->isReferenceType()) { + if ((argDiff.getExpr_dx() != nullptr) && + PVD->getType()->isReferenceType()) { Expr* derivedArg = argDiff.getExpr_dx(); - // FIXME: We may need this if-block once we support pointers, and passing pointers-by-reference - // if (isCladArrayType(derivedArg->getType())) + // FIXME: We may need this if-block once we support pointers, and + // passing pointers-by-reference if + // (isCladArrayType(derivedArg->getType())) // CallArgs.push_back(derivedArg); // else CallArgs.push_back( @@ -1912,31 +1915,28 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } if (isa(CE)) { Expr* baseE = baseDiff.getExpr(); - call = BuildCallExprToMemFn( - baseE, calleeFnForwPassFD->getName(), CallArgs, calleeFnForwPassFD); + call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), + CallArgs, calleeFnForwPassFD); } else { call = m_Sema - .ActOnCallExpr(getCurrentScope(), + .ActOnCallExpr(getCurrentScope(), BuildDeclRef(calleeFnForwPassFD), noLoc, CallArgs, noLoc) - .get(); + .get(); } - auto *callRes = StoreAndRef(call); - auto *resValue = + auto* callRes = StoreAndRef(call); + auto* resValue = utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); - auto *resAdjoint = + auto* resAdjoint = utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); return StmtDiff(resValue, nullptr, resAdjoint); - } // Recreate the original call expression. - call = m_Sema - .ActOnCallExpr(getCurrentScope(), - Clone(CE->getCallee()), - noLoc, - CallArgs, - noLoc) - .get(); - return StmtDiff(call); - + } // Recreate the original call expression. + call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, + CallArgs, noLoc) + .get(); + return StmtDiff(call); + return {}; } @@ -3208,7 +3208,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // TODO: Add DiffMode::experimental_pullback support here as well. if (m_Mode == DiffMode::reverse || m_Mode == DiffMode::experimental_pullback) { - QualType effectiveReturnType = m_Function->getReturnType().getNonReferenceType(); + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); if (m_Mode == DiffMode::experimental_pullback) { // FIXME: Generally, we use the function's return type as the argument's // derivative type. We cannot follow this strategy for `void` function diff --git a/lib/Differentiator/StmtClone.cpp b/lib/Differentiator/StmtClone.cpp index 14c49c894..f897d49f3 100644 --- a/lib/Differentiator/StmtClone.cpp +++ b/lib/Differentiator/StmtClone.cpp @@ -76,7 +76,12 @@ DEFINE_CLONE_EXPR(CharacterLiteral, (Node->getValue(), Node->getKind(), Node->ge DEFINE_CLONE_EXPR(ImaginaryLiteral, (Clone(Node->getSubExpr()), Node->getType())) DEFINE_CLONE_EXPR(ParenExpr, (Node->getLParen(), Node->getRParen(), Clone(Node->getSubExpr()))) DEFINE_CLONE_EXPR(ArraySubscriptExpr, (Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getType(), Node->getValueKind(), Node->getObjectKind(), Node->getRBracketLoc())) -DEFINE_CREATE_EXPR(CXXDefaultArgExpr, (Ctx, SourceLocation(), Node->getParam() CLAD_COMPAT_CLANG16_CXXDefaultArgExpr_getRewrittenExpr_Param(Node) CLAD_COMPAT_CLANG9_CXXDefaultArgExpr_getUsedContext_Param(Node))) +DEFINE_CREATE_EXPR( + CXXDefaultArgExpr, + (Ctx, SourceLocation(), + Node->getParam() + CLAD_COMPAT_CLANG16_CXXDefaultArgExpr_getRewrittenExpr_Param(Node) + CLAD_COMPAT_CLANG9_CXXDefaultArgExpr_getUsedContext_Param(Node))) Stmt* StmtClone::VisitMemberExpr(MemberExpr* Node) { TemplateArgumentListInfo TemplateArgs; @@ -108,7 +113,14 @@ DEFINE_CREATE_EXPR(CXXStaticCastExpr, (Ctx, Node->getType(), Node->getValueKind( DEFINE_CREATE_EXPR(CXXDynamicCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) DEFINE_CREATE_EXPR(CXXReinterpretCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getCastKind(), Clone(Node->getSubExpr()), 0, Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) DEFINE_CREATE_EXPR(CXXConstCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Clone(Node->getSubExpr()), Node->getTypeInfoAsWritten(), Node->getOperatorLoc(), Node->getRParenLoc(), Node->getAngleBrackets())) -DEFINE_CREATE_EXPR(CXXConstructExpr, (Ctx, Node->getType(), Node->getLocation(), Node->getConstructor(), Node->isElidable(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), Node->getConstructionKind(), Node->getParenOrBraceRange())) +DEFINE_CREATE_EXPR( + CXXConstructExpr, + (Ctx, Node->getType(), Node->getLocation(), Node->getConstructor(), + Node->isElidable(), + clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), + Node->hadMultipleCandidates(), Node->isListInitialization(), + Node->isStdInitListInitialization(), Node->requiresZeroInitialization(), + Node->getConstructionKind(), Node->getParenOrBraceRange())) DEFINE_CREATE_EXPR(CXXFunctionalCastExpr, (Ctx, Node->getType(), Node->getValueKind(), Node->getTypeInfoAsWritten(), Node->getCastKind(), Clone(Node->getSubExpr()), 0 /*EP*/CLAD_COMPAT_CLANG12_CastExpr_GetFPO(Node), Node->getLParenLoc(), Node->getRParenLoc())) DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), Node->cleanupsHaveSideEffects(), {})) @@ -117,7 +129,13 @@ DEFINE_CREATE_EXPR(ExprWithCleanups, (Ctx, Node->getSubExpr(), DEFINE_CREATE_EXPR(ConstantExpr, (Ctx, Clone(Node->getSubExpr()) CLAD_COMPAT_ConstantExpr_Create_ExtraParams)) #endif -DEFINE_CLONE_EXPR_CO(CXXTemporaryObjectExpr, (Ctx, Node->getConstructor(), Node->getType(), Node->getTypeSourceInfo(), clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), Node->getSourceRange(), Node->hadMultipleCandidates(), Node->isListInitialization(), Node->isStdInitListInitialization(), Node->requiresZeroInitialization())) +DEFINE_CLONE_EXPR_CO( + CXXTemporaryObjectExpr, + (Ctx, Node->getConstructor(), Node->getType(), Node->getTypeSourceInfo(), + clad_compat::makeArrayRef(Node->getArgs(), Node->getNumArgs()), + Node->getSourceRange(), Node->hadMultipleCandidates(), + Node->isListInitialization(), Node->isStdInitListInitialization(), + Node->requiresZeroInitialization())) DEFINE_CLONE_EXPR(MaterializeTemporaryExpr, (Node->getType(), CLAD_COMPAT_CLANG10_GetTemporaryExpr(Node), Node->isBoundToLvalueReference())) DEFINE_CLONE_EXPR_CO11(CompoundAssignOperator, (CLAD_COMPAT_CLANG11_Ctx_ExtraParams Clone(Node->getLHS()), Clone(Node->getRHS()), Node->getOpcode(), Node->getType(), @@ -137,7 +155,11 @@ DEFINE_CLONE_EXPR(CXXThrowExpr, (Clone(Node->getSubExpr()), Node->getType(), Nod #if CLANG_VERSION_MAJOR < 16 DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getParameter(), CLAD_COMPAT_SubstNonTypeTemplateParmExpr_isReferenceParameter_ExtraParam(Node) Node->getReplacement())) #else -DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), Node->getReplacement(), Node->getAssociatedDecl(), Node->getIndex(), Node->getPackIndex(), Node->isReferenceParameter())); +DEFINE_CLONE_EXPR(SubstNonTypeTemplateParmExpr, + (Node->getType(), Node->getValueKind(), Node->getBeginLoc(), + Node->getReplacement(), Node->getAssociatedDecl(), + Node->getIndex(), Node->getPackIndex(), + Node->isReferenceParameter())); #endif DEFINE_CREATE_EXPR(PseudoObjectExpr, (Ctx, Node->getSyntacticForm(), llvm::SmallVector(Node->semantics_begin(), Node->semantics_end()), Node->getResultExprIndex())) //BlockExpr diff --git a/tools/DerivedFnInfo.h b/tools/DerivedFnInfo.h index dc07c3cf2..160cbc1af 100644 --- a/tools/DerivedFnInfo.h +++ b/tools/DerivedFnInfo.h @@ -6,39 +6,41 @@ #include "clad/Differentiator/ParseDiffArgsTypes.h" namespace clad { - struct DiffRequest; - - /// `DerivedFnInfo` is designed to effectively store information about a - /// derived function. - struct DerivedFnInfo { - const clang::FunctionDecl* m_OriginalFn = nullptr; - clang::FunctionDecl* m_DerivedFn = nullptr; - clang::FunctionDecl* m_OverloadedDerivedFn = nullptr; - DiffMode m_Mode = DiffMode::unknown; - unsigned m_DerivativeOrder = 0; - DiffInputVarsInfo m_DiffVarsInfo; - bool m_UsesEnzyme = false; - - DerivedFnInfo() {} - DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn, - clang::FunctionDecl* overloadedDerivedFn); - - /// Returns true if the derived function represented by the object, - /// satisfies the requirements of the given differentiation request. - bool SatisfiesRequest(const DiffRequest& request) const; - - /// Returns true if the object represents any derived function; otherwise - /// returns false. - bool IsValid() const; - - const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; } - clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; } - clang::FunctionDecl* OverloadedDerivedFn() const { return m_OverloadedDerivedFn; } - - /// Returns true if `lhs` and `rhs` represents same derivative. - /// Here derivative is any function derived by clad. - static bool RepresentsSameDerivative(const DerivedFnInfo& lhs, - const DerivedFnInfo& rhs); +struct DiffRequest; + +/// `DerivedFnInfo` is designed to effectively store information about a +/// derived function. +struct DerivedFnInfo { + const clang::FunctionDecl* m_OriginalFn = nullptr; + clang::FunctionDecl* m_DerivedFn = nullptr; + clang::FunctionDecl* m_OverloadedDerivedFn = nullptr; + DiffMode m_Mode = DiffMode::unknown; + unsigned m_DerivativeOrder = 0; + DiffInputVarsInfo m_DiffVarsInfo; + bool m_UsesEnzyme = false; + + DerivedFnInfo() {} + DerivedFnInfo(const DiffRequest& request, clang::FunctionDecl* derivedFn, + clang::FunctionDecl* overloadedDerivedFn); + + /// Returns true if the derived function represented by the object, + /// satisfies the requirements of the given differentiation request. + bool SatisfiesRequest(const DiffRequest& request) const; + + /// Returns true if the object represents any derived function; otherwise + /// returns false. + bool IsValid() const; + + const clang::FunctionDecl* OriginalFn() const { return m_OriginalFn; } + clang::FunctionDecl* DerivedFn() const { return m_DerivedFn; } + clang::FunctionDecl* OverloadedDerivedFn() const { + return m_OverloadedDerivedFn; + } + + /// Returns true if `lhs` and `rhs` represents same derivative. + /// Here derivative is any function derived by clad. + static bool RepresentsSameDerivative(const DerivedFnInfo& lhs, + const DerivedFnInfo& rhs); }; } // namespace clad