Skip to content

Commit

Permalink
[RISCV] Adjust select shuffle cost to reflect mask creation cost
Browse files Browse the repository at this point in the history
This is inspired by llvm#77342 (review), and is split off of same with some differences in style.

A select is a vmerge.vv with the additional cost of materializing the bitmask vector in a vreg.  All masks fit within a single vector register (e8 + m8 is the worst case), and thus our worst case cost should be roughly 3 (2 scalar to produce the address, one vector load op).  Given most shuffles are small, and the mask will be instead produced by LUI/ADDI + vmv.s.x or ADDI + vmv.s.x, using 2 as the default seems quite reasonable.  At worst, we're not going to be off by much.

The prior lowering scaled the cost of the bitmask with LMUL, which I don't understand.  At m1 it did use the same base cost of 2.
  • Loading branch information
preames committed Jan 12, 2024
1 parent 5ce067d commit e886bb3
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 47 deletions.
11 changes: 6 additions & 5 deletions llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,6 @@ RISCVTTIImpl::getRISCVInstructionCost(ArrayRef<unsigned> OpCodes, MVT VT,
Cost += VL;
break;
}
case RISCV::VMV_S_X:
// FIXME: VMV_S_X doesn't use LMUL, the cost should be 1
default:
Cost += LMULCost;
}
Expand Down Expand Up @@ -443,10 +441,13 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
// vsetivli zero, 8, e8, mf2, ta, ma (ignored)
// vmv.s.x v0, a0
// vmerge.vvm v8, v9, v8, v0
// We use 2 for the cost of the mask materialization as this is the true
// cost for small masks and most shuffles are small. At worst, this cost
// should be a very small constant for the constant pool load. As such,
// we may bias towards large selects slightly more than truely warranted.
return LT.first *
(TLI->getLMULCost(LT.second) + // FIXME: should be 1 for li
getRISCVInstructionCost({RISCV::VMV_S_X, RISCV::VMERGE_VVM},
LT.second, CostKind));
(2 + getRISCVInstructionCost({RISCV::VMERGE_VVM},
LT.second, CostKind));
}
case TTI::SK_Broadcast: {
bool HasScalar = (Args.size() > 0) && (Operator::getOpcode(Args[0]) ==
Expand Down
Loading

0 comments on commit e886bb3

Please sign in to comment.