Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tls: Typesafe tls slots #13789

Merged
merged 10 commits into from
Oct 29, 2020
91 changes: 87 additions & 4 deletions include/envoy/thread_local/thread_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,17 @@ class Slot {
virtual void set(InitializeCb cb) PURE;

/**
* UpdateCb takes the current stored data, and returns an updated/new version data.
* TLS will run the callback and replace the stored data with the returned value *in each thread*.
* UpdateCb takes the current stored data, and must return the same data. The
* API was designed to allow replacement of the object via this API, but this
* is not currently used, and thus we are removing the functionality. In a future
* PR, the API will be removed.
*
* NOTE: The update callback is not supposed to capture the Slot, or its owner. As the owner may
* be destructed in main thread before the update_cb gets called in a worker thread.
* TLS will run the callback and assert the returned returned value matches
* the current value.
*
* NOTE: The update callback is not supposed to capture the Slot, or its
* owner. As the owner may be destructed in main thread before the update_cb
* gets called in a worker thread.
jmarantz marked this conversation as resolved.
Show resolved Hide resolved
**/
using UpdateCb = std::function<ThreadLocalObjectSharedPtr(ThreadLocalObjectSharedPtr)>;
virtual void runOnAllThreads(const UpdateCb& update_cb) PURE;
Expand All @@ -102,6 +108,83 @@ class SlotAllocator {
virtual SlotPtr allocateSlot() PURE;
};

// Provides a typesafe API for slots.
//
// TODO(jmarantz): Rename the Slot class to something like RawSlot, where the
// only reference is from TypedSlot, which we can then rename to Slot.
template <class T> class TypedSlot {
public:
/**
* Helper method to create a unique_ptr for a typed slot. This helper
* reduces some verbose parameterization at call-sites.
*
* @param allocator factory to allocate untyped Slot objects.
* @return a TypedSlotPtr<T> (the type is defined below).
*/
static std::unique_ptr<TypedSlot> makeUnique(SlotAllocator& allocator) {
return std::make_unique<TypedSlot>(allocator);
}

explicit TypedSlot(SlotAllocator& allocator) : slot_(allocator.allocateSlot()) {}

/**
* Returns if there is thread local data for this thread.
*
* This should return true for Envoy worker threads and false for threads which do not have thread
* local storage allocated.
*
* @return true if registerThread has been called for this thread, false otherwise.
*/
bool currentThreadRegistered() { return slot_->currentThreadRegistered(); }

/**
* Set thread local data on all threads previously registered via registerThread().
* @param initializeCb supplies the functor that will be called *on each thread*. The functor
* returns the thread local object which is then stored. The storage is via
* a shared_ptr. Thus, this is a flexible mechanism that can be used to share
* the same data across all threads or to share different data on each thread.
*
* NOTE: The initialize callback is not supposed to capture the Slot, or its owner. As the owner
* may be destructed in main thread before the update_cb gets called in a worker thread.
*/
using InitializeCb = std::function<std::shared_ptr<T>(Event::Dispatcher& dispatcher)>;
void set(InitializeCb cb) { slot_->set(cb); }

/**
* @return a reference to the thread local object.
*/
T& get() { return slot_->getTyped<T>(); }

/**
* @return a pointer to the thread local object.
*/
T* operator->() { return &get(); }

/**
* UpdateCb is passed a mutable reference to the current stored data.
*
* NOTE: The update callback is not supposed to capture the TypedSlot, or its owner. As the owner
* may be destructed in main thread before the update_cb gets called in a worker thread.
*/
using UpdateCb = std::function<void(T& obj)>;
void runOnAllThreads(const UpdateCb& cb) { slot_->runOnAllThreads(makeSlotUpdateCb(cb)); }
void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) {
slot_->runOnAllThreads(makeSlotUpdateCb(cb), complete_cb);
}

private:
Slot::UpdateCb makeSlotUpdateCb(UpdateCb cb) {
return [cb](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr {
cb(obj->asType<T>());
return obj;
};
jmarantz marked this conversation as resolved.
Show resolved Hide resolved
}

const SlotPtr slot_;
};

template <class T> using TypedSlotPtr = std::unique_ptr<TypedSlot<T>>;

/**
* Interface for getting and setting thread local data as well as registering a thread
*/
Expand Down
51 changes: 20 additions & 31 deletions source/common/stats/thread_local_store.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,10 +185,9 @@ void ThreadLocalStoreImpl::initializeThreading(Event::Dispatcher& main_thread_di
ThreadLocal::Instance& tls) {
threading_ever_initialized_ = true;
main_thread_dispatcher_ = &main_thread_dispatcher;
tls_ = tls.allocateSlot();
tls_->set([](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return std::make_shared<TlsCache>();
});
tls_cache_ = ThreadLocal::TypedSlot<TlsCache>::makeUnique(tls);
tls_cache_->set(
[](Event::Dispatcher&) -> std::shared_ptr<TlsCache> { return std::make_shared<TlsCache>(); });
}

void ThreadLocalStoreImpl::shutdownThreading() {
Expand All @@ -205,14 +204,12 @@ void ThreadLocalStoreImpl::mergeHistograms(PostMergeCb merge_complete_cb) {
if (!shutting_down_) {
ASSERT(!merge_in_progress_);
merge_in_progress_ = true;
tls_->runOnAllThreads(
[](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
for (const auto& id_hist : object->asType<TlsCache>().tls_histogram_cache_) {
tls_cache_->runOnAllThreads(
[](TlsCache& tls_cache) {
for (const auto& id_hist : tls_cache.tls_histogram_cache_) {
const TlsHistogramSharedPtr& tls_hist = id_hist.second;
tls_hist->beginMerge();
}
return object;
},
[this, merge_complete_cb]() -> void { mergeInternal(merge_complete_cb); });
} else {
Expand Down Expand Up @@ -305,12 +302,8 @@ void ThreadLocalStoreImpl::clearScopeFromCaches(uint64_t scope_id,
// at the same time.
if (!shutting_down_) {
// Perform a cache flush on all threads.
tls_->runOnAllThreads(
[scope_id](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
object->asType<TlsCache>().eraseScope(scope_id);
return object;
},
tls_cache_->runOnAllThreads(
[scope_id](TlsCache& tls_cache) { tls_cache.eraseScope(scope_id); },
[central_cache]() { /* Holds onto central_cache until all tls caches are clear */ });
}
}
Expand All @@ -326,11 +319,8 @@ void ThreadLocalStoreImpl::clearHistogramFromCaches(uint64_t histogram_id) {
// https://gist.github.com/jmarantz/838cb6de7e74c0970ea6b63eded0139a
// contains a patch that will implement batching together to clear multiple
// histograms.
tls_->runOnAllThreads([histogram_id](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
object->asType<TlsCache>().eraseHistogram(histogram_id);
return object;
});
tls_cache_->runOnAllThreads(
[histogram_id](TlsCache& tls_cache) { tls_cache.eraseHistogram(histogram_id); });
}
}

Expand Down Expand Up @@ -498,8 +488,8 @@ Counter& ThreadLocalStoreImpl::ScopeImpl::counterFromStatNameWithTags(
// initialized currently.
StatRefMap<Counter>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().insertScope(this->scope_id_);
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_);
tls_cache = &entry.counters_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -550,8 +540,8 @@ Gauge& ThreadLocalStoreImpl::ScopeImpl::gaugeFromStatNameWithTags(

StatRefMap<Gauge>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().scope_cache_[this->scope_id_];
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_];
tls_cache = &entry.gauges_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -588,8 +578,8 @@ Histogram& ThreadLocalStoreImpl::ScopeImpl::histogramFromStatNameWithTags(

StatNameHashMap<ParentHistogramSharedPtr>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().scope_cache_[this->scope_id_];
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().scope_cache_[this->scope_id_];
tls_cache = &entry.parent_histograms_;
auto iter = tls_cache->find(final_stat_name);
if (iter != tls_cache->end()) {
Expand Down Expand Up @@ -666,8 +656,8 @@ TextReadout& ThreadLocalStoreImpl::ScopeImpl::textReadoutFromStatNameWithTags(
// initialized currently.
StatRefMap<TextReadout>* tls_cache = nullptr;
StatNameHashSet* tls_rejected_stats = nullptr;
if (!parent_.shutting_down_ && parent_.tls_) {
TlsCacheEntry& entry = parent_.tls_->getTyped<TlsCache>().insertScope(this->scope_id_);
if (!parent_.shutting_down_ && parent_.tls_cache_) {
TlsCacheEntry& entry = parent_.tls_cache_->get().insertScope(this->scope_id_);
tls_cache = &entry.text_readouts_;
tls_rejected_stats = &entry.rejected_stats_;
}
Expand Down Expand Up @@ -712,9 +702,8 @@ Histogram& ThreadLocalStoreImpl::tlsHistogram(ParentHistogramImpl& parent, uint6
// See comments in counterFromStatName() which explains the logic here.

TlsHistogramSharedPtr* tls_histogram = nullptr;
if (!shutting_down_ && tls_ != nullptr) {
TlsCache& tls_cache = tls_->getTyped<TlsCache>();
tls_histogram = &tls_cache.tls_histogram_cache_[id];
if (!shutting_down_ && tls_cache_) {
tls_histogram = &tls_cache_->get().tls_histogram_cache_[id];
if (*tls_histogram != nullptr) {
return **tls_histogram;
}
Expand Down
3 changes: 2 additions & 1 deletion source/common/stats/thread_local_store.h
Original file line number Diff line number Diff line change
Expand Up @@ -480,7 +480,8 @@ class ThreadLocalStoreImpl : Logger::Loggable<Logger::Id::stats>, public StoreRo

Allocator& alloc_;
Event::Dispatcher* main_thread_dispatcher_{};
ThreadLocal::SlotPtr tls_;
using TlsCacheSlot = ThreadLocal::TypedSlotPtr<TlsCache>;
ThreadLocal::TypedSlotPtr<TlsCache> tls_cache_;
mutable Thread::MutexBasicLockable lock_;
absl::flat_hash_set<ScopeImpl*> scopes_ ABSL_GUARDED_BY(lock_);
ScopePtr default_scope_;
Expand Down
28 changes: 21 additions & 7 deletions source/common/thread_local/thread_local_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,17 +66,31 @@ ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::getWorker(uint32_t index) {

ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() { return getWorker(index_); }

void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) {
Event::PostCb InstanceImpl::SlotImpl::dataCallback(const UpdateCb& cb) {
// See the header file comments for still_alive_guard_ for why we capture index_.
parent_.runOnAllThreads(
wrapCallback([cb, index = index_]() { setThreadLocal(index, cb(getWorker(index))); }),
complete_cb);
return [still_alive_guard = std::weak_ptr<bool>(still_alive_guard_), cb, index = index_] {
if (still_alive_guard.expired()) {
return;
}
jmarantz marked this conversation as resolved.
Show resolved Hide resolved
auto obj = getWorker(index);
auto new_obj = cb(obj);
// The API definition for runOnAllThreads allows for replacing the object
// via the callback return value. However, this never occurs in the codebase
// as of Oct 2020, and we plan to remove this API. To avoid PR races, we
// will add an assert to ensure such a dependency does not emerge.
//
// TODO(jmarantz): remove this once we phase out use of the untyped slot
// API, rename it, and change all call-sites to use TypedSlot.
ASSERT(obj.get() == new_obj.get());
};
}

void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) {
parent_.runOnAllThreads(dataCallback(cb), complete_cb);
}

void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) {
// See the header file comments for still_alive_guard_ for why we capture index_.
parent_.runOnAllThreads(
wrapCallback([cb, index = index_]() { setThreadLocal(index, cb(getWorker(index))); }));
parent_.runOnAllThreads(dataCallback(cb));
}

void InstanceImpl::SlotImpl::set(InitializeCb cb) {
Expand Down
1 change: 1 addition & 0 deletions source/common/thread_local/thread_local_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class InstanceImpl : Logger::Loggable<Logger::Id::main>, public NonCopyable, pub
SlotImpl(InstanceImpl& parent, uint32_t index);
~SlotImpl() override { parent_.removeSlot(index_); }
Event::PostCb wrapCallback(Event::PostCb&& cb);
Event::PostCb dataCallback(const UpdateCb& cb);
static bool currentThreadRegisteredWorker(uint32_t index);
static ThreadLocalObjectSharedPtr getWorker(uint32_t index);

Expand Down
7 changes: 2 additions & 5 deletions test/common/stats/thread_local_store_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,9 @@ class ThreadLocalStoreTestingPeer {
static void numTlsHistograms(ThreadLocalStoreImpl& thread_local_store_impl,
const std::function<void(uint32_t)>& num_tls_hist_cb) {
auto num_tls_histograms = std::make_shared<std::atomic<uint32_t>>(0);
thread_local_store_impl.tls_->runOnAllThreads(
[num_tls_histograms](ThreadLocal::ThreadLocalObjectSharedPtr object)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
auto& tls_cache = object->asType<ThreadLocalStoreImpl::TlsCache>();
thread_local_store_impl.tls_cache_->runOnAllThreads(
[num_tls_histograms](ThreadLocalStoreImpl::TlsCache& tls_cache) {
*num_tls_histograms += tls_cache.tls_histogram_cache_.size();
return object;
},
[num_tls_hist_cb, num_tls_histograms]() { num_tls_hist_cb(*num_tls_histograms); });
}
Expand Down
55 changes: 38 additions & 17 deletions test/common/thread_local/thread_local_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,35 +149,56 @@ TEST_F(ThreadLocalInstanceImplTest, CallbackNotInvokedAfterDeletion) {
tls_.shutdownGlobalThreading();
}

// Test that the config passed into the update callback is the previous version stored in the slot.
// Test that the update callback is called as expected, for the worker and main threads.
TEST_F(ThreadLocalInstanceImplTest, UpdateCallback) {
InSequence s;

SlotPtr slot = tls_.allocateSlot();

auto newer_version = std::make_shared<TestThreadLocalObject>();
bool update_called = false;
uint32_t update_called = 0;

TestThreadLocalObject& object_ref = setObject(*slot);
auto update_cb = [&object_ref, &update_called,
newer_version](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr {
// The unit test setup have two dispatchers registered, but only one thread, this lambda will be
// called twice in the same thread.
if (!update_called) {
EXPECT_EQ(obj.get(), &object_ref);
update_called = true;
} else {
EXPECT_EQ(obj.get(), newer_version.get());
}

return newer_version;
auto update_cb = [&update_called](ThreadLocalObjectSharedPtr obj) -> ThreadLocalObjectSharedPtr {
++update_called;
return obj;
};
EXPECT_CALL(thread_dispatcher_, post(_));
EXPECT_CALL(object_ref, onDestroy());
EXPECT_CALL(*newer_version, onDestroy());
slot->runOnAllThreads(update_cb);

EXPECT_EQ(newer_version.get(), &slot->getTyped<TestThreadLocalObject>());
EXPECT_EQ(2, update_called); // 1 worker, 1 main thread.

tls_.shutdownGlobalThreading();
tls_.shutdownThread();
}

struct StringSlotObject : public ThreadLocalObject {
std::string str_;
};

TEST_F(ThreadLocalInstanceImplTest, TypedUpdateCallback) {
InSequence s;

TypedSlot<StringSlotObject> slot(tls_);

uint32_t update_called = 0;
EXPECT_CALL(thread_dispatcher_, post(_));
slot.set([](Event::Dispatcher&) -> std::shared_ptr<StringSlotObject> {
auto s = std::make_shared<StringSlotObject>();
s->str_ = "hello";
return s;
});
EXPECT_EQ("hello", slot.get().str_);

auto update_cb = [&update_called](StringSlotObject& s) {
++update_called;
s.str_ = "goodbye";
};
EXPECT_CALL(thread_dispatcher_, post(_));
slot.runOnAllThreads(update_cb);

EXPECT_EQ("goodbye", slot.get().str_);
EXPECT_EQ(2, update_called); // 1 worker, 1 main thread.

tls_.shutdownGlobalThreading();
tls_.shutdownThread();
Expand Down