diff --git a/source/common/common/BUILD b/source/common/common/BUILD index f5632b8fee27..bf7701a21fa8 100644 --- a/source/common/common/BUILD +++ b/source/common/common/BUILD @@ -320,6 +320,7 @@ envoy_cc_library( external_deps = ["abseil_synchronization"], deps = envoy_cc_platform_dep("thread_impl_lib") + [ ":non_copyable", + "//source/common/singleton:threadsafe_singleton", ], ) diff --git a/source/common/common/thread.h b/source/common/common/thread.h index 4808d391dfbd..8c9d991268aa 100644 --- a/source/common/common/thread.h +++ b/source/common/common/thread.h @@ -8,6 +8,7 @@ #include "envoy/thread/thread.h" #include "common/common/non_copyable.h" +#include "common/singleton/threadsafe_singleton.h" #include "absl/synchronization/mutex.h" @@ -168,5 +169,19 @@ class AtomicPtr : private AtomicPtrArray { T* get(const MakeObject& make_object) { return BaseClass::get(0, make_object); } }; +struct MainThread { + using MainThreadSingleton = InjectableSingleton; + bool inMainThread() const { return main_thread_id_ == std::this_thread::get_id(); } + static void init() { MainThreadSingleton::initialize(new MainThread()); } + static void clear() { + free(MainThreadSingleton::getExisting()); + MainThreadSingleton::clear(); + } + static bool isMainThread() { return MainThreadSingleton::get().inMainThread(); } + +private: + std::thread::id main_thread_id_{std::this_thread::get_id()}; +}; + } // namespace Thread } // namespace Envoy diff --git a/source/common/thread_local/thread_local_impl.cc b/source/common/thread_local/thread_local_impl.cc index 0815236a3195..9bf485419745 100644 --- a/source/common/thread_local/thread_local_impl.cc +++ b/source/common/thread_local/thread_local_impl.cc @@ -16,13 +16,14 @@ namespace ThreadLocal { thread_local InstanceImpl::ThreadLocalData InstanceImpl::thread_local_data_; InstanceImpl::~InstanceImpl() { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); ASSERT(shutdown_); thread_local_data_.data_.clear(); + Thread::MainThread::clear(); } SlotPtr InstanceImpl::allocateSlot() { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); ASSERT(!shutdown_); if (free_slot_indexes_.empty()) { @@ -91,7 +92,7 @@ void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) { } void InstanceImpl::SlotImpl::set(InitializeCb cb) { - ASSERT(std::this_thread::get_id() == parent_.main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); ASSERT(!parent_.shutdown_); for (Event::Dispatcher& dispatcher : parent_.registered_threads_) { @@ -105,7 +106,7 @@ void InstanceImpl::SlotImpl::set(InitializeCb cb) { } void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_thread) { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); ASSERT(!shutdown_); if (main_thread) { @@ -119,7 +120,7 @@ void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_threa } void InstanceImpl::removeSlot(uint32_t slot) { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); // When shutting down, we do not post slot removals to other threads. This is because the other // threads have already shut down and the dispatcher is no longer alive. There is also no reason @@ -146,7 +147,7 @@ void InstanceImpl::removeSlot(uint32_t slot) { } void InstanceImpl::runOnAllThreads(Event::PostCb cb) { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); ASSERT(!shutdown_); for (Event::Dispatcher& dispatcher : registered_threads_) { @@ -158,7 +159,7 @@ void InstanceImpl::runOnAllThreads(Event::PostCb cb) { } void InstanceImpl::runOnAllThreads(Event::PostCb cb, Event::PostCb all_threads_complete_cb) { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); ASSERT(!shutdown_); // Handle main thread first so that when the last worker thread wins, we could just call the // all_threads_complete_cb method. Parallelism of main thread execution is being traded off @@ -185,7 +186,7 @@ void InstanceImpl::setThreadLocal(uint32_t index, ThreadLocalObjectSharedPtr obj } void InstanceImpl::shutdownGlobalThreading() { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); ASSERT(!shutdown_); shutdown_ = true; } diff --git a/source/common/thread_local/thread_local_impl.h b/source/common/thread_local/thread_local_impl.h index 7abed0499166..e098723d7232 100644 --- a/source/common/thread_local/thread_local_impl.h +++ b/source/common/thread_local/thread_local_impl.h @@ -19,7 +19,7 @@ namespace ThreadLocal { */ class InstanceImpl : Logger::Loggable, public NonCopyable, public Instance { public: - InstanceImpl() : main_thread_id_(std::this_thread::get_id()) {} + InstanceImpl() { Thread::MainThread::init(); } ~InstanceImpl() override; // ThreadLocal::Instance @@ -81,7 +81,6 @@ class InstanceImpl : Logger::Loggable, public NonCopyable, pub // A list of index of freed slots. std::list free_slot_indexes_; std::list> registered_threads_; - std::thread::id main_thread_id_; Event::Dispatcher* main_thread_dispatcher_{}; std::atomic shutdown_{}; diff --git a/source/server/server.cc b/source/server/server.cc index 1d5f95378f28..26795a8b2aed 100644 --- a/source/server/server.cc +++ b/source/server/server.cc @@ -89,7 +89,7 @@ InstanceImpl::InstanceImpl( : nullptr), grpc_context_(store.symbolTable()), http_context_(store.symbolTable()), router_context_(store.symbolTable()), process_context_(std::move(process_context)), - main_thread_id_(std::this_thread::get_id()), hooks_(hooks), server_contexts_(*this) { + hooks_(hooks), server_contexts_(*this) { try { if (!options.logPath().empty()) { try { @@ -819,7 +819,7 @@ InstanceImpl::registerCallback(Stage stage, StageCallbackWithCompletion callback } void InstanceImpl::notifyCallbacksForStage(Stage stage, Event::PostCb completion_cb) { - ASSERT(std::this_thread::get_id() == main_thread_id_); + ASSERT(Thread::MainThread::isMainThread()); const auto it = stage_callbacks_.find(stage); if (it != stage_callbacks_.end()) { for (const StageCallback& callback : it->second) { diff --git a/source/server/server.h b/source/server/server.h index 117107524827..cf4d24eadd56 100644 --- a/source/server/server.h +++ b/source/server/server.h @@ -362,7 +362,6 @@ class InstanceImpl final : Logger::Loggable, Router::ContextImpl router_context_; std::unique_ptr process_context_; std::unique_ptr heap_shrinker_; - const std::thread::id main_thread_id_; // initialization_time is a histogram for tracking the initialization time across hot restarts // whenever we have support for histogram merge across hot restarts. Stats::TimespanPtr initialization_timer_; diff --git a/test/common/thread_local/thread_local_impl_test.cc b/test/common/thread_local/thread_local_impl_test.cc index 59bdd6d0080b..de6c67318c6a 100644 --- a/test/common/thread_local/thread_local_impl_test.cc +++ b/test/common/thread_local/thread_local_impl_test.cc @@ -15,6 +15,26 @@ using testing::ReturnPointee; namespace Envoy { namespace ThreadLocal { +TEST(MainThreadVerificationTest, All) { + // Main thread singleton is initialized in the constructor of tls instance. Call to main thread + // verification will fail before that. + EXPECT_DEATH(Thread::MainThread::isMainThread(), + "InjectableSingleton used prior to initialization"); + { + EXPECT_DEATH(Thread::MainThread::isMainThread(), + "InjectableSingleton used prior to initialization"); + InstanceImpl tls; + // Call to main thread verification should succeed after tls instance has been initialized. + ASSERT(Thread::MainThread::isMainThread()); + tls.shutdownGlobalThreading(); + tls.shutdownThread(); + } + // Main thread singleton is cleared in the destructor of tls instance. Call to main thread + // verification will fail after that. + EXPECT_DEATH(Thread::MainThread::isMainThread(), + "InjectableSingleton used prior to initialization"); +} + class TestThreadLocalObject : public ThreadLocalObject { public: ~TestThreadLocalObject() override { onDestroy(); }