Skip to content

Commit

Permalink
Handle non-strict select / phi (rust-lang#763)
Browse files Browse the repository at this point in the history
* Handle non-strict select

* Fix non-strict phi
  • Loading branch information
wsmoses authored Jul 30, 2022
1 parent a959be9 commit 146643f
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 13 deletions.
51 changes: 39 additions & 12 deletions enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1491,7 +1491,8 @@ void TypeAnalyzer::visitGetElementPtrInst(GetElementPtrInst &gep) {
#else
APInt ai(DL.getPointerSize(gep.getPointerAddressSpace()) * 8, 0);
#endif
g2->accumulateConstantOffset(DL, ai);
bool valid = g2->accumulateConstantOffset(DL, ai);
assert(valid);
// Using destructor rather than eraseFromParent
// as g2 has no parent
delete g2;
Expand Down Expand Up @@ -1539,14 +1540,31 @@ void TypeAnalyzer::visitPHINode(PHINode &phi) {
TypeTree upVal = getAnalysis(&phi);
// only propagate anything's up if there is one
// incoming value
if (phi.getNumIncomingValues() >= 2) {
Value *seen = phi.getIncomingValue(0);
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
if (seen != phi.getIncomingValue(i)) {
seen = nullptr;
break;
}
}

if (!seen) {
upVal = upVal.PurgeAnything();
}
auto L = LI.getLoopFor(phi.getParent());
bool isHeader = L && L->getHeader() == phi.getParent();
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
if (!isHeader || !L->contains(phi.getIncomingBlock(i))) {
updateAnalysis(phi.getIncomingValue(i), upVal, &phi);

if (EnzymeStrictAliasing || seen) {
auto L = LI.getLoopFor(phi.getParent());
bool isHeader = L && L->getHeader() == phi.getParent();
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i) {
if (!isHeader || !L->contains(phi.getIncomingBlock(i))) {
updateAnalysis(phi.getIncomingValue(i), upVal, &phi);
}
}
} else {
if (EnzymePrintType) {
for (size_t i = 0, end = phi.getNumIncomingValues(); i < end; ++i)
llvm::errs() << " skipping update into " << *phi.getIncomingValue(i)
<< " of " << upVal.str() << " from " << phi << "\n";
}
}
}
Expand Down Expand Up @@ -1840,11 +1858,20 @@ void TypeAnalyzer::visitBitCastInst(BitCastInst &I) {
}

void TypeAnalyzer::visitSelectInst(SelectInst &I) {
if (direction & UP)
updateAnalysis(I.getTrueValue(), getAnalysis(&I).PurgeAnything(), &I);
if (direction & UP)
updateAnalysis(I.getFalseValue(), getAnalysis(&I).PurgeAnything(), &I);

if (direction & UP) {
auto Data = getAnalysis(&I).PurgeAnything();
if (EnzymeStrictAliasing || (I.getTrueValue() == I.getFalseValue())) {
updateAnalysis(I.getTrueValue(), Data, &I);
updateAnalysis(I.getFalseValue(), Data, &I);
} else {
if (EnzymePrintType) {
llvm::errs() << " skipping update into " << *I.getTrueValue() << " of "
<< Data.str() << " from " << I << "\n";
llvm::errs() << " skipping update into " << *I.getFalseValue() << " of "
<< Data.str() << " from " << I << "\n";
}
}
}
if (direction & DOWN) {
// special case for min/max result is still that operand [even if something
// is 0]
Expand Down
2 changes: 1 addition & 1 deletion enzyme/test/TypeAnalysis/strictalphi.ll
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ bb153: ; preds = %bb216

; CHECK: f - {} |
; CHECK-NEXT: e
; CHECK-NEXT: %i78 = call noalias nonnull i8* @_Znwm(i64 8): {[-1]:Pointer, [-1,0]:Integer}
; CHECK-NEXT: %i78 = call noalias nonnull i8* @_Znwm(i64 8): {[-1]:Pointer}
; CHECK-NEXT: br label %bb155: {}
; CHECK-NEXT: bb155
; CHECK-NEXT: %i159 = phi i8* [ %i78, %e ], [ %i220, %bb216 ]: {[-1]:Pointer, [-1,0]:Integer}
Expand Down
97 changes: 97 additions & 0 deletions enzyme/test/TypeAnalysis/strictphi.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=f -enzyme-strict-aliasing=0 -o /dev/null | FileCheck %s

source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%class.Testing = type { %struct.Header, %struct.Header }
%struct.Header = type { %struct.Base, i32 }
%struct.Base = type { %struct.Base*, %struct.Base* }

define dso_local void @f(%class.Testing* %arg) {
bb:
%i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0
%i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0
%i13 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0
%i14 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0, i32 0
br label %bb2

bb2: ; preds = %bb2, %bb
%i3 = phi %struct.Base** [ %i1, %bb ], [ %i7, %bb2 ]
%i4 = phi %struct.Base* [ %i, %bb ], [ %i5, %bb2 ]
%i5 = load %struct.Base*, %struct.Base** %i3, align 8, !tbaa !3
%i6 = icmp eq %struct.Base* %i5, null
%i7 = getelementptr inbounds %struct.Base, %struct.Base* %i5, i64 0, i32 1
br i1 %i6, label %bb8, label %bb2, !llvm.loop !7

bb8: ; preds = %bb2
%i9 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 1, i32 1
%i10 = bitcast %struct.Base** %i9 to double*
%i11 = load double, double* %i10, align 8, !tbaa !9
br label %bb15

bb15: ; preds = %bb15, %bb8
%i16 = phi %struct.Base** [ %i14, %bb8 ], [ %i20, %bb15 ]
%i17 = phi %struct.Base* [ %i13, %bb8 ], [ %i18, %bb15 ]
%i18 = load %struct.Base*, %struct.Base** %i16, align 8, !tbaa !3
%i19 = icmp eq %struct.Base* %i18, null
%i20 = getelementptr inbounds %struct.Base, %struct.Base* %i18, i64 0, i32 1
br i1 %i19, label %bb21, label %bb15, !llvm.loop !7

bb21: ; preds = %bb15
%i22 = getelementptr inbounds %struct.Base, %struct.Base* %i17, i64 1, i32 1
%i23 = bitcast %struct.Base** %i22 to double*
%i24 = load double, double* %i23, align 8, !tbaa !9
tail call void @_Z5printdd(double %i11, double %i24)
ret void
}

declare void @_Z5printdd(double, double)

!llvm.module.flags = !{!0, !1}
!llvm.ident = !{!2}

!0 = !{i32 7, !"Dwarf Version", i32 4}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{!"clang version 12.0.1 (https://github.com/llvm/llvm-project.git fed41342a82f5a3a9201819a82bf7a48313e296b)"}
!3 = !{!4, !4, i64 0}
!4 = !{!"any pointer", !5, i64 0}
!5 = !{!"omnipotent char", !6, i64 0}
!6 = !{!"Simple C++ TBAA"}
!7 = distinct !{!7, !8}
!8 = !{!"llvm.loop.mustprogress"}
!9 = !{!10, !10, i64 0}
!10 = !{!"double", !5, i64 0}

; CHECK: %class.Testing* %arg: {[-1]:Pointer}
; CHECK-NEXT: bb
; CHECK-NEXT: %i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0: {[-1]:Pointer}
; CHECK-NEXT: %i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0: {[-1]:Pointer}
; CHECK-NEXT: %i13 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0: {[-1]:Pointer}
; CHECK-NEXT: %i14 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1, i32 0, i32 0: {[-1]:Pointer}
; CHECK-NEXT: br label %bb2: {}
; CHECK-NEXT: bb2
; CHECK-NEXT: %i3 = phi %struct.Base** [ %i1, %bb ], [ %i7, %bb2 ]: {[-1]:Pointer, [-1,0]:Pointer}
; CHECK-NEXT: %i4 = phi %struct.Base* [ %i, %bb ], [ %i5, %bb2 ]: {[-1]:Pointer, [-1,24]:Float@double}
; CHECK-NEXT: %i5 = load %struct.Base*, %struct.Base** %i3, align 8, !tbaa !3: {[-1]:Pointer}
; CHECK-NEXT: %i6 = icmp eq %struct.Base* %i5, null: {[-1]:Integer}
; CHECK-NEXT: %i7 = getelementptr inbounds %struct.Base, %struct.Base* %i5, i64 0, i32 1: {[-1]:Pointer}
; CHECK-NEXT: br i1 %i6, label %bb8, label %bb2, !llvm.loop !7: {}
; CHECK-NEXT: bb8
; CHECK-NEXT: %i9 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 1, i32 1: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i10 = bitcast %struct.Base** %i9 to double*: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i11 = load double, double* %i10, align 8, !tbaa !9: {[-1]:Float@double}
; CHECK-NEXT: br label %bb15: {}
; CHECK-NEXT: bb15
; CHECK-NEXT: %i16 = phi %struct.Base** [ %i14, %bb8 ], [ %i20, %bb15 ]: {[-1]:Pointer, [-1,0]:Pointer}
; CHECK-NEXT: %i17 = phi %struct.Base* [ %i13, %bb8 ], [ %i18, %bb15 ]: {[-1]:Pointer, [-1,24]:Float@double}
; CHECK-NEXT: %i18 = load %struct.Base*, %struct.Base** %i16, align 8, !tbaa !3: {[-1]:Pointer}
; CHECK-NEXT: %i19 = icmp eq %struct.Base* %i18, null: {[-1]:Integer}
; CHECK-NEXT: %i20 = getelementptr inbounds %struct.Base, %struct.Base* %i18, i64 0, i32 1: {[-1]:Pointer}
; CHECK-NEXT: br i1 %i19, label %bb21, label %bb15, !llvm.loop !7: {}
; CHECK-NEXT: bb21
; CHECK-NEXT: %i22 = getelementptr inbounds %struct.Base, %struct.Base* %i17, i64 1, i32 1: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i23 = bitcast %struct.Base** %i22 to double*: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i24 = load double, double* %i23, align 8, !tbaa !9: {[-1]:Float@double}
; CHECK-NEXT: tail call void @_Z5printdd(double %i11, double %i24): {}
; CHECK-NEXT: ret void: {}
73 changes: 73 additions & 0 deletions enzyme/test/TypeAnalysis/strictselect.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
; RUN: %opt < %s %loadEnzyme -print-type-analysis -type-analysis-func=f -enzyme-strict-aliasing=0 -o /dev/null | FileCheck %s

source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%class.Testing = type { %struct.Header, %struct.Header }
%struct.Header = type { %struct.Base, i32 }
%struct.Base = type { %struct.Base* }

define dso_local void @f(%class.Testing* nocapture nonnull readonly %arg) {
bb:
%i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0
%i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0
%i2 = load %struct.Base*, %struct.Base** %i1, align 8, !tbaa !3
%i3 = icmp eq %struct.Base* %i2, null
%i4 = select i1 %i3, %struct.Base* %i, %struct.Base* %i2
%i5 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 2
%i6 = bitcast %struct.Base* %i5 to double*
%i7 = load double, double* %i6, align 8, !tbaa !10
%i8 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1
%i9 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0
%i10 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0, i32 0
%i11 = load %struct.Base*, %struct.Base** %i10, align 8, !tbaa !3
%i12 = icmp eq %struct.Base* %i11, null
%i13 = select i1 %i12, %struct.Base* %i9, %struct.Base* %i11
%i14 = getelementptr inbounds %struct.Base, %struct.Base* %i13, i64 2
%i15 = bitcast %struct.Base* %i14 to double*
%i16 = load double, double* %i15, align 8, !tbaa !10
tail call void (...) @_Z6printfPKcz(double %i7, double %i16)
ret void
}

declare void @_Z6printfPKcz(...)

!llvm.module.flags = !{!0, !1}
!llvm.ident = !{!2}

!0 = !{i32 7, !"Dwarf Version", i32 4}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{!"clang version 12.0.1 (https://github.com/llvm/llvm-project.git fed41342a82f5a3a9201819a82bf7a48313e296b)"}
!3 = !{!4, !6, i64 0}
!4 = !{!"_ZTS6Header", !5, i64 0, !9, i64 8}
!5 = !{!"_ZTS4Base", !6, i64 0}
!6 = !{!"any pointer", !7, i64 0}
!7 = !{!"omnipotent char", !8, i64 0}
!8 = !{!"Simple C++ TBAA"}
!9 = !{!"int", !7, i64 0}
!10 = !{!11, !11, i64 0}
!11 = !{!"double", !7, i64 0}

; CHECK: f - {} |{[-1]:Pointer}:{}
; CHECK-NEXT: %class.Testing* %arg: {[-1]:Pointer, [-1,0]:Pointer, [-1,16]:Pointer}
; CHECK-NEXT: bb
; CHECK-NEXT: %i = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
; CHECK-NEXT: %i1 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 0, i32 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
; CHECK-NEXT: %i2 = load %struct.Base*, %struct.Base** %i1, align 8, !tbaa !3: {[-1]:Pointer}
; CHECK-NEXT: %i3 = icmp eq %struct.Base* %i2, null: {[-1]:Integer}
; CHECK-NEXT: %i4 = select i1 %i3, %struct.Base* %i, %struct.Base* %i2: {[-1]:Pointer, [-1,16]:Float@double}
; CHECK-NEXT: %i5 = getelementptr inbounds %struct.Base, %struct.Base* %i4, i64 2: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i6 = bitcast %struct.Base* %i5 to double*: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i7 = load double, double* %i6, align 8, !tbaa !10: {[-1]:Float@double}
; CHECK-NEXT: %i8 = getelementptr inbounds %class.Testing, %class.Testing* %arg, i64 0, i32 1: {[-1]:Pointer, [-1,0]:Pointer}
; CHECK-NEXT: %i9 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
; CHECK-NEXT: %i10 = getelementptr inbounds %struct.Header, %struct.Header* %i8, i64 0, i32 0, i32 0: {[-1]:Pointer, [-1,0]:Pointer}
; CHECK-NEXT: %i11 = load %struct.Base*, %struct.Base** %i10, align 8, !tbaa !3: {[-1]:Pointer}
; CHECK-NEXT: %i12 = icmp eq %struct.Base* %i11, null: {[-1]:Integer}
; CHECK-NEXT: %i13 = select i1 %i12, %struct.Base* %i9, %struct.Base* %i11: {[-1]:Pointer, [-1,16]:Float@double}
; CHECK-NEXT: %i14 = getelementptr inbounds %struct.Base, %struct.Base* %i13, i64 2: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i15 = bitcast %struct.Base* %i14 to double*: {[-1]:Pointer, [-1,0]:Float@double}
; CHECK-NEXT: %i16 = load double, double* %i15, align 8, !tbaa !10: {[-1]:Float@double}
; CHECK-NEXT: tail call void (...) @_Z6printfPKcz(double %i7, double %i16): {}
; CHECK-NEXT: ret void: {}

0 comments on commit 146643f

Please sign in to comment.