Skip to content

Commit

Permalink
Add constant evaluation for clamp() (microsoft#3581)
Browse files Browse the repository at this point in the history
  • Loading branch information
tex3d authored Mar 18, 2021
1 parent 63ce61a commit 2bda44f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 0 deletions.
92 changes: 92 additions & 0 deletions tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,10 @@ typedef APInt(__cdecl *IntBinaryEvalFuncType)(const APInt &, const APInt &);
typedef float(__cdecl *FloatBinaryEvalFuncType)(float, float);
typedef double(__cdecl *DoubleBinaryEvalFuncType)(double, double);

typedef APInt(__cdecl *IntTernaryEvalFuncType)(const APInt &, const APInt &, const APInt &);
typedef float(__cdecl *FloatTernaryEvalFuncType)(float, float, float);
typedef double(__cdecl *DoubleTernaryEvalFuncType)(double, double, double);

Value *EvalUnaryIntrinsic(ConstantFP *fpV, FloatUnaryEvalFuncType floatEvalFunc,
DoubleUnaryEvalFuncType doubleEvalFunc) {
llvm::Type *Ty = fpV->getType();
Expand Down Expand Up @@ -1510,6 +1514,45 @@ Value *EvalBinaryIntrinsic(Constant *cV0, Constant *cV1,
return Result;
}

Value *EvalTernaryIntrinsic(Constant *cV0, Constant *cV1, Constant *cV2,
FloatTernaryEvalFuncType floatEvalFunc,
DoubleTernaryEvalFuncType doubleEvalFunc,
IntTernaryEvalFuncType intEvalFunc) {
llvm::Type *Ty = cV0->getType();
Value *Result = nullptr;
if (Ty->isDoubleTy()) {
ConstantFP *fpV0 = cast<ConstantFP>(cV0);
ConstantFP *fpV1 = cast<ConstantFP>(cV1);
ConstantFP *fpV2 = cast<ConstantFP>(cV2);
double dV0 = fpV0->getValueAPF().convertToDouble();
double dV1 = fpV1->getValueAPF().convertToDouble();
double dV2 = fpV2->getValueAPF().convertToDouble();
Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV0, dV1, dV2));
Result = dResult;
} else if (Ty->isFloatTy()) {
ConstantFP *fpV0 = cast<ConstantFP>(cV0);
ConstantFP *fpV1 = cast<ConstantFP>(cV1);
ConstantFP *fpV2 = cast<ConstantFP>(cV2);
float fV0 = fpV0->getValueAPF().convertToFloat();
float fV1 = fpV1->getValueAPF().convertToFloat();
float fV2 = fpV2->getValueAPF().convertToFloat();
Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV0, fV1, fV2));
Result = dResult;
} else {
DXASSERT_NOMSG(Ty->isIntegerTy());
DXASSERT_NOMSG(intEvalFunc);
ConstantInt *ciV0 = cast<ConstantInt>(cV0);
ConstantInt *ciV1 = cast<ConstantInt>(cV1);
ConstantInt *ciV2 = cast<ConstantInt>(cV2);
const APInt &iV0 = ciV0->getValue();
const APInt &iV1 = ciV1->getValue();
const APInt &iV2 = ciV2->getValue();
Value *dResult = ConstantInt::get(Ty, intEvalFunc(iV0, iV1, iV2));
Result = dResult;
}
return Result;
}

Value *EvalUnaryIntrinsic(CallInst *CI, FloatUnaryEvalFuncType floatEvalFunc,
DoubleUnaryEvalFuncType doubleEvalFunc) {
Value *V = CI->getArgOperand(0);
Expand Down Expand Up @@ -1566,6 +1609,43 @@ Value *EvalBinaryIntrinsic(CallInst *CI, FloatBinaryEvalFuncType floatEvalFunc,
return Result;
}

Value *EvalTernaryIntrinsic(CallInst *CI, FloatTernaryEvalFuncType floatEvalFunc,
DoubleTernaryEvalFuncType doubleEvalFunc,
IntTernaryEvalFuncType intEvalFunc = nullptr) {
Value *V0 = CI->getArgOperand(0);
Value *V1 = CI->getArgOperand(1);
Value *V2 = CI->getArgOperand(2);
llvm::Type *Ty = CI->getType();
Value *Result = nullptr;
if (llvm::VectorType *VT = dyn_cast<llvm::VectorType>(Ty)) {
Result = UndefValue::get(Ty);
Constant *CV0 = cast<Constant>(V0);
Constant *CV1 = cast<Constant>(V1);
Constant *CV2 = cast<Constant>(V2);
IRBuilder<> Builder(CI);
for (unsigned i = 0; i < VT->getNumElements(); i++) {
Constant *cV0 = cast<Constant>(CV0->getAggregateElement(i));
Constant *cV1 = cast<Constant>(CV1->getAggregateElement(i));
Constant *cV2 = cast<Constant>(CV2->getAggregateElement(i));
Value *EltResult = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc,
doubleEvalFunc, intEvalFunc);
Result = Builder.CreateInsertElement(Result, EltResult, i);
}
} else {
Constant *cV0 = cast<Constant>(V0);
Constant *cV1 = cast<Constant>(V1);
Constant *cV2 = cast<Constant>(V2);
Result = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc, doubleEvalFunc,
intEvalFunc);
}
CI->replaceAllUsesWith(Result);
CI->eraseFromParent();
return Result;

CI->eraseFromParent();
return Result;
}

void SimpleTransformForHLDXIRInst(Instruction *I, SmallInstSet &deadInsts) {

unsigned opcode = I->getOpcode();
Expand Down Expand Up @@ -1789,6 +1869,18 @@ Value *TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp,
CI->eraseFromParent();
return cNan;
} break;
case IntrinsicOp::IOP_clamp: {
auto clampF = [](float a, float b, float c) {
return a < b ? b : a > c ? c : a;
};
auto clampD = [](double a, double b, double c) {
return a < b ? b : a > c ? c : a;
};
auto clampI = [](const APInt &a, const APInt &b, const APInt &c) -> APInt {
return a.slt(b) ? b : a.sgt(c) ? c : a;
};
return EvalTernaryIntrinsic(CI, clampF, clampD, clampI);
} break;
default:
return nullptr;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
// RUN: %dxc -T ps_6_0 %s -E main | %FileCheck %s
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 1.000000e+00)
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float -1.250000e+00)
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float 3.000000e+00)
// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float 2.000000e+00)

[RootSignature("")]
float4 main() : SV_Target {
return float4(
clamp(10, 0, 1),
clamp(-1.0f, -2.5f, -1.25f),
clamp((double)3, (double)-2, (double)5),
clamp(-5LL, 2LL, 5LL));
}

0 comments on commit 2bda44f

Please sign in to comment.