Skip to content

Commit

Permalink
Expand shadow_alloc_rewrite capabilities
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 22, 2024
1 parent a03c1d4 commit 514ab72
Showing 1 changed file with 15 additions and 3 deletions.
18 changes: 15 additions & 3 deletions enzyme/Enzyme/CallDerivatives.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
using namespace llvm;

extern "C" {
void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *) = nullptr;
void (*EnzymeShadowAllocRewrite)(LLVMValueRef, void *, LLVMValueRef, uint64_t,
LLVMValueRef) = nullptr;
}

void AdjointGenerator::handleMPI(llvm::CallInst &call, llvm::Function *called,
Expand Down Expand Up @@ -3014,6 +3015,9 @@ bool AdjointGenerator::handleKnownCallDerivatives(
bb, anti, getIndex(&call, CacheType::Shadow, BuilderZ));
} else {
bool zeroed = false;
uint64_t idx = 0;
Value *prev = nullptr;
;
auto rule = [&]() {
Value *anti =
bb.CreateCall(call.getFunctionType(), call.getCalledOperand(),
Expand Down Expand Up @@ -3059,7 +3063,8 @@ bool AdjointGenerator::handleKnownCallDerivatives(
funcName == "jl_gc_alloc_typed" ||
funcName == "ijl_gc_alloc_typed") {
if (EnzymeShadowAllocRewrite)
EnzymeShadowAllocRewrite(wrap(anti), gutils);
EnzymeShadowAllocRewrite(wrap(anti), gutils, wrap(&call),
idx, wrap(prev));
}
}
if (Mode == DerivativeMode::ReverseModeCombined ||
Expand All @@ -3075,6 +3080,8 @@ bool AdjointGenerator::handleKnownCallDerivatives(
zeroed = true;
}
}
idx++;
prev = anti;
return anti;
};

Expand Down Expand Up @@ -3224,6 +3231,8 @@ bool AdjointGenerator::handleKnownCallDerivatives(
args.push_back(gutils->getNewFromOriginal(arg));
}

uint64_t idx = 0;
Value *prev = gutils->getNewFromOriginal(&call);
auto rule = [&]() {
SmallVector<ValueType, 2> BundleTypes(args.size(), ValueType::Primal);

Expand All @@ -3241,8 +3250,11 @@ bool AdjointGenerator::handleKnownCallDerivatives(
funcName == "jl_gc_alloc_typed" ||
funcName == "ijl_gc_alloc_typed") {
if (EnzymeShadowAllocRewrite)
EnzymeShadowAllocRewrite(wrap(CI), gutils);
EnzymeShadowAllocRewrite(wrap(CI), gutils, wrap(&call), idx,
wrap(prev));
}
idx++;
prev = CI;
return CI;
};

Expand Down

0 comments on commit 514ab72

Please sign in to comment.