Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

JIT ARM64-SVE: Add AbsoluteCompare* APIs #102611

Closed
wants to merge 10 commits into from
3 changes: 2 additions & 1 deletion src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18290,7 +18290,8 @@ bool GenTree::canBeContained() const
}
else if (OperIsHWIntrinsic() && !isContainableHWIntrinsic())
{
return isEmbeddedMaskingCompatibleHWIntrinsic();
return isEmbeddedMaskingCompatibleHWIntrinsic() ||
HWIntrinsicInfo::SupportsContainment(AsHWIntrinsic()->GetHWIntrinsicId());
}

return true;
Expand Down
15 changes: 13 additions & 2 deletions src/coreclr/jit/hwintrinsiccodegenarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -542,9 +542,20 @@ void CodeGen::genHWIntrinsic(GenTreeHWIntrinsic* node)

case 2:
{
assert(instrIsRMW);
if (!instrIsRMW)
{
// Either this is VectorZero or ConvertVectorToMask(TrueMask, VectorZero)
assert(intrin.op3->IsVectorZero() ||
(intrin.op3->OperIs(GT_HWINTRINSIC) &&
intrin.op3->AsHWIntrinsic()->OperIsConvertVectorToMask() &&
intrin.op3->AsHWIntrinsic()->Op(2)->IsVectorZero()));

if (intrin.op3->IsVectorZero())
// Perform the actual "predicated" operation so that `embMaskOp1Reg` is the first operand
// and `embMaskOp2Reg` is the second operand.
GetEmitter()->emitIns_R_R_R_R(insEmbMask, emitSize, targetReg, maskReg, embMaskOp1Reg,
embMaskOp2Reg, opt);
}
else if (intrin.op3->IsVectorZero())
{
// If `falseReg` is zero, then move the first operand of `intrinEmbMask` in the
// destination using /Z.
Expand Down
8 changes: 6 additions & 2 deletions src/coreclr/jit/hwintrinsiclistarm64sve.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@

// Sve
HARDWARE_INTRINSIC(Sve, Abs, -1, -1, false, {INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_abs, INS_invalid, INS_sve_fabs, INS_sve_fabs}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, AbsoluteCompareGreaterThan, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_facgt, INS_sve_facgt}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, AbsoluteCompareGreaterThanOrEqual, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_facge, INS_sve_facge}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, AbsoluteCompareLessThan, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_faclt, INS_sve_faclt}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, AbsoluteCompareLessThanOrEqual, -1, -1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_sve_facle, INS_sve_facle}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, AbsoluteDifference, -1, -1, false, {INS_sve_sabd, INS_sve_uabd, INS_sve_sabd, INS_sve_uabd, INS_sve_sabd, INS_sve_uabd, INS_sve_sabd, INS_sve_uabd, INS_sve_fabd, INS_sve_fabd}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_EmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, Add, -1, -1, false, {INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_add, INS_sve_fadd, INS_sve_fadd}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_OptionalEmbeddedMaskedOperation|HW_Flag_HasRMWSemantics|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, AddAcross, -1, 1, true, {INS_sve_saddv, INS_sve_uaddv, INS_sve_saddv, INS_sve_uaddv, INS_sve_saddv, INS_sve_uaddv, INS_sve_uaddv, INS_sve_uaddv, INS_sve_faddv, INS_sve_faddv}, HW_Category_SIMD, HW_Flag_Scalable|HW_Flag_BaseTypeFromFirstArg|HW_Flag_EmbeddedMaskedOperation|HW_Flag_LowMaskedOperation)
Expand Down Expand Up @@ -210,8 +214,8 @@ HARDWARE_INTRINSIC(Sve, ZipLow,
// ***************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************************
// Special intrinsics that are generated during importing or lowering

HARDWARE_INTRINSIC(Sve, ConvertMaskToVector, -1, 1, true, {INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation)
HARDWARE_INTRINSIC(Sve, ConvertVectorToMask, -1, 2, true, {INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask|HW_Flag_LowMaskedOperation)
HARDWARE_INTRINSIC(Sve, ConvertMaskToVector, -1, 1, true, {INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov, INS_sve_mov}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ExplicitMaskedOperation| HW_Flag_SupportsContainment)
HARDWARE_INTRINSIC(Sve, ConvertVectorToMask, -1, 2, true, {INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne, INS_sve_cmpne}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask|HW_Flag_LowMaskedOperation| HW_Flag_SupportsContainment)
HARDWARE_INTRINSIC(Sve, CreateTrueMaskAll, -1, -1, false, {INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue, INS_sve_ptrue}, HW_Category_Helper, HW_Flag_Scalable|HW_Flag_ReturnsPerElementMask)
HARDWARE_INTRINSIC(Sve, StoreAndZipx2, -1, 3, true, {INS_sve_st2b, INS_sve_st2b, INS_sve_st2h, INS_sve_st2h, INS_sve_st2w, INS_sve_st2w, INS_sve_st2d, INS_sve_st2d, INS_sve_st2w, INS_sve_st2d}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_NeedsConsecutiveRegisters)
HARDWARE_INTRINSIC(Sve, StoreAndZipx3, -1, 3, true, {INS_sve_st3b, INS_sve_st3b, INS_sve_st3h, INS_sve_st3h, INS_sve_st3w, INS_sve_st3w, INS_sve_st3d, INS_sve_st3d, INS_sve_st3w, INS_sve_st3d}, HW_Category_MemoryStore, HW_Flag_Scalable|HW_Flag_SpecialCodeGen|HW_Flag_ExplicitMaskedOperation|HW_Flag_LowMaskedOperation|HW_Flag_NeedsConsecutiveRegisters)
Expand Down
63 changes: 60 additions & 3 deletions src/coreclr/jit/lowerarmarch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1303,17 +1303,32 @@ GenTree* Lowering::LowerHWIntrinsic(GenTreeHWIntrinsic* node)
{
GenTree* user = use.User();
// Wrap the intrinsic in ConditionalSelect only if it is not already inside another ConditionalSelect
if (!user->OperIsHWIntrinsic() || (user->AsHWIntrinsic()->GetHWIntrinsicId() != NI_Sve_ConditionalSelect))
// If it is inside ConditionalSelect, then make sure that it is the `mask` operation of it.
Comment on lines 1305 to +1306
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure I follow this entirely...

This is stating we need to wrap in a CndSel if the user isn't a CndSel or if it is but we're not the mask operand.

The !user->OperIsHWIntrinsic(NI_Sve_ConditionalSelect) seems fine with that regard but the HWIntrinsicInfo::ReturnsPerElementMask(node->GetHWIntrinsicId()) doesn't seem to fit.

Most notably we can have mask = CndSel(mask1, mask2, mask3) or vector CndSel(mask, vector1, vector2), but also I would have expected that intrinsics like Abs(x) which must be emitted with a mask to need CndSel regardless of where it appears and only that we could optimize it away in some cases like vector = CndSel(mask, vector1, CndSel(mask, vector2, vector3)) (where the outer and inner cndsel use identical masks that mean the latter CndSel becomes redundant as the computed operands would never be selected anyways).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I would have expected the latter condition to basically be conditioned differently, basically ensuring that intrinsics that require a CndSel wrapper always, optionally trying to fold in the case the mask usages can be detected as redundant

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's better to rewrite the comment to avoid confusion I guess:

          // Wrap the intrinsic in ConditionalSelect only if it is not already inside another ConditionalSelect
          // or if it is but it is the `mask` operand.

The goal of this block of code is to wrap HWIntrinsics which satisfy HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinsicId)) into ConditionalSelect. Usually this only needs to be done when the intrinsic is not already wrapped. But this logic doesn't hold when such intrinsic is used as the mask operand for ConditionalSelect - user->AsHWIntrinsic()->GetHWIntrinsicId() != NI_Sve_ConditionalSelect doesn't hold true any more, but the intrinsic still has to be wrapped ConditionalSelect (all that satisfy HWIntrinsicInfo::IsEmbeddedMaskedOperation(intrinsicId) have to).

if (!user->OperIsHWIntrinsic(NI_Sve_ConditionalSelect) ||
(HWIntrinsicInfo::ReturnsPerElementMask(node->GetHWIntrinsicId())))
{
CorInfoType simdBaseJitType = node->GetSimdBaseJitType();
unsigned simdSize = node->GetSimdSize();
var_types simdType = Compiler::getSIMDTypeForSize(simdSize);
GenTree* trueMask = comp->gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
GenTree* trueVal = node;
GenTree* falseVal = comp->gtNewZeroConNode(simdType);
GenTree* trueVal = node;
var_types nodeType = simdType;

if (HWIntrinsicInfo::ReturnsPerElementMask(node->GetHWIntrinsicId()))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to test this out to make sure this is correct. I will pull your branch once #103288 is merged.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am seeing the test DOTNET_TieredCompilation=0:

Assert failure(PID 18916 [0x000049e4], Thread: 28152 [0x6df8]): Assertion failed 'unreached' in 'JIT.HardwareIntrinsics.Arm._Sve.SimpleBinaryOpTest__Sve_AbsoluteCompareGreaterThan_float:ConditionalSelect_MethodMask():this' during 'Generate code' (IL size 109; hash 0x5f5a6ecd; FullOpts)

    File: D:\git\runtime\src\coreclr\jit\emitarm64sve.cpp:4396
    Image: D:\kpathak\Core_Root_absolutecompare\Core_Root\corerun.exe

This is missing a piece where if the cndSelNode is of type TYP_MASK, then the falseValue should also be TYP_MASK. Today, we do not support ZeroConNode of TYP_MASK, so I was thinking of wrapping falseVal with ConvertVectorToMask(), but for some reason that was failing in LIR validation that I didn't get a chance to verify. I will look tomorrow.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@a74nh, @mikabl-arm - I have fixed code to handle falseValue of TYP_MASK. PTAL.

Stress tests are passing: https://gist.github.com/kunalspathak/95ddd4913a6642bcbda7ddff9c33f752

The failure case RunLoadMask is because we are not preserving p0 across call (known problem).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

{
nodeType = TYP_MASK;

GenTree* trueMaskForOp1 = comp->gtNewSimdAllTrueMaskNode(simdBaseJitType, simdSize);
BlockRange().InsertBefore(node, trueMaskForOp1);
BlockRange().InsertBefore(node, falseVal);

falseVal = comp->gtNewSimdHWIntrinsicNode(TYP_MASK, trueMaskForOp1, falseVal,
NI_Sve_ConvertVectorToMask, simdBaseJitType, simdSize);
}

GenTreeHWIntrinsic* condSelNode =
comp->gtNewSimdHWIntrinsicNode(simdType, trueMask, trueVal, falseVal, NI_Sve_ConditionalSelect,
comp->gtNewSimdHWIntrinsicNode(nodeType, trueMask, trueVal, falseVal, NI_Sve_ConditionalSelect,
simdBaseJitType, simdSize);

BlockRange().InsertBefore(node, trueMask);
Expand Down Expand Up @@ -3365,6 +3380,32 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
}

// Handle op3
if (op1->IsMaskAllBitsSet())
{
bool isZeroValue = false;
if (op3->IsVectorZero())
{
isZeroValue = true;
}
else if (op3->OperIsHWIntrinsic(NI_Sve_ConvertVectorToMask))
{
GenTreeHWIntrinsic* gtVectorToMask = op3->AsHWIntrinsic();

if (gtVectorToMask->Op(1)->IsMaskAllBitsSet() && gtVectorToMask->Op(2)->IsVectorZero())
{
isZeroValue = true;
}
}

if (isZeroValue)
{
// When we are merging with zero, we can specialize
// and avoid instantiating the vector constant.
// Do this only if op1 was AllTrueMask
MakeSrcContained(node, op3);
}
}

if (op3->IsVectorZero() && op1->IsMaskAllBitsSet())
{
// When we are merging with zero, we can specialize
Expand Down Expand Up @@ -3410,6 +3451,22 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
MakeSrcContained(node, intrin.op3);
}
break;
case NI_Sve_ConvertVectorToMask:
assert(varTypeIsMask(intrin.op1));
assert(varTypeIsSIMD(intrin.op2));
if (intrin.op1->IsMaskAllBitsSet() && intrin.op2->IsVectorZero())
{
MakeSrcContained(node, intrin.op2);
}
break;

case NI_Sve_ConvertMaskToVector:
assert(varTypeIsMask(intrin.op1));
if (intrin.op1->IsVectorZero())
{
MakeSrcContained(node, intrin.op1);
}
break;

default:
unreached();
Expand Down
4 changes: 3 additions & 1 deletion src/coreclr/jit/lsraarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1997,8 +1997,10 @@ int LinearScan::BuildHWIntrinsic(GenTreeHWIntrinsic* intrinsicTree, int* pDstCou
{
SingleTypeRegSet candidates = lowVectorOperandNum == 2 ? lowVectorCandidates : RBM_NONE;

if (intrin.op2->gtType == TYP_MASK)
if (intrin.op2->OperIsHWIntrinsic(NI_Sve_ConvertVectorToMask))
{
// Have RBM_ALLMASK candidates only if op2 is VectorToMask
assert(intrin.op2->gtType == TYP_MASK);
assert(lowVectorOperandNum != 2);
candidates = RBM_ALLMASK.GetPredicateRegSet();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,66 @@ internal Arm64() { }
/// </summary>
public static unsafe Vector<ulong> AbsoluteDifference(Vector<ulong> left, Vector<ulong> right) { throw new PlatformNotSupportedException(); }

/// Absolute compare greater than

/// <summary>
/// svbool_t svacgt[_f32](svbool_t pg, svfloat32_t op1, svfloat32_t op2)
/// FACGT Presult.S, Pg/Z, Zop1.S, Zop2.S
/// </summary>
public static unsafe Vector<float> AbsoluteCompareGreaterThan(Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svacgt[_f64](svbool_t pg, svfloat64_t op1, svfloat64_t op2)
/// FACGT Presult.D, Pg/Z, Zop1.D, Zop2.D
/// </summary>
public static unsafe Vector<double> AbsoluteCompareGreaterThan(Vector<double> left, Vector<double> right) { throw new PlatformNotSupportedException(); }


/// Absolute compare greater than or equal to

/// <summary>
/// svbool_t svacge[_f32](svbool_t pg, svfloat32_t op1, svfloat32_t op2)
/// FACGE Presult.S, Pg/Z, Zop1.S, Zop2.S
/// </summary>
public static unsafe Vector<float> AbsoluteCompareGreaterThanOrEqual(Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svacge[_f64](svbool_t pg, svfloat64_t op1, svfloat64_t op2)
/// FACGE Presult.D, Pg/Z, Zop1.D, Zop2.D
/// </summary>
public static unsafe Vector<double> AbsoluteCompareGreaterThanOrEqual(Vector<double> left, Vector<double> right) { throw new PlatformNotSupportedException(); }


/// Absolute compare less than

/// <summary>
/// svbool_t svaclt[_f32](svbool_t pg, svfloat32_t op1, svfloat32_t op2)
/// FACLT Presult.S, Pg/Z, Zop1.S, Zop2.S
/// </summary>
public static unsafe Vector<float> AbsoluteCompareLessThan(Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svaclt[_f64](svbool_t pg, svfloat64_t op1, svfloat64_t op2)
/// FACLT Presult.D, Pg/Z, Zop1.D, Zop2.D
/// </summary>
public static unsafe Vector<double> AbsoluteCompareLessThan(Vector<double> left, Vector<double> right) { throw new PlatformNotSupportedException(); }


/// Absolute compare less than or equal to

/// <summary>
/// svbool_t svacle[_f32](svbool_t pg, svfloat32_t op1, svfloat32_t op2)
/// FACLE Presult.S, Pg/Z, Zop1.S, Zop2.S
/// </summary>
public static unsafe Vector<float> AbsoluteCompareLessThanOrEqual(Vector<float> left, Vector<float> right) { throw new PlatformNotSupportedException(); }

/// <summary>
/// svbool_t svacle[_f64](svbool_t pg, svfloat64_t op1, svfloat64_t op2)
/// FACLE Presult.D, Pg/Z, Zop1.D, Zop2.D
/// </summary>
public static unsafe Vector<double> AbsoluteCompareLessThanOrEqual(Vector<double> left, Vector<double> right) { throw new PlatformNotSupportedException(); }


/// Add : Add

/// <summary>
Expand Down
Loading
Loading