From 113d1e671c045b6ba5aad9eb84ca90bf0cfdfa6a Mon Sep 17 00:00:00 2001 From: Levi Broderick Date: Sat, 14 Nov 2020 17:27:29 -0800 Subject: [PATCH] Update internal comparers & out-of-bounds regression tests --- .../OutOfBoundsRegression.cs | 217 ++++++++++++------ .../RandomizedStringEqualityComparer.cs | 25 -- 2 files changed, 146 insertions(+), 96 deletions(-) diff --git a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs index 934e57aafe4cc3..12f3a66f0edd68 100644 --- a/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs +++ b/src/libraries/System.Collections/tests/Generic/Dictionary/HashCollisionScenarios/OutOfBoundsRegression.cs @@ -54,56 +54,56 @@ public static void ComparerImplementations_Dictionary_WithWellKnownStringCompare RunDictionaryTest( equalityComparer: null, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, - expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType); // EqualityComparer.Default comparer RunDictionaryTest( equalityComparer: EqualityComparer.Default, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, - expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType); // Ordinal comparer RunDictionaryTest( equalityComparer: StringComparer.Ordinal, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, - expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType); // OrdinalIgnoreCase comparer RunDictionaryTest( equalityComparer: StringComparer.OrdinalIgnoreCase, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType, - expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType); // linguistic comparer (not optimized) RunDictionaryTest( equalityComparer: StringComparer.InvariantCulture, - expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), - expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), - expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType()); + expectedInternalComparerTypeBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), + expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture, + expectedInternalComparerTypeAfterCollisionThreshold: StringComparer.InvariantCulture.GetType()); static void RunDictionaryTest( IEqualityComparer equalityComparer, - Type expectedInternalComparerBeforeCollisionThreshold, - Type expectedPublicComparerBeforeCollisionThreshold, - Type expectedComparerAfterCollisionThreshold) + Type expectedInternalComparerTypeBeforeCollisionThreshold, + IEqualityComparer expectedPublicComparerBeforeCollisionThreshold, + Type expectedInternalComparerTypeAfterCollisionThreshold) { RunCollectionTestCommon( () => new Dictionary(equalityComparer), (dictionary, key) => dictionary.Add(key, null), (dictionary, key) => dictionary.ContainsKey(key), dictionary => dictionary.Comparer, - expectedInternalComparerBeforeCollisionThreshold, + expectedInternalComparerTypeBeforeCollisionThreshold, expectedPublicComparerBeforeCollisionThreshold, - expectedComparerAfterCollisionThreshold); + expectedInternalComparerTypeAfterCollisionThreshold); } } @@ -119,56 +119,56 @@ public static void ComparerImplementations_HashSet_WithWellKnownStringComparers( RunHashSetTest( equalityComparer: null, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, - expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType); // EqualityComparer.Default comparer RunHashSetTest( equalityComparer: EqualityComparer.Default, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, - expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: EqualityComparer.Default, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType); // Ordinal comparer RunHashSetTest( equalityComparer: StringComparer.Ordinal, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, - expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.Ordinal, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalComparerType); // OrdinalIgnoreCase comparer RunHashSetTest( equalityComparer: StringComparer.OrdinalIgnoreCase, - expectedInternalComparerBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType, - expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase.GetType(), - expectedComparerAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType); + expectedInternalComparerTypeBeforeCollisionThreshold: nonRandomizedOrdinalIgnoreCaseComparerType, + expectedPublicComparerBeforeCollisionThreshold: StringComparer.OrdinalIgnoreCase, + expectedInternalComparerTypeAfterCollisionThreshold: randomizedOrdinalIgnoreCaseComparerType); // linguistic comparer (not optimized) RunHashSetTest( equalityComparer: StringComparer.InvariantCulture, - expectedInternalComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), - expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), - expectedComparerAfterCollisionThreshold: StringComparer.InvariantCulture.GetType()); + expectedInternalComparerTypeBeforeCollisionThreshold: StringComparer.InvariantCulture.GetType(), + expectedPublicComparerBeforeCollisionThreshold: StringComparer.InvariantCulture, + expectedInternalComparerTypeAfterCollisionThreshold: StringComparer.InvariantCulture.GetType()); static void RunHashSetTest( IEqualityComparer equalityComparer, - Type expectedInternalComparerBeforeCollisionThreshold, - Type expectedPublicComparerBeforeCollisionThreshold, - Type expectedComparerAfterCollisionThreshold) + Type expectedInternalComparerTypeBeforeCollisionThreshold, + IEqualityComparer expectedPublicComparerBeforeCollisionThreshold, + Type expectedInternalComparerTypeAfterCollisionThreshold) { RunCollectionTestCommon( () => new HashSet(equalityComparer), (set, key) => Assert.True(set.Add(key)), (set, key) => set.Contains(key), set => set.Comparer, - expectedInternalComparerBeforeCollisionThreshold, + expectedInternalComparerTypeBeforeCollisionThreshold, expectedPublicComparerBeforeCollisionThreshold, - expectedComparerAfterCollisionThreshold); + expectedInternalComparerTypeAfterCollisionThreshold); } } @@ -177,24 +177,18 @@ private static void RunCollectionTestCommon( Action addKeyCallback, Func containsKeyCallback, Func> getComparerCallback, - Type expectedInternalComparerBeforeCollisionThreshold, - Type expectedPublicComparerBeforeCollisionThreshold, - Type expectedComparerAfterCollisionThreshold) + Type expectedInternalComparerTypeBeforeCollisionThreshold, + IEqualityComparer expectedPublicComparerBeforeCollisionThreshold, + Type expectedInternalComparerTypeAfterCollisionThreshold) { TCollection collection = collectionFactory(); List allKeys = new List(); - const int StartOfRange = 0xE020; // use the Unicode Private Use range to avoid accidentally creating strings that really do compare as equal OrdinalIgnoreCase - const int Stride = 0x40; // to ensure we don't accidentally reset the 0x20 bit of the seed, which is used to negate OrdinalIgnoreCase effects - // First, go right up to the collision threshold, but don't exceed it. for (int i = 0; i < 100; i++) { - string newKey = GenerateCollidingString(i * Stride + StartOfRange); - Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal - Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase - + string newKey = _collidingStrings[i]; addKeyCallback(collection, newKey); allKeys.Add(newKey); } @@ -202,15 +196,18 @@ private static void RunCollectionTestCommon( FieldInfo internalComparerField = collection.GetType().GetField("_comparer", BindingFlags.NonPublic | BindingFlags.Instance); Assert.NotNull(internalComparerField); - Assert.Equal(expectedInternalComparerBeforeCollisionThreshold, internalComparerField.GetValue(collection)?.GetType()); - Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType()); + IEqualityComparer actualInternalComparerBeforeCollisionThreshold = (IEqualityComparer)internalComparerField.GetValue(collection); + ValidateBehaviorOfInternalComparerVsPublicComparer(actualInternalComparerBeforeCollisionThreshold, expectedPublicComparerBeforeCollisionThreshold); + + Assert.Equal(expectedInternalComparerTypeBeforeCollisionThreshold, actualInternalComparerBeforeCollisionThreshold?.GetType()); + Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection)); // Now exceed the collision threshold, which should rebucket entries. // Continue adding a few more entries to ensure we didn't corrupt internal state. for (int i = 100; i < 110; i++) { - string newKey = GenerateCollidingString(i * Stride + StartOfRange); + string newKey = _collidingStrings[i]; Assert.Equal(0, _lazyGetNonRandomizedHashCodeDel.Value(newKey)); // ensure has a zero hash code Ordinal Assert.Equal(0x24716ca0, _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(newKey)); // ensure has a zero hash code OrdinalIgnoreCase @@ -218,8 +215,11 @@ private static void RunCollectionTestCommon( allKeys.Add(newKey); } - Assert.Equal(expectedComparerAfterCollisionThreshold, internalComparerField.GetValue(collection)?.GetType()); - Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection).GetType()); // shouldn't change this return value after collision threshold met + IEqualityComparer actualInternalComparerAfterCollisionThreshold = (IEqualityComparer)internalComparerField.GetValue(collection); + ValidateBehaviorOfInternalComparerVsPublicComparer(actualInternalComparerAfterCollisionThreshold, expectedPublicComparerBeforeCollisionThreshold); + + Assert.Equal(expectedInternalComparerTypeAfterCollisionThreshold, actualInternalComparerAfterCollisionThreshold?.GetType()); + Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, getComparerCallback(collection)); // shouldn't change this return value after collision threshold met // And validate that all strings are present in the dictionary. @@ -235,7 +235,7 @@ private static void RunCollectionTestCommon( ((ISerializable)collection).GetObjectData(si, new StreamingContext()); object serializedComparer = si.GetValue("Comparer", typeof(IEqualityComparer)); - Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, serializedComparer.GetType()); + Assert.Equal(expectedPublicComparerBeforeCollisionThreshold, serializedComparer); } private static Lazy> _lazyGetNonRandomizedHashCodeDel = new Lazy>( @@ -244,27 +244,63 @@ private static void RunCollectionTestCommon( private static Lazy> _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel = new Lazy>( () => GetStringHashCodeOpenDelegate("GetNonRandomizedHashCodeOrdinalIgnoreCase")); - // Generates a string with a well-known non-randomized hash code: - // - string.GetNonRandomizedHashCode returns 0. - // - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0. - // Provide a different seed to produce a different string. - private static string GenerateCollidingString(int seed) + // n.b., must be initialized *after* delegate fields above + private static readonly List _collidingStrings = GenerateCollidingStrings(110); + + private static List GenerateCollidingStrings(int count) { - return string.Create(8, seed, (span, seed) => + const int StartOfRange = 0xE020; // use the Unicode Private Use range to avoid accidentally creating strings that really do compare as equal OrdinalIgnoreCase + const int Stride = 0x40; // to ensure we don't accidentally reset the 0x20 bit of the seed, which is used to negate OrdinalIgnoreCase effects + + int currentSeed = StartOfRange; + + List collidingStrings = new List(count); + while (collidingStrings.Count < count) { - Span asBytes = MemoryMarshal.AsBytes(span); + if (currentSeed > ushort.MaxValue) + { + throw new Exception($"Couldn't create enough colliding strings? Created {collidingStrings.Count}, needed {count}."); + } - uint hash1 = (5381 << 16) + 5381; - uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1; + string candidate = GenerateCollidingStringCandidate(currentSeed); - MemoryMarshal.Write(asBytes, ref seed); - MemoryMarshal.Write(asBytes.Slice(4), ref hash2); // set hash2 := 0 (for Ordinal) + int ordinalHashCode = _lazyGetNonRandomizedHashCodeDel.Value(candidate); + Assert.Equal(0, ordinalHashCode); // ensure has a zero hash code Ordinal - hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed; - hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1); + int ordinalIgnoreCaseHashCode = _lazyGetNonRandomizedOrdinalIgnoreCaseHashCodeDel.Value(candidate); + if (ordinalIgnoreCaseHashCode == 0x24716ca0) // ensure has a zero hash code OrdinalIgnoreCase (might not have one) + { + collidingStrings.Add(candidate); // success! + } - MemoryMarshal.Write(asBytes.Slice(8), ref hash1); // set hash1 := 0 (for Ordinal) - }); + currentSeed += Stride; + } + + return collidingStrings; + + // Generates a possible string with a well-known non-randomized hash code: + // - string.GetNonRandomizedHashCode returns 0. + // - string.GetNonRandomizedHashCodeOrdinalIgnoreCase returns 0x24716ca0. + // Provide a different seed to produce a different string. + // Caller must check OrdinalIgnoreCase hash code to ensure correctness. + static string GenerateCollidingStringCandidate(int seed) + { + return string.Create(8, seed, (span, seed) => + { + Span asBytes = MemoryMarshal.AsBytes(span); + + uint hash1 = (5381 << 16) + 5381; + uint hash2 = BitOperations.RotateLeft(hash1, 5) + hash1; + + MemoryMarshal.Write(asBytes, ref seed); + MemoryMarshal.Write(asBytes.Slice(4), ref hash2); // set hash2 := 0 (for Ordinal) + + hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1) ^ (uint)seed; + hash1 = (BitOperations.RotateLeft(hash1, 5) + hash1); + + MemoryMarshal.Write(asBytes.Slice(8), ref hash1); // set hash1 := 0 (for Ordinal) + }); + } } private static Func GetStringHashCodeOpenDelegate(string methodName) @@ -274,5 +310,44 @@ private static Func GetStringHashCodeOpenDelegate(string methodName return method.CreateDelegate>(target: null); // create open delegate unbound to 'this' } + + private static void ValidateBehaviorOfInternalComparerVsPublicComparer(IEqualityComparer internalComparer, IEqualityComparer publicComparer) + { + // This helper ensures that when we substitute one of our internal comparers + // in place of the expected public comparer, the internal comparer's Equals + // and GetHashCode behavior are consistent with the public comparer's. + + if (internalComparer is null) + { + internalComparer = EqualityComparer.Default; + } + if (publicComparer is null) + { + publicComparer = EqualityComparer.Default; + } + + foreach (var pair in new[] { + ("Hello", "Hello"), // exactly equal + ("Hello", "Goodbye"), // not equal at all + ("Hello", "hello"), // case-insensitive equal + ("Hello", "He\u200dllo"), // equal under linguistic comparer + ("Hello", "HE\u200dLLO"), // equal under case-insensitive linguistic comparer + ("абвгдеёжзийклмнопрстуфхцчшщьыъэюя", "АБВГДЕЁЖЗИЙКЛМНОПРСТУФХЦЧШЩЬЫЪЭЮЯ"), // Cyrillic, case-insensitive equal + }) + { + bool arePairElementsExpectedEqual = publicComparer.Equals(pair.Item1, pair.Item2); + Assert.Equal(arePairElementsExpectedEqual, internalComparer.Equals(pair.Item1, pair.Item2)); + + bool areInternalHashCodesEqual = internalComparer.GetHashCode(pair.Item1) == internalComparer.GetHashCode(pair.Item2); + if (arePairElementsExpectedEqual) + { + Assert.True(areInternalHashCodesEqual); + } + else if (!areInternalHashCodesEqual) + { + Assert.False(arePairElementsExpectedEqual); + } + } + } } } diff --git a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs index 168959d83386a2..30db6049d22f77 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Collections/Generic/RandomizedStringEqualityComparer.cs @@ -80,31 +80,6 @@ internal OrdinalIgnoreCaseComparer(IEqualityComparer wrappedComparer) public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y); - public override int GetHashCode(string? obj) - { - if (obj is null) - { - return 0; - } - - // The Ordinal version of Marvin32 operates over bytes, so convert - // char count -> byte count. Guaranteed not to integer overflow. - return Marvin.ComputeHash32( - ref Unsafe.As(ref obj.GetRawStringData()), - (uint)obj.Length * sizeof(char), - _seed.p0, _seed.p1); - } - } - - private sealed class RandomizedOrdinalIgnoreCaseComparer : RandomizedStringEqualityComparer - { - internal RandomizedOrdinalIgnoreCaseComparer(IEqualityComparer underlyingComparer) - : base(underlyingComparer) - { - } - - public override bool Equals(string? x, string? y) => string.EqualsOrdinalIgnoreCase(x, y); - public override int GetHashCode(string? obj) { if (obj is null)