Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
phrygiangates committed Aug 19, 2023
1 parent d89426c commit 1eebf49
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 131 deletions.
3 changes: 2 additions & 1 deletion demos/ComputerGraphics/smallpt/SmallPT.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
// ./SmallPT 500 && xv image.ppm

// A typical invocation would be:
// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang clad \
// ../../../../../obj/Debug+Asserts/bin/clang++ -O3 -Xclang -add-plugin -Xclang
// clad \
// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \
// -I../../include/ -std=c++11 SmallPT.cpp -fopenmp=libiomp5 -o SmallPT
// ./SmallPT 500 && xv image.ppm
Expand Down
8 changes: 5 additions & 3 deletions demos/OpenCL/RosenbrockFunction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@

// To run the demo please type:
// path/to/clang++ -Xclang -add-plugin -Xclang clad -Xclang -load -Xclang \
// path/to/libclad.so -I../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp
// path/to/libclad.so -I../include/ -framework opencl -std=c++11
// RosenbrockFunction.cpp
//
// A typical invocation would be:
// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang clad \
// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \
// ../../../../../obj/Debug+Asserts/bin/clang++ -Xclang -add-plugin -Xclang
// clad \
// -Xclang -load -Xclang ../../../../../obj/Debug+Asserts/lib/libclad.dylib \
// -I../../include/ -framework opencl -std=c++11 RosenbrockFunction.cpp

// Necessary for clad to work include
Expand Down
32 changes: 17 additions & 15 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,15 +75,15 @@ static inline NamespaceDecl*
NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline,
SourceLocation StartLoc, SourceLocation IdLoc,
IdentifierInfo* Id, NamespaceDecl* PrevDecl) {
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl);
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl);
}
#else
static inline NamespaceDecl*
NamespaceDecl_Create(ASTContext& C, DeclContext* DC, bool Inline,
SourceLocation StartLoc, SourceLocation IdLoc,
IdentifierInfo* Id, NamespaceDecl* PrevDecl) {
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl,
/*Nested=*/false);
return NamespaceDecl::Create(C, DC, Inline, StartLoc, IdLoc, Id, PrevDecl,
/*Nested=*/false);
}
#endif

Expand Down Expand Up @@ -249,17 +249,19 @@ static inline void ExprSetDeps(Expr* result, Expr* Node) {
#define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsPar ,clang::CallExpr::ADLCallKind UsesADL
#define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsUse ,UsesADL
#define CLAD_COMPAT_CLANG11_CXXOperatorCallExpr_Create_ExtraParamsOverride FPOptionsOverride
#if CLANG_VERSION_MAJOR >= 16
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/
#else
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts()
#endif
#define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx,
#define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS)
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved ,Node->getComputationLHSType(),Node->getComputationResultType()
#define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams ,Node->getLParenLoc(),Node->getRParenLoc()
#if CLANG_VERSION_MAJOR >= 16
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams /**/
#else
#define CLAD_COMPAT_CLANG11_LangOptions_EtraParams Ctx.getLangOpts()
#endif
#define CLAD_COMPAT_CLANG11_Ctx_ExtraParams Ctx,
#define CLAD_COMPAT_CREATE11(CLASS, CTORARGS) (CLASS::Create CTORARGS)
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_CompoundAssignOperator_EtraParams_Moved \
, Node->getComputationLHSType(), Node->getComputationResultType()
#define CLAD_COMPAT_CLANG11_ChooseExpr_EtraParams_Removed /**/
#define CLAD_COMPAT_CLANG11_WhileStmt_ExtraParams \
, Node->getLParenLoc(), Node->getRParenLoc()
#endif

// Compatibility helper function for creation CXXOperatorCallExpr. Clang 8 and above use Create.
Expand Down Expand Up @@ -725,7 +727,7 @@ ArraySize_GetValue(const llvm::Optional<const Expr*>& opt) {
#else
static inline const Expr*
ArraySize_GetValue(const std::optional<const Expr*>& opt) {
return opt.value();
return opt.value();
}
#endif

Expand Down
35 changes: 17 additions & 18 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,10 @@
#include <cstring>

namespace clad {
template<typename T, typename U>
struct ValueAndAdjoint {
T value;
U adjoint;
};
template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
Expand Down Expand Up @@ -305,9 +304,9 @@ namespace clad {
differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code);
}

/// Specialization for differentiating functors.
Expand All @@ -325,8 +324,8 @@ namespace clad {
differentiate(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>>(derivedFn,
code, f);
}

/// Generates function which computes derivative of `fn` argument w.r.t
Expand All @@ -348,9 +347,9 @@ namespace clad {
differentiate(F fn, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn, code);
assert(fn && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn, code);
}

/// Generates function which computes gradient of the given function wrt the
Expand All @@ -370,9 +369,9 @@ namespace clad {
gradient(F f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
assert(f && "Must pass in a non-0 argument");
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code);
}

/// Specialization for differentiating functors.
Expand All @@ -388,8 +387,8 @@ namespace clad {
gradient(F&& f, ArgSpec args = "",
DerivedFnType derivedFn = static_cast<DerivedFnType>(nullptr),
const char* code = "") {
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, f);
return CladFunction<DerivedFnType, ExtractFunctorTraits_t<F>, true>(
derivedFn /* will be replaced by gradient*/, code, f);
}

/// Generates function which computes hessian matrix of the given function wrt
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeForwPassVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ReverseModeForwPassVisitor : public ReverseModeVisitor {
public:
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
DerivativeAndOverload Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);
const DiffRequest& request);

StmtDiff ProcessSingleStmt(const clang::Stmt* S);

Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace clad {
class ReverseModeVisitor
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

protected:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
Expand Down
3 changes: 1 addition & 2 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,8 @@
#include "clad/Differentiator/ForwardModeVisitor.h"
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/VectorForwardModeVisitor.h"

Expand Down
16 changes: 8 additions & 8 deletions lib/Differentiator/DiffPlanner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -390,20 +390,20 @@ namespace clad {
// The string is not a range just a single index
size_t index;
if (firstStr.getAsInteger(Radix, index)) {
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse index '%0'", {diffSpec});
return;
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse index '%0'", {diffSpec});
return;
}
dVarInfo.paramIndexInterval = IndexInterval(index);
} else {
size_t first, last;
if (firstStr.getAsInteger(Radix, first) ||
lastStr.getAsInteger(Radix, last)) {
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse range '%0'", {diffSpec});
return;
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
diffArgs->getEndLoc(),
"Could not parse range '%0'", {diffSpec});
return;
}
if (first >= last) {
utils::EmitDiag(semaRef, DiagnosticsEngine::Error,
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/HessianModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -380,4 +380,4 @@ namespace clad {
return DerivativeAndOverload{result.first,
/*OverloadFunctionDecl=*/nullptr};
}
} // end namespace clad
} // end namespace clad
41 changes: 22 additions & 19 deletions lib/Differentiator/ReverseModeForwPassVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,

auto paramTypes = ComputeParamTypes(args);
auto returnType = ComputeReturnType();
const auto *sourceFnType = dyn_cast<FunctionProtoType>(m_Function->getType());
const auto* sourceFnType = dyn_cast<FunctionProtoType>(m_Function->getType());
auto fnType = m_Context.getFunctionType(returnType, paramTypes,
sourceFnType->getExtProtoInfo());

Expand Down Expand Up @@ -67,7 +67,7 @@ ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
for (Stmt* S : ReverseModeVisitor::m_Globals)
addToCurrentBlock(S);

if (auto *CS = dyn_cast<CompoundStmt>(forward))
if (auto* CS = dyn_cast<CompoundStmt>(forward))
for (Stmt* S : CS->body())
addToCurrentBlock(S);

Expand All @@ -93,20 +93,20 @@ ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType,
QualType nonRefXValueType = xValueType.getNonReferenceType();
if (nonRefXValueType->isRealType())
return GetCladArrayRefOfType(yType);
return GetCladArrayRefOfType(nonRefXValueType);
return GetCladArrayRefOfType(nonRefXValueType);
}

llvm::SmallVector<clang::QualType, 8>
ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) {
llvm::SmallVector<clang::QualType, 8> paramTypes;
paramTypes.reserve(m_Function->getNumParams() * 2);
for (auto *PVD : m_Function->parameters())
for (auto* PVD : m_Function->parameters())
paramTypes.push_back(PVD->getType());

QualType effectiveReturnType =
m_Function->getReturnType().getNonReferenceType();

if (const auto *MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
const CXXRecordDecl* RD = MD->getParent();
if (MD->isInstance() && !RD->isLambda()) {
QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD);
Expand All @@ -115,8 +115,9 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) {
}
}

for (auto *PVD : m_Function->parameters()) {
const auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
for (auto* PVD : m_Function->parameters()) {
const auto* it =
std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
paramTypes.push_back(
GetParameterDerivativeType(effectiveReturnType, PVD->getType()));
Expand All @@ -126,7 +127,8 @@ ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) {
}

clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() {
auto *valAndAdjointTempDecl = LookupTemplateDeclInCladNamespace("ValueAndAdjoint");
auto* valAndAdjointTempDecl =
LookupTemplateDeclInCladNamespace("ValueAndAdjoint");
auto RT = m_Function->getReturnType();
auto T = InstantiateTemplate(valAndAdjointTempDecl, {RT, RT});
return T;
Expand All @@ -137,14 +139,15 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
llvm::SmallVector<clang::ParmVarDecl*, 8> params;
llvm::SmallVector<clang::ParmVarDecl*, 8> paramDerivatives;
params.reserve(m_Function->getNumParams() + diffParams.size());
const auto *derivativeFnType = cast<FunctionProtoType>(m_Derivative->getType());
const auto* derivativeFnType =
cast<FunctionProtoType>(m_Derivative->getType());

std::size_t dParamTypesIdx = m_Function->getNumParams();

if (const auto *MD = dyn_cast<CXXMethodDecl>(m_Function)) {
if (const auto* MD = dyn_cast<CXXMethodDecl>(m_Function)) {
const CXXRecordDecl* RD = MD->getParent();
if (MD->isInstance() && !RD->isLambda()) {
auto *thisDerivativePVD = utils::BuildParmVarDecl(
auto* thisDerivativePVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"),
derivativeFnType->getParamType(dParamTypesIdx));
paramDerivatives.push_back(thisDerivativePVD);
Expand All @@ -159,10 +162,10 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
++dParamTypesIdx;
}
}
for (auto *PVD : m_Function->parameters()) {
for (auto* PVD : m_Function->parameters()) {
// FIXME: Call expression may contain default arguments that we are now
// removing. This may cause issues.
auto *newPVD = utils::BuildParmVarDecl(
auto* newPVD = utils::BuildParmVarDecl(
m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(),
PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo());
params.push_back(newPVD);
Expand All @@ -171,14 +174,14 @@ ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
m_Sema.PushOnScopeChains(newPVD, getCurrentScope(),
/*AddToContext=*/false);

auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
auto* it = std::find(std::begin(diffParams), std::end(diffParams), PVD);
if (it != std::end(diffParams)) {
*it = newPVD;
QualType dType = derivativeFnType->getParamType(dParamTypesIdx);
IdentifierInfo* dII =
CreateUniqueIdentifier("_d_" + PVD->getNameAsString());
auto *dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType,
PVD->getStorageClass());
auto* dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType,
PVD->getStorageClass());
paramDerivatives.push_back(dPVD);
++dParamTypesIdx;

Expand Down Expand Up @@ -230,7 +233,7 @@ StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
DeclRefExpr* clonedDRE = nullptr;
// Check if referenced Decl was "replaced" with another identifier inside
// the derivative
if (const auto *VD = dyn_cast<VarDecl>(DRE->getDecl())) {
if (const auto* VD = dyn_cast<VarDecl>(DRE->getDecl())) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements))
clonedDRE = BuildDeclRef(it->second);
Expand All @@ -242,13 +245,13 @@ StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
// Sema::BuildDeclRefExpr is responsible for adding captured fields
// to the underlying struct of a lambda.
if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) {
auto *referencedDecl = cast<VarDecl>(clonedDRE->getDecl());
auto* referencedDecl = cast<VarDecl>(clonedDRE->getDecl());
clonedDRE = cast<DeclRefExpr>(BuildDeclRef(referencedDecl));
}
} else
clonedDRE = cast<DeclRefExpr>(Clone(DRE));

if (auto *decl = dyn_cast<VarDecl>(clonedDRE->getDecl())) {
if (auto* decl = dyn_cast<VarDecl>(clonedDRE->getDecl())) {
// Check DeclRefExpr is a reference to an independent variable.
auto it = m_Variables.find(decl);
if (it == std::end(m_Variables)) {
Expand Down
Loading

0 comments on commit 1eebf49

Please sign in to comment.