From a4c1d258c8f13ff42f05363e08be5a404c114a33 Mon Sep 17 00:00:00 2001 From: Max Andriychuk Date: Sat, 24 Aug 2024 16:25:04 +0200 Subject: [PATCH] Fix braceless if differentiation in reverse mode As in #1049, the statement inside the braceless if is not included causing an error. Fixes:#1049 --- lib/Differentiator/ReverseModeVisitor.cpp | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index b8baae4f9..637b027cc 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -878,10 +878,21 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, elseDiff.getStmt()); addToCurrentBlock(Forward, direction::forward); - Stmt* Reverse = clad_compat::IfStmt_Create( - m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, /*Var=*/nullptr, - condDiffStored, noLoc, noLoc, thenDiff.getStmt_dx(), noLoc, - elseDiff.getStmt_dx()); + Stmt* Reverse = nullptr; + // thenDiff.getStmt_dx() might be empty if TBR is on leadinf to a crash in + // case of the braceless if. + if (thenDiff.getStmt_dx()) + Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, + /*Var=*/nullptr, condDiffStored, noLoc, noLoc, thenDiff.getStmt_dx(), + noLoc, elseDiff.getStmt_dx()); + else if (elseDiff.getStmt_dx()) + Reverse = clad_compat::IfStmt_Create( + m_Context, noLoc, If->isConstexpr(), /*Init=*/nullptr, + /*Var=*/nullptr, + BuildOp(clang::UnaryOperatorKind::UO_LNot, + BuildParens(condDiffStored)), + noLoc, noLoc, elseDiff.getStmt_dx(), noLoc, {}); addToCurrentBlock(Reverse, direction::reverse); CompoundStmt* ForwardBlock = endBlock(direction::forward); CompoundStmt* ReverseBlock = endBlock(direction::reverse);