From ea8b406dc8e52c38209cd845493c37b968c43c4f Mon Sep 17 00:00:00 2001 From: Parth Date: Sun, 10 Apr 2022 21:45:09 +0530 Subject: [PATCH] Add support for diff of ref return types in rev mode Co-authored-by: Daemond --- .clang-tidy | 6 +- include/clad/Differentiator/Compatibility.h | 3 +- .../clad/Differentiator/DerivativeBuilder.h | 6 +- include/clad/Differentiator/DiffMode.h | 1 + include/clad/Differentiator/Differentiator.h | 5 + .../ReverseModeForwPassVisitor.h | 36 +++ .../clad/Differentiator/ReverseModeVisitor.h | 11 +- lib/Differentiator/BaseForwardModeVisitor.cpp | 4 +- lib/Differentiator/CMakeLists.txt | 1 + lib/Differentiator/DerivativeBuilder.cpp | 19 +- lib/Differentiator/ForwardModeVisitor.cpp | 5 +- lib/Differentiator/HessianModeVisitor.cpp | 10 +- .../ReverseModeForwPassVisitor.cpp | 243 ++++++++++++++++++ lib/Differentiator/ReverseModeVisitor.cpp | 122 +++++++-- .../VectorForwardModeVisitor.cpp | 9 +- test/Gradient/FunctionCalls.C | 112 ++++++++ test/Gradient/MemberFunctions.C | 37 +++ 17 files changed, 562 insertions(+), 68 deletions(-) create mode 100644 include/clad/Differentiator/ReverseModeForwPassVisitor.h create mode 100644 lib/Differentiator/ReverseModeForwPassVisitor.cpp diff --git a/.clang-tidy b/.clang-tidy index 38a6a3ca3..86bd21013 100644 --- a/.clang-tidy +++ b/.clang-tidy @@ -43,4 +43,8 @@ CheckOptions: - key: readability-identifier-naming.IgnoreMainLikeFunctions value: 1 - key: readability-implicit-bool-conversion.AllowPointerConditions - value: 1 \ No newline at end of file + value: 1 + - key: readability-magic-numbers.IgnorePowersOf2IntegerValues + value: 1 + - key: readability-magic-numbers.IgnoredIntegerValues + value: 4;8;16; diff --git a/include/clad/Differentiator/Compatibility.h b/include/clad/Differentiator/Compatibility.h index fdc76a9b4..edf84134d 100644 --- a/include/clad/Differentiator/Compatibility.h +++ b/include/clad/Differentiator/Compatibility.h @@ -427,7 +427,8 @@ static inline QualType getConstantArrayType(const ASTContext &Ctx, #if CLANG_VERSION_MAJOR < 10 #define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) /**/ #elif CLANG_VERSION_MAJOR >= 10 - #define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) ,((x)?VD.Clone((x)):nullptr) +#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) \ + , ((x) ? VB.Clone((x)) : nullptr) #endif // Clang 10 remove GetTemporaryExpr(). Use getSubExpr() instead diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 39c31691c..c1015db68 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -77,7 +77,7 @@ namespace clad { friend class ReverseModeVisitor; friend class HessianModeVisitor; friend class JacobianModeVisitor; - + friend class ReverseModeForwPassVisitor; clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; @@ -93,9 +93,7 @@ namespace clad { llvm::SmallVector, 4> m_ErrorEstHandler; DeclWithContext cloneFunction(const clang::FunctionDecl* FD, - clad::VisitorBase VB, clang::DeclContext* DC, - clang::Sema& m_Sema, - clang::ASTContext& m_Context, + clad::VisitorBase& VB, clang::DeclContext* DC, clang::SourceLocation& noLoc, clang::DeclarationNameInfo name, clang::QualType functionType); diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index a9c27a935..a03e77e49 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -11,6 +11,7 @@ enum class DiffMode { reverse, hessian, jacobian, + reverse_mode_forward_pass, error_estimation }; } diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index b59f77189..49797a7ba 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -21,6 +21,11 @@ #include namespace clad { +template struct ValueAndAdjoint { + T value; + U adjoint; +}; + /// \returns the size of a c-style string CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { unsigned int count; diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h new file mode 100644 index 000000000..5d60e6cb6 --- /dev/null +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -0,0 +1,36 @@ +#ifndef CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H +#define CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H + +#include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clad/Differentiator/ReverseModeVisitor.h" + +#include "clang/AST/StmtVisitor.h" +#include "clang/Sema/Sema.h" + +#include "llvm/ADT/SmallVector.h" + +namespace clad { +class ReverseModeForwPassVisitor : public ReverseModeVisitor { +private: + Stmts m_Globals; + + llvm::SmallVector + ComputeParamTypes(const DiffParams& diffParams); + clang::QualType ComputeReturnType(); + llvm::SmallVector BuildParams(DiffParams& diffParams); + clang::QualType GetParameterDerivativeType(clang::QualType yType, + clang::QualType xType); + +public: + ReverseModeForwPassVisitor(DerivativeBuilder& builder); + DerivativeAndOverload Derive(const clang::FunctionDecl* FD, + const DiffRequest& request); + + StmtDiff ProcessSingleStmt(const clang::Stmt* S); + StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override; + StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override; + StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; +}; +} // namespace clad + +#endif \ No newline at end of file diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 3b668ff87..3b7c4b1cb 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -30,8 +30,7 @@ namespace clad { class ReverseModeVisitor : public clang::ConstStmtVisitor, public VisitorBase { - - private: + protected: // FIXME: We should remove friend-dependency of the plugin classes here. // For this we will need to separate out AST related functions in // a separate namespace, as well as add getters/setters function of @@ -292,7 +291,7 @@ namespace clad { public: ReverseModeVisitor(DerivativeBuilder& builder); - ~ReverseModeVisitor(); + virtual ~ReverseModeVisitor(); ///\brief Produces the gradient of a given function. /// @@ -321,11 +320,11 @@ namespace clad { StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); StmtDiff VisitCallExpr(const clang::CallExpr* CE); - StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); + virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); - StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); + virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); StmtDiff VisitForStmt(const clang::ForStmt* FS); @@ -335,7 +334,7 @@ namespace clad { StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL); StmtDiff VisitMemberExpr(const clang::MemberExpr* ME); StmtDiff VisitParenExpr(const clang::ParenExpr* PE); - StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); + virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); StmtDiff VisitStmt(const clang::Stmt* S); StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index d2066e185..a4ed41f68 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -168,8 +168,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD, llvm::SaveAndRestore SaveScope(m_CurScope); DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; - DeclWithContext result = m_Builder.cloneFunction( - FD, *this, DC, m_Sema, m_Context, loc, name, FD->getType()); + DeclWithContext result = + m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType()); FunctionDecl* derivedFD = result.first; m_Derivative = derivedFD; diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 61695990e..f5eddb2c6 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator HessianModeVisitor.cpp JacobianModeVisitor.cpp MultiplexExternalRMVSource.cpp + ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp StmtClone.cpp VectorForwardModeVisitor.cpp diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index d5a15580b..c07e70057 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -21,6 +21,7 @@ #include "clad/Differentiator/ForwardModeVisitor.h" #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" #include "clad/Differentiator/ReverseModeVisitor.h" #include "clad/Differentiator/StmtClone.h" #include "clad/Differentiator/VectorForwardModeVisitor.h" @@ -88,15 +89,10 @@ namespace clad { return false; } - DeclWithContext - DerivativeBuilder::cloneFunction(const clang::FunctionDecl* FD, - clad::VisitorBase VD, - clang::DeclContext* DC, - clang::Sema& m_Sema, - clang::ASTContext& m_Context, - clang::SourceLocation& noLoc, - clang::DeclarationNameInfo name, - clang::QualType functionType) { + DeclWithContext DerivativeBuilder::cloneFunction( + const clang::FunctionDecl* FD, clad::VisitorBase& VB, + clang::DeclContext* DC, clang::SourceLocation& noLoc, + clang::DeclarationNameInfo name, clang::QualType functionType) { FunctionDecl* returnedFD = nullptr; NamespaceDecl* enclosingNS = nullptr; if (isa(FD)) { @@ -115,7 +111,7 @@ namespace clad { returnedFD->setAccess(FD->getAccess()); } else { assert (isa(FD) && "Unexpected!"); - enclosingNS = VD.RebuildEnclosingNamespaces(DC); + enclosingNS = VB.RebuildEnclosingNamespaces(DC); returnedFD = FunctionDecl::Create(m_Context, m_Sema.CurContext, noLoc, @@ -230,6 +226,9 @@ namespace clad { result = V.DerivePullback(FD, request); if (!m_ErrorEstHandler.empty()) CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); + } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { + ReverseModeForwPassVisitor V(*this); + result = V.Derive(FD, request); } else if (request.Mode == DiffMode::hessian) { HessianModeVisitor H(*this); result = H.Derive(FD, request); diff --git a/lib/Differentiator/ForwardModeVisitor.cpp b/lib/Differentiator/ForwardModeVisitor.cpp index 4c7f31e03..d702dcbfc 100644 --- a/lib/Differentiator/ForwardModeVisitor.cpp +++ b/lib/Differentiator/ForwardModeVisitor.cpp @@ -81,9 +81,8 @@ clang::QualType ForwardModeVisitor::ComputePushforwardFnReturnType() { DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; - DeclWithContext cloneFunctionResult = - m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc, - derivedFnName, derivedFnType); + DeclWithContext cloneFunctionResult = m_Builder.cloneFunction( + m_Function, *this, DC, noLoc, derivedFnName, derivedFnType); m_Derivative = cloneFunctionResult.first; llvm::SmallVector params; diff --git a/lib/Differentiator/HessianModeVisitor.cpp b/lib/Differentiator/HessianModeVisitor.cpp index a4e5820f0..a85bfd49a 100644 --- a/lib/Differentiator/HessianModeVisitor.cpp +++ b/lib/Differentiator/HessianModeVisitor.cpp @@ -231,14 +231,8 @@ namespace clad { llvm::SaveAndRestore SaveScope(m_CurScope); m_Sema.CurContext = DC; - DeclWithContext result = m_Builder.cloneFunction(m_Function, - *this, - DC, - m_Sema, - m_Context, - noLoc, - name, - hessianFunctionType); + DeclWithContext result = m_Builder.cloneFunction( + m_Function, *this, DC, noLoc, name, hessianFunctionType); FunctionDecl* hessianFD = result.first; beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp new file mode 100644 index 000000000..c22f4fadf --- /dev/null +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -0,0 +1,243 @@ +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" + +#include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/ErrorEstimator.h" + +#include "llvm/Support/SaveAndRestore.h" + +#include + +using namespace clang; + +namespace clad { + +ReverseModeForwPassVisitor::ReverseModeForwPassVisitor( + DerivativeBuilder& builder) + : ReverseModeVisitor(builder) {} + +DerivativeAndOverload +ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, + const DiffRequest& request) { + silenceDiags = !request.VerboseDiags; + m_Function = FD; + + m_Mode = DiffMode::reverse_mode_forward_pass; + + assert(m_Function && "Must not be null."); + + DiffParams args{}; + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + + auto fnName = m_Function->getNameAsString() + "_forw"; + auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName); + + auto paramTypes = ComputeParamTypes(args); + auto returnType = ComputeReturnType(); + const auto* sourceFnType = dyn_cast(m_Function->getType()); + auto fnType = m_Context.getFunctionType(returnType, paramTypes, + sourceFnType->getExtProtoInfo()); + + llvm::SaveAndRestore saveContext(m_Sema.CurContext); + llvm::SaveAndRestore saveScope(m_CurScope); + // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast) + m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + + DeclWithContext fnBuildRes = m_Builder.cloneFunction( + m_Function, *this, m_Sema.CurContext, noLoc, fnDNI, fnType); + m_Derivative = fnBuildRes.first; + + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + + auto params = BuildParams(args); + m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); + + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + + beginBlock(); + beginBlock(direction::reverse); + + StmtDiff bodyDiff = Visit(m_Function->getBody()); + Stmt* forward = bodyDiff.getStmt(); + + for (Stmt* S : ReverseModeVisitor::m_Globals) + addToCurrentBlock(S); + + if (auto* CS = dyn_cast(forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S); + + Stmt* fnBody = endBlock(); + m_Derivative->setBody(fnBody); + endScope(); + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); + return DerivativeAndOverload{m_Derivative, nullptr}; +} + +// FIXME: This function is copied from ReverseModeVisitor. Find a suitable place +// for it. +QualType +ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, + QualType xType) { + assert(yType.getNonReferenceType()->isRealType() && + "yType should be a builtin-numerical scalar type!!"); + QualType xValueType = utils::GetValueType(xType); + // derivative variables should always be of non-const type. + xValueType.removeLocalConst(); + QualType nonRefXValueType = xValueType.getNonReferenceType(); + if (nonRefXValueType->isRealType()) + return GetCladArrayRefOfType(yType); + return GetCladArrayRefOfType(nonRefXValueType); +} + +llvm::SmallVector +ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { + llvm::SmallVector paramTypes; + paramTypes.reserve(m_Function->getNumParams() * 2); + for (auto* PVD : m_Function->parameters()) + paramTypes.push_back(PVD->getType()); + + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); + + if (const auto* MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, thisType)); + } + } + + for (auto* PVD : m_Function->parameters()) { + const auto* it = + std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, PVD->getType())); + } + } + return paramTypes; +} + +clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() { + auto* valAndAdjointTempDecl = + LookupTemplateDeclInCladNamespace("ValueAndAdjoint"); + auto RT = m_Function->getReturnType(); + auto T = InstantiateTemplate(valAndAdjointTempDecl, {RT, RT}); + return T; +} + +llvm::SmallVector +ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { + llvm::SmallVector params; + llvm::SmallVector paramDerivatives; + params.reserve(m_Function->getNumParams() + diffParams.size()); + const auto* derivativeFnType = + cast(m_Derivative->getType()); + + std::size_t dParamTypesIdx = m_Function->getNumParams(); + + if (const auto* MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + auto* thisDerivativePVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), + derivativeFnType->getParamType(dParamTypesIdx)); + paramDerivatives.push_back(thisDerivativePVD); + + if (thisDerivativePVD->getIdentifier()) + m_Sema.PushOnScopeChains(thisDerivativePVD, getCurrentScope(), + /*AddToContext=*/false); + + Expr* deref = + BuildOp(UnaryOperatorKind::UO_Deref, BuildDeclRef(thisDerivativePVD)); + m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); + ++dParamTypesIdx; + } + } + for (auto* PVD : m_Function->parameters()) { + // FIXME: Call expression may contain default arguments that we are now + // removing. This may cause issues. + auto* newPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), + PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); + params.push_back(newPVD); + + if (newPVD->getIdentifier()) + m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), + /*AddToContext=*/false); + + auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + *it = newPVD; + QualType dType = derivativeFnType->getParamType(dParamTypesIdx); + IdentifierInfo* dII = + CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, + PVD->getStorageClass()); + paramDerivatives.push_back(dPVD); + ++dParamTypesIdx; + + if (dPVD->getIdentifier()) + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), + /*AddToContext=*/false); + m_Variables[*it] = + BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation()); + } + } + params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end()); + return params; +} + +StmtDiff ReverseModeForwPassVisitor::ProcessSingleStmt(const clang::Stmt* S) { + StmtDiff SDiff = Visit(S); + return {SDiff.getStmt()}; +} + +StmtDiff +ReverseModeForwPassVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) { + beginScope(Scope::DeclScope); + beginBlock(); + for (Stmt* S : CS->body()) { + StmtDiff SDiff = ProcessSingleStmt(S); + addToCurrentBlock(SDiff.getStmt()); + } + CompoundStmt* forward = endBlock(); + endScope(); + return {forward}; +} + +StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { + DeclRefExpr* clonedDRE = nullptr; + // Check if referenced Decl was "replaced" with another identifier inside + // the derivative + const auto* VD = dyn_cast(DRE->getDecl()); + auto it = m_DeclReplacements.find(VD); + if (it != std::end(m_DeclReplacements)) + clonedDRE = BuildDeclRef(it->second); + else + clonedDRE = cast(Clone(DRE)); + + auto* decl = dyn_cast(clonedDRE->getDecl()); + return StmtDiff(clonedDRE, m_Variables.find(decl)->second); +} + +StmtDiff +ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { + const Expr* value = RS->getRetValue(); + auto returnDiff = Visit(value); + llvm::SmallVector returnArgs = {returnDiff.getExpr(), + returnDiff.getExpr_dx()}; + Expr* returnInitList = m_Sema.ActOnInitList(noLoc, returnArgs, noLoc).get(); + Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get(); + return {newRS}; +} +} // namespace clad diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4e9805b64..9e00797f2 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -161,8 +161,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext gradientOverloadFDWC = - m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc, - gradientNameInfo, gradientFunctionOverloadType); + m_Builder.cloneFunction(m_Function, *this, DC, noLoc, gradientNameInfo, + gradientFunctionOverloadType); FunctionDecl* gradientOverloadFD = gradientOverloadFDWC.first; beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | @@ -353,14 +353,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore SaveScope(m_CurScope); DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; - DeclWithContext result = m_Builder.cloneFunction(m_Function, - *this, - DC, - m_Sema, - m_Context, - noLoc, - name, - gradientFunctionType); + DeclWithContext result = m_Builder.cloneFunction( + m_Function, *this, DC, noLoc, name, gradientFunctionType); FunctionDecl* gradientFD = result.first; m_Derivative = gradientFD; @@ -492,9 +486,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SaveAndRestore saveScope(m_CurScope); m_Sema.CurContext = const_cast(m_Function->getDeclContext()); - DeclWithContext fnBuildRes = - m_Builder.cloneFunction(m_Function, *this, m_Sema.CurContext, m_Sema, - m_Context, noLoc, DNI, pullbackFnType); + DeclWithContext fnBuildRes = m_Builder.cloneFunction( + m_Function, *this, m_Sema.CurContext, noLoc, DNI, pullbackFnType); m_Derivative = fnBuildRes.first; if (m_ExternalSource) @@ -1673,15 +1666,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl)); idx++; } + Expr* pullback = dfdx(); + if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) + pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); + // FIXME: Remove this restriction. if (!FD->getReturnType()->isVoidType()) { - assert((dfdx() && !FD->getReturnType()->isVoidType()) && + assert((pullback && !FD->getReturnType()->isVoidType()) && "Call to function returning non-void type with no dfdx() is not " "supported!"); } if (FD->getReturnType()->isVoidType()) { - assert(dfdx() == nullptr && FD->getReturnType()->isVoidType() && + assert(pullback == nullptr && FD->getReturnType()->isVoidType() && "Call to function returning void type should not have any " "corresponding dfdx()."); } @@ -1691,9 +1688,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivedCallOutputArgs.end()); pullbackCallArgs = DerivedCallArgs; - if (dfdx()) + if (pullback) pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(), - dfdx()); + pullback); // Try to find it in builtin derivatives std::string customPullback = FD->getNameAsString() + "_pullback"; @@ -1857,15 +1854,83 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::end(CallArgs), std::begin(CallArgs), [this](Expr* E) { return Clone(E); }); - // Recreate the original call expression. - Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), - Clone(CE->getCallee()), - noLoc, - CallArgs, - noLoc) - .get(); + + Expr* call = nullptr; + + if (FD->getReturnType()->isReferenceType()) { + DiffRequest calleeFnForwPassReq; + calleeFnForwPassReq.Function = FD; + calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; + calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString(); + calleeFnForwPassReq.VerboseDiags = true; + FunctionDecl* calleeFnForwPassFD = + plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + + assert(calleeFnForwPassFD && + "Clad failed to generate callee function forward pass function"); + + // FIXME: We are using the derivatives in forward pass here + // If `expr_dx()` is only meant to be used in reverse pass, + // (for example, `clad::pop(...)` expression and a corresponding + // `clad::push(...)` in the forward pass), then this can result in + // incorrect derivative or crash at runtime. Ideally, we should have + // a separate routine to use derivative in the forward pass. + + // We cannot reuse the derivatives previously computed because + // they might contain 'clad::pop(..)` expression. + if (isa(CE)) { + Expr* derivedBase = baseDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and + // passing pointers-by-reference if + // (isCladArrayType(derivedBase->getType())) + // CallArgs.push_back(derivedBase); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc)); + } + + for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + const Expr* arg = CE->getArg(i); + const ParmVarDecl* PVD = FD->getParamDecl(i); + StmtDiff argDiff = Visit(arg); + if ((argDiff.getExpr_dx() != nullptr) && + PVD->getType()->isReferenceType()) { + Expr* derivedArg = argDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and + // passing pointers-by-reference if + // (isCladArrayType(derivedArg->getType())) + // CallArgs.push_back(derivedArg); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc)); + } else + CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); + } + if (isa(CE)) { + Expr* baseE = baseDiff.getExpr(); + call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), + CallArgs, calleeFnForwPassFD); + } else { + call = m_Sema + .ActOnCallExpr(getCurrentScope(), + BuildDeclRef(calleeFnForwPassFD), noLoc, + CallArgs, noLoc) + .get(); + } + auto* callRes = StoreAndRef(call); + auto* resValue = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); + auto* resAdjoint = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); + return StmtDiff(resValue, nullptr, resAdjoint); + } // Recreate the original call expression. + call = m_Sema + .ActOnCallExpr(getCurrentScope(), Clone(CE->getCallee()), noLoc, + CallArgs, noLoc) + .get(); return StmtDiff(call); + + return {}; } StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { @@ -2312,7 +2377,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isDerivativeOfRefType) { initDiff = Visit(VD->getInit()); - if (!initDiff.getExpr_dx()) { + if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); isDerivativeOfRefType = false; @@ -3136,7 +3201,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // TODO: Add DiffMode::experimental_pullback support here as well. if (m_Mode == DiffMode::reverse || m_Mode == DiffMode::experimental_pullback) { - QualType effectiveReturnType = m_Function->getReturnType(); + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); if (m_Mode == DiffMode::experimental_pullback) { // FIXME: Generally, we use the function's return type as the argument's // derivative type. We cannot follow this strategy for `void` function @@ -3150,7 +3216,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (effectiveReturnType->isVoidType()) effectiveReturnType = m_Context.DoubleTy; else - paramTypes.push_back(m_Function->getReturnType()); + paramTypes.push_back(effectiveReturnType); } if (auto MD = dyn_cast(m_Function)) { diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 10971e9bc..9e38bc6a8 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -73,9 +73,8 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, // Create the function declaration for the derivative. DeclContext* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; - DeclWithContext result = - m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, loc, - name, vectorDiffFunctionType); + DeclWithContext result = m_Builder.cloneFunction( + m_Function, *this, DC, loc, name, vectorDiffFunctionType); FunctionDecl* vectorDiffFD = result.first; m_Derivative = vectorDiffFD; @@ -251,8 +250,8 @@ clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { auto* DC = const_cast(m_Function->getDeclContext()); m_Sema.CurContext = DC; DeclWithContext result = - m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc, - vectorModeNameInfo, vectorModeFuncOverloadType); + m_Builder.cloneFunction(m_Function, *this, DC, noLoc, vectorModeNameInfo, + vectorModeFuncOverloadType); FunctionDecl* vectorModeOverloadFD = result.first; // Function declaration scope diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 725731bc4..93d135756 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -351,6 +351,116 @@ double fn6(double i=0, double j=0) { return i*j; } +struct MyStruct { + static void myFunction() {} +}; + +double& identity(double& i) { + MyStruct::myFunction(); + double _d_i = i; + _d_i += 1; + return i; +} + +double fn7(double i, double j) { + double& k = identity(i); + double& l = identity(j); + k += 7*j; + l += 9*i; + return i + j; +} + +// CHECK: void fn6_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: _t1 = i; +// CHECK-NEXT: _t0 = j; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 1 * _t0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: double _r1 = _t1 * 1; +// CHECK-NEXT: * _d_j += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } + +// CHECK: void identity_pullback(double &i, double _d_y, clad::array_ref _d_i) { +// CHECK-NEXT: double _d__d_i = 0; +// CHECK-NEXT: MyStruct::myFunction(); +// CHECK-NEXT: double _d_i0 = i; +// CHECK-NEXT: _d_i0 += 1; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: * _d_i += _d_y; +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = _d__d_i; +// CHECK-NEXT: _d__d_i += _r_d0; +// CHECK-NEXT: _d__d_i -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: * _d_i += _d__d_i; +// CHECK-NEXT: } + +// CHECK: clad::ValueAndAdjoint identity_forw(double &i, clad::array_ref _d_i) { +// CHECK-NEXT: double _d__d_i = 0; +// CHECK-NEXT: MyStruct::myFunction(); +// CHECK-NEXT: double _d_i0 = i; +// CHECK-NEXT: _d_i0 += 1; +// CHECK-NEXT: return {i, * _d_i}; +// CHECK-NEXT: } + +// CHECK: void fn7_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double *_d_k = 0; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double *_d_l = 0; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = identity_forw(i, &* _d_i); +// CHECK-NEXT: _d_k = &_t1.adjoint; +// CHECK-NEXT: double &k = _t1.value; +// CHECK-NEXT: _t2 = j; +// CHECK-NEXT: clad::ValueAndAdjoint _t3 = identity_forw(j, &* _d_j); +// CHECK-NEXT: _d_l = &_t3.adjoint; +// CHECK-NEXT: double &l = _t3.value; +// CHECK-NEXT: _t4 = j; +// CHECK-NEXT: k += 7 * _t4; +// CHECK-NEXT: _t5 = i; +// CHECK-NEXT: l += 9 * _t5; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: * _d_i += 1; +// CHECK-NEXT: * _d_j += 1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d1 = *_d_l; +// CHECK-NEXT: *_d_l += _r_d1; +// CHECK-NEXT: double _r4 = _r_d1 * _t5; +// CHECK-NEXT: double _r5 = 9 * _r_d1; +// CHECK-NEXT: * _d_i += _r5; +// CHECK-NEXT: *_d_l -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = *_d_k; +// CHECK-NEXT: *_d_k += _r_d0; +// CHECK-NEXT: double _r2 = _r_d0 * _t4; +// CHECK-NEXT: double _r3 = 7 * _r_d0; +// CHECK-NEXT: * _d_j += _r3; +// CHECK-NEXT: *_d_k -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t2, 0, &* _d_j); +// CHECK-NEXT: double _r1 = * _d_j; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t0, 0, &* _d_i); +// CHECK-NEXT: double _r0 = * _d_i; +// CHECK-NEXT: } +// CHECK-NEXT: } + + template void reset(T* arr, int n) { for (int i=0; i _d_i, clad::array_ref _d_j); void const_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); void volatile_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); @@ -756,6 +758,35 @@ double fn(double i,double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn2(SimpleFunctions& sf, double i) { + return sf.ref_mem_fn(i); +} + +// CHECK: void ref_mem_fn_pullback(double i, double _d_y, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: (* _d_this).x += _d_y; +// CHECK-NEXT: } +// CHECK: clad::ValueAndAdjoint ref_mem_fn_forw(double i, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: return {this->x, (* _d_this).x}; +// CHECK-NEXT: } +// CHECK: void fn2_grad(SimpleFunctions &sf, double i, clad::array_ref _d_sf, clad::array_ref _d_i) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: SimpleFunctions _t1; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: _t1 = sf; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.ref_mem_fn_forw(_t0, &(* _d_sf), nullptr); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: _t1.ref_mem_fn_pullback(_t0, 1, &(* _d_sf), &_grad0); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + int main() { auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn); auto d_const_mem_fn = clad::gradient(&SimpleFunctions::const_mem_fn); @@ -790,6 +821,12 @@ int main() { printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00 } + SimpleFunctions sf(2, 3); + SimpleFunctions d_sf; + auto d_fn2 = clad::gradient(fn2); + d_fn2.execute(sf, 2, &d_sf, &result[0]); + printf("%.2f", result[0]); //CHECK-EXEC: 40.00 + auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i"); // CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref _d_this, clad::array_ref _d_i) const volatile & {