Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Use masked pseudo peephole for reduction pseudos #71508

Merged
merged 1 commit into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVPseudos.td
Original file line number Diff line number Diff line change
Expand Up @@ -3213,7 +3213,8 @@ multiclass VPseudoTernaryWithTailPolicy<VReg RetClass,
defvar mx = MInfo.MX;
let isCommutable = Commutable in
def "_" # mx # "_E" # sew : VPseudoTernaryNoMaskWithPolicy<RetClass, Op1Class, Op2Class, Constraint>;
def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>;
def "_" # mx # "_E" # sew # "_MASK" : VPseudoTernaryMaskPolicy<RetClass, Op1Class, Op2Class, Constraint>,
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
}
}

Expand All @@ -3232,7 +3233,8 @@ multiclass VPseudoTernaryWithTailPolicyRoundingMode<VReg RetClass,
Op2Class, Constraint>;
def "_" # mx # "_E" # sew # "_MASK"
: VPseudoTernaryMaskPolicyRoundingMode<RetClass, Op1Class,
Op2Class, Constraint>;
Op2Class, Constraint>,
RISCVMaskedPseudo<MaskIdx=3, MaskAffectsRes=true>;
}
}

Expand Down
63 changes: 0 additions & 63 deletions llvm/lib/Target/RISCV/RISCVInstrInfoVVLPatterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -1381,16 +1381,6 @@ multiclass VPatReductionVL<SDNode vop, string instruction_name, bit is_float> {
foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
let Predicates = GetVTypePredicates<vti>.Predicates in {
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
(vti.Mask true_mask), VLOpFrag,
(XLenVT timm:$policy))),
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
(vti_m1.Vector VR:$merge),
(vti.Vector vti.RegClass:$rs1),
(vti_m1.Vector VR:$rs2),
GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;

def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
(vti.Mask V0), VLOpFrag,
Expand All @@ -1408,19 +1398,6 @@ multiclass VPatReductionVL_RM<SDNode vop, string instruction_name, bit is_float>
foreach vti = !if(is_float, AllFloatVectors, AllIntegerVectors) in {
defvar vti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # vti.SEW # "M1");
let Predicates = GetVTypePredicates<vti>.Predicates in {
def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
(vti.Mask true_mask), VLOpFrag,
(XLenVT timm:$policy))),
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
(vti_m1.Vector VR:$merge),
(vti.Vector vti.RegClass:$rs1),
(vti_m1.Vector VR:$rs2),
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
FRM_DYN,
GPR:$vl, vti.Log2SEW, (XLenVT timm:$policy))>;

def: Pat<(vti_m1.Vector (vop (vti_m1.Vector VR:$merge),
(vti.Vector vti.RegClass:$rs1), VR:$rs2,
(vti.Mask V0), VLOpFrag,
Expand Down Expand Up @@ -1486,14 +1463,6 @@ multiclass VPatWidenReductionVL<SDNode vop, PatFrags extop, string instruction_n
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
(XLenVT timm:$policy))),
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
(wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
(XLenVT timm:$policy))>;
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
VR:$rs2, (vti.Mask V0), VLOpFrag,
Expand All @@ -1513,18 +1482,6 @@ multiclass VPatWidenReductionVL_RM<SDNode vop, PatFrags extop, string instructio
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
(XLenVT timm:$policy))),
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
(wti_m1.Vector VR:$rs2),
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
FRM_DYN,
GPR:$vl, vti.Log2SEW,
(XLenVT timm:$policy))>;
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1))),
VR:$rs2, (vti.Mask V0), VLOpFrag,
Expand All @@ -1548,14 +1505,6 @@ multiclass VPatWidenReductionVL_Ext_VL<SDNode vop, PatFrags extop, string instru
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
(XLenVT timm:$policy))),
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
(wti_m1.Vector VR:$rs2), GPR:$vl, vti.Log2SEW,
(XLenVT timm:$policy))>;
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
VR:$rs2, (vti.Mask V0), VLOpFrag,
Expand All @@ -1575,18 +1524,6 @@ multiclass VPatWidenReductionVL_Ext_VL_RM<SDNode vop, PatFrags extop, string ins
defvar wti_m1 = !cast<VTypeInfo>(!if(is_float, "VF", "VI") # wti.SEW # "M1");
let Predicates = !listconcat(GetVTypePredicates<vti>.Predicates,
GetVTypePredicates<wti>.Predicates) in {
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
VR:$rs2, (vti.Mask true_mask), VLOpFrag,
(XLenVT timm:$policy))),
(!cast<Instruction>(instruction_name#"_VS_"#vti.LMul.MX#"_E"#vti.SEW)
(wti_m1.Vector VR:$merge), (vti.Vector vti.RegClass:$rs1),
(wti_m1.Vector VR:$rs2),
// Value to indicate no rounding mode change in
// RISCVInsertReadWriteCSR
FRM_DYN,
GPR:$vl, vti.Log2SEW,
(XLenVT timm:$policy))>;
def: Pat<(wti_m1.Vector (vop (wti_m1.Vector VR:$merge),
(wti.Vector (extop (vti.Vector vti.RegClass:$rs1), (vti.Mask true_mask), VLOpFrag)),
VR:$rs2, (vti.Mask V0), VLOpFrag,
Expand Down
14 changes: 4 additions & 10 deletions llvm/test/CodeGen/RISCV/rvv/rvv-peephole-vmerge-vops.ll
Original file line number Diff line number Diff line change
Expand Up @@ -1049,11 +1049,8 @@ define <vscale x 2 x float> @vfredusum(<vscale x 2 x float> %passthru, <vscale x
define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <vscale x 2 x i32> %x, <vscale x 2 x i32> %y, i64 %vl) {
; CHECK-LABEL: vredsum_allones_mask:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, ma
; CHECK-NEXT: vmv1r.v v11, v8
; CHECK-NEXT: vredsum.vs v11, v9, v10
; CHECK-NEXT: vsetvli zero, zero, e32, m1, tu, ma
; CHECK-NEXT: vmv.v.v v8, v11
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, ma
; CHECK-NEXT: vredsum.vs v8, v9, v10
; CHECK-NEXT: ret
%splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0
%mask = shufflevector <vscale x 2 x i1> %splat, <vscale x 2 x i1> poison, <vscale x 2 x i32> zeroinitializer
Expand All @@ -1070,12 +1067,9 @@ define <vscale x 2 x i32> @vredsum_allones_mask(<vscale x 2 x i32> %passthru, <v
define <vscale x 2 x float> @vfredusum_allones_mask(<vscale x 2 x float> %passthru, <vscale x 2 x float> %x, <vscale x 2 x float> %y, i64 %vl) {
; CHECK-LABEL: vfredusum_allones_mask:
; CHECK: # %bb.0:
; CHECK-NEXT: vsetvli zero, a0, e32, m1, ta, ma
; CHECK-NEXT: vsetvli zero, a0, e32, m1, tu, ma
; CHECK-NEXT: fsrmi a0, 0
; CHECK-NEXT: vmv1r.v v11, v8
; CHECK-NEXT: vfredusum.vs v11, v9, v10
; CHECK-NEXT: vsetvli zero, zero, e32, m1, tu, ma
; CHECK-NEXT: vmv.v.v v8, v11
; CHECK-NEXT: vfredusum.vs v8, v9, v10
; CHECK-NEXT: fsrm a0
; CHECK-NEXT: ret
%splat = insertelement <vscale x 2 x i1> poison, i1 -1, i32 0
Expand Down