Skip to content

Commit

Permalink
Add support for diff of ref return types in rev mode
Browse files Browse the repository at this point in the history
Co-authored-by: Daemond <daemondzh@gmail.com>
  • Loading branch information
parth-07 authored and PhrygianGates committed Aug 28, 2023
1 parent 24a15f9 commit ea8b406
Show file tree
Hide file tree
Showing 17 changed files with 562 additions and 68 deletions.
6 changes: 5 additions & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ CheckOptions:
- key: readability-identifier-naming.IgnoreMainLikeFunctions
value: 1
- key: readability-implicit-bool-conversion.AllowPointerConditions
value: 1
value: 1
- key: readability-magic-numbers.IgnorePowersOf2IntegerValues
value: 1
- key: readability-magic-numbers.IgnoredIntegerValues
value: 4;8;16;
3 changes: 2 additions & 1 deletion include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ static inline QualType getConstantArrayType(const ASTContext &Ctx,
#if CLANG_VERSION_MAJOR < 10
#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) /**/
#elif CLANG_VERSION_MAJOR >= 10
#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) ,((x)?VD.Clone((x)):nullptr)
#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) \
, ((x) ? VB.Clone((x)) : nullptr)
#endif

// Clang 10 remove GetTemporaryExpr(). Use getSubExpr() instead
Expand Down
6 changes: 2 additions & 4 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace clad {
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;

friend class ReverseModeForwPassVisitor;
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
Expand All @@ -93,9 +93,7 @@ namespace clad {
llvm::SmallVector<std::unique_ptr<ErrorEstimationHandler>, 4>
m_ErrorEstHandler;
DeclWithContext cloneFunction(const clang::FunctionDecl* FD,
clad::VisitorBase VB, clang::DeclContext* DC,
clang::Sema& m_Sema,
clang::ASTContext& m_Context,
clad::VisitorBase& VB, clang::DeclContext* DC,
clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name,
clang::QualType functionType);
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
reverse,
hessian,
jacobian,
reverse_mode_forward_pass,
error_estimation
};
}
Expand Down
5 changes: 5 additions & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
#include <cstring>

namespace clad {
template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
unsigned int count;
Expand Down
36 changes: 36 additions & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H
#define CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H

#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clad/Differentiator/ReverseModeVisitor.h"

#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

namespace clad {
class ReverseModeForwPassVisitor : public ReverseModeVisitor {
private:
Stmts m_Globals;

llvm::SmallVector<clang::QualType, 8>
ComputeParamTypes(const DiffParams& diffParams);
clang::QualType ComputeReturnType();
llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

StmtDiff ProcessSingleStmt(const clang::Stmt* S);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // namespace clad

#endif
11 changes: 5 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ namespace clad {
class ReverseModeVisitor
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

private:
protected:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
// a separate namespace, as well as add getters/setters function of
Expand Down Expand Up @@ -292,7 +291,7 @@ namespace clad {

public:
ReverseModeVisitor(DerivativeBuilder& builder);
~ReverseModeVisitor();
virtual ~ReverseModeVisitor();

///\brief Produces the gradient of a given function.
///
Expand Down Expand Up @@ -321,11 +320,11 @@ namespace clad {
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
Expand All @@ -335,7 +334,7 @@ namespace clad {
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result = m_Builder.cloneFunction(
FD, *this, DC, m_Sema, m_Context, loc, name, FD->getType());
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType());
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
Expand Down
19 changes: 9 additions & 10 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "clad/Differentiator/ForwardModeVisitor.h"
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/VectorForwardModeVisitor.h"
Expand Down Expand Up @@ -88,15 +89,10 @@ namespace clad {
return false;
}

DeclWithContext
DerivativeBuilder::cloneFunction(const clang::FunctionDecl* FD,
clad::VisitorBase VD,
clang::DeclContext* DC,
clang::Sema& m_Sema,
clang::ASTContext& m_Context,
clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name,
clang::QualType functionType) {
DeclWithContext DerivativeBuilder::cloneFunction(
const clang::FunctionDecl* FD, clad::VisitorBase& VB,
clang::DeclContext* DC, clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name, clang::QualType functionType) {
FunctionDecl* returnedFD = nullptr;
NamespaceDecl* enclosingNS = nullptr;
if (isa<CXXMethodDecl>(FD)) {
Expand All @@ -115,7 +111,7 @@ namespace clad {
returnedFD->setAccess(FD->getAccess());
} else {
assert (isa<FunctionDecl>(FD) && "Unexpected!");
enclosingNS = VD.RebuildEnclosingNamespaces(DC);
enclosingNS = VB.RebuildEnclosingNamespaces(DC);
returnedFD = FunctionDecl::Create(m_Context,
m_Sema.CurContext,
noLoc,
Expand Down Expand Up @@ -230,6 +226,9 @@ namespace clad {
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
HessianModeVisitor H(*this);
result = H.Derive(FD, request);
Expand Down
5 changes: 2 additions & 3 deletions lib/Differentiator/ForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ clang::QualType ForwardModeVisitor::ComputePushforwardFnReturnType() {
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;

DeclWithContext cloneFunctionResult =
m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc,
derivedFnName, derivedFnType);
DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
Expand Down
10 changes: 2 additions & 8 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,8 @@ namespace clad {
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
m_Sema.CurContext = DC;

DeclWithContext result = m_Builder.cloneFunction(m_Function,
*this,
DC,
m_Sema,
m_Context,
noLoc,
name,
hessianFunctionType);
DeclWithContext result = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, name, hessianFunctionType);
FunctionDecl* hessianFD = result.first;

beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Expand Down
Loading

0 comments on commit ea8b406

Please sign in to comment.