Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rev ref returns #601

Merged
merged 1 commit into from
Aug 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ CheckOptions:
- key: readability-identifier-naming.IgnoreMainLikeFunctions
value: 1
- key: readability-implicit-bool-conversion.AllowPointerConditions
value: 1
value: 1
- key: readability-magic-numbers.IgnorePowersOf2IntegerValues
value: 1
- key: readability-magic-numbers.IgnoredIntegerValues
value: 4;8;16;
3 changes: 2 additions & 1 deletion include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,8 @@ static inline QualType getConstantArrayType(const ASTContext &Ctx,
#if CLANG_VERSION_MAJOR < 10
#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) /**/
#elif CLANG_VERSION_MAJOR >= 10
#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) ,((x)?VD.Clone((x)):nullptr)
#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x) \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function-like macro 'CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams' used; consider a 'constexpr' template function [cppcoreguidelines-macro-usage]

#define CLAD_COMPAT_CLANG10_FunctionDecl_Create_ExtraParams(x)                 \
        ^

, ((x) ? VB.Clone((x)) : nullptr)
#endif

// Clang 10 remove GetTemporaryExpr(). Use getSubExpr() instead
Expand Down
6 changes: 2 additions & 4 deletions include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ namespace clad {
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;

friend class ReverseModeForwPassVisitor;
clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
clang::ASTContext& m_Context;
Expand All @@ -93,9 +93,7 @@ namespace clad {
llvm::SmallVector<std::unique_ptr<ErrorEstimationHandler>, 4>
m_ErrorEstHandler;
DeclWithContext cloneFunction(const clang::FunctionDecl* FD,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: function 'clad::DerivativeBuilder::cloneFunction' has a definition with different parameter names [readability-inconsistent-declaration-parameter-name]

    DeclWithContext cloneFunction(const clang::FunctionDecl* FD,
                    ^
Additional context

lib/Differentiator/DerivativeBuilder.cpp:93: the definition seen here

  DeclWithContext DerivativeBuilder::cloneFunction(
                                     ^

include/clad/Differentiator/DerivativeBuilder.h:94: differing parameters are named here: ('VB'), in definition: ('VD')

    DeclWithContext cloneFunction(const clang::FunctionDecl* FD,
                    ^

clad::VisitorBase VB, clang::DeclContext* DC,
clang::Sema& m_Sema,
clang::ASTContext& m_Context,
clad::VisitorBase& VB, clang::DeclContext* DC,
clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name,
clang::QualType functionType);
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ enum class DiffMode {
reverse,
hessian,
jacobian,
reverse_mode_forward_pass,
error_estimation
};
}
Expand Down
5 changes: 5 additions & 0 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@
#include <cstring>

namespace clad {
template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
unsigned int count;
Expand Down
36 changes: 36 additions & 0 deletions include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#ifndef CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H
#define CLAD_DIFFERENTIATOR_REVERSEMODEFORWPASSVISITOR_H

#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clad/Differentiator/ReverseModeVisitor.h"

#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

namespace clad {
class ReverseModeForwPassVisitor : public ReverseModeVisitor {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: destructor of 'ReverseModeForwPassVisitor' is public and non-virtual [cppcoreguidelines-virtual-class-destructor]

class ReverseModeForwPassVisitor : public ReverseModeVisitor {
      ^
Additional context

include/clad/Differentiator/ReverseModeForwPassVisitor.h:12: make it public and virtual

class ReverseModeForwPassVisitor : public ReverseModeVisitor {
      ^

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Likewise.

private:
Stmts m_Globals;

llvm::SmallVector<clang::QualType, 8>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 8 is a magic number; consider replacing it with a named constant [cppcoreguidelines-avoid-magic-numbers]

  llvm::SmallVector<clang::QualType, 8>
                                     ^

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should be able to suppress these by adding IgnorePowersOf2IntegerValues=true in the clang-tidy config: https://clang.llvm.org/extra/clang-tidy/checks/readability/magic-numbers.html

ComputeParamTypes(const DiffParams& diffParams);
clang::QualType ComputeReturnType();
llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 8 is a magic number; consider replacing it with a named constant [cppcoreguidelines-avoid-magic-numbers]

  llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
                                         ^

clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);

public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

StmtDiff ProcessSingleStmt(const clang::Stmt* S);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override;
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override;
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // namespace clad

#endif
11 changes: 5 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ namespace clad {
class ReverseModeVisitor
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

private:
protected:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
// a separate namespace, as well as add getters/setters function of
Expand Down Expand Up @@ -292,7 +291,7 @@ namespace clad {

public:
ReverseModeVisitor(DerivativeBuilder& builder);
~ReverseModeVisitor();
virtual ~ReverseModeVisitor();

///\brief Produces the gradient of a given function.
///
Expand Down Expand Up @@ -321,11 +320,11 @@ namespace clad {
StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE);
StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp);
StmtDiff VisitCallExpr(const clang::CallExpr* CE);
StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS);
StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO);
StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL);
StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE);
StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE);
StmtDiff VisitDeclStmt(const clang::DeclStmt* DS);
StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL);
StmtDiff VisitForStmt(const clang::ForStmt* FS);
Expand All @@ -335,7 +334,7 @@ namespace clad {
StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL);
StmtDiff VisitMemberExpr(const clang::MemberExpr* ME);
StmtDiff VisitParenExpr(const clang::ParenExpr* PE);
StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS);
StmtDiff VisitStmt(const clang::Stmt* S);
StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp);
StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC);
Expand Down
4 changes: 2 additions & 2 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,8 @@ BaseForwardModeVisitor::Derive(const FunctionDecl* FD,
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result = m_Builder.cloneFunction(
FD, *this, DC, m_Sema, m_Context, loc, name, FD->getType());
DeclWithContext result =
m_Builder.cloneFunction(FD, *this, DC, loc, name, FD->getType());
FunctionDecl* derivedFD = result.first;
m_Derivative = derivedFD;

Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
StmtClone.cpp
VectorForwardModeVisitor.cpp
Expand Down
19 changes: 9 additions & 10 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "clad/Differentiator/ForwardModeVisitor.h"
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/VectorForwardModeVisitor.h"
Expand Down Expand Up @@ -88,15 +89,10 @@ namespace clad {
return false;
}

DeclWithContext
DerivativeBuilder::cloneFunction(const clang::FunctionDecl* FD,
clad::VisitorBase VD,
clang::DeclContext* DC,
clang::Sema& m_Sema,
clang::ASTContext& m_Context,
clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name,
clang::QualType functionType) {
DeclWithContext DerivativeBuilder::cloneFunction(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: method 'cloneFunction' can be made static [readability-convert-member-functions-to-static]

include/clad/Differentiator/DerivativeBuilder.h:94:

-     DeclWithContext cloneFunction(const clang::FunctionDecl* FD,
+     static DeclWithContext cloneFunction(const clang::FunctionDecl* FD,

const clang::FunctionDecl* FD, clad::VisitorBase& VB,
clang::DeclContext* DC, clang::SourceLocation& noLoc,
clang::DeclarationNameInfo name, clang::QualType functionType) {
FunctionDecl* returnedFD = nullptr;
NamespaceDecl* enclosingNS = nullptr;
if (isa<CXXMethodDecl>(FD)) {
Expand All @@ -115,7 +111,7 @@ namespace clad {
returnedFD->setAccess(FD->getAccess());
} else {
assert (isa<FunctionDecl>(FD) && "Unexpected!");
enclosingNS = VD.RebuildEnclosingNamespaces(DC);
enclosingNS = VB.RebuildEnclosingNamespaces(DC);
returnedFD = FunctionDecl::Create(m_Context,
m_Sema.CurContext,
noLoc,
Expand Down Expand Up @@ -230,6 +226,9 @@ namespace clad {
result = V.DerivePullback(FD, request);
if (!m_ErrorEstHandler.empty())
CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this);
result = V.Derive(FD, request);
} else if (request.Mode == DiffMode::hessian) {
HessianModeVisitor H(*this);
result = H.Derive(FD, request);
Expand Down
5 changes: 2 additions & 3 deletions lib/Differentiator/ForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ clang::QualType ForwardModeVisitor::ComputePushforwardFnReturnType() {
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;

DeclWithContext cloneFunctionResult =
m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc,
derivedFnName, derivedFnType);
DeclWithContext cloneFunctionResult = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, derivedFnName, derivedFnType);
m_Derivative = cloneFunctionResult.first;

llvm::SmallVector<ParmVarDecl*, 16> params;
Expand Down
10 changes: 2 additions & 8 deletions lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,14 +231,8 @@ namespace clad {
llvm::SaveAndRestore<Scope*> SaveScope(m_CurScope);
m_Sema.CurContext = DC;

DeclWithContext result = m_Builder.cloneFunction(m_Function,
*this,
DC,
m_Sema,
m_Context,
noLoc,
name,
hessianFunctionType);
DeclWithContext result = m_Builder.cloneFunction(
m_Function, *this, DC, noLoc, name, hessianFunctionType);
FunctionDecl* hessianFD = result.first;

beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Expand Down
Loading