diff --git a/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadLocal.cs b/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadLocal.cs index 539ac4f9e9f5a..7b73e8d64e675 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadLocal.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Threading/ThreadLocal.cs @@ -123,7 +123,7 @@ private void Initialize(Func? valueFactory, bool trackAllValues) _trackAllValues = trackAllValues; // Assign the ID and mark the instance as initialized. - _idComplement = ~s_idManager.GetId(); + _idComplement = ~s_idManager.GetId(trackAllValues); // As the last step, mark the instance as fully initialized. (Otherwise, if _initialized=false, we know that an exception // occurred in the constructor.) @@ -201,7 +201,7 @@ protected virtual void Dispose(bool disposing) } } _linkedSlot = null; - s_idManager.ReturnId(id); + s_idManager.ReturnId(id, _trackAllValues); } #endregion @@ -346,7 +346,7 @@ private void SetValueSlow(T value, LinkedSlotVolatile[]? slotArray) if (slotArray == null) { slotArray = new LinkedSlotVolatile[GetNewTableSize(id + 1)]; - ts_finalizationHelper = new FinalizationHelper(slotArray, _trackAllValues); + ts_finalizationHelper = new FinalizationHelper(slotArray); ts_slotArray = slotArray; } @@ -675,42 +675,66 @@ private sealed class IdManager { // The next ID to try private int _nextIdToTry; + // Keep track of the count of non-TrackAllValues ids in use. A count of 0 leads to more efficient thread cleanup + private volatile int _idsThatDoNotTrackAllValues; - // Stores whether each ID is free or not. Additionally, the object is also used as a lock for the IdManager. - private readonly List _freeIds = new List(); + private const byte IdFree = 0; + private const byte TrackAllValuesAllocated = 1; + private const byte DoNotTrackAllValuesAllocated = 2; - internal int GetId() + // Stores whether each ID is free or not, and if it tracksAllValues or not. Additionally, the object is also used as a lock for the IdManager. + private readonly List _ids = new List(); + + internal int GetId(bool trackAllValues) { - lock (_freeIds) + lock (_ids) { int availableId = _nextIdToTry; - while (availableId < _freeIds.Count) + while (availableId < _ids.Count) { - if (_freeIds[availableId]) { break; } + if (_ids[availableId] == IdFree) { break; } availableId++; } - if (availableId == _freeIds.Count) + byte allocatedFlag = trackAllValues ? TrackAllValuesAllocated : DoNotTrackAllValuesAllocated; + if (availableId == _ids.Count) { - _freeIds.Add(false); + _ids.Add(allocatedFlag); } else { - _freeIds[availableId] = false; + _ids[availableId] = allocatedFlag; } + if (!trackAllValues) + _idsThatDoNotTrackAllValues++; + _nextIdToTry = availableId + 1; return availableId; } } + // Identify if an allocated id tracks all values or not + internal bool IdTracksAllValues(int id) + { + lock (_ids) + { + return _ids[id] == TrackAllValuesAllocated; + } + } + + internal int IdsThatDoNotTrackValuesCount => _idsThatDoNotTrackAllValues; + // Return an ID to the pool - internal void ReturnId(int id) + internal void ReturnId(int id, bool idTracksAllValues) { - lock (_freeIds) + lock (_ids) { - _freeIds[id] = true; + if (!idTracksAllValues) + _idsThatDoNotTrackAllValues--; + + _ids[id] = IdFree; if (id < _nextIdToTry) _nextIdToTry = id; } } @@ -731,18 +755,17 @@ internal void ReturnId(int id) private sealed class FinalizationHelper { internal LinkedSlotVolatile[] SlotArray; - private readonly bool _trackAllValues; - internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues) + internal FinalizationHelper(LinkedSlotVolatile[] slotArray) { SlotArray = slotArray; - _trackAllValues = trackAllValues; } ~FinalizationHelper() { LinkedSlotVolatile[] slotArray = SlotArray; Debug.Assert(slotArray != null); + int idsThatDoNotTrackAllValuesCountRemaining = s_idManager.IdsThatDoNotTrackValuesCount; for (int i = 0; i < slotArray.Length; i++) { @@ -753,7 +776,10 @@ internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues) continue; } - if (_trackAllValues) + // If there are no ids that do not TrackAllValues, we don't need to call the IdTracksAllValues function. + // This is an improvement as that function requires taking a lock. + if (idsThatDoNotTrackAllValuesCountRemaining == 0 || + s_idManager.IdTracksAllValues(i)) { // Set the SlotArray field to null to release the slot array. linkedSlot._slotArray = null; @@ -764,6 +790,13 @@ internal FinalizationHelper(LinkedSlotVolatile[] slotArray, bool trackAllValues) // the table will be have been removed, and so the table can get GC'd. lock (s_idManager) { + // If the slot wasn't disposed between reading it above and entering the lock + // decrement idsThatDoNotTrackAllValuesCountRemaining + if (slotArray[i].Value != null) + { + idsThatDoNotTrackAllValuesCountRemaining--; + } + if (linkedSlot._next != null) { linkedSlot._next._previous = linkedSlot._previous; diff --git a/src/libraries/System.Threading/tests/ThreadLocalTests.cs b/src/libraries/System.Threading/tests/ThreadLocalTests.cs index 54cdf6d46b4b4..7a1ff7743a237 100644 --- a/src/libraries/System.Threading/tests/ThreadLocalTests.cs +++ b/src/libraries/System.Threading/tests/ThreadLocalTests.cs @@ -435,6 +435,37 @@ public static void ValuesGetterDoesNotThrowUnexpectedExceptionWhenDisposed() Assert.False(failed); } + private enum UniqueEnumUsedOnlyWithNonInterferenceTest { True, False } + + [ConditionalFact(typeof(PlatformDetection), nameof(PlatformDetection.IsThreadingSupported))] + public static void TestUnrelatedThreadLocalDoesNotInterfereWithTrackAllValues() + { + ThreadLocal localThatDoesNotTrackValues = new ThreadLocal(false); + ThreadLocal localThatDoesTrackValues = new ThreadLocal(true); + + for (int i = 0; i < 10; i++) + { + Thread t = new Thread(Work); + t.Start(); + t.Join(); + } + GC.Collect(); + GC.WaitForPendingFinalizers(); + int count = 0; + foreach (var x in localThatDoesTrackValues.Values) + { + if (x == UniqueEnumUsedOnlyWithNonInterferenceTest.True) + count++; + } + + Assert.Equal(10, count); + void Work() + { + localThatDoesNotTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True; + localThatDoesTrackValues.Value = UniqueEnumUsedOnlyWithNonInterferenceTest.True; + } + } + private class SetMreOnFinalize { private ManualResetEventSlim _mres;