Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CpuMath Enhancement: Make bound checking of loops in hardware intrinsics more efficient #2939

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 20 additions & 20 deletions src/Microsoft.ML.CpuMath/AvxIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst)

Vector256<float> scalarVector256 = Vector256.Create(scalar);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 8)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Explicitly cache the loop-end in a local to ensure it also isn't recomputed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the destination span is empty, wouldn't pDstEnd - 8 integer underflow back to a very large number, which could lead to an AV?

Copy link
Member

@eerhardt eerhardt Mar 27, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The spans passed into these functions are guaranteed not to be empty in the calling code.

Contracts.AssertNonEmpty(destination);

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is also guaranteed in release builds? The contracts routines look like they're surrounded by [Conditional("DEBUG")].

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is an invariant that these methods are not called with empty spans. Even higher in the callstack (in many places) are the real checks. For example:

int count = srcValues.Length;
int length = src.Length;
ectx.Assert(divisor >= 0);
if (count == 0)
{
VBufferUtils.Resize(ref dst, length, 0);
return;
}

Here, when count == 0, the code doesn't even get into CpuMathUtils.Add below - it exits early.

On the "other side" of these CpuMathUtils (when are are compiling for netstandard where C# intrinsics don't exist), we make similar assumptions:

Contracts.AssertNonEmpty(dst);
unsafe
{
fixed (float* pdst = &MemoryMarshal.GetReference(dst))
Thunk.AddScalarU(a, pdst, dst.Length);
}

MemoryMarshal.GetReference is going to fail if you pass in an empty span.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also note - the CpuMathUtils methods are not public.

{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tannergooding, do you know if there's any instruction with corresponds roughly to addfloats xmm ptr [rax], xmm0 and if this trio of instructions collapses to that? It seems like it'd be more efficient and would avoid the xmm register spill.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, there is only load-forms, so this will compile down to two instructions (ideally):

addps tmp, scalarVector128, [pDstCurrent]
movps [pDstCurrent], tmp

dstVector = Avx.Add(dstVector, scalarVector256);
Expand Down Expand Up @@ -577,7 +577,7 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<f

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Multiply(srcVector, scaleVector256);
Expand Down Expand Up @@ -623,7 +623,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst)
Vector256<float> a256 = Vector256.Create(a);
Vector256<float> b256 = Vector256.Create(b);

while (pDstCurrent + 8 <= pDstEnd)
while (pDstCurrent <= pDstEnd - 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
dstVector = Avx.Add(dstVector, b256);
Expand Down Expand Up @@ -671,7 +671,7 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<f

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pDstCurrent + 8 <= pEnd)
while (pDstCurrent <= pEnd - 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);

Expand Down Expand Up @@ -728,7 +728,7 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, Re

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pResCurrent + 8 <= pResEnd)
while (pResCurrent <= pResEnd - 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
Expand Down Expand Up @@ -785,7 +785,7 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadO

Vector256<float> scaleVector256 = Vector256.Create(scale);

while (pIdxCurrent + 8 <= pEnd)
while (pIdxCurrent <= pEnd - 8)
{
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector);
Expand Down Expand Up @@ -831,7 +831,7 @@ public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int cou
float* pDstCurrent = pdst;
float* pEnd = psrc + count;

while (pSrcCurrent + 8 <= pEnd)
while (pSrcCurrent <= pEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
Expand Down Expand Up @@ -883,7 +883,7 @@ public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx,
float* pDstCurrent = pdst;
int* pEnd = pidx + count;

while (pIdxCurrent + 8 <= pEnd)
while (pIdxCurrent <= pEnd - 8)
{
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent);
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
Expand Down Expand Up @@ -931,7 +931,7 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
float* pDstCurrent = pdst;
float* pEnd = pdst + count;

while (pDstCurrent + 8 <= pEnd)
while (pDstCurrent <= pEnd - 8)
{
Vector256<float> src1Vector = Avx.LoadVector256(pSrc1Current);
Vector256<float> src2Vector = Avx.LoadVector256(pSrc2Current);
Expand Down Expand Up @@ -1066,7 +1066,7 @@ public static unsafe float SumSqU(ReadOnlySpan<float> src)

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = MultiplyAdd(srcVector, srcVector, result256);
Expand Down Expand Up @@ -1111,7 +1111,7 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
Vector256<float> result256 = Vector256<float>.Zero;
Vector256<float> meanVector256 = Vector256.Create(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1158,7 +1158,7 @@ public static unsafe float SumAbsU(ReadOnlySpan<float> src)

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256));
Expand Down Expand Up @@ -1203,7 +1203,7 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
Vector256<float> result256 = Vector256<float>.Zero;
Vector256<float> meanVector256 = Vector256.Create(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1251,7 +1251,7 @@ public static unsafe float MaxAbsU(ReadOnlySpan<float> src)

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256));
Expand Down Expand Up @@ -1296,7 +1296,7 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
Vector256<float> result256 = Vector256<float>.Zero;
Vector256<float> meanVector256 = Vector256.Create(mean);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent);
srcVector = Avx.Subtract(srcVector, meanVector256);
Expand Down Expand Up @@ -1348,7 +1348,7 @@ public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst

Vector256<float> result256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent);
result256 = MultiplyAdd(pSrcCurrent, dstVector, result256);
Expand Down Expand Up @@ -1405,7 +1405,7 @@ public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> ds

Vector256<float> result256 = Vector256<float>.Zero;

while (pIdxCurrent + 8 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 8)
{
Vector256<float> srcVector = Load8(pSrcCurrent, pIdxCurrent);
result256 = MultiplyAdd(pDstCurrent, srcVector, result256);
Expand Down Expand Up @@ -1459,7 +1459,7 @@ public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> ds

Vector256<float> sqDistanceVector256 = Vector256<float>.Zero;

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent),
Avx.LoadVector256(pDstCurrent));
Expand Down Expand Up @@ -1514,7 +1514,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS
Vector256<float> xPrimal256 = Vector256.Create(primalUpdate);
Vector256<float> xThreshold256 = Vector256.Create(threshold);

while (pSrcCurrent + 8 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 8)
{
Vector256<float> xDst1 = Avx.LoadVector256(pDst1Current);
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
Expand Down Expand Up @@ -1574,7 +1574,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly
Vector256<float> xPrimal256 = Vector256.Create(primalUpdate);
Vector256<float> xThreshold = Vector256.Create(threshold);

while (pIdxCurrent + 8 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 8)
{
Vector256<float> xDst1 = Load8(pdst1, pIdxCurrent);
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1);
Expand Down
34 changes: 17 additions & 17 deletions src/Microsoft.ML.CpuMath/SseIntrinsics.cs
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<f

Vector128<float> scaleVector = Vector128.Create(scale);

while (pDstCurrent + 4 <= pEnd)
while (pDstCurrent <= pEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -609,7 +609,7 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, Re

Vector128<float> scaleVector = Vector128.Create(scale);

while (pResCurrent + 4 <= pResEnd)
while (pResCurrent <= pResEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -653,7 +653,7 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadO

Vector128<float> scaleVector = Vector128.Create(scale);

while (pIdxCurrent + 4 <= pEnd)
while (pIdxCurrent <= pEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Load4(pDstCurrent, pIdxCurrent);
Expand Down Expand Up @@ -687,7 +687,7 @@ public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int cou
float* pDstCurrent = pdst;
float* pEnd = psrc + count;

while (pSrcCurrent + 4 <= pEnd)
while (pSrcCurrent <= pEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -727,7 +727,7 @@ public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx,
float* pDstCurrent = pdst;
int* pEnd = pidx + count;

while (pIdxCurrent + 4 <= pEnd)
while (pIdxCurrent <= pEnd - 4)
{
Vector128<float> dstVector = Load4(pDstCurrent, pIdxCurrent);
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Expand Down Expand Up @@ -763,7 +763,7 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan
float* pDstCurrent = pdst;
float* pEnd = pdst + count;

while (pDstCurrent + 4 <= pEnd)
while (pDstCurrent <= pEnd - 4)
{
Vector128<float> src1Vector = Sse.LoadVector128(pSrc1Current);
Vector128<float> src2Vector = Sse.LoadVector128(pSrc2Current);
Expand Down Expand Up @@ -883,7 +883,7 @@ public static unsafe float SumSqU(ReadOnlySpan<float> src)

Vector128<float> result = Vector128<float>.Zero;

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Add(result, Sse.Multiply(srcVector, srcVector));
Expand Down Expand Up @@ -915,7 +915,7 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src)
Vector128<float> result = Vector128<float>.Zero;
Vector128<float> meanVector = Vector128.Create(mean);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
Expand Down Expand Up @@ -948,7 +948,7 @@ public static unsafe float SumAbsU(ReadOnlySpan<float> src)

Vector128<float> result = Vector128<float>.Zero;

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Add(result, Sse.And(srcVector, AbsMask128));
Expand Down Expand Up @@ -980,7 +980,7 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src)
Vector128<float> result = Vector128<float>.Zero;
Vector128<float> meanVector = Vector128.Create(mean);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
Expand Down Expand Up @@ -1013,7 +1013,7 @@ public static unsafe float MaxAbsU(ReadOnlySpan<float> src)

Vector128<float> result = Vector128<float>.Zero;

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
result = Sse.Max(result, Sse.And(srcVector, AbsMask128));
Expand Down Expand Up @@ -1045,7 +1045,7 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src)
Vector128<float> result = Vector128<float>.Zero;
Vector128<float> meanVector = Vector128.Create(mean);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
srcVector = Sse.Subtract(srcVector, meanVector);
Expand Down Expand Up @@ -1082,7 +1082,7 @@ public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst

Vector128<float> result = Vector128<float>.Zero;

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> srcVector = Sse.LoadVector128(pSrcCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -1126,7 +1126,7 @@ public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> ds

Vector128<float> result = Vector128<float>.Zero;

while (pIdxCurrent + 4 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 4)
{
Vector128<float> srcVector = Load4(pSrcCurrent, pIdxCurrent);
Vector128<float> dstVector = Sse.LoadVector128(pDstCurrent);
Expand Down Expand Up @@ -1167,7 +1167,7 @@ public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> ds

Vector128<float> sqDistanceVector = Vector128<float>.Zero;

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> distanceVector = Sse.Subtract(Sse.LoadVector128(pSrcCurrent),
Sse.LoadVector128(pDstCurrent));
Expand Down Expand Up @@ -1210,7 +1210,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS
Vector128<float> signMask = Vector128.Create(-0.0f); // 0x8000 0000
Vector128<float> xThreshold = Vector128.Create(threshold);

while (pSrcCurrent + 4 <= pSrcEnd)
while (pSrcCurrent <= pSrcEnd - 4)
{
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);

Expand Down Expand Up @@ -1255,7 +1255,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly
Vector128<float> signMask = Vector128.Create(-0.0f); // 0x8000 0000
Vector128<float> xThreshold = Vector128.Create(threshold);

while (pIdxCurrent + 4 <= pIdxEnd)
while (pIdxCurrent <= pIdxEnd - 4)
{
Vector128<float> xSrc = Sse.LoadVector128(pSrcCurrent);

Expand Down