diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index e6bb7ef0f..427859dc2 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -400,6 +400,8 @@ namespace clad { virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); StmtDiff VisitStmt(const clang::Stmt* S); virtual StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); + StmtDiff + VisitUnaryExprOrTypeTraitExpr(const clang::UnaryExprOrTypeTraitExpr* UE); StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); /// Decl is not Stmt, so it cannot be visited directly. StmtDiff VisitWhileStmt(const clang::WhileStmt* WS); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index db0a5c2a6..6306399ea 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -4107,6 +4107,11 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return Visit(NTTP->getReplacement()); } + StmtDiff ReverseModeVisitor::VisitUnaryExprOrTypeTraitExpr( + const clang::UnaryExprOrTypeTraitExpr* UE) { + return {Clone(UE), Clone(UE)}; + } + DeclDiff ReverseModeVisitor::DifferentiateStaticAssertDecl( const clang::StaticAssertDecl* SAD) { return DeclDiff(nullptr, nullptr);