Skip to content

Commit

Permalink
Fix braceless if differentiation in reverse mode (#1058)
Browse files Browse the repository at this point in the history
As in #1049, the statement inside the braceless if is not included causing an error.
Fixes: #1049
  • Loading branch information
ovdiiuv authored Aug 25, 2024
1 parent 4368c04 commit 6f4b081
Show file tree
Hide file tree
Showing 2 changed files with 155 additions and 6 deletions.
19 changes: 15 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -878,10 +878,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
elseDiff.getStmt());
addToCurrentBlock(Forward, direction::forward);

Stmt* Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr,
condDiffStored, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc,
elseDiff.getStmt_dx());
Stmt* Reverse = nullptr;
// thenDiff.getStmt_dx() might be empty if TBR is on leadinf to a crash in
// case of the braceless if.
if (thenDiff.getStmt_dx())
Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr,
/*Var=*/nullptr, condDiffStored, noLoc, noLoc, thenDiff.getStmt_dx(),
noLoc, elseDiff.getStmt_dx());
else if (elseDiff.getStmt_dx())
Reverse = clad_compat::IfStmt_Create(
m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr,
/*Var=*/nullptr,
BuildOp(clang::UnaryOperatorKind::UO_LNot,
BuildParens(condDiffStored)),
noLoc, noLoc, elseDiff.getStmt_dx(), noLoc, {});
addToCurrentBlock(Reverse, direction::reverse);
CompoundStmt* ForwardBlock = endBlock(direction::forward);
CompoundStmt* ReverseBlock = endBlock(direction::reverse);
Expand Down
142 changes: 140 additions & 2 deletions test/Analyses/TBR.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// RUN: %cladclang -mllvm -debug-only=clad-tbr %s -I%S/../../include -oReverseLoops.out 2>&1 | %filecheck %s
// REQUIRES: asserts
// RUN: %cladclang %s -I%S/../../include -oTBR.out | %filecheck %s
// RUN: ./TBR.out | %filecheck_exec %s
// RUN: %cladclang -Xclang -plugin-arg-clad -Xclang -enable-tbr %s -I%S/../../include -oTBR.out
// RUN: ./TBR.out | %filecheck_exec %s
//CHECK-NOT: {{.*error|warning|note:.*}}

#include "clad/Differentiator/Differentiator.h"
Expand All @@ -11,6 +13,36 @@ double f1(double x) {
return t;
} // == x^3

//CHECK: void f1_grad(double x, double *_d_x) {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_t = 0;
//CHECK-NEXT: double t = 1;
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 0; ; i++) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!(i < 3))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, t);
//CHECK-NEXT: t *= x;
//CHECK-NEXT: }
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: t = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_t;
//CHECK-NEXT: _d_t = 0;
//CHECK-NEXT: _d_t += _r_d0 * x;
//CHECK-NEXT: *_d_x += t * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }

double f2(double val) {
double res = 0;
for (int i=1; i<5; ++i) {
Expand All @@ -21,6 +53,111 @@ double f2(double val) {
return res;
}

//CHECK: void f2_grad(double val, double *_d_val) {
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: int i = 0;
//CHECK-NEXT: clad::tape<bool> _cond0 = {};
//CHECK-NEXT: clad::tape<unsigned {{int|long}}> _t1 = {};
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
//CHECK-NEXT: for (i = 1; ; ++i) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!(i < 5))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: {
//CHECK-NEXT: clad::push(_cond0, i == 3);
//CHECK-NEXT: if (clad::back(_cond0)) {
//CHECK-NEXT: clad::push(_t1, {{1U|1UL}});
//CHECK-NEXT: continue;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: res += i * val;
//CHECK-NEXT: clad::push(_t1, {{2U|2UL}});
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: switch (clad::pop(_t1)) {
//CHECK-NEXT: case {{2U|2UL}}:
//CHECK-NEXT: ;
//CHECK-NEXT: --i;
//CHECK-NEXT: {
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_i += _r_d0 * val;
//CHECK-NEXT: *_d_val += i * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: if (clad::back(_cond0))
//CHECK-NEXT: case {{1U|1UL}}:
//CHECK-NEXT: ;
//CHECK-NEXT: clad::pop(_cond0);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

double f3 (double x){
double i = 1;
double j = 0;
double res = 0;
res += i*x;
if(j)
j++;
else if(i)
res += i*x;
else if(j)
i++;
return res;
}

//CHECK: void f3_grad(double x, double *_d_x) {
//CHECK-NEXT: bool _cond0;
//CHECK-NEXT: bool _cond1;
//CHECK-NEXT: bool _cond2;
//CHECK-NEXT: double _d_i = 0;
//CHECK-NEXT: double i = 1;
//CHECK-NEXT: double _d_j = 0;
//CHECK-NEXT: double j = 0;
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: res += i * x;
//CHECK-NEXT: {
//CHECK-NEXT: _cond0 = j;
//CHECK-NEXT: if (_cond0)
//CHECK-NEXT: j++;
//CHECK-NEXT: else {
//CHECK-NEXT: _cond1 = i;
//CHECK-NEXT: if (_cond1)
//CHECK-NEXT: res += i * x;
//CHECK-NEXT: else {
//CHECK-NEXT: _cond2 = j;
//CHECK-NEXT: if (_cond2)
//CHECK-NEXT: i++;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: if (!_cond0)
//CHECK-NEXT: if (_cond1) {
//CHECK-NEXT: double _r_d1 = _d_res;
//CHECK-NEXT: _d_i += _r_d1 * x;
//CHECK-NEXT: *_d_x += i * _r_d1;
//CHECK-NEXT: } else if (_cond2)
//CHECK-NEXT: i--;
//CHECK-NEXT: {
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_i += _r_d0 * x;
//CHECK-NEXT: *_d_x += i * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }


#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient<clad::opts::enable_tbr>(F);\
Expand All @@ -32,4 +169,5 @@ int main() {
double result[3] = {};
TEST(f1, 3); // CHECK-EXEC: {27.00}
TEST(f2, 3); // CHECK-EXEC: {9.00}
TEST(f3, 3); // CHECK-EXEC: {2.00}
}

0 comments on commit 6f4b081

Please sign in to comment.