Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify std::regex #436

Merged
merged 2 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 35 additions & 37 deletions src/http.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,57 +225,58 @@ vector<pair<string, string> > parse_params(const string &p) {
return queryKVPairs;
}

std::unique_ptr<encoding> choose_encoding(const string &accept_encoding) {
vector<string> encodings;
std::unique_ptr<encoding> choose_encoding(const std::string &accept_encoding) {

static const std::regex regex1("\\s*([^()<>@,;:\\\\\"/[\\]\\\\?={} "
"\\t]+)\\s*;\\s*q\\s*=(\\d+(\\.\\d+)?)\\s*");
std::vector<std::string> encodings;

static const std::regex regex2( R"(\s*([^()<>@,;:\\"/[\]\\?={} \t]+)\s*)");
al::iter_split(encodings, accept_encoding, al::first_finder(", "));

al::split(encodings, accept_encoding, al::is_any_of(","));

float identity_quality = 0.001;
float identity_quality = 0.000;
float deflate_quality = 0.000;
float gzip_quality = 0.000;
float brotli_quality = 0.000;

for (const string &encoding : encodings) {
std::smatch what;
string name;
float quality;

if (std::regex_match(
encoding, what,
regex1)) {
name = what[1];
quality = std::atof(string(what[2]).c_str());
} else if (std::regex_match(
encoding, what,
regex2)) {
name = what[1];
// set default if header empty
if (encodings.empty())
encodings.push_back("*");

for (const auto &encoding : encodings) {

std::string name;
float quality = 0.0;

std::vector<std::string> what;

al::iter_split(what, encoding, al::first_finder(";q="));

if (what.size() == 2) {
float q = std::stof(what[1]);
if (q >= 0 && q <= 1) {
name = what[0];
quality = q;
}
}
else if (what.size() == 1) {
name = what[0];
quality = 1.0;
} else {
name = "";
quality = 0.0;
}

if (al::iequals(name, "identity")) {
if (name == "identity") {
identity_quality = quality;
} else if (al::iequals(name, "deflate")) {
} else if (name == "deflate") {
deflate_quality = quality;
} else if (al::iequals(name, "gzip")) {
} else if (name == "gzip") {
gzip_quality = quality;
} else if (al::iequals(name, "br")) {
} else if (name == "br") {
brotli_quality = quality;
} else if (al::iequals(name, "*")) {
} else if (name == "*") {
if (identity_quality == 0.000)
identity_quality = quality;
if (deflate_quality == 0.000)
deflate_quality = quality;
if (gzip_quality == 0.001)
if (gzip_quality == 0.000)
gzip_quality = quality;
if (brotli_quality == 0.001)
if (brotli_quality == 0.000)
brotli_quality = quality;
}
}
Expand All @@ -288,13 +289,10 @@ std::unique_ptr<encoding> choose_encoding(const string &accept_encoding) {
}
#endif
#ifdef HAVE_LIBZ
#ifdef ENABLE_DEFLATE
if (deflate_quality > 0.0 && deflate_quality >= gzip_quality &&
deflate_quality >= identity_quality) {
return std:make_unique<deflate>();
} else
#endif /* ENABLE_DEFLATE */
if (gzip_quality > 0.0 && gzip_quality >= identity_quality) {
return std::make_unique<deflate>();
} else if (gzip_quality > 0.0 && gzip_quality >= identity_quality) {
return std::make_unique<gzip>();
}
#endif /* HAVE_LIBZ */
Expand Down
55 changes: 42 additions & 13 deletions src/oauth2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,60 @@ inline std::string sha256_hash(const std::string& s) {

namespace oauth2 {

[[nodiscard]] std::optional<osm_user_id_t> validate_bearer_token(const request &req, data_selection& selection, bool& allow_api_write)
{
bool is_valid_bearer_token_char(unsigned char c) {
// according to RFC 6750, section 2.1

switch (c) {
case 'a' ... 'z':
return true;
case 'A' ... 'Z':
return true;
case '0' ... '9':
return true;
case '-':
return true;
case '.':
return true;
case '_':
return true;
case '~':
return true;
case '+':
return true;
case '/':
return true;
case '=':
return true; // we ignore that this char should only occur at end
}

return false;
}

static const std::regex r(R"(Bearer ([A-Za-z0-9~_\-\.\+\/]+=*))"); // according to RFC 6750, section 2.1
bool has_forbidden_char(std::string_view str) {
return std::find_if(str.begin(), str.end(), [](unsigned char ch) {
return !is_valid_bearer_token_char(ch);
}) != str.end();
}

[[nodiscard]] std::optional<osm_user_id_t> validate_bearer_token(const request &req, data_selection& selection, bool& allow_api_write)
{
const char * auth_hdr = req.get_param ("HTTP_AUTHORIZATION");
if (auth_hdr == nullptr)
return std::nullopt;

const auto auth_header = std::string(auth_hdr);

std::smatch sm;

try {
if (!std::regex_match(auth_header, sm, r))
return std::nullopt;
// Auth header starts with Bearer?
if (auth_header.rfind("Bearer ", 0) == std::string::npos)
return std::nullopt;

if (sm.size() != 2)
return std::nullopt;
const auto bearer_token = auth_header.substr(7);

} catch (std::regex_error&) {
if (bearer_token.empty())
return std::nullopt;
}

const auto& bearer_token = sm[1];
if (has_forbidden_char(bearer_token))
return std::nullopt;

bool expired;
bool revoked;
Expand Down
2 changes: 1 addition & 1 deletion test/test_apidb_backend_oauth2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ struct recording_rate_limiter : public rate_limiter {
void add_common_headers(test_request& req)
{
req.set_header("HTTP_HOST", "www.openstreetmap.org");
req.set_header("HTTP_ACCEPT_ENCODING", "gzip;q=1.0,deflate;q=0.6,identity;q=0.3");
req.set_header("HTTP_ACCEPT_ENCODING", "gzip;q=1.0, deflate;q=0.6, identity;q=0.3");
req.set_header("HTTP_ACCEPT", "*/*");
req.set_header("HTTP_USER_AGENT", "OAuth gem v0.4.7");
req.set_header("HTTP_X_REQUEST_ID", "V-eaKX8AAQEAAF4UzHwAAAHt");
Expand Down
9 changes: 6 additions & 3 deletions test/test_http.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,14 +71,17 @@ TEST_CASE("http_check_parse_methods", "[http]") {
}

TEST_CASE("http_check_choose_encoding", "[http]") {
CHECK(http::choose_encoding("deflate, gzip;q=1.0, *;q=0.5")->name() == "gzip");
CHECK(http::choose_encoding("deflate, gzip;q=1.0, *;q=0.5")->name() == "deflate");
CHECK(http::choose_encoding("gzip;q=1.0, identity;q=0.8, *;q=0.1")->name() == "gzip");
CHECK(http::choose_encoding("identity;q=0.8, gzip;q=1.0, *;q=0.1")->name() == "gzip");
CHECK(http::choose_encoding("gzip")->name() == "gzip");
CHECK(http::choose_encoding("identity")->name() == "identity");
CHECK(http::choose_encoding("*")->name() == "identity");
CHECK(http::choose_encoding("deflate")->name() == "identity");
CHECK(http::choose_encoding("*")->name() == "br");
CHECK(http::choose_encoding("deflate")->name() == "deflate");
#if HAVE_BROTLI
CHECK(http::choose_encoding("gzip, deflate, br")->name() == "br");
CHECK(http::choose_encoding("zstd;q=1.0, deflate;q=0.8, br;q=0.9")->name() == "br");
CHECK(http::choose_encoding("zstd;q=1.0, unknown;q=0.8, br;q=0.9")->name() == "br");
CHECK(http::choose_encoding("gzip, deflate, br")->name() == "br");
#endif
}
27 changes: 25 additions & 2 deletions test/test_oauth2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,37 @@ TEST_CASE("test_validate_bearer_token", "[oauth2]") {
CHECK(allow_api_write);
}

SECTION("Test bearer token invalid format") {
SECTION("Test bearer token invalid format (invalid chars)") {
req.set_header("HTTP_AUTHORIZATION","Bearer 6!#c23.-;<<>>");
auto res = oauth2::validate_bearer_token(req, *sel, allow_api_write);
CHECK(res == std::optional<osm_user_id_t>{});
}

SECTION("Test invalid bearer token") {
SECTION("Test bearer token invalid format (extra space after bearer)") {
req.set_header("HTTP_AUTHORIZATION","Bearer abc");
auto res = oauth2::validate_bearer_token(req, *sel, allow_api_write);
CHECK(res == std::optional<osm_user_id_t>{});
}

SECTION("Test bearer token invalid format (lowercase Bearer)") {
req.set_header("HTTP_AUTHORIZATION","bearer abc");
auto res = oauth2::validate_bearer_token(req, *sel, allow_api_write);
CHECK(res == std::optional<osm_user_id_t>{});
}

SECTION("Test bearer token invalid format (trailing space after token)") {
req.set_header("HTTP_AUTHORIZATION","Bearer abcdefghijklm ");
auto res = oauth2::validate_bearer_token(req, *sel, allow_api_write);
CHECK(res == std::optional<osm_user_id_t>{});
}

SECTION("Test bearer token invalid format (missing tokan)") {
req.set_header("HTTP_AUTHORIZATION","Bearer ");
auto res = oauth2::validate_bearer_token(req, *sel, allow_api_write);
CHECK(res == std::optional<osm_user_id_t>{});
}

SECTION("Test invalid bearer token") {
req.set_header("HTTP_AUTHORIZATION","Bearer nFRBLFyNXPKY1fiTHAIfVsjQYkCD2KoRuH66upvueaQ");
REQUIRE_THROWS_MATCHES(static_cast<void>(oauth2::validate_bearer_token(req, *sel, allow_api_write)), http::unauthorized,
Catch::Message("invalid_token"));
Expand Down