Skip to content

Commit

Permalink
[AArch64] Push mul into extend operands
Browse files Browse the repository at this point in the history
In a similar way to how we push vector adds into extends, this pushed
'mul(zext,zext)' into 'zext(mul(zext,zext))' if the extend can be done
in two or more steps.

https://alive2.llvm.org/ce/z/WjU7Kr
  • Loading branch information
davemgreen committed Jun 10, 2024
1 parent fe0dee4 commit fbc6669
Show file tree
Hide file tree
Showing 5 changed files with 740 additions and 892 deletions.
80 changes: 44 additions & 36 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17720,6 +17720,47 @@ static SDValue performMulVectorCmpZeroCombine(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::NVCAST, DL, VT, CM);
}

// Transform vector add(zext i8 to i32, zext i8 to i32)
// into sext(add(zext(i8 to i16), zext(i8 to i16)) to i32)
// This allows extra uses of saddl/uaddl at the lower vector widths, and less
// extends.
static SDValue performVectorExtCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (!VT.isFixedLengthVector() || VT.getSizeInBits() <= 128 ||
(N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND) ||
(N->getOperand(1).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(1).getOpcode() != ISD::SIGN_EXTEND) ||
N->getOperand(0).getOperand(0).getValueType() !=
N->getOperand(1).getOperand(0).getValueType())
return SDValue();

if (N->getOpcode() == ISD::MUL &&
N->getOperand(0).getOpcode() != N->getOperand(1).getOpcode())
return SDValue();

SDValue N0 = N->getOperand(0).getOperand(0);
SDValue N1 = N->getOperand(1).getOperand(0);
EVT InVT = N0.getValueType();

EVT S1 = InVT.getScalarType();
EVT S2 = VT.getScalarType();
if ((S2 == MVT::i32 && S1 == MVT::i8) ||
(S2 == MVT::i64 && (S1 == MVT::i8 || S1 == MVT::i16))) {
SDLoc DL(N);
EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
S2.getHalfSizedIntegerVT(*DAG.getContext()),
VT.getVectorElementCount());
SDValue NewN0 = DAG.getNode(N->getOperand(0).getOpcode(), DL, HalfVT, N0);
SDValue NewN1 = DAG.getNode(N->getOperand(1).getOpcode(), DL, HalfVT, N1);
SDValue NewOp = DAG.getNode(N->getOpcode(), DL, HalfVT, NewN0, NewN1);
return DAG.getNode(N->getOpcode() == ISD::MUL ? N->getOperand(0).getOpcode()
: (unsigned)ISD::SIGN_EXTEND,
DL, VT, NewOp);
}
return SDValue();
}

static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
Expand All @@ -17728,6 +17769,8 @@ static SDValue performMulCombine(SDNode *N, SelectionDAG &DAG,
return Ext;
if (SDValue Ext = performMulVectorCmpZeroCombine(N, DAG))
return Ext;
if (SDValue Ext = performVectorExtCombine(N, DAG))
return Ext;

if (DCI.isBeforeLegalizeOps())
return SDValue();
Expand Down Expand Up @@ -19604,41 +19647,6 @@ static SDValue foldADCToCINC(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(AArch64ISD::CSINC, DL, VT, LHS, LHS, CC, Cond);
}

// Transform vector add(zext i8 to i32, zext i8 to i32)
// into sext(add(zext(i8 to i16), zext(i8 to i16)) to i32)
// This allows extra uses of saddl/uaddl at the lower vector widths, and less
// extends.
static SDValue performVectorAddSubExtCombine(SDNode *N, SelectionDAG &DAG) {
EVT VT = N->getValueType(0);
if (!VT.isFixedLengthVector() || VT.getSizeInBits() <= 128 ||
(N->getOperand(0).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(0).getOpcode() != ISD::SIGN_EXTEND) ||
(N->getOperand(1).getOpcode() != ISD::ZERO_EXTEND &&
N->getOperand(1).getOpcode() != ISD::SIGN_EXTEND) ||
N->getOperand(0).getOperand(0).getValueType() !=
N->getOperand(1).getOperand(0).getValueType())
return SDValue();

SDValue N0 = N->getOperand(0).getOperand(0);
SDValue N1 = N->getOperand(1).getOperand(0);
EVT InVT = N0.getValueType();

EVT S1 = InVT.getScalarType();
EVT S2 = VT.getScalarType();
if ((S2 == MVT::i32 && S1 == MVT::i8) ||
(S2 == MVT::i64 && (S1 == MVT::i8 || S1 == MVT::i16))) {
SDLoc DL(N);
EVT HalfVT = EVT::getVectorVT(*DAG.getContext(),
S2.getHalfSizedIntegerVT(*DAG.getContext()),
VT.getVectorElementCount());
SDValue NewN0 = DAG.getNode(N->getOperand(0).getOpcode(), DL, HalfVT, N0);
SDValue NewN1 = DAG.getNode(N->getOperand(1).getOpcode(), DL, HalfVT, N1);
SDValue NewOp = DAG.getNode(N->getOpcode(), DL, HalfVT, NewN0, NewN1);
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, NewOp);
}
return SDValue();
}

static SDValue performBuildVectorCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
Expand Down Expand Up @@ -20260,7 +20268,7 @@ static SDValue performAddSubCombine(SDNode *N,
return Val;
if (SDValue Val = performNegCSelCombine(N, DCI.DAG))
return Val;
if (SDValue Val = performVectorAddSubExtCombine(N, DCI.DAG))
if (SDValue Val = performVectorExtCombine(N, DCI.DAG))
return Val;
if (SDValue Val = performAddCombineForShiftedOperands(N, DCI.DAG))
return Val;
Expand Down
108 changes: 41 additions & 67 deletions llvm/test/CodeGen/AArch64/aarch64-wide-mul.ll
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,12 @@ entry:
define <16 x i32> @mul_i32(<16 x i8> %a, <16 x i8> %b) {
; CHECK-SD-LABEL: mul_i32:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll v4.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v5.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll2 v6.8h, v1.16b, #0
; CHECK-SD-NEXT: umull v0.4s, v2.4h, v4.4h
; CHECK-SD-NEXT: umull2 v1.4s, v2.8h, v4.8h
; CHECK-SD-NEXT: umull2 v3.4s, v5.8h, v6.8h
; CHECK-SD-NEXT: umull v2.4s, v5.4h, v6.4h
; CHECK-SD-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: umull2 v4.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: ushll v0.4s, v2.4h, #0
; CHECK-SD-NEXT: ushll2 v3.4s, v4.8h, #0
; CHECK-SD-NEXT: ushll2 v1.4s, v2.8h, #0
; CHECK-SD-NEXT: ushll v2.4s, v4.4h, #0
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: mul_i32:
Expand All @@ -59,26 +57,20 @@ entry:
define <16 x i64> @mul_i64(<16 x i8> %a, <16 x i8> %b) {
; CHECK-SD-LABEL: mul_i64:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: ushll v2.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll v3.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-SD-NEXT: ushll v4.4s, v2.4h, #0
; CHECK-SD-NEXT: ushll v5.4s, v0.4h, #0
; CHECK-SD-NEXT: ushll v6.4s, v3.4h, #0
; CHECK-SD-NEXT: umull v2.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: ushll v3.4s, v2.4h, #0
; CHECK-SD-NEXT: ushll2 v2.4s, v2.8h, #0
; CHECK-SD-NEXT: ushll v16.4s, v1.4h, #0
; CHECK-SD-NEXT: ushll2 v7.4s, v3.8h, #0
; CHECK-SD-NEXT: ushll2 v17.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll2 v18.4s, v1.8h, #0
; CHECK-SD-NEXT: umull2 v1.2d, v4.4s, v6.4s
; CHECK-SD-NEXT: umull v0.2d, v4.2s, v6.2s
; CHECK-SD-NEXT: umull2 v3.2d, v2.4s, v7.4s
; CHECK-SD-NEXT: umull v2.2d, v2.2s, v7.2s
; CHECK-SD-NEXT: umull v4.2d, v5.2s, v16.2s
; CHECK-SD-NEXT: umull2 v7.2d, v17.4s, v18.4s
; CHECK-SD-NEXT: umull2 v5.2d, v5.4s, v16.4s
; CHECK-SD-NEXT: umull v6.2d, v17.2s, v18.2s
; CHECK-SD-NEXT: ushll v5.4s, v0.4h, #0
; CHECK-SD-NEXT: ushll2 v6.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll2 v1.2d, v3.4s, #0
; CHECK-SD-NEXT: ushll v0.2d, v3.2s, #0
; CHECK-SD-NEXT: ushll2 v3.2d, v2.4s, #0
; CHECK-SD-NEXT: ushll v2.2d, v2.2s, #0
; CHECK-SD-NEXT: ushll v4.2d, v5.2s, #0
; CHECK-SD-NEXT: ushll2 v7.2d, v6.4s, #0
; CHECK-SD-NEXT: ushll2 v5.2d, v5.4s, #0
; CHECK-SD-NEXT: ushll v6.2d, v6.2s, #0
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: mul_i64:
Expand Down Expand Up @@ -139,17 +131,12 @@ entry:
define <16 x i32> @mla_i32(<16 x i8> %a, <16 x i8> %b, <16 x i32> %c) {
; CHECK-SD-LABEL: mla_i32:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-SD-NEXT: umlal v2.4s, v6.4h, v7.4h
; CHECK-SD-NEXT: umlal2 v3.4s, v6.8h, v7.8h
; CHECK-SD-NEXT: umlal2 v5.4s, v0.8h, v1.8h
; CHECK-SD-NEXT: umlal v4.4s, v0.4h, v1.4h
; CHECK-SD-NEXT: mov v0.16b, v2.16b
; CHECK-SD-NEXT: mov v1.16b, v3.16b
; CHECK-SD-NEXT: mov v2.16b, v4.16b
; CHECK-SD-NEXT: umull2 v7.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: umull v6.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: uaddw2 v5.4s, v5.4s, v7.8h
; CHECK-SD-NEXT: uaddw v0.4s, v2.4s, v6.4h
; CHECK-SD-NEXT: uaddw2 v1.4s, v3.4s, v6.8h
; CHECK-SD-NEXT: uaddw v2.4s, v4.4s, v7.4h
; CHECK-SD-NEXT: mov v3.16b, v5.16b
; CHECK-SD-NEXT: ret
;
Expand Down Expand Up @@ -179,35 +166,22 @@ entry:
define <16 x i64> @mla_i64(<16 x i8> %a, <16 x i8> %b, <16 x i64> %c) {
; CHECK-SD-LABEL: mla_i64:
; CHECK-SD: // %bb.0: // %entry
; CHECK-SD-NEXT: mov v17.16b, v7.16b
; CHECK-SD-NEXT: mov v16.16b, v6.16b
; CHECK-SD-NEXT: ushll v6.8h, v0.8b, #0
; CHECK-SD-NEXT: ushll2 v0.8h, v0.16b, #0
; CHECK-SD-NEXT: ushll v7.8h, v1.8b, #0
; CHECK-SD-NEXT: ushll2 v1.8h, v1.16b, #0
; CHECK-SD-NEXT: ushll v18.4s, v6.4h, #0
; CHECK-SD-NEXT: ushll2 v21.4s, v6.8h, #0
; CHECK-SD-NEXT: ushll v19.4s, v0.4h, #0
; CHECK-SD-NEXT: ushll v20.4s, v7.4h, #0
; CHECK-SD-NEXT: ushll v22.4s, v1.4h, #0
; CHECK-SD-NEXT: ushll2 v23.4s, v7.8h, #0
; CHECK-SD-NEXT: ldp q6, q7, [sp]
; CHECK-SD-NEXT: ushll2 v0.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll2 v1.4s, v1.8h, #0
; CHECK-SD-NEXT: umlal2 v3.2d, v18.4s, v20.4s
; CHECK-SD-NEXT: umlal v2.2d, v18.2s, v20.2s
; CHECK-SD-NEXT: umlal v16.2d, v19.2s, v22.2s
; CHECK-SD-NEXT: umlal2 v5.2d, v21.4s, v23.4s
; CHECK-SD-NEXT: umlal v4.2d, v21.2s, v23.2s
; CHECK-SD-NEXT: umlal2 v17.2d, v19.4s, v22.4s
; CHECK-SD-NEXT: umlal2 v7.2d, v0.4s, v1.4s
; CHECK-SD-NEXT: umlal v6.2d, v0.2s, v1.2s
; CHECK-SD-NEXT: mov v0.16b, v2.16b
; CHECK-SD-NEXT: mov v1.16b, v3.16b
; CHECK-SD-NEXT: mov v2.16b, v4.16b
; CHECK-SD-NEXT: mov v3.16b, v5.16b
; CHECK-SD-NEXT: mov v4.16b, v16.16b
; CHECK-SD-NEXT: mov v5.16b, v17.16b
; CHECK-SD-NEXT: umull v16.8h, v0.8b, v1.8b
; CHECK-SD-NEXT: umull2 v0.8h, v0.16b, v1.16b
; CHECK-SD-NEXT: ldp q20, q21, [sp]
; CHECK-SD-NEXT: ushll v17.4s, v16.4h, #0
; CHECK-SD-NEXT: ushll2 v16.4s, v16.8h, #0
; CHECK-SD-NEXT: ushll2 v19.4s, v0.8h, #0
; CHECK-SD-NEXT: ushll v18.4s, v0.4h, #0
; CHECK-SD-NEXT: uaddw2 v1.2d, v3.2d, v17.4s
; CHECK-SD-NEXT: uaddw v0.2d, v2.2d, v17.2s
; CHECK-SD-NEXT: uaddw2 v3.2d, v5.2d, v16.4s
; CHECK-SD-NEXT: uaddw v2.2d, v4.2d, v16.2s
; CHECK-SD-NEXT: uaddw2 v16.2d, v21.2d, v19.4s
; CHECK-SD-NEXT: uaddw v4.2d, v6.2d, v18.2s
; CHECK-SD-NEXT: uaddw2 v5.2d, v7.2d, v18.4s
; CHECK-SD-NEXT: uaddw v6.2d, v20.2d, v19.2s
; CHECK-SD-NEXT: mov v7.16b, v16.16b
; CHECK-SD-NEXT: ret
;
; CHECK-GI-LABEL: mla_i64:
Expand Down
Loading

0 comments on commit fbc6669

Please sign in to comment.