Skip to content

Commit

Permalink
[RISCV] Update implies for subtarget feature. (llvm#75824)
Browse files Browse the repository at this point in the history
PR llvm#75576 and llvm#75735 update some implies in
llvm/lib/Support/RISCVISAInfo.cpp, but both of them miss the subtarget
feature part.
This patch still preserve predicate HasStdExtZfhOrZfhmin and
HasStdExtZhinxOrZhinxmin, since they could make error message more
readable. ( Users might not know that zfh implies zfhmin.)
  • Loading branch information
yetingk authored Dec 19, 2023
1 parent c587171 commit cdc0392
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 40 deletions.
22 changes: 11 additions & 11 deletions llvm/lib/Target/RISCV/RISCVFeatures.td
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,15 @@ def HasStdExtZfhmin : Predicate<"Subtarget->hasStdExtZfhmin()">,
def FeatureStdExtZfh
: SubtargetFeature<"zfh", "HasStdExtZfh", "true",
"'Zfh' (Half-Precision Floating-Point)",
[FeatureStdExtF]>;
[FeatureStdExtZfhmin]>;
def HasStdExtZfh : Predicate<"Subtarget->hasStdExtZfh()">,
AssemblerPredicate<(all_of FeatureStdExtZfh),
"'Zfh' (Half-Precision Floating-Point)">;
def NoStdExtZfh : Predicate<"!Subtarget->hasStdExtZfh()">;

def HasStdExtZfhOrZfhmin
: Predicate<"Subtarget->hasStdExtZfhOrZfhmin()">,
AssemblerPredicate<(any_of FeatureStdExtZfh, FeatureStdExtZfhmin),
: Predicate<"Subtarget->hasStdExtZfhmin()">,
AssemblerPredicate<(all_of FeatureStdExtZfhmin),
"'Zfh' (Half-Precision Floating-Point) or "
"'Zfhmin' (Half-Precision Floating-Point Minimal)">;

Expand Down Expand Up @@ -146,15 +146,15 @@ def HasStdExtZhinxmin : Predicate<"Subtarget->hasStdExtZhinxmin()">,
def FeatureStdExtZhinx
: SubtargetFeature<"zhinx", "HasStdExtZhinx", "true",
"'Zhinx' (Half Float in Integer)",
[FeatureStdExtZfinx]>;
[FeatureStdExtZhinxmin]>;
def HasStdExtZhinx : Predicate<"Subtarget->hasStdExtZhinx()">,
AssemblerPredicate<(all_of FeatureStdExtZhinx),
"'Zhinx' (Half Float in Integer)">;
def NoStdExtZhinx : Predicate<"!Subtarget->hasStdExtZhinx()">;

def HasStdExtZhinxOrZhinxmin
: Predicate<"Subtarget->hasStdExtZhinx() || Subtarget->hasStdExtZhinxmin()">,
AssemblerPredicate<(any_of FeatureStdExtZhinx, FeatureStdExtZhinxmin),
: Predicate<"Subtarget->hasStdExtZhinxmin()">,
AssemblerPredicate<(all_of FeatureStdExtZhinxmin),
"'Zhinx' (Half Float in Integer) or "
"'Zhinxmin' (Half Float in Integer Minimal)">;

Expand Down Expand Up @@ -487,16 +487,16 @@ def HasStdExtZvfbfwma : Predicate<"Subtarget->hasStdExtZvfbfwma()">,

def HasVInstructionsBF16 : Predicate<"Subtarget->hasVInstructionsBF16()">;

def FeatureStdExtZvfh
: SubtargetFeature<"zvfh", "HasStdExtZvfh", "true",
"'Zvfh' (Vector Half-Precision Floating-Point)",
[FeatureStdExtZve32f, FeatureStdExtZfhmin]>;

def FeatureStdExtZvfhmin
: SubtargetFeature<"zvfhmin", "HasStdExtZvfhmin", "true",
"'Zvfhmin' (Vector Half-Precision Floating-Point Minimal)",
[FeatureStdExtZve32f]>;

def FeatureStdExtZvfh
: SubtargetFeature<"zvfh", "HasStdExtZvfh", "true",
"'Zvfh' (Vector Half-Precision Floating-Point)",
[FeatureStdExtZvfhmin, FeatureStdExtZfhmin]>;

def HasVInstructionsF16 : Predicate<"Subtarget->hasVInstructionsF16()">;

def HasVInstructionsF16Minimal : Predicate<"Subtarget->hasVInstructionsF16Minimal()">,
Expand Down
3 changes: 1 addition & 2 deletions llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -915,8 +915,7 @@ void RISCVDAGToDAGISel::Select(SDNode *Node) {
Opc = RISCV::FMV_H_X;
break;
case MVT::f16:
Opc =
Subtarget->hasStdExtZhinxOrZhinxmin() ? RISCV::COPY : RISCV::FMV_H_X;
Opc = Subtarget->hasStdExtZhinxmin() ? RISCV::COPY : RISCV::FMV_H_X;
break;
case MVT::f32:
Opc = Subtarget->hasStdExtZfinx() ? RISCV::COPY : RISCV::FMV_W_X;
Expand Down
32 changes: 16 additions & 16 deletions llvm/lib/Target/RISCV/RISCVISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
if (Subtarget.is64Bit() && RV64LegalI32)
addRegisterClass(MVT::i32, &RISCV::GPRRegClass);

if (Subtarget.hasStdExtZfhOrZfhmin())
if (Subtarget.hasStdExtZfhmin())
addRegisterClass(MVT::f16, &RISCV::FPR16RegClass);
if (Subtarget.hasStdExtZfbfmin())
addRegisterClass(MVT::bf16, &RISCV::FPR16RegClass);
if (Subtarget.hasStdExtF())
addRegisterClass(MVT::f32, &RISCV::FPR32RegClass);
if (Subtarget.hasStdExtD())
addRegisterClass(MVT::f64, &RISCV::FPR64RegClass);
if (Subtarget.hasStdExtZhinxOrZhinxmin())
if (Subtarget.hasStdExtZhinxmin())
addRegisterClass(MVT::f16, &RISCV::GPRF16RegClass);
if (Subtarget.hasStdExtZfinx())
addRegisterClass(MVT::f32, &RISCV::GPRF32RegClass);
Expand Down Expand Up @@ -439,7 +439,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
ISD::FCEIL, ISD::FFLOOR, ISD::FTRUNC, ISD::FRINT, ISD::FROUND,
ISD::FROUNDEVEN};

if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
if (Subtarget.hasStdExtZfhminOrZhinxmin())
setOperationAction(ISD::BITCAST, MVT::i16, Custom);

static const unsigned ZfhminZfbfminPromoteOps[] = {
Expand Down Expand Up @@ -469,7 +469,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
setOperationAction(ISD::FCOPYSIGN, MVT::bf16, Expand);
}

if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
if (Subtarget.hasStdExtZfhminOrZhinxmin()) {
if (Subtarget.hasStdExtZfhOrZhinx()) {
setOperationAction(FPLegalNodeTypes, MVT::f16, Legal);
setOperationAction(FPRndMode, MVT::f16,
Expand Down Expand Up @@ -1322,7 +1322,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
// Custom-legalize bitcasts from fixed-length vectors to scalar types.
setOperationAction(ISD::BITCAST, {MVT::i8, MVT::i16, MVT::i32, MVT::i64},
Custom);
if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
if (Subtarget.hasStdExtZfhminOrZhinxmin())
setOperationAction(ISD::BITCAST, MVT::f16, Custom);
if (Subtarget.hasStdExtFOrZfinx())
setOperationAction(ISD::BITCAST, MVT::f32, Custom);
Expand Down Expand Up @@ -1388,7 +1388,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,

if (Subtarget.hasStdExtZbkb())
setTargetDAGCombine(ISD::BITREVERSE);
if (Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
if (Subtarget.hasStdExtZfhminOrZhinxmin())
setTargetDAGCombine(ISD::SIGN_EXTEND_INREG);
if (Subtarget.hasStdExtFOrZfinx())
setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
Expand Down Expand Up @@ -2099,7 +2099,7 @@ bool RISCVTargetLowering::isFPImmLegal(const APFloat &Imm, EVT VT,
bool ForCodeSize) const {
bool IsLegalVT = false;
if (VT == MVT::f16)
IsLegalVT = Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin();
IsLegalVT = Subtarget.hasStdExtZfhminOrZhinxmin();
else if (VT == MVT::f32)
IsLegalVT = Subtarget.hasStdExtFOrZfinx();
else if (VT == MVT::f64)
Expand Down Expand Up @@ -2171,7 +2171,7 @@ MVT RISCVTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
// Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
// We might still end up using a GPR but that will be decided based on ABI.
if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() &&
!Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
!Subtarget.hasStdExtZfhminOrZhinxmin())
return MVT::f32;

MVT PartVT = TargetLowering::getRegisterTypeForCallingConv(Context, CC, VT);
Expand All @@ -2188,7 +2188,7 @@ unsigned RISCVTargetLowering::getNumRegistersForCallingConv(LLVMContext &Context
// Use f32 to pass f16 if it is legal and Zfh/Zfhmin is not enabled.
// We might still end up using a GPR but that will be decided based on ABI.
if (VT == MVT::f16 && Subtarget.hasStdExtFOrZfinx() &&
!Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin())
!Subtarget.hasStdExtZfhminOrZhinxmin())
return 1;

return TargetLowering::getNumRegistersForCallingConv(Context, CC, VT);
Expand Down Expand Up @@ -5761,7 +5761,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
EVT Op0VT = Op0.getValueType();
MVT XLenVT = Subtarget.getXLenVT();
if (VT == MVT::f16 && Op0VT == MVT::i16 &&
Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
Subtarget.hasStdExtZfhminOrZhinxmin()) {
SDValue NewOp0 = DAG.getNode(ISD::ANY_EXTEND, DL, XLenVT, Op0);
SDValue FPConv = DAG.getNode(RISCVISD::FMV_H_X, DL, MVT::f16, NewOp0);
return FPConv;
Expand Down Expand Up @@ -11527,11 +11527,11 @@ void RISCVTargetLowering::ReplaceNodeResults(SDNode *N,
EVT Op0VT = Op0.getValueType();
MVT XLenVT = Subtarget.getXLenVT();
if (VT == MVT::i16 && Op0VT == MVT::f16 &&
Subtarget.hasStdExtZfhOrZfhminOrZhinxOrZhinxmin()) {
Subtarget.hasStdExtZfhminOrZhinxmin()) {
SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
} else if (VT == MVT::i16 && Op0VT == MVT::bf16 &&
Subtarget.hasStdExtZfbfmin()) {
Subtarget.hasStdExtZfbfmin()) {
SDValue FPConv = DAG.getNode(RISCVISD::FMV_X_ANYEXTH, DL, XLenVT, Op0);
Results.push_back(DAG.getNode(ISD::TRUNCATE, DL, MVT::i16, FPConv));
} else if (VT == MVT::i32 && Op0VT == MVT::f32 && Subtarget.is64Bit() &&
Expand Down Expand Up @@ -18632,15 +18632,15 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
// TODO: Support fixed vectors up to XLen for P extension?
if (VT.isVector())
break;
if (VT == MVT::f16 && Subtarget.hasStdExtZhinxOrZhinxmin())
if (VT == MVT::f16 && Subtarget.hasStdExtZhinxmin())
return std::make_pair(0U, &RISCV::GPRF16RegClass);
if (VT == MVT::f32 && Subtarget.hasStdExtZfinx())
return std::make_pair(0U, &RISCV::GPRF32RegClass);
if (VT == MVT::f64 && Subtarget.hasStdExtZdinx() && !Subtarget.is64Bit())
return std::make_pair(0U, &RISCV::GPRPF64RegClass);
return std::make_pair(0U, &RISCV::GPRNoX0RegClass);
case 'f':
if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16)
if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16)
return std::make_pair(0U, &RISCV::FPR16RegClass);
if (Subtarget.hasStdExtF() && VT == MVT::f32)
return std::make_pair(0U, &RISCV::FPR32RegClass);
Expand Down Expand Up @@ -18753,7 +18753,7 @@ RISCVTargetLowering::getRegForInlineAsmConstraint(const TargetRegisterInfo *TRI,
}
if (VT == MVT::f32 || VT == MVT::Other)
return std::make_pair(FReg, &RISCV::FPR32RegClass);
if (Subtarget.hasStdExtZfhOrZfhmin() && VT == MVT::f16) {
if (Subtarget.hasStdExtZfhmin() && VT == MVT::f16) {
unsigned RegNo = FReg - RISCV::F0_F;
unsigned HReg = RISCV::F0_H + RegNo;
return std::make_pair(HReg, &RISCV::FPR16RegClass);
Expand Down Expand Up @@ -19100,7 +19100,7 @@ bool RISCVTargetLowering::shouldConvertFpToSat(unsigned Op, EVT FPVT,

switch (FPVT.getSimpleVT().SimpleTy) {
case MVT::f16:
return Subtarget.hasStdExtZfhOrZfhmin();
return Subtarget.hasStdExtZfhmin();
case MVT::f32:
return Subtarget.hasStdExtF();
case MVT::f64:
Expand Down
14 changes: 4 additions & 10 deletions llvm/lib/Target/RISCV/RISCVSubtarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,12 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
bool hasStdExtZvl() const { return ZvlLen != 0; }
bool hasStdExtFOrZfinx() const { return HasStdExtF || HasStdExtZfinx; }
bool hasStdExtDOrZdinx() const { return HasStdExtD || HasStdExtZdinx; }
bool hasStdExtZfhOrZfhmin() const { return HasStdExtZfh || HasStdExtZfhmin; }
bool hasStdExtZfhOrZhinx() const { return HasStdExtZfh || HasStdExtZhinx; }
bool hasStdExtZhinxOrZhinxmin() const {
return HasStdExtZhinx || HasStdExtZhinxmin;
}
bool hasStdExtZfhOrZfhminOrZhinxOrZhinxmin() const {
return hasStdExtZfhOrZfhmin() || hasStdExtZhinxOrZhinxmin();
bool hasStdExtZfhminOrZhinxmin() const {
return HasStdExtZfhmin || HasStdExtZhinxmin;
}
bool hasHalfFPLoadStoreMove() const {
return hasStdExtZfhOrZfhmin() || HasStdExtZfbfmin;
return HasStdExtZfhmin || HasStdExtZfbfmin;
}
bool is64Bit() const { return IsRV64; }
MVT getXLenVT() const {
Expand Down Expand Up @@ -201,9 +197,7 @@ class RISCVSubtarget : public RISCVGenSubtargetInfo {
// Vector codegen related methods.
bool hasVInstructions() const { return HasStdExtZve32x; }
bool hasVInstructionsI64() const { return HasStdExtZve64x; }
bool hasVInstructionsF16Minimal() const {
return HasStdExtZvfhmin || HasStdExtZvfh;
}
bool hasVInstructionsF16Minimal() const { return HasStdExtZvfhmin; }
bool hasVInstructionsF16() const { return HasStdExtZvfh; }
bool hasVInstructionsBF16() const { return HasStdExtZvfbfmin; }
bool hasVInstructionsF32() const { return HasStdExtZve32f; }
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ class RISCVTTIImpl : public BasicTTIImplBase<RISCVTTIImpl> {
return RISCVRegisterClass::GPRRC;

Type *ScalarTy = Ty->getScalarType();
if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhOrZfhmin()) ||
if ((ScalarTy->isHalfTy() && ST->hasStdExtZfhmin()) ||
(ScalarTy->isFloatTy() && ST->hasStdExtF()) ||
(ScalarTy->isDoubleTy() && ST->hasStdExtD())) {
return RISCVRegisterClass::FPRRC;
Expand Down

0 comments on commit cdc0392

Please sign in to comment.