-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[NVPTX] Improve lowering of v4i8 #67866
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Make it a legal type and plumb through lowering of relevant instructions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually trying out the patch in triton it causes some invalid ptx to get emitted (see comment)
Verified that NVPTX tests pass with ptxas being able to compiler PTX produced by llc tests.
To make things work consisstently for v4i8, we need to implement other vector ops.
✅ With the latest revision this PR passed the C/C++ code formatter. |
I still need to test it on live code, though we do not have much code that would end up using v4i8. The generated PTX checked in llvm/test/CodeGen/NVPTX/i8x4-instructions.ll could use extra scrutiny. |
I ran the patch on our triton kernels and I don't see any functional problems left. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like it required quite a lot of cases to be handled :(
Thanks for doing this, it solves some of the problems triton had with latest LLVM. Changes look good to me.
I see one suspicious failure in tensorflow tests. I suspect I've messed something up in v4i8 comparison. |
Yup, there is a problem:
|
Resolved by 9821e90 |
Found another issue. We merge four independent byte loads with
|
clang-format failure on GitHub is weird -- it just silently exits with an error. The buildkite failure somewhere in RISC-V appears to be unrelated. |
I believe this may be causing failures for u/srem. See #69124 Edit: Also causing failures when you sign-extend the result of a <4 x i8> comparison. |
In [[NVPTX] Improve lowering of v4i8](cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: ([source](#67866 (comment))) > Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
In [[NVPTX] Improve lowering of v4i8](llvm@cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: ([source](llvm#67866 (comment))) > Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
In [[NVPTX] Improve lowering of v4i8](llvm@cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: ([source](llvm#67866 (comment))) > Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
Fix [failure](#110766 (comment)) identified by @akuegel. --- In [[NVPTX] Improve lowering of v4i8](cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: (#67866 (comment)) Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
Fix [failure](llvm#110766 (comment)) identified by @akuegel. --- In [[NVPTX] Improve lowering of v4i8](llvm@cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: (llvm#67866 (comment)) Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
Fix [failure](llvm#110766 (comment)) identified by @akuegel. --- In [[NVPTX] Improve lowering of v4i8](llvm@cbafb6f) @Artem-B add the ability to lower ISD::BUILD_VECTOR with bfi PTX instructions. @Artem-B did this because: (llvm#67866 (comment)) Under the hood byte extraction/insertion ends up as BFI/BFE instructions, so we may as well do that in PTX, too. https://godbolt.org/z/Tb3zWbj9b However, the example that @Artem-B linked was targeting sm_52. On modern architectures, ptxas uses prmt.b32. [Example](https://godbolt.org/z/Ye4W1n84o). Thus, remove uses of NVPTXISD::BFI in favor of NVPTXISD::PRMT.
@llvm/pr-subscribers-backend-nvptx Author: Artem Belevich (Artem-B) ChangesMake it a legal type and plumb through lowering of relevant instructions. Patch is 122.18 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/67866.diff 15 Files Affected:
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
index 5d27accdc198c..b7a20c351f5ff 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.cpp
@@ -309,3 +309,34 @@ void NVPTXInstPrinter::printProtoIdent(const MCInst *MI, int OpNum,
const MCSymbol &Sym = cast<MCSymbolRefExpr>(Expr)->getSymbol();
O << Sym.getName();
}
+
+void NVPTXInstPrinter::printPrmtMode(const MCInst *MI, int OpNum,
+ raw_ostream &O, const char *Modifier) {
+ const MCOperand &MO = MI->getOperand(OpNum);
+ int64_t Imm = MO.getImm();
+
+ switch (Imm) {
+ default:
+ return;
+ case NVPTX::PTXPrmtMode::NONE:
+ break;
+ case NVPTX::PTXPrmtMode::F4E:
+ O << ".f4e";
+ break;
+ case NVPTX::PTXPrmtMode::B4E:
+ O << ".b4e";
+ break;
+ case NVPTX::PTXPrmtMode::RC8:
+ O << ".rc8";
+ break;
+ case NVPTX::PTXPrmtMode::ECL:
+ O << ".ecl";
+ break;
+ case NVPTX::PTXPrmtMode::ECR:
+ O << ".ecr";
+ break;
+ case NVPTX::PTXPrmtMode::RC16:
+ O << ".rc16";
+ break;
+ }
+}
diff --git a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
index 49ad3f269229d..e6954f861cd10 100644
--- a/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
+++ b/llvm/lib/Target/NVPTX/MCTargetDesc/NVPTXInstPrinter.h
@@ -47,6 +47,8 @@ class NVPTXInstPrinter : public MCInstPrinter {
raw_ostream &O, const char *Modifier = nullptr);
void printProtoIdent(const MCInst *MI, int OpNum,
raw_ostream &O, const char *Modifier = nullptr);
+ void printPrmtMode(const MCInst *MI, int OpNum, raw_ostream &O,
+ const char *Modifier = nullptr);
};
}
diff --git a/llvm/lib/Target/NVPTX/NVPTX.h b/llvm/lib/Target/NVPTX/NVPTX.h
index 8dc68911fff0c..07ee34968b023 100644
--- a/llvm/lib/Target/NVPTX/NVPTX.h
+++ b/llvm/lib/Target/NVPTX/NVPTX.h
@@ -181,6 +181,18 @@ enum CmpMode {
FTZ_FLAG = 0x100
};
}
+
+namespace PTXPrmtMode {
+enum PrmtMode {
+ NONE,
+ F4E,
+ B4E,
+ RC8,
+ ECL,
+ ECR,
+ RC16,
+};
+}
}
void initializeNVPTXDAGToDAGISelPass(PassRegistry &);
} // namespace llvm
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
index 0aef2591c6e23..68391cdb6ff17 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp
@@ -14,6 +14,7 @@
#include "MCTargetDesc/NVPTXBaseInfo.h"
#include "NVPTXUtilities.h"
#include "llvm/Analysis/ValueTracking.h"
+#include "llvm/CodeGen/ISDOpcodes.h"
#include "llvm/IR/GlobalValue.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/IntrinsicsNVPTX.h"
@@ -829,6 +830,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
case MVT::v2f16:
case MVT::v2bf16:
case MVT::v2i16:
+ case MVT::v4i8:
return Opcode_i32;
case MVT::f32:
return Opcode_f32;
@@ -910,7 +912,8 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
// Vector Setting
unsigned vecType = NVPTX::PTXLdStInstCode::Scalar;
if (SimpleVT.isVector()) {
- assert(Isv2x16VT(LoadedVT) && "Unexpected vector type");
+ assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
+ "Unexpected vector type");
// v2f16/v2bf16/v2i16 is loaded using ld.b32
fromTypeWidth = 32;
}
@@ -1254,6 +1257,7 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
SDLoc DL(N);
SDNode *LD;
SDValue Base, Offset, Addr;
+ EVT OrigType = N->getValueType(0);
EVT EltVT = Mem->getMemoryVT();
unsigned NumElts = 1;
@@ -1261,12 +1265,15 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
NumElts = EltVT.getVectorNumElements();
EltVT = EltVT.getVectorElementType();
// vectors of 16bits type are loaded/stored as multiples of v2x16 elements.
- if ((EltVT == MVT::f16 && N->getValueType(0) == MVT::v2f16) ||
- (EltVT == MVT::bf16 && N->getValueType(0) == MVT::v2bf16) ||
- (EltVT == MVT::i16 && N->getValueType(0) == MVT::v2i16)) {
+ if ((EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
+ (EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
+ (EltVT == MVT::i16 && OrigType == MVT::v2i16)) {
assert(NumElts % 2 == 0 && "Vector must have even number of elements");
- EltVT = N->getValueType(0);
+ EltVT = OrigType;
NumElts /= 2;
+ } else if (OrigType == MVT::v4i8) {
+ EltVT = OrigType;
+ NumElts = 1;
}
}
@@ -1601,7 +1608,6 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
// concept of sign-/zero-extension, so emulate it here by adding an explicit
// CVT instruction. Ptxas should clean up any redundancies here.
- EVT OrigType = N->getValueType(0);
LoadSDNode *LdNode = dyn_cast<LoadSDNode>(N);
if (OrigType != EltVT &&
@@ -1679,7 +1685,8 @@ bool NVPTXDAGToDAGISel::tryStore(SDNode *N) {
MVT ScalarVT = SimpleVT.getScalarType();
unsigned toTypeWidth = ScalarVT.getSizeInBits();
if (SimpleVT.isVector()) {
- assert(Isv2x16VT(StoreVT) && "Unexpected vector type");
+ assert((Isv2x16VT(StoreVT) || StoreVT == MVT::v4i8) &&
+ "Unexpected vector type");
// v2x16 is stored using st.b32
toTypeWidth = 32;
}
diff --git a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
index b24aae4792ce6..36da2e7b40efa 100644
--- a/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
+++ b/llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp
@@ -221,6 +221,11 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
llvm_unreachable("Unexpected type");
}
NumElts /= 2;
+ } else if (EltVT.getSimpleVT() == MVT::i8 &&
+ (NumElts % 4 == 0 || NumElts == 3)) {
+ // v*i8 are formally lowered as v4i8
+ EltVT = MVT::v4i8;
+ NumElts = (NumElts + 3) / 4;
}
for (unsigned j = 0; j != NumElts; ++j) {
ValueVTs.push_back(EltVT);
@@ -458,6 +463,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
addRegisterClass(MVT::i1, &NVPTX::Int1RegsRegClass);
addRegisterClass(MVT::i16, &NVPTX::Int16RegsRegClass);
addRegisterClass(MVT::v2i16, &NVPTX::Int32RegsRegClass);
+ addRegisterClass(MVT::v4i8, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i32, &NVPTX::Int32RegsRegClass);
addRegisterClass(MVT::i64, &NVPTX::Int64RegsRegClass);
addRegisterClass(MVT::f32, &NVPTX::Float32RegsRegClass);
@@ -491,10 +497,26 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v2i16, Expand);
setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v2i16, Expand);
+ setOperationAction(ISD::BUILD_VECTOR, MVT::v4i8, Custom);
+ setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v4i8, Custom);
+ setOperationAction(ISD::INSERT_VECTOR_ELT, MVT::v4i8, Custom);
+ setOperationAction(ISD::VECTOR_SHUFFLE, MVT::v4i8, Custom);
+ // Only logical ops can be done on v4i8 directly, others must be done
+ // elementwise.
+ setOperationAction(
+ {ISD::ADD, ISD::MUL, ISD::ABS, ISD::SMIN,
+ ISD::SMAX, ISD::UMIN, ISD::UMAX, ISD::CTPOP,
+ ISD::CTLZ, ISD::ADD, ISD::SUB, ISD::MUL,
+ ISD::SHL, ISD::SREM, ISD::UREM, ISD::SDIV,
+ ISD::UDIV, ISD::SRA, ISD::SRL, ISD::MULHS,
+ ISD::MULHU, ISD::FP_TO_SINT, ISD::FP_TO_UINT, ISD::SINT_TO_FP,
+ ISD::UINT_TO_FP},
+ MVT::v4i8, Expand);
+
// Operations not directly supported by NVPTX.
- for (MVT VT :
- {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32, MVT::f64,
- MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::i32, MVT::i64}) {
+ for (MVT VT : {MVT::bf16, MVT::f16, MVT::v2bf16, MVT::v2f16, MVT::f32,
+ MVT::f64, MVT::i1, MVT::i8, MVT::i16, MVT::v2i16, MVT::v4i8,
+ MVT::i32, MVT::i64}) {
setOperationAction(ISD::SELECT_CC, VT, Expand);
setOperationAction(ISD::BR_CC, VT, Expand);
}
@@ -672,7 +694,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
// We have some custom DAG combine patterns for these nodes
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::FADD, ISD::MUL, ISD::SHL,
- ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT});
+ ISD::SREM, ISD::UREM, ISD::EXTRACT_VECTOR_ELT,
+ ISD::VSELECT});
// setcc for f16x2 and bf16x2 needs special handling to prevent
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -881,6 +904,12 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
return "NVPTXISD::FUN_SHFR_CLAMP";
case NVPTXISD::IMAD:
return "NVPTXISD::IMAD";
+ case NVPTXISD::BFE:
+ return "NVPTXISD::BFE";
+ case NVPTXISD::BFI:
+ return "NVPTXISD::BFI";
+ case NVPTXISD::PRMT:
+ return "NVPTXISD::PRMT";
case NVPTXISD::SETP_F16X2:
return "NVPTXISD::SETP_F16X2";
case NVPTXISD::Dummy:
@@ -2150,58 +2179,98 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
}
-// We can init constant f16x2 with a single .b32 move. Normally it
+// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
// would get lowered as two constant loads and vector-packing move.
-// mov.b16 %h1, 0x4000;
-// mov.b16 %h2, 0x3C00;
-// mov.b32 %hh2, {%h2, %h1};
// Instead we want just a constant move:
-// mov.b32 %hh2, 0x40003C00
-//
-// This results in better SASS code with CUDA 7.x. Ptxas in CUDA 8.0
-// generates good SASS in both cases.
+// mov.b32 %r2, 0x40003C00
SDValue NVPTXTargetLowering::LowerBUILD_VECTOR(SDValue Op,
SelectionDAG &DAG) const {
EVT VT = Op->getValueType(0);
- if (!(Isv2x16VT(VT)))
+ if (!(Isv2x16VT(VT) || VT == MVT::v4i8))
return Op;
- APInt E0;
- APInt E1;
- if (VT == MVT::v2f16 || VT == MVT::v2bf16) {
- if (!(isa<ConstantFPSDNode>(Op->getOperand(0)) &&
- isa<ConstantFPSDNode>(Op->getOperand(1))))
- return Op;
-
- E0 = cast<ConstantFPSDNode>(Op->getOperand(0))
- ->getValueAPF()
- .bitcastToAPInt();
- E1 = cast<ConstantFPSDNode>(Op->getOperand(1))
- ->getValueAPF()
- .bitcastToAPInt();
- } else {
- assert(VT == MVT::v2i16);
- if (!(isa<ConstantSDNode>(Op->getOperand(0)) &&
- isa<ConstantSDNode>(Op->getOperand(1))))
- return Op;
- E0 = cast<ConstantSDNode>(Op->getOperand(0))->getAPIntValue();
- E1 = cast<ConstantSDNode>(Op->getOperand(1))->getAPIntValue();
+ SDLoc DL(Op);
+
+ if (!llvm::all_of(Op->ops(), [](SDValue Operand) {
+ return Operand->isUndef() || isa<ConstantSDNode>(Operand) ||
+ isa<ConstantFPSDNode>(Operand);
+ })) {
+ // Lower non-const v4i8 vector as byte-wise constructed i32, which allows us
+ // to optimize calculation of constant parts.
+ if (VT == MVT::v4i8) {
+ SDValue C8 = DAG.getConstant(8, DL, MVT::i32);
+ SDValue E01 = DAG.getNode(
+ NVPTXISD::BFI, DL, MVT::i32,
+ DAG.getAnyExtOrTrunc(Op->getOperand(1), DL, MVT::i32),
+ DAG.getAnyExtOrTrunc(Op->getOperand(0), DL, MVT::i32), C8, C8);
+ SDValue E012 =
+ DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+ DAG.getAnyExtOrTrunc(Op->getOperand(2), DL, MVT::i32),
+ E01, DAG.getConstant(16, DL, MVT::i32), C8);
+ SDValue E0123 =
+ DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+ DAG.getAnyExtOrTrunc(Op->getOperand(3), DL, MVT::i32),
+ E012, DAG.getConstant(24, DL, MVT::i32), C8);
+ return DAG.getNode(ISD::BITCAST, DL, VT, E0123);
+ }
+ return Op;
}
- SDValue Const =
- DAG.getConstant(E1.zext(32).shl(16) | E0.zext(32), SDLoc(Op), MVT::i32);
+
+ // Get value or the Nth operand as an APInt(32). Undef values treated as 0.
+ auto GetOperand = [](SDValue Op, int N) -> APInt {
+ const SDValue &Operand = Op->getOperand(N);
+ EVT VT = Op->getValueType(0);
+ if (Operand->isUndef())
+ return APInt(32, 0);
+ APInt Value;
+ if (VT == MVT::v2f16 || VT == MVT::v2bf16)
+ Value = cast<ConstantFPSDNode>(Operand)->getValueAPF().bitcastToAPInt();
+ else if (VT == MVT::v2i16 || VT == MVT::v4i8)
+ Value = cast<ConstantSDNode>(Operand)->getAPIntValue();
+ else
+ llvm_unreachable("Unsupported type");
+ // i8 values are carried around as i16, so we need to zero out upper bits,
+ // so they do not get in the way of combining individual byte values
+ if (VT == MVT::v4i8)
+ Value = Value.trunc(8);
+ return Value.zext(32);
+ };
+ APInt Value;
+ if (Isv2x16VT(VT)) {
+ Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(16);
+ } else if (VT == MVT::v4i8) {
+ Value = GetOperand(Op, 0) | GetOperand(Op, 1).shl(8) |
+ GetOperand(Op, 2).shl(16) | GetOperand(Op, 3).shl(24);
+ } else {
+ llvm_unreachable("Unsupported type");
+ }
+ SDValue Const = DAG.getConstant(Value, SDLoc(Op), MVT::i32);
return DAG.getNode(ISD::BITCAST, SDLoc(Op), Op->getValueType(0), Const);
}
SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
SelectionDAG &DAG) const {
SDValue Index = Op->getOperand(1);
+ SDValue Vector = Op->getOperand(0);
+ SDLoc DL(Op);
+ EVT VectorVT = Vector.getValueType();
+
+ if (VectorVT == MVT::v4i8) {
+ SDValue BFE =
+ DAG.getNode(NVPTXISD::BFE, DL, MVT::i32,
+ {Vector,
+ DAG.getNode(ISD::MUL, DL, MVT::i32,
+ DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+ DAG.getConstant(8, DL, MVT::i32)),
+ DAG.getConstant(8, DL, MVT::i32)});
+ return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
+ }
+
// Constant index will be matched by tablegen.
if (isa<ConstantSDNode>(Index.getNode()))
return Op;
// Extract individual elements and select one of them.
- SDValue Vector = Op->getOperand(0);
- EVT VectorVT = Vector.getValueType();
assert(Isv2x16VT(VectorVT) && "Unexpected vector type.");
EVT EltVT = VectorVT.getVectorElementType();
@@ -2214,6 +2283,49 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
ISD::CondCode::SETEQ);
}
+SDValue NVPTXTargetLowering::LowerINSERT_VECTOR_ELT(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue Vector = Op->getOperand(0);
+ EVT VectorVT = Vector.getValueType();
+
+ if (VectorVT != MVT::v4i8)
+ return Op;
+ SDLoc DL(Op);
+ SDValue Value = Op->getOperand(1);
+ if (Value->isUndef())
+ return Vector;
+
+ SDValue Index = Op->getOperand(2);
+
+ SDValue BFI =
+ DAG.getNode(NVPTXISD::BFI, DL, MVT::i32,
+ {DAG.getZExtOrTrunc(Value, DL, MVT::i32), Vector,
+ DAG.getNode(ISD::MUL, DL, MVT::i32,
+ DAG.getZExtOrTrunc(Index, DL, MVT::i32),
+ DAG.getConstant(8, DL, MVT::i32)),
+ DAG.getConstant(8, DL, MVT::i32)});
+ return DAG.getNode(ISD::BITCAST, DL, Op->getValueType(0), BFI);
+}
+
+SDValue NVPTXTargetLowering::LowerVECTOR_SHUFFLE(SDValue Op,
+ SelectionDAG &DAG) const {
+ SDValue V1 = Op.getOperand(0);
+ EVT VectorVT = V1.getValueType();
+ if (VectorVT != MVT::v4i8 || Op.getValueType() != MVT::v4i8)
+ return Op;
+
+ // Lower shuffle to PRMT instruction.
+ const ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
+ SDValue V2 = Op.getOperand(1);
+ uint32_t Selector = 0;
+ for (auto I : llvm::enumerate(SVN->getMask()))
+ Selector |= (I.value() << (I.index() * 4));
+
+ SDLoc DL(Op);
+ return DAG.getNode(NVPTXISD::PRMT, DL, MVT::v4i8, V1, V2,
+ DAG.getConstant(Selector, DL, MVT::i32),
+ DAG.getConstant(NVPTX::PTXPrmtMode::NONE, DL, MVT::i32));
+}
/// LowerShiftRightParts - Lower SRL_PARTS, SRA_PARTS, which
/// 1) returns two i32 values and take a 2 x i32 value to shift plus a shift
/// amount, or
@@ -2464,6 +2576,10 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
return Op;
case ISD::EXTRACT_VECTOR_ELT:
return LowerEXTRACT_VECTOR_ELT(Op, DAG);
+ case ISD::INSERT_VECTOR_ELT:
+ return LowerINSERT_VECTOR_ELT(Op, DAG);
+ case ISD::VECTOR_SHUFFLE:
+ return LowerVECTOR_SHUFFLE(Op, DAG);
case ISD::CONCAT_VECTORS:
return LowerCONCAT_VECTORS(Op, DAG);
case ISD::STORE:
@@ -2578,9 +2694,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
if (Op.getValueType() == MVT::i1)
return LowerLOADi1(Op, DAG);
- // v2f16/v2bf16/v2i16 are legal, so we can't rely on legalizer to handle
+ // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
// unaligned loads and have to handle it here.
- if (Isv2x16VT(Op.getValueType())) {
+ EVT VT = Op.getValueType();
+ if (Isv2x16VT(VT) || VT == MVT::v4i8) {
LoadSDNode *Load = cast<LoadSDNode>(Op);
EVT MemVT = Load->getMemoryVT();
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -2625,13 +2742,13 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
// stores and have to handle it here.
- if (Isv2x16VT(VT) &&
+ if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
VT, *Store->getMemOperand()))
return expandUnalignedStore(Store, DAG);
// v2f16, v2bf16 and v2i16 don't need special handling.
- if (Isv2x16VT(VT))
+ if (Isv2x16VT(VT) || VT == MVT::v4i8)
return SDValue();
if (VT.isVector())
@@ -2903,7 +3020,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
EVT LoadVT = EltVT;
if (EltVT == MVT::i1)
LoadVT = MVT::i8;
- else if (Isv2x16VT(EltVT))
+ else if (Isv2x16VT(EltVT) || EltVT == MVT::v4i8)
// getLoad needs a vector type, but it can't handle
// vectors which contain v2f16 or v2bf16 elements. So we must load
// using i32 here and then bitcast back.
@@ -2929,7 +3046,7 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
if (EltVT == MVT::i1)
Elt = DAG.getNode(ISD::TRUNCATE, dl, MVT::i1, Elt);
// v2f16 was loaded as an i32. Now we must bitcast it back.
- else if (Isv2x16VT(EltVT))
+ else if (EltVT != LoadVT)
Elt = DAG.getNode(ISD::BITCAST, dl, EltVT, Elt);
// If a promoted integer type is used, truncate down to the original
@@ -4975,6 +5092,32 @@ static SDValue PerformANDCombine(SDNode *N,
}
SDValue AExt;
+
+ // Convert BFE-> truncate i16 -> and 255
+ // To just BFE-> truncate i16, as the value already has all the bits in the
+ // right places.
+ if (Val.getOpcode() == ISD::TRUNCATE) {
+ SDValue BFE = Val.getOperand(0);
+ if (BFE.getOpcode() != NVPTXISD::BFE)
+ return SDValue();
+
+ ConstantSDNode *BFEBits = dyn_cast<ConstantSDNode>(BFE.getOperand(0));
+ if (!BFEBits)
+ return SDValue();
+ uint64_t BFEBitsVal = BFEBits->getZExtValue();
+
+ ConstantSDNode *MaskCnst = dyn_cast<ConstantSDNode>(Mask);
+ if (!MaskCnst) {
+ // Not an AND with a constant
+ return SDValue();
+ }
+ uint64_t MaskVal = MaskCnst->getZExtValue();
+
+ if (MaskVal != (uint64_t(1) << BFEBitsVal) - 1)
+ return SDValue();
+ // If we get here, the AND is unnecessary. Just replace it with the trunc
+ DCI.CombineTo(N, Val, false);
+ }
// Generally, we will see zextload -> IMOV16rr -> ANY_EXTEND -> and
if (Val.getOpcode() == ISD::ANY_EXTEND) {
AExt = Val;
@@ -5254,13 +5397,15 @@ static SDValue PerformSETCCCombine(SDNode *N,
static SDValue PerformEXTRACTCombine(SDNode *N,
Tar...
[truncated]
|
Make it a legal type and plumb through lowering of relevant instructions.