Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV][GISel] RegBank select and instruction select for vector G_ADD, G_SUB #74114

Merged
merged 12 commits into from
Feb 1, 2024
Merged
3 changes: 2 additions & 1 deletion llvm/lib/CodeGen/GlobalISel/InstructionSelect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ bool InstructionSelect::runOnMachineFunction(MachineFunction &MF) {
}

const LLT Ty = MRI.getType(VReg);
if (Ty.isValid() && Ty.getSizeInBits() > TRI.getRegSizeInBits(*RC)) {
if (Ty.isValid() &&
TypeSize::isKnownGT(Ty.getSizeInBits(), TRI.getRegSizeInBits(*RC))) {
reportGISelFailure(
MF, TPC, MORE, "gisel-select",
"VReg's low-level type and register class have different sizes", *MI);
Expand Down
15 changes: 14 additions & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,20 @@ const TargetRegisterClass *RISCVInstructionSelector::getRegClassForTypeOnBank(
return &RISCV::FPR64RegClass;
}

// TODO: Non-GPR register classes.
if (RB.getID() == RISCV::VRBRegBankID) {
if (Ty.getSizeInBits().getKnownMinValue() <= 64)
return &RISCV::VRRegClass;

if (Ty.getSizeInBits().getKnownMinValue() == 128)
return &RISCV::VRM2RegClass;

if (Ty.getSizeInBits().getKnownMinValue() == 256)
return &RISCV::VRM4RegClass;

if (Ty.getSizeInBits().getKnownMinValue() == 512)
return &RISCV::VRM8RegClass;
}

return nullptr;
}

Expand Down
58 changes: 57 additions & 1 deletion llvm/lib/Target/RISCV/GISel/RISCVRegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,27 @@ namespace llvm {
namespace RISCV {

const RegisterBankInfo::PartialMapping PartMappings[] = {
// clang-format off
{0, 32, GPRBRegBank},
{0, 64, GPRBRegBank},
{0, 32, FPRBRegBank},
{0, 64, FPRBRegBank},
{0, 64, VRBRegBank},
{0, 128, VRBRegBank},
{0, 256, VRBRegBank},
{0, 512, VRBRegBank},
// clang-format on
};

enum PartialMappingIdx {
PMI_GPRB32 = 0,
PMI_GPRB64 = 1,
PMI_FPRB32 = 2,
PMI_FPRB64 = 3,
PMI_VRB64 = 4,
PMI_VRB128 = 5,
PMI_VRB256 = 6,
PMI_VRB512 = 7,
};

const RegisterBankInfo::ValueMapping ValueMappings[] = {
Expand All @@ -57,6 +67,22 @@ const RegisterBankInfo::ValueMapping ValueMappings[] = {
{&PartMappings[PMI_FPRB64], 1},
{&PartMappings[PMI_FPRB64], 1},
{&PartMappings[PMI_FPRB64], 1},
// Maximum 3 VR LMUL={1, MF2, MF4, MF8} operands.
{&PartMappings[PMI_VRB64], 1},
{&PartMappings[PMI_VRB64], 1},
{&PartMappings[PMI_VRB64], 1},
// Maximum 3 VR LMUL=2 operands.
{&PartMappings[PMI_VRB128], 1},
{&PartMappings[PMI_VRB128], 1},
{&PartMappings[PMI_VRB128], 1},
// Maximum 3 VR LMUL=4 operands.
{&PartMappings[PMI_VRB256], 1},
{&PartMappings[PMI_VRB256], 1},
{&PartMappings[PMI_VRB256], 1},
// Maximum 3 VR LMUL=8 operands.
{&PartMappings[PMI_VRB512], 1},
{&PartMappings[PMI_VRB512], 1},
{&PartMappings[PMI_VRB512], 1},
};

enum ValueMappingIdx {
Expand All @@ -65,6 +91,10 @@ enum ValueMappingIdx {
GPRB64Idx = 4,
FPRB32Idx = 7,
FPRB64Idx = 10,
VRB64Idx = 13,
VRB128Idx = 16,
VRB256Idx = 19,
VRB512Idx = 22,
};
} // namespace RISCV
} // namespace llvm
Expand Down Expand Up @@ -215,6 +245,23 @@ bool RISCVRegisterBankInfo::anyUseOnlyUseFP(
[&](const MachineInstr &UseMI) { return onlyUsesFP(UseMI, MRI, TRI); });
}

static const RegisterBankInfo::ValueMapping *getVRBValueMapping(unsigned Size) {
unsigned Idx;

if (Size <= 64)
Idx = RISCV::VRB64Idx;
else if (Size == 128)
Idx = RISCV::VRB128Idx;
else if (Size == 256)
Idx = RISCV::VRB256Idx;
else if (Size == 512)
Idx = RISCV::VRB512Idx;
else
llvm::report_fatal_error("Invalid Size");

return &RISCV::ValueMappings[Idx];
}

const RegisterBankInfo::InstructionMapping &
RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {
const unsigned Opc = MI.getOpcode();
Expand Down Expand Up @@ -242,7 +289,16 @@ RISCVRegisterBankInfo::getInstrMapping(const MachineInstr &MI) const {

switch (Opc) {
case TargetOpcode::G_ADD:
case TargetOpcode::G_SUB:
case TargetOpcode::G_SUB: {
if (MRI.getType(MI.getOperand(0).getReg()).isVector()) {
LLT Ty = MRI.getType(MI.getOperand(0).getReg());
return getInstructionMapping(
DefaultMappingID, /*Cost=*/1,
getVRBValueMapping(Ty.getSizeInBits().getKnownMinValue()),
NumOperands);
}
}
LLVM_FALLTHROUGH;
case TargetOpcode::G_SHL:
case TargetOpcode::G_ASHR:
case TargetOpcode::G_LSHR:
Expand Down
Loading
Loading