Skip to content

Commit

Permalink
Add SPIRV generation for HLSL dot
Browse files Browse the repository at this point in the history
Use the new LLVM dot intrinsics to build SPIRV instructions.
This involves generating multiply and add operations for integers
and the existing OpDot operation for floating point. This includes
adding some generic opcodes for signed, unsigned and floats.
These require updating an existing test for all such opcodes.

New tests for generating SPIRV float and integer dot intrinsics are
added as well.

Fixes llvm#88056
  • Loading branch information
pow2clk committed Aug 12, 2024
1 parent 7ca6bc5 commit edbb80c
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 0 deletions.
9 changes: 9 additions & 0 deletions llvm/include/llvm/Support/TargetOpcodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,15 @@ HANDLE_TARGET_OPCODE(G_FSINH)
/// Floating point hyperbolic tangent.
HANDLE_TARGET_OPCODE(G_FTANH)

/// Floating point vector dot product
HANDLE_TARGET_OPCODE(G_FDOTPROD)

/// Unsigned integer vector dot product
HANDLE_TARGET_OPCODE(G_UDOTPROD)

/// Signed integer vector dot product
HANDLE_TARGET_OPCODE(G_SDOTPROD)

/// Floating point square root.
HANDLE_TARGET_OPCODE(G_FSQRT)

Expand Down
21 changes: 21 additions & 0 deletions llvm/include/llvm/Target/GenericOpcodes.td
Original file line number Diff line number Diff line change
Expand Up @@ -1057,6 +1057,27 @@ def G_FTANH : GenericInstruction {
let hasSideEffects = false;
}

/// Floating point vector dot product
def G_FDOTPROD : GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type0:$src1, type0:$src2);
let hasSideEffects = false;
}

/// Signed integer vector dot product
def G_SDOTPROD : GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type0:$src1, type0:$src2);
let hasSideEffects = false;
}

/// Unsigned integer vector dot product
def G_UDOTPROD : GenericInstruction {
let OutOperandList = (outs type0:$dst);
let InOperandList = (ins type0:$src1, type0:$src2);
let hasSideEffects = false;
}

// Floating point square root of a value.
// This returns NaN for negative nonzero values.
// NOTE: Unlike libm sqrt(), this never sets errno. In all other respects it's
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/GlobalISel/IRTranslator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1903,6 +1903,12 @@ unsigned IRTranslator::getSimpleIntrinsicOpcode(Intrinsic::ID ID) {
return TargetOpcode::G_CTPOP;
case Intrinsic::exp:
return TargetOpcode::G_FEXP;
case Intrinsic::fdot:
return TargetOpcode::G_FDOTPROD;
case Intrinsic::sdot:
return TargetOpcode::G_SDOTPROD;
case Intrinsic::udot:
return TargetOpcode::G_UDOTPROD;
case Intrinsic::exp2:
return TargetOpcode::G_FEXP2;
case Intrinsic::exp10:
Expand Down
78 changes: 78 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class SPIRVInstructionSelector : public InstructionSelector {
bool selectRsqrt(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

bool selectIntegerDot(Register ResVReg, const SPIRVType *ResType,
MachineInstr &I) const;

void renderImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
int OpIdx) const;
void renderFImm32(MachineInstrBuilder &MIB, const MachineInstr &I,
Expand Down Expand Up @@ -380,6 +383,20 @@ bool SPIRVInstructionSelector::spvSelect(Register ResVReg,
MIB.addImm(V);
return MIB.constrainAllUses(TII, TRI, RBI);
}

case TargetOpcode::G_FDOTPROD: {
MachineBasicBlock &BB = *I.getParent();
return BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpDot))
.addDef(ResVReg)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(I.getOperand(1).getReg())
.addUse(I.getOperand(2).getReg())
.constrainAllUses(TII, TRI, RBI);
}
case TargetOpcode::G_SDOTPROD:
case TargetOpcode::G_UDOTPROD:
return selectIntegerDot(ResVReg, ResType, I);

case TargetOpcode::G_MEMMOVE:
case TargetOpcode::G_MEMCPY:
case TargetOpcode::G_MEMSET:
Expand Down Expand Up @@ -1366,6 +1383,67 @@ bool SPIRVInstructionSelector::selectRsqrt(Register ResVReg,
.constrainAllUses(TII, TRI, RBI);
}

// Since there is no integer dot implementation, expand by piecewise multiplying
// and adding the results, making use of FMA operations where possible.
bool SPIRVInstructionSelector::selectIntegerDot(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
assert(I.getNumOperands() == 3);
assert(I.getOperand(1).isReg());
assert(I.getOperand(2).isReg());
MachineBasicBlock &BB = *I.getParent();

// Multiply the vectors, then sum the results
Register Vec0 = I.getOperand(1).getReg();
Register Vec1 = I.getOperand(2).getReg();
Register TmpVec = MRI->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *VecType = GR.getSPIRVTypeForVReg(Vec0);

bool Result = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIMulV))
.addDef(TmpVec)
.addUse(GR.getSPIRVTypeID(VecType))
.addUse(Vec0)
.addUse(Vec1)
.constrainAllUses(TII, TRI, RBI);

assert(GR.getScalarOrVectorComponentCount(VecType) > 1 &&
"dot product requires a vector of at least 2 components");

Register Res = MRI->createVirtualRegister(&SPIRV::IDRegClass);
Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(Res)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(TmpVec)
.addImm(0)
.constrainAllUses(TII, TRI, RBI);

for (unsigned i = 1; i < GR.getScalarOrVectorComponentCount(VecType); i++) {
Register Elt = MRI->createVirtualRegister(&SPIRV::IDRegClass);

Result |=
BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract))
.addDef(Elt)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(TmpVec)
.addImm(i)
.constrainAllUses(TII, TRI, RBI);

Register Sum = i < GR.getScalarOrVectorComponentCount(VecType) - 1
? MRI->createVirtualRegister(&SPIRV::IDRegClass)
: ResVReg;

Result |= BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpIAddS))
.addDef(Sum)
.addUse(GR.getSPIRVTypeID(ResType))
.addUse(Res)
.addUse(Elt)
.constrainAllUses(TII, TRI, RBI);
Res = Sum;
}

return Result;
}

bool SPIRVInstructionSelector::selectBitreverse(Register ResVReg,
const SPIRVType *ResType,
MachineInstr &I) const {
Expand Down
3 changes: 3 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVLegalizerInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,9 @@ SPIRVLegalizerInfo::SPIRVLegalizerInfo(const SPIRVSubtarget &ST) {
G_FCOSH,
G_FSINH,
G_FTANH,
G_FDOTPROD,
G_SDOTPROD,
G_UDOTPROD,
G_FSQRT,
G_FFLOOR,
G_FRINT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -716,6 +716,15 @@
# DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
# DEBUG-NEXT: .. the first uncovered type index: 1, OK
# DEBUG-NEXT: .. the first uncovered imm index: 0, OK
# DEBUG-NEXT: G_FDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_UDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_SDOTPROD (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT: .. type index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: .. imm index coverage check SKIPPED: no rules defined
# DEBUG-NEXT: G_FSQRT (opcode {{[0-9]+}}): 1 type index, 0 imm indices
# DEBUG-NEXT: .. opcode {{[0-9]+}} is aliased to {{[0-9]+}}
# DEBUG-NEXT: .. type index coverage check SKIPPED: user-defined predicate detected
Expand Down
75 changes: 75 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/fdot.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; Make sure dxil operation function calls for dot are generated for float type vectors.

; CHECK-DAG: %[[#float_16:]] = OpTypeFloat 16
; CHECK-DAG: %[[#vec2_float_16:]] = OpTypeVector %[[#float_16]] 2
; CHECK-DAG: %[[#vec3_float_16:]] = OpTypeVector %[[#float_16]] 3
; CHECK-DAG: %[[#vec4_float_16:]] = OpTypeVector %[[#float_16]] 4
; CHECK-DAG: %[[#float_32:]] = OpTypeFloat 32
; CHECK-DAG: %[[#vec2_float_32:]] = OpTypeVector %[[#float_32]] 2
; CHECK-DAG: %[[#vec3_float_32:]] = OpTypeVector %[[#float_32]] 3
; CHECK-DAG: %[[#vec4_float_32:]] = OpTypeVector %[[#float_32]] 4


define noundef half @dot_half2(<2 x half> noundef %a, <2 x half> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_16]]
; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
%dx.dot = call half @llvm.fdot.v2f16(<2 x half> %a, <2 x half> %b)
ret half %dx.dot
}

define noundef half @dot_half3(<3 x half> noundef %a, <3 x half> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_16]]
; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
%dx.dot = call half @llvm.fdot.v3f16(<3 x half> %a, <3 x half> %b)
ret half %dx.dot
}

define noundef half @dot_half4(<4 x half> noundef %a, <4 x half> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_16]]
; CHECK: OpDot %[[#float_16]] %[[#arg0:]] %[[#arg1:]]
%dx.dot = call half @llvm.fdot.v4f16(<4 x half> %a, <4 x half> %b)
ret half %dx.dot
}

define noundef float @dot_float2(<2 x float> noundef %a, <2 x float> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_float_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_float_32]]
; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
%dx.dot = call float @llvm.fdot.v2f32(<2 x float> %a, <2 x float> %b)
ret float %dx.dot
}

define noundef float @dot_float3(<3 x float> noundef %a, <3 x float> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_float_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_float_32]]
; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
%dx.dot = call float @llvm.fdot.v3f32(<3 x float> %a, <3 x float> %b)
ret float %dx.dot
}

define noundef float @dot_float4(<4 x float> noundef %a, <4 x float> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
; CHECK: OpDot %[[#float_32]] %[[#arg0:]] %[[#arg1:]]
%dx.dot = call float @llvm.fdot.v4f32(<4 x float> %a, <4 x float> %b)
ret float %dx.dot
}

declare half @llvm.fdot.v2f16(<2 x half> , <2 x half> )
declare half @llvm.fdot.v3f16(<3 x half> , <3 x half> )
declare half @llvm.fdot.v4f16(<4 x half> , <4 x half> )
declare float @llvm.fdot.v2f32(<2 x float>, <2 x float>)
declare float @llvm.fdot.v3f32(<3 x float>, <3 x float>)
declare float @llvm.fdot.v4f32(<4 x float>, <4 x float>)
88 changes: 88 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-intrinsics/idot.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-unknown %s -o - -filetype=obj | spirv-val %}

; Make sure dxil operation function calls for dot are generated for int/uint vectors.

; CHECK-DAG: %[[#int_16:]] = OpTypeInt 16
; CHECK-DAG: %[[#vec2_int_16:]] = OpTypeVector %[[#int_16]] 2
; CHECK-DAG: %[[#vec3_int_16:]] = OpTypeVector %[[#int_16]] 3
; CHECK-DAG: %[[#int_32:]] = OpTypeInt 32
; CHECK-DAG: %[[#vec4_int_32:]] = OpTypeVector %[[#int_32]] 4
; CHECK-DAG: %[[#int_64:]] = OpTypeInt 64
; CHECK-DAG: %[[#vec2_int_64:]] = OpTypeVector %[[#int_64]] 2

define noundef i16 @dot_int16_t2(<2 x i16> noundef %a, <2 x i16> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_16]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK: %[[#sum:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
%dot = call i16 @llvm.sdot.v3i16(<2 x i16> %a, <2 x i16> %b)
ret i16 %dot
}

define noundef i32 @dot_int4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
%dot = call i32 @llvm.sdot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}

define noundef i16 @dot_uint16_t3(<3 x i16> noundef %a, <3 x i16> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec3_int_16]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec3_int_16]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec3_int_16]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_16]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_16]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_16]] %[[#sum0]] %[[#elt2]]
%dot = call i16 @llvm.udot.v3i16(<3 x i16> %a, <3 x i16> %b)
ret i16 %dot
}

define noundef i32 @dot_uint4(<4 x i32> noundef %a, <4 x i32> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_int_32]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec4_int_32]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_32]] %[[#elt0]] %[[#elt1]]
; CHECK: %[[#elt2:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 2
; CHECK: %[[#sum1:]] = OpIAdd %[[#int_32]] %[[#sum0]] %[[#elt2]]
; CHECK: %[[#elt3:]] = OpCompositeExtract %[[#int_32]] %[[#mul_vec]] 3
; CHECK: %[[#sum2:]] = OpIAdd %[[#int_32]] %[[#sum1]] %[[#elt3]]
%dot = call i32 @llvm.udot.v4i32(<4 x i32> %a, <4 x i32> %b)
ret i32 %dot
}

define noundef i64 @dot_uint64_t4(<2 x i64> noundef %a, <2 x i64> noundef %b) {
entry:
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec2_int_64]]
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec2_int_64]]
; CHECK: %[[#mul_vec:]] = OpIMul %[[#vec2_int_64]] %[[#arg0]] %[[#arg1]]
; CHECK: %[[#elt0:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 0
; CHECK: %[[#elt1:]] = OpCompositeExtract %[[#int_64]] %[[#mul_vec]] 1
; CHECK: %[[#sum0:]] = OpIAdd %[[#int_64]] %[[#elt0]] %[[#elt1]]
%dot = call i64 @llvm.udot.v2i64(<2 x i64> %a, <2 x i64> %b)
ret i64 %dot
}

declare i16 @llvm.sdot.v2i16(<2 x i16>, <2 x i16>)
declare i32 @llvm.sdot.v4i32(<4 x i32>, <4 x i32>)
declare i16 @llvm.udot.v3i32(<3 x i16>, <3 x i16>)
declare i32 @llvm.udot.v4i32(<4 x i32>, <4 x i32>)
declare i64 @llvm.udot.v2i64(<2 x i64>, <2 x i64>)

0 comments on commit edbb80c

Please sign in to comment.