-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CpuMath Enhancement: Make bound checking of loops in hardware intrinsics more efficient #2939
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -425,7 +425,7 @@ public static unsafe void AddScalarU(float scalar, Span<float> dst) | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> scalarVector256 = Vector256.Create(scalar); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pDstCurrent + 8 <= pDstEnd) | ||||||||||||||||||||||||||||||||||||
while (pDstCurrent <= pDstEnd - 8) | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the destination span is empty, wouldn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The spans passed into these functions are guaranteed not to be empty in the calling code.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is also guaranteed in release builds? The contracts routines look like they're surrounded by There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, it is an invariant that these methods are not called with empty spans. Even higher in the callstack (in many places) are the real checks. For example: machinelearning/src/Microsoft.ML.Transforms/GcnTransform.cs Lines 479 to 487 in d38a35e
Here, when On the "other side" of these CpuMathUtils (when are are compiling for machinelearning/src/Microsoft.ML.CpuMath/CpuMathUtils.netstandard.cs Lines 95 to 101 in a570da1
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also note - the CpuMathUtils methods are not public. |
||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent); | ||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tannergooding, do you know if there's any instruction with corresponds roughly to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, there is only load-forms, so this will compile down to two instructions (ideally):
|
||||||||||||||||||||||||||||||||||||
dstVector = Avx.Add(dstVector, scalarVector256); | ||||||||||||||||||||||||||||||||||||
|
@@ -577,7 +577,7 @@ public static unsafe void ScaleSrcU(float scale, ReadOnlySpan<float> src, Span<f | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> scaleVector256 = Vector256.Create(scale); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pDstCurrent + 8 <= pDstEnd) | ||||||||||||||||||||||||||||||||||||
while (pDstCurrent <= pDstEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
srcVector = Avx.Multiply(srcVector, scaleVector256); | ||||||||||||||||||||||||||||||||||||
|
@@ -623,7 +623,7 @@ public static unsafe void ScaleAddU(float a, float b, Span<float> dst) | |||||||||||||||||||||||||||||||||||
Vector256<float> a256 = Vector256.Create(a); | ||||||||||||||||||||||||||||||||||||
Vector256<float> b256 = Vector256.Create(b); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pDstCurrent + 8 <= pDstEnd) | ||||||||||||||||||||||||||||||||||||
while (pDstCurrent <= pDstEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent); | ||||||||||||||||||||||||||||||||||||
dstVector = Avx.Add(dstVector, b256); | ||||||||||||||||||||||||||||||||||||
|
@@ -671,7 +671,7 @@ public static unsafe void AddScaleU(float scale, ReadOnlySpan<float> src, Span<f | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> scaleVector256 = Vector256.Create(scale); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pDstCurrent + 8 <= pEnd) | ||||||||||||||||||||||||||||||||||||
while (pDstCurrent <= pEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
|
@@ -728,7 +728,7 @@ public static unsafe void AddScaleCopyU(float scale, ReadOnlySpan<float> src, Re | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> scaleVector256 = Vector256.Create(scale); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pResCurrent + 8 <= pResEnd) | ||||||||||||||||||||||||||||||||||||
while (pResCurrent <= pResEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent); | ||||||||||||||||||||||||||||||||||||
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector); | ||||||||||||||||||||||||||||||||||||
|
@@ -785,7 +785,7 @@ public static unsafe void AddScaleSU(float scale, ReadOnlySpan<float> src, ReadO | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> scaleVector256 = Vector256.Create(scale); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pIdxCurrent + 8 <= pEnd) | ||||||||||||||||||||||||||||||||||||
while (pIdxCurrent <= pEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent); | ||||||||||||||||||||||||||||||||||||
dstVector = MultiplyAdd(pSrcCurrent, scaleVector256, dstVector); | ||||||||||||||||||||||||||||||||||||
|
@@ -831,7 +831,7 @@ public static unsafe void AddU(ReadOnlySpan<float> src, Span<float> dst, int cou | |||||||||||||||||||||||||||||||||||
float* pDstCurrent = pdst; | ||||||||||||||||||||||||||||||||||||
float* pEnd = psrc + count; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent); | ||||||||||||||||||||||||||||||||||||
|
@@ -883,7 +883,7 @@ public static unsafe void AddSU(ReadOnlySpan<float> src, ReadOnlySpan<int> idx, | |||||||||||||||||||||||||||||||||||
float* pDstCurrent = pdst; | ||||||||||||||||||||||||||||||||||||
int* pEnd = pidx + count; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pIdxCurrent + 8 <= pEnd) | ||||||||||||||||||||||||||||||||||||
while (pIdxCurrent <= pEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Load8(pDstCurrent, pIdxCurrent); | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
|
@@ -931,7 +931,7 @@ public static unsafe void MulElementWiseU(ReadOnlySpan<float> src1, ReadOnlySpan | |||||||||||||||||||||||||||||||||||
float* pDstCurrent = pdst; | ||||||||||||||||||||||||||||||||||||
float* pEnd = pdst + count; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pDstCurrent + 8 <= pEnd) | ||||||||||||||||||||||||||||||||||||
while (pDstCurrent <= pEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> src1Vector = Avx.LoadVector256(pSrc1Current); | ||||||||||||||||||||||||||||||||||||
Vector256<float> src2Vector = Avx.LoadVector256(pSrc2Current); | ||||||||||||||||||||||||||||||||||||
|
@@ -1066,7 +1066,7 @@ public static unsafe float SumSqU(ReadOnlySpan<float> src) | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
result256 = MultiplyAdd(srcVector, srcVector, result256); | ||||||||||||||||||||||||||||||||||||
|
@@ -1111,7 +1111,7 @@ public static unsafe float SumSqDiffU(float mean, ReadOnlySpan<float> src) | |||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
Vector256<float> meanVector256 = Vector256.Create(mean); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
srcVector = Avx.Subtract(srcVector, meanVector256); | ||||||||||||||||||||||||||||||||||||
|
@@ -1158,7 +1158,7 @@ public static unsafe float SumAbsU(ReadOnlySpan<float> src) | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
result256 = Avx.Add(result256, Avx.And(srcVector, _absMask256)); | ||||||||||||||||||||||||||||||||||||
|
@@ -1203,7 +1203,7 @@ public static unsafe float SumAbsDiffU(float mean, ReadOnlySpan<float> src) | |||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
Vector256<float> meanVector256 = Vector256.Create(mean); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
srcVector = Avx.Subtract(srcVector, meanVector256); | ||||||||||||||||||||||||||||||||||||
|
@@ -1251,7 +1251,7 @@ public static unsafe float MaxAbsU(ReadOnlySpan<float> src) | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
result256 = Avx.Max(result256, Avx.And(srcVector, _absMask256)); | ||||||||||||||||||||||||||||||||||||
|
@@ -1296,7 +1296,7 @@ public static unsafe float MaxAbsDiffU(float mean, ReadOnlySpan<float> src) | |||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
Vector256<float> meanVector256 = Vector256.Create(mean); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Avx.LoadVector256(pSrcCurrent); | ||||||||||||||||||||||||||||||||||||
srcVector = Avx.Subtract(srcVector, meanVector256); | ||||||||||||||||||||||||||||||||||||
|
@@ -1348,7 +1348,7 @@ public static unsafe float DotU(ReadOnlySpan<float> src, ReadOnlySpan<float> dst | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> dstVector = Avx.LoadVector256(pDstCurrent); | ||||||||||||||||||||||||||||||||||||
result256 = MultiplyAdd(pSrcCurrent, dstVector, result256); | ||||||||||||||||||||||||||||||||||||
|
@@ -1405,7 +1405,7 @@ public static unsafe float DotSU(ReadOnlySpan<float> src, ReadOnlySpan<float> ds | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> result256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pIdxCurrent + 8 <= pIdxEnd) | ||||||||||||||||||||||||||||||||||||
while (pIdxCurrent <= pIdxEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> srcVector = Load8(pSrcCurrent, pIdxCurrent); | ||||||||||||||||||||||||||||||||||||
result256 = MultiplyAdd(pDstCurrent, srcVector, result256); | ||||||||||||||||||||||||||||||||||||
|
@@ -1459,7 +1459,7 @@ public static unsafe float Dist2(ReadOnlySpan<float> src, ReadOnlySpan<float> ds | |||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
Vector256<float> sqDistanceVector256 = Vector256<float>.Zero; | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> distanceVector = Avx.Subtract(Avx.LoadVector256(pSrcCurrent), | ||||||||||||||||||||||||||||||||||||
Avx.LoadVector256(pDstCurrent)); | ||||||||||||||||||||||||||||||||||||
|
@@ -1514,7 +1514,7 @@ public static unsafe void SdcaL1UpdateU(float primalUpdate, int count, ReadOnlyS | |||||||||||||||||||||||||||||||||||
Vector256<float> xPrimal256 = Vector256.Create(primalUpdate); | ||||||||||||||||||||||||||||||||||||
Vector256<float> xThreshold256 = Vector256.Create(threshold); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pSrcCurrent + 8 <= pSrcEnd) | ||||||||||||||||||||||||||||||||||||
while (pSrcCurrent <= pSrcEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> xDst1 = Avx.LoadVector256(pDst1Current); | ||||||||||||||||||||||||||||||||||||
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1); | ||||||||||||||||||||||||||||||||||||
|
@@ -1574,7 +1574,7 @@ public static unsafe void SdcaL1UpdateSU(float primalUpdate, int count, ReadOnly | |||||||||||||||||||||||||||||||||||
Vector256<float> xPrimal256 = Vector256.Create(primalUpdate); | ||||||||||||||||||||||||||||||||||||
Vector256<float> xThreshold = Vector256.Create(threshold); | ||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||
while (pIdxCurrent + 8 <= pIdxEnd) | ||||||||||||||||||||||||||||||||||||
while (pIdxCurrent <= pIdxEnd - 8) | ||||||||||||||||||||||||||||||||||||
{ | ||||||||||||||||||||||||||||||||||||
Vector256<float> xDst1 = Load8(pdst1, pIdxCurrent); | ||||||||||||||||||||||||||||||||||||
xDst1 = MultiplyAdd(pSrcCurrent, xPrimal256, xDst1); | ||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Explicitly cache the loop-end in a local to ensure it also isn't recomputed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed