From b2151f8f6dc9c3a807a7216fdfab612be12ab2cc Mon Sep 17 00:00:00 2001 From: William Moses Date: Thu, 25 Jan 2024 01:19:48 -0500 Subject: [PATCH] Fix struct containing attributes (#1627) --- enzyme/Enzyme/Clang/EnzymeClang.cpp | 22 ++++++- .../Integration/ReverseMode/inactiveglob.cpp | 62 +++++++++++++++++++ 2 files changed, 82 insertions(+), 2 deletions(-) create mode 100644 enzyme/test/Integration/ReverseMode/inactiveglob.cpp diff --git a/enzyme/Enzyme/Clang/EnzymeClang.cpp b/enzyme/Enzyme/Clang/EnzymeClang.cpp index ee5794cb13e2..ed01f1bf5739 100644 --- a/enzyme/Enzyme/Clang/EnzymeClang.cpp +++ b/enzyme/Enzyme/Clang/EnzymeClang.cpp @@ -240,6 +240,11 @@ struct EnzymeFunctionLikeAttrInfo : public ParsedAttrInfo { // if (FD->isLateTemplateParsed()) return; auto &AST = S.getASTContext(); DeclContext *declCtx = FD->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = FD->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) @@ -369,6 +374,11 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo { auto &AST = S.getASTContext(); DeclContext *declCtx = D->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = D->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) @@ -425,7 +435,6 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo { return AttributeNotApplied; } V->setInit(expr); - V->dump(); S.MarkVariableReferenced(loc, V); S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); return AttributeApplied; @@ -479,6 +488,11 @@ struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo { auto &AST = S.getASTContext(); DeclContext *declCtx = D->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = D->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) @@ -534,7 +548,6 @@ struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo { return AttributeNotApplied; } V->setInit(expr); - V->dump(); S.MarkVariableReferenced(loc, V); S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V)); return AttributeApplied; @@ -584,6 +597,11 @@ struct EnzymeSparseAccumulateAttrInfo : public ParsedAttrInfo { auto &AST = S.getASTContext(); DeclContext *declCtx = D->getDeclContext(); + for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) { + if (tmpCtx->isRecord()) { + declCtx = tmpCtx->getParent(); + } + } auto loc = D->getLocation(); RecordDecl *RD; if (S.getLangOpts().CPlusPlus) diff --git a/enzyme/test/Integration/ReverseMode/inactiveglob.cpp b/enzyme/test/Integration/ReverseMode/inactiveglob.cpp new file mode 100644 index 000000000000..118ab257acd2 --- /dev/null +++ b/enzyme/test/Integration/ReverseMode/inactiveglob.cpp @@ -0,0 +1,62 @@ +// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load +// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions... +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi +// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi + +#include +#include +#include + +#include "../test_utils.h" + +double __enzyme_autodiff(void*, ...); + +struct Temp { +private: + __attribute__((enzyme_inactive)) + static double tmp; +public: + +__attribute__((noinline)) +static void f(bool cond, double a, double *c) { + if (cond) + tmp *= a; + else + *c *= a; +} + +static double get() { return tmp; } +static void set(double v) { tmp = v; } + +}; + +double Temp::tmp = 0; + +double test(bool cond, double a) { + double dat = a; + Temp::f(cond, a, &dat); + return dat; +} + +int main(int argc, char** argv) { + Temp::set(5.5); + double out = __enzyme_autodiff((void*)test, false, 3.0); + printf("out=%f\n", out); + APPROX_EQ(out, 6.0, 1e-10); + APPROX_EQ(Temp::get(), 5.5, 1e-10); + return 0; +}