Skip to content

Commit

Permalink
Improve cast ft error
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed Dec 31, 2024
1 parent 7cf9e90 commit 0937a52
Showing 1 changed file with 24 additions and 22 deletions.
46 changes: 24 additions & 22 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -1380,31 +1380,33 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
ss << "Cannot deduce adding type (cast) of " << I;
EmitNoTypeError(str, I, gutils, Builder2);
}
assert(FT);

auto rule = [&](Value *dif) {
if (I.getOpcode() == CastInst::CastOps::FPTrunc ||
I.getOpcode() == CastInst::CastOps::FPExt) {
return Builder2.CreateFPCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::BitCast) {
return Builder2.CreateBitCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::Trunc) {
// TODO CHECK THIS
return Builder2.CreateZExt(dif, op0->getType());
} else {
std::string s;
llvm::raw_string_ostream ss(s);
ss << *I.getParent()->getParent() << "\n";
ss << "cannot handle above cast " << I << "\n";
EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
return (llvm::Value *)UndefValue::get(op0->getType());
}
};
if (FT) {

auto rule = [&](Value *dif) {
if (I.getOpcode() == CastInst::CastOps::FPTrunc ||
I.getOpcode() == CastInst::CastOps::FPExt) {
return Builder2.CreateFPCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::BitCast) {
return Builder2.CreateBitCast(dif, op0->getType());
} else if (I.getOpcode() == CastInst::CastOps::Trunc) {
// TODO CHECK THIS
return Builder2.CreateZExt(dif, op0->getType());
} else {
std::string s;
llvm::raw_string_ostream ss(s);
ss << *I.getParent()->getParent() << "\n";
ss << "cannot handle above cast " << I << "\n";
EmitNoDerivativeError(ss.str(), I, gutils, Builder2);
return (llvm::Value *)UndefValue::get(op0->getType());
}
};

Value *dif = diffe(&I, Builder2);
Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif);
Value *dif = diffe(&I, Builder2);
Value *diff = applyChainRule(op0->getType(), Builder2, rule, dif);

addToDiffe(orig_op0, diff, Builder2, FT);
addToDiffe(orig_op0, diff, Builder2, FT);
}
}

Type *diffTy = gutils->getShadowType(I.getType());
Expand Down

0 comments on commit 0937a52

Please sign in to comment.