Skip to content

Commit

Permalink
Fix ThreadLocal tracking behavior (#56956)
Browse files Browse the repository at this point in the history
- Before this change the trackAllValues behavior for ThreadLocal<SomeParticularType> was defined by the first instance of thread local to have its value set on the thread
  - This could lead to unpredictable memory leaks (where the value was improperly tracked even though it wasn't supposed to be) This reproduces as a memory leak with no other observable behavior
  - Or data loss, where the Values collection was missing entries.
- Change the model so that ThreadLocal<T> trackAllValues behavior is properly defined by the exact ThreadLocal<T> instance in use
- Implement by keeping track of the track all changes behavior within the IdManager

Fixes #55796
  • Loading branch information
davidwrighton committed Aug 10, 2021
1 parent 4d82932 commit fc2c089
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ private void Initialize(Func<T>? valueFactory, bool trackAllValues)
_trackAllValues = trackAllValues;

// Assign the ID and mark the instance as initialized.
_idComplement = ~s_idManager.GetId();
_idComplement = ~s_idManager.GetId(trackAllValues);

// As the last step, mark the instance as fully initialized. (Otherwise, if _initialized=false, we know that an exception
// occurred in the constructor.)
Expand Down Expand Up @@ -201,7 +201,7 @@ protected virtual void Dispose(bool disposing)
}
}
_linkedSlot = null;
s_idManager.ReturnId(id);
s_idManager.ReturnId(id, _trackAllValues);
}

#endregion
Expand Down Expand Up @@ -346,7 +346,7 @@ private void SetValueSlow(T value, LinkedSlotVolatile[]? slotArray)
if (slotArray == null)
{
slotArray = new LinkedSlotVolatile[GetNewTableSize(id + 1)];
ts_finalizationHelper = new FinalizationHelper(slotArray, _trackAllValues);
ts_finalizationHelper = new FinalizationHelper(slotArray);
ts_slotArray = slotArray;
}

Expand Down Expand Up @@ -675,42 +675,66 @@ private sealed class IdManager
{
// The next ID to try
private int _nextIdToTry;
// Keep track of the count of non-TrackAllValues ids in use. A count of 0 leads to more efficient thread cleanup
private volatile int _idsThatDoNotTrackAllValues;

// Stores whether each ID is free or not. Additionally, the object is also used as a lock for the IdManager.
private readonly List<bool> _freeIds = new List<bool>();
private const byte IdFree = 0;
private const byte TrackAllValuesAllocated = 1;
private const byte DoNotTrackAllValuesAllocated = 2;

internal int GetId()
// Stores whether each ID is free or not, and if it tracksAllValues or not. Additionally, the object is also used as a lock for the IdManager.
private readonly List<byte> _ids = new List<byte>();

internal int GetId(bool trackAllValues)
{
lock (_freeIds)
lock (_ids)
{
int availableId = _nextIdToTry;
while (availableId < _freeIds.Count)
while (availableId < _ids.Count)
{
if (_freeIds[availableId]) { break; }
if (_ids[availableId] == IdFree) { break; }
availableId++;
}

if (availableId == _freeIds.Count)
byte allocatedFlag = trackAllValues ? TrackAllValuesAllocated : DoNotTrackAllValuesAllocated;
if (availableId == _ids.Count)
{
_freeIds.Add(false);
_ids.Add(allocatedFlag);
}
else
{
_freeIds[availableId] = false;
_ids[availableId] = allocatedFlag;
}

if (!trackAllValues)
_idsThatDoNotTrackAllValues++;

_nextIdToTry = availableId + 1;

return availableId;
}
}

// Identify if an allocated id tracks all values or not
internal bool IdTracksAllValues(int id)
{
lock (_ids)
{
return _ids[id] == TrackAllValuesAllocated;
}
}

internal int IdsThatDoNotTrackValuesCount => _idsThatDoNotTrackAllValues;

// Return an ID to the pool
internal void ReturnId(int id)
internal void ReturnId(int id, bool idTracksAllValues)
{
lock (_freeIds)
lock (_ids)
{
_freeIds[id] = true;
if (!idTracksAllValues)
_idsThatDoNotTrackAllValues--;

_ids[id] = IdFree;
if (id < _nextIdToTry) _nextIdToTry = id;
}
}
Expand All @@ -731,18 +755,17 @@ internal void ReturnId(int id)
private sealed class FinalizationHelper
{
internal LinkedSlotVolatile[] SlotArray;
private readonly bool _trackAllValues;

internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues)
internal FinalizationHelper(LinkedSlotVolatile[] slotArray)
{
SlotArray = slotArray;
_trackAllValues = trackAllValues;
}

~FinalizationHelper()
{
LinkedSlotVolatile[] slotArray = SlotArray;
Debug.Assert(slotArray != null);
int idsThatDoNotTrackAllValuesCountRemaining = s_idManager.IdsThatDoNotTrackValuesCount;

for (int i = 0; i < slotArray.Length; i++)
{
Expand All @@ -753,7 +776,10 @@ internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues)
continue;
}

if (_trackAllValues)
// If there are no ids that do not TrackAllValues, we don't need to call the IdTracksAllValues function.
// This is an improvement as that function requires taking a lock.
if (idsThatDoNotTrackAllValuesCountRemaining == 0 ||
s_idManager.IdTracksAllValues(i))
{
// Set the SlotArray field to null to release the slot array.
linkedSlot._slotArray = null;
Expand All @@ -764,6 +790,13 @@ internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues)
// the table will be have been removed, and so the table can get GC'd.
lock (s_idManager)
{
// If the slot wasn't disposed between reading it above and entering the lock
// decrement idsThatDoNotTrackAllValuesCountRemaining
if (slotArray[i].Value != null)
{
idsThatDoNotTrackAllValuesCountRemaining--;
}

if (linkedSlot._next != null)
{
linkedSlot._next._previous = linkedSlot._previous;
Expand Down
31 changes: 31 additions & 0 deletions src/libraries/System.Threading/tests/ThreadLocalTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,37 @@ public static void ValuesGetterDoesNotThrowUnexpectedExceptionWhenDisposed()
Assert.False(failed);
}

private enum UniqueEnumUsedOnlyWithNonInterferenceTest { True, False }

[ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))]
public static void TestUnrelatedThreadLocalDoesNotInterfereWithTrackAllValues()
{
ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest> localThatDoesNotTrackValues = new ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest>(false);
ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest> localThatDoesTrackValues = new ThreadLocal<UniqueEnumUsedOnlyWithNonInterferenceTest>(true);

for (int i = 0; i < 10; i++)
{
Thread t = new Thread(Work);
t.Start();
t.Join();
}
GC.Collect();
GC.WaitForPendingFinalizers();
int count = 0;
foreach (var x in localThatDoesTrackValues.Values)
{
if (x == UniqueEnumUsedOnlyWithNonInterferenceTest.True)
count++;
}

Assert.Equal(10, count);
void Work()
{
localThatDoesNotTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True;
localThatDoesTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True;
}
}

private class SetMreOnFinalize
{
private ManualResetEventSlim _mres;
Expand Down

0 comments on commit fc2c089

Please sign in to comment.