Skip to content

Commit

Permalink
Handle maxpd (rust-lang#690)
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses authored Jun 16, 2022
1 parent c68257e commit 4eaa5ef
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 0 deletions.
37 changes: 37 additions & 0 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -10494,6 +10494,43 @@ class AdjointGenerator
}
}
}
#if LLVM_VERSION_MAJOR >= 11
if (auto assembly = dyn_cast<InlineAsm>(orig->getCalledOperand()))
#else
if (auto assembly = dyn_cast<InlineAsm>(orig->getCalledValue()))
#endif
{
if (assembly->getAsmString() == "maxpd $1, $0") {
if (Mode == DerivativeMode::ReverseModePrimal ||
gutils->isConstantInstruction(orig)) {

if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
if (!gutils->knownRecomputeHeuristic[orig]) {
gutils->cacheForReverse(BuilderZ, newCall,
getIndex(orig, CacheType::Self));
}
}
eraseIfUnused(*orig);
return;
}

SmallVector<Value *, 2> orig_ops(orig->getNumOperands());
for (unsigned i = 0; i < orig->getNumOperands(); ++i) {
orig_ops[i] = orig->getOperand(i);
}
handleAdjointForIntrinsic(Intrinsic::maxnum, *orig, orig_ops);
if (gutils->knownRecomputeHeuristic.find(orig) !=
gutils->knownRecomputeHeuristic.end()) {
if (!gutils->knownRecomputeHeuristic[orig]) {
gutils->cacheForReverse(BuilderZ, newCall,
getIndex(orig, CacheType::Self));
}
}
eraseIfUnused(*orig);
return;
}
}

if (called && isAllocationFunction(*called, gutils->TLI)) {

Expand Down
23 changes: 23 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/maxpd.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
; RUN: %opt < %s %loadEnzyme -enzyme -enzyme-preopt=false -mem2reg -early-cse -simplifycfg -S | FileCheck %s

define <2 x double> @pmax(<2 x double> %a, <2 x double> %b) {
%r = call <2 x double> asm "maxpd $1, $0", "=x,x,0,~{dirflag},~{fpsr},~{flags}"(<2 x double> %a, <2 x double> %b)
ret <2 x double> %r
}

declare { <2 x double>, <2 x double> } @__enzyme_autodiff(...)

define { <2 x double>, <2 x double> } @test_derivative(<2 x double> %x, <2 x double> %y) {
entry:
%0 = tail call { <2 x double>, <2 x double> } (...) @__enzyme_autodiff(<2 x double> (<2 x double>, <2 x double>)* @pmax, <2 x double> %x, <2 x double> %y)
ret { <2 x double>, <2 x double> } %0
}

; CHECK: define internal { <2 x double>, <2 x double> } @diffepmax(<2 x double> %a, <2 x double> %b, <2 x double> %differeturn)
; CHECK: %r = call <2 x double> asm "maxpd $1, $0", "=x,x,0,~{dirflag},~{fpsr},~{flags}"(<2 x double> %a, <2 x double> %b)
; CHECK-NEXT: %[[i0:.+]] = fcmp fast olt <2 x double> %a, %b
; CHECK-NEXT: %[[i1:.+]] = select {{(fast )?}}<2 x i1> %[[i0]], <2 x double> zeroinitializer, <2 x double> %differeturn
; CHECK-NEXT: %[[i2:.+]] = select {{(fast )?}}<2 x i1> %[[i0]], <2 x double> %differeturn, <2 x double> zeroinitializer
; CHECK-NEXT: %[[i3:.+]] = insertvalue { <2 x double>, <2 x double> } undef, <2 x double> %[[i1]], 0
; CHECK-NEXT: %[[i4:.+]] = insertvalue { <2 x double>, <2 x double> } %[[i3]], <2 x double> %[[i2]], 1
; CHECK-NEXT: ret { <2 x double>, <2 x double> } %[[i4]]

0 comments on commit 4eaa5ef

Please sign in to comment.