Skip to content

Commit

Permalink
[RISCV] Add a combine to form masked.load from unit strided load (llv…
Browse files Browse the repository at this point in the history
…m#65674)

Add a DAG combine to form a masked.load from a masked_strided_load
intrinsic with stride equal to element size. This covers a couple of
extra test cases, and allows us to simplify and common some existing
code on the concat_vector(load, ...) to strided load transform.

This is the first in a mini-patch series to try and generalize our
strided load and gather matching to handle more cases, and common up
different approaches to the same problems in different places.
  • Loading branch information
preames authored and ZijunZhaoCCK committed Sep 19, 2023
1 parent 45ba533 commit 84f9d7f
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 39 deletions.
62 changes: 29 additions & 33 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13371,27 +13371,6 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

// A special case is if the stride is exactly the width of one of the loads,
// in which case it's contiguous and can be combined into a regular vle
// without changing the element size
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
ConstStride && !Reversed &&
ConstStride->getZExtValue() == BaseLdVT.getFixedSizeInBits() / 8) {
MachineMemOperand *MMO = DAG.getMachineFunction().getMachineMemOperand(
BaseLd->getPointerInfo(), BaseLd->getMemOperand()->getFlags(),
VT.getStoreSize(), Align);
// Can't do the combine if the load isn't naturally aligned with the element
// type
if (!TLI.allowsMemoryAccessForAlignment(*DAG.getContext(),
DAG.getDataLayout(), VT, *MMO))
return SDValue();

SDValue WideLoad = DAG.getLoad(VT, DL, BaseLd->getChain(), BasePtr, MMO);
for (SDValue Ld : N->ops())
DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), WideLoad);
return WideLoad;
}

// Get the widened scalar type, e.g. v4i8 -> i64
unsigned WideScalarBitWidth =
BaseLdVT.getScalarSizeInBits() * BaseLdVT.getVectorNumElements();
Expand All @@ -13406,20 +13385,22 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
if (!TLI.isLegalStridedLoadStore(WideVecVT, Align))
return SDValue();

MVT ContainerVT = TLI.getContainerForFixedLengthVector(WideVecVT);
SDValue VL =
getDefaultVLOps(WideVecVT, ContainerVT, DL, DAG, Subtarget).second;
SDVTList VTs = DAG.getVTList({ContainerVT, MVT::Other});
SDVTList VTs = DAG.getVTList({WideVecVT, MVT::Other});
SDValue IntID =
DAG.getTargetConstant(Intrinsic::riscv_vlse, DL, Subtarget.getXLenVT());
DAG.getTargetConstant(Intrinsic::riscv_masked_strided_load, DL,
Subtarget.getXLenVT());
if (Reversed)
Stride = DAG.getNegative(Stride, DL, Stride->getValueType(0));
SDValue AllOneMask =
DAG.getSplat(WideVecVT.changeVectorElementType(MVT::i1), DL,
DAG.getConstant(1, DL, MVT::i1));

SDValue Ops[] = {BaseLd->getChain(),
IntID,
DAG.getUNDEF(ContainerVT),
DAG.getUNDEF(WideVecVT),
BasePtr,
Stride,
VL};
AllOneMask};

uint64_t MemSize;
if (auto *ConstStride = dyn_cast<ConstantSDNode>(Stride);
Expand All @@ -13441,11 +13422,7 @@ static SDValue performCONCAT_VECTORSCombine(SDNode *N, SelectionDAG &DAG,
for (SDValue Ld : N->ops())
DAG.makeEquivalentMemoryOrdering(cast<LoadSDNode>(Ld), StridedLoad);

// Note: Perform the bitcast before the convertFromScalableVector so we have
// balanced pairs of convertFromScalable/convertToScalable
SDValue Res = DAG.getBitcast(
TLI.getContainerForFixedLengthVector(VT.getSimpleVT()), StridedLoad);
return convertFromScalableVector(VT, Res, DAG, Subtarget);
return DAG.getBitcast(VT.getSimpleVT(), StridedLoad);
}

static SDValue combineToVWMACC(SDNode *N, SelectionDAG &DAG,
Expand Down Expand Up @@ -14184,6 +14161,25 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
// By default we do not combine any intrinsic.
default:
return SDValue();
case Intrinsic::riscv_masked_strided_load: {
MVT VT = N->getSimpleValueType(0);
auto *Load = cast<MemIntrinsicSDNode>(N);
SDValue PassThru = N->getOperand(2);
SDValue Base = N->getOperand(3);
SDValue Stride = N->getOperand(4);
SDValue Mask = N->getOperand(5);

// If the stride is equal to the element size in bytes, we can use
// a masked.load.
const unsigned ElementSize = VT.getScalarStoreSize();
if (auto *StrideC = dyn_cast<ConstantSDNode>(Stride);
StrideC && StrideC->getZExtValue() == ElementSize)
return DAG.getMaskedLoad(VT, DL, Load->getChain(), Base,
DAG.getUNDEF(XLenVT), Mask, PassThru,
Load->getMemoryVT(), Load->getMemOperand(),
ISD::UNINDEXED, ISD::NON_EXTLOAD);
return SDValue();
}
case Intrinsic::riscv_vcpop:
case Intrinsic::riscv_vcpop_mask:
case Intrinsic::riscv_vfirst:
Expand Down
6 changes: 2 additions & 4 deletions llvm/test/CodeGen/RISCV/rvv/fixed-vectors-masked-gather.ll
Original file line number Diff line number Diff line change
Expand Up @@ -13010,9 +13010,8 @@ define <4 x i32> @mgather_broadcast_load_masked(ptr %base, <4 x i1> %m) {
define <4 x i32> @mgather_unit_stride_load(ptr %base) {
; RV32-LABEL: mgather_unit_stride_load:
; RV32: # %bb.0:
; RV32-NEXT: li a1, 4
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; RV32-NEXT: vlse32.v v8, (a0), a1
; RV32-NEXT: vle32.v v8, (a0)
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_unit_stride_load:
Expand Down Expand Up @@ -13082,9 +13081,8 @@ define <4 x i32> @mgather_unit_stride_load_with_offset(ptr %base) {
; RV32-LABEL: mgather_unit_stride_load_with_offset:
; RV32: # %bb.0:
; RV32-NEXT: addi a0, a0, 16
; RV32-NEXT: li a1, 4
; RV32-NEXT: vsetivli zero, 4, e32, m1, ta, ma
; RV32-NEXT: vlse32.v v8, (a0), a1
; RV32-NEXT: vle32.v v8, (a0)
; RV32-NEXT: ret
;
; RV64V-LABEL: mgather_unit_stride_load_with_offset:
Expand Down
3 changes: 1 addition & 2 deletions llvm/test/CodeGen/RISCV/rvv/strided-load-store-intrinsics.ll
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ define <32 x i8> @strided_load_i8_nostride(ptr %p, <32 x i1> %m) {
; CHECK-LABEL: strided_load_i8_nostride:
; CHECK: # %bb.0:
; CHECK-NEXT: li a1, 32
; CHECK-NEXT: li a2, 1
; CHECK-NEXT: vsetvli zero, a1, e8, m2, ta, ma
; CHECK-NEXT: vlse8.v v8, (a0), a2, v0.t
; CHECK-NEXT: vle8.v v8, (a0), v0.t
; CHECK-NEXT: ret
%res = call <32 x i8> @llvm.riscv.masked.strided.load.v32i8.p0.i64(<32 x i8> undef, ptr %p, i64 1, <32 x i1> %m)
ret <32 x i8> %res
Expand Down

0 comments on commit 84f9d7f

Please sign in to comment.