Skip to content

Commit

Permalink
Activity Analysis: strengthen recursive hyp (#2197)
Browse files Browse the repository at this point in the history
* Activity Analysis: strengthen recursive hyp

* fix test

* fix
  • Loading branch information
wsmoses authored Dec 14, 2024
1 parent 068ad9c commit 5f1d332
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 6 deletions.
25 changes: 19 additions & 6 deletions enzyme/Enzyme/ActivityAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -980,12 +980,10 @@ bool ActivityAnalyzer::isConstantInstruction(TypeResults const &TR,
insertConstantsFrom(TR, *DownHypothesis);
return true;
} else if (directions == 3) {
if (isa<LoadInst>(I) || isa<StoreInst>(I) || isa<BinaryOperator>(I)) {
for (auto &op : I->operands()) {
if (!UpHypothesis->isConstantValue(TR, op) &&
EnzymeEnableRecursiveHypotheses) {
ReEvaluateInstIfInactiveValue[op].insert(I);
}
for (auto &op : I->operands()) {
if (!UpHypothesis->isConstantValue(TR, op) &&
EnzymeEnableRecursiveHypotheses) {
ReEvaluateInstIfInactiveValue[op].insert(I);
}
}
}
Expand Down Expand Up @@ -1785,6 +1783,13 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
InsertConstantValue(TR, Val);
insertConstantsFrom(TR, *UpHypothesis);
return true;
} else if (directions == 3) {
for (auto &op : inst->operands()) {
if (!UpHypothesis->isConstantValue(TR, op) &&
EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[op].insert(Val);
}
}
}
}
}
Expand Down Expand Up @@ -1826,6 +1831,14 @@ bool ActivityAnalyzer::isConstantValue(TypeResults const &TR, Value *Val) {
if (EnzymePrintActivity)
llvm::errs() << " cannot show constant instruction hypothesis: "
<< *VI << "\n";
if (directions == 3) {
for (auto &op : VI->operands()) {
if (!UpHypothesis->isConstantValue(TR, op) &&
EnzymeEnableRecursiveHypotheses) {
ReEvaluateValueIfInactiveValue[op].insert(Val);
}
}
}
}
}

Expand Down
43 changes: 43 additions & 0 deletions enzyme/test/ActivityAnalysis/mallocuse.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
; RUN: %opt < %s %newLoadEnzyme -passes="print-activity-analysis" -activity-analysis-func=_take -opaque-pointers -S -o /dev/null | FileCheck %s

declare ptr @malloc(i64)

define double @_take(ptr %a0, i1 %a1) {
entry:
%a3 = tail call ptr @malloc(i64 10)
%a4 = tail call ptr @malloc(i64 10)
%a5 = ptrtoint ptr %a4 to i64
%a6 = or i64 %a5, 1
%a7 = inttoptr i64 %a6 to ptr
%a8 = load double, ptr %a7, align 8
store double %a8, ptr %a0, align 8
br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next

.lr.ph1.peel.next: ; preds = %2
%.pre = load double, ptr %a4, align 8
ret double %.pre

.lr.ph: ; preds = %.lr.ph, %2
%a9 = load double, ptr %a3, align 4
store double %a9, ptr %a4, align 8
br label %.lr.ph
}

; CHECK: ptr %a0: icv:0
; CHECK-NEXT: i1 %a1: icv:1
; CHECK-NEXT: entry
; CHECK-NEXT: %a3 = tail call ptr @malloc(i64 10): icv:1 ici:1
; CHECK-NEXT: %a4 = tail call ptr @malloc(i64 10): icv:1 ici:1
; CHECK-NEXT: %a5 = ptrtoint ptr %a4 to i64: icv:1 ici:1
; CHECK-NEXT: %a6 = or i64 %a5, 1: icv:1 ici:1
; CHECK-NEXT: %a7 = inttoptr i64 %a6 to ptr: icv:1 ici:1
; CHECK-NEXT: %a8 = load double, ptr %a7, align 8: icv:1 ici:1
; CHECK-NEXT: store double %a8, ptr %a0, align 8: icv:1 ici:1
; CHECK-NEXT: br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next: icv:1 ici:1
; CHECK-NEXT: .lr.ph1.peel.next
; CHECK-NEXT: %.pre = load double, ptr %a4, align 8: icv:1 ici:1
; CHECK-NEXT: ret double %.pre: icv:1 ici:1
; CHECK-NEXT: .lr.ph
; CHECK-NEXT: %a9 = load double, ptr %a3, align 4: icv:1 ici:1
; CHECK-NEXT: store double %a9, ptr %a4, align 8: icv:1 ici:1
; CHECK-NEXT: br label %.lr.ph: icv:1 ici:1
66 changes: 66 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/mallocuse.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
; RUN: %opt < %s %newLoadEnzyme -passes="enzyme,function(mem2reg,early-cse,sroa,instsimplify,%simplifycfg,adce)" -enzyme-preopt=false -opaque-pointers -S | FileCheck %s

declare ptr @__enzyme_virtualreverse(...)

declare ptr @malloc(i64)

define void @my_model.fullgrad1() {
%z = call ptr (...) @__enzyme_virtualreverse(ptr nonnull @_take)
ret void
}

define double @_take(ptr %a0, i1 %a1) {
%a3 = tail call ptr @malloc(i64 10)
%a4 = tail call ptr @malloc(i64 10)
%a5 = ptrtoint ptr %a4 to i64
%a6 = or i64 %a5, 1
%a7 = inttoptr i64 %a6 to ptr
%a8 = load double, ptr %a7, align 8
store double %a8, ptr %a0, align 8
br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next

.lr.ph1.peel.next: ; preds = %2
%.pre = load double, ptr %a4, align 8
ret double %.pre

.lr.ph: ; preds = %.lr.ph, %2
%a9 = load double, ptr %a3, align 4
store double %a9, ptr %a4, align 8
br label %.lr.ph
}

; CHECK: define internal { ptr, double } @augmented__take(ptr %a0, ptr %"a0'", i1 %a1)
; CHECK-NEXT: %malloccall = tail call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) ptr @malloc(i64 8)
; CHECK-NEXT: %a3 = tail call ptr @malloc(i64 10)
; CHECK-NEXT: %a4 = tail call ptr @malloc(i64 10)
; CHECK-NEXT: store ptr %a4, ptr %malloccall, align 8
; CHECK-NEXT: %a5 = ptrtoint ptr %a4 to i64
; CHECK-NEXT: %a6 = or i64 %a5, 1
; CHECK-NEXT: %a7 = inttoptr i64 %a6 to ptr
; CHECK-NEXT: %a8 = load double, ptr %a7, align 8
; CHECK-NEXT: store double %a8, ptr %a0, align 8
; CHECK-NEXT: br i1 %a1, label %.lr.ph, label %.lr.ph1.peel.next

; CHECK: .lr.ph1.peel.next: ; preds = %0
; CHECK-NEXT: %.pre = load double, ptr %a4, align 8, !alias.scope !10, !noalias !13
; CHECK-NEXT: %.fca.0.insert = insertvalue { ptr, double } poison, ptr %malloccall, 0
; CHECK-NEXT: %.fca.1.insert = insertvalue { ptr, double } %.fca.0.insert, double %.pre, 1
; CHECK-NEXT: ret { ptr, double } %.fca.1.insert

; CHECK: .lr.ph: ; preds = %0, %.lr.ph
; CHECK-NEXT: %a9 = load double, ptr %a3, align 4
; CHECK-NEXT: store double %a9, ptr %a4, align 8
; CHECK-NEXT: br label %.lr.ph
; CHECK-NEXT: }

; CHECK: define internal void @diffe_take(ptr %a0, ptr %"a0'", i1 %a1, double %differeturn, ptr %tapeArg)
; CHECK-NEXT: tail call void @free(ptr nonnull %tapeArg)
; CHECK-NEXT: br i1 %a1, label %.lr.ph, label %invert.lr.ph1.peel.next

; CHECK: .lr.ph: ; preds = %0, %.lr.ph
; CHECK-NEXT: br label %.lr.ph

; CHECK: invert.lr.ph1.peel.next: ; preds = %0
; CHECK-NEXT: store double 0.000000e+00, ptr %"a0'", align 8
; CHECK-NEXT: ret void
; CHECK-NEXT: }

0 comments on commit 5f1d332

Please sign in to comment.