Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply fix for use-after-free in Envoy ThreadLocal Slot. #111

Merged
merged 3 commits into from
Oct 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/envoy/thread_local/thread_local.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ class Slot {
*/
using InitializeCb = std::function<ThreadLocalObjectSharedPtr(Event::Dispatcher& dispatcher)>;
virtual void set(InitializeCb cb) PURE;

/**
* UpdateCb takes the current stored data, and returns an updated/new version data.
* TLS will run the callback and replace the stored data with the returned value *in each thread*.
*
* NOTE: The update callback is not supposed to capture the Slot, or its owner. As the owner may
* be destructed in main thread before the update_cb gets called in a worker thread.
**/
using UpdateCb = std::function<ThreadLocalObjectSharedPtr(ThreadLocalObjectSharedPtr)>;
virtual void runOnAllThreads(const UpdateCb& update_cb) PURE;
virtual void runOnAllThreads(const UpdateCb& update_cb, Event::PostCb complete_cb) PURE;
};

using SlotPtr = std::unique_ptr<Slot>;
Expand Down
13 changes: 9 additions & 4 deletions source/common/common/non_copyable.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,19 @@

namespace Envoy {
/**
* Mixin class that makes derived classes not copyable. Like boost::noncopyable without boost.
* Mixin class that makes derived classes not copyable and not moveable. Like boost::noncopyable
* without boost.
*/
class NonCopyable {
protected:
NonCopyable() = default;

private:
NonCopyable(const NonCopyable&);
NonCopyable& operator=(const NonCopyable&);
// Non-moveable.
NonCopyable(NonCopyable&&) noexcept = delete;
NonCopyable& operator=(NonCopyable&&) noexcept = delete;

// Non-copyable.
NonCopyable(const NonCopyable&) = delete;
NonCopyable& operator=(const NonCopyable&) = delete;
};
} // namespace Envoy
10 changes: 10 additions & 0 deletions source/common/config/config_provider_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@ ConfigSubscriptionCommonBase::~ConfigSubscriptionCommonBase() {
init_target_.ready();
config_provider_manager_.unbindSubscription(manager_identifier_);
}

void ConfigSubscriptionCommonBase::applyConfigUpdate(const ConfigUpdateCb& update_fn) {
tls_->runOnAllThreads([update_fn](ThreadLocal::ThreadLocalObjectSharedPtr previous)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
auto prev_thread_local_config = std::dynamic_pointer_cast<ThreadLocalConfig>(previous);
prev_thread_local_config->config_ = update_fn(prev_thread_local_config->config_);
return previous;
});
}

bool ConfigSubscriptionInstance::checkAndApplyConfigUpdate(const Protobuf::Message& config_proto,
const std::string& config_name,
const std::string& version_info) {
Expand Down
20 changes: 1 addition & 19 deletions source/common/config/config_provider_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,26 +220,8 @@ class ConfigSubscriptionCommonBase
*
* @param update_fn the callback to run on each thread, it takes the previous version Config and
* returns a updated/new version Config.
* @param complete_cb the callback to run when the update propagation is done.
*/
void applyConfigUpdate(
const ConfigUpdateCb& update_fn, const Event::PostCb& complete_cb = []() {}) {
// It is safe to call shared_from_this here as this is in main thread, and destruction of a
// ConfigSubscriptionCommonBase owner (i.e., a provider) happens in main thread as well.
auto shared_this = shared_from_this();
tls_->runOnAllThreads(
[this, update_fn]() {
tls_->getTyped<ThreadLocalConfig>().config_ = update_fn(this->getConfig());
},
// During the update propagation, a subscription may get teared down in main thread due to
// all owners/providers destructed in a xDS update (e.g. LDS demolishes a
// RouteConfigProvider and its subscription).
// If such a race condition happens, holding a reference to the "*this" subscription
// instance in this cb will ensure the shared "*this" gets posted back to main thread, after
// all the workers finish calling the update_fn, at which point it's safe to destruct
// "*this" instance.
[shared_this, complete_cb]() { complete_cb(); });
}
void applyConfigUpdate(const ConfigUpdateCb& update_fn);

void setLastUpdated() { last_updated_ = time_source_.systemTime(); }

Expand Down
8 changes: 6 additions & 2 deletions source/common/router/rds_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,12 @@ Router::ConfigConstSharedPtr RdsRouteConfigProviderImpl::config() {
void RdsRouteConfigProviderImpl::onConfigUpdate() {
ConfigConstSharedPtr new_config(
new ConfigImpl(config_update_info_->routeConfiguration(), factory_context_, false));
tls_->runOnAllThreads(
[this, new_config]() -> void { tls_->getTyped<ThreadLocalConfig>().config_ = new_config; });
tls_->runOnAllThreads([new_config](ThreadLocal::ThreadLocalObjectSharedPtr previous)
-> ThreadLocal::ThreadLocalObjectSharedPtr {
auto prev_config = std::dynamic_pointer_cast<ThreadLocalConfig>(previous);
prev_config->config_ = new_config;
return previous;
});
}

RouteConfigProviderManagerImpl::RouteConfigProviderManagerImpl(Server::Admin& admin) {
Expand Down
111 changes: 101 additions & 10 deletions source/common/thread_local/thread_local_impl.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "common/thread_local/thread_local_impl.h"

#include <algorithm>
#include <atomic>
#include <cstdint>
#include <list>
Expand All @@ -24,28 +25,82 @@ SlotPtr InstanceImpl::allocateSlot() {
ASSERT(std::this_thread::get_id() == main_thread_id_);
ASSERT(!shutdown_);

for (uint64_t i = 0; i < slots_.size(); i++) {
if (slots_[i] == nullptr) {
std::unique_ptr<SlotImpl> slot(new SlotImpl(*this, i));
slots_[i] = slot.get();
return slot;
}
if (free_slot_indexes_.empty()) {
std::unique_ptr<SlotImpl> slot(new SlotImpl(*this, slots_.size()));
auto wrapper = std::make_unique<Bookkeeper>(*this, std::move(slot));
slots_.push_back(wrapper->slot_.get());
return wrapper;
}

std::unique_ptr<SlotImpl> slot(new SlotImpl(*this, slots_.size()));
slots_.push_back(slot.get());
return slot;
const uint32_t idx = free_slot_indexes_.front();
free_slot_indexes_.pop_front();
ASSERT(idx < slots_.size());
std::unique_ptr<SlotImpl> slot(new SlotImpl(*this, idx));
slots_[idx] = slot.get();
return std::make_unique<Bookkeeper>(*this, std::move(slot));
}

bool InstanceImpl::SlotImpl::currentThreadRegistered() {
return thread_local_data_.data_.size() > index_;
}

void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb) {
parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); });
}

void InstanceImpl::SlotImpl::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) {
parent_.runOnAllThreads([this, cb]() { setThreadLocal(index_, cb(get())); }, complete_cb);
}

ThreadLocalObjectSharedPtr InstanceImpl::SlotImpl::get() {
ASSERT(currentThreadRegistered());
return thread_local_data_.data_[index_];
}

InstanceImpl::Bookkeeper::Bookkeeper(InstanceImpl& parent, std::unique_ptr<SlotImpl>&& slot)
: parent_(parent), slot_(std::move(slot)),
ref_count_(/*not used.*/ nullptr,
[slot = slot_.get(), &parent = this->parent_](uint32_t* /* not used */) {
// On destruction, post a cleanup callback on main thread, this could happen on
// any thread.
parent.scheduleCleanup(slot);
}) {}

ThreadLocalObjectSharedPtr InstanceImpl::Bookkeeper::get() { return slot_->get(); }

void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) {
slot_->runOnAllThreads(
[cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) {
return cb(std::move(previous));
},
complete_cb);
}

void InstanceImpl::Bookkeeper::runOnAllThreads(const UpdateCb& cb) {
slot_->runOnAllThreads([cb, ref_count = this->ref_count_](ThreadLocalObjectSharedPtr previous) {
return cb(std::move(previous));
});
}

bool InstanceImpl::Bookkeeper::currentThreadRegistered() {
return slot_->currentThreadRegistered();
}

void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb) {
// Use ref_count_ to bookkeep how many on-the-fly callback are out there.
slot_->runOnAllThreads([cb, ref_count = this->ref_count_]() { cb(); });
}

void InstanceImpl::Bookkeeper::runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) {
// Use ref_count_ to bookkeep how many on-the-fly callback are out there.
slot_->runOnAllThreads([cb, main_callback, ref_count = this->ref_count_]() { cb(); },
main_callback);
}

void InstanceImpl::Bookkeeper::set(InitializeCb cb) {
slot_->set([cb, ref_count = this->ref_count_](Event::Dispatcher& dispatcher)
-> ThreadLocalObjectSharedPtr { return cb(dispatcher); });
}

void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_thread) {
ASSERT(std::this_thread::get_id() == main_thread_id_);
ASSERT(!shutdown_);
Expand All @@ -60,6 +115,38 @@ void InstanceImpl::registerThread(Event::Dispatcher& dispatcher, bool main_threa
}
}

// Puts the slot into a deferred delete container, the slot will be destructed when its out-going
// callback reference count goes to 0.
void InstanceImpl::recycle(std::unique_ptr<SlotImpl>&& slot) {
ASSERT(std::this_thread::get_id() == main_thread_id_);
ASSERT(slot != nullptr);
auto* slot_addr = slot.get();
deferred_deletes_.insert({slot_addr, std::move(slot)});
}

// Called by the Bookkeeper ref_count destructor, the SlotImpl in the deferred deletes map can be
// destructed now.
void InstanceImpl::scheduleCleanup(SlotImpl* slot) {
if (shutdown_) {
// If server is shutting down, do nothing here.
// The destruction of Bookkeeper has already transferred the SlotImpl to the deferred_deletes_
// queue. No matter if this method is called from a Worker thread, the SlotImpl will be
// destructed on main thread when InstanceImpl destructs.
return;
}
if (std::this_thread::get_id() == main_thread_id_) {
// If called from main thread, save a callback.
ASSERT(deferred_deletes_.contains(slot));
deferred_deletes_.erase(slot);
return;
}
main_thread_dispatcher_->post([slot, this]() {
ASSERT(deferred_deletes_.contains(slot));
// The slot is guaranteed to be put into the deferred_deletes_ map by Bookkeeper destructor.
deferred_deletes_.erase(slot);
});
}

void InstanceImpl::removeSlot(SlotImpl& slot) {
ASSERT(std::this_thread::get_id() == main_thread_id_);

Expand All @@ -73,6 +160,10 @@ void InstanceImpl::removeSlot(SlotImpl& slot) {

const uint64_t index = slot.index_;
slots_[index] = nullptr;
ASSERT(std::find(free_slot_indexes_.begin(), free_slot_indexes_.end(), index) ==
free_slot_indexes_.end(),
fmt::format("slot index {} already in free slot set!", index));
free_slot_indexes_.push_back(index);
runOnAllThreads([index]() -> void {
// This runs on each thread and clears the slot, making it available for a new allocations.
// This is safe even if a new allocation comes in, because everything happens with post() and
Expand Down
43 changes: 42 additions & 1 deletion source/common/thread_local/thread_local_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,17 @@
#include "envoy/thread_local/thread_local.h"

#include "common/common/logger.h"
#include "common/common/non_copyable.h"

#include "absl/container/flat_hash_map.h"

namespace Envoy {
namespace ThreadLocal {

/**
* Implementation of ThreadLocal that relies on static thread_local objects.
*/
class InstanceImpl : Logger::Loggable<Logger::Id::main>, public Instance {
class InstanceImpl : Logger::Loggable<Logger::Id::main>, public NonCopyable, public Instance {
public:
InstanceImpl() : main_thread_id_(std::this_thread::get_id()) {}
~InstanceImpl() override;
Expand All @@ -35,6 +38,8 @@ class InstanceImpl : Logger::Loggable<Logger::Id::main>, public Instance {
// ThreadLocal::Slot
ThreadLocalObjectSharedPtr get() override;
bool currentThreadRegistered() override;
void runOnAllThreads(const UpdateCb& cb) override;
void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override;
void runOnAllThreads(Event::PostCb cb) override { parent_.runOnAllThreads(cb); }
void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override {
parent_.runOnAllThreads(cb, main_callback);
Expand All @@ -45,22 +50,58 @@ class InstanceImpl : Logger::Loggable<Logger::Id::main>, public Instance {
const uint64_t index_;
};

// A Wrapper of SlotImpl which on destruction returns the SlotImpl to the deferred delete queue
// (detaches it).
struct Bookkeeper : public Slot {
Bookkeeper(InstanceImpl& parent, std::unique_ptr<SlotImpl>&& slot);
~Bookkeeper() override { parent_.recycle(std::move(slot_)); }

// ThreadLocal::Slot
ThreadLocalObjectSharedPtr get() override;
void runOnAllThreads(const UpdateCb& cb) override;
void runOnAllThreads(const UpdateCb& cb, Event::PostCb complete_cb) override;
bool currentThreadRegistered() override;
void runOnAllThreads(Event::PostCb cb) override;
void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback) override;
void set(InitializeCb cb) override;

InstanceImpl& parent_;
std::unique_ptr<SlotImpl> slot_;
std::shared_ptr<uint32_t> ref_count_;
};

struct ThreadLocalData {
Event::Dispatcher* dispatcher_{};
std::vector<ThreadLocalObjectSharedPtr> data_;
};

void recycle(std::unique_ptr<SlotImpl>&& slot);
// Cleanup the deferred deletes queue.
void scheduleCleanup(SlotImpl* slot);

void removeSlot(SlotImpl& slot);
void runOnAllThreads(Event::PostCb cb);
void runOnAllThreads(Event::PostCb cb, Event::PostCb main_callback);
static void setThreadLocal(uint32_t index, ThreadLocalObjectSharedPtr object);

static thread_local ThreadLocalData thread_local_data_;

// A indexed container for Slots that has to be deferred to delete due to out-going callbacks
// pointing to the Slot. To let the ref_count_ deleter find the SlotImpl by address, the container
// is defined as a map of SlotImpl address to the unique_ptr<SlotImpl>.
absl::flat_hash_map<SlotImpl*, std::unique_ptr<SlotImpl>> deferred_deletes_;

std::vector<SlotImpl*> slots_;
// A list of index of freed slots.
std::list<uint32_t> free_slot_indexes_;

std::list<std::reference_wrapper<Event::Dispatcher>> registered_threads_;
std::thread::id main_thread_id_;
Event::Dispatcher* main_thread_dispatcher_{};
std::atomic<bool> shutdown_{};

// Test only.
friend class ThreadLocalInstanceImplTest;
};

} // namespace ThreadLocal
Expand Down
Loading