Skip to content

Commit

Permalink
[AArch64] Peephole rule to remove redundant cmp after cset.
Browse files Browse the repository at this point in the history
Comparisons to zero or one after cset instructions can be safely
removed in examples like:

cset w9, eq          cset w9, eq
cmp  w9, #1   --->   <removed>
b.ne    .L1          b.ne    .L1

cset w9, eq          cset w9, eq
cmp  w9, #0   --->   <removed>
b.ne    .L1          b.eq    .L1

Peephole optimization to detect suitable cases and get rid of that
comparisons added.

Differential Revision: https://reviews.llvm.org/D98564
  • Loading branch information
ilinpv committed Apr 19, 2021
1 parent d880557 commit 2ec1610
Show file tree
Hide file tree
Showing 4 changed files with 219 additions and 76 deletions.
277 changes: 213 additions & 64 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1463,14 +1463,16 @@ bool AArch64InstrInfo::optimizeCompareInstr(
// FIXME:CmpValue has already been converted to 0 or 1 in analyzeCompare
// function.
assert((CmpValue == 0 || CmpValue == 1) && "CmpValue must be 0 or 1!");
if (CmpValue != 0 || SrcReg2 != 0)
if (SrcReg2 != 0)
return false;

// CmpInstr is a Compare instruction if destination register is not used.
if (!MRI->use_nodbg_empty(CmpInstr.getOperand(0).getReg()))
return false;

return substituteCmpToZero(CmpInstr, SrcReg, MRI);
if (!CmpValue && substituteCmpToZero(CmpInstr, SrcReg, *MRI))
return true;
return removeCmpToZeroOrOne(CmpInstr, SrcReg, CmpValue, *MRI);
}

/// Get opcode of S version of Instr.
Expand Down Expand Up @@ -1524,13 +1526,44 @@ static unsigned sForm(MachineInstr &Instr) {
}

/// Check if AArch64::NZCV should be alive in successors of MBB.
static bool areCFlagsAliveInSuccessors(MachineBasicBlock *MBB) {
static bool areCFlagsAliveInSuccessors(const MachineBasicBlock *MBB) {
for (auto *BB : MBB->successors())
if (BB->isLiveIn(AArch64::NZCV))
return true;
return false;
}

/// \returns The condition code operand index for \p Instr if it is a branch
/// or select and -1 otherwise.
static int
findCondCodeUseOperandIdxForBranchOrSelect(const MachineInstr &Instr) {
switch (Instr.getOpcode()) {
default:
return -1;

case AArch64::Bcc: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 2);
return Idx - 2;
}

case AArch64::CSINVWr:
case AArch64::CSINVXr:
case AArch64::CSINCWr:
case AArch64::CSINCXr:
case AArch64::CSELWr:
case AArch64::CSELXr:
case AArch64::CSNEGWr:
case AArch64::CSNEGXr:
case AArch64::FCSELSrrr:
case AArch64::FCSELDrrr: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 1);
return Idx - 1;
}
}
}

namespace {

struct UsedNZCV {
Expand All @@ -1556,31 +1589,10 @@ struct UsedNZCV {
/// Returns AArch64CC::Invalid if either the instruction does not use condition
/// codes or we don't optimize CmpInstr in the presence of such instructions.
static AArch64CC::CondCode findCondCodeUsedByInstr(const MachineInstr &Instr) {
switch (Instr.getOpcode()) {
default:
return AArch64CC::Invalid;

case AArch64::Bcc: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 2);
return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 2).getImm());
}

case AArch64::CSINVWr:
case AArch64::CSINVXr:
case AArch64::CSINCWr:
case AArch64::CSINCXr:
case AArch64::CSELWr:
case AArch64::CSELXr:
case AArch64::CSNEGWr:
case AArch64::CSNEGXr:
case AArch64::FCSELSrrr:
case AArch64::FCSELDrrr: {
int Idx = Instr.findRegisterUseOperandIdx(AArch64::NZCV);
assert(Idx >= 1);
return static_cast<AArch64CC::CondCode>(Instr.getOperand(Idx - 1).getImm());
}
}
int CCIdx = findCondCodeUseOperandIdxForBranchOrSelect(Instr);
return CCIdx >= 0 ? static_cast<AArch64CC::CondCode>(
Instr.getOperand(CCIdx).getImm())
: AArch64CC::Invalid;
}

static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
Expand Down Expand Up @@ -1627,6 +1639,41 @@ static UsedNZCV getUsedNZCV(AArch64CC::CondCode CC) {
return UsedFlags;
}

/// \returns Conditions flags used after \p CmpInstr in its MachineBB if they
/// are not containing C or V flags and NZCV flags are not alive in successors
/// of the same \p CmpInstr and \p MI parent. \returns None otherwise.
///
/// Collect instructions using that flags in \p CCUseInstrs if provided.
static Optional<UsedNZCV>
examineCFlagsUse(MachineInstr &MI, MachineInstr &CmpInstr,
const TargetRegisterInfo &TRI,
SmallVectorImpl<MachineInstr *> *CCUseInstrs = nullptr) {
MachineBasicBlock *CmpParent = CmpInstr.getParent();
if (MI.getParent() != CmpParent)
return None;

if (areCFlagsAliveInSuccessors(CmpParent))
return None;

UsedNZCV NZCVUsedAfterCmp;
for (MachineInstr &Instr : instructionsWithoutDebug(
std::next(CmpInstr.getIterator()), CmpParent->instr_end())) {
if (Instr.readsRegister(AArch64::NZCV, &TRI)) {
AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
return None;
NZCVUsedAfterCmp |= getUsedNZCV(CC);
if (CCUseInstrs)
CCUseInstrs->push_back(&Instr);
}
if (Instr.modifiesRegister(AArch64::NZCV, &TRI))
break;
}
if (NZCVUsedAfterCmp.C || NZCVUsedAfterCmp.V)
return None;
return NZCVUsedAfterCmp;
}

static bool isADDSRegImm(unsigned Opcode) {
return Opcode == AArch64::ADDSWri || Opcode == AArch64::ADDSXri;
}
Expand All @@ -1646,44 +1693,21 @@ static bool isSUBSRegImm(unsigned Opcode) {
/// or if MI opcode is not the S form there must be neither defs of flags
/// nor uses of flags between MI and CmpInstr.
/// - and C/V flags are not used after CmpInstr
static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
const TargetRegisterInfo *TRI) {
assert(MI);
assert(sForm(*MI) != AArch64::INSTRUCTION_LIST_END);
assert(CmpInstr);
static bool canInstrSubstituteCmpInstr(MachineInstr &MI, MachineInstr &CmpInstr,
const TargetRegisterInfo &TRI) {
assert(sForm(MI) != AArch64::INSTRUCTION_LIST_END);

const unsigned CmpOpcode = CmpInstr->getOpcode();
const unsigned CmpOpcode = CmpInstr.getOpcode();
if (!isADDSRegImm(CmpOpcode) && !isSUBSRegImm(CmpOpcode))
return false;

if (MI->getParent() != CmpInstr->getParent())
return false;

if (areCFlagsAliveInSuccessors(CmpInstr->getParent()))
if (!examineCFlagsUse(MI, CmpInstr, TRI))
return false;

AccessKind AccessToCheck = AK_Write;
if (sForm(*MI) != MI->getOpcode())
if (sForm(MI) != MI.getOpcode())
AccessToCheck = AK_All;
if (areCFlagsAccessedBetweenInstrs(MI, CmpInstr, TRI, AccessToCheck))
return false;

UsedNZCV NZCVUsedAfterCmp;
for (const MachineInstr &Instr :
instructionsWithoutDebug(std::next(CmpInstr->getIterator()),
CmpInstr->getParent()->instr_end())) {
if (Instr.readsRegister(AArch64::NZCV, TRI)) {
AArch64CC::CondCode CC = findCondCodeUsedByInstr(Instr);
if (CC == AArch64CC::Invalid) // Unsupported conditional instruction
return false;
NZCVUsedAfterCmp |= getUsedNZCV(CC);
}

if (Instr.modifiesRegister(AArch64::NZCV, TRI))
break;
}

return !NZCVUsedAfterCmp.C && !NZCVUsedAfterCmp.V;
return !areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AccessToCheck);
}

/// Substitute an instruction comparing to zero with another instruction
Expand All @@ -1692,20 +1716,19 @@ static bool canInstrSubstituteCmpInstr(MachineInstr *MI, MachineInstr *CmpInstr,
/// Return true on success.
bool AArch64InstrInfo::substituteCmpToZero(
MachineInstr &CmpInstr, unsigned SrcReg,
const MachineRegisterInfo *MRI) const {
assert(MRI);
const MachineRegisterInfo &MRI) const {
// Get the unique definition of SrcReg.
MachineInstr *MI = MRI->getUniqueVRegDef(SrcReg);
MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
if (!MI)
return false;

const TargetRegisterInfo *TRI = &getRegisterInfo();
const TargetRegisterInfo &TRI = getRegisterInfo();

unsigned NewOpc = sForm(*MI);
if (NewOpc == AArch64::INSTRUCTION_LIST_END)
return false;

if (!canInstrSubstituteCmpInstr(MI, &CmpInstr, TRI))
if (!canInstrSubstituteCmpInstr(*MI, CmpInstr, TRI))
return false;

// Update the instruction to set NZCV.
Expand All @@ -1714,7 +1737,133 @@ bool AArch64InstrInfo::substituteCmpToZero(
bool succeeded = UpdateOperandRegClass(*MI);
(void)succeeded;
assert(succeeded && "Some operands reg class are incompatible!");
MI->addRegisterDefined(AArch64::NZCV, TRI);
MI->addRegisterDefined(AArch64::NZCV, &TRI);
return true;
}

/// \returns True if \p CmpInstr can be removed.
///
/// \p IsInvertCC is true if, after removing \p CmpInstr, condition
/// codes used in \p CCUseInstrs must be inverted.
static bool canCmpInstrBeRemoved(MachineInstr &MI, MachineInstr &CmpInstr,
int CmpValue, const TargetRegisterInfo &TRI,
SmallVectorImpl<MachineInstr *> &CCUseInstrs,
bool &IsInvertCC) {
assert((CmpValue == 0 || CmpValue == 1) &&
"Only comparisons to 0 or 1 considered for removal!");

// MI is 'CSINCWr %vreg, wzr, wzr, <cc>' or 'CSINCXr %vreg, xzr, xzr, <cc>'
unsigned MIOpc = MI.getOpcode();
if (MIOpc == AArch64::CSINCWr) {
if (MI.getOperand(1).getReg() != AArch64::WZR ||
MI.getOperand(2).getReg() != AArch64::WZR)
return false;
} else if (MIOpc == AArch64::CSINCXr) {
if (MI.getOperand(1).getReg() != AArch64::XZR ||
MI.getOperand(2).getReg() != AArch64::XZR)
return false;
} else {
return false;
}
AArch64CC::CondCode MICC = findCondCodeUsedByInstr(MI);
if (MICC == AArch64CC::Invalid)
return false;

// NZCV needs to be defined
if (MI.findRegisterDefOperandIdx(AArch64::NZCV, true) != -1)
return false;

// CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0' or 'SUBS %vreg, 1'
const unsigned CmpOpcode = CmpInstr.getOpcode();
bool IsSubsRegImm = isSUBSRegImm(CmpOpcode);
if (CmpValue && !IsSubsRegImm)
return false;
if (!CmpValue && !IsSubsRegImm && !isADDSRegImm(CmpOpcode))
return false;

// MI conditions allowed: eq, ne, mi, pl
UsedNZCV MIUsedNZCV = getUsedNZCV(MICC);
if (MIUsedNZCV.C || MIUsedNZCV.V)
return false;

Optional<UsedNZCV> NZCVUsedAfterCmp =
examineCFlagsUse(MI, CmpInstr, TRI, &CCUseInstrs);
// Condition flags are not used in CmpInstr basic block successors and only
// Z or N flags allowed to be used after CmpInstr within its basic block
if (!NZCVUsedAfterCmp)
return false;
// Z or N flag used after CmpInstr must correspond to the flag used in MI
if ((MIUsedNZCV.Z && NZCVUsedAfterCmp->N) ||
(MIUsedNZCV.N && NZCVUsedAfterCmp->Z))
return false;
// If CmpInstr is comparison to zero MI conditions are limited to eq, ne
if (MIUsedNZCV.N && !CmpValue)
return false;

// There must be no defs of flags between MI and CmpInstr
if (areCFlagsAccessedBetweenInstrs(&MI, &CmpInstr, &TRI, AK_Write))
return false;

// Condition code is inverted in the following cases:
// 1. MI condition is ne; CmpInstr is 'ADDS %vreg, 0' or 'SUBS %vreg, 0'
// 2. MI condition is eq, pl; CmpInstr is 'SUBS %vreg, 1'
IsInvertCC = (CmpValue && (MICC == AArch64CC::EQ || MICC == AArch64CC::PL)) ||
(!CmpValue && MICC == AArch64CC::NE);
return true;
}

/// Remove comparision in csinc-cmp sequence
///
/// Examples:
/// 1. \code
/// csinc w9, wzr, wzr, ne
/// cmp w9, #0
/// b.eq
/// \endcode
/// to
/// \code
/// csinc w9, wzr, wzr, ne
/// b.ne
/// \endcode
///
/// 2. \code
/// csinc x2, xzr, xzr, mi
/// cmp x2, #1
/// b.pl
/// \endcode
/// to
/// \code
/// csinc x2, xzr, xzr, mi
/// b.pl
/// \endcode
///
/// \param CmpInstr comparison instruction
/// \return True when comparison removed
bool AArch64InstrInfo::removeCmpToZeroOrOne(
MachineInstr &CmpInstr, unsigned SrcReg, int CmpValue,
const MachineRegisterInfo &MRI) const {
MachineInstr *MI = MRI.getUniqueVRegDef(SrcReg);
if (!MI)
return false;
const TargetRegisterInfo &TRI = getRegisterInfo();
SmallVector<MachineInstr *, 4> CCUseInstrs;
bool IsInvertCC = false;
if (!canCmpInstrBeRemoved(*MI, CmpInstr, CmpValue, TRI, CCUseInstrs,
IsInvertCC))
return false;
// Make transformation
CmpInstr.eraseFromParent();
if (IsInvertCC) {
// Invert condition codes in CmpInstr CC users
for (MachineInstr *CCUseInstr : CCUseInstrs) {
int Idx = findCondCodeUseOperandIdxForBranchOrSelect(*CCUseInstr);
assert(Idx >= 0 && "Unexpected instruction using CC.");
MachineOperand &CCOperand = CCUseInstr->getOperand(Idx);
AArch64CC::CondCode CCUse = AArch64CC::getInvertedCondCode(
static_cast<AArch64CC::CondCode>(CCOperand.getImm()));
CCOperand.setImm(CCUse);
}
}
return true;
}

Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/AArch64/AArch64InstrInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,9 @@ class AArch64InstrInfo final : public AArch64GenInstrInfo {
MachineBasicBlock *TBB,
ArrayRef<MachineOperand> Cond) const;
bool substituteCmpToZero(MachineInstr &CmpInstr, unsigned SrcReg,
const MachineRegisterInfo *MRI) const;
const MachineRegisterInfo &MRI) const;
bool removeCmpToZeroOrOne(MachineInstr &CmpInstr, unsigned SrcReg,
int CmpValue, const MachineRegisterInfo &MRI) const;

/// Returns an unused general-purpose register which can be used for
/// constructing an outlined call if one exists. Returns 0 otherwise.
Expand Down
Loading

0 comments on commit 2ec1610

Please sign in to comment.