Skip to content

Commit

Permalink
porting protocol changes btween 54455 and 54459 (#504)
Browse files Browse the repository at this point in the history
  • Loading branch information
zliang-min authored Jan 19, 2024
1 parent 1bc83df commit d7fe1d5
Show file tree
Hide file tree
Showing 62 changed files with 708 additions and 263 deletions.
28 changes: 24 additions & 4 deletions programs/benchmark/Benchmark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class Benchmark : public Poco::Util::Application
Benchmark(unsigned concurrency_, double delay_,
Strings && hosts_, Ports && ports_, bool round_robin_,
bool cumulative_, bool secure_, const String & default_database_,
const String & user_, const String & password_, const String & stage,
const String & user_, const String & password_, const String & quota_key_, const String & stage,
bool randomize_, size_t max_iterations_, double max_time_,
const String & json_path_, size_t confidence_,
const String & query_id_, const String & query_to_execute_, bool continue_on_errors_,
Expand Down Expand Up @@ -90,7 +90,7 @@ class Benchmark : public Poco::Util::Application
connections.emplace_back(std::make_unique<ConnectionPool>(
concurrency,
cur_host, cur_port,
default_database_, user_, password_,
default_database_, user_, password_, quota_key_,
/* cluster_= */ "",
/* cluster_secret_= */ "",
/* client_name_= */ "benchmark",
Expand Down Expand Up @@ -599,6 +599,24 @@ int mainBenchmark(int argc, char ** argv)
{
using boost::program_options::value;

/// Note: according to the standard, subsequent calls to getenv can mangle previous result.
/// So we copy the results to std::string.
std::optional<std::string> env_user_str;
std::optional<std::string> env_password_str;
std::optional<std::string> env_quota_key_str;

const char * env_user = getenv("TIMEPLUS_USER"); // NOLINT(concurrency-mt-unsafe)
if (env_user != nullptr)
env_user_str.emplace(std::string(env_user));

const char * env_password = getenv("TIMEPLUS_PASSWORD"); // NOLINT(concurrency-mt-unsafe)
if (env_password != nullptr)
env_password_str.emplace(std::string(env_password));

const char * env_quota_key = getenv("TIMEPLUS_QUOTA_KEY"); // NOLINT(concurrency-mt-unsafe)
if (env_quota_key != nullptr)
env_quota_key_str.emplace(std::string(env_quota_key));

boost::program_options::options_description desc = createOptionsDescription("Allowed options", getTerminalWidth());
desc.add_options()
("help", "produce help message")
Expand All @@ -615,8 +633,9 @@ int mainBenchmark(int argc, char ** argv)
("roundrobin", "Instead of comparing queries for different --host/--port just pick one random --host/--port for every query and send query to it.")
("cumulative", "prints cumulative data instead of data per interval")
("secure,s", "Use TLS connection")
("user", value<std::string>()->default_value("default"), "")
("password", value<std::string>()->default_value(""), "")
("user,u", value<std::string>()->default_value(env_user_str.value_or("default")), "")
("password", value<std::string>()->default_value(env_password_str.value_or("")), "")
("quota_key", value<std::string>()->default_value(env_quota_key_str.value_or("")), "")
("database", value<std::string>()->default_value("default"), "")
("stacktrace", "print stack traces of exceptions")
("confidence", value<size_t>()->default_value(5), "set the level of confidence for T-test [0=80%, 1=90%, 2=95%, 3=98%, 4=99%, 5=99.5%(default)")
Expand Down Expand Up @@ -665,6 +684,7 @@ int mainBenchmark(int argc, char ** argv)
options["database"].as<std::string>(),
options["user"].as<std::string>(),
options["password"].as<std::string>(),
options["quota_key"].as<std::string>(),
options["stage"].as<std::string>(),
options["randomize"].as<bool>(),
options["iterations"].as<size_t>(),
Expand Down
46 changes: 27 additions & 19 deletions programs/client/Client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,10 +309,34 @@ bool Client::executeMultiQuery(const String & all_queries_text)
}
}

void Client::showWarnings()
{
try
{
std::vector<String> messages = loadWarningMessages();
if (!messages.empty())
{
std::cout << "Warnings:" << std::endl;
for (const auto & message : messages)
std::cout << " * " << message << std::endl;
std::cout << std::endl;
}
}
catch (...)
{
/// Ignore exception
}
}

/// Make query to get all server warnings
std::vector<String> Client::loadWarningMessages()
{
/// Older server versions cannot execute the query loading warnings.
constexpr UInt64 min_server_revision_to_load_warnings = DBMS_MIN_PROTOCOL_VERSION_WITH_VIEW_IF_PERMITTED;

if (server_revision < min_server_revision_to_load_warnings)
return {};

std::vector<String> messages;
connection->sendQuery(connection_parameters.timeouts, "SELECT message FROM system.warnings SETTINGS _tp_internal_system_open_sesame=true",
{} /* query_parameters */,
Expand Down Expand Up @@ -413,25 +437,9 @@ try

connect();

/// Load Warnings at the beginning of connection
/// Show warnings at the beginning of connection.
if (is_interactive && !config().has("no-warnings"))
{
try
{
std::vector<String> messages = loadWarningMessages();
if (!messages.empty())
{
std::cout << "Warnings:" << std::endl;
for (const auto & message : messages)
std::cout << " * " << message << std::endl;
std::cout << std::endl;
}
}
catch (...)
{
/// Ignore exception
}
}
showWarnings();

if (is_interactive && !delayed_interactive)
{
Expand Down Expand Up @@ -526,7 +534,7 @@ void Client::connect()
}

server_version = toString(server_version_major) + "." + toString(server_version_minor) + "." + toString(server_version_patch);
load_suggestions = is_interactive && (server_revision >= Suggest::MIN_SERVER_REVISION && !config().getBool("disable_suggestion", false));
load_suggestions = is_interactive && (server_revision >= Suggest::MIN_SERVER_REVISION) && !config().getBool("disable_suggestion", false);

if (server_display_name = connection->getServerDisplayName(connection_parameters.timeouts); server_display_name.empty())
server_display_name = config().getString("host", "localhost");
Expand Down
1 change: 1 addition & 0 deletions programs/client/Client.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class Client : public ClientBase

private:
void printChangedSettings() const;
void showWarnings();
std::vector<String> loadWarningMessages();
};
}
2 changes: 1 addition & 1 deletion src/Access/AccessControl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class AccessControl::ContextAccessCache
auto x = cache.get(params);
if (x)
{
if ((*x)->getUser())
if ((*x)->tryGetUser())
return *x;
/// No user, probably the user has been dropped while it was in the cache.
cache.remove(params);
Expand Down
44 changes: 44 additions & 0 deletions src/Access/CachedAccessChecking.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <Access/CachedAccessChecking.h>
#include <Access/ContextAccess.h>


namespace DB
{
CachedAccessChecking::CachedAccessChecking(const std::shared_ptr<const ContextAccess> & access_, AccessFlags access_flags_)
: CachedAccessChecking(access_, AccessRightsElement{access_flags_})
{
}

CachedAccessChecking::CachedAccessChecking(const std::shared_ptr<const ContextAccess> & access_, const AccessRightsElement & element_)
: access(access_), element(element_)
{
}

CachedAccessChecking::~CachedAccessChecking() = default;

bool CachedAccessChecking::checkAccess(bool throw_if_denied)
{
if (checked)
return result;
if (throw_if_denied)
{
try
{
access->checkAccess(element);
result = true;
}
catch (...)
{
result = false;
throw;
}
}
else
{
result = access->isGranted(element);
}
checked = true;
return result;
}

}
29 changes: 29 additions & 0 deletions src/Access/CachedAccessChecking.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once

#include <Access/Common/AccessRightsElement.h>
#include <memory>


namespace DB
{
class ContextAccess;

/// Checks if the current user has a specified access type granted,
/// and if it's checked another time later, it will just return the first result.
class CachedAccessChecking
{
public:
CachedAccessChecking(const std::shared_ptr<const ContextAccess> & access_, AccessFlags access_flags_);
CachedAccessChecking(const std::shared_ptr<const ContextAccess> & access_, const AccessRightsElement & element_);
~CachedAccessChecking();

bool checkAccess(bool throw_if_denied = true);

private:
const std::shared_ptr<const ContextAccess> access;
const AccessRightsElement element;
bool checked = false;
bool result = false;
};

}
51 changes: 49 additions & 2 deletions src/Access/ContextAccess.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <Access/EnabledQuota.h>
#include <Access/QuotaUsage.h>
#include <Access/User.h>
#include <Access/Role.h>
#include <Access/EnabledRolesInfo.h>
#include <Access/EnabledSettings.h>
#include <Access/SettingsProfilesInfo.h>
Expand All @@ -29,6 +30,7 @@ namespace ErrorCodes
extern const int QUERY_IS_PROHIBITED;
extern const int FUNCTION_NOT_ALLOWED;
extern const int UNKNOWN_USER;
extern const int LOGICAL_ERROR;
}


Expand Down Expand Up @@ -187,6 +189,7 @@ void ContextAccess::setUser(const UserPtr & user_) const
if (!user)
{
/// User has been dropped.
user_was_dropped = true;
subscription_for_user_change = {};
subscription_for_roles_changes = {};
access = nullptr;
Expand Down Expand Up @@ -261,6 +264,20 @@ void ContextAccess::calculateAccessRights() const


UserPtr ContextAccess::getUser() const
{
auto res = tryGetUser();

if (likely(res))
return res;

if (user_was_dropped)
throw Exception(ErrorCodes::UNKNOWN_USER, "User has been dropped");

throw Exception(ErrorCodes::LOGICAL_ERROR, "No user in current context, it's a bug");
}


UserPtr ContextAccess::tryGetUser() const
{
std::lock_guard lock{mutex};
return user;
Expand Down Expand Up @@ -397,7 +414,7 @@ bool ContextAccess::checkAccessImplHelper(const AccessFlags & flags, const Args
if (!flags || is_full_access)
return access_granted();

if (!getUser())
if (!tryGetUser())
return access_denied("User has been dropped", ErrorCodes::UNKNOWN_USER);

/// Access to temporary tables is controlled in an unusual way, not like normal tables.
Expand Down Expand Up @@ -587,7 +604,7 @@ bool ContextAccess::checkAdminOptionImplHelper(const Container & role_ids, const
throw Exception(getUserName() + ": " + msg, error_code);
};

if (!getUser())
if (!tryGetUser())
{
show_error("User has been dropped", ErrorCodes::UNKNOWN_USER);
return false;
Expand Down Expand Up @@ -677,4 +694,34 @@ void ContextAccess::checkAdminOption(const std::vector<UUID> & role_ids) const {
void ContextAccess::checkAdminOption(const std::vector<UUID> & role_ids, const Strings & names_of_roles) const { checkAdminOptionImpl<true>(role_ids, names_of_roles); }
void ContextAccess::checkAdminOption(const std::vector<UUID> & role_ids, const std::unordered_map<UUID, String> & names_of_roles) const { checkAdminOptionImpl<true>(role_ids, names_of_roles); }


void ContextAccess::checkGranteeIsAllowed(const UUID & grantee_id, const IAccessEntity & grantee) const
{
if (is_full_access)
return;

auto current_user = getUser();
if (!current_user->grantees.match(grantee_id))
throw Exception(grantee.formatTypeWithName() + " is not allowed as grantee", ErrorCodes::ACCESS_DENIED);
}

void ContextAccess::checkGranteesAreAllowed(const std::vector<UUID> & grantee_ids) const
{
if (is_full_access)
return;

auto current_user = getUser();
if (current_user->grantees == RolesOrUsersSet::AllTag{})
return;

for (const auto & id : grantee_ids)
{
auto entity = access_control->tryRead(id);
if (auto role_entity = typeid_cast<RolePtr>(entity))
checkGranteeIsAllowed(id, *role_entity);
else if (auto user_entity = typeid_cast<UserPtr>(entity))
checkGranteeIsAllowed(id, *user_entity);
}
}

}
11 changes: 10 additions & 1 deletion src/Access/ContextAccess.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct SettingsProfilesInfo;
class SettingsChanges;
class AccessControl;
class IAST;
struct IAccessEntity;
using ASTPtr = std::shared_ptr<IAST>;


Expand Down Expand Up @@ -69,8 +70,10 @@ class ContextAccess : public std::enable_shared_from_this<ContextAccess>
using Params = ContextAccessParams;
const Params & getParams() const { return params; }

/// Returns the current user. The function can return nullptr.
/// Returns the current user. Throws if user is nullptr.
UserPtr getUser() const;
/// Same as above, but can return nullptr.
UserPtr tryGetUser() const;
String getUserName() const;
std::optional<UUID> getUserID() const { return getParams().user_id; }

Expand Down Expand Up @@ -152,6 +155,11 @@ class ContextAccess : public std::enable_shared_from_this<ContextAccess>
bool hasAdminOption(const std::vector<UUID> & role_ids, const Strings & names_of_roles) const;
bool hasAdminOption(const std::vector<UUID> & role_ids, const std::unordered_map<UUID, String> & names_of_roles) const;

/// Checks if a grantee is allowed for the current user, throws an exception if not.
void checkGranteeIsAllowed(const UUID & grantee_id, const IAccessEntity & grantee) const;
/// Checks if grantees are allowed for the current user, throws an exception if not.
void checkGranteesAreAllowed(const std::vector<UUID> & grantee_ids) const;

/// Makes an instance of ContextAccess which provides full access to everything
/// without any limitations. This is used for the global context.
static std::shared_ptr<const ContextAccess> getFullAccess();
Expand Down Expand Up @@ -214,6 +222,7 @@ class ContextAccess : public std::enable_shared_from_this<ContextAccess>
mutable Poco::Logger * trace_log = nullptr;
mutable UserPtr user;
mutable String user_name;
mutable bool user_was_dropped = false;
mutable scope_guard subscription_for_user_change;
mutable std::shared_ptr<const EnabledRoles> enabled_roles;
mutable scope_guard subscription_for_roles_changes;
Expand Down
17 changes: 17 additions & 0 deletions src/Access/User.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
#include <Access/User.h>
#include <Core/Protocol.h>


namespace DB
{
namespace ErrorCodes
{
extern const int BAD_ARGUMENTS;
}

bool User::equal(const IAccessEntity & other) const
{
Expand All @@ -14,4 +19,16 @@ bool User::equal(const IAccessEntity & other) const
&& (settings == other_user.settings) && (grantees == other_user.grantees) && (default_database == other_user.default_database);
}

void User::setName(const String & name_)
{
/// Unfortunately, there is not way to distinguish USER_INTERSERVER_MARKER from actual username in native protocol,
/// so we have to ensure that no such user will appear.
/// Also it was possible to create a user with empty name for some reason.
if (name_.empty())
throw Exception(ErrorCodes::BAD_ARGUMENTS, "User name is empty");
if (name_ == USER_INTERSERVER_MARKER)
throw Exception(ErrorCodes::BAD_ARGUMENTS, "User name '{}' is reserved", USER_INTERSERVER_MARKER);
name = name_;
}

}
Loading

0 comments on commit d7fe1d5

Please sign in to comment.