Skip to content

Commit

Permalink
Allow AspiredVersionsManager to handle aspire->unaspire->reaspire of …
Browse files Browse the repository at this point in the history
…a given servable:

 (1) have BasicManager always forget about a servable that goes to state kEnd;
 (2) when AspiredVersionsManager gets a re-aspire request for a servable not in state kEnd, it blocks until that servable reaches kEnd (and everything gets reset)

Regarding tensorflow#1: independent of the re-aspire issue, I think that's a cleaner contract and it also ensures the manager won't accumulate state proportional to the # of failed servables it has ever dealt with.
Change: 132115088
  • Loading branch information
Christopher Olston authored and pythonner committed Sep 8, 2016
1 parent 4647230 commit a66f9df
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 52 deletions.
57 changes: 46 additions & 11 deletions tensorflow_serving/core/aspired_versions_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,23 @@ struct CompareActions {
}
};

// Determines whether 'state_snapshots' contains any entries with
// is_aspired==false that are in 'next_aspired_versions'. (Ignores servable
// names and only looks at version numbers; assumes all information pertains to
// a single servable name.)
bool ContainsAnyReaspiredVersions(
const std::vector<ServableStateSnapshot<Aspired>>& state_snapshots,
const std::set<int64>& next_aspired_versions) {
for (const ServableStateSnapshot<Aspired>& state_snapshot : state_snapshots) {
if (!state_snapshot.additional_state->is_aspired &&
next_aspired_versions.find(state_snapshot.id.version) !=
next_aspired_versions.end()) {
return true;
}
}
return false;
}

} // namespace

namespace internal {
Expand Down Expand Up @@ -183,8 +200,7 @@ void AspiredVersionsManager::SetAspiredVersions(
std::vector<ServableData<std::unique_ptr<Loader>>> versions) {
// We go through the aspired_servable_versions and fill a vector with the
// next aspired version numbers, and sort it.
std::vector<int64> next_aspired_versions;
next_aspired_versions.reserve(versions.size());
std::set<int64> next_aspired_versions;
for (const auto& version : versions) {
if (servable_name != version.id().name) {
LOG(ERROR) << "Servable name: " << servable_name
Expand All @@ -193,23 +209,44 @@ void AspiredVersionsManager::SetAspiredVersions(
DCHECK(false) << "See previous servable name mismatch error message.";
return;
}
next_aspired_versions.push_back(version.id().version);
next_aspired_versions.insert(version.id().version);
}
std::sort(next_aspired_versions.begin(), next_aspired_versions.end());

{
mutex_lock l(basic_manager_read_modify_write_mu_);

// We wait for any re-aspired versions (versions currently not aspired, but
// present in 'next_aspired_versions') to quiesce and be removed from
// BasicManager. Doing so ensures that re-aspired versions start with a
// clean slate by being re-inserted from scratch into BasicManager, below.
//
// TODO(b/31269483): Make SetAspiredVersions() asynchronous to avoid
// blocking the calling thread in situations like this.
std::vector<ServableStateSnapshot<Aspired>> state_snapshots;
while (true) {
state_snapshots =
basic_manager_->GetManagedServableStateSnapshots<Aspired>(
servable_name.ToString());
if (!ContainsAnyReaspiredVersions(state_snapshots,
next_aspired_versions)) {
break;
}
const auto kWaitTime = std::chrono::milliseconds(10);
// (We use this condition variable in a degenerate way -- it never gets
// notified -- to sleep without holding the mutex.)
condition_variable cv;
cv.wait_for(l, kWaitTime);
}

// We gather all the servables with the servable_name and
// 1. Add the current aspired version numbers to a vector and sort it,
// 2. Set the aspired bool to false for all current servable harnesses which
// are not aspired.
std::vector<int64> current_aspired_versions;
std::set<int64> current_aspired_versions;
for (const ServableStateSnapshot<Aspired> state_snapshot :
basic_manager_->GetManagedServableStateSnapshots<Aspired>(
servable_name.ToString())) {
state_snapshots) {
if (state_snapshot.additional_state->is_aspired) {
current_aspired_versions.push_back(state_snapshot.id.version);
current_aspired_versions.insert(state_snapshot.id.version);
}
// If this version is not part of the aspired versions.
if (std::find(next_aspired_versions.begin(), next_aspired_versions.end(),
Expand All @@ -219,13 +256,11 @@ void AspiredVersionsManager::SetAspiredVersions(
basic_manager_->CancelLoadServableRetry(state_snapshot.id);
}
}
std::sort(current_aspired_versions.begin(), current_aspired_versions.end());

// We do a set_difference (A - B), on the next aspired versions and the
// current aspired versions to find the version numbers which need to be
// added the harness map.
std::vector<int64> additions;
additions.reserve(next_aspired_versions.size());
std::set<int64> additions;
std::set_difference(
next_aspired_versions.begin(), next_aspired_versions.end(),
current_aspired_versions.begin(), current_aspired_versions.end(),
Expand Down
74 changes: 74 additions & 0 deletions tensorflow_serving/core/aspired_versions_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -739,6 +739,80 @@ TEST_P(AspiredVersionsManagerTest, RetryOnLoadErrorFinallyFails) {
EqualsServableState(error_state));
}

TEST_P(AspiredVersionsManagerTest, UnaspireThenImmediatelyReaspire) {
// This test exercises a scenario in which a servable has been unaspired, and
// while it is still being managed (e.g. loading, serving or unloading) it
// gets reaspired (with a new loader). The manager should wait for the
// original loader to get taken down via the normal process for unaspired
// loaders, and then proceed to bring up the new loader.

const ServableId id = {kServableName, 7};

std::vector<ServableData<std::unique_ptr<Loader>>> first_aspired_versions;
test_util::MockLoader* first_loader = new NiceMock<test_util::MockLoader>();
first_aspired_versions.push_back({id, std::unique_ptr<Loader>(first_loader)});
EXPECT_CALL(*first_loader, Load(_)).WillOnce(Return(Status::OK()));
manager_->GetAspiredVersionsCallback()(kServableName,
std::move(first_aspired_versions));

std::unique_ptr<Thread> load_thread(Env::Default()->StartThread(
ThreadOptions(), "LoadThread", [&]() { RunManageState(); }));

// Pin 'first_loader' in the manager by holding a handle to its servable.
WaitUntilServableManagerStateIsOneOf(
servable_state_monitor_, id, {ServableState::ManagerState::kAvailable});
int servable = 42;
EXPECT_CALL(*first_loader, servable()).WillOnce(InvokeWithoutArgs([&]() {
return AnyPtr{&servable};
}));
auto first_loader_handle =
std::unique_ptr<ServableHandle<int>>(new ServableHandle<int>);
TF_ASSERT_OK(manager_->GetServableHandle(ServableRequest::FromId(id),
first_loader_handle.get()));

// Now, we'll un-aspire the servable, and then re-aspire it with a new loader.
// The manager should wait until it is able to unload the first loader, then
// bring up the second loader.

std::vector<ServableData<std::unique_ptr<Loader>>> empty_aspired_versions;
manager_->GetAspiredVersionsCallback()(kServableName,
std::move(empty_aspired_versions));

std::vector<ServableData<std::unique_ptr<Loader>>> second_aspired_versions;
test_util::MockLoader* second_loader = new NiceMock<test_util::MockLoader>();
second_aspired_versions.push_back(
{id, std::unique_ptr<Loader>(second_loader)});
Notification second_load_called;
EXPECT_CALL(*second_loader, Load(_)).WillOnce(InvokeWithoutArgs([&]() {
second_load_called.Notify();
return Status::OK();
}));
// TODO(b/31269483): Once we make SetAspiredVersions() non-blocking we won't
// need to run this in a separate thread.
std::unique_ptr<Thread> reaspire_thread(
Env::Default()->StartThread(ThreadOptions(), "ReaspireThread", [&]() {
manager_->GetAspiredVersionsCallback()(
kServableName, std::move(second_aspired_versions));
}));

// Give the re-aspire call some time to sit on the request, and make sure it
// doesn't process it prematurely or do any other bad things.
Env::Default()->SleepForMicroseconds(50 * 1000 /* 50 ms */);

// Unpin the first loader. Eventually the manager should bring up the second
// loader.
first_loader_handle = nullptr;
{
std::unique_ptr<Thread> reload_thread(
Env::Default()->StartThread(ThreadOptions(), "ReloadThread", [&]() {
while (!second_load_called.HasBeenNotified()) {
RunManageState();
}
}));
}
ASSERT_TRUE(second_load_called.HasBeenNotified());
}

} // namespace
} // namespace serving
} // namespace tensorflow
22 changes: 18 additions & 4 deletions tensorflow_serving/core/basic_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,17 @@ BasicManager::ManagedMap::iterator BasicManager::FindHarnessInMap(
return managed_map_.end();
}

void BasicManager::DeleteHarness(const ServableId& id) {
const auto it = FindHarnessInMap(id);
DCHECK(it != managed_map_.end());
if (it == managed_map_.end()) {
LOG(ERROR) << "Request to delete harness for " << id
<< ", but no such harness found in managed_map_";
return;
}
managed_map_.erase(it);
}

void BasicManager::ManageServableInternal(
ServableData<std::unique_ptr<Loader>> servable,
std::function<std::shared_ptr<LoaderHarness>(const ServableId&,
Expand Down Expand Up @@ -307,8 +318,8 @@ void BasicManager::ManageServableInternal(
} else {
PublishOnEventBus({harness->id(), ServableState::ManagerState::kStart,
harness->status()});
managed_map_.emplace(servable.id().name, harness);
}
managed_map_.emplace(servable.id().name, harness);
}

void BasicManager::ManageServable(
Expand Down Expand Up @@ -405,6 +416,7 @@ Status BasicManager::ExecuteLoad(LoaderHarness* harness) {
if (!load_status.ok()) {
PublishOnEventBus({harness->id(), ServableState::ManagerState::kEnd,
harness->status()});
DeleteHarness(harness->id());
return load_status;
}

Expand Down Expand Up @@ -583,12 +595,14 @@ Status BasicManager::ApproveLoad(LoaderHarness* harness, mutex_lock* mu_lock) {
// for this servable.
LOG(WARNING) << "Unable to reserve resources to load servable "
<< harness->id().DebugString();
harness->Error(errors::ResourceExhausted(
const Status error = errors::ResourceExhausted(
"Insufficient resources to load servable ",
harness->id().DebugString()));
harness->id().DebugString());
harness->Error(error);
PublishOnEventBus({harness->id(), ServableState::ManagerState::kEnd,
harness->status()});
return harness->status();
DeleteHarness(harness->id());
return error;
} else {
// Wait until at least one load/unload request finishes, then retry.
num_ongoing_load_unload_executions_cv_.wait(*mu_lock);
Expand Down
20 changes: 17 additions & 3 deletions tensorflow_serving/core/basic_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -291,9 +291,14 @@ class BasicManager : public Manager {
Status ApproveLoadOrUnload(const LoadOrUnloadRequest& request,
LoaderHarness** harness) LOCKS_EXCLUDED(mu_);

// The decision phase of whether to approve a load request. If it succeeds,
// places the servable into state kApprovedForLoad. Among other things, that
// prevents a subsequent load request from proceeding concurrently.
// The decision phase of whether to approve a load request.
//
// If it succeeds, places the servable into state kApprovedForLoad. Among
// other things, that prevents a subsequent load request from proceeding
// concurrently.
//
// If it fails, removes 'harness' from 'managed_map_' (which causes 'harness'
// to be deleted).
//
// Argument 'mu_lock' is a lock held on 'mu_'. It is released temporarily via
// 'num_ongoing_load_unload_executions_cv_'.
Expand All @@ -310,6 +315,9 @@ class BasicManager : public Manager {
//
// Upon completion (and regardless of the outcome), signals exit of the
// execution phase by decrementing 'num_ongoing_load_unload_executions_'.
//
// If it fails, removes 'harness' from 'managed_map_' (which causes 'harness'
// to be deleted).
Status ExecuteLoadOrUnload(const LoadOrUnloadRequest& request,
LoaderHarness* harness);

Expand Down Expand Up @@ -341,6 +349,12 @@ class BasicManager : public Manager {
ManagedMap::iterator FindHarnessInMap(const ServableId& id)
EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Removes the harness associated with 'id' from 'managed_map_' and deletes
// the harness.
//
// If no matching harness is found, DCHECK-fails and logs an error.
void DeleteHarness(const ServableId& id) EXCLUSIVE_LOCKS_REQUIRED(mu_);

// Publishes the state on the event bus, if an event bus was part of the
// options, if not we ignore it.
void PublishOnEventBus(const ServableState& state);
Expand Down
41 changes: 7 additions & 34 deletions tensorflow_serving/core/basic_manager_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -348,14 +348,10 @@ TEST_P(BasicManagerTest, GetManagedServableNames) {
}

TEST_P(BasicManagerTest,
GetManagedServableStateSnapshotsWithoutAdditionalState) {
const ServableId id = {kServableName, 7};
basic_manager_->ManageServable(
ServableData<std::unique_ptr<Loader>>(id, errors::Internal("An error.")));
GetManagedServableStateSnapshotWithoutAdditionalState) {
const std::vector<ServableStateSnapshot<>> expected = {
{{kServableName, 1}, LoaderHarness::State::kReady, {}},
{{kServableName, 2}, LoaderHarness::State::kReady, {}},
{{kServableName, 7}, LoaderHarness::State::kError, {}}};
{{kServableName, 2}, LoaderHarness::State::kReady, {}}};
EXPECT_THAT(basic_manager_->GetManagedServableStateSnapshots(kServableName),
UnorderedElementsAreArray(expected));
}
Expand All @@ -371,18 +367,6 @@ TEST_P(BasicManagerTest, GetManagedServableStateSnapshot) {
id_ready, LoaderHarness::State::kReady, {}};
EXPECT_EQ(actual_ready_snapshot, expected_ready_snapshot);

// Check servable state snapshot corresponding to a servable-id that is in
// error state.
const ServableId id_error = {kServableName, 7};
basic_manager_->ManageServable(ServableData<std::unique_ptr<Loader>>(
id_error, errors::Internal("An error.")));
const optional<ServableStateSnapshot<>> actual_error_snapshot =
basic_manager_->GetManagedServableStateSnapshot(id_error);
EXPECT_TRUE(actual_error_snapshot);
const ServableStateSnapshot<> expected_error_snapshot = {
id_error, LoaderHarness::State::kError, {}};
EXPECT_EQ(actual_error_snapshot, expected_error_snapshot);

// Check servable state snapshot corresponding to a servable-id that is not
// managed by the basic-manager.
const ServableId id_notmanaged = {kServableName, 8};
Expand All @@ -394,14 +378,9 @@ TEST_P(BasicManagerTest, GetManagedServableStateSnapshotsWithAdditionalState) {
CreateServable({kServableName3, 0}), std::unique_ptr<int>(new int(0)));
basic_manager_->ManageServableWithAdditionalState(
CreateServable({kServableName3, 1}), std::unique_ptr<int>(new int(1)));
basic_manager_->ManageServableWithAdditionalState(
ServableData<std::unique_ptr<Loader>>({kServableName3, 2},
errors::Internal("An error.")),
std::unique_ptr<int>(new int(2)));
const std::vector<ServableStateSnapshot<int>> expected = {
{{kServableName3, 0}, LoaderHarness::State::kNew, {0}},
{{kServableName3, 1}, LoaderHarness::State::kNew, {1}},
{{kServableName3, 2}, LoaderHarness::State::kError, {2}}};
{{kServableName3, 1}, LoaderHarness::State::kNew, {1}}};
EXPECT_THAT(
basic_manager_->GetManagedServableStateSnapshots<int>(kServableName3),
UnorderedElementsAreArray(expected));
Expand Down Expand Up @@ -437,9 +416,8 @@ TEST_P(BasicManagerTest, ErroneousServable) {
Status status = basic_manager_->GetServableHandle(
ServableRequest::Specific(kServableName, 3), &handle);
EXPECT_FALSE(status.ok()) << status;
basic_manager_->LoadServable(id, [](const Status& status) {
EXPECT_EQ(errors::Unknown("error"), status);
});
basic_manager_->LoadServable(
id, [](const Status& status) { EXPECT_FALSE(status.ok()) << status; });

status = basic_manager_->GetServableHandle(
ServableRequest::Specific(kServableName, 3), &handle);
Expand Down Expand Up @@ -917,9 +895,8 @@ TEST_P(BasicManagerTest, LoadAfterCancelledLoad) {
WaitUntilServableManagerStateIsOneOf(servable_state_monitor_, id,
{ServableState::ManagerState::kEnd});

basic_manager_->LoadServable(id, [](const Status& status) {
EXPECT_EQ(errors::Internal("Load error."), status);
});
basic_manager_->LoadServable(
id, [](const Status& status) { EXPECT_FALSE(status.ok()) << status; });
}

// Creates a ResourceAllocation proto with 'quantity' units of RAM.
Expand Down Expand Up @@ -1051,10 +1028,6 @@ TEST_F(ResourceConstrainedBasicManagerTest, InsufficientResources) {
rejection_received.Notify();
});
rejection_received.WaitForNotification();
EXPECT_THAT(
basic_manager_->GetManagedServableStateSnapshots(rejected_id.name),
UnorderedElementsAre(ServableStateSnapshot<>{
rejected_id, LoaderHarness::State::kError, {}}));
const ServableState expected_error_state = {
rejected_id, ServableState::ManagerState::kEnd, rejected_status};
EXPECT_THAT(*servable_state_monitor_.GetState(rejected_id),
Expand Down

0 comments on commit a66f9df

Please sign in to comment.