Skip to content

Commit

Permalink
[llvm][AArch64][Assembly]: Add FP8FMA assembly and disassembly. (#70134)
Browse files Browse the repository at this point in the history
This patch adds the feature flag FP8FMA and the assembly/disassembly
for the following instructions of NEON and SVE2:
  * NEON: 
    - FMLALBlane
    - FMLALTlane
    - FMLALLBBlane
    - FMLALLBTlane
    - FMLALLTBlane
    - FMLALLTTlane
    - FMLALB
    - FMLALT
    - FMLALLB
    - FMLALLBT
    - FMLALLTB
    - FMLALLTT
  * SVE2:
    - FMLALB_ZZZI
    - FMLALT_ZZZI
    - FMLALB_ZZZ 
    - FMLALT_ZZZ 
    - FMLALLBB_ZZZI 
    - FMLALLBT_ZZZI 
    - FMLALLTB_ZZZI 
    - FMLALLTT_ZZZI 
    - FMLALLBB_ZZZ 
    - FMLALLBT_ZZZ 
    - FMLALLTB_ZZZ 
    - FMLALLTT_ZZZ

That is according to this documentation:
https://developer.arm.com/documentation/ddi0602/2023-09
  • Loading branch information
hassnaaHamdi authored Nov 1, 2023
1 parent 47d9fbc commit 6477b41
Show file tree
Hide file tree
Showing 22 changed files with 944 additions and 53 deletions.
4 changes: 4 additions & 0 deletions llvm/include/llvm/TargetParser/AArch64TargetParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ enum ArchExtKind : unsigned {
AEK_FPMR = 58, // FEAT_FPMR
AEK_FP8 = 59, // FEAT_FP8
AEK_FAMINMAX = 60, // FEAT_FAMINMAX
AEK_FP8FMA = 61, // FEAT_FP8FMA
AEK_SSVE_FP8FMA = 62, // FEAT_SSVE_FP8FMA
AEK_NUM_EXTENSIONS
};
using ExtensionBitset = Bitset<AEK_NUM_EXTENSIONS>;
Expand Down Expand Up @@ -273,6 +275,8 @@ inline constexpr ExtensionInfo Extensions[] = {
{"fpmr", AArch64::AEK_FPMR, "+fpmr", "-fpmr", FEAT_INIT, "", 0},
{"fp8", AArch64::AEK_FP8, "+fp8", "-fp8", FEAT_INIT, "+fpmr", 0},
{"faminmax", AArch64::AEK_FAMINMAX, "+faminmax", "-faminmax", FEAT_INIT, "", 0},
{"fp8fma", AArch64::AEK_FP8FMA, "+fp8fma", "-fp8fma", FEAT_INIT, "+fpmr", 0},
{"ssve-fp8fma", AArch64::AEK_SSVE_FP8FMA, "+ssve-fp8fma", "-ssve-fp8fma", FEAT_INIT, "+sme2", 0},
// Special cases
{"none", AArch64::AEK_NONE, {}, {}, FEAT_INIT, "", ExtensionInfo::MaxFMVPriority},
};
Expand Down
2 changes: 1 addition & 1 deletion llvm/include/llvm/TargetParser/SubtargetFeature.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace llvm {
class raw_ostream;
class Triple;

const unsigned MAX_SUBTARGET_WORDS = 4;
const unsigned MAX_SUBTARGET_WORDS = 5;
const unsigned MAX_SUBTARGET_FEATURES = MAX_SUBTARGET_WORDS * 64;

/// Container class for subtarget features.
Expand Down
10 changes: 8 additions & 2 deletions llvm/lib/Target/AArch64/AArch64.td
Original file line number Diff line number Diff line change
Expand Up @@ -517,6 +517,12 @@ def FeatureSME2p1 : SubtargetFeature<"sme2p1", "HasSME2p1", "true",
def FeatureFAMINMAX: SubtargetFeature<"faminmax", "HasFAMINMAX", "true",
"Enable FAMIN and FAMAX instructions (FEAT_FAMINMAX)">;

def FeatureFP8FMA : SubtargetFeature<"fp8fma", "HasFP8FMA", "true",
"Enable fp8 multiply-add instructions (FEAT_FP8FMA)">;

def FeatureSSVE_FP8FMA : SubtargetFeature<"ssve-fp8fma", "HasSSVE_FP8FMA", "true",
"Enable SVE2 fp8 multiply-add instructions (FEAT_SSVE_FP8FMA)", [FeatureSME2]>;

def FeatureAppleA7SysReg : SubtargetFeature<"apple-a7-sysreg", "HasAppleA7SysReg", "true",
"Apple A7 (the CPU formerly known as Cyclone)">;

Expand Down Expand Up @@ -747,7 +753,7 @@ let F = [HasSVE2p1, HasSVE2p1_or_HasSME2, HasSVE2p1_or_HasSME2p1] in
def SVE2p1Unsupported : AArch64Unsupported;

def SVE2Unsupported : AArch64Unsupported {
let F = !listconcat([HasSVE2, HasSVE2orSME,
let F = !listconcat([HasSVE2, HasSVE2orSME, HasSSVE_FP8FMA,
HasSVE2AES, HasSVE2SHA3, HasSVE2SM4, HasSVE2BitPerm],
SVE2p1Unsupported.F);
}
Expand All @@ -761,7 +767,7 @@ let F = [HasSME2p1, HasSVE2p1_or_HasSME2p1] in
def SME2p1Unsupported : AArch64Unsupported;

def SME2Unsupported : AArch64Unsupported {
let F = !listconcat([HasSME2, HasSVE2p1_or_HasSME2],
let F = !listconcat([HasSME2, HasSVE2p1_or_HasSME2, HasSSVE_FP8FMA],
SME2p1Unsupported.F);
}

Expand Down
34 changes: 34 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrFormats.td
Original file line number Diff line number Diff line change
Expand Up @@ -6055,6 +6055,15 @@ multiclass SIMDThreeSameVectorFML<bit U, bit b13, bits<3> size, string asm,
v4f32, v8f16, OpNode>;
}

multiclass SIMDThreeSameVectorMLA<bit Q, string asm>{
def v8f16 : BaseSIMDThreeSameVectorDot<Q, 0b0, 0b11, 0b1111, asm, ".8h", ".16b",
V128, v8f16, v16i8, null_frag>;
}

multiclass SIMDThreeSameVectorMLAL<bit Q, bits<2> sz, string asm>{
def v4f32 : BaseSIMDThreeSameVectorDot<Q, 0b0, sz, 0b1000, asm, ".4s", ".16b",
V128, v4f32, v16i8, null_frag>;
}

// FP8 assembly/disassembly classes

Expand Down Expand Up @@ -8521,6 +8530,31 @@ class BF16ToSinglePrecision<string asm>
}
} // End of let mayStore = 0, mayLoad = 0, hasSideEffects = 0

//----------------------------------------------------------------------------
class BaseSIMDThreeSameVectorIndexB<bit Q, bit U, bits<2> sz, bits<4> opc,
string asm, string dst_kind,
RegisterOperand RegType,
RegisterOperand RegType_lo>
: BaseSIMDIndexedTied<Q, U, 0b0, sz, opc,
RegType, RegType, RegType_lo, VectorIndexB,
asm, "", dst_kind, ".16b", ".b", []> {

// idx = H:L:M
bits<4> idx;
let Inst{11} = idx{3};
let Inst{21-19} = idx{2-0};
}

multiclass SIMDThreeSameVectorMLAIndex<bit Q, string asm> {
def v8f16 : BaseSIMDThreeSameVectorIndexB<Q, 0b0, 0b11, 0b0000, asm, ".8h",
V128, V128_0to7>;
}

multiclass SIMDThreeSameVectorMLALIndex<bit Q, bits<2> sz, string asm> {
def v4f32 : BaseSIMDThreeSameVectorIndexB<Q, 0b1, sz, 0b1000, asm, ".4s",
V128, V128_0to7>;
}

//----------------------------------------------------------------------------
// Armv8.6 Matrix Multiply Extension
//----------------------------------------------------------------------------
Expand Down
22 changes: 22 additions & 0 deletions llvm/lib/Target/AArch64/AArch64InstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ def HasFP8 : Predicate<"Subtarget->hasFP8()">,
AssemblerPredicateWithAll<(all_of FeatureFP8), "fp8">;
def HasFAMINMAX : Predicate<"Subtarget->hasFAMINMAX()">,
AssemblerPredicateWithAll<(all_of FeatureFAMINMAX), "faminmax">;
def HasFP8FMA : Predicate<"Subtarget->hasFP8FMA()">,
AssemblerPredicateWithAll<(all_of FeatureFP8FMA), "fp8fma">;
def HasSSVE_FP8FMA : Predicate<"Subtarget->SSVE_FP8FMA() || "
"(Subtarget->hasSVE2() && Subtarget->hasFP8FMA())">,
AssemblerPredicateWithAll<(any_of FeatureSSVE_FP8FMA,
(all_of FeatureSVE2, FeatureFP8FMA)),
"ssve-fp8fma or (sve2 and fp8fma)">;

// A subset of SVE(2) instructions are legal in Streaming SVE execution mode,
// they should be enabled if either has been specified.
Expand Down Expand Up @@ -9286,6 +9293,21 @@ let Predicates = [HasFAMINMAX] in {
defm FAMIN : SIMDThreeSameVectorFP<0b1, 0b1, 0b011, "famin", null_frag>;
} // End let Predicates = [HasFAMAXMIN]

let Predicates = [HasFP8FMA] in {
defm FMLALBlane : SIMDThreeSameVectorMLAIndex<0b0, "fmlalb">;
defm FMLALTlane : SIMDThreeSameVectorMLAIndex<0b1, "fmlalt">;
defm FMLALLBBlane : SIMDThreeSameVectorMLALIndex<0b0, 0b00, "fmlallbb">;
defm FMLALLBTlane : SIMDThreeSameVectorMLALIndex<0b0, 0b01, "fmlallbt">;
defm FMLALLTBlane : SIMDThreeSameVectorMLALIndex<0b1, 0b00, "fmlalltb">;
defm FMLALLTTlane : SIMDThreeSameVectorMLALIndex<0b1, 0b01, "fmlalltt">;

defm FMLALB : SIMDThreeSameVectorMLA<0b0, "fmlalb">;
defm FMLALT : SIMDThreeSameVectorMLA<0b1, "fmlalt">;
defm FMLALLBB : SIMDThreeSameVectorMLAL<0b0, 0b00, "fmlallbb">;
defm FMLALLBT : SIMDThreeSameVectorMLAL<0b0, 0b01, "fmlallbt">;
defm FMLALLTB : SIMDThreeSameVectorMLAL<0b1, 0b00, "fmlalltb">;
defm FMLALLTT : SIMDThreeSameVectorMLAL<0b1, 0b01, "fmlalltt">;
} // End let Predicates = [HasFP8FMA]

include "AArch64InstrAtomics.td"
include "AArch64SVEInstrInfo.td"
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -981,6 +981,8 @@ unsigned AArch64RegisterInfo::getRegPressureLimit(const TargetRegisterClass *RC,
case AArch64::FPR64_loRegClassID:
case AArch64::FPR16_loRegClassID:
return 16;
case AArch64::FPR128_0to7RegClassID:
return 8;
}
}

Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/AArch64/AArch64RegisterInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,13 @@ def FPR128_lo : RegisterClass<"AArch64",
v8bf16],
128, (trunc FPR128, 16)>;

// The lower 8 vector registers. Some instructions can only take registers
// in this range.
def FPR128_0to7 : RegisterClass<"AArch64",
[v16i8, v8i16, v4i32, v2i64, v4f32, v2f64, v8f16,
v8bf16],
128, (trunc FPR128, 8)>;

// Pairs, triples, and quads of 64-bit vector registers.
def DSeqPairs : RegisterTuples<[dsub0, dsub1], [(rotl FPR64, 0), (rotl FPR64, 1)]>;
def DSeqTriples : RegisterTuples<[dsub0, dsub1, dsub2],
Expand Down Expand Up @@ -534,6 +541,15 @@ def V128_lo : RegisterOperand<FPR128_lo, "printVRegOperand"> {
let ParserMatchClass = VectorRegLoAsmOperand;
}

def VectorReg0to7AsmOperand : AsmOperandClass {
let Name = "VectorReg0to7";
let PredicateMethod = "isNeonVectorReg0to7";
}

def V128_0to7 : RegisterOperand<FPR128_0to7, "printVRegOperand"> {
let ParserMatchClass = VectorReg0to7AsmOperand;
}

class TypedVecListAsmOperand<int count, string vecty, int lanes, int eltsize>
: AsmOperandClass {
let Name = "TypedVectorList" # count # "_" # lanes # eltsize;
Expand Down
19 changes: 19 additions & 0 deletions llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -4045,3 +4045,22 @@ let Predicates = [HasSVE2orSME2, HasFAMINMAX] in {
defm FAMIN_ZPmZ : sve_fp_2op_p_zds<0b1111, "famin", "", null_frag, DestructiveOther>;
defm FAMAX_ZPmZ : sve_fp_2op_p_zds<0b1110, "famax", "", null_frag, DestructiveOther>;
} // End HasSVE2orSME2, HasFAMINMAX

let Predicates = [HasSSVE_FP8FMA] in {
// FP8 Widening Multiply-Add Long - Indexed Group
def FMLALB_ZZZI : sve2_fp8_mla_long_by_indexed_elem<0b0, "fmlalb">;
def FMLALT_ZZZI : sve2_fp8_mla_long_by_indexed_elem<0b1, "fmlalt">;
// FP8 Widening Multiply-Add Long Group
def FMLALB_ZZZ : sve2_fp8_mla<0b100, ZPR16, "fmlalb">;
def FMLALT_ZZZ : sve2_fp8_mla<0b101, ZPR16, "fmlalt">;
// FP8 Widening Multiply-Add Long Long - Indexed Group
def FMLALLBB_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b00, "fmlallbb">;
def FMLALLBT_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b01, "fmlallbt">;
def FMLALLTB_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b10, "fmlalltb">;
def FMLALLTT_ZZZI : sve2_fp8_mla_long_long_by_indexed_elem<0b11, "fmlalltt">;
// FP8 Widening Multiply-Add Long Long Group
def FMLALLBB_ZZZ : sve2_fp8_mla<0b000, ZPR32, "fmlallbb">;
def FMLALLBT_ZZZ : sve2_fp8_mla<0b001, ZPR32, "fmlallbt">;
def FMLALLTB_ZZZ : sve2_fp8_mla<0b010, ZPR32, "fmlalltb">;
def FMLALLTT_ZZZ : sve2_fp8_mla<0b011, ZPR32, "fmlalltt">;
} // End HasSSVE_FP8FMA
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64SchedA64FX.td
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def A64FXModel : SchedMachineModel {
list<Predicate> UnsupportedFeatures =
[HasSVE2, HasSVE2AES, HasSVE2SM4, HasSVE2SHA3, HasSVE2BitPerm, HasPAuth,
HasSVE2orSME, HasMTE, HasMatMulInt8, HasBF16, HasSME2, HasSME2p1, HasSVE2p1,
HasSVE2p1_or_HasSME2p1, HasSMEF16F16];
HasSVE2p1_or_HasSME2p1, HasSMEF16F16, HasSSVE_FP8FMA];

let FullInstRWOverlapCheck = 0;
}
Expand Down
63 changes: 38 additions & 25 deletions llvm/lib/Target/AArch64/AsmParser/AArch64AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1223,6 +1223,12 @@ class AArch64Operand : public MCParsedAsmOperand {
Reg.RegNum));
}

bool isNeonVectorReg0to7() const {
return Kind == k_Register && Reg.Kind == RegKind::NeonVector &&
(AArch64MCRegisterClasses[AArch64::FPR128_0to7RegClassID].contains(
Reg.RegNum));
}

bool isMatrix() const { return Kind == k_MatrixRegister; }
bool isMatrixTileList() const { return Kind == k_MatrixTileList; }

Expand Down Expand Up @@ -1766,6 +1772,11 @@ class AArch64Operand : public MCParsedAsmOperand {
Inst.addOperand(MCOperand::createReg(getReg()));
}

void addVectorReg0to7Operands(MCInst &Inst, unsigned N) const {
assert(N == 1 && "Invalid number of operands!");
Inst.addOperand(MCOperand::createReg(getReg()));
}

enum VecListIndexType {
VecListIdx_DReg = 0,
VecListIdx_QReg = 1,
Expand Down Expand Up @@ -2598,31 +2609,31 @@ static std::optional<std::pair<int, int>> parseVectorKind(StringRef Suffix,

switch (VectorKind) {
case RegKind::NeonVector:
Res =
StringSwitch<std::pair<int, int>>(Suffix.lower())
.Case("", {0, 0})
.Case(".1d", {1, 64})
.Case(".1q", {1, 128})
// '.2h' needed for fp16 scalar pairwise reductions
.Case(".2h", {2, 16})
.Case(".2s", {2, 32})
.Case(".2d", {2, 64})
// '.4b' is another special case for the ARMv8.2a dot product
// operand
.Case(".4b", {4, 8})
.Case(".4h", {4, 16})
.Case(".4s", {4, 32})
.Case(".8b", {8, 8})
.Case(".8h", {8, 16})
.Case(".16b", {16, 8})
// Accept the width neutral ones, too, for verbose syntax. If those
// aren't used in the right places, the token operand won't match so
// all will work out.
.Case(".b", {0, 8})
.Case(".h", {0, 16})
.Case(".s", {0, 32})
.Case(".d", {0, 64})
.Default({-1, -1});
Res = StringSwitch<std::pair<int, int>>(Suffix.lower())
.Case("", {0, 0})
.Case(".1d", {1, 64})
.Case(".1q", {1, 128})
// '.2h' needed for fp16 scalar pairwise reductions
.Case(".2h", {2, 16})
.Case(".2b", {2, 8})
.Case(".2s", {2, 32})
.Case(".2d", {2, 64})
// '.4b' is another special case for the ARMv8.2a dot product
// operand
.Case(".4b", {4, 8})
.Case(".4h", {4, 16})
.Case(".4s", {4, 32})
.Case(".8b", {8, 8})
.Case(".8h", {8, 16})
.Case(".16b", {16, 8})
// Accept the width neutral ones, too, for verbose syntax. If
// those aren't used in the right places, the token operand won't
// match so all will work out.
.Case(".b", {0, 8})
.Case(".h", {0, 16})
.Case(".s", {0, 32})
.Case(".d", {0, 64})
.Default({-1, -1});
break;
case RegKind::SVEPredicateAsCounter:
case RegKind::SVEPredicateVector:
Expand Down Expand Up @@ -3641,6 +3652,8 @@ static const struct Extension {
{"fpmr", {AArch64::FeatureFPMR}},
{"fp8", {AArch64::FeatureFP8}},
{"faminmax", {AArch64::FeatureFAMINMAX}},
{"fp8fma", {AArch64::FeatureFP8FMA}},
{"ssve-fp8fma", {AArch64::FeatureSSVE_FP8FMA}},
};

static void setRequiredFeatureString(FeatureBitset FBS, std::string &Str) {
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/AArch64/Disassembler/AArch64Disassembler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ static DecodeStatus DecodeFPR128RegisterClass(MCInst &Inst, unsigned RegNo,
static DecodeStatus DecodeFPR128_loRegisterClass(MCInst &Inst, unsigned RegNo,
uint64_t Address,
const MCDisassembler *Decoder);
static DecodeStatus
DecodeFPR128_0to7RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Address,
const MCDisassembler *Decoder);
static DecodeStatus DecodeFPR64RegisterClass(MCInst &Inst, unsigned RegNo,
uint64_t Address,
const MCDisassembler *Decoder);
Expand Down Expand Up @@ -437,6 +440,14 @@ DecodeFPR128_loRegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
return DecodeFPR128RegisterClass(Inst, RegNo, Addr, Decoder);
}

static DecodeStatus
DecodeFPR128_0to7RegisterClass(MCInst &Inst, unsigned RegNo, uint64_t Addr,
const MCDisassembler *Decoder) {
if (RegNo > 7)
return Fail;
return DecodeFPR128RegisterClass(Inst, RegNo, Addr, Decoder);
}

static DecodeStatus DecodeFPR64RegisterClass(MCInst &Inst, unsigned RegNo,
uint64_t Addr,
const MCDisassembler *Decoder) {
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/AArch64/GISel/AArch64RegisterBankInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,10 @@ AArch64RegisterBankInfo::getRegBankFromRegClass(const TargetRegisterClass &RC,
case AArch64::FPR32_with_hsub_in_FPR16_loRegClassID:
case AArch64::FPR32RegClassID:
case AArch64::FPR64RegClassID:
case AArch64::FPR64_loRegClassID:
case AArch64::FPR128RegClassID:
case AArch64::FPR64_loRegClassID:
case AArch64::FPR128_loRegClassID:
case AArch64::FPR128_0to7RegClassID:
case AArch64::DDRegClassID:
case AArch64::DDDRegClassID:
case AArch64::DDDDRegClassID:
Expand Down
Loading

0 comments on commit 6477b41

Please sign in to comment.