Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix braceless if differentiation in reverse mode #1058

Merged
merged 2 commits into from
Aug 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,10 +878,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
elseDiff.getStmt());
addToCurrentBlock(Forward, direction::forward);

Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr,
condDiffStored, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
elseDiff.getStmt_dx());
Stmt* Reverse = nullptr;
// thenDiff.getStmt_dx() might be empty if TBR is on leadinf to a crash in
// case of the braceless if.
if (thenDiff.getStmt_dx())
Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr,
/*Var=*/nullptr, condDiffStored, noLoc, noLoc, thenDiff.getStmt_dx(),
noLoc, elseDiff.getStmt_dx());
else if (elseDiff.getStmt_dx())
Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr,
/*Var=*/nullptr,
BuildOp(clang::UnaryOperatorKind::UO_LNot,
BuildParens(condDiffStored)),
noLoc, noLoc, elseDiff.getStmt_dx(), noLoc, {});
addToCurrentBlock(Reverse, direction::reverse);
CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
Expand Down
142 changes: 140 additions & 2 deletions test/Analyses/TBR.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: %cladclang -mllvm -debug-only=clad-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | %filecheck %s
// REQUIRES: asserts
// RUN: %cladclang %s -I%S/../../include -oTBR.out | %filecheck %s
// RUN: ./TBR.out | %filecheck_exec %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oTBR.out
// RUN: ./TBR.out | %filecheck_exec %s
//CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
Expand All @@ -11,6 +13,36 @@ double f1(double x) {
return t;
} // == x^3

//CHECK: void f1_grad(double x, double *_d_x) {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: double t = 1;
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!(i < 3))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, t);
//CHECK-NEXT: t *= x;
//CHECK-NEXT: }
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: t = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_t;
//CHECK-NEXT: _d_t = 0;
//CHECK-NEXT: _d_t += _r_d0 * x;
//CHECK-NEXT: *_d_x += t * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

double f2(double val) {
double res = 0;
for (int i=1; i<5; ++i) {
Expand All @@ -21,6 +53,111 @@ double f2(double val) {
return res;
}

//CHECK: void f2_grad(double val, double *_d_val) {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<bool> _cond0 = {};
//CHECK-NEXT: clad::tape<unsigned {{int|long}}> _t1 = {};
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 1; ; ++i) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!(i < 5))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: {
//CHECK-NEXT: clad::push(_cond0, i == 3);
//CHECK-NEXT: if (clad::back(_cond0)) {
//CHECK-NEXT: clad::push(_t1, {{1U|1UL}});
//CHECK-NEXT: continue;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: res += i * val;
//CHECK-NEXT: clad::push(_t1, {{2U|2UL}});
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: switch (clad::pop(_t1)) {
//CHECK-NEXT: case {{2U|2UL}}:
//CHECK-NEXT: ;
//CHECK-NEXT: --i;
//CHECK-NEXT: {
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_i += _r_d0 * val;
//CHECK-NEXT: *_d_val += i * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: if (clad::back(_cond0))
//CHECK-NEXT: case {{1U|1UL}}:
//CHECK-NEXT: ;
//CHECK-NEXT: clad::pop(_cond0);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

double f3 (double x){
double i = 1;
double j = 0;
double res = 0;
res += i*x;
if(j)
j++;
else if(i)
res += i*x;
else if(j)
i++;
return res;
}

//CHECK: void f3_grad(double x, double *_d_x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: bool _cond1;
//CHECK-NEXT: bool _cond2;
//CHECK-NEXT: double _d_i = 0;
//CHECK-NEXT: double i = 1;
//CHECK-NEXT: double _d_j = 0;
//CHECK-NEXT: double j = 0;
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: res += i * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = j;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: j++;
//CHECK-NEXT: else {
//CHECK-NEXT: _cond1 = i;
//CHECK-NEXT: if (_cond1)
//CHECK-NEXT: res += i * x;
//CHECK-NEXT: else {
//CHECK-NEXT: _cond2 = j;
//CHECK-NEXT: if (_cond2)
//CHECK-NEXT: i++;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: if (!_cond0)
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: double _r_d1 = _d_res;
//CHECK-NEXT: _d_i += _r_d1 * x;
//CHECK-NEXT: *_d_x += i * _r_d1;
//CHECK-NEXT: } else if (_cond2)
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_i += _r_d0 * x;
//CHECK-NEXT: *_d_x += i * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }


#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient<clad::opts::enable_tbr>(F);\
Expand All @@ -32,4 +169,5 @@ int main() {
double result[3] = {};
TEST(f1, 3); // CHECK-EXEC: {27.00}
TEST(f2, 3); // CHECK-EXEC: {9.00}
TEST(f3, 3); // CHECK-EXEC: {2.00}
}
Loading