Skip to content

Commit

Permalink
[DAG] Add legalization handling for ABDS/ABDU (llvm#92576) (REAPPLIED)
Browse files Browse the repository at this point in the history
Always match ABD patterns pre-legalization, and use TargetLowering::expandABD to expand again during legalization.

abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), usub_overflow(lhs, rhs)), usub_overflow(lhs, rhs))
Alive2: https://alive2.llvm.org/ce/z/dVdMyv

REAPPLIED: Fix regression issue with "abs(ext(x) - ext(y)) -> zext(abd(x, y))" fold failing after type legalization
  • Loading branch information
RKSimon authored and kstoimenov committed Aug 15, 2024
1 parent 674623e commit e91a6a2
Show file tree
Hide file tree
Showing 29 changed files with 3,206 additions and 4,084 deletions.
16 changes: 11 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4091,13 +4091,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
}

// smax(a,b) - smin(a,b) --> abds(a,b)
if (hasOperation(ISD::ABDS, VT) &&
if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
return DAG.getNode(ISD::ABDS, DL, VT, A, B);

// umax(a,b) - umin(a,b) --> abdu(a,b)
if (hasOperation(ISD::ABDU, VT) &&
if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
return DAG.getNode(ISD::ABDU, DL, VT, A, B);
Expand Down Expand Up @@ -5263,6 +5263,10 @@ SDValue DAGCombiner::visitABD(SDNode *N) {
if (N0.isUndef() || N1.isUndef())
return DAG.getConstant(0, DL, VT);

// fold (abd x, x) -> 0
if (N0 == N1)
return DAG.getConstant(0, DL, VT);

SDValue X;

// fold (abds x, 0) -> abs x
Expand Down Expand Up @@ -10924,6 +10928,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
Opc0 != ISD::SIGN_EXTEND_INREG)) {
// fold (abs (sub nsw x, y)) -> abds(x, y)
// Don't fold this for unsupported types as we lose the NSW handling.
if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) &&
TLI.preferABDSToABSWithNSW(VT)) {
SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
Expand All @@ -10946,7 +10951,8 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
// fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
if ((VT0 == MaxVT || Op0->hasOneUse()) &&
(VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(ABDOpcode, MaxVT)) {
(VT1 == MaxVT || Op1->hasOneUse()) &&
(!LegalTypes || hasOperation(ABDOpcode, MaxVT))) {
SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
Expand All @@ -10956,7 +10962,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {

// fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
// fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
if (hasOperation(ABDOpcode, VT)) {
if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
}
Expand Down Expand Up @@ -11580,7 +11586,7 @@ SDValue DAGCombiner::foldSelectToABD(SDValue LHS, SDValue RHS, SDValue True,
unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
EVT VT = LHS.getValueType();

if (!hasOperation(ABDOpc, VT))
if (LegalOperations && !hasOperation(ABDOpc, VT))
return SDValue();

switch (CC) {
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::VP_SUB:
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;

case ISD::ABDS:
case ISD::AVGCEILS:
case ISD::AVGFLOORS:
case ISD::VP_SMIN:
Expand All @@ -201,6 +202,7 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::VP_SDIV:
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;

case ISD::ABDU:
case ISD::AVGCEILU:
case ISD::AVGFLOORU:
case ISD::VP_UMIN:
Expand Down Expand Up @@ -2791,6 +2793,8 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::PARITY: ExpandIntRes_PARITY(N, Lo, Hi); break;
case ISD::Constant: ExpandIntRes_Constant(N, Lo, Hi); break;
case ISD::ABS: ExpandIntRes_ABS(N, Lo, Hi); break;
case ISD::ABDS:
case ISD::ABDU: ExpandIntRes_ABD(N, Lo, Hi); break;
case ISD::CTLZ_ZERO_UNDEF:
case ISD::CTLZ: ExpandIntRes_CTLZ(N, Lo, Hi); break;
case ISD::CTPOP: ExpandIntRes_CTPOP(N, Lo, Hi); break;
Expand Down Expand Up @@ -3850,6 +3854,11 @@ void DAGTypeLegalizer::ExpandIntRes_CTLZ(SDNode *N,
Hi = DAG.getConstant(0, dl, NVT);
}

void DAGTypeLegalizer::ExpandIntRes_ABD(SDNode *N, SDValue &Lo, SDValue &Hi) {
SDValue Result = TLI.expandABD(N, DAG);
SplitInteger(Result, Lo, Hi);
}

void DAGTypeLegalizer::ExpandIntRes_CTPOP(SDNode *N,
SDValue &Lo, SDValue &Hi) {
SDLoc dl(N);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void ExpandIntRes_AssertZext (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_Constant (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_ABS (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_ABD (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_CTLZ (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_CTPOP (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_CTTZ (SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FMINIMUM:
case ISD::FMAXIMUM:
case ISD::FLDEXP:
case ISD::ABDS:
case ISD::ABDU:
case ISD::SMIN:
case ISD::SMAX:
case ISD::UMIN:
Expand Down Expand Up @@ -1233,6 +1235,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::ABDS:
case ISD::ABDU:
case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
Expand Down Expand Up @@ -4368,6 +4372,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::ABDS:
case ISD::ABDU:
case ISD::OR: case ISD::VP_OR:
case ISD::SUB: case ISD::VP_SUB:
case ISD::XOR: case ISD::VP_XOR:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7024,6 +7024,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
assert(VT.isInteger() && "This operator does not apply to FP types!");
assert(N1.getValueType() == N2.getValueType() &&
N1.getValueType() == VT && "Binary operator types must match!");
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
return getNode(ISD::XOR, DL, VT, N1, N2);
break;
case ISD::SMIN:
case ISD::UMAX:
Expand Down
32 changes: 32 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9311,6 +9311,21 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));

// If the subtract doesn't overflow then just use abs(sub())
// NOTE: don't use frozen operands for value tracking.
bool IsNonNegative = DAG.SignBitIsZero(N->getOperand(1)) &&
DAG.SignBitIsZero(N->getOperand(0));

if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(0),
N->getOperand(1)))
return DAG.getNode(ISD::ABS, dl, VT,
DAG.getNode(ISD::SUB, dl, VT, LHS, RHS));

if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(1),
N->getOperand(0)))
return DAG.getNode(ISD::ABS, dl, VT,
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));

EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
Expand All @@ -9324,6 +9339,23 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
return DAG.getNode(ISD::SUB, dl, VT, Cmp, Xor);
}

// Similar to the branchless expansion, use the (sign-extended) usubo overflow
// flag if the (scalar) type is illegal as this is more likely to legalize
// cleanly:
// abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), uof(lhs, rhs)), uof(lhs, rhs))
if (!IsSigned && VT.isScalarInteger() && !isTypeLegal(VT)) {
SDValue USubO =
DAG.getNode(ISD::USUBO, dl, DAG.getVTList(VT, MVT::i1), {LHS, RHS});
SDValue Cmp = DAG.getNode(ISD::SIGN_EXTEND, dl, VT, USubO.getValue(1));
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, USubO.getValue(0), Cmp);
return DAG.getNode(ISD::SUB, dl, VT, Xor, Cmp);
}

// FIXME: Should really try to split the vector in case it's legal on a
// subvector.
if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
return DAG.UnrollVectorOp(N);

// abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
// abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
Expand Down
Loading

0 comments on commit e91a6a2

Please sign in to comment.