Skip to content

Commit

Permalink
Add support of custom _forw functions
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 committed Aug 10, 2024
1 parent 1b81084 commit 5d64c63
Show file tree
Hide file tree
Showing 9 changed files with 178 additions and 13 deletions.
5 changes: 5 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ template <typename T, typename U> struct ValueAndPushforward {
}
};

template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// It is used to identify constructor custom pushforwards. For
/// constructor custom pushforward functions, we cannot use the same
/// strategy which we use for custom pushforward for member
Expand Down
2 changes: 2 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ namespace clad {

bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);

bool isNonConstReferenceType(clang::QualType QT);
} // namespace utils
} // namespace clad

Expand Down
4 changes: 0 additions & 4 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,6 @@
#include <cstring>

namespace clad {
template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
Expand Down
4 changes: 4 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ namespace clad {
// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();

clang::Expr* BuildCallToCustomForwPassFn(
const clang::FunctionDecl* FD, llvm::ArrayRef<clang::Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, clang::Expr* baseExpr);

public:
using direction = rmv::direction;
clang::Expr* dfdx() {
Expand Down
29 changes: 29 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,35 @@ void fill_pushforward(::std::array<T, N>* a, const T& u,
d_a->fill(d_u);
}

template <typename T, typename U>
void push_back_forw(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U* d_val) {
v->push_back(val);
d_v->push_back(0);
}

template <typename T, typename U>
void push_back_pullback(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U* d_val) {
*d_val += d_v->back();
d_v->pop_back();
}

template <typename T>
clad::ValueAndAdjoint<T&, T&> operator_subscript_forw(
::std::vector<T>* vec, typename ::std::vector<T>::size_type idx,
::std::vector<T>* d_vec, typename ::std::vector<T>::size_type* d_idx) {
return {(*vec)[idx], (*d_vec)[idx]};
}

template <typename T, typename P>
void operator_subscript_pullback(::std::vector<T>* vec,
typename ::std::vector<T>::size_type idx,
P d_y, ::std::vector<T>* d_vec,
typename ::std::vector<T>::size_type* d_idx) {
(*d_vec)[idx] += d_y;
}

} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,5 +705,10 @@ namespace clad {
return FD->getNameAsString() == "free";
#endif
}

bool isNonConstReferenceType(clang::QualType QT) {
return QT->isReferenceType() &&
!QT.getNonReferenceType().isConstQualified();
}
} // namespace utils
} // namespace clad
36 changes: 34 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1791,6 +1791,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Stores differentiation result of implicit `this` object, if any.
StmtDiff baseDiff;
Expr *baseExpr = nullptr;
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
Expand All @@ -1814,6 +1815,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
baseOriginalE = OCE->getArg(0);

baseDiff = Visit(baseOriginalE);
baseExpr = baseDiff.getExpr();
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
baseDiff.updateStmt(baseDiffStore);
Expr* baseDerivative = baseDiff.getExpr_dx();
Expand Down Expand Up @@ -1999,8 +2001,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call = nullptr;

QualType returnType = FD->getReturnType();
if (returnType->isReferenceType() &&
!returnType.getNonReferenceType().isConstQualified()) {
if (Expr* customForwardPassCE = BuildCallToCustomForwPassFn(
FD, CallArgs, DerivedCallOutputArgs, baseExpr)) {
if (!utils::isNonConstReferenceType(returnType))
return StmtDiff{customForwardPassCE};
auto* callRes = StoreAndRef(customForwardPassCE);
auto* resValue =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
}
if (utils::isNonConstReferenceType(returnType)) {
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
Expand Down Expand Up @@ -4259,4 +4271,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diffParams.end());
return params;
}

Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn(
const FunctionDecl* FD, llvm::ArrayRef<Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, Expr *baseExpr) {
std::string forwPassFnName =
clad::utils::ComputeEffectiveFnName(FD) + "_forw";
llvm::SmallVector<Expr*> args;
if (baseExpr) {
baseExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, baseExpr,
m_DiffReq->getLocation());
args.push_back(baseExpr);
}
args.append(primalArgs.begin(), primalArgs.end());
args.append(derivedArgs.begin(), derivedArgs.end());
Expr* customForwPassCE =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
forwPassFnName, args, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
return customForwPassCE;
}
} // end namespace clad
6 changes: 1 addition & 5 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -260,10 +260,6 @@ double fn7(double i, double j) {

// CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i);

// CHECK: clad::ValueAndAdjoint<double &, double &> custom_identity_forw(double &i, double *d_i) {
// CHECK-NEXT: return {i, *d_i};
// CHECK-NEXT: }

// CHECK: void fn7_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _t0 = i;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = identity_forw(i, &*_d_i);
Expand All @@ -274,7 +270,7 @@ double fn7(double i, double j) {
// CHECK-NEXT: double &_d_l = _t3.adjoint;
// CHECK-NEXT: double &l = _t3.value;
// CHECK-NEXT: double _t4 = i;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = custom_identity_forw(i, &*_d_i);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = {{.*}}custom_derivatives::custom_identity_forw(i, &*_d_i);
// CHECK-NEXT: double &_d_temp = _t5.adjoint;
// CHECK-NEXT: double &temp = _t5.value;
// CHECK-NEXT: double _t6 = k;
Expand Down
100 changes: 98 additions & 2 deletions test/Gradient/UserDefinedTypes.C
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
// RUN: %cladclang %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s
// RUN: %cladclang -std=c++14 %s -I%S/../../include -oUserDefinedTypes.out 2>&1 | %filecheck %s
// RUN: ./UserDefinedTypes.out | %filecheck_exec %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out
// RUN: %cladclang -std=c++14 -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oUserDefinedTypes.out
// RUN: ./UserDefinedTypes.out | %filecheck_exec %s
// CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
#include "clad/Differentiator/STLBuiltins.h"

#include <utility>
#include <vector>
#include <complex>

#include "../TestUtils.h"
Expand Down Expand Up @@ -326,6 +328,22 @@ double fn9(Tangent t, dcomplex c) {
// CHECK-NEXT: }
// CHECK-NEXT: }

double fn10(double u, double v) {
std::vector<double> vec;
vec.push_back(u);
vec.push_back(v);
return vec[0] + vec[1];
}

double fn11(double u, double v) {
std::vector<double> vec;
vec.push_back(u);
vec.push_back(v);
double &ref = vec[0];
ref += u;
return vec[0] + vec[1];
}

void print(const Tangent& t) {
for (int i = 0; i < 5; ++i) {
printf("%.2f", t.data[i]);
Expand All @@ -334,6 +352,7 @@ void print(const Tangent& t) {
}
}


int main() {
pairdd p(3, 5), d_p;
double i = 3, d_i, d_j;
Expand All @@ -351,6 +370,8 @@ int main() {
INIT_GRADIENT(fn7);
INIT_GRADIENT(fn8);
INIT_GRADIENT(fn9);
INIT_GRADIENT(fn10);
INIT_GRADIENT(fn11);

TEST_GRADIENT(fn1, /*numOfDerivativeArgs=*/2, p, i, &d_p, &d_i); // CHECK-EXEC: {1.00, 2.00, 3.00}
TEST_GRADIENT(fn2, /*numOfDerivativeArgs=*/2, t, i, &d_t, &d_i); // CHECK-EXEC: {4.00, 2.00, 2.00, 2.00, 2.00, 1.00}
Expand All @@ -364,8 +385,83 @@ int main() {
TEST_GRADIENT(fn7, /*numOfDerivativeArgs=*/2, c1, c2, &d_c1, &d_c2);// CHECK-EXEC: {0.00, 3.00, 5.00, 1.00}
TEST_GRADIENT(fn8, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {0.00, 0.00, 0.00, 0.00, 0.00, 5.00, 0.00}
TEST_GRADIENT(fn9, /*numOfDerivativeArgs=*/2, t, c1, &d_t, &d_c1); // CHECK-EXEC: {1.00, 1.00, 1.00, 1.00, 1.00, 5.00, 10.00}
TEST_GRADIENT(fn10, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {1.00, 1.00}
TEST_GRADIENT(fn11, /*numOfDerivativeArgs=*/2, 3, 5, &d_i, &d_j); // CHECK-EXEC: {2.00, 1.00}
}

// CHECK: void fn10_grad(double u, double v, double *_d_u, double *_d_v) {
// CHECK-NEXT: std::vector<double> _d_vec({});
// CHECK-NEXT: std::vector<double> vec;
// CHECK-NEXT: double _t0 = u;
// CHECK-NEXT: std::vector<double> _t1 = vec;
// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, u, &_d_vec, &*_d_u);
// CHECK-NEXT: double _t2 = v;
// CHECK-NEXT: std::vector<double> _t3 = vec;
// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, v, &_d_vec, &*_d_v);
// CHECK-NEXT: std::vector<double> _t4 = vec;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}class_functions::operator_subscript_forw(&vec, 0, &_d_vec, &_r0);
// CHECK-NEXT: std::vector<double> _t6 = vec;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t7 = {{.*}}class_functions::operator_subscript_forw(&vec, 1, &_d_vec, &_r1);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = 0;
// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t4, 0, 1, &_d_vec, &_r0);
// CHECK-NEXT: {{.*}} _r1 = 0;
// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t6, 1, 1, &_d_vec, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: v = _t2;
// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t3, _t2, &_d_vec, &*_d_v);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: u = _t0;
// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t1, _t0, &_d_vec, &*_d_u);
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK-NEXT: void fn11_grad(double u, double v, double *_d_u, double *_d_v) {
// CHECK-NEXT: std::vector<double> _d_vec({});
// CHECK-NEXT: std::vector<double> vec;
// CHECK-NEXT: double _t0 = u;
// CHECK-NEXT: std::vector<double> _t1 = vec;
// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, u, &_d_vec, &*_d_u);
// CHECK-NEXT: double _t2 = v;
// CHECK-NEXT: std::vector<double> _t3 = vec;
// CHECK-NEXT: {{.*}}class_functions::push_back_forw(&vec, v, &_d_vec, &*_d_v);
// CHECK-NEXT: std::vector<double> _t4 = vec;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t5 = {{.*}}class_functions::operator_subscript_forw(&vec, 0, &_d_vec, &_r0);
// CHECK-NEXT: double &_d_ref = _t5.adjoint;
// CHECK-NEXT: double &ref = _t5.value;
// CHECK-NEXT: double _t6 = ref;
// CHECK-NEXT: ref += u;
// CHECK-NEXT: std::vector<double> _t7 = vec;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t8 = {{.*}}class_functions::operator_subscript_forw(&vec, 0, &_d_vec, &_r1);
// CHECK-NEXT: std::vector<double> _t9 = vec;
// CHECK-NEXT: {{.*}}ValueAndAdjoint<double &, double &> _t10 = {{.*}}class_functions::operator_subscript_forw(&vec, 1, &_d_vec, &_r2);
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r1 = 0;
// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t7, 0, 1, &_d_vec, &_r1);
// CHECK-NEXT: {{.*}} _r2 = 0;
// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t9, 1, 1, &_d_vec, &_r2);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: ref = _t6;
// CHECK-NEXT: double _r_d0 = _d_ref;
// CHECK-NEXT: *_d_u += _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: {{.*}} _r0 = 0;
// CHECK-NEXT: {{.*}}class_functions::operator_subscript_pullback(&_t4, 0, 0, &_d_vec, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: v = _t2;
// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t3, _t2, &_d_vec, &*_d_v);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: u = _t0;
// CHECK-NEXT: {{.*}}class_functions::push_back_pullback(&_t1, _t0, &_d_vec, &*_d_u);
// CHECK-NEXT: }
// CHECK-NEXT: }

// CHECK: void sum_pullback(Tangent &t, double _d_y, Tangent *_d_t) {
// CHECK-NEXT: int _d_i = 0;
// CHECK-NEXT: int i = 0;
Expand Down

0 comments on commit 5d64c63

Please sign in to comment.