diff --git a/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp b/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp index 1daa569e6c..b6d096fb23 100644 --- a/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp +++ b/tools/clang/lib/CodeGen/CGHLSLMSFinishCodeGen.cpp @@ -1460,6 +1460,10 @@ typedef APInt(__cdecl *IntBinaryEvalFuncType)(const APInt &, const APInt &); typedef float(__cdecl *FloatBinaryEvalFuncType)(float, float); typedef double(__cdecl *DoubleBinaryEvalFuncType)(double, double); +typedef APInt(__cdecl *IntTernaryEvalFuncType)(const APInt &, const APInt &, const APInt &); +typedef float(__cdecl *FloatTernaryEvalFuncType)(float, float, float); +typedef double(__cdecl *DoubleTernaryEvalFuncType)(double, double, double); + Value *EvalUnaryIntrinsic(ConstantFP *fpV, FloatUnaryEvalFuncType floatEvalFunc, DoubleUnaryEvalFuncType doubleEvalFunc) { llvm::Type *Ty = fpV->getType(); @@ -1510,6 +1514,45 @@ Value *EvalBinaryIntrinsic(Constant *cV0, Constant *cV1, return Result; } +Value *EvalTernaryIntrinsic(Constant *cV0, Constant *cV1, Constant *cV2, + FloatTernaryEvalFuncType floatEvalFunc, + DoubleTernaryEvalFuncType doubleEvalFunc, + IntTernaryEvalFuncType intEvalFunc) { + llvm::Type *Ty = cV0->getType(); + Value *Result = nullptr; + if (Ty->isDoubleTy()) { + ConstantFP *fpV0 = cast(cV0); + ConstantFP *fpV1 = cast(cV1); + ConstantFP *fpV2 = cast(cV2); + double dV0 = fpV0->getValueAPF().convertToDouble(); + double dV1 = fpV1->getValueAPF().convertToDouble(); + double dV2 = fpV2->getValueAPF().convertToDouble(); + Value *dResult = ConstantFP::get(Ty, doubleEvalFunc(dV0, dV1, dV2)); + Result = dResult; + } else if (Ty->isFloatTy()) { + ConstantFP *fpV0 = cast(cV0); + ConstantFP *fpV1 = cast(cV1); + ConstantFP *fpV2 = cast(cV2); + float fV0 = fpV0->getValueAPF().convertToFloat(); + float fV1 = fpV1->getValueAPF().convertToFloat(); + float fV2 = fpV2->getValueAPF().convertToFloat(); + Value *dResult = ConstantFP::get(Ty, floatEvalFunc(fV0, fV1, fV2)); + Result = dResult; + } else { + DXASSERT_NOMSG(Ty->isIntegerTy()); + DXASSERT_NOMSG(intEvalFunc); + ConstantInt *ciV0 = cast(cV0); + ConstantInt *ciV1 = cast(cV1); + ConstantInt *ciV2 = cast(cV2); + const APInt &iV0 = ciV0->getValue(); + const APInt &iV1 = ciV1->getValue(); + const APInt &iV2 = ciV2->getValue(); + Value *dResult = ConstantInt::get(Ty, intEvalFunc(iV0, iV1, iV2)); + Result = dResult; + } + return Result; +} + Value *EvalUnaryIntrinsic(CallInst *CI, FloatUnaryEvalFuncType floatEvalFunc, DoubleUnaryEvalFuncType doubleEvalFunc) { Value *V = CI->getArgOperand(0); @@ -1566,6 +1609,43 @@ Value *EvalBinaryIntrinsic(CallInst *CI, FloatBinaryEvalFuncType floatEvalFunc, return Result; } +Value *EvalTernaryIntrinsic(CallInst *CI, FloatTernaryEvalFuncType floatEvalFunc, + DoubleTernaryEvalFuncType doubleEvalFunc, + IntTernaryEvalFuncType intEvalFunc = nullptr) { + Value *V0 = CI->getArgOperand(0); + Value *V1 = CI->getArgOperand(1); + Value *V2 = CI->getArgOperand(2); + llvm::Type *Ty = CI->getType(); + Value *Result = nullptr; + if (llvm::VectorType *VT = dyn_cast(Ty)) { + Result = UndefValue::get(Ty); + Constant *CV0 = cast(V0); + Constant *CV1 = cast(V1); + Constant *CV2 = cast(V2); + IRBuilder<> Builder(CI); + for (unsigned i = 0; i < VT->getNumElements(); i++) { + Constant *cV0 = cast(CV0->getAggregateElement(i)); + Constant *cV1 = cast(CV1->getAggregateElement(i)); + Constant *cV2 = cast(CV2->getAggregateElement(i)); + Value *EltResult = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc, + doubleEvalFunc, intEvalFunc); + Result = Builder.CreateInsertElement(Result, EltResult, i); + } + } else { + Constant *cV0 = cast(V0); + Constant *cV1 = cast(V1); + Constant *cV2 = cast(V2); + Result = EvalTernaryIntrinsic(cV0, cV1, cV2, floatEvalFunc, doubleEvalFunc, + intEvalFunc); + } + CI->replaceAllUsesWith(Result); + CI->eraseFromParent(); + return Result; + + CI->eraseFromParent(); + return Result; +} + void SimpleTransformForHLDXIRInst(Instruction *I, SmallInstSet &deadInsts) { unsigned opcode = I->getOpcode(); @@ -1789,6 +1869,18 @@ Value *TryEvalIntrinsic(CallInst *CI, IntrinsicOp intriOp, CI->eraseFromParent(); return cNan; } break; + case IntrinsicOp::IOP_clamp: { + auto clampF = [](float a, float b, float c) { + return a < b ? b : a > c ? c : a; + }; + auto clampD = [](double a, double b, double c) { + return a < b ? b : a > c ? c : a; + }; + auto clampI = [](const APInt &a, const APInt &b, const APInt &c) -> APInt { + return a.slt(b) ? b : a.sgt(c) ? c : a; + }; + return EvalTernaryIntrinsic(CI, clampF, clampD, clampI); + } break; default: return nullptr; } diff --git a/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/basic/clamp_const_prop.hlsl b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/basic/clamp_const_prop.hlsl new file mode 100644 index 0000000000..6f26c83ba8 --- /dev/null +++ b/tools/clang/test/HLSLFileCheck/hlsl/intrinsics/basic/clamp_const_prop.hlsl @@ -0,0 +1,14 @@ +// RUN: %dxc -T ps_6_0 %s -E main | %FileCheck %s +// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 0, float 1.000000e+00) +// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 1, float -1.250000e+00) +// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 2, float 3.000000e+00) +// CHECK: call void @dx.op.storeOutput.f32(i32 5, i32 0, i32 0, i8 3, float 2.000000e+00) + +[RootSignature("")] +float4 main() : SV_Target { + return float4( + clamp(10, 0, 1), + clamp(-1.0f, -2.5f, -1.25f), + clamp((double)3, (double)-2, (double)5), + clamp(-5LL, 2LL, 5LL)); +}