Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvement in Count<T> extension #3548

Merged
6 commits merged into from
Nov 21, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,6 @@ public static nint Count<T>(ref T r0, nint length, T value)
/// Implements <see cref="Count{T}"/> with a sequential search.
/// </summary>
[Pure]
#if NETCOREAPP3_1
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
#endif
private static nint CountSequential<T>(ref T r0, nint length, T value)
where T : IEquatable<T>
{
Expand Down Expand Up @@ -132,9 +129,6 @@ private static nint CountSequential<T>(ref T r0, nint length, T value)
/// Implements <see cref="Count{T}"/> with a vectorized search.
/// </summary>
[Pure]
#if NETCOREAPP3_1
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
#endif
private static nint CountSimd<T>(ref T r0, nint length, T value)
where T : unmanaged, IEquatable<T>
{
Expand All @@ -161,6 +155,67 @@ private static nint CountSimd<T>(ref T r0, nint length, T value)

var partials = Vector<T>.Zero;

// Unrolled vectorized loop, with 8 unrolled iterations. We only run this when the
// current type T is at least 2 bytes in size, otherwise the average chunk length
// would always be too small to be able to trigger the unrolled loop, and the overall
// performance would just be slightly worse due to the additional conditional branches.
if (typeof(T) != typeof(sbyte))
{
while (chunkLength >= Vector<T>.Count * 8)
{
ref T ri0 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 0));
var vi0 = Unsafe.As<T, Vector<T>>(ref ri0);
var ve0 = Vector.Equals(vi0, vc);

partials -= ve0;

ref T ri1 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 1));
var vi1 = Unsafe.As<T, Vector<T>>(ref ri1);
var ve1 = Vector.Equals(vi1, vc);

partials -= ve1;

ref T ri2 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 2));
var vi2 = Unsafe.As<T, Vector<T>>(ref ri2);
var ve2 = Vector.Equals(vi2, vc);

partials -= ve2;

ref T ri3 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 3));
var vi3 = Unsafe.As<T, Vector<T>>(ref ri3);
var ve3 = Vector.Equals(vi3, vc);

partials -= ve3;

ref T ri4 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 4));
var vi4 = Unsafe.As<T, Vector<T>>(ref ri4);
var ve4 = Vector.Equals(vi4, vc);

partials -= ve4;

ref T ri5 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 5));
var vi5 = Unsafe.As<T, Vector<T>>(ref ri5);
var ve5 = Vector.Equals(vi5, vc);

partials -= ve5;

ref T ri6 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 6));
var vi6 = Unsafe.As<T, Vector<T>>(ref ri6);
var ve6 = Vector.Equals(vi6, vc);

partials -= ve6;

ref T ri7 = ref Unsafe.Add(ref r0, offset + (Vector<T>.Count * 7));
var vi7 = Unsafe.As<T, Vector<T>>(ref ri7);
var ve7 = Vector.Equals(vi7, vc);

partials -= ve7;

chunkLength -= Vector<T>.Count * 8;
offset += Vector<T>.Count * 8;
}
}

while (chunkLength >= Vector<T>.Count)
{
ref T ri = ref Unsafe.Add(ref r0, offset);
Expand Down Expand Up @@ -242,28 +297,22 @@ private static nint CountSimd<T>(ref T r0, nint length, T value)
private static unsafe nint GetUpperBound<T>()
where T : unmanaged
{
if (typeof(T) == typeof(byte) ||
typeof(T) == typeof(sbyte) ||
typeof(T) == typeof(bool))
if (typeof(T) == typeof(sbyte))
{
return sbyte.MaxValue;
}

if (typeof(T) == typeof(char) ||
typeof(T) == typeof(ushort) ||
typeof(T) == typeof(short))
if (typeof(T) == typeof(short))
{
return short.MaxValue;
}

if (typeof(T) == typeof(int) ||
typeof(T) == typeof(uint))
if (typeof(T) == typeof(int))
{
return int.MaxValue;
}

if (typeof(T) == typeof(long) ||
typeof(T) == typeof(ulong))
if (typeof(T) == typeof(long))
{
if (sizeof(nint) == sizeof(int))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ internal static partial class SpanHelper
/// <param name="length">The number of items to hash.</param>
/// <returns>The Djb2 value for the input sequence of items.</returns>
[Pure]
#if NETCOREAPP3_1
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
#endif
public static int GetDjb2HashCode<T>(ref T r0, nint length)
where T : notnull
{
Expand Down Expand Up @@ -87,9 +84,6 @@ public static int GetDjb2HashCode<T>(ref T r0, nint length)
/// faster than <see cref="GetDjb2HashCode{T}"/>, as it can parallelize much of the workload.
/// </remarks>
[Pure]
#if NETCOREAPP3_1
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
#endif
public static unsafe int GetDjb2LikeByteHash(ref byte r0, nint length)
{
int hash = 5381;
Expand Down