diff --git a/include/envoy/thread_local/thread_local.h b/include/envoy/thread_local/thread_local.h index 5c4374aff8ea..b4828d6e6550 100644 --- a/include/envoy/thread_local/thread_local.h +++ b/include/envoy/thread_local/thread_local.h @@ -76,11 +76,17 @@ class Slot { virtual void set(InitializeCb cb) PURE; /** - * UpdateCb takes the current stored data, and returns an updated/new version data. - * TLS will run the callback and replace the stored data with the returned value *in each thread*. + * UpdateCb takes the current stored data, and must return the same data. The + * API was designed to allow replacement of the object via this API, but this + * is not currently used, and thus we are removing the functionality. In a future + * PR, the API will be removed. * - * NOTE: The update callback is not supposed to capture the Slot, or its owner. As the owner may - * be destructed in main thread before the update_cb gets called in a worker thread. + * TLS will run the callback and assert the returned returned value matches + * the current value. + * + * NOTE: The update callback is not supposed to capture the Slot, or its + * owner. As the owner may be destructed in main thread before the update_cb + * gets called in a worker thread. **/ using UpdateCb = std::function; virtual void runOnAllThreads(const UpdateCb& update_cb) PURE; @@ -102,6 +108,83 @@ class SlotAllocator { virtual SlotPtr allocateSlot() PURE; }; +// Provides a typesafe API for slots. +// +// TODO(jmarantz): Rename the Slot class to something like RawSlot, where the +// only reference is from TypedSlot, which we can then rename to Slot. +template class TypedSlot { +public: + /** + * Helper method to create a unique_ptr for a typed slot. This helper + * reduces some verbose parameterization at call-sites. + * + * @param allocator factory to allocate untyped Slot objects. + * @return a TypedSlotPtr (the type is defined below). + */ + static std::unique_ptr makeUnique(SlotAllocator& allocator) { + return std::make_unique(allocator); + } + + explicit TypedSlot(SlotAllocator& allocator) : slot_(allocator.allocateSlot()) {} + + /** + * Returns if there is thread local data for this thread. + * + * This should return true for Envoy worker threads and false for threads which do not have thread + * local storage allocated. + * + * @return true if registerThread has been called for this thread, false otherwise. + */ + bool currentThreadRegistered() { return slot_->currentThreadRegistered(); } + + /** + * Set thread local data on all threads previously registered via registerThread(). + * @param initializeCb supplies the functor that will be called *on each thread*. The functor + * returns the thread local object which is then stored. The storage is via + * a shared_ptr. Thus, this is a flexible mechanism that can be used to share + * the same data across all threads or to share different data on each thread. + * + * NOTE: The initialize callback is not supposed to capture the Slot, or its owner. As the owner + * may be destructed in main thread before the update_cb gets called in a worker thread. + */ + using InitializeCb = std::function(Event::Dispatcher& dispatcher)>; + void set(InitializeCb cb) { slot_->set(cb); } + + /** + * @return a reference to the thread local object. + */ + T& get() { return slot_->getTyped(); } + + /** + * @return a pointer to the thread local object. + */ + T* operator->() { return &get(); } + + /** + * UpdateCb is passed a mutable reference to the current stored data. + * + * NOTE: The update callback is not supposed to capture the TypedSlot, or its owner. As the owner + * may be destructed in main thread before the update_cb gets called in a worker thread. + */ + using UpdateCb = std::function; + void runOnAllThreads(const UpdateCb& cb) { slot_->runOnAllThreads(makeSlotUpdateCb(cb)); } + void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { + slot_->runOnAllThreads(makeSlotUpdateCb(cb), complete_cb); + } + +private: + Slot::UpdateCb makeSlotUpdateCb(UpdateCb cb) { + return [cb](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr { + cb(obj->asType()); + return obj; + }; + } + + const SlotPtr slot_; +}; + +template using TypedSlotPtr = std::unique_ptr>; + /** * Interface for getting and setting thread local data as well as registering a thread */ diff --git a/source/common/stats/thread_local_store.cc b/source/common/stats/thread_local_store.cc index 4bd4ec6a9d6a..b0704ff97c19 100644 --- a/source/common/stats/thread_local_store.cc +++ b/source/common/stats/thread_local_store.cc @@ -185,10 +185,9 @@ void ThreadLocalStoreImpl::initializeThreading(Event::Dispatcher& main_thread_di ThreadLocal::Instance& tls) { threading_ever_initialized_ = true; main_thread_dispatcher_ = &main_thread_dispatcher; - tls_ = tls.allocateSlot(); - tls_->set([](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr { - return std::make_shared(); - }); + tls_cache_ = ThreadLocal::TypedSlot::makeUnique(tls); + tls_cache_->set( + [](Event::Dispatcher&) -> std::shared_ptr { return std::make_shared(); }); } void ThreadLocalStoreImpl::shutdownThreading() { @@ -205,14 +204,12 @@ void ThreadLocalStoreImpl::mergeHistograms(PostMergeCb merge_complete_cb) { if (!shutting_down_) { ASSERT(!merge_in_progress_); merge_in_progress_ = true; - tls_->runOnAllThreads( - [](ThreadLocal::ThreadLocalObjectSharedPtr object) - -> ThreadLocal::ThreadLocalObjectSharedPtr { - for (const auto& id_hist : object->asType().tls_histogram_cache_) { + tls_cache_->runOnAllThreads( + [](TlsCache& tls_cache) { + for (const auto& id_hist : tls_cache.tls_histogram_cache_) { const TlsHistogramSharedPtr& tls_hist = id_hist.second; tls_hist->beginMerge(); } - return object; }, [this, merge_complete_cb]() -> void { mergeInternal(merge_complete_cb); }); } else { @@ -305,12 +302,8 @@ void ThreadLocalStoreImpl::clearScopeFromCaches(uint64_t scope_id, // at the same time. if (!shutting_down_) { // Perform a cache flush on all threads. - tls_->runOnAllThreads( - [scope_id](ThreadLocal::ThreadLocalObjectSharedPtr object) - -> ThreadLocal::ThreadLocalObjectSharedPtr { - object->asType().eraseScope(scope_id); - return object; - }, + tls_cache_->runOnAllThreads( + [scope_id](TlsCache& tls_cache) { tls_cache.eraseScope(scope_id); }, [central_cache]() { /* Holds onto central_cache until all tls caches are clear */ }); } } @@ -326,11 +319,8 @@ void ThreadLocalStoreImpl::clearHistogramFromCaches(uint64_t histogram_id) { // https://gist.github.com/jmarantz/838cb6de7e74c0970ea6b63eded0139a // contains a patch that will implement batching together to clear multiple // histograms. - tls_->runOnAllThreads([histogram_id](ThreadLocal::ThreadLocalObjectSharedPtr object) - -> ThreadLocal::ThreadLocalObjectSharedPtr { - object->asType().eraseHistogram(histogram_id); - return object; - }); + tls_cache_->runOnAllThreads( + [histogram_id](TlsCache& tls_cache) { tls_cache.eraseHistogram(histogram_id); }); } } @@ -498,8 +488,8 @@ Counter& ThreadLocalStoreImpl::ScopeImpl::counterFromStatNameWithTags( // initialized currently. StatRefMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; - if (!parent_.shutting_down_ && parent_.tls_) { - TlsCacheEntry& entry = parent_.tls_->getTyped().insertScope(this->scope_id_); + if (!parent_.shutting_down_ && parent_.tls_cache_) { + TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_); tls_cache = &entry.counters_; tls_rejected_stats = &entry.rejected_stats_; } @@ -550,8 +540,8 @@ Gauge& ThreadLocalStoreImpl::ScopeImpl::gaugeFromStatNameWithTags( StatRefMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; - if (!parent_.shutting_down_ && parent_.tls_) { - TlsCacheEntry& entry = parent_.tls_->getTyped().scope_cache_[this->scope_id_]; + if (!parent_.shutting_down_ && parent_.tls_cache_) { + TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_]; tls_cache = &entry.gauges_; tls_rejected_stats = &entry.rejected_stats_; } @@ -588,8 +578,8 @@ Histogram& ThreadLocalStoreImpl::ScopeImpl::histogramFromStatNameWithTags( StatNameHashMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; - if (!parent_.shutting_down_ && parent_.tls_) { - TlsCacheEntry& entry = parent_.tls_->getTyped().scope_cache_[this->scope_id_]; + if (!parent_.shutting_down_ && parent_.tls_cache_) { + TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_]; tls_cache = &entry.parent_histograms_; auto iter = tls_cache->find(final_stat_name); if (iter != tls_cache->end()) { @@ -666,8 +656,8 @@ TextReadout& ThreadLocalStoreImpl::ScopeImpl::textReadoutFromStatNameWithTags( // initialized currently. StatRefMap* tls_cache = nullptr; StatNameHashSet* tls_rejected_stats = nullptr; - if (!parent_.shutting_down_ && parent_.tls_) { - TlsCacheEntry& entry = parent_.tls_->getTyped().insertScope(this->scope_id_); + if (!parent_.shutting_down_ && parent_.tls_cache_) { + TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_); tls_cache = &entry.text_readouts_; tls_rejected_stats = &entry.rejected_stats_; } @@ -712,9 +702,8 @@ Histogram& ThreadLocalStoreImpl::tlsHistogram(ParentHistogramImpl& parent, uint6 // See comments in counterFromStatName() which explains the logic here. TlsHistogramSharedPtr* tls_histogram = nullptr; - if (!shutting_down_ && tls_ != nullptr) { - TlsCache& tls_cache = tls_->getTyped(); - tls_histogram = &tls_cache.tls_histogram_cache_[id]; + if (!shutting_down_ && tls_cache_) { + tls_histogram = &tls_cache_->get().tls_histogram_cache_[id]; if (*tls_histogram != nullptr) { return **tls_histogram; } diff --git a/source/common/stats/thread_local_store.h b/source/common/stats/thread_local_store.h index 8ef60df207ba..410254afeb33 100644 --- a/source/common/stats/thread_local_store.h +++ b/source/common/stats/thread_local_store.h @@ -480,7 +480,8 @@ class ThreadLocalStoreImpl : Logger::Loggable, public StoreRo Allocator& alloc_; Event::Dispatcher* main_thread_dispatcher_{}; - ThreadLocal::SlotPtr tls_; + using TlsCacheSlot = ThreadLocal::TypedSlotPtr; + ThreadLocal::TypedSlotPtr tls_cache_; mutable Thread::MutexBasicLockable lock_; absl::flat_hash_set scopes_ ABSL_GUARDED_BY(lock_); ScopePtr default_scope_; diff --git a/source/common/thread_local/thread_local_impl.cc b/source/common/thread_local/thread_local_impl.cc index 7ed9eeca7942..a6c924ba9f9f 100644 --- a/source/common/thread_local/thread_local_impl.cc +++ b/source/common/thread_local/thread_local_impl.cc @@ -44,6 +44,9 @@ InstanceImpl::SlotImpl::SlotImpl(InstanceImpl& parent, uint32_t index) Event::PostCb InstanceImpl::SlotImpl::wrapCallback(Event::PostCb&& cb) { // See the header file comments for still_alive_guard_ for the purpose of this capture and the // expired check below. + // + // Note also that this logic is duplicated below and dataCallback(), rather + // than incurring another lambda redirection. return [still_alive_guard = std::weak_ptr(still_alive_guard_), cb] { if (!still_alive_guard.expired()) { cb(); @@ -66,17 +69,35 @@ ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::getWorker(uint32_t index) { ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() { return getWorker(index_); } -void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { +Event::PostCb InstanceImpl::SlotImpl::dataCallback(const UpdateCb& cb) { // See the header file comments for still_alive_guard_ for why we capture index_. - parent_.runOnAllThreads( - wrapCallback([cb, index = index_]() { setThreadLocal(index, cb(getWorker(index))); }), - complete_cb); + return [still_alive_guard = std::weak_ptr(still_alive_guard_), cb, index = index_] { + // This duplicates logic in wrapCallback() (above). Using wrapCallback also + // works, but incurs another indirection of lambda at runtime. As the + // duplicated logic is only an if-statement and a bool function, it doesn't + // seem worth factoring that out to a helper function. + if (still_alive_guard.expired()) { + return; + } + auto obj = getWorker(index); + auto new_obj = cb(obj); + // The API definition for runOnAllThreads allows for replacing the object + // via the callback return value. However, this never occurs in the codebase + // as of Oct 2020, and we plan to remove this API. To avoid PR races, we + // will add an assert to ensure such a dependency does not emerge. + // + // TODO(jmarantz): remove this once we phase out use of the untyped slot + // API, rename it, and change all call-sites to use TypedSlot. + ASSERT(obj.get() == new_obj.get()); + }; +} + +void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) { + parent_.runOnAllThreads(dataCallback(cb), complete_cb); } void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) { - // See the header file comments for still_alive_guard_ for why we capture index_. - parent_.runOnAllThreads( - wrapCallback([cb, index = index_]() { setThreadLocal(index, cb(getWorker(index))); })); + parent_.runOnAllThreads(dataCallback(cb)); } void InstanceImpl::SlotImpl::set(InitializeCb cb) { diff --git a/source/common/thread_local/thread_local_impl.h b/source/common/thread_local/thread_local_impl.h index 2b83a2aebf47..888657ad5e81 100644 --- a/source/common/thread_local/thread_local_impl.h +++ b/source/common/thread_local/thread_local_impl.h @@ -37,6 +37,7 @@ class InstanceImpl : Logger::Loggable, public NonCopyable, pub SlotImpl(InstanceImpl& parent, uint32_t index); ~SlotImpl() override { parent_.removeSlot(index_); } Event::PostCb wrapCallback(Event::PostCb&& cb); + Event::PostCb dataCallback(const UpdateCb& cb); static bool currentThreadRegisteredWorker(uint32_t index); static ThreadLocalObjectSharedPtr getWorker(uint32_t index); diff --git a/test/common/stats/thread_local_store_test.cc b/test/common/stats/thread_local_store_test.cc index 395a84cf32e6..9e97d323d43d 100644 --- a/test/common/stats/thread_local_store_test.cc +++ b/test/common/stats/thread_local_store_test.cc @@ -54,12 +54,9 @@ class ThreadLocalStoreTestingPeer { static void numTlsHistograms(ThreadLocalStoreImpl& thread_local_store_impl, const std::function& num_tls_hist_cb) { auto num_tls_histograms = std::make_shared>(0); - thread_local_store_impl.tls_->runOnAllThreads( - [num_tls_histograms](ThreadLocal::ThreadLocalObjectSharedPtr object) - -> ThreadLocal::ThreadLocalObjectSharedPtr { - auto& tls_cache = object->asType(); + thread_local_store_impl.tls_cache_->runOnAllThreads( + [num_tls_histograms](ThreadLocalStoreImpl::TlsCache& tls_cache) { *num_tls_histograms += tls_cache.tls_histogram_cache_.size(); - return object; }, [num_tls_hist_cb, num_tls_histograms]() { num_tls_hist_cb(*num_tls_histograms); }); } diff --git a/test/common/thread_local/thread_local_impl_test.cc b/test/common/thread_local/thread_local_impl_test.cc index a625d57002a7..0e203e77a94f 100644 --- a/test/common/thread_local/thread_local_impl_test.cc +++ b/test/common/thread_local/thread_local_impl_test.cc @@ -149,35 +149,56 @@ TEST_F(ThreadLocalInstanceImplTest, CallbackNotInvokedAfterDeletion) { tls_.shutdownGlobalThreading(); } -// Test that the config passed into the update callback is the previous version stored in the slot. +// Test that the update callback is called as expected, for the worker and main threads. TEST_F(ThreadLocalInstanceImplTest, UpdateCallback) { InSequence s; SlotPtr slot = tls_.allocateSlot(); - auto newer_version = std::make_shared(); - bool update_called = false; + uint32_t update_called = 0; TestThreadLocalObject& object_ref = setObject(*slot); - auto update_cb = [&object_ref, &update_called, - newer_version](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr { - // The unit test setup have two dispatchers registered, but only one thread, this lambda will be - // called twice in the same thread. - if (!update_called) { - EXPECT_EQ(obj.get(), &object_ref); - update_called = true; - } else { - EXPECT_EQ(obj.get(), newer_version.get()); - } - - return newer_version; + auto update_cb = [&update_called](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr { + ++update_called; + return obj; }; EXPECT_CALL(thread_dispatcher_, post(_)); EXPECT_CALL(object_ref, onDestroy()); - EXPECT_CALL(*newer_version, onDestroy()); slot->runOnAllThreads(update_cb); - EXPECT_EQ(newer_version.get(), &slot->getTyped()); + EXPECT_EQ(2, update_called); // 1 worker, 1 main thread. + + tls_.shutdownGlobalThreading(); + tls_.shutdownThread(); +} + +struct StringSlotObject : public ThreadLocalObject { + std::string str_; +}; + +TEST_F(ThreadLocalInstanceImplTest, TypedUpdateCallback) { + InSequence s; + + TypedSlot slot(tls_); + + uint32_t update_called = 0; + EXPECT_CALL(thread_dispatcher_, post(_)); + slot.set([](Event::Dispatcher&) -> std::shared_ptr { + auto s = std::make_shared(); + s->str_ = "hello"; + return s; + }); + EXPECT_EQ("hello", slot.get().str_); + + auto update_cb = [&update_called](StringSlotObject& s) { + ++update_called; + s.str_ = "goodbye"; + }; + EXPECT_CALL(thread_dispatcher_, post(_)); + slot.runOnAllThreads(update_cb); + + EXPECT_EQ("goodbye", slot.get().str_); + EXPECT_EQ(2, update_called); // 1 worker, 1 main thread. tls_.shutdownGlobalThreading(); tls_.shutdownThread();