diff --git a/include/cgimap/http.hpp b/include/cgimap/http.hpp index 7f7349ee..11d63024 100644 --- a/include/cgimap/http.hpp +++ b/include/cgimap/http.hpp @@ -157,7 +157,8 @@ class not_found : public exception { */ class bandwidth_limit_exceeded : public exception { public: - bandwidth_limit_exceeded(const std::string &message); + bandwidth_limit_exceeded(int retry_seconds); + int retry_seconds; }; /** diff --git a/include/cgimap/rate_limiter.hpp b/include/cgimap/rate_limiter.hpp index 12339340..9c3550a4 100644 --- a/include/cgimap/rate_limiter.hpp +++ b/include/cgimap/rate_limiter.hpp @@ -10,7 +10,7 @@ struct rate_limiter { // check if the key is below the rate limit. return true to indicate that it // is. - virtual bool check(const std::string &key, bool moderator) = 0; + virtual std::tuple check(const std::string &key, bool moderator) = 0; // update the limit for the key to say it has consumed this number of bytes. virtual void update(const std::string &key, int bytes, bool moderator) = 0; @@ -19,7 +19,7 @@ struct rate_limiter { struct null_rate_limiter : public rate_limiter { ~null_rate_limiter(); - bool check(const std::string &key, bool moderator); + std::tuple check(const std::string &key, bool moderator); void update(const std::string &key, int bytes, bool moderator); }; @@ -31,7 +31,7 @@ class memcached_rate_limiter */ memcached_rate_limiter(const boost::program_options::variables_map &options); ~memcached_rate_limiter(); - bool check(const std::string &key, bool moderator); + std::tuple check(const std::string &key, bool moderator); void update(const std::string &key, int bytes, bool moderator); private: diff --git a/src/http.cpp b/src/http.cpp index a3e6f896..6fc3d829 100644 --- a/src/http.cpp +++ b/src/http.cpp @@ -103,8 +103,8 @@ const char *precondition_failed::what() const noexcept { return fullstring.c_str payload_too_large::payload_too_large(const string &message) : exception(413, "Payload Too Large", message) {} -bandwidth_limit_exceeded::bandwidth_limit_exceeded(const string &message) - : exception(509, "Bandwidth Limit Exceeded", message) {} +bandwidth_limit_exceeded::bandwidth_limit_exceeded(int retry_seconds) + : exception(509, "Bandwidth Limit Exceeded", fmt::format("You have downloaded too much data. Please try again in {} seconds.", retry_seconds)), retry_seconds(retry_seconds) {} gone::gone(const string &message) : exception(410, "Gone", message) {} diff --git a/src/options.cpp b/src/options.cpp index 57c4db09..674dd571 100644 --- a/src/options.cpp +++ b/src/options.cpp @@ -147,8 +147,8 @@ void global_settings_via_options::set_oauth_10_support(const po::variables_map & void global_settings_via_options::set_ratelimiter_ratelimit(const po::variables_map &options) { if (options.count("ratelimit")) { auto parsed_bytes_per_sec = options["ratelimit"].as(); - if (parsed_bytes_per_sec < 0) - throw std::invalid_argument("ratelimit must be a positive number"); + if (parsed_bytes_per_sec <= 0) + throw std::invalid_argument("ratelimit must be greater than zero"); if (parsed_bytes_per_sec > 1024 * 1024 * 1024) throw std::invalid_argument("ratelimit must be 1GB or less"); m_ratelimiter_ratelimit = parsed_bytes_per_sec; @@ -156,8 +156,8 @@ void global_settings_via_options::set_ratelimiter_ratelimit(const po::variables_ if (options.count("moderator-ratelimit")) { auto parsed_bytes_per_sec = options["moderator-ratelimit"].as(); - if (parsed_bytes_per_sec < 0) - throw std::invalid_argument("moderator-ratelimit must be a positive number"); + if (parsed_bytes_per_sec <= 0) + throw std::invalid_argument("moderator-ratelimit must be greater than zero"); if (parsed_bytes_per_sec > 1024 * 1024 * 1024) throw std::invalid_argument("moderator-ratelimit must be 1GB or less"); m_moderator_ratelimiter_ratelimit = parsed_bytes_per_sec; @@ -167,8 +167,8 @@ void global_settings_via_options::set_ratelimiter_ratelimit(const po::variables_ void global_settings_via_options::set_ratelimiter_maxdebt(const po::variables_map &options) { if (options.count("maxdebt")) { auto parsed_max_bytes = options["maxdebt"].as(); - if (parsed_max_bytes < 0) - throw std::invalid_argument("maxdebt must be a positive number"); + if (parsed_max_bytes <= 0) + throw std::invalid_argument("maxdebt must be greater than zero"); if (parsed_max_bytes > 3500) throw std::invalid_argument("maxdebt (in MB) must be 3500 or less"); m_ratelimiter_maxdebt = parsed_max_bytes * 1024 * 1024; @@ -176,8 +176,8 @@ void global_settings_via_options::set_ratelimiter_maxdebt(const po::variables_ma if (options.count("moderator-maxdebt")) { auto parsed_max_bytes = options["moderator-maxdebt"].as(); - if (parsed_max_bytes < 0) - throw std::invalid_argument("moderator-maxdebt must be a positive number"); + if (parsed_max_bytes <= 0) + throw std::invalid_argument("moderator-maxdebt must be greater than zero"); if (parsed_max_bytes > 3500) throw std::invalid_argument("moderator-maxdebt (in MB) must be 3500 or less"); m_moderator_ratelimiter_maxdebt = parsed_max_bytes * 1024 * 1024; diff --git a/src/process_request.cpp b/src/process_request.cpp index 68c21ad5..ed943c93 100644 --- a/src/process_request.cpp +++ b/src/process_request.cpp @@ -69,6 +69,8 @@ void respond_404(const http::not_found &e, request &r) { .add_header("Content-Length", "0") .add_header("Cache-Control", "no-cache") .put(""); + + r.finish(); } void respond_401(const http::unauthorized &e, request &r) { @@ -140,8 +142,16 @@ void respond_error(const http::exception &e, request &r) { .add_header("Content-Type", "text/plain") .add_header("Content-Length", std::to_string(message.size())) .add_header("Error", message_error_header) - .add_header("Cache-Control", "no-cache") - .put(message); // output the message as well + .add_header("Cache-Control", "no-cache"); + + if (e.code() == 509) { + if (auto bandwidth_exception = dynamic_cast(&e)) { + r.add_header("Retry-After", + std::to_string(bandwidth_exception->retry_seconds)); + } + } + + r.put(message); // output the message as well } r.finish(); @@ -517,10 +527,13 @@ void process_request(request &req, rate_limiter &limiter, auto is_moderator = user_roles.count(osm_user_role_t::moderator) > 0; + bool exceeded_limit; + int retry_seconds; + std::tie(exceeded_limit, retry_seconds) = limiter.check(client_key, is_moderator); // check whether the client is being rate limited - if (!limiter.check(client_key, is_moderator)) { + if (exceeded_limit) { logger::message(fmt::format("Rate limiter rejected request from {}", client_key)); - throw http::bandwidth_limit_exceeded("You have downloaded too much data. Please try again later."); + throw http::bandwidth_limit_exceeded(retry_seconds); } auto start_time = std::chrono::high_resolution_clock::now(); diff --git a/src/rate_limiter.cpp b/src/rate_limiter.cpp index 7909e86b..56c46c7d 100644 --- a/src/rate_limiter.cpp +++ b/src/rate_limiter.cpp @@ -8,8 +8,8 @@ rate_limiter::~rate_limiter() = default; null_rate_limiter::~null_rate_limiter() = default; -bool null_rate_limiter::check(const std::string &, bool) { - return true; +std::tuple null_rate_limiter::check(const std::string &, bool) { + return std::make_tuple(false, 0); } void null_rate_limiter::update(const std::string &, int, bool) { @@ -45,7 +45,7 @@ memcached_rate_limiter::~memcached_rate_limiter() { memcached_free(ptr); } -bool memcached_rate_limiter::check(const std::string &key, bool moderator) { +std::tuple memcached_rate_limiter::check(const std::string &key, bool moderator) { uint32_t bytes_served = 0; std::string mc_key; state *sp; @@ -71,7 +71,12 @@ bool memcached_rate_limiter::check(const std::string &key, bool moderator) { } auto max_bytes = global_settings::get_ratelimiter_maxdebt(moderator); - return bytes_served < max_bytes; + if (bytes_served < max_bytes) { + return std::make_tuple(false, 0); + } else { + // + 1 to reverse effect of integer flooring seconds + return std::make_tuple(true, (bytes_served - max_bytes) / bytes_per_sec + 1); + } } void memcached_rate_limiter::update(const std::string &key, int bytes, bool moderator) { diff --git a/test/test_apidb_backend_oauth.cpp b/test/test_apidb_backend_oauth.cpp index 9a76b104..c5a00d74 100644 --- a/test/test_apidb_backend_oauth.cpp +++ b/test/test_apidb_backend_oauth.cpp @@ -304,9 +304,9 @@ struct recording_rate_limiter : public rate_limiter { ~recording_rate_limiter() = default; - bool check(const std::string &key, bool moderator) { + std::tuple check(const std::string &key, bool moderator) { m_keys_seen.insert(key); - return true; + return std::make_tuple(false, 0); } void update(const std::string &key, int bytes, bool moderator) { diff --git a/test/test_apidb_backend_oauth2.cpp b/test/test_apidb_backend_oauth2.cpp index 16c6fdb5..6a3812ec 100644 --- a/test/test_apidb_backend_oauth2.cpp +++ b/test/test_apidb_backend_oauth2.cpp @@ -241,9 +241,9 @@ struct recording_rate_limiter : public rate_limiter { ~recording_rate_limiter() = default; - bool check(const std::string &key, bool moderator) { + std::tuple check(const std::string &key, bool moderator) { m_keys_seen.insert(key); - return true; + return std::make_tuple(false, 0); } void update(const std::string &key, int bytes, bool moderator) {