Skip to content

Commit

Permalink
[AArch64] Implement promotion type legalisation for histogram intrins…
Browse files Browse the repository at this point in the history
…ic (#101017)

Currently the histogram intrinsic
(llvm.experimental.vector.histogram.add) only allows i32 and i64 types
for the memory locations to be updated, matching the restrictions of the
histcnt instruction. This patch adds support for the legalisation of
smaller types (i8 and i16) via promotion.
  • Loading branch information
DevM-uk authored Aug 12, 2024
1 parent 908c89e commit 670d208
Show file tree
Hide file tree
Showing 2 changed files with 142 additions and 12 deletions.
35 changes: 23 additions & 12 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1776,9 +1776,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom);

// Histcnt is SVE2 only
if (Subtarget->hasSVE2())
if (Subtarget->hasSVE2()) {
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other,
Custom);
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom);
setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom);
}
}


Expand Down Expand Up @@ -28175,11 +28178,18 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,

EVT IncVT = Inc.getValueType();
EVT IndexVT = Index.getValueType();
EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT,
IndexVT.getVectorElementCount());
LLVMContext &Ctx = *DAG.getContext();
ElementCount EC = IndexVT.getVectorElementCount();
EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC);
EVT IncExtVT =
EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue());
EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC);
bool ExtTrunc = IncSplatVT != MemVT;

SDValue Zero = DAG.getConstant(0, DL, MVT::i64);
SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero);
SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc);
SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero);
SDValue IncSplat = DAG.getSplatVector(
IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT));
SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale};

MachineMemOperand *MMO = HG->getMemOperand();
Expand All @@ -28188,18 +28198,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(),
MMO->getAlign(), MMO->getAAInfo());
ISD::MemIndexType IndexType = HG->getIndexType();
SDValue Gather =
DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops,
GMMO, IndexType, ISD::NON_EXTLOAD);
SDValue Gather = DAG.getMaskedGather(
DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType,
ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD);

SDValue GChain = Gather.getValue(1);

// Perform the histcnt, multiply by inc, add to bucket data.
SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT);
SDValue ID =
DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT);
SDValue HistCnt =
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index);
SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat);
SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul);
SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat);
SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul);

// Create an MMO for the scatter, without load|store flags.
MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand(
Expand All @@ -28208,7 +28219,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,

SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale};
SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL,
ScatterOps, SMMO, IndexType, false);
ScatterOps, SMMO, IndexType, ExtTrunc);
return Scatter;
}

Expand Down
119 changes: 119 additions & 0 deletions llvm/test/CodeGen/AArch64/sve2-histcnt.ll
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, <vscale x 4 x i32> %indice
ret void
}

define void @histogram_i32_promote(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i32 %inc) #0 {
; CHECK-LABEL: histogram_i32_promote:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
; CHECK-NEXT: mov z3.d, x1
; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2]
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, lsl #2]
; CHECK-NEXT: ret
%buckets = getelementptr i32, ptr %base, <vscale x 2 x i64> %indices
call void @llvm.experimental.vector.histogram.add.nxv2p0.i32(<vscale x 2 x ptr> %buckets, i32 %inc, <vscale x 2 x i1> %mask)
ret void
}

define void @histogram_i16(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i16 %inc) #0 {
; CHECK-LABEL: histogram_i16:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, w1
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 %inc, <vscale x 4 x i1> %mask)
ret void
}

define void @histogram_i8(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask, i8 %inc) #0 {
; CHECK-LABEL: histogram_i8:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, w1
; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, sxtw]
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, sxtw]
; CHECK-NEXT: ret
%buckets = getelementptr i8, ptr %base, <vscale x 4 x i32> %indices
call void @llvm.experimental.vector.histogram.add.nxv4p0.i8(<vscale x 4 x ptr> %buckets, i8 %inc, <vscale x 4 x i1> %mask)
ret void
}

define void @histogram_i16_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i16 %inc) #0 {
; CHECK-LABEL: histogram_i16_2_lane:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
; CHECK-NEXT: mov z3.d, x1
; CHECK-NEXT: ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1]
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
; CHECK-NEXT: st1h { z1.d }, p0, [x0, z0.d, lsl #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 2 x i64> %indices
call void @llvm.experimental.vector.histogram.add.nxv2p0.i16(<vscale x 2 x ptr> %buckets, i16 %inc, <vscale x 2 x i1> %mask)
ret void
}

define void @histogram_i8_2_lane(ptr %base, <vscale x 2 x i64> %indices, <vscale x 2 x i1> %mask, i8 %inc) #0 {
; CHECK-LABEL: histogram_i8_2_lane:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d
; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1
; CHECK-NEXT: mov z3.d, x1
; CHECK-NEXT: ld1b { z2.d }, p0/z, [x0, z0.d]
; CHECK-NEXT: ptrue p1.d
; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d
; CHECK-NEXT: st1b { z1.d }, p0, [x0, z0.d]
; CHECK-NEXT: ret
%buckets = getelementptr i8, ptr %base, <vscale x 2 x i64> %indices
call void @llvm.experimental.vector.histogram.add.nxv2p0.i8(<vscale x 2 x ptr> %buckets, i8 %inc, <vscale x 2 x i1> %mask)
ret void
}

define void @histogram_i16_literal_1(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i16_literal_1:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
; CHECK-NEXT: add z1.s, z2.s, z1.s
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 1, <vscale x 4 x i1> %mask)
ret void
}

define void @histogram_i16_literal_2(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i16_literal_2:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1]
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 2, <vscale x 4 x i1> %mask)
ret void
}

define void @histogram_i16_literal_3(ptr %base, <vscale x 4 x i32> %indices, <vscale x 4 x i1> %mask) #0 {
; CHECK-LABEL: histogram_i16_literal_3:
; CHECK: // %bb.0:
; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s
; CHECK-NEXT: mov z3.s, #3 // =0x3
; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ptrue p1.s
; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s
; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1]
; CHECK-NEXT: ret
%buckets = getelementptr i16, ptr %base, <vscale x 4 x i32> %indices
call void @llvm.experimental.vector.histogram.add.nxv4p0.i16(<vscale x 4 x ptr> %buckets, i16 3, <vscale x 4 x i1> %mask)
ret void
}

attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }

0 comments on commit 670d208

Please sign in to comment.