Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an additional length check to FrozenDictionary and FrozenSet #92546

Merged
merged 9 commits into from
Dec 11, 2023
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,12 @@ private static FrozenDictionary<TKey, TValue> CreateFromDictionary<TKey, TValue>

// Calculate the minimum and maximum lengths of the strings in the dictionary. Several of the analyses need this.
int minLength = int.MaxValue, maxLength = 0;
ulong lengthFilter = 0;
foreach (string key in keys)
{
if (key.Length < minLength) minLength = key.Length;
if (key.Length > maxLength) maxLength = key.Length;
lengthFilter |= (1UL << (key.Length % 64));
}
Debug.Assert(minLength >= 0 && maxLength >= minLength);

Expand Down Expand Up @@ -215,12 +217,12 @@ private static FrozenDictionary<TKey, TValue> CreateFromDictionary<TKey, TValue>
if (analysis.IgnoreCase)
{
frozenDictionary = analysis.AllAsciiIfIgnoreCase
? new OrdinalStringFrozenDictionary_FullCaseInsensitiveAscii<TValue>(keys, values, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff)
: new OrdinalStringFrozenDictionary_FullCaseInsensitive<TValue>(keys, values, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff);
? new OrdinalStringFrozenDictionary_FullCaseInsensitiveAscii<TValue>(keys, values, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff, lengthFilter)
: new OrdinalStringFrozenDictionary_FullCaseInsensitive<TValue>(keys, values, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff, lengthFilter);
}
else
{
frozenDictionary = new OrdinalStringFrozenDictionary_Full<TValue>(keys, values, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff);
frozenDictionary = new OrdinalStringFrozenDictionary_Full<TValue>(keys, values, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff, lengthFilter);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,10 +109,12 @@ private static FrozenSet<T> CreateFromSet<T>(HashSet<T> source)

// Calculate the minimum and maximum lengths of the strings in the set. Several of the analyses need this.
int minLength = int.MaxValue, maxLength = 0;
ulong lengthFilter = 0;
foreach (string s in entries)
{
if (s.Length < minLength) minLength = s.Length;
if (s.Length > maxLength) maxLength = s.Length;
lengthFilter |= (1UL << (s.Length % 64));
}
Debug.Assert(minLength >= 0 && maxLength >= minLength);

Expand Down Expand Up @@ -163,12 +165,12 @@ private static FrozenSet<T> CreateFromSet<T>(HashSet<T> source)
if (analysis.IgnoreCase)
{
frozenSet = analysis.AllAsciiIfIgnoreCase
? new OrdinalStringFrozenSet_FullCaseInsensitiveAscii(entries, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff)
: new OrdinalStringFrozenSet_FullCaseInsensitive(entries, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff);
? new OrdinalStringFrozenSet_FullCaseInsensitiveAscii(entries, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff, lengthFilter)
: new OrdinalStringFrozenSet_FullCaseInsensitive(entries, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff, lengthFilter);
}
else
{
frozenSet = new OrdinalStringFrozenSet_Full(entries, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff);
frozenSet = new OrdinalStringFrozenSet_Full(entries, stringComparer, analysis.MinimumLength, analysis.MaximumLengthDiff, lengthFilter);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ internal OrdinalStringFrozenDictionary(
private protected int HashCount { get; }
private protected abstract bool Equals(string? x, string? y);
private protected abstract int GetHashCode(string s);
private protected virtual bool CheckLengthQuick(string key) => true;
private protected override string[] KeysCore => _keys;
private protected override TValue[] ValuesCore => _values;
private protected override Enumerator GetEnumeratorCore() => new Enumerator(_keys, _values);
Expand All @@ -74,20 +75,23 @@ private protected override ref readonly TValue GetValueRefOrNullRefCore(string k
{
if ((uint)(key.Length - _minimumLength) <= (uint)_maximumLengthDiff)
{
int hashCode = GetHashCode(key);
_hashTable.FindMatchingEntries(hashCode, out int index, out int endIndex);

while (index <= endIndex)
if (CheckLengthQuick(key))
{
if (hashCode == _hashTable.HashCodes[index])
int hashCode = GetHashCode(key);
_hashTable.FindMatchingEntries(hashCode, out int index, out int endIndex);

while (index <= endIndex)
{
if (Equals(key, _keys[index]))
if (hashCode == _hashTable.HashCodes[index])
{
return ref _values[index];
if (Equals(key, _keys[index]))
{
return ref _values[index];
}
}
}

index++;
index++;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ namespace System.Collections.Frozen
{
internal sealed class OrdinalStringFrozenDictionary_Full<TValue> : OrdinalStringFrozenDictionary<TValue>
{
private readonly ulong _lengthFilter;

internal OrdinalStringFrozenDictionary_Full(
string[] keys,
TValue[] values,
IEqualityComparer<string> comparer,
int minimumLength,
int maximumLengthDiff)
int maximumLengthDiff,
ulong lengthFilter)
: base(keys, values, comparer, minimumLength, maximumLengthDiff)
{
_lengthFilter = lengthFilter;
}

// This override is necessary to force the jit to emit the code in such a way that it
Expand All @@ -24,5 +28,6 @@ internal OrdinalStringFrozenDictionary_Full(

private protected override bool Equals(string? x, string? y) => string.Equals(x, y);
private protected override int GetHashCode(string s) => Hashing.GetHashCodeOrdinal(s.AsSpan());
private protected override bool CheckLengthQuick(string key) => (_lengthFilter & (1UL << (key.Length % 64))) > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ namespace System.Collections.Frozen
{
internal sealed class OrdinalStringFrozenDictionary_FullCaseInsensitive<TValue> : OrdinalStringFrozenDictionary<TValue>
{
private readonly ulong _lengthFilter;

internal OrdinalStringFrozenDictionary_FullCaseInsensitive(
string[] keys,
TValue[] values,
IEqualityComparer<string> comparer,
int minimumLength,
int maximumLengthDiff)
int maximumLengthDiff,
ulong lengthFilter)
: base(keys, values, comparer, minimumLength, maximumLengthDiff)
{
_lengthFilter = lengthFilter;
}

// This override is necessary to force the jit to emit the code in such a way that it
Expand All @@ -24,5 +28,6 @@ internal OrdinalStringFrozenDictionary_FullCaseInsensitive(

private protected override bool Equals(string? x, string? y) => StringComparer.OrdinalIgnoreCase.Equals(x, y);
private protected override int GetHashCode(string s) => Hashing.GetHashCodeOrdinalIgnoreCase(s.AsSpan());
private protected override bool CheckLengthQuick(string key) => (_lengthFilter & (1UL << (key.Length % 64))) > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,18 @@ namespace System.Collections.Frozen
{
internal sealed class OrdinalStringFrozenDictionary_FullCaseInsensitiveAscii<TValue> : OrdinalStringFrozenDictionary<TValue>
{
private readonly ulong _lengthFilter;

internal OrdinalStringFrozenDictionary_FullCaseInsensitiveAscii(
string[] keys,
TValue[] values,
IEqualityComparer<string> comparer,
int minimumLength,
int maximumLengthDiff)
int maximumLengthDiff,
ulong lengthFilter)
: base(keys, values, comparer, minimumLength, maximumLengthDiff)
{
_lengthFilter = lengthFilter;
}

// This override is necessary to force the jit to emit the code in such a way that it
Expand All @@ -24,5 +28,6 @@ internal OrdinalStringFrozenDictionary_FullCaseInsensitiveAscii(

private protected override bool Equals(string? x, string? y) => StringComparer.OrdinalIgnoreCase.Equals(x, y);
private protected override int GetHashCode(string s) => Hashing.GetHashCodeOrdinalIgnoreCaseAscii(s.AsSpan());
private protected override bool CheckLengthQuick(string key) => (_lengthFilter & (1UL << (key.Length % 64))) > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ internal OrdinalStringFrozenSet(
private protected int HashCount { get; }
private protected abstract bool Equals(string? x, string? y);
private protected abstract int GetHashCode(string s);
private protected virtual bool CheckLengthQuick(string key) => true;
private protected override string[] ItemsCore => _items;
private protected override Enumerator GetEnumeratorCore() => new Enumerator(_items);
private protected override int CountCore => _hashTable.Count;
Expand All @@ -64,20 +65,23 @@ private protected override int FindItemIndex(string item)
if (item is not null && // this implementation won't be used for null values
(uint)(item.Length - _minimumLength) <= (uint)_maximumLengthDiff)
{
int hashCode = GetHashCode(item);
_hashTable.FindMatchingEntries(hashCode, out int index, out int endIndex);

while (index <= endIndex)
if (CheckLengthQuick(item))
{
if (hashCode == _hashTable.HashCodes[index])
int hashCode = GetHashCode(item);
_hashTable.FindMatchingEntries(hashCode, out int index, out int endIndex);

while (index <= endIndex)
{
if (Equals(item, _items[index]))
if (hashCode == _hashTable.HashCodes[index])
{
return index;
if (Equals(item, _items[index]))
{
return index;
}
}
}

index++;
index++;
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ namespace System.Collections.Frozen
{
internal sealed class OrdinalStringFrozenSet_Full : OrdinalStringFrozenSet
{
private readonly ulong _lengthFilter;

internal OrdinalStringFrozenSet_Full(
string[] entries,
IEqualityComparer<string> comparer,
int minimumLength,
int maximumLengthDiff)
int maximumLengthDiff,
ulong lengthFilter)
: base(entries, comparer, minimumLength, maximumLengthDiff)
{
_lengthFilter = lengthFilter;
}

// This override is necessary to force the jit to emit the code in such a way that it
Expand All @@ -23,5 +27,6 @@ internal OrdinalStringFrozenSet_Full(

private protected override bool Equals(string? x, string? y) => string.Equals(x, y);
private protected override int GetHashCode(string s) => Hashing.GetHashCodeOrdinal(s.AsSpan());
private protected override bool CheckLengthQuick(string key) => (_lengthFilter & (1UL << (key.Length % 64))) > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ namespace System.Collections.Frozen
{
internal sealed class OrdinalStringFrozenSet_FullCaseInsensitive : OrdinalStringFrozenSet
{
private readonly ulong _lengthFilter;

internal OrdinalStringFrozenSet_FullCaseInsensitive(
string[] entries,
IEqualityComparer<string> comparer,
int minimumLength,
int maximumLengthDiff)
int maximumLengthDiff,
ulong lengthFilter)
: base(entries, comparer, minimumLength, maximumLengthDiff)
{
_lengthFilter = lengthFilter;
}

// This override is necessary to force the jit to emit the code in such a way that it
Expand All @@ -23,5 +27,6 @@ internal OrdinalStringFrozenSet_FullCaseInsensitive(

private protected override bool Equals(string? x, string? y) => StringComparer.OrdinalIgnoreCase.Equals(x, y);
private protected override int GetHashCode(string s) => Hashing.GetHashCodeOrdinalIgnoreCase(s.AsSpan());
private protected override bool CheckLengthQuick(string key) => (_lengthFilter & (1UL << (key.Length % 64))) > 0;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,17 @@ namespace System.Collections.Frozen
{
internal sealed class OrdinalStringFrozenSet_FullCaseInsensitiveAscii : OrdinalStringFrozenSet
{
private readonly ulong _lengthFilter;

internal OrdinalStringFrozenSet_FullCaseInsensitiveAscii(
string[] entries,
IEqualityComparer<string> comparer,
int minimumLength,
int maximumLengthDiff)
int maximumLengthDiff,
ulong lengthFilter)
: base(entries, comparer, minimumLength, maximumLengthDiff)
{
_lengthFilter = lengthFilter;
}

// This override is necessary to force the jit to emit the code in such a way that it
Expand All @@ -23,5 +27,6 @@ internal OrdinalStringFrozenSet_FullCaseInsensitiveAscii(

private protected override bool Equals(string? x, string? y) => StringComparer.OrdinalIgnoreCase.Equals(x, y);
private protected override int GetHashCode(string s) => Hashing.GetHashCodeOrdinalIgnoreCaseAscii(s.AsSpan());
private protected override bool CheckLengthQuick(string key) => (_lengthFilter & (1UL << (key.Length % 64))) > 0;
}
}
Loading