From c193357e8cf6365b51d8e4e8fe3bfc23543e30af Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Thu, 19 Sep 2024 15:34:31 -0400 Subject: [PATCH] Improve Random{NumberGenerator}.GetItems/String for non-power of 2 choices In .NET 9, we added an optimization to Random.GetItems and RandomNumberGenerator.GetItems/GetString that special-cases a power-of-2 number of choices that's <= 256. In such a case, we can avoid many trips to the RNG by requesting bytes in bulk, rather than requesting an Int32 per element. Each byte is masked to produce the index into the choices. This PR extends that optimization to also cover non-power-of-2 choices. It can't just mask off the bits as in the power-of-2 case, but it can mask off some bits and then do rejection sampling, which on average still yields big wins. --- .../src/System/Random.cs | 100 +++++++++++++----- .../System/Random.cs | 90 +++++++++++++++- .../Cryptography/RandomNumberGenerator.cs | 92 ++++++++++++---- 3 files changed, 230 insertions(+), 52 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Random.cs b/src/libraries/System.Private.CoreLib/src/System/Random.cs index 39438be55c8f0..ac32aba7a68cd 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Random.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Random.cs @@ -197,46 +197,94 @@ public void GetItems(ReadOnlySpan choices, Span destination) throw new ArgumentException(SR.Arg_EmptySpan, nameof(choices)); } - // The most expensive part of this operation is the call to get random data. We can - // do so potentially many fewer times if: - // - the instance was constructed as `new Random()` or is `Random.Shared`, such that it's not seeded nor is it - // a custom derived type. We don't want to observably change the deterministically-produced sequence from previous releases. - // - the number of choices is <= 256. This let's us get a single byte per choice. - // - the number of choices is a power of two. This let's us use a byte and simply mask off - // unnecessary bits cheaply rather than needing to use rejection sampling. - // In such a case, we can grab a bunch of random bytes in one call. + // The most expensive part of this operation is the call to get random data. If the number of + // choices is <= 256 (which is the majority use case), we can use a single byte per element, + // which means we can ammortize the cost of getting random data by getting random bytes in bulk. + // However, we can only do that if this instance is Random.Shared or an instance created with + // `new Random()`. If it was created with a seed, changing which members we call and how many + // times may result in a visible difference in the sequence of output items. Similarly if it's + // a derived instance, which overrides get called and when is observable. ImplBase impl = _impl; if ((impl is null || impl.GetType() == typeof(XoshiroImpl)) && - BitOperations.IsPow2(choices.Length) && choices.Length <= 256) { - Span randomBytes = stackalloc byte[512]; // arbitrary size, a balance between stack consumed and number of random calls required - while (!destination.IsEmpty) + // Get stack space to store random bytes. This size was chosen to balance between + // stack consumed and number of random calls required. + Span randomBytes = stackalloc byte[512]; + + if (BitOperations.IsPow2(choices.Length)) { - if (destination.Length < randomBytes.Length) + // To avoid bias, we can't just % all bytes to get them into range; that would cause + // the lower values to be more likely than the higher values. If the number of choices + // is a power of 2, though, we can just mask off the extraneous bits. + + int mask = choices.Length - 1; + + while (!destination.IsEmpty) { - randomBytes = randomBytes.Slice(0, destination.Length); + // If this will be the last iteration, avoid over-requesting randomness. + if (destination.Length < randomBytes.Length) + { + randomBytes = randomBytes.Slice(0, destination.Length); + } + + NextBytes(randomBytes); + + for (int i = 0; i < randomBytes.Length; i++) + { + destination[i] = choices[randomBytes[i] & mask]; + } + + destination = destination.Slice(randomBytes.Length); } + } + else + { + // As the length isn't a power of two, we can't just mask off all extraneous bits, and + // instead need to do rejection sampling. However, we can mask off the irrelevant bits, which + // then reduces the chances of needing to reject a value. - NextBytes(randomBytes); + int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1; - int mask = choices.Length - 1; - for (int i = 0; i < randomBytes.Length; i++) + while (!destination.IsEmpty) { - destination[i] = choices[randomBytes[i] & mask]; + // Unlike in the IsPow2 case, where every byte will be used, some bytes here may + // be rejected. On average, half the bytes may be rejected, so we heuristically + // choose to shrink to twice the destination length. + if (destination.Length * 2 < randomBytes.Length) + { + randomBytes = randomBytes.Slice(0, destination.Length * 2); + } + + NextBytes(randomBytes); + + int i = 0; + foreach (byte b in randomBytes) + { + if ((uint)i >= (uint)destination.Length) + { + break; + } + + byte masked = (byte)(b & mask); + if (masked < (uint)choices.Length) + { + destination[i++] = choices[masked]; + } + } + + destination = destination.Slice(i); } - - destination = destination.Slice(randomBytes.Length); } - - return; } - - // Simple fallback: get each item individually, generating a new random Int32 for each - // item. This is slower than the above, but it works for all types and sizes of choices. - for (int i = 0; i < destination.Length; i++) + else { - destination[i] = choices[Next(choices.Length)]; + // Simple fallback: get each item individually, generating a new random Int32 for each + // item. This is slower than the above, but it works for all types and sizes of choices. + for (int i = 0; i < destination.Length; i++) + { + destination[i] = choices[Next(choices.Length)]; + } } } diff --git a/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Random.cs b/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Random.cs index 27cfa26920e06..f6cb50a52fc90 100644 --- a/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Random.cs +++ b/src/libraries/System.Runtime/tests/System.Runtime.Extensions.Tests/System/Random.cs @@ -792,7 +792,7 @@ public static void GetItems_Buffer_ArgValidation() } [Fact] - public static void GetItems_Allocating_Array_Seeded() + public static void GetItems_Allocating_Array_Seeded_NonPower2() { Random random = new Random(0x70636A61); byte[] items = new byte[] { 1, 2, 3 }; @@ -808,7 +808,7 @@ public static void GetItems_Allocating_Array_Seeded() } [Fact] - public static void GetItems_Allocating_Span_Seeded() + public static void GetItems_Allocating_Span_Seeded_NonPower2() { Random random = new Random(0x70636A61); ReadOnlySpan items = new byte[] { 1, 2, 3 }; @@ -824,7 +824,7 @@ public static void GetItems_Allocating_Span_Seeded() } [Fact] - public static void GetItems_Buffer_Seeded() + public static void GetItems_Buffer_Seeded_NonPower2() { Random random = new Random(0x70636A61); ReadOnlySpan items = new byte[] { 1, 2, 3 }; @@ -840,6 +840,90 @@ public static void GetItems_Buffer_Seeded() AssertExtensions.SequenceEqual(new byte[] { 1, 1, 3, 1, 3, 2, 2 }, buffer); } + [Fact] + public static void GetItems_Allocating_Array_Seeded_Power2() + { + Random random = new Random(0x70636A61); + byte[] items = new byte[] { 1, 2, 3, 4 }; + + byte[] result = random.GetItems(items, length: 7); + Assert.Equal(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, result); + + result = random.GetItems(items, length: 7); + Assert.Equal(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, result); + + result = random.GetItems(items, length: 7); + Assert.Equal(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, result); + } + + [Fact] + public static void GetItems_Allocating_Span_Seeded_Power2() + { + Random random = new Random(0x70636A61); + ReadOnlySpan items = new byte[] { 1, 2, 3, 4 }; + + byte[] result = random.GetItems(items, length: 7); + Assert.Equal(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, result); + + result = random.GetItems(items, length: 7); + Assert.Equal(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, result); + + result = random.GetItems(items, length: 7); + Assert.Equal(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, result); + } + + [Fact] + public static void GetItems_Buffer_Seeded_Power2() + { + Random random = new Random(0x70636A61); + ReadOnlySpan items = new byte[] { 1, 2, 3, 4 }; + + Span buffer = stackalloc byte[7]; + random.GetItems(items, buffer); + AssertExtensions.SequenceEqual(new byte[] { 4, 1, 4, 2, 4, 4, 4 }, buffer); + + random.GetItems(items, buffer); + AssertExtensions.SequenceEqual(new byte[] { 2, 2, 3, 1, 3, 3, 1 }, buffer); + + random.GetItems(items, buffer); + AssertExtensions.SequenceEqual(new byte[] { 2, 1, 4, 2, 4, 2, 2 }, buffer); + } + + [Theory] + [InlineData(0)] + [InlineData(1)] + [InlineData(2)] + [InlineData(3)] + [InlineData(4)] + public static void GetItems_AllValuesInRange(int mode) + { + Random random = mode switch + { + 0 => new Random(), + 1 => new Random(42), + 2 => new SubRandom(), + 3 => new SubRandom(42), + _ => Random.Shared, + }; + + foreach (int numItems in Enumerable.Range(1, 8).Append(300)) + { + int[] items = Enumerable.Range(42, numItems).ToArray(); + for (int length = 1; length <= 16; length++) + { + int[] result = random.GetItems(items, length: length); + Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1)); + + result = random.GetItems((ReadOnlySpan)items, length: length); + Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1)); + + Array.Clear(result); + random.GetItems(items, (Span)result); + Assert.All(result, b => Assert.InRange(b, 42, 42 + numItems - 1)); + } + } + } + private static Random Create(bool derived, bool seeded) => (derived, seeded) switch { diff --git a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/RandomNumberGenerator.cs b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/RandomNumberGenerator.cs index 7baed36c87f80..14d8635fd70fb 100644 --- a/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/RandomNumberGenerator.cs +++ b/src/libraries/System.Security.Cryptography/src/System/Security/Cryptography/RandomNumberGenerator.cs @@ -346,44 +346,90 @@ private static void GetHexStringCore(Span destination, bool lowercase) private static void GetItemsCore(ReadOnlySpan choices, Span destination) { - // The most expensive part of this operation is the call to get random data. We can - // do so potentially many fewer times if: - // - the number of choices is <= 256. This let's us get a single byte per choice. - // - the number of choices is a power of two. This let's us use a byte and simply mask off - // unnecessary bits cheaply rather than needing to use rejection sampling. - // In such a case, we can grab a bunch of random bytes in one call. - if (BitOperations.IsPow2(choices.Length) && choices.Length <= 256) + Debug.Assert(choices.Length > 0); + + // The most expensive part of this operation is the call to get random data. If the number of + // choices is <= 256 (which is the majority use case), we can use a single byte per element, + // which means we can ammortize the cost of getting random data by getting random bytes in bulk. + if (choices.Length <= 256) { // Get stack space to store random bytes. This size was chosen to balance between // stack consumed and number of random calls required. Span randomBytes = stackalloc byte[512]; - while (!destination.IsEmpty) + if (BitOperations.IsPow2(choices.Length)) { - if (destination.Length < randomBytes.Length) + // To avoid bias, we can't just % all bytes to get them into range; that would cause + // the lower values to be more likely than the higher values. If the number of choices + // is a power of 2, though, we can just mask off the extraneous bits. + + int mask = choices.Length - 1; + + while (!destination.IsEmpty) { - randomBytes = randomBytes.Slice(0, destination.Length); + // If this will be the last iteration, avoid over-requesting randomness. + if (destination.Length < randomBytes.Length) + { + randomBytes = randomBytes.Slice(0, destination.Length); + } + + RandomNumberGeneratorImplementation.FillSpan(randomBytes); + + for (int i = 0; i < randomBytes.Length; i++) + { + destination[i] = choices[randomBytes[i] & mask]; + } + + destination = destination.Slice(randomBytes.Length); } + } + else + { + // As the length isn't a power of two, we can't just mask off all extraneous bits, and + // instead need to do rejection sampling. However, we can mask off the irrelevant bits, which + // then reduces the chances of needing to reject a value. - RandomNumberGeneratorImplementation.FillSpan(randomBytes); + int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1; - int mask = choices.Length - 1; - for (int i = 0; i < randomBytes.Length; i++) + while (!destination.IsEmpty) { - destination[i] = choices[randomBytes[i] & mask]; + // Unlike in the IsPow2 case, where every byte will be used, some bytes here may + // be rejected. On average, half the bytes may be rejected, so we heuristically + // choose to shrink to twice the destination length. + if (destination.Length * 2 < randomBytes.Length) + { + randomBytes = randomBytes.Slice(0, destination.Length * 2); + } + + RandomNumberGeneratorImplementation.FillSpan(randomBytes); + + int i = 0; + foreach (byte b in randomBytes) + { + if ((uint)i >= (uint)destination.Length) + { + break; + } + + byte masked = (byte)(b & mask); + if (masked < (uint)choices.Length) + { + destination[i++] = choices[masked]; + } + } + + destination = destination.Slice(i); } - - destination = destination.Slice(randomBytes.Length); } - - return; } - - // Simple fallback: get each item individually, generating a new random Int32 for each - // item. This is slower than the above, but it works for all types and sizes of choices. - for (int i = 0; i < destination.Length; i++) + else { - destination[i] = choices[GetInt32(choices.Length)]; + // Simple fallback: get each item individually, generating a new random Int32 for each + // item. This is slower than the above, but it works for all types and sizes of choices. + for (int i = 0; i < destination.Length; i++) + { + destination[i] = choices[GetInt32(choices.Length)]; + } } }