Skip to content

Commit

Permalink
Update DX intrinsic expansion for new llvm intrinsics
Browse files Browse the repository at this point in the history
The new LLVM integer intrinsics replace the previous dx intrinsics.
This updates the DXIL intrinsic expansion code and tests to use and
expect the new integer intrinsics and the flattened DX floating
vector size variants only after op lowering.

Part of llvm#88056
  • Loading branch information
pow2clk committed Aug 12, 2024
1 parent 6fde4bc commit 7ca6bc5
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 96 deletions.
20 changes: 6 additions & 14 deletions llvm/include/llvm/IR/IntrinsicsDirectX.td
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,18 @@ def int_dx_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty]>;
def int_dx_clamp : DefaultAttrsIntrinsic<[llvm_any_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
def int_dx_uclamp : DefaultAttrsIntrinsic<[llvm_anyint_ty], [LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;

def int_dx_dot2 :
Intrinsic<[LLVMVectorElementType<0>],
def int_dx_dot2 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_dot3 :
Intrinsic<[LLVMVectorElementType<0>],
def int_dx_dot3 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_dot4 :
Intrinsic<[LLVMVectorElementType<0>],
def int_dx_dot4 :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyfloat_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_sdot :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;
def int_dx_udot :
Intrinsic<[LLVMVectorElementType<0>],
[llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>],
[IntrNoMem, IntrWillReturn, Commutative] >;

def int_dx_frac : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>]>;

Expand Down
6 changes: 3 additions & 3 deletions llvm/lib/Target/DirectX/DXIL.td
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ def UMad : DXILOp<49, tertiary> {

def Dot2 : DXILOp<54, dot2> {
let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
"a[n]*b[n] where n is between 0 and 1";
"a[n]*b[n] where n is 0 to 1 inclusive";
let LLVMIntrinsic = int_dx_dot2;
let arguments = !listsplat(overloadTy, 4);
let result = overloadTy;
Expand All @@ -648,7 +648,7 @@ def Dot2 : DXILOp<54, dot2> {

def Dot3 : DXILOp<55, dot3> {
let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
"a[n]*b[n] where n is between 0 and 2";
"a[n]*b[n] where n is 0 to 2 inclusive";
let LLVMIntrinsic = int_dx_dot3;
let arguments = !listsplat(overloadTy, 6);
let result = overloadTy;
Expand All @@ -659,7 +659,7 @@ def Dot3 : DXILOp<55, dot3> {

def Dot4 : DXILOp<56, dot4> {
let Doc = "dot product of two float vectors Dot(a,b) = a[0]*b[0] + ... + "
"a[n]*b[n] where n is between 0 and 3";
"a[n]*b[n] where n is 0 to 3 inclusive";
let LLVMIntrinsic = int_dx_dot4;
let arguments = !listsplat(overloadTy, 8);
let result = overloadTy;
Expand Down
15 changes: 8 additions & 7 deletions llvm/lib/Target/DirectX/DXILIntrinsicExpansion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ static bool isIntrinsicExpansion(Function &F) {
case Intrinsic::dx_uclamp:
case Intrinsic::dx_lerp:
case Intrinsic::dx_length:
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
case Intrinsic::sdot:
case Intrinsic::udot:
case Intrinsic::fdot:
return true;
}
Expand Down Expand Up @@ -72,10 +72,11 @@ static bool expandAbs(CallInst *Orig) {
}

static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
assert(DotIntrinsic == Intrinsic::dx_sdot ||
DotIntrinsic == Intrinsic::dx_udot || DotIntrinsic == Intrinsic::fdot);
assert(DotIntrinsic == Intrinsic::sdot || DotIntrinsic == Intrinsic::udot ||
DotIntrinsic == Intrinsic::fdot);
Value *A = Orig->getOperand(0);
Value *B = Orig->getOperand(1);

[[maybe_unused]] Type *ATy = A->getType();
[[maybe_unused]] Type *BTy = B->getType();
assert(ATy->isVectorTy() && BTy->isVectorTy());
Expand All @@ -88,7 +89,7 @@ static bool expandDotIntrinsic(CallInst *Orig, Intrinsic::ID DotIntrinsic) {
Value *Result;
if (EltTy->isIntegerTy()) {
// Expand integer dot product to multiply and add ops
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::dx_sdot
Intrinsic::ID MadIntrinsic = DotIntrinsic == Intrinsic::sdot
? Intrinsic::dx_imad
: Intrinsic::dx_umad;
Value *Elt0 = Builder.CreateExtractElement(A, (uint64_t)0);
Expand Down Expand Up @@ -340,9 +341,9 @@ static bool expandIntrinsic(Function &F, CallInst *Orig) {
return expandLerpIntrinsic(Orig);
case Intrinsic::dx_length:
return expandLengthIntrinsic(Orig);
case Intrinsic::dx_sdot:
case Intrinsic::dx_udot:
case Intrinsic::fdot:
case Intrinsic::sdot:
case Intrinsic::udot:
return expandDotIntrinsic(Orig, F.getIntrinsicID());
}
return false;
Expand Down
117 changes: 62 additions & 55 deletions llvm/test/CodeGen/DirectX/fdot.ll
Original file line number Diff line number Diff line change
@@ -1,94 +1,101 @@
; RUN: opt -S -dxil-intrinsic-expansion -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s --check-prefixes=CHECK,EXPCHECK
; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

; Make sure dxil operation function calls for dot are generated for int/uint vectors.
; Make sure dxil operation function calls for dot are generated for float type vectors.

; CHECK-LABEL: dot_half2
define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
entry:
; CHECK: extractelement <2 x half> %a, i32 0
; CHECK: extractelement <2 x half> %a, i32 1
; CHECK: extractelement <2 x half> %b, i32 0
; CHECK: extractelement <2 x half> %b, i32 1
; CHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
%dx.dot = call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
; DOPCHECK: extractelement <2 x half> %a, i32 0
; DOPCHECK: extractelement <2 x half> %a, i32 1
; DOPCHECK: extractelement <2 x half> %b, i32 0
; DOPCHECK: extractelement <2 x half> %b, i32 1
; DOPCHECK: call half @dx.op.dot2.f16(i32 54, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
; EXPCHECK: call half @llvm.dx.dot2.v2f16(<2 x half> %a, <2 x half> %b)
%dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b)
ret half %dx.dot
}

; CHECK-LABEL: dot_half3
define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
entry:
; CHECK: extractelement <3 x half> %a, i32 0
; CHECK: extractelement <3 x half> %a, i32 1
; CHECK: extractelement <3 x half> %a, i32 2
; CHECK: extractelement <3 x half> %b, i32 0
; CHECK: extractelement <3 x half> %b, i32 1
; CHECK: extractelement <3 x half> %b, i32 2
; CHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
%dx.dot = call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
; DOPCHECK: extractelement <3 x half> %a, i32 0
; DOPCHECK: extractelement <3 x half> %a, i32 1
; DOPCHECK: extractelement <3 x half> %a, i32 2
; DOPCHECK: extractelement <3 x half> %b, i32 0
; DOPCHECK: extractelement <3 x half> %b, i32 1
; DOPCHECK: extractelement <3 x half> %b, i32 2
; DOPCHECK: call half @dx.op.dot3.f16(i32 55, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
; EXPCHECK: call half @llvm.dx.dot3.v3f16(<3 x half> %a, <3 x half> %b)
%dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b)
ret half %dx.dot
}

; CHECK-LABEL: dot_half4
define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
entry:
; CHECK: extractelement <4 x half> %a, i32 0
; CHECK: extractelement <4 x half> %a, i32 1
; CHECK: extractelement <4 x half> %a, i32 2
; CHECK: extractelement <4 x half> %a, i32 3
; CHECK: extractelement <4 x half> %b, i32 0
; CHECK: extractelement <4 x half> %b, i32 1
; CHECK: extractelement <4 x half> %b, i32 2
; CHECK: extractelement <4 x half> %b, i32 3
; CHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
%dx.dot = call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
; DOPCHECK: extractelement <4 x half> %a, i32 0
; DOPCHECK: extractelement <4 x half> %a, i32 1
; DOPCHECK: extractelement <4 x half> %a, i32 2
; DOPCHECK: extractelement <4 x half> %a, i32 3
; DOPCHECK: extractelement <4 x half> %b, i32 0
; DOPCHECK: extractelement <4 x half> %b, i32 1
; DOPCHECK: extractelement <4 x half> %b, i32 2
; DOPCHECK: extractelement <4 x half> %b, i32 3
; DOPCHECK: call half @dx.op.dot4.f16(i32 56, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}}, half %{{.*}})
; EXPCHECK: call half @llvm.dx.dot4.v4f16(<4 x half> %a, <4 x half> %b)
%dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b)
ret half %dx.dot
}

; CHECK-LABEL: dot_float2
define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
entry:
; CHECK: extractelement <2 x float> %a, i32 0
; CHECK: extractelement <2 x float> %a, i32 1
; CHECK: extractelement <2 x float> %b, i32 0
; CHECK: extractelement <2 x float> %b, i32 1
; CHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
%dx.dot = call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
; DOPCHECK: extractelement <2 x float> %a, i32 0
; DOPCHECK: extractelement <2 x float> %a, i32 1
; DOPCHECK: extractelement <2 x float> %b, i32 0
; DOPCHECK: extractelement <2 x float> %b, i32 1
; DOPCHECK: call float @dx.op.dot2.f32(i32 54, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
; EXPCHECK: call float @llvm.dx.dot2.v2f32(<2 x float> %a, <2 x float> %b)
%dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b)
ret float %dx.dot
}

; CHECK-LABEL: dot_float3
define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
entry:
; CHECK: extractelement <3 x float> %a, i32 0
; CHECK: extractelement <3 x float> %a, i32 1
; CHECK: extractelement <3 x float> %a, i32 2
; CHECK: extractelement <3 x float> %b, i32 0
; CHECK: extractelement <3 x float> %b, i32 1
; CHECK: extractelement <3 x float> %b, i32 2
; CHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
%dx.dot = call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
; DOPCHECK: extractelement <3 x float> %a, i32 0
; DOPCHECK: extractelement <3 x float> %a, i32 1
; DOPCHECK: extractelement <3 x float> %a, i32 2
; DOPCHECK: extractelement <3 x float> %b, i32 0
; DOPCHECK: extractelement <3 x float> %b, i32 1
; DOPCHECK: extractelement <3 x float> %b, i32 2
; DOPCHECK: call float @dx.op.dot3.f32(i32 55, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
; EXPCHECK: call float @llvm.dx.dot3.v3f32(<3 x float> %a, <3 x float> %b)
%dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b)
ret float %dx.dot
}

; CHECK-LABEL: dot_float4
define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
entry:
; CHECK: extractelement <4 x float> %a, i32 0
; CHECK: extractelement <4 x float> %a, i32 1
; CHECK: extractelement <4 x float> %a, i32 2
; CHECK: extractelement <4 x float> %a, i32 3
; CHECK: extractelement <4 x float> %b, i32 0
; CHECK: extractelement <4 x float> %b, i32 1
; CHECK: extractelement <4 x float> %b, i32 2
; CHECK: extractelement <4 x float> %b, i32 3
; CHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
%dx.dot = call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
; DOPCHECK: extractelement <4 x float> %a, i32 0
; DOPCHECK: extractelement <4 x float> %a, i32 1
; DOPCHECK: extractelement <4 x float> %a, i32 2
; DOPCHECK: extractelement <4 x float> %a, i32 3
; DOPCHECK: extractelement <4 x float> %b, i32 0
; DOPCHECK: extractelement <4 x float> %b, i32 1
; DOPCHECK: extractelement <4 x float> %b, i32 2
; DOPCHECK: extractelement <4 x float> %b, i32 3
; DOPCHECK: call float @dx.op.dot4.f32(i32 56, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}}, float %{{.*}})
; EXPCHECK: call float @llvm.dx.dot4.v4f32(<4 x float> %a, <4 x float> %b)
%dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b)
ret float %dx.dot
}

declare half @llvm.dx.dot.v2f16(<2 x half> , <2 x half> )
declare half @llvm.dx.dot.v3f16(<3 x half> , <3 x half> )
declare half @llvm.dx.dot.v4f16(<4 x half> , <4 x half> )
declare float @llvm.dx.dot.v2f32(<2 x float>, <2 x float>)
declare float @llvm.dx.dot.v3f32(<3 x float>, <3 x float>)
declare float @llvm.dx.dot.v4f32(<4 x float>, <4 x float>)
declare half @llvm.fdot.v2f16(<2 x half> , <2 x half> )
declare half @llvm.fdot.v3f16(<3 x half> , <3 x half> )
declare half @llvm.fdot.v4f16(<4 x half> , <4 x half> )
declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>)
declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>)
declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>)
34 changes: 17 additions & 17 deletions llvm/test/CodeGen/DirectX/idot.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@ entry:
; CHECK: extractelement <2 x i16> %b, i64 1
; EXPCHECK: call i16 @llvm.dx.imad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 48, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
%dx.dot = call i16 @llvm.dx.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
ret i16 %dx.dot
%dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
ret i16 %dot
}

; CHECK-LABEL: sdot_int4
define noundef i32 @sdot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
; CHECK-LABEL: dot_int4
define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: extractelement <4 x i32> %a, i64 0
; CHECK: extractelement <4 x i32> %b, i64 0
Expand All @@ -35,8 +35,8 @@ entry:
; CHECK: extractelement <4 x i32> %b, i64 3
; EXPCHECK: call i32 @llvm.dx.imad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 48, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%dx.dot = call i32 @llvm.dx.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dx.dot
%dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}

; CHECK-LABEL: dot_uint16_t3
Expand All @@ -53,8 +53,8 @@ entry:
; CHECK: extractelement <3 x i16> %b, i64 2
; EXPCHECK: call i16 @llvm.dx.umad.i16(i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
; DOPCHECK: call i16 @dx.op.tertiary.i16(i32 49, i16 %{{.*}}, i16 %{{.*}}, i16 %{{.*}})
%dx.dot = call i16 @llvm.dx.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
ret i16 %dx.dot
%dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
ret i16 %dot
}

; CHECK-LABEL: dot_uint4
Expand All @@ -75,8 +75,8 @@ entry:
; CHECK: extractelement <4 x i32> %b, i64 3
; EXPCHECK: call i32 @llvm.dx.umad.i32(i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
; DOPCHECK: call i32 @dx.op.tertiary.i32(i32 49, i32 %{{.*}}, i32 %{{.*}}, i32 %{{.*}})
%dx.dot = call i32 @llvm.dx.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dx.dot
%dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}

; CHECK-LABEL: dot_uint64_t4
Expand All @@ -89,12 +89,12 @@ entry:
; CHECK: extractelement <2 x i64> %b, i64 1
; EXPCHECK: call i64 @llvm.dx.umad.i64(i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
; DOPCHECK: call i64 @dx.op.tertiary.i64(i32 49, i64 %{{.*}}, i64 %{{.*}}, i64 %{{.*}})
%dx.dot = call i64 @llvm.dx.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
ret i64 %dx.dot
%dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
ret i64 %dot
}

declare i16 @llvm.dx.sdot.v2i16(<2 x i16>, <2 x i16>)
declare i32 @llvm.dx.sdot.v4i32(<4 x i32>, <4 x i32>)
declare i16 @llvm.dx.udot.v3i32(<3 x i16>, <3 x i16>)
declare i32 @llvm.dx.udot.v4i32(<4 x i32>, <4 x i32>)
declare i64 @llvm.dx.udot.v2i64(<2 x i64>, <2 x i64>)
declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>)
declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>)
declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>)
declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>)
declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>)

0 comments on commit 7ca6bc5

Please sign in to comment.