Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimized string.Replace(char, char) #67049

Merged
merged 13 commits into from
Aug 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions src/libraries/Common/tests/Tests/System/StringTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4697,14 +4697,22 @@ public static void Remove_Invalid()
[InlineData("Aaaaaaaa", 'A', 'a', "aaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
// Single iteration of vectorised path; 3 remainders handled by vectorized path
[InlineData("AaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaa")]
// Single iteration of vectorized path; 0 remainders handled by vectorized path
[InlineData("aaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaa")]
// Eight chars before a match (copyLength > 0), single iteration of vectorized path for the remainder
[InlineData("12345678AAAAAAA", 'A', 'a', "12345678aaaaaaa")]
// ------------------------- For Vector<ushort>.Count == 16 (AVX2) -------------------------
[InlineData("AaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
// Single iteration of vectorised path; 3 remainders handled by vectorized path
[InlineData("AaaaaaaaAaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Single iteration of vectorized path; 0 remainders handled by vectorized path
[InlineData("aaaaaaaaaaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Sixteen chars before a match (copyLength > 0), single iteration of vectorized path for the remainder
[InlineData("1234567890123456AAAAAAAAAAAAAAA", 'A', 'a', "1234567890123456aaaaaaaaaaaaaaa")]
// ----------------------------------- General test data -----------------------------------
[InlineData("Hello", 'l', '!', "He!!o")] // 2 match, non-vectorised path
[InlineData("Hello", 'e', 'e', "Hello")] // oldChar and newChar are same; nothing to replace
Expand Down
16 changes: 16 additions & 0 deletions src/libraries/System.Private.CoreLib/src/System/Numerics/Vector.cs
Original file line number Diff line number Diff line change
Expand Up @@ -895,6 +895,14 @@ public static bool LessThanOrEqualAll<T>(Vector<T> left, Vector<T> right)
public static bool LessThanOrEqualAny<T>(Vector<T> left, Vector<T> right)
where T : struct => LessThanOrEqual(left, right).As<T, nuint>() != Vector<nuint>.Zero;

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static Vector<T> LoadUnsafe<T>(ref T source, nuint elementOffset)
where T : struct
{
source = ref Unsafe.Add(ref source, elementOffset);
return Unsafe.ReadUnaligned<Vector<T>>(ref Unsafe.As<T, byte>(ref source));
}

/// <summary>Computes the maximum of two vectors on a per-element basis.</summary>
/// <param name="left">The vector to compare with <paramref name="right" />.</param>
/// <param name="right">The vector to compare with <paramref name="left" />.</param>
Expand Down Expand Up @@ -1658,6 +1666,14 @@ public static Vector<T> SquareRoot<T>(Vector<T> value)
return result;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static void StoreUnsafe<T>(this Vector<T> source, ref T destination, nuint elementOffset)
where T : struct
{
destination = ref Unsafe.Add(ref destination, elementOffset);
Unsafe.WriteUnaligned(ref Unsafe.As<T, byte>(ref destination), source);
}

/// <summary>Subtracts two vectors to compute their difference.</summary>
/// <param name="left">The vector from which <paramref name="right" /> will be subtracted.</param>
/// <param name="right">The vector to subtract from <paramref name="left" />.</param>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,7 @@ public string Replace(char oldChar, char newChar)
if (firstIndex < 0)
return this;

int remainingLength = Length - firstIndex;
nuint remainingLength = (uint)(Length - firstIndex);
string result = FastAllocateString(Length);

int copyLength = firstIndex;
Expand All @@ -1006,35 +1006,56 @@ public string Replace(char oldChar, char newChar)
}

// Copy the remaining characters, doing the replacement as we go.
ref ushort pSrc = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref _firstChar), copyLength);
ref ushort pDst = ref Unsafe.Add(ref Unsafe.As<char, ushort>(ref result._firstChar), copyLength);
ref ushort pSrc = ref Unsafe.Add(ref GetRawStringDataAsUInt16(), (uint)copyLength);
ref ushort pDst = ref Unsafe.Add(ref result.GetRawStringDataAsUInt16(), (uint)copyLength);
nuint i = 0;

if (Vector.IsHardwareAccelerated && remainingLength >= Vector<ushort>.Count)
if (Vector.IsHardwareAccelerated && Length >= Vector<ushort>.Count)
{
Vector<ushort> oldChars = new Vector<ushort>(oldChar);
Vector<ushort> newChars = new Vector<ushort>(newChar);
Vector<ushort> oldChars = new(oldChar);
Vector<ushort> newChars = new(newChar);

do
Vector<ushort> original;
Vector<ushort> equals;
Vector<ushort> results;

if (remainingLength > (nuint)Vector<ushort>.Count)
{
Vector<ushort> original = Unsafe.ReadUnaligned<Vector<ushort>>(ref Unsafe.As<ushort, byte>(ref pSrc));
Vector<ushort> equals = Vector.Equals(original, oldChars);
Vector<ushort> results = Vector.ConditionalSelect(equals, newChars, original);
Unsafe.WriteUnaligned(ref Unsafe.As<ushort, byte>(ref pDst), results);

pSrc = ref Unsafe.Add(ref pSrc, Vector<ushort>.Count);
pDst = ref Unsafe.Add(ref pDst, Vector<ushort>.Count);
remainingLength -= Vector<ushort>.Count;
nuint lengthToExamine = remainingLength - (nuint)Vector<ushort>.Count;

do
{
original = Vector.LoadUnsafe(ref pSrc, i);
equals = Vector.Equals(original, oldChars);
results = Vector.ConditionalSelect(equals, newChars, original);
results.StoreUnsafe(ref pDst, i);

i += (nuint)Vector<ushort>.Count;
}
while (i < lengthToExamine);
}
while (remainingLength >= Vector<ushort>.Count);
}

for (; remainingLength > 0; remainingLength--)
{
ushort currentChar = pSrc;
pDst = currentChar == oldChar ? newChar : currentChar;
// There are [0, Vector<ushort>.Count) elements remaining now.
// As the operation is idempotent, and we know that in total there are at least Vector<ushort>.Count
// elements available, we read a vector from the very end of the string, perform the replace
// and write to the destination at the very end.
// Thus we can eliminate the scalar processing of the remaining elements.
// We perform this operation even if there are 0 elements remaining, as it is cheaper than the
// additional check which would introduce a branch here.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as it is cheaper than the additional check which would introduce a branch here.

Can you quantify this? Even with good branch prediction it's still more expensive?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's hard to pour this statement into numbers, as with a BDN-benchmark the branch predictor will very likely do a great job (they got really smart over the last generation of cpus).

In contrast to real-world usage I assume that it is more likely to have $&gt; 0$ elements remaining than having a remainder of $= 0$. In that case, and with the assumption that the branch predictor predictis $&gt; 0$ elements, the additional check (would be a test-instruction on x86) costs more than just executing the code (which needs to be done anyway).
So we penalize the case of having 0 elements remaining (which is assumed to be less likely), but all the data should be in the cache and cpu's memory system's store buffer should help to minimize that penalty.

When I start working on Vector128/256 support for string.Replace I'll try to examine that further, as there may be a code-path that starts with Vector256 where remainders will be processed by Vector128.


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps worth adding an assert that current Debug.Assert(this.Length - i <= Vector<ushort>.Count) to make sure we won't skip any data?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, I think in this case a test should fail?
I'll re-check the tests and make sure that case is covered.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests cover these cases, so I don't see a need for the Debug.Assert -- but I'll add it of course if you want.

// -------------------- For Vector<ushort>.Count == 8 (SSE2 / ARM NEON) --------------------
[InlineData("Aaaaaaaa", 'A', 'a', "aaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
[InlineData("AaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaa")]
// ------------------------- For Vector<ushort>.Count == 16 (AVX2) -------------------------
[InlineData("AaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaa")] // Single iteration of vectorised path; no remainders through non-vectorised path
// Three leading 'a's before a match (copyLength > 0), Single iteration of vectorised path; no remainders through non-vectorised path
[InlineData("aaaAaaaaaaaAaaaaaaa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// Single iteration of vectorised path; 3 remainders through non-vectorised path
[InlineData("AaaaaaaaAaaaaaaaaAa", 'A', 'a', "aaaaaaaaaaaaaaaaaaa")]
// ----------------------------------- General test data -----------------------------------

pSrc = ref Unsafe.Add(ref pSrc, 1);
pDst = ref Unsafe.Add(ref pDst, 1);
i = (uint)(Length - Vector<ushort>.Count);
original = Vector.LoadUnsafe(ref GetRawStringDataAsUInt16(), i);
equals = Vector.Equals(original, oldChars);
results = Vector.ConditionalSelect(equals, newChars, original);
results.StoreUnsafe(ref result.GetRawStringDataAsUInt16(), i);
}
else
{
for (; i < remainingLength; ++i)
{
ushort currentChar = Unsafe.Add(ref pSrc, i);
Unsafe.Add(ref pDst, i) = currentChar == oldChar ? newChar : currentChar;
}
}

return result;
Expand Down
1 change: 1 addition & 0 deletions src/libraries/System.Private.CoreLib/src/System/String.cs
Original file line number Diff line number Diff line change
Expand Up @@ -508,6 +508,7 @@ public static bool IsNullOrWhiteSpace([NotNullWhen(false)] string? value)
public ref readonly char GetPinnableReference() => ref _firstChar;

internal ref char GetRawStringData() => ref _firstChar;
internal ref ushort GetRawStringDataAsUInt16() => ref Unsafe.As<char, ushort>(ref _firstChar);

// Helper for encodings so they can talk to our buffer directly
// stringLength must be the exact size we'll expect
Expand Down