Skip to content

Commit

Permalink
[SelectionDAG] Better legalization for FSHL and FSHR
Browse files Browse the repository at this point in the history
In SelectionDAGBuilder always translate the fshl and fshr intrinsics to
FSHL and FSHR (or ROTL and ROTR) instead of lowering them to shifts and
ORs. Improve the legalization of FSHL and FSHR to avoid code quality
regressions.

Differential Revision: https://reviews.llvm.org/D77152
  • Loading branch information
jayfoad committed Aug 21, 2020
1 parent c6863a4 commit 0819a64
Show file tree
Hide file tree
Showing 32 changed files with 8,323 additions and 9,164 deletions.
73 changes: 73 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,16 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::FREEZE:
Res = PromoteIntRes_FREEZE(N);
break;

case ISD::ROTL:
case ISD::ROTR:
Res = PromoteIntRes_Rotate(N);
break;

case ISD::FSHL:
case ISD::FSHR:
Res = PromoteIntRes_FunnelShift(N);
break;
}

// If the result is null then the sub-method took care of registering it.
Expand Down Expand Up @@ -1105,6 +1115,43 @@ SDValue DAGTypeLegalizer::PromoteIntRes_SRL(SDNode *N) {
return DAG.getNode(ISD::SRL, SDLoc(N), LHS.getValueType(), LHS, RHS);
}

SDValue DAGTypeLegalizer::PromoteIntRes_Rotate(SDNode *N) {
// Lower the rotate to shifts and ORs which can be promoted.
SDValue Res;
TLI.expandROT(N, Res, DAG);
ReplaceValueWith(SDValue(N, 0), Res);
return SDValue();
}

SDValue DAGTypeLegalizer::PromoteIntRes_FunnelShift(SDNode *N) {
SDValue Hi = GetPromotedInteger(N->getOperand(0));
SDValue Lo = GetPromotedInteger(N->getOperand(1));
SDValue Amount = GetPromotedInteger(N->getOperand(2));

unsigned OldBits = N->getOperand(0).getScalarValueSizeInBits();
unsigned NewBits = Hi.getScalarValueSizeInBits();

// Shift Lo up to occupy the upper bits of the promoted type.
SDLoc DL(N);
EVT VT = Lo.getValueType();
Lo = DAG.getNode(ISD::SHL, DL, VT, Lo,
DAG.getConstant(NewBits - OldBits, DL, VT));

// Amount has to be interpreted modulo the old bit width.
Amount =
DAG.getNode(ISD::UREM, DL, VT, Amount, DAG.getConstant(OldBits, DL, VT));

unsigned Opcode = N->getOpcode();
if (Opcode == ISD::FSHR) {
// Increase Amount to shift the result into the lower bits of the promoted
// type.
Amount = DAG.getNode(ISD::ADD, DL, VT, Amount,
DAG.getConstant(NewBits - OldBits, DL, VT));
}

return DAG.getNode(Opcode, DL, VT, Hi, Lo, Amount);
}

SDValue DAGTypeLegalizer::PromoteIntRes_TRUNCATE(SDNode *N) {
EVT NVT = TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0));
SDValue Res;
Expand Down Expand Up @@ -2059,6 +2106,16 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
case ISD::VECREDUCE_SMIN:
case ISD::VECREDUCE_UMAX:
case ISD::VECREDUCE_UMIN: ExpandIntRes_VECREDUCE(N, Lo, Hi); break;

case ISD::ROTL:
case ISD::ROTR:
ExpandIntRes_Rotate(N, Lo, Hi);
break;

case ISD::FSHL:
case ISD::FSHR:
ExpandIntRes_FunnelShift(N, Lo, Hi);
break;
}

// If Lo/Hi is null, the sub-method took care of registering results etc.
Expand Down Expand Up @@ -3895,6 +3952,22 @@ void DAGTypeLegalizer::ExpandIntRes_VECREDUCE(SDNode *N,
SplitInteger(Res, Lo, Hi);
}

void DAGTypeLegalizer::ExpandIntRes_Rotate(SDNode *N,
SDValue &Lo, SDValue &Hi) {
// Lower the rotate to shifts and ORs which can be expanded.
SDValue Res;
TLI.expandROT(N, Res, DAG);
SplitInteger(Res, Lo, Hi);
}

void DAGTypeLegalizer::ExpandIntRes_FunnelShift(SDNode *N,
SDValue &Lo, SDValue &Hi) {
// Lower the funnel shift to shifts and ORs which can be expanded.
SDValue Res;
TLI.expandFunnelShift(N, Res, DAG);
SplitInteger(Res, Lo, Hi);
}

//===----------------------------------------------------------------------===//
// Integer Operand Expansion
//===----------------------------------------------------------------------===//
Expand Down
5 changes: 5 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,8 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
SDValue PromoteIntRes_FLT_ROUNDS(SDNode *N);
SDValue PromoteIntRes_VECREDUCE(SDNode *N);
SDValue PromoteIntRes_ABS(SDNode *N);
SDValue PromoteIntRes_Rotate(SDNode *N);
SDValue PromoteIntRes_FunnelShift(SDNode *N);

// Integer Operand Promotion.
bool PromoteIntegerOperand(SDNode *N, unsigned OpNo);
Expand Down Expand Up @@ -449,6 +451,9 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
void ExpandIntRes_ATOMIC_LOAD (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_VECREDUCE (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandIntRes_Rotate (SDNode *N, SDValue &Lo, SDValue &Hi);
void ExpandIntRes_FunnelShift (SDNode *N, SDValue &Lo, SDValue &Hi);

void ExpandShiftByConstant(SDNode *N, const APInt &Amt,
SDValue &Lo, SDValue &Hi);
bool ExpandShiftWithKnownAmountBit(SDNode *N, SDValue &Lo, SDValue &Hi);
Expand Down
8 changes: 8 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,8 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
R = ScalarizeVecRes_BinOp(N);
break;
case ISD::FMA:
case ISD::FSHL:
case ISD::FSHR:
R = ScalarizeVecRes_TernaryOp(N);
break;

Expand Down Expand Up @@ -946,9 +948,13 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
case ISD::USUBSAT:
case ISD::SSHLSAT:
case ISD::USHLSAT:
case ISD::ROTL:
case ISD::ROTR:
SplitVecRes_BinOp(N, Lo, Hi);
break;
case ISD::FMA:
case ISD::FSHL:
case ISD::FSHR:
SplitVecRes_TernaryOp(N, Lo, Hi);
break;

Expand Down Expand Up @@ -2926,6 +2932,8 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
Res = WidenVecRes_Unary(N);
break;
case ISD::FMA:
case ISD::FSHL:
case ISD::FSHR:
Res = WidenVecRes_Ternary(N);
break;
}
Expand Down
58 changes: 5 additions & 53 deletions llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6252,62 +6252,14 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
SDValue Y = getValue(I.getArgOperand(1));
SDValue Z = getValue(I.getArgOperand(2));
EVT VT = X.getValueType();
SDValue BitWidthC = DAG.getConstant(VT.getScalarSizeInBits(), sdl, VT);
SDValue Zero = DAG.getConstant(0, sdl, VT);
SDValue ShAmt = DAG.getNode(ISD::UREM, sdl, VT, Z, BitWidthC);

// When X == Y, this is rotate. If the data type has a power-of-2 size, we
// avoid the select that is necessary in the general case to filter out
// the 0-shift possibility that leads to UB.
if (X == Y && isPowerOf2_32(VT.getScalarSizeInBits())) {
auto RotateOpcode = IsFSHL ? ISD::ROTL : ISD::ROTR;
if (TLI.isOperationLegalOrCustom(RotateOpcode, VT)) {
setValue(&I, DAG.getNode(RotateOpcode, sdl, VT, X, Z));
return;
}

// Some targets only rotate one way. Try the opposite direction.
RotateOpcode = IsFSHL ? ISD::ROTR : ISD::ROTL;
if (TLI.isOperationLegalOrCustom(RotateOpcode, VT)) {
// Negate the shift amount because it is safe to ignore the high bits.
SDValue NegShAmt = DAG.getNode(ISD::SUB, sdl, VT, Zero, Z);
setValue(&I, DAG.getNode(RotateOpcode, sdl, VT, X, NegShAmt));
return;
}

// fshl (rotl): (X << (Z % BW)) | (X >> ((0 - Z) % BW))
// fshr (rotr): (X << ((0 - Z) % BW)) | (X >> (Z % BW))
SDValue NegZ = DAG.getNode(ISD::SUB, sdl, VT, Zero, Z);
SDValue NShAmt = DAG.getNode(ISD::UREM, sdl, VT, NegZ, BitWidthC);
SDValue ShX = DAG.getNode(ISD::SHL, sdl, VT, X, IsFSHL ? ShAmt : NShAmt);
SDValue ShY = DAG.getNode(ISD::SRL, sdl, VT, X, IsFSHL ? NShAmt : ShAmt);
setValue(&I, DAG.getNode(ISD::OR, sdl, VT, ShX, ShY));
return;
}

auto FunnelOpcode = IsFSHL ? ISD::FSHL : ISD::FSHR;
if (TLI.isOperationLegalOrCustom(FunnelOpcode, VT)) {
if (X == Y) {
auto RotateOpcode = IsFSHL ? ISD::ROTL : ISD::ROTR;
setValue(&I, DAG.getNode(RotateOpcode, sdl, VT, X, Z));
} else {
auto FunnelOpcode = IsFSHL ? ISD::FSHL : ISD::FSHR;
setValue(&I, DAG.getNode(FunnelOpcode, sdl, VT, X, Y, Z));
return;
}

// fshl: (X << (Z % BW)) | (Y >> (BW - (Z % BW)))
// fshr: (X << (BW - (Z % BW))) | (Y >> (Z % BW))
SDValue InvShAmt = DAG.getNode(ISD::SUB, sdl, VT, BitWidthC, ShAmt);
SDValue ShX = DAG.getNode(ISD::SHL, sdl, VT, X, IsFSHL ? ShAmt : InvShAmt);
SDValue ShY = DAG.getNode(ISD::SRL, sdl, VT, Y, IsFSHL ? InvShAmt : ShAmt);
SDValue Or = DAG.getNode(ISD::OR, sdl, VT, ShX, ShY);

// If (Z % BW == 0), then the opposite direction shift is shift-by-bitwidth,
// and that is undefined. We must compare and select to avoid UB.
EVT CCVT = MVT::i1;
if (VT.isVector())
CCVT = EVT::getVectorVT(*Context, CCVT, VT.getVectorNumElements());

// For fshl, 0-shift returns the 1st arg (X).
// For fshr, 0-shift returns the 2nd arg (Y).
SDValue IsZeroShift = DAG.getSetCC(sdl, CCVT, ShAmt, Zero, ISD::SETEQ);
setValue(&I, DAG.getSelect(sdl, VT, IsZeroShift, IsFSHL ? X : Y, Or));
return;
}
case Intrinsic::sadd_sat: {
Expand Down
12 changes: 12 additions & 0 deletions llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6156,6 +6156,18 @@ bool TargetLowering::expandFunnelShift(SDNode *Node, SDValue &Result,

EVT ShVT = Z.getValueType();

assert(isPowerOf2_32(BW) && "Expecting the type bitwidth to be a power of 2");

// If a funnel shift in the other direction is more supported, use it.
unsigned RevOpcode = IsFSHL ? ISD::FSHR : ISD::FSHL;
if (!isOperationLegalOrCustom(Node->getOpcode(), VT) &&
isOperationLegalOrCustom(RevOpcode, VT)) {
SDValue Zero = DAG.getConstant(0, DL, ShVT);
SDValue Sub = DAG.getNode(ISD::SUB, DL, ShVT, Zero, Z);
Result = DAG.getNode(RevOpcode, DL, VT, X, Y, Sub);
return true;
}

SDValue ShX, ShY;
SDValue ShAmt, InvShAmt;
if (isNonZeroModBitWidth(Z, BW)) {
Expand Down
73 changes: 18 additions & 55 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -124,25 +124,37 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
case ISD::SRL: {
if (!Subtarget->is64Bit())
break;
SDValue Op0 = Node->getOperand(0);
SDValue Op1 = Node->getOperand(1);
SDNode *Op0 = Node->getOperand(0).getNode();
uint64_t Mask;
// Match (srl (and val, mask), imm) where the result would be a
// zero-extended 32-bit integer. i.e. the mask is 0xffffffff or the result
// is equivalent to this (SimplifyDemandedBits may have removed lower bits
// from the mask that aren't necessary due to the right-shifting).
if (Op1.getOpcode() == ISD::Constant &&
isConstantMask(Op0.getNode(), Mask)) {
uint64_t ShAmt = cast<ConstantSDNode>(Op1.getNode())->getZExtValue();
if (isa<ConstantSDNode>(Node->getOperand(1)) && isConstantMask(Op0, Mask)) {
uint64_t ShAmt = Node->getConstantOperandVal(1);

if ((Mask | maskTrailingOnes<uint64_t>(ShAmt)) == 0xffffffff) {
SDValue ShAmtVal =
CurDAG->getTargetConstant(ShAmt, SDLoc(Node), XLenVT);
CurDAG->SelectNodeTo(Node, RISCV::SRLIW, XLenVT, Op0.getOperand(0),
CurDAG->SelectNodeTo(Node, RISCV::SRLIW, XLenVT, Op0->getOperand(0),
ShAmtVal);
return;
}
}
// Match (srl (shl val, 32), imm).
if (Op0->getOpcode() == ISD::SHL &&
isa<ConstantSDNode>(Op0->getOperand(1)) &&
isa<ConstantSDNode>(Node->getOperand(1))) {
uint64_t ShlAmt = Op0->getConstantOperandVal(1);
uint64_t SrlAmt = Node->getConstantOperandVal(1);
if (ShlAmt == 32 && SrlAmt > 32) {
SDValue SrlAmtSub32Val =
CurDAG->getTargetConstant(SrlAmt - 32, SDLoc(Node), XLenVT);
CurDAG->SelectNodeTo(Node, RISCV::SRLIW, XLenVT, Op0->getOperand(0),
SrlAmtSub32Val);
return;
}
}
break;
}
case RISCVISD::READ_CYCLE_WIDE:
Expand Down Expand Up @@ -459,55 +471,6 @@ bool RISCVDAGToDAGISel::SelectRORIW(SDValue N, SDValue &RS1, SDValue &Shamt) {
return false;
}

// Check that it is a FSRIW (i32 Funnel Shift Right Immediate on RV64).
// We first check that it is the right node tree:
//
// (SIGN_EXTEND_INREG (OR (SHL (AsserSext RS1, i32), VC2),
// (SRL (AND (AssertSext RS2, i32), VC3), VC1)))
//
// Then we check that the constant operands respect these constraints:
//
// VC2 == 32 - VC1
// VC3 == maskLeadingOnes<uint32_t>(VC2)
//
// being VC1 the Shamt we need, VC2 the complementary of Shamt over 32
// and VC3 a 32 bit mask of (32 - VC1) leading ones.

bool RISCVDAGToDAGISel::SelectFSRIW(SDValue N, SDValue &RS1, SDValue &RS2,
SDValue &Shamt) {
if (N.getOpcode() == ISD::SIGN_EXTEND_INREG &&
Subtarget->getXLenVT() == MVT::i64 &&
cast<VTSDNode>(N.getOperand(1))->getVT() == MVT::i32) {
if (N.getOperand(0).getOpcode() == ISD::OR) {
SDValue Or = N.getOperand(0);
if (Or.getOperand(0).getOpcode() == ISD::SHL &&
Or.getOperand(1).getOpcode() == ISD::SRL) {
SDValue Shl = Or.getOperand(0);
SDValue Srl = Or.getOperand(1);
if (Srl.getOperand(0).getOpcode() == ISD::AND) {
SDValue And = Srl.getOperand(0);
if (isa<ConstantSDNode>(Srl.getOperand(1)) &&
isa<ConstantSDNode>(Shl.getOperand(1)) &&
isa<ConstantSDNode>(And.getOperand(1))) {
uint32_t VC1 = Srl.getConstantOperandVal(1);
uint32_t VC2 = Shl.getConstantOperandVal(1);
uint32_t VC3 = And.getConstantOperandVal(1);
if (VC2 == (32 - VC1) &&
VC3 == maskLeadingOnes<uint32_t>(VC2)) {
RS1 = Shl.getOperand(0);
RS2 = And.getOperand(0);
Shamt = CurDAG->getTargetConstant(VC1, SDLoc(N),
Srl.getOperand(1).getValueType());
return true;
}
}
}
}
}
}
return false;
}

// Merge an ADDI into the offset of a load/store instruction where possible.
// (load (addi base, off1), off2) -> (load base, off1+off2)
// (store val, (addi base, off1), off2) -> (store val, base, off1+off2)
Expand Down
1 change: 0 additions & 1 deletion llvm/lib/Target/RISCV/RISCVISelDAGToDAG.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class RISCVDAGToDAGISel : public SelectionDAGISel {
bool SelectSLOIW(SDValue N, SDValue &RS1, SDValue &Shamt);
bool SelectSROIW(SDValue N, SDValue &RS1, SDValue &Shamt);
bool SelectRORIW(SDValue N, SDValue &RS1, SDValue &Shamt);
bool SelectFSRIW(SDValue N, SDValue &RS1, SDValue &RS2, SDValue &Shamt);

// Include the pieces autogenerated from the target description.
#include "RISCVGenDAGISel.inc"
Expand Down
Loading

0 comments on commit 0819a64

Please sign in to comment.