Skip to content

Commit

Permalink
[AIE2P] Support Wide Load/Store
Browse files Browse the repository at this point in the history
  • Loading branch information
Abnikant Singh committed Jan 24, 2025
1 parent f9ffbd6 commit 6aac9bb
Show file tree
Hide file tree
Showing 3 changed files with 442 additions and 83 deletions.
200 changes: 132 additions & 68 deletions llvm/lib/Target/AIE/aie2p/AIE2PInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,9 +100,9 @@ class AIE2PInstructionSelector : public AIEBaseInstructionSelector {
bool selectG_LOAD(MachineInstr &I, MachineRegisterInfo &MRI);
bool selectG_STORE(MachineInstr &I, MachineRegisterInfo &MRI);
bool selectG_AIE_LOAD_STORE(MachineInstr &I, MachineRegisterInfo &MRI);
bool select1024BitG_AIE_LOAD_STORE(MachineInstr &I, LoadStoreOpcodes &LSO,
AddressingModeInfo &AMI,
MachineRegisterInfo &MRI);
bool selectWideG_AIE_LOAD_STORE(MachineInstr &I, LoadStoreOpcodes &LSO,
AddressingModeInfo &AMI,
MachineRegisterInfo &MRI);
bool selectSetI128(MachineInstr &I, MachineOperand &DstReg,
MachineOperand &SrcReg, MachineRegisterInfo &MRI);
bool selectExtractI128(MachineInstr &I, Register DstReg, Register SrcReg,
Expand Down Expand Up @@ -1731,14 +1731,25 @@ LoadStoreOpcodes AIE2PInstructionSelector::getLoadStoreOpcode(
AlwaysFitsImmediateRange,
/*OffsetOpcode=*/AIE2P::VLDA_dmx_lda_fifohl_idx_imm};
}
if (RBID == AIE2P::VRegBankID) {
llvm_unreachable("Unimplemented");
}
if (RBID == AIE2P::AccRegBankID) {
llvm_unreachable("Unimplemented");
return {/*ISelOpcode=*/AIE2P::VLDA_dmx_lda_bm_idx_imm,
AlwaysFitsImmediateRange,
/*OffsetOpcode=*/AIE2P::VLDA_dmx_lda_bm_idx_imm};
}
if (RBID == AIE2P::VRegBankID) {
return {/*ISelOpcode=*/AIE2P::VLDA_dmx_lda_x_idx_imm,
AlwaysFitsImmediateRange,
/*OffsetOpcode=*/AIE2P::VLDA_dmx_lda_x_idx_imm};
}
llvm_unreachable("1024-bit vector type must be in AccRegBank or VRegBank "
"or FifoRegBankID");
} else if (LoadStoreSize == 2048) {
if (RBID == AIE2P::AccRegBankID) {
return {/*ISelOpcode=*/AIE2P::VLDA_dmx_lda_bm_idx_imm,
AlwaysFitsImmediateRange,
/*OffsetOpcode=*/AIE2P::VLDA_dmx_lda_bm_idx_imm};
}
llvm_unreachable("2048-bit vector type must be in AccRegBank");
}
break;
}
Expand Down Expand Up @@ -2055,6 +2066,14 @@ LoadStoreOpcodes AIE2PInstructionSelector::getLoadStoreOpcode(
}
llvm_unreachable("1024-bit vector type must be in AccRegBank or "
"VRegBank or FifoRegBankID");
} else if (LoadStoreSize == 2048) {
assert(RBID == AIE2P::AccRegBankID &&
"2048-bit vectors should be in the Accumulator Register Bank");
if (RBID == AIE2P::AccRegBankID) {
return {/*ISelOpcode=*/AIE2P::VST_dmx_sts_bm_idx_imm,
/*FitsImmediateRange=*/AlwaysFitsImmediateRange,
/*OffsetOpcode=*/AIE2P::VST_dmx_sts_bm_idx_imm};
}
}
break;
}
Expand Down Expand Up @@ -2253,76 +2272,120 @@ LoadStoreOpcodes AIE2PInstructionSelector::getLoadStoreOpcode(
llvm_unreachable("Invalid instruction");
}

bool AIE2PInstructionSelector::select1024BitG_AIE_LOAD_STORE(
bool AIE2PInstructionSelector::selectWideG_AIE_LOAD_STORE(
MachineInstr &I, LoadStoreOpcodes &LSO, AddressingModeInfo &AMI,
MachineRegisterInfo &MRI) {
LLT SrcDstTy = MRI.getType(AMI.SrcDstOp.getReg());
unsigned SrcDstTySize = SrcDstTy.getSizeInBits();
unsigned SplitFactor = (SrcDstTySize == 1024) ? 2 : 4;
unsigned RBID = deriveRegBankID(I.getOperand(0).getReg(), MRI, RBI);
const TargetRegisterClass *RC512 = nullptr;
const TargetRegisterClass *RC1024 = nullptr;
const TargetRegisterClass *RC2048 = &AIE2P::ACC2048RegClass;
llvm::SmallVector<unsigned, 4> SubRegIdxes;

if (RBID == AIE2P::AccRegBankID) {
SubRegIdxes = {AIE2P::sub_512_acc_lo, AIE2P::sub_512_acc_hi,
AIE2P::sub_1024_acc_hi_then_sub_512_acc_lo,
AIE2P::sub_1024_acc_hi_then_sub_512_acc_hi};
RC512 = &AIE2P::ACC512RegClass;
RC1024 = &AIE2P::ACC1024RegClass;
} else if (RBID == AIE2P::VRegBankID) {
SubRegIdxes = {AIE2P::sub_512_lo, AIE2P::sub_512_hi};
RC512 = &AIE2P::VEC512RegClass;
RC1024 = &AIE2P::VEC1024RegClass;
} else if (RBID == AIE2P::FifoRegBankID) {
RC512 = &AIE2P::FIFO512RegClass;
SubRegIdxes = {AIE2P::sub_lo_fifo, AIE2P::sub_hi_fifo};
RC1024 = &AIE2P::FIFO1024RegClass;
} else {
llvm_unreachable("Unknown Register Bank ID!");
}

bool IsFifo = deriveRegBankID(I.getOperand(0).getReg(), MRI, RBI) ==
AIE2P::FifoRegBankID;
assert(IsFifo && "Expected FiforegBank for 1024-bit load/store. Other banks "
"are unsupported");

Register Low512 = MRI.createVirtualRegister(&AIE2P::FIFO512RegClass);
Register High512 = MRI.createVirtualRegister(&AIE2P::FIFO512RegClass);

std::vector<Register> SubRegs(SplitFactor);
for (unsigned i = 0; i < SplitFactor; ++i) {
SubRegs[i] = MRI.createVirtualRegister(RC512);
}
auto handleSplitMemOperands = [&](auto &Instrs) {
int NumSplits = SplitFactor / 2;
for (unsigned i = 0; i < NumSplits; ++i) {
unsigned Offset = (SrcDstTySize == 2048 && i == 0) ? 128 : 0;
addSplitMemOperands(
AMI.MemI, Instrs[SplitFactor - 1 - 2 * i] /*Higher MIB*/,
Instrs[SplitFactor - 2 - 2 * i] /*Lower MIB*/, Offset, SplitFactor);
}
};

auto constrainInstRegOps = [&](auto &Instrs) {
return std::all_of(Instrs.begin(), Instrs.end(), [&](const auto &Instr) {
return constrainSelectedInstRegOperands(*Instr, TII, TRI, RBI);
});
};

llvm::SmallVector<MachineInstrBuilder, 4> Instrs;
switch (AMI.MemI.getOpcode()) {
case AIE2P::G_STORE: {
auto LowerBits = MIB.buildInstr(TargetOpcode::COPY, {Low512}, {})
.addReg(AMI.SrcDstOp.getReg(), 0, AIE2P::sub_lo_fifo);
auto HigherBits = MIB.buildInstr(TargetOpcode::COPY, {High512}, {})
.addReg(AMI.SrcDstOp.getReg(), 0, AIE2P::sub_hi_fifo);

auto StoreHigher = MIB.buildInstr(*LSO.OffsetOpcode, {}, {})
.addReg(HigherBits.getReg(0))
.addReg(AMI.PtrOp.getReg())
.addImm(64); // Offset
auto StoreLower = MIB.buildInstr(LSO.ISelOpcode, {}, {});

for (auto Def : AMI.MemI.defs())
StoreLower.addDef(Def.getReg());

StoreLower.addReg(LowerBits.getReg(0));

addAddressingMode(StoreLower, AMI, LSO.FitsImmediateRange, false, MRI);

addSplitMemOperands(AMI.MemI, StoreHigher, StoreLower, 0, 2);
for (unsigned SubReg = 0; SubReg < SplitFactor; ++SubReg) {
int Offset = SubReg * 64;
auto Copy = MIB.buildInstr(TargetOpcode::COPY, {SubRegs[SubReg]}, {})
.addReg(AMI.SrcDstOp.getReg(), 0,
SubRegIdxes[SubReg % SubRegIdxes.size()]);

auto Store = (SubReg == 0) ? MIB.buildInstr(LSO.ISelOpcode, {}, {})
: MIB.buildInstr(*LSO.OffsetOpcode, {}, {})
.addReg(Copy.getReg(0))
.addReg(AMI.PtrOp.getReg())
.addImm(Offset);

if (SubReg == 0) {
for (auto Def : AMI.MemI.defs())
Store.addDef(Def.getReg());
Store.addReg(Copy.getReg(0));
addAddressingMode(Store, AMI, LSO.FitsImmediateRange, false, MRI);
}
Instrs.push_back(Store);
}

AMI.MemI.eraseFromParent();
return constrainSelectedInstRegOperands(*StoreLower, TII, TRI, RBI) &&
constrainSelectedInstRegOperands(*StoreHigher, TII, TRI, RBI);
handleSplitMemOperands(Instrs);
break;
}
case AIE2P::G_LOAD: {
auto LoadHigher = MIB.buildInstr(*LSO.OffsetOpcode, {}, {})
.addDef(High512)
.addUse(AMI.PtrOp.getReg())
.addImm(64); // Offset

auto LoadLower = MIB.buildInstr(LSO.ISelOpcode, {Low512}, {});
// We have to skip the first Def (the 1024-bit Dst-Reg)
for (auto *Def = AMI.MemI.defs().begin() + 1; Def != AMI.MemI.defs().end();
Def++)
LoadLower.addDef(Def->getReg());

addAddressingMode(LoadLower, AMI, LSO.FitsImmediateRange, false, MRI);

addSplitMemOperands(AMI.MemI, LoadHigher, LoadLower, 0, 2);

MIB.buildInstr(AIE2P::REG_SEQUENCE, {AMI.SrcDstOp.getReg()}, {})
.addReg(Low512)
.addImm(AIE2P::sub_lo_fifo)
.addReg(High512)
.addImm(AIE2P::sub_hi_fifo);
for (unsigned SubReg = 0; SubReg < SplitFactor; ++SubReg) {
auto Load = (SubReg == 0)
? MIB.buildInstr(LSO.ISelOpcode, {SubRegs[SubReg]}, {})
: MIB.buildInstr(*LSO.OffsetOpcode, {}, {})
.addDef(SubRegs[SubReg])
.addUse(AMI.PtrOp.getReg())
.addImm(SubReg * 64);
if (SubReg == 0) {
for (auto *Def = AMI.MemI.defs().begin() + 1;
Def != AMI.MemI.defs().end(); Def++) {
Load.addDef(Def->getReg());
}
addAddressingMode(Load, AMI, LSO.FitsImmediateRange, false, MRI);
}

Instrs.push_back(Load);
}
auto RegSeq =
MIB.buildInstr(AIE2P::REG_SEQUENCE, {AMI.SrcDstOp.getReg()}, {});
for (unsigned SubReg = 0; SubReg < SplitFactor; ++SubReg) {
RegSeq.addReg(SubRegs[SubReg]).addImm(SubRegIdxes[SubReg]);
}
Register SrcDstReg = AMI.SrcDstOp.getReg();
AMI.MemI.eraseFromParent();
return constrainSelectedInstRegOperands(*LoadLower, TII, TRI, RBI) &&
constrainSelectedInstRegOperands(*LoadHigher, TII, TRI, RBI) &&
RBI.constrainGenericRegister(SrcDstReg, *&AIE2P::FIFO1024RegClass,
MRI);
if (!RBI.constrainGenericRegister(
SrcDstReg, *(SrcDstTySize == 2048 ? RC2048 : RC1024), MRI))
return false;

handleSplitMemOperands(Instrs);
break;
}
default:
return false;
}

AMI.MemI.eraseFromParent();
return constrainInstRegOps(Instrs);
}

bool AIE2PInstructionSelector::selectG_LOAD(MachineInstr &I,
Expand Down Expand Up @@ -2384,12 +2447,13 @@ bool AIE2PInstructionSelector::selectG_AIE_LOAD_STORE(

LoadStoreOpcodes LSO =
getLoadStoreOpcode(AMI->MemI, MRI, RBI, AMI->ImmediateOffset);
auto StoreSize = MRI.getType(AMI->SrcDstOp.getReg()).getSizeInBits();
if (StoreSize == 1024) {
return select1024BitG_AIE_LOAD_STORE(I, LSO, *AMI, MRI);
}
MachineInstrBuilder NewInstr = MIB.buildInstr(LSO.ISelOpcode);

LLT SrcDstTy = MRI.getType(AMI->SrcDstOp.getReg());
auto SrcDstTySize = SrcDstTy.getSizeInBits();
if ((SrcDstTySize == 1024) || (SrcDstTySize == 2048))
return selectWideG_AIE_LOAD_STORE(I, LSO, *AMI, MRI);

MachineInstrBuilder NewInstr = MIB.buildInstr(LSO.ISelOpcode);
for (auto Def : AMI->MemI.defs())
NewInstr.addDef(Def.getReg());

Expand Down
Loading

0 comments on commit 6aac9bb

Please sign in to comment.