Skip to content

Commit

Permalink
[RISCV] Adjust select shuffle cost to reflect mask creation cost (#77963
Browse files Browse the repository at this point in the history
)

This is inspired by
#77342 (review),
and is split off of same with some differences in style.

A select is a vmerge.vv with the additional cost of materializing the
bitmask vector in a vreg. All masks fit within a single vector register
(e8 + m8 is the worst case), and thus our worst case cost should be
roughly 3 (2 scalar to produce the address, one vector load op). Given
most shuffles are small, and the mask will be instead produced by
LUI/ADDI + vmv.s.x or ADDI + vmv.s.x, using 2 as the default seems quite
reasonable. At worst, we're not going to be off by much.

The prior lowering scaled the cost of the bitmask with LMUL, which I
don't understand. At m1 it did use the same base cost of 2. (@lukel97
You wrote the original code here, anything I'm missing here?)
  • Loading branch information
preames authored Jan 18, 2024
1 parent 0c195e5 commit 2663d2c
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 47 deletions.
11 changes: 6 additions & 5 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ RISCVTTIImpl::getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
Cost += VL;
break;
}
case RISCV::VMV_S_X:
// FIXME: VMV_S_X doesn't use LMUL, the cost should be 1
default:
Cost += LMULCost;
}
Expand Down Expand Up @@ -443,10 +441,13 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
// vsetivli zero, 8, e8, mf2, ta, ma (ignored)
// vmv.s.x v0, a0
// vmerge.vvm v8, v9, v8, v0
// We use 2 for the cost of the mask materialization as this is the true
// cost for small masks and most shuffles are small. At worst, this cost
// should be a very small constant for the constant pool load. As such,
// we may bias towards large selects slightly more than truely warranted.
return LT.first *
(TLI->getLMULCost(LT.second) + // FIXME: should be 1 for li
getRISCVInstructionCost({RISCV::VMV_S_X, RISCV::VMERGE_VVM},
LT.second, CostKind));
(2 + getRISCVInstructionCost({RISCV::VMERGE_VVM},
LT.second, CostKind));
}
case TTI::SK_Broadcast: {
bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==
Expand Down
Loading

0 comments on commit 2663d2c

Please sign in to comment.