Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls: Typesafe tls slots #13789

Merged
merged 10 commits into from
Oct 29, 2020
75 changes: 75 additions & 0 deletions include/envoy/thread_local/thread_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,81 @@ class Slot {

using SlotPtr = std::unique_ptr<Slot>;

// Provides a typesafe API for slots.
//
// TODO(jmarantz): Rename the Slot class to something like RawSlot, where the
// only reference is from TypedSlot, which we can then rename to Slot.
template <class T> class TypedSlot {
public:
/**
* @return true if the TypedSlot object has been allocated a slot.
*/
bool hasSlot() const { return slot_ != nullptr; }

/**
* @param slot The allocated slot object. Ownership transferred into the TypedSlot.
*/
void setSlot(SlotPtr&& slot) { slot_ = std::move(slot); }
jmarantz marked this conversation as resolved.
Show resolved Hide resolved

/**
* Returns if there is thread local data for this thread.
*
* This should return true for Envoy worker threads and false for threads which do not have thread
* local storage allocated.
*
* @return true if registerThread has been called for this thread, false otherwise.
*/
bool currentThreadRegistered() { return slot_->currentThreadRegistered(); }

/**
* Set thread local data on all threads previously registered via registerThread().
* @param initializeCb supplies the functor that will be called *on each thread*. The functor
* returns the thread local object which is then stored. The storage is via
* a shared_ptr. Thus, this is a flexible mechanism that can be used to share
* the same data across all threads or to share different data on each thread.
*
* NOTE: The initialize callback is not supposed to capture the Slot, or its owner. As the owner
* may be destructed in main thread before the update_cb gets called in a worker thread.
*/
using SharedT = std::shared_ptr<T>;
using InitializeCb = std::function<SharedT(Event::Dispatcher& dispatcher)>;
void set(InitializeCb cb) { slot_->set(cb); }

/**
* @return a reference to the thread local object.
*/
T& get() { return slot_->getTyped<T>(); }

/**
* @return a pointer to the thread local object.
*/
T* operator->() { return &get(); }

/**
* UpdateCb is passed a mutable reference to the current stored data.
*
* NOTE: The update callback is not supposed to capture the TypedSlot, or its owner. As the owner
* may be destructed in main thread before the update_cb gets called in a worker thread.
*/
using UpdateCb = std::function<void(T& obj)>;
void runOnAllThreads(const UpdateCb& cb) { slot_->runOnAllThreads(makeSlotUpdateCb(cb)); }
void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) {
slot_->runOnAllThreads(makeSlotUpdateCb(cb), complete_cb);
}

private:
Slot::UpdateCb makeSlotUpdateCb(UpdateCb cb) {
return [cb](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr {
T& typed_obj = obj->asType<T>();
cb(typed_obj);
// Note: Better have a test for mutating the object.
return obj;
};
}

SlotPtr slot_;
jmarantz marked this conversation as resolved.
Show resolved Hide resolved
};

/**
* Interface used to allocate thread local slots.
*/
Expand Down
50 changes: 20 additions & 30 deletions source/common/stats/thread_local_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,9 @@ void ThreadLocalStoreImpl::initializeThreading(Event::Dispatcher& main_thread_di
ThreadLocal::Instance& tls) {
threading_ever_initialized_ = true;
main_thread_dispatcher_ = &main_thread_dispatcher;
tls_ = tls.allocateSlot();
tls_->set([](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return std::make_shared<TlsCache>();
});
tls_cache_.setSlot(tls.allocateSlot());
tls_cache_.set(
[](Event::Dispatcher&) -> std::shared_ptr<TlsCache> { return std::make_shared<TlsCache>(); });
}

void ThreadLocalStoreImpl::shutdownThreading() {
Expand All @@ -205,14 +204,12 @@ void ThreadLocalStoreImpl::mergeHistograms(PostMergeCb merge_complete_cb) {
if (!shutting_down_) {
ASSERT(!merge_in_progress_);
merge_in_progress_ = true;
tls_->runOnAllThreads(
[](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
for (const auto& id_hist : object->asType<TlsCache>().tls_histogram_cache_) {
tls_cache_.runOnAllThreads(
[](TlsCache& tls_cache) {
for (const auto& id_hist : tls_cache.tls_histogram_cache_) {
const TlsHistogramSharedPtr& tls_hist = id_hist.second;
tls_hist->beginMerge();
}
return object;
},
[this, merge_complete_cb]() -> void { mergeInternal(merge_complete_cb); });
} else {
Expand Down Expand Up @@ -305,12 +302,8 @@ void ThreadLocalStoreImpl::clearScopeFromCaches(uint64_t scope_id,
// at the same time.
if (!shutting_down_) {
// Perform a cache flush on all threads.
tls_->runOnAllThreads(
[scope_id](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
object->asType<TlsCache>().eraseScope(scope_id);
return object;
},
tls_cache_.runOnAllThreads(
[scope_id](TlsCache& tls_cache) { tls_cache.eraseScope(scope_id); },
[central_cache]() { /* Holds onto central_cache until all tls caches are clear */ });
}
}
Expand All @@ -326,11 +319,8 @@ void ThreadLocalStoreImpl::clearHistogramFromCaches(uint64_t histogram_id) {
// https://gist.github.com/jmarantz/838cb6de7e74c0970ea6b63eded0139a
// contains a patch that will implement batching together to clear multiple
// histograms.
tls_->runOnAllThreads([histogram_id](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
object->asType<TlsCache>().eraseHistogram(histogram_id);
return object;
});
tls_cache_.runOnAllThreads(
[histogram_id](TlsCache& tls_cache) { tls_cache.eraseHistogram(histogram_id); });
}
}

Expand Down Expand Up @@ -498,8 +488,8 @@ Counter& ThreadLocalStoreImpl::ScopeImpl::counterFromStatNameWithTags(
// initialized currently.
StatRefMap<Counter>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().insertScope(this->scope_id_);
if (!parent_.shutting_down_ && parent_.tls_cache_.hasSlot()) {
TlsCacheEntry& entry = parent_.tls_cache_->insertScope(this->scope_id_);
tls_cache = &entry.counters_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -550,8 +540,8 @@ Gauge& ThreadLocalStoreImpl::ScopeImpl::gaugeFromStatNameWithTags(

StatRefMap<Gauge>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().scope_cache_[this->scope_id_];
if (!parent_.shutting_down_ && parent_.tls_cache_.hasSlot()) {
TlsCacheEntry& entry = parent_.tls_cache_->scope_cache_[this->scope_id_];
tls_cache = &entry.gauges_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -588,8 +578,8 @@ Histogram& ThreadLocalStoreImpl::ScopeImpl::histogramFromStatNameWithTags(

StatNameHashMap<ParentHistogramSharedPtr>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().scope_cache_[this->scope_id_];
if (!parent_.shutting_down_ && parent_.tls_cache_.hasSlot()) {
TlsCacheEntry& entry = parent_.tls_cache_->scope_cache_[this->scope_id_];
tls_cache = &entry.parent_histograms_;
auto iter = tls_cache->find(final_stat_name);
if (iter != tls_cache->end()) {
Expand Down Expand Up @@ -666,8 +656,8 @@ TextReadout& ThreadLocalStoreImpl::ScopeImpl::textReadoutFromStatNameWithTags(
// initialized currently.
StatRefMap<TextReadout>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().insertScope(this->scope_id_);
if (!parent_.shutting_down_ && parent_.tls_cache_.hasSlot()) {
TlsCacheEntry& entry = parent_.tls_cache_->insertScope(this->scope_id_);
tls_cache = &entry.text_readouts_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -712,8 +702,8 @@ Histogram& ThreadLocalStoreImpl::tlsHistogram(ParentHistogramImpl& parent, uint6
// See comments in counterFromStatName() which explains the logic here.

TlsHistogramSharedPtr* tls_histogram = nullptr;
if (!shutting_down_ && tls_ != nullptr) {
TlsCache& tls_cache = tls_->getTyped<TlsCache>();
if (!shutting_down_ && tls_cache_.hasSlot()) {
TlsCache& tls_cache = tls_cache_.get();
tls_histogram = &tls_cache.tls_histogram_cache_[id];
if (*tls_histogram != nullptr) {
return **tls_histogram;
Expand Down
2 changes: 1 addition & 1 deletion source/common/stats/thread_local_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,7 @@ class ThreadLocalStoreImpl : Logger::Loggable<Logger::Id::stats>, public StoreRo

Allocator& alloc_;
Event::Dispatcher* main_thread_dispatcher_{};
ThreadLocal::SlotPtr tls_;
ThreadLocal::TypedSlot<TlsCache> tls_cache_;
mutable Thread::MutexBasicLockable lock_;
absl::flat_hash_set<ScopeImpl*> scopes_ ABSL_GUARDED_BY(lock_);
ScopePtr default_scope_;
Expand Down
7 changes: 2 additions & 5 deletions test/common/stats/thread_local_store_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,9 @@ class ThreadLocalStoreTestingPeer {
static void numTlsHistograms(ThreadLocalStoreImpl& thread_local_store_impl,
const std::function<void(uint32_t)>& num_tls_hist_cb) {
auto num_tls_histograms = std::make_shared<std::atomic<uint32_t>>(0);
thread_local_store_impl.tls_->runOnAllThreads(
[num_tls_histograms](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
auto& tls_cache = object->asType<ThreadLocalStoreImpl::TlsCache>();
thread_local_store_impl.tls_cache_.runOnAllThreads(
[num_tls_histograms](ThreadLocalStoreImpl::TlsCache& tls_cache) {
*num_tls_histograms += tls_cache.tls_histogram_cache_.size();
return object;
},
[num_tls_hist_cb, num_tls_histograms]() { num_tls_hist_cb(*num_tls_histograms); });
}
Expand Down