Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ThreadLocal tracking behavior #56956

Merged
merged 1 commit into from
Aug 10, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private void Initialize(Func<T>? valueFactory, bool trackAllValues)
_trackAllValues = trackAllValues;

// Assign the ID and mark the instance as initialized.
_idComplement = ~s_idManager.GetId();
_idComplement = ~s_idManager.GetId(trackAllValues);

// As the last step, mark the instance as fully initialized. (Otherwise, if _initialized=false, we know that an exception
// occurred in the constructor.)
Expand Down Expand Up @@ -201,7 +201,7 @@ protected virtual void Dispose(bool disposing)
}
}
_linkedSlot = null;
s_idManager.ReturnId(id);
s_idManager.ReturnId(id, _trackAllValues);
}

#endregion
Expand Down Expand Up @@ -346,7 +346,7 @@ private void SetValueSlow(T value, LinkedSlotVolatile[]? slotArray)
if (slotArray == null)
{
slotArray = new LinkedSlotVolatile[GetNewTableSize(id + 1)];
ts_finalizationHelper = new FinalizationHelper(slotArray, _trackAllValues);
ts_finalizationHelper = new FinalizationHelper(slotArray);
ts_slotArray = slotArray;
}

Expand Down Expand Up @@ -675,42 +675,66 @@ private sealed class IdManager
{
// The next ID to try
private int _nextIdToTry;
// Keep track of the count of non-TrackAllValues ids in use. A count of 0 leads to more efficient thread cleanup
private volatile int _idsThatDoNotTrackAllValues;

// Stores whether each ID is free or not. Additionally, the object is also used as a lock for the IdManager.
private readonly List<bool> _freeIds = new List<bool>();
private const byte IdFree = 0;
private const byte TrackAllValuesAllocated = 1;
private const byte DoNotTrackAllValuesAllocated = 2;

internal int GetId()
// Stores whether each ID is free or not, and if it tracksAllValues or not. Additionally, the object is also used as a lock for the IdManager.
private readonly List<byte> _ids = new List<byte>();

internal int GetId(bool trackAllValues)
{
lock (_freeIds)
lock (_ids)
{
int availableId = _nextIdToTry;
while (availableId < _freeIds.Count)
while (availableId < _ids.Count)
{
if (_freeIds[availableId]) { break; }
if (_ids[availableId] == IdFree) { break; }
availableId++;
}

if (availableId == _freeIds.Count)
byte allocatedFlag = trackAllValues ? TrackAllValuesAllocated : DoNotTrackAllValuesAllocated;
if (availableId == _ids.Count)
{
_freeIds.Add(false);
_ids.Add(allocatedFlag);
}
else
{
_freeIds[availableId] = false;
_ids[availableId] = allocatedFlag;
}

if (!trackAllValues)
_idsThatDoNotTrackAllValues++;

_nextIdToTry = availableId + 1;

return availableId;
}
}

// Identify if an allocated id tracks all values or not
internal bool IdTracksAllValues(int id)
{
lock (_ids)
{
return _ids[id] == TrackAllValuesAllocated;
}
}

internal int IdsThatDoNotTrackValuesCount => _idsThatDoNotTrackAllValues;

// Return an ID to the pool
internal void ReturnId(int id)
internal void ReturnId(int id, bool idTracksAllValues)
{
lock (_freeIds)
lock (_ids)
{
_freeIds[id] = true;
if (!idTracksAllValues)
_idsThatDoNotTrackAllValues--;

_ids[id] = IdFree;
if (id < _nextIdToTry) _nextIdToTry = id;
}
}
Expand All @@ -731,18 +755,17 @@ internal void ReturnId(int id)
private sealed class FinalizationHelper
{
internal LinkedSlotVolatile[] SlotArray;
private readonly bool _trackAllValues;

internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues)
internal FinalizationHelper(LinkedSlotVolatile[] slotArray)
{
SlotArray = slotArray;
_trackAllValues = trackAllValues;
}

~FinalizationHelper()
{
LinkedSlotVolatile[] slotArray = SlotArray;
Debug.Assert(slotArray != null);
int idsThatDoNotTrackAllValuesCountRemaining = s_idManager.IdsThatDoNotTrackValuesCount;

for (int i = 0; i < slotArray.Length; i++)
{
Expand All @@ -753,7 +776,10 @@ internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues)
continue;
}

if (_trackAllValues)
// If there are no ids that do not TrackAllValues, we don't need to call the IdTracksAllValues function.
// This is an improvement as that function requires taking a lock.
if (idsThatDoNotTrackAllValuesCountRemaining == 0 ||
s_idManager.IdTracksAllValues(i))
{
// Set the SlotArray field to null to release the slot array.
linkedSlot._slotArray = null;
Expand All @@ -764,6 +790,13 @@ internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues)
// the table will be have been removed, and so the table can get GC'd.
lock (s_idManager)
{
// If the slot wasn't disposed between reading it above and entering the lock
// decrement idsThatDoNotTrackAllValuesCountRemaining
if (slotArray[i].Value != null)
{
idsThatDoNotTrackAllValuesCountRemaining--;
}

if (linkedSlot._next != null)
{
linkedSlot._next._previous = linkedSlot._previous;
Expand Down
31 changes: 31 additions & 0 deletions src/libraries/System.Threading/tests/ThreadLocalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,37 @@ public static void ValuesGetterDoesNotThrowUnexpectedExceptionWhenDisposed()
Assert.False(failed);
}

private enum UniqueEnumUsedOnlyWithNonInterferenceTest { True, False }

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public static void TestUnrelatedThreadLocalDoesNotInterfereWithTrackAllValues()
{
ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest> localThatDoesNotTrackValues = new ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest>(false);
ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest> localThatDoesTrackValues = new ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest>(true);

for (int i = 0; i < 10; i++)
{
Thread t = new Thread(Work);
t.Start();
t.Join();
}
GC.Collect();
GC.WaitForPendingFinalizers();
int count = 0;
foreach (var x in localThatDoesTrackValues.Values)
{
if (x == UniqueEnumUsedOnlyWithNonInterferenceTest.True)
count++;
}

Assert.Equal(10, count);
void Work()
{
localThatDoesNotTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True;
localThatDoesTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True;
}
}

private class SetMreOnFinalize
{
private ManualResetEventSlim _mres;
Expand Down