Skip to content

Commit

Permalink
Make ThreadLocalOverloadState virtual
Browse files Browse the repository at this point in the history
Make ThreadLocalOverloadState an interface that the Overload Manager can
return an implementation of. This decouples the behavior specified by
the header from the details of the implementation.

Signed-off-by: Alex Konradi <akonradi@google.com>
  • Loading branch information
akonradi committed Jun 26, 2020
1 parent 09a4d7c commit 76f494b
Show file tree
Hide file tree
Showing 7 changed files with 60 additions and 39 deletions.
25 changes: 5 additions & 20 deletions include/envoy/server/overload_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,27 +31,12 @@ using OverloadActionCb = std::function<void(OverloadActionState)>;
/**
* Thread-local copy of the state of each configured overload action.
*/
class ThreadLocalOverloadState : public ThreadLocal::ThreadLocalObject {
class ThreadLocalOverloadState {
public:
const OverloadActionState& getState(const std::string& action) {
auto it = actions_.find(action);
if (it == actions_.end()) {
it = actions_.insert(std::make_pair(action, OverloadActionState::Inactive)).first;
}
return it->second;
}

void setState(const std::string& action, OverloadActionState state) {
auto it = actions_.find(action);
if (it == actions_.end()) {
actions_[action] = state;
} else {
it->second = state;
}
}

private:
std::unordered_map<std::string, OverloadActionState> actions_;
virtual ~ThreadLocalOverloadState() = default;

// Get a thread-local reference to the value for the given action key.
virtual const OverloadActionState& getState(const std::string& action) PURE;
};

/**
Expand Down
6 changes: 0 additions & 6 deletions source/server/admin/admin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -650,12 +650,6 @@ ConfigTracker& AdminImpl::getConfigTracker() { return config_tracker_; }
AdminImpl::NullRouteConfigProvider::NullRouteConfigProvider(TimeSource& time_source)
: config_(new Router::NullConfigImpl()), time_source_(time_source) {}

OverloadTimerFactory AdminImpl::NullOverloadManager::NullThreadOverloadState::getTimerFactory() {
return [this](absl::string_view, Event::TimerCb callback) {
return dispatcher_.createTimer(callback);
};
}

void AdminImpl::startHttpListener(const std::string& access_log_path,
const std::string& address_out_path,
Network::Address::InstanceConstSharedPtr address,
Expand Down
2 changes: 0 additions & 2 deletions source/server/admin/admin.h
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,6 @@ class AdminImpl : public Admin,
struct NullThreadOverloadState : public ThreadLocalOverloadState {
NullThreadOverloadState(Event::Dispatcher& dispatcher) : dispatcher_(dispatcher) {}
const OverloadActionState& getState(const std::string&) override { return inactive_; }
void setState(const std::string&, OverloadActionState) override {}
OverloadTimerFactory getTimerFactory() override;

const OverloadActionState inactive_ = OverloadActionState::Inactive;
Event::Dispatcher& dispatcher_;
Expand Down
33 changes: 30 additions & 3 deletions source/server/overload_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,33 @@ class ThresholdTriggerImpl : public OverloadAction::Trigger {
absl::optional<double> value_;
};

/**
* Thread-local copy of the state of each configured overload action.
*/
class ThreadLocalOverloadStateImpl : public ThreadLocalOverloadState,
public ThreadLocal::ThreadLocalObject {
public:
const OverloadActionState& getState(const std::string& action) override {
auto it = actions_.find(action);
if (it == actions_.end()) {
it = actions_.insert(std::make_pair(action, OverloadActionState::Inactive)).first;
}
return it->second;
}

void setState(const std::string& action, OverloadActionState state) {
auto it = actions_.find(action);
if (it == actions_.end()) {
actions_[action] = state;
} else {
it->second = state;
}
}

private:
std::unordered_map<std::string, OverloadActionState> actions_;
};

Stats::Counter& makeCounter(Stats::Scope& scope, absl::string_view a, absl::string_view b) {
Stats::StatNameManagedStorage stat_name(absl::StrCat("overload.", a, ".", b),
scope.symbolTable());
Expand Down Expand Up @@ -148,7 +175,7 @@ void OverloadManagerImpl::start() {
started_ = true;

tls_->set([](Event::Dispatcher&) -> ThreadLocal::ThreadLocalObjectSharedPtr {
return std::make_shared<ThreadLocalOverloadState>();
return std::make_shared<ThreadLocalOverloadStateImpl>();
});

if (resources_.empty()) {
Expand Down Expand Up @@ -191,7 +218,7 @@ bool OverloadManagerImpl::registerForAction(const std::string& action,
}

ThreadLocalOverloadState& OverloadManagerImpl::getThreadLocalOverloadState() {
return tls_->getTyped<ThreadLocalOverloadState>();
return tls_->getTyped<ThreadLocalOverloadStateImpl>();
}

void OverloadManagerImpl::updateResourcePressure(const std::string& resource, double pressure) {
Expand All @@ -208,7 +235,7 @@ void OverloadManagerImpl::updateResourcePressure(const std::string& resource, do
ENVOY_LOG(info, "Overload action {} became {}", action,
is_active ? "active" : "inactive");
tls_->runOnAllThreads([this, action, state] {
tls_->getTyped<ThreadLocalOverloadState>().setState(action, state);
tls_->getTyped<ThreadLocalOverloadStateImpl>().setState(action, state);
});
auto callback_range = action_to_callbacks_.equal_range(action);
std::for_each(callback_range.first, callback_range.second,
Expand Down
17 changes: 10 additions & 7 deletions test/common/http/conn_manager_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5600,11 +5600,12 @@ TEST(HttpConnectionManagerTracingStatsTest, verifyTracingStats) {
}

TEST_F(HttpConnectionManagerImplTest, NoNewStreamWhenOverloaded) {
setup(false, "");
Server::OverloadActionState stop_accepting_requests = Server::OverloadActionState::Active;
ON_CALL(overload_manager_.overload_state_,
getState(Server::OverloadActionNames::get().StopAcceptingRequests))
.WillByDefault(ReturnRef(stop_accepting_requests));

overload_manager_.overload_state_.setState(
Server::OverloadActionNames::get().StopAcceptingRequests,
Server::OverloadActionState::Active);
setup(false, "");

EXPECT_CALL(*codec_, dispatch(_)).WillRepeatedly(Invoke([&](Buffer::Instance&) -> Http::Status {
RequestDecoder* decoder = &conn_manager_->newStream(response_encoder_);
Expand All @@ -5630,10 +5631,12 @@ TEST_F(HttpConnectionManagerImplTest, NoNewStreamWhenOverloaded) {
}

TEST_F(HttpConnectionManagerImplTest, DisableKeepAliveWhenOverloaded) {
setup(false, "");
Server::OverloadActionState disable_http_keep_alive = Server::OverloadActionState::Active;
ON_CALL(overload_manager_.overload_state_,
getState(Server::OverloadActionNames::get().DisableHttpKeepAlive))
.WillByDefault(ReturnRef(disable_http_keep_alive));

overload_manager_.overload_state_.setState(
Server::OverloadActionNames::get().DisableHttpKeepAlive, Server::OverloadActionState::Active);
setup(false, "");

std::shared_ptr<MockStreamDecoderFilter> filter(new NiceMock<MockStreamDecoderFilter>());
EXPECT_CALL(filter_factory_, createFilterChain(_))
Expand Down
5 changes: 5 additions & 0 deletions test/mocks/server/mocks.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,11 @@ MockHotRestart::MockHotRestart() : stats_allocator_(*symbol_table_) {
}
MockHotRestart::~MockHotRestart() = default;

MockThreadLocalOverloadState::MockThreadLocalOverloadState()
: disabled_state_(OverloadActionState::Inactive) {
ON_CALL(*this, getState).WillByDefault(ReturnRef(disabled_state_));
}

MockOverloadManager::MockOverloadManager() {
ON_CALL(*this, getThreadLocalOverloadState()).WillByDefault(ReturnRef(overload_state_));
}
Expand Down
11 changes: 10 additions & 1 deletion test/mocks/server/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,15 @@ class MockWorker : public Worker {
std::function<void()> remove_filter_chains_completion_;
};

class MockThreadLocalOverloadState : public ThreadLocalOverloadState {
public:
MockThreadLocalOverloadState();
MOCK_METHOD(const OverloadActionState&, getState, (const std::string&), (override));

private:
const OverloadActionState disabled_state_;
};

class MockOverloadManager : public OverloadManager {
public:
MockOverloadManager();
Expand All @@ -343,7 +352,7 @@ class MockOverloadManager : public OverloadManager {
OverloadActionCb callback));
MOCK_METHOD(ThreadLocalOverloadState&, getThreadLocalOverloadState, ());

ThreadLocalOverloadState overload_state_;
NiceMock<MockThreadLocalOverloadState> overload_state_;
};

class MockInstance : public Instance {
Expand Down

0 comments on commit 76f494b

Please sign in to comment.