Skip to content

Commit

Permalink
Fix struct containing attributes (#1627)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jan 25, 2024
1 parent 5caf99a commit b2151f8
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 2 deletions.
22 changes: 20 additions & 2 deletions enzyme/Enzyme/Clang/EnzymeClang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,11 @@ struct EnzymeFunctionLikeAttrInfo : public ParsedAttrInfo {
// if (FD->isLateTemplateParsed()) return;
auto &AST = S.getASTContext();
DeclContext *declCtx = FD->getDeclContext();
for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
if (tmpCtx->isRecord()) {
declCtx = tmpCtx->getParent();
}
}
auto loc = FD->getLocation();
RecordDecl *RD;
if (S.getLangOpts().CPlusPlus)
Expand Down Expand Up @@ -369,6 +374,11 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo {

auto &AST = S.getASTContext();
DeclContext *declCtx = D->getDeclContext();
for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
if (tmpCtx->isRecord()) {
declCtx = tmpCtx->getParent();
}
}
auto loc = D->getLocation();
RecordDecl *RD;
if (S.getLangOpts().CPlusPlus)
Expand Down Expand Up @@ -425,7 +435,6 @@ struct EnzymeInactiveAttrInfo : public ParsedAttrInfo {
return AttributeNotApplied;
}
V->setInit(expr);
V->dump();
S.MarkVariableReferenced(loc, V);
S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
return AttributeApplied;
Expand Down Expand Up @@ -479,6 +488,11 @@ struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo {

auto &AST = S.getASTContext();
DeclContext *declCtx = D->getDeclContext();
for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
if (tmpCtx->isRecord()) {
declCtx = tmpCtx->getParent();
}
}
auto loc = D->getLocation();
RecordDecl *RD;
if (S.getLangOpts().CPlusPlus)
Expand Down Expand Up @@ -534,7 +548,6 @@ struct EnzymeNoFreeAttrInfo : public ParsedAttrInfo {
return AttributeNotApplied;
}
V->setInit(expr);
V->dump();
S.MarkVariableReferenced(loc, V);
S.getASTConsumer().HandleTopLevelDecl(DeclGroupRef(V));
return AttributeApplied;
Expand Down Expand Up @@ -584,6 +597,11 @@ struct EnzymeSparseAccumulateAttrInfo : public ParsedAttrInfo {

auto &AST = S.getASTContext();
DeclContext *declCtx = D->getDeclContext();
for (auto tmpCtx = declCtx; tmpCtx; tmpCtx = tmpCtx->getParent()) {
if (tmpCtx->isRecord()) {
declCtx = tmpCtx->getParent();
}
}
auto loc = D->getLocation();
RecordDecl *RD;
if (S.getLangOpts().CPlusPlus)
Expand Down
62 changes: 62 additions & 0 deletions enzyme/test/Integration/ReverseMode/inactiveglob.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// This should work on LLVM 7, 8, 9, however in CI the version of clang installed on Ubuntu 18.04 cannot load
// a clang plugin properly without segfaulting on exit. This is fine on Ubuntu 20.04 or later LLVM versions...
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// RUN: if [ %llvmver -ge 11 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %loadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O0 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O1 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O2 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi
// RUN: if [ %llvmver -ge 12 ]; then %clang++ -std=c++11 -O3 %s -S -emit-llvm -o - %newLoadClangEnzyme -mllvm -enzyme-inline=1 -S | %lli - ; fi

#include <stdio.h>
#include <math.h>
#include <assert.h>

#include "../test_utils.h"

double __enzyme_autodiff(void*, ...);

struct Temp {
private:
__attribute__((enzyme_inactive))
static double tmp;
public:

__attribute__((noinline))
static void f(bool cond, double a, double *c) {
if (cond)
tmp *= a;
else
*c *= a;
}

static double get() { return tmp; }
static void set(double v) { tmp = v; }

};

double Temp::tmp = 0;

double test(bool cond, double a) {
double dat = a;
Temp::f(cond, a, &dat);
return dat;
}

int main(int argc, char** argv) {
Temp::set(5.5);
double out = __enzyme_autodiff((void*)test, false, 3.0);
printf("out=%f\n", out);
APPROX_EQ(out, 6.0, 1e-10);
APPROX_EQ(Temp::get(), 5.5, 1e-10);
return 0;
}

0 comments on commit b2151f8

Please sign in to comment.