From 7ca6bc5940321c18f5634bb960fa795366097e45 Mon Sep 17 00:00:00 2001 From: Greg Roth Date: Sat, 10 Aug 2024 17:07:19 -0600 Subject: [PATCH] Update DX intrinsic expansion for new llvm intrinsics The new LLVM integer intrinsics replace the previous dx intrinsics. This updates the DXIL intrinsic expansion code and tests to use and expect the new integer intrinsics and the flattened DX floating vector size variants only after op lowering. Part of #88056 --- llvm/include/llvm/IR/IntrinsicsDirectX.td | 20 +-- llvm/lib/Target/DirectX/DXIL.td | 6 +- .../Target/DirectX/DXILIntrinsicExpansion.cpp | 15 +-- llvm/test/CodeGen/DirectX/fdot.ll | 117 ++++++++++-------- llvm/test/CodeGen/DirectX/idot.ll | 34 ++--- 5 files changed, 96 insertions(+), 96 deletions(-) diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index 312c3862f240d8..8ce79eb7cbaafa 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>; def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>; -def int_dx_dot2 : - Intrinsic<[LLVMVectorElementType<0>], +def int_dx_dot2 : + Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_dot3 : - Intrinsic<[LLVMVectorElementType<0>], +def int_dx_dot3 : + Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_dot4 : - Intrinsic<[LLVMVectorElementType<0>], +def int_dx_dot4 : + Intrinsic<[LLVMVectorElementType<0>], [llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_sdot : - Intrinsic<[LLVMVectorElementType<0>], - [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], - [IntrNoMem, IntrWillReturn, Commutative] >; -def int_dx_udot : - Intrinsic<[LLVMVectorElementType<0>], - [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], - [IntrNoMem, IntrWillReturn, Commutative] >; def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index 67015cff78a79a..ac79b84a1e9100 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -637,7 +637,7 @@ def UMad : DXILOp<49, tertiary> { def Dot2 : DXILOp<54, dot2> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " - "a[n]*b[n] where n is between 0 and 1"; + "a[n]*b[n] where n is 0 to 1 inclusive"; let LLVMIntrinsic = int_dx_dot2; let arguments = !listsplat(overloadTy, 4); let result = overloadTy; @@ -648,7 +648,7 @@ def Dot2 : DXILOp<54, dot2> { def Dot3 : DXILOp<55, dot3> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " - "a[n]*b[n] where n is between 0 and 2"; + "a[n]*b[n] where n is 0 to 2 inclusive"; let LLVMIntrinsic = int_dx_dot3; let arguments = !listsplat(overloadTy, 6); let result = overloadTy; @@ -659,7 +659,7 @@ def Dot3 : DXILOp<55, dot3> { def Dot4 : DXILOp<56, dot4> { let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + " - "a[n]*b[n] where n is between 0 and 3"; + "a[n]*b[n] where n is 0 to 3 inclusive"; let LLVMIntrinsic = int_dx_dot4; let arguments = !listsplat(overloadTy, 8); let result = overloadTy; diff --git a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp index 307ed453b9eb6d..5adb62ed54e81b 100644 --- a/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp +++ b/llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp @@ -43,8 +43,8 @@ static bool isIntrinsicExpansion(Function &F) { case Intrinsic::dx_uclamp: case Intrinsic::dx_lerp: case Intrinsic::dx_length: - case Intrinsic::dx_sdot: - case Intrinsic::dx_udot: + case Intrinsic::sdot: + case Intrinsic::udot: case Intrinsic::fdot: return true; } @@ -72,10 +72,11 @@ static bool expandAbs(CallInst *Orig) { } static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) { - assert(DotIntrinsic == Intrinsic::dx_sdot || - DotIntrinsic == Intrinsic::dx_udot || DotIntrinsic == Intrinsic::fdot); + assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot || + DotIntrinsic == Intrinsic::fdot); Value *A = Orig->getOperand(0); Value *B = Orig->getOperand(1); + [[maybe_unused]] Type *ATy = A->getType(); [[maybe_unused]] Type *BTy = B->getType(); assert(ATy->isVectorTy() && BTy->isVectorTy()); @@ -88,7 +89,7 @@ static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) { Value *Result; if (EltTy->isIntegerTy()) { // Expand integer dot product to multiply and add ops - Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot + Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::sdot ? Intrinsic::dx_imad : Intrinsic::dx_umad; Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0); @@ -340,9 +341,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) { return expandLerpIntrinsic(Orig); case Intrinsic::dx_length: return expandLengthIntrinsic(Orig); - case Intrinsic::dx_sdot: - case Intrinsic::dx_udot: case Intrinsic::fdot: + case Intrinsic::sdot: + case Intrinsic::udot: return expandDotIntrinsic(Orig, F.getIntrinsicID()); } return false; diff --git a/llvm/test/CodeGen/DirectX/fdot.ll b/llvm/test/CodeGen/DirectX/fdot.ll index 56817a172ff9e3..3eb39fd5d4bb74 100644 --- a/llvm/test/CodeGen/DirectX/fdot.ll +++ b/llvm/test/CodeGen/DirectX/fdot.ll @@ -1,94 +1,101 @@ +; RUN: opt -S -dxil-intrinsic-expansion -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK ; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s -; Make sure dxil operation function calls for dot are generated for int/uint vectors. +; Make sure dxil operation function calls for dot are generated for float type vectors. ; CHECK-LABEL: dot_half2 define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) { entry: -; CHECK: extractelement <2 x half> %a, i32 0 -; CHECK: extractelement <2 x half> %a, i32 1 -; CHECK: extractelement <2 x half> %b, i32 0 -; CHECK: extractelement <2 x half> %b, i32 1 -; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) - %dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b) +; DOPCHECK: extractelement <2 x half> %a, i32 0 +; DOPCHECK: extractelement <2 x half> %a, i32 1 +; DOPCHECK: extractelement <2 x half> %b, i32 0 +; DOPCHECK: extractelement <2 x half> %b, i32 1 +; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) +; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b) + %dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b) ret half %dx.dot } ; CHECK-LABEL: dot_half3 define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) { entry: -; CHECK: extractelement <3 x half> %a, i32 0 -; CHECK: extractelement <3 x half> %a, i32 1 -; CHECK: extractelement <3 x half> %a, i32 2 -; CHECK: extractelement <3 x half> %b, i32 0 -; CHECK: extractelement <3 x half> %b, i32 1 -; CHECK: extractelement <3 x half> %b, i32 2 -; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) - %dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b) +; DOPCHECK: extractelement <3 x half> %a, i32 0 +; DOPCHECK: extractelement <3 x half> %a, i32 1 +; DOPCHECK: extractelement <3 x half> %a, i32 2 +; DOPCHECK: extractelement <3 x half> %b, i32 0 +; DOPCHECK: extractelement <3 x half> %b, i32 1 +; DOPCHECK: extractelement <3 x half> %b, i32 2 +; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) +; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b) + %dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b) ret half %dx.dot } ; CHECK-LABEL: dot_half4 define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) { entry: -; CHECK: extractelement <4 x half> %a, i32 0 -; CHECK: extractelement <4 x half> %a, i32 1 -; CHECK: extractelement <4 x half> %a, i32 2 -; CHECK: extractelement <4 x half> %a, i32 3 -; CHECK: extractelement <4 x half> %b, i32 0 -; CHECK: extractelement <4 x half> %b, i32 1 -; CHECK: extractelement <4 x half> %b, i32 2 -; CHECK: extractelement <4 x half> %b, i32 3 -; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) - %dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b) +; DOPCHECK: extractelement <4 x half> %a, i32 0 +; DOPCHECK: extractelement <4 x half> %a, i32 1 +; DOPCHECK: extractelement <4 x half> %a, i32 2 +; DOPCHECK: extractelement <4 x half> %a, i32 3 +; DOPCHECK: extractelement <4 x half> %b, i32 0 +; DOPCHECK: extractelement <4 x half> %b, i32 1 +; DOPCHECK: extractelement <4 x half> %b, i32 2 +; DOPCHECK: extractelement <4 x half> %b, i32 3 +; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}) +; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b) + %dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b) ret half %dx.dot } ; CHECK-LABEL: dot_float2 define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) { entry: -; CHECK: extractelement <2 x float> %a, i32 0 -; CHECK: extractelement <2 x float> %a, i32 1 -; CHECK: extractelement <2 x float> %b, i32 0 -; CHECK: extractelement <2 x float> %b, i32 1 -; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) - %dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b) +; DOPCHECK: extractelement <2 x float> %a, i32 0 +; DOPCHECK: extractelement <2 x float> %a, i32 1 +; DOPCHECK: extractelement <2 x float> %b, i32 0 +; DOPCHECK: extractelement <2 x float> %b, i32 1 +; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) +; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b) + %dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b) ret float %dx.dot } ; CHECK-LABEL: dot_float3 define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) { entry: -; CHECK: extractelement <3 x float> %a, i32 0 -; CHECK: extractelement <3 x float> %a, i32 1 -; CHECK: extractelement <3 x float> %a, i32 2 -; CHECK: extractelement <3 x float> %b, i32 0 -; CHECK: extractelement <3 x float> %b, i32 1 -; CHECK: extractelement <3 x float> %b, i32 2 -; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) - %dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b) +; DOPCHECK: extractelement <3 x float> %a, i32 0 +; DOPCHECK: extractelement <3 x float> %a, i32 1 +; DOPCHECK: extractelement <3 x float> %a, i32 2 +; DOPCHECK: extractelement <3 x float> %b, i32 0 +; DOPCHECK: extractelement <3 x float> %b, i32 1 +; DOPCHECK: extractelement <3 x float> %b, i32 2 +; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) +; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b) + %dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b) ret float %dx.dot } ; CHECK-LABEL: dot_float4 define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) { entry: -; CHECK: extractelement <4 x float> %a, i32 0 -; CHECK: extractelement <4 x float> %a, i32 1 -; CHECK: extractelement <4 x float> %a, i32 2 -; CHECK: extractelement <4 x float> %a, i32 3 -; CHECK: extractelement <4 x float> %b, i32 0 -; CHECK: extractelement <4 x float> %b, i32 1 -; CHECK: extractelement <4 x float> %b, i32 2 -; CHECK: extractelement <4 x float> %b, i32 3 -; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) - %dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b) +; DOPCHECK: extractelement <4 x float> %a, i32 0 +; DOPCHECK: extractelement <4 x float> %a, i32 1 +; DOPCHECK: extractelement <4 x float> %a, i32 2 +; DOPCHECK: extractelement <4 x float> %a, i32 3 +; DOPCHECK: extractelement <4 x float> %b, i32 0 +; DOPCHECK: extractelement <4 x float> %b, i32 1 +; DOPCHECK: extractelement <4 x float> %b, i32 2 +; DOPCHECK: extractelement <4 x float> %b, i32 3 +; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}) +; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b) + %dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b) ret float %dx.dot } -declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> ) -declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> ) -declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> ) -declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>) -declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>) -declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>) +declare half @llvm.fdot.v2f16(<2 x half> , <2 x half> ) +declare half @llvm.fdot.v3f16(<3 x half> , <3 x half> ) +declare half @llvm.fdot.v4f16(<4 x half> , <4 x half> ) +declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>) +declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>) +declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>) diff --git a/llvm/test/CodeGen/DirectX/idot.ll b/llvm/test/CodeGen/DirectX/idot.ll index eac1b91106ddef..94822f92b41351 100644 --- a/llvm/test/CodeGen/DirectX/idot.ll +++ b/llvm/test/CodeGen/DirectX/idot.ll @@ -13,12 +13,12 @@ entry: ; CHECK: extractelement <2 x i16> %b, i64 1 ; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) ; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) - %dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) - ret i16 %dx.dot + %dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b) + ret i16 %dot } -; CHECK-LABEL: sdot_int4 -define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { +; CHECK-LABEL: dot_int4 +define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) { entry: ; CHECK: extractelement <4 x i32> %a, i64 0 ; CHECK: extractelement <4 x i32> %b, i64 0 @@ -35,8 +35,8 @@ entry: ; CHECK: extractelement <4 x i32> %b, i64 3 ; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) ; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) - ret i32 %dx.dot + %dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dot } ; CHECK-LABEL: dot_uint16_t3 @@ -53,8 +53,8 @@ entry: ; CHECK: extractelement <3 x i16> %b, i64 2 ; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) ; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}}) - %dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b) - ret i16 %dx.dot + %dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b) + ret i16 %dot } ; CHECK-LABEL: dot_uint4 @@ -75,8 +75,8 @@ entry: ; CHECK: extractelement <4 x i32> %b, i64 3 ; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) ; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}}) - %dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b) - ret i32 %dx.dot + %dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b) + ret i32 %dot } ; CHECK-LABEL: dot_uint64_t4 @@ -89,12 +89,12 @@ entry: ; CHECK: extractelement <2 x i64> %b, i64 1 ; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}}) ; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}}) - %dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b) - ret i64 %dx.dot + %dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b) + ret i64 %dot } -declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>) -declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>) -declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i16>) -declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>) -declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>) +declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>) +declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>) +declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>) +declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>) +declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>)