Skip to content

Commit

Permalink
fix: prevent setDatabase starvation which leads to no updates of curr…
Browse files Browse the repository at this point in the history
…ently loaded database
  • Loading branch information
Taepper committed Oct 14, 2024
1 parent 13f4362 commit 1898c73
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 44 deletions.
20 changes: 7 additions & 13 deletions include/silo_api/database_mutex.h
Original file line number Diff line number Diff line change
@@ -1,36 +1,30 @@
#pragma once

#include <shared_mutex>
#include <memory>

#include "silo/database.h"

namespace silo_api {

class FixedDatabase {
std::shared_lock<std::shared_mutex> lock;

public:
FixedDatabase(const silo::Database& database, std::shared_lock<std::shared_mutex>&& mutex);

const silo::Database& database;
};

class UninitializedDatabaseException : public std::runtime_error {
public:
UninitializedDatabaseException()
: std::runtime_error("Database not initialized yet") {}
};

class DatabaseMutex {
std::shared_mutex mutex;
silo::Database database;
bool is_initialized = false;
std::shared_ptr<silo::Database> database;

public:
DatabaseMutex() = default;
DatabaseMutex(const DatabaseMutex& other) = delete;
DatabaseMutex(DatabaseMutex&& other) = delete;
DatabaseMutex& operator=(const DatabaseMutex& other) = delete;
DatabaseMutex& operator=(DatabaseMutex&& other) = delete;

void setDatabase(silo::Database&& new_database);

virtual FixedDatabase getDatabase();
virtual std::shared_ptr<silo::Database> getDatabase();
};
} // namespace silo_api
2 changes: 1 addition & 1 deletion src/silo_api/database_directory_watcher.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ void silo_api::DatabaseDirectoryWatcher::checkDirectoryForData(Poco::Timer& /*ti
{
try {
const auto current_data_version_timestamp =
database_mutex.getDatabase().database.getDataVersionTimestamp();
database_mutex.getDatabase()->getDataVersionTimestamp();
const auto most_recent_data_version_timestamp_found =
most_recent_database_state->second.getTimestamp();
if (current_data_version_timestamp >= most_recent_data_version_timestamp_found) {
Expand Down
19 changes: 6 additions & 13 deletions src/silo_api/database_mutex.cpp
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
#include "silo_api/database_mutex.h"

#include <mutex>
#include <atomic>
#include <utility>

#include "silo/database.h"

namespace silo_api {

silo_api::FixedDatabase::FixedDatabase(
const silo::Database& database,
std::shared_lock<std::shared_mutex>&& mutex
)
: lock(std::move(mutex)),
database(database) {}

void silo_api::DatabaseMutex::setDatabase(silo::Database&& new_database) {
const std::unique_lock lock(mutex);
database = std::move(new_database);
auto new_database_pointer = std::make_shared<silo::Database>(std::move(new_database));

std::atomic_store(&database, new_database_pointer);
is_initialized = true;
}

silo_api::FixedDatabase silo_api::DatabaseMutex::getDatabase() {
std::shared_ptr<silo::Database> silo_api::DatabaseMutex::getDatabase() {
if (!is_initialized) {
throw silo_api::UninitializedDatabaseException();
}
std::shared_lock<std::shared_mutex> lock(mutex);
return {database, std::move(lock)};
return std::atomic_load(&database);
}

} // namespace silo_api
8 changes: 4 additions & 4 deletions src/silo_api/info_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,13 +99,13 @@ void InfoHandler::get(

const auto fixed_database = database.getDatabase();

response.set("data-version", fixed_database.database.getDataVersionTimestamp().value);
response.set("data-version", fixed_database->getDataVersionTimestamp().value);

const bool return_detailed_info = request_parameter.find("details") != request_parameter.end() &&
request_parameter.at("details") == "true";
const nlohmann::json database_info =
return_detailed_info ? nlohmann::json(database.getDatabase().database.detailedDatabaseInfo())
: nlohmann::json(database.getDatabase().database.getDatabaseInfo());
const nlohmann::json database_info = return_detailed_info
? nlohmann::json(fixed_database->detailedDatabaseInfo())
: nlohmann::json(fixed_database->getDatabaseInfo());
response.setContentType("application/json");
std::ostream& out_stream = response.send();
out_stream << database_info;
Expand Down
4 changes: 2 additions & 2 deletions src/silo_api/query_handler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ void QueryHandler::post(
try {
const auto fixed_database = database_mutex.getDatabase();

auto query_result = fixed_database.database.executeQuery(query);
auto query_result = fixed_database->executeQuery(query);

response.set("data-version", fixed_database.database.getDataVersionTimestamp().value);
response.set("data-version", fixed_database->getDataVersionTimestamp().value);

response.setContentType("application/x-ndjson");
std::ostream& out_stream = response.send();
Expand Down
20 changes: 9 additions & 11 deletions src/silo_api/request_handler_factory.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,10 @@ class MockDatabase : public silo::Database {

class MockDatabaseMutex : public silo_api::DatabaseMutex {
public:
std::shared_mutex mutex;
MockDatabase mock_database;
std::shared_ptr<MockDatabase> mock_database = std::make_shared<MockDatabase>();

silo_api::FixedDatabase getDatabase() override {
std::shared_lock<std::shared_mutex> lock(mutex);
return {mock_database, std::move(lock)};
std::shared_ptr<silo::Database> getDatabase() override {
return mock_database;
}
};

Expand Down Expand Up @@ -68,11 +66,11 @@ const int FOUR_MINUTES_IN_SECONDS = 240;
} // namespace

TEST_F(RequestHandlerTestFixture, handlesGetInfoRequest) {
EXPECT_CALL(database_mutex.mock_database, getDatabaseInfo)
EXPECT_CALL(*database_mutex.mock_database, getDatabaseInfo)
.WillRepeatedly(testing::Return(
silo::DatabaseInfo{.sequence_count = 1, .total_size = 2, .n_bitmaps_size = 3}
));
EXPECT_CALL(database_mutex.mock_database, getDataVersionTimestamp)
EXPECT_CALL(*database_mutex.mock_database, getDataVersionTimestamp)
.WillRepeatedly(testing::Return(silo::DataVersion::Timestamp::fromString("1234").value()));

request.setURI("/info");
Expand Down Expand Up @@ -103,9 +101,9 @@ TEST_F(RequestHandlerTestFixture, handlesGetInfoRequestDetails) {

const silo::DetailedDatabaseInfo detailed_database_info = {{{"main", stats}}};

EXPECT_CALL(database_mutex.mock_database, detailedDatabaseInfo)
EXPECT_CALL(*database_mutex.mock_database, detailedDatabaseInfo)
.WillRepeatedly(testing::Return(detailed_database_info));
EXPECT_CALL(database_mutex.mock_database, getDataVersionTimestamp)
EXPECT_CALL(*database_mutex.mock_database, getDataVersionTimestamp)
.WillRepeatedly(testing::Return(silo::DataVersion::Timestamp::fromString("1234").value()));

request.setURI("/info?details=true");
Expand Down Expand Up @@ -144,8 +142,8 @@ TEST_F(RequestHandlerTestFixture, handlesPostQueryRequest) {
};
std::vector<silo::query_engine::QueryResultEntry> tmp{{fields1}, {fields2}};
auto query_result = silo::query_engine::QueryResult::fromVector(std::move(tmp));
EXPECT_CALL(database_mutex.mock_database, executeQuery).WillOnce(testing::Return(query_result));
EXPECT_CALL(database_mutex.mock_database, getDataVersionTimestamp)
EXPECT_CALL(*database_mutex.mock_database, executeQuery).WillOnce(testing::Return(query_result));
EXPECT_CALL(*database_mutex.mock_database, getDataVersionTimestamp)
.WillOnce(testing::Return(silo::DataVersion::Timestamp::fromString("1234").value()));

request.setMethod("POST");
Expand Down

0 comments on commit 1898c73

Please sign in to comment.