Skip to content

Commit

Permalink
Improve Random{NumberGenerator}.GetItems/String for non-power of 2 ch…
Browse files Browse the repository at this point in the history
…oices

In .NET 9, we added an optimization to Random.GetItems and RandomNumberGenerator.GetItems/GetString that special-cases a power-of-2 number of choices that's <= 256. In such a case, we can avoid many trips to the RNG by requesting bytes in bulk, rather than requesting an Int32 per element. Each byte is masked to produce the index into the choices.

This PR extends that optimization to also cover non-power-of-2 choices. It can't just mask off the bits as in the power-of-2 case, but it can mask off some bits and then do rejection sampling, which on average still yields big wins.
  • Loading branch information
stephentoub committed Sep 19, 2024
1 parent 58f431b commit c193357
Show file tree
Hide file tree
Showing 3 changed files with 230 additions and 52 deletions.
100 changes: 74 additions & 26 deletions src/libraries/System.Private.CoreLib/src/System/Random.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,46 +197,94 @@ public void GetItems<T>(ReadOnlySpan<T> choices, Span<T> destination)
throw new ArgumentException(SR.Arg_EmptySpan, nameof(choices));
}

// The most expensive part of this operation is the call to get random data. We can
// do so potentially many fewer times if:
// - the instance was constructed as `new Random()` or is `Random.Shared`, such that it's not seeded nor is it
// a custom derived type. We don't want to observably change the deterministically-produced sequence from previous releases.
// - the number of choices is <= 256. This let's us get a single byte per choice.
// - the number of choices is a power of two. This let's us use a byte and simply mask off
// unnecessary bits cheaply rather than needing to use rejection sampling.
// In such a case, we can grab a bunch of random bytes in one call.
// The most expensive part of this operation is the call to get random data. If the number of
// choices is <= 256 (which is the majority use case), we can use a single byte per element,
// which means we can ammortize the cost of getting random data by getting random bytes in bulk.
// However, we can only do that if this instance is Random.Shared or an instance created with
// `new Random()`. If it was created with a seed, changing which members we call and how many
// times may result in a visible difference in the sequence of output items. Similarly if it's
// a derived instance, which overrides get called and when is observable.
ImplBase impl = _impl;
if ((impl is null || impl.GetType() == typeof(XoshiroImpl)) &&
BitOperations.IsPow2(choices.Length) &&
choices.Length <= 256)
{
Span<byte> randomBytes = stackalloc byte[512]; // arbitrary size, a balance between stack consumed and number of random calls required
while (!destination.IsEmpty)
// Get stack space to store random bytes. This size was chosen to balance between
// stack consumed and number of random calls required.
Span<byte> randomBytes = stackalloc byte[512];

if (BitOperations.IsPow2(choices.Length))
{
if (destination.Length < randomBytes.Length)
// To avoid bias, we can't just % all bytes to get them into range; that would cause
// the lower values to be more likely than the higher values. If the number of choices
// is a power of 2, though, we can just mask off the extraneous bits.

int mask = choices.Length - 1;

while (!destination.IsEmpty)
{
randomBytes = randomBytes.Slice(0, destination.Length);
// If this will be the last iteration, avoid over-requesting randomness.
if (destination.Length < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length);
}

NextBytes(randomBytes);

for (int i = 0; i < randomBytes.Length; i++)
{
destination[i] = choices[randomBytes[i] & mask];
}

destination = destination.Slice(randomBytes.Length);
}
}
else
{
// As the length isn't a power of two, we can't just mask off all extraneous bits, and
// instead need to do rejection sampling. However, we can mask off the irrelevant bits, which
// then reduces the chances of needing to reject a value.

NextBytes(randomBytes);
int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1;

int mask = choices.Length - 1;
for (int i = 0; i < randomBytes.Length; i++)
while (!destination.IsEmpty)
{
destination[i] = choices[randomBytes[i] & mask];
// Unlike in the IsPow2 case, where every byte will be used, some bytes here may
// be rejected. On average, half the bytes may be rejected, so we heuristically
// choose to shrink to twice the destination length.
if (destination.Length * 2 < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length * 2);
}

NextBytes(randomBytes);

int i = 0;
foreach (byte b in randomBytes)
{
if ((uint)i >= (uint)destination.Length)
{
break;
}

byte masked = (byte)(b & mask);
if (masked < (uint)choices.Length)
{
destination[i++] = choices[masked];
}
}

destination = destination.Slice(i);
}

destination = destination.Slice(randomBytes.Length);
}

return;
}

// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
else
{
destination[i] = choices[Next(choices.Length)];
// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
{
destination[i] = choices[Next(choices.Length)];
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -792,7 +792,7 @@ public static void GetItems_Buffer_ArgValidation()
}

[Fact]
public static void GetItems_Allocating_Array_Seeded()
public static void GetItems_Allocating_Array_Seeded_NonPower2()
{
Random random = new Random(0x70636A61);
byte[] items = new byte[] { 1, 2, 3 };
Expand All @@ -808,7 +808,7 @@ public static void GetItems_Allocating_Array_Seeded()
}

[Fact]
public static void GetItems_Allocating_Span_Seeded()
public static void GetItems_Allocating_Span_Seeded_NonPower2()
{
Random random = new Random(0x70636A61);
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3 };
Expand All @@ -824,7 +824,7 @@ public static void GetItems_Allocating_Span_Seeded()
}

[Fact]
public static void GetItems_Buffer_Seeded()
public static void GetItems_Buffer_Seeded_NonPower2()
{
Random random = new Random(0x70636A61);
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3 };
Expand All @@ -840,6 +840,90 @@ public static void GetItems_Buffer_Seeded()
AssertExtensions.SequenceEqual(new byte[] { 1, 1, 3, 1, 3, 2, 2 }, buffer);
}

[Fact]
public static void GetItems_Allocating_Array_Seeded_Power2()
{
Random random = new Random(0x70636A61);
byte[] items = new byte[] { 1, 2, 3, 4 };

byte[] result = random.GetItems(items, length: 7);
Assert.Equal(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, result);

result = random.GetItems(items, length: 7);
Assert.Equal(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, result);

result = random.GetItems(items, length: 7);
Assert.Equal(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, result);
}

[Fact]
public static void GetItems_Allocating_Span_Seeded_Power2()
{
Random random = new Random(0x70636A61);
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3, 4 };

byte[] result = random.GetItems(items, length: 7);
Assert.Equal(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, result);

result = random.GetItems(items, length: 7);
Assert.Equal(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, result);

result = random.GetItems(items, length: 7);
Assert.Equal(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, result);
}

[Fact]
public static void GetItems_Buffer_Seeded_Power2()
{
Random random = new Random(0x70636A61);
ReadOnlySpan<byte> items = new byte[] { 1, 2, 3, 4 };

Span<byte> buffer = stackalloc byte[7];
random.GetItems(items, buffer);
AssertExtensions.SequenceEqual(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, buffer);

random.GetItems(items, buffer);
AssertExtensions.SequenceEqual(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, buffer);

random.GetItems(items, buffer);
AssertExtensions.SequenceEqual(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, buffer);
}

[Theory]
[InlineData(0)]
[InlineData(1)]
[InlineData(2)]
[InlineData(3)]
[InlineData(4)]
public static void GetItems_AllValuesInRange(int mode)
{
Random random = mode switch
{
0 => new Random(),
1 => new Random(42),
2 => new SubRandom(),
3 => new SubRandom(42),
_ => Random.Shared,
};

foreach (int numItems in Enumerable.Range(1, 8).Append(300))
{
int[] items = Enumerable.Range(42, numItems).ToArray();
for (int length = 1; length <= 16; length++)
{
int[] result = random.GetItems(items, length: length);
Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1));

result = random.GetItems((ReadOnlySpan<int>)items, length: length);
Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1));

Array.Clear(result);
random.GetItems(items, (Span<int>)result);
Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1));
}
}
}

private static Random Create(bool derived, bool seeded) =>
(derived, seeded) switch
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,44 +346,90 @@ private static void GetHexStringCore(Span<char> destination, bool lowercase)

private static void GetItemsCore<T>(ReadOnlySpan<T> choices, Span<T> destination)
{
// The most expensive part of this operation is the call to get random data. We can
// do so potentially many fewer times if:
// - the number of choices is <= 256. This let's us get a single byte per choice.
// - the number of choices is a power of two. This let's us use a byte and simply mask off
// unnecessary bits cheaply rather than needing to use rejection sampling.
// In such a case, we can grab a bunch of random bytes in one call.
if (BitOperations.IsPow2(choices.Length) && choices.Length <= 256)
Debug.Assert(choices.Length > 0);

// The most expensive part of this operation is the call to get random data. If the number of
// choices is <= 256 (which is the majority use case), we can use a single byte per element,
// which means we can ammortize the cost of getting random data by getting random bytes in bulk.
if (choices.Length <= 256)
{
// Get stack space to store random bytes. This size was chosen to balance between
// stack consumed and number of random calls required.
Span<byte> randomBytes = stackalloc byte[512];

while (!destination.IsEmpty)
if (BitOperations.IsPow2(choices.Length))
{
if (destination.Length < randomBytes.Length)
// To avoid bias, we can't just % all bytes to get them into range; that would cause
// the lower values to be more likely than the higher values. If the number of choices
// is a power of 2, though, we can just mask off the extraneous bits.

int mask = choices.Length - 1;

while (!destination.IsEmpty)
{
randomBytes = randomBytes.Slice(0, destination.Length);
// If this will be the last iteration, avoid over-requesting randomness.
if (destination.Length < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length);
}

RandomNumberGeneratorImplementation.FillSpan(randomBytes);

for (int i = 0; i < randomBytes.Length; i++)
{
destination[i] = choices[randomBytes[i] & mask];
}

destination = destination.Slice(randomBytes.Length);
}
}
else
{
// As the length isn't a power of two, we can't just mask off all extraneous bits, and
// instead need to do rejection sampling. However, we can mask off the irrelevant bits, which
// then reduces the chances of needing to reject a value.

RandomNumberGeneratorImplementation.FillSpan(randomBytes);
int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1;

int mask = choices.Length - 1;
for (int i = 0; i < randomBytes.Length; i++)
while (!destination.IsEmpty)
{
destination[i] = choices[randomBytes[i] & mask];
// Unlike in the IsPow2 case, where every byte will be used, some bytes here may
// be rejected. On average, half the bytes may be rejected, so we heuristically
// choose to shrink to twice the destination length.
if (destination.Length * 2 < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length * 2);
}

RandomNumberGeneratorImplementation.FillSpan(randomBytes);

int i = 0;
foreach (byte b in randomBytes)
{
if ((uint)i >= (uint)destination.Length)
{
break;
}

byte masked = (byte)(b & mask);
if (masked < (uint)choices.Length)
{
destination[i++] = choices[masked];
}
}

destination = destination.Slice(i);
}

destination = destination.Slice(randomBytes.Length);
}

return;
}

// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
else
{
destination[i] = choices[GetInt32(choices.Length)];
// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
{
destination[i] = choices[GetInt32(choices.Length)];
}
}
}

Expand Down

0 comments on commit c193357

Please sign in to comment.