diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 98ec2c7f529ecd..7777aa4b50a370 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -1776,9 +1776,12 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM, setOperationAction(ISD::VECREDUCE_SEQ_FADD, VT, Custom); // Histcnt is SVE2 only - if (Subtarget->hasSVE2()) + if (Subtarget->hasSVE2()) { setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::Other, Custom); + setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i8, Custom); + setOperationAction(ISD::EXPERIMENTAL_VECTOR_HISTOGRAM, MVT::i16, Custom); + } } @@ -28175,11 +28178,18 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, EVT IncVT = Inc.getValueType(); EVT IndexVT = Index.getValueType(); - EVT MemVT = EVT::getVectorVT(*DAG.getContext(), IncVT, - IndexVT.getVectorElementCount()); + LLVMContext &Ctx = *DAG.getContext(); + ElementCount EC = IndexVT.getVectorElementCount(); + EVT MemVT = EVT::getVectorVT(Ctx, IncVT, EC); + EVT IncExtVT = + EVT::getIntegerVT(Ctx, AArch64::SVEBitsPerBlock / EC.getKnownMinValue()); + EVT IncSplatVT = EVT::getVectorVT(Ctx, IncExtVT, EC); + bool ExtTrunc = IncSplatVT != MemVT; + SDValue Zero = DAG.getConstant(0, DL, MVT::i64); - SDValue PassThru = DAG.getSplatVector(MemVT, DL, Zero); - SDValue IncSplat = DAG.getSplatVector(MemVT, DL, Inc); + SDValue PassThru = DAG.getSplatVector(IncSplatVT, DL, Zero); + SDValue IncSplat = DAG.getSplatVector( + IncSplatVT, DL, DAG.getAnyExtOrTrunc(Inc, DL, IncExtVT)); SDValue Ops[] = {Chain, PassThru, Mask, Ptr, Index, Scale}; MachineMemOperand *MMO = HG->getMemOperand(); @@ -28188,18 +28198,19 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, MMO->getPointerInfo(), MachineMemOperand::MOLoad, MMO->getSize(), MMO->getAlign(), MMO->getAAInfo()); ISD::MemIndexType IndexType = HG->getIndexType(); - SDValue Gather = - DAG.getMaskedGather(DAG.getVTList(MemVT, MVT::Other), MemVT, DL, Ops, - GMMO, IndexType, ISD::NON_EXTLOAD); + SDValue Gather = DAG.getMaskedGather( + DAG.getVTList(IncSplatVT, MVT::Other), MemVT, DL, Ops, GMMO, IndexType, + ExtTrunc ? ISD::EXTLOAD : ISD::NON_EXTLOAD); SDValue GChain = Gather.getValue(1); // Perform the histcnt, multiply by inc, add to bucket data. - SDValue ID = DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncVT); + SDValue ID = + DAG.getTargetConstant(Intrinsic::aarch64_sve_histcnt, DL, IncExtVT); SDValue HistCnt = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, IndexVT, ID, Mask, Index, Index); - SDValue Mul = DAG.getNode(ISD::MUL, DL, MemVT, HistCnt, IncSplat); - SDValue Add = DAG.getNode(ISD::ADD, DL, MemVT, Gather, Mul); + SDValue Mul = DAG.getNode(ISD::MUL, DL, IncSplatVT, HistCnt, IncSplat); + SDValue Add = DAG.getNode(ISD::ADD, DL, IncSplatVT, Gather, Mul); // Create an MMO for the scatter, without load|store flags. MachineMemOperand *SMMO = DAG.getMachineFunction().getMachineMemOperand( @@ -28208,7 +28219,7 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op, SDValue ScatterOps[] = {GChain, Add, Mask, Ptr, Index, Scale}; SDValue Scatter = DAG.getMaskedScatter(DAG.getVTList(MVT::Other), MemVT, DL, - ScatterOps, SMMO, IndexType, false); + ScatterOps, SMMO, IndexType, ExtTrunc); return Scatter; } diff --git a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll index db164e288abde3..2874e47511e12f 100644 --- a/llvm/test/CodeGen/AArch64/sve2-histcnt.ll +++ b/llvm/test/CodeGen/AArch64/sve2-histcnt.ll @@ -50,4 +50,123 @@ define void @histogram_i32_literal_noscale(ptr %base, %indice ret void } +define void @histogram_i32_promote(ptr %base, %indices, %mask, i32 %inc) #0 { +; CHECK-LABEL: histogram_i32_promote: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d +; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1 +; CHECK-NEXT: mov z3.d, x1 +; CHECK-NEXT: ld1w { z2.d }, p0/z, [x0, z0.d, lsl #2] +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d +; CHECK-NEXT: st1w { z1.d }, p0, [x0, z0.d, lsl #2] +; CHECK-NEXT: ret + %buckets = getelementptr i32, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv2p0.i32( %buckets, i32 %inc, %mask) + ret void +} + +define void @histogram_i16(ptr %base, %indices, %mask, i16 %inc) #0 { +; CHECK-LABEL: histogram_i16: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: mov z3.s, w1 +; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1] +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1] +; CHECK-NEXT: ret + %buckets = getelementptr i16, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv4p0.i16( %buckets, i16 %inc, %mask) + ret void +} + +define void @histogram_i8(ptr %base, %indices, %mask, i8 %inc) #0 { +; CHECK-LABEL: histogram_i8: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: mov z3.s, w1 +; CHECK-NEXT: ld1b { z2.s }, p0/z, [x0, z0.s, sxtw] +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1b { z1.s }, p0, [x0, z0.s, sxtw] +; CHECK-NEXT: ret + %buckets = getelementptr i8, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv4p0.i8( %buckets, i8 %inc, %mask) + ret void +} + +define void @histogram_i16_2_lane(ptr %base, %indices, %mask, i16 %inc) #0 { +; CHECK-LABEL: histogram_i16_2_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d +; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1 +; CHECK-NEXT: mov z3.d, x1 +; CHECK-NEXT: ld1h { z2.d }, p0/z, [x0, z0.d, lsl #1] +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d +; CHECK-NEXT: st1h { z1.d }, p0, [x0, z0.d, lsl #1] +; CHECK-NEXT: ret + %buckets = getelementptr i16, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv2p0.i16( %buckets, i16 %inc, %mask) + ret void +} + +define void @histogram_i8_2_lane(ptr %base, %indices, %mask, i8 %inc) #0 { +; CHECK-LABEL: histogram_i8_2_lane: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.d, p0/z, z0.d, z0.d +; CHECK-NEXT: // kill: def $w1 killed $w1 def $x1 +; CHECK-NEXT: mov z3.d, x1 +; CHECK-NEXT: ld1b { z2.d }, p0/z, [x0, z0.d] +; CHECK-NEXT: ptrue p1.d +; CHECK-NEXT: mad z1.d, p1/m, z3.d, z2.d +; CHECK-NEXT: st1b { z1.d }, p0, [x0, z0.d] +; CHECK-NEXT: ret + %buckets = getelementptr i8, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv2p0.i8( %buckets, i8 %inc, %mask) + ret void +} + +define void @histogram_i16_literal_1(ptr %base, %indices, %mask) #0 { +; CHECK-LABEL: histogram_i16_literal_1: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1] +; CHECK-NEXT: add z1.s, z2.s, z1.s +; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1] +; CHECK-NEXT: ret + %buckets = getelementptr i16, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv4p0.i16( %buckets, i16 1, %mask) + ret void +} + +define void @histogram_i16_literal_2(ptr %base, %indices, %mask) #0 { +; CHECK-LABEL: histogram_i16_literal_2: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1] +; CHECK-NEXT: adr z1.s, [z2.s, z1.s, lsl #1] +; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1] +; CHECK-NEXT: ret + %buckets = getelementptr i16, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv4p0.i16( %buckets, i16 2, %mask) + ret void +} + +define void @histogram_i16_literal_3(ptr %base, %indices, %mask) #0 { +; CHECK-LABEL: histogram_i16_literal_3: +; CHECK: // %bb.0: +; CHECK-NEXT: histcnt z1.s, p0/z, z0.s, z0.s +; CHECK-NEXT: mov z3.s, #3 // =0x3 +; CHECK-NEXT: ld1h { z2.s }, p0/z, [x0, z0.s, sxtw #1] +; CHECK-NEXT: ptrue p1.s +; CHECK-NEXT: mad z1.s, p1/m, z3.s, z2.s +; CHECK-NEXT: st1h { z1.s }, p0, [x0, z0.s, sxtw #1] +; CHECK-NEXT: ret + %buckets = getelementptr i16, ptr %base, %indices + call void @llvm.experimental.vector.histogram.add.nxv4p0.i16( %buckets, i16 3, %mask) + ret void +} + attributes #0 = { "target-features"="+sve2" vscale_range(1, 16) }