Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DAG] Add legalization handling for ABDS/ABDU #92576

Merged
merged 4 commits into from
Aug 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4089,13 +4089,13 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
}

// smax(a,b) - smin(a,b) --> abds(a,b)
if (hasOperation(ISD::ABDS, VT) &&
if ((!LegalOperations || hasOperation(ISD::ABDS, VT)) &&
sd_match(N0, m_SMax(m_Value(A), m_Value(B))) &&
sd_match(N1, m_SMin(m_Specific(A), m_Specific(B))))
return DAG.getNode(ISD::ABDS, DL, VT, A, B);

// umax(a,b) - umin(a,b) --> abdu(a,b)
if (hasOperation(ISD::ABDU, VT) &&
if ((!LegalOperations || hasOperation(ISD::ABDU, VT)) &&
sd_match(N0, m_UMax(m_Value(A), m_Value(B))) &&
sd_match(N1, m_UMin(m_Specific(A), m_Specific(B))))
return DAG.getNode(ISD::ABDU, DL, VT, A, B);
Expand Down Expand Up @@ -10922,6 +10922,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
(Opc0 != ISD::ZERO_EXTEND && Opc0 != ISD::SIGN_EXTEND &&
Opc0 != ISD::SIGN_EXTEND_INREG)) {
// fold (abs (sub nsw x, y)) -> abds(x, y)
// Don't fold this for unsupported types as we lose the NSW handling.
if (AbsOp1->getFlags().hasNoSignedWrap() && hasOperation(ISD::ABDS, VT) &&
TLI.preferABDSToABSWithNSW(VT)) {
SDValue ABD = DAG.getNode(ISD::ABDS, DL, VT, Op0, Op1);
Expand All @@ -10944,7 +10945,8 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {
// fold abs(zext(x) - zext(y)) -> zext(abdu(x, y))
EVT MaxVT = VT0.bitsGT(VT1) ? VT0 : VT1;
if ((VT0 == MaxVT || Op0->hasOneUse()) &&
(VT1 == MaxVT || Op1->hasOneUse()) && hasOperation(ABDOpcode, MaxVT)) {
(VT1 == MaxVT || Op1->hasOneUse()) &&
(!LegalOperations || hasOperation(ABDOpcode, MaxVT))) {
SDValue ABD = DAG.getNode(ABDOpcode, DL, MaxVT,
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op0),
DAG.getNode(ISD::TRUNCATE, DL, MaxVT, Op1));
Expand All @@ -10954,7 +10956,7 @@ SDValue DAGCombiner::foldABSToABD(SDNode *N, const SDLoc &DL) {

// fold abs(sext(x) - sext(y)) -> abds(sext(x), sext(y))
// fold abs(zext(x) - zext(y)) -> abdu(zext(x), zext(y))
if (hasOperation(ABDOpcode, VT)) {
if (!LegalOperations || hasOperation(ABDOpcode, VT)) {
SDValue ABD = DAG.getNode(ABDOpcode, DL, VT, Op0, Op1);
return DAG.getZExtOrTrunc(ABD, DL, SrcVT);
}
Expand Down Expand Up @@ -12376,7 +12378,7 @@ SDValue DAGCombiner::visitVSELECT(SDNode *N) {
N1.getOperand(1) == N2.getOperand(0)) {
bool IsSigned = isSignedIntSetCC(CC);
unsigned ABDOpc = IsSigned ? ISD::ABDS : ISD::ABDU;
if (hasOperation(ABDOpc, VT)) {
if (!LegalOperations || hasOperation(ABDOpc, VT)) {
switch (CC) {
case ISD::SETGT:
case ISD::SETGE:
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,17 +192,21 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::VP_SUB:
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;

case ISD::ABDS:
case ISD::AVGCEILS:
case ISD::AVGFLOORS:

case ISD::VP_SMIN:
case ISD::VP_SMAX:
case ISD::SDIV:
case ISD::SREM:
case ISD::VP_SDIV:
case ISD::VP_SREM: Res = PromoteIntRes_SExtIntBinOp(N); break;

case ISD::ABDU:
case ISD::AVGCEILU:
case ISD::AVGFLOORU:

case ISD::VP_UMIN:
case ISD::VP_UMAX:
case ISD::UDIV:
Expand Down Expand Up @@ -2791,6 +2795,8 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::PARITY: ExpandIntRes_PARITY(N, Lo, Hi); break;
case ISD::Constant: ExpandIntRes_Constant(N, Lo, Hi); break;
case ISD::ABS: ExpandIntRes_ABS(N, Lo, Hi); break;
case ISD::ABDS:
case ISD::ABDU: ExpandIntRes_ABD(N, Lo, Hi); break;
case ISD::CTLZ_ZERO_UNDEF:
case ISD::CTLZ: ExpandIntRes_CTLZ(N, Lo, Hi); break;
case ISD::CTPOP: ExpandIntRes_CTPOP(N, Lo, Hi); break;
Expand Down Expand Up @@ -3850,6 +3856,11 @@ void DAGTypeLegalizer::ExpandIntRes_CTLZ(SDNode *N,
Hi = DAG.getConstant(0, dl, NVT);
}

void DAGTypeLegalizer::ExpandIntRes_ABD(SDNode *N, SDValue &Lo, SDValue &Hi) {
SDValue Result = TLI.expandABD(N, DAG);
SplitInteger(Result, Lo, Hi);
}

void DAGTypeLegalizer::ExpandIntRes_CTPOP(SDNode *N,
SDValue &Lo, SDValue &Hi) {
SDLoc dl(N);
Expand Down
1 change: 1 addition & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void ExpandIntRes_AssertZext (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_Constant (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_ABS (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_ABD (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_CTLZ (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_CTPOP (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_CTTZ (SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
case ISD::FMINIMUM:
case ISD::FMAXIMUM:
case ISD::FLDEXP:
case ISD::ABDS:
case ISD::ABDU:
case ISD::SMIN:
case ISD::SMAX:
case ISD::UMIN:
Expand Down Expand Up @@ -1233,6 +1235,8 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::ABDS:
case ISD::ABDU:
case ISD::AVGCEILS:
case ISD::AVGCEILU:
case ISD::AVGFLOORS:
Expand Down Expand Up @@ -4368,6 +4372,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
case ISD::MUL: case ISD::VP_MUL:
case ISD::MULHS:
case ISD::MULHU:
case ISD::ABDS:
case ISD::ABDU:
case ISD::OR: case ISD::VP_OR:
case ISD::SUB: case ISD::VP_SUB:
case ISD::XOR: case ISD::VP_XOR:
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7024,6 +7024,8 @@ SDValue SelectionDAG::getNode(unsigned Opcode, const SDLoc &DL, EVT VT,
assert(VT.isInteger() && "This operator does not apply to FP types!");
assert(N1.getValueType() == N2.getValueType() &&
N1.getValueType() == VT && "Binary operator types must match!");
if (VT.isVector() && VT.getVectorElementType() == MVT::i1)
return getNode(ISD::XOR, DL, VT, N1, N2);
break;
case ISD::SMIN:
case ISD::UMAX:
Expand Down
32 changes: 32 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9311,6 +9311,21 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
DAG.getNode(ISD::USUBSAT, dl, VT, LHS, RHS),
DAG.getNode(ISD::USUBSAT, dl, VT, RHS, LHS));

// If the subtract doesn't overflow then just use abs(sub())
// NOTE: don't use frozen operands for value tracking.
bool IsNonNegative = DAG.SignBitIsZero(N->getOperand(1)) &&
DAG.SignBitIsZero(N->getOperand(0));

if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(0),
N->getOperand(1)))
return DAG.getNode(ISD::ABS, dl, VT,
DAG.getNode(ISD::SUB, dl, VT, LHS, RHS));

if (DAG.willNotOverflowSub(IsSigned || IsNonNegative, N->getOperand(1),
N->getOperand(0)))
return DAG.getNode(ISD::ABS, dl, VT,
DAG.getNode(ISD::SUB, dl, VT, RHS, LHS));

EVT CCVT = getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), VT);
ISD::CondCode CC = IsSigned ? ISD::CondCode::SETGT : ISD::CondCode::SETUGT;
SDValue Cmp = DAG.getSetCC(dl, CCVT, LHS, RHS, CC);
Expand All @@ -9324,6 +9339,23 @@ SDValue TargetLowering::expandABD(SDNode *N, SelectionDAG &DAG) const {
return DAG.getNode(ISD::SUB, dl, VT, Cmp, Xor);
}

// Similar to the branchless expansion, use the (sign-extended) usubo overflow
// flag if the (scalar) type is illegal as this is more likely to legalize
// cleanly:
// abdu(lhs, rhs) -> sub(xor(sub(lhs, rhs), uof(lhs, rhs)), uof(lhs, rhs))
if (!IsSigned && VT.isScalarInteger() && !isTypeLegal(VT)) {
SDValue USubO =
DAG.getNode(ISD::USUBO, dl, DAG.getVTList(VT, MVT::i1), {LHS, RHS});
SDValue Cmp = DAG.getNode(ISD::SIGN_EXTEND, dl, VT, USubO.getValue(1));
SDValue Xor = DAG.getNode(ISD::XOR, dl, VT, USubO.getValue(0), Cmp);
return DAG.getNode(ISD::SUB, dl, VT, Xor, Cmp);
}

// FIXME: Should really try to split the vector in case it's legal on a
// subvector.
if (VT.isVector() && !isOperationLegalOrCustom(ISD::VSELECT, VT))
return DAG.UnrollVectorOp(N);

// abds(lhs, rhs) -> select(sgt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
// abdu(lhs, rhs) -> select(ugt(lhs,rhs), sub(lhs,rhs), sub(rhs,lhs))
return DAG.getSelect(dl, VT, Cmp, DAG.getNode(ISD::SUB, dl, VT, LHS, RHS),
Expand Down
Loading
Loading