From f9088fd345082b8cf3941f86b7cab7ca9eb81910 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Fri, 6 Dec 2024 23:09:56 -0800 Subject: [PATCH] Ensure that TYP_SIMD16 and TYP_SIMD32 don't force mask usage unnecessarily --- src/coreclr/jit/hwintrinsic.h | 11 +- src/coreclr/jit/hwintrinsicxarch.cpp | 102 +-------- src/coreclr/jit/lowerxarch.cpp | 331 +++++++++++++++++++++++++++ 3 files changed, 344 insertions(+), 100 deletions(-) diff --git a/src/coreclr/jit/hwintrinsic.h b/src/coreclr/jit/hwintrinsic.h index d8bf386eb6009d..48c7eec6691125 100644 --- a/src/coreclr/jit/hwintrinsic.h +++ b/src/coreclr/jit/hwintrinsic.h @@ -546,12 +546,11 @@ struct HWIntrinsicInfo static bool isScalarIsa(CORINFO_InstructionSet isa); #ifdef TARGET_XARCH - static bool isAVX2GatherIntrinsic(NamedIntrinsic id); - static FloatComparisonMode lookupFloatComparisonModeForSwappedArgs(FloatComparisonMode comparison); - static NamedIntrinsic lookupIdForFloatComparisonMode(NamedIntrinsic intrinsic, - FloatComparisonMode comparison, - var_types simdBaseType, - unsigned simdSize); + static bool isAVX2GatherIntrinsic(NamedIntrinsic id); + static NamedIntrinsic lookupIdForFloatComparisonMode(NamedIntrinsic intrinsic, + FloatComparisonMode comparison, + var_types simdBaseType, + unsigned simdSize); #endif // Member lookup diff --git a/src/coreclr/jit/hwintrinsicxarch.cpp b/src/coreclr/jit/hwintrinsicxarch.cpp index 94574884a16e42..80aa199285184d 100644 --- a/src/coreclr/jit/hwintrinsicxarch.cpp +++ b/src/coreclr/jit/hwintrinsicxarch.cpp @@ -473,95 +473,6 @@ bool HWIntrinsicInfo::isAVX2GatherIntrinsic(NamedIntrinsic id) } } -//------------------------------------------------------------------------ -// lookupFloatComparisonModeForSwappedArgs: Get the floating-point comparison -// mode to use when the operands are swapped. -// -// Arguments: -// comparison -- The comparison mode used for (op1, op2) -// -// Return Value: -// The comparison mode to use for (op2, op1) -// -FloatComparisonMode HWIntrinsicInfo::lookupFloatComparisonModeForSwappedArgs(FloatComparisonMode comparison) -{ - switch (comparison) - { - // These comparison modes are the same even if the operands are swapped - - case FloatComparisonMode::OrderedEqualNonSignaling: - return FloatComparisonMode::OrderedEqualNonSignaling; - case FloatComparisonMode::UnorderedNonSignaling: - return FloatComparisonMode::UnorderedNonSignaling; - case FloatComparisonMode::UnorderedNotEqualNonSignaling: - return FloatComparisonMode::UnorderedNotEqualNonSignaling; - case FloatComparisonMode::OrderedNonSignaling: - return FloatComparisonMode::OrderedNonSignaling; - case FloatComparisonMode::UnorderedEqualNonSignaling: - return FloatComparisonMode::UnorderedEqualNonSignaling; - case FloatComparisonMode::OrderedFalseNonSignaling: - return FloatComparisonMode::OrderedFalseNonSignaling; - case FloatComparisonMode::OrderedNotEqualNonSignaling: - return FloatComparisonMode::OrderedNotEqualNonSignaling; - case FloatComparisonMode::UnorderedTrueNonSignaling: - return FloatComparisonMode::UnorderedTrueNonSignaling; - case FloatComparisonMode::OrderedEqualSignaling: - return FloatComparisonMode::OrderedEqualSignaling; - case FloatComparisonMode::UnorderedSignaling: - return FloatComparisonMode::UnorderedSignaling; - case FloatComparisonMode::UnorderedNotEqualSignaling: - return FloatComparisonMode::UnorderedNotEqualSignaling; - case FloatComparisonMode::OrderedSignaling: - return FloatComparisonMode::OrderedSignaling; - case FloatComparisonMode::UnorderedEqualSignaling: - return FloatComparisonMode::UnorderedEqualSignaling; - case FloatComparisonMode::OrderedFalseSignaling: - return FloatComparisonMode::OrderedFalseSignaling; - case FloatComparisonMode::OrderedNotEqualSignaling: - return FloatComparisonMode::OrderedNotEqualSignaling; - case FloatComparisonMode::UnorderedTrueSignaling: - return FloatComparisonMode::UnorderedTrueSignaling; - - // These comparison modes need a different mode if the operands are swapped - - case FloatComparisonMode::OrderedLessThanSignaling: - return FloatComparisonMode::OrderedGreaterThanSignaling; - case FloatComparisonMode::OrderedLessThanOrEqualSignaling: - return FloatComparisonMode::OrderedGreaterThanOrEqualSignaling; - case FloatComparisonMode::UnorderedNotLessThanSignaling: - return FloatComparisonMode::UnorderedNotGreaterThanSignaling; - case FloatComparisonMode::UnorderedNotLessThanOrEqualSignaling: - return FloatComparisonMode::UnorderedNotGreaterThanOrEqualSignaling; - case FloatComparisonMode::UnorderedNotGreaterThanOrEqualSignaling: - return FloatComparisonMode::UnorderedNotLessThanOrEqualSignaling; - case FloatComparisonMode::UnorderedNotGreaterThanSignaling: - return FloatComparisonMode::UnorderedNotLessThanSignaling; - case FloatComparisonMode::OrderedGreaterThanOrEqualSignaling: - return FloatComparisonMode::OrderedLessThanOrEqualSignaling; - case FloatComparisonMode::OrderedGreaterThanSignaling: - return FloatComparisonMode::OrderedLessThanSignaling; - case FloatComparisonMode::OrderedLessThanNonSignaling: - return FloatComparisonMode::OrderedGreaterThanNonSignaling; - case FloatComparisonMode::OrderedLessThanOrEqualNonSignaling: - return FloatComparisonMode::OrderedGreaterThanOrEqualNonSignaling; - case FloatComparisonMode::UnorderedNotLessThanNonSignaling: - return FloatComparisonMode::UnorderedNotGreaterThanNonSignaling; - case FloatComparisonMode::UnorderedNotLessThanOrEqualNonSignaling: - return FloatComparisonMode::UnorderedNotGreaterThanOrEqualNonSignaling; - case FloatComparisonMode::UnorderedNotGreaterThanOrEqualNonSignaling: - return FloatComparisonMode::UnorderedNotLessThanOrEqualNonSignaling; - case FloatComparisonMode::UnorderedNotGreaterThanNonSignaling: - return FloatComparisonMode::UnorderedNotLessThanNonSignaling; - case FloatComparisonMode::OrderedGreaterThanOrEqualNonSignaling: - return FloatComparisonMode::OrderedLessThanOrEqualNonSignaling; - case FloatComparisonMode::OrderedGreaterThanNonSignaling: - return FloatComparisonMode::OrderedLessThanNonSignaling; - - default: - unreached(); - } -} - //------------------------------------------------------------------------ // lookupIdForFloatComparisonMode: Get the intrinsic ID to use for a given float comparison mode // @@ -4858,8 +4769,6 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic, op3 = impPopStack().val; op3 = addRangeCheckIfNeeded(intrinsic, op3, immLowerBound, immUpperBound); - op2 = impSIMDPopStack(); - op1 = impSIMDPopStack(); if (op3->IsCnsIntOrI()) { @@ -4869,20 +4778,25 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic, if (id != intrinsic) { - intrinsic = id; - op3 = nullptr; + return impSpecialIntrinsic(id, clsHnd, method, sig R2RARG(&emptyEntryPoint), + simdBaseJitType, retType, simdSize, mustExpand); } } + op2 = impSIMDPopStack(); + op1 = impSIMDPopStack(); + if (op3 == nullptr) { retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, intrinsic, simdBaseJitType, simdSize); } else { - retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize); + } + retNode = gtNewSimdHWIntrinsicNode(retType, op1, op2, op3, intrinsic, simdBaseJitType, simdSize); + if (retType == TYP_MASK) { retType = getSIMDTypeForSize(simdSize); diff --git a/src/coreclr/jit/lowerxarch.cpp b/src/coreclr/jit/lowerxarch.cpp index 87528fee51bdfe..92d1218829a10c 100644 --- a/src/coreclr/jit/lowerxarch.cpp +++ b/src/coreclr/jit/lowerxarch.cpp @@ -2045,6 +2045,337 @@ GenTree* Lowering::LowerHWIntrinsic(GenTreeHWIntrinsic* node) break; } + case NI_EVEX_ConvertMaskToVector: + { + NamedIntrinsic id = NI_Illegal; + + unsigned simdSize = node->GetSimdSize(); + var_types simdBaseType = node->GetSimdBaseType(); + + if (simdSize == 64) + { + // Nothing to handle for TYP_SIMD64 as they require masks + break; + } + + GenTree* op1 = node->Op(1); + + if (!op1->OperIsHWIntrinsic()) + { + // We can only special case certain HWINTRINSIC nodes + break; + } + + GenTreeHWIntrinsic* op1Intrin = op1->AsHWIntrinsic(); + NamedIntrinsic op1IntrinId = op1Intrin->GetHWIntrinsicId(); + + switch (op1IntrinId) + { + case NI_EVEX_CompareEqualMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::OrderedEqualNonSignaling, + simdBaseType, simdSize); + } + else if (simdSize == 32) + { + id = NI_AVX2_CompareEqual; + } + else if (varTypeIsLong(simdBaseType)) + { + id = NI_SSE41_CompareEqual; + } + else + { + id = NI_SSE2_CompareEqual; + } + break; + } + + case NI_EVEX_CompareGreaterThanMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::OrderedGreaterThanSignaling, + simdBaseType, simdSize); + } + else if (varTypeIsUnsigned(simdBaseType)) + { + // Unsigned integer comparisons must use the EVEX instruction + break; + } + else if (simdSize == 32) + { + id = NI_AVX2_CompareGreaterThan; + } + else if (varTypeIsLong(simdBaseType)) + { + id = NI_SSE42_CompareGreaterThan; + } + else + { + id = NI_SSE2_CompareGreaterThan; + } + break; + } + + case NI_EVEX_CompareGreaterThanOrEqualMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::OrderedGreaterThanOrEqualSignaling, + simdBaseType, simdSize); + } + else + { + // Integer comparisons must use the EVEX instruction + } + break; + } + + case NI_EVEX_CompareLessThanMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::OrderedGreaterThanOrEqualSignaling, + simdBaseType, simdSize); + } + else if (varTypeIsUnsigned(simdBaseType)) + { + // Unsigned integer comparisons must use the EVEX instruction + break; + } + else if (simdSize == 32) + { + id = NI_AVX2_CompareLessThan; + } + else if (varTypeIsLong(simdBaseType)) + { + id = NI_SSE42_CompareLessThan; + } + else + { + id = NI_SSE2_CompareLessThan; + } + break; + } + + case NI_EVEX_CompareLessThanOrEqualMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::OrderedLessThanOrEqualSignaling, + simdBaseType, simdSize); + } + else + { + // Integer comparisons must use the EVEX instruction + } + break; + } + + case NI_EVEX_CompareNotEqualMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::UnorderedNotEqualNonSignaling, + simdBaseType, simdSize); + } + else + { + // Integer comparisons must use the EVEX instruction + } + break; + } + + case NI_EVEX_CompareNotGreaterThanMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::UnorderedNotGreaterThanSignaling, + simdBaseType, simdSize); + } + else + { + // Integer comparisons must use the EVEX instruction + // as this is the same as: LessThanOrEqual + } + break; + } + + case NI_EVEX_CompareNotGreaterThanOrEqualMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::UnorderedNotGreaterThanOrEqualSignaling, + simdBaseType, simdSize); + } + else if (varTypeIsUnsigned(simdBaseType)) + { + // Unsigned integer comparisons must use the EVEX instruction + // as this is the same as: LessThan + break; + } + else if (simdSize == 32) + { + id = NI_AVX2_CompareLessThan; + } + else if (varTypeIsLong(simdBaseType)) + { + id = NI_SSE42_CompareLessThan; + } + else + { + id = NI_SSE2_CompareLessThan; + } + break; + } + + case NI_EVEX_CompareNotLessThanMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::UnorderedNotLessThanSignaling, + simdBaseType, simdSize); + } + else + { + // Integer comparisons must use the EVEX instruction + // as this is the same as: GreaterThanOrEqual + } + break; + } + + case NI_EVEX_CompareNotLessThanOrEqualMask: + { + if (varTypeIsFloating(simdBaseType)) + { + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::UnorderedNotLessThanOrEqualSignaling, + simdBaseType, simdSize); + } + else if (varTypeIsUnsigned(simdBaseType)) + { + // Unsigned integer comparisons must use the EVEX instruction + // as this is the same as: GreaterThan + break; + } + else if (simdSize == 32) + { + id = NI_AVX2_CompareGreaterThan; + } + else if (varTypeIsLong(simdBaseType)) + { + id = NI_SSE42_CompareGreaterThan; + } + else + { + id = NI_SSE2_CompareGreaterThan; + } + break; + } + + case NI_EVEX_CompareOrderedMask: + { + assert(varTypeIsFloating(simdBaseType)); + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::OrderedNonSignaling, + simdBaseType, simdSize); + break; + } + + case NI_EVEX_CompareUnorderedMask: + { + assert(varTypeIsFloating(simdBaseType)); + id = HWIntrinsicInfo::lookupIdForFloatComparisonMode(NI_AVX_Compare, + FloatComparisonMode::UnorderedNonSignaling, + simdBaseType, simdSize); + break; + } + + default: + { + // Other cases get no special handling + break; + } + } + + if (id != NI_Illegal) + { + // We've remapped ConvertMaskToVector(Compare*Mask) to be simply + // Compare*, allowing us to avoid the additional conversion expense + + op1Intrin->gtType = node->TypeGet(); + op1Intrin->ChangeHWIntrinsicId(id); + + GenTree* nextNode = node->gtNext; + + LIR::Use use; + + if (BlockRange().TryGetUse(node, &use)) + { + use.ReplaceWith(op1Intrin); + } + else + { + op1Intrin->SetUnusedValue(); + } + + BlockRange().Remove(node); + return nextNode; + } + break; + } + + case NI_EVEX_BlendVariableMask: + { + unsigned simdSize = node->GetSimdSize(); + + if (simdSize == 64) + { + // Nothing to handle for TYP_SIMD64 as they require masks + break; + } + + GenTree* op3 = node->Op(3); + + if (!op3->OperIsConvertVectorToMask()) + { + // We can only special case when op3 is ConvertVectorToMask + break; + } + + // We have BlendVariableMask(op1, op2, ConvertVectorToMask(op3)) and + // so we'll rewrite it to BlendVariable(op1, op2, op3) allowing us + // to avoid the additional conversion all together + + var_types simdBaseType = node->GetSimdBaseType(); + + if (simdSize == 32) + { + intrinsicId = varTypeIsFloating(simdBaseType) ? NI_AVX_BlendVariable : NI_AVX2_BlendVariable; + } + else + { + intrinsicId = NI_SSE41_BlendVariable; + } + + node->ResetHWIntrinsicId(intrinsicId, comp, node->Op(1), node->Op(2), op3->AsHWIntrinsic()->Op(1)); + BlockRange().Remove(op3); + + return LowerNode(node); + } + case NI_EVEX_NotMask: { // We want to recognize ~(op1 ^ op2) and transform it