From d7fe1d5a92dc5f86053ff92da0f4ef536bc5a85e Mon Sep 17 00:00:00 2001 From: Gimi Liang Date: Thu, 18 Jan 2024 19:54:44 -0800 Subject: [PATCH] porting protocol changes btween 54455 and 54459 (#504) --- programs/benchmark/Benchmark.cpp | 28 +++- programs/client/Client.cpp | 46 +++--- programs/client/Client.h | 1 + src/Access/AccessControl.cpp | 2 +- src/Access/CachedAccessChecking.cpp | 44 ++++++ src/Access/CachedAccessChecking.h | 29 ++++ src/Access/ContextAccess.cpp | 51 +++++- src/Access/ContextAccess.h | 11 +- src/Access/User.cpp | 17 ++ src/Access/User.h | 1 + src/Client/ClientBase.cpp | 12 +- src/Client/ClientBase.h | 2 +- src/Client/Connection.cpp | 20 ++- src/Client/Connection.h | 3 + src/Client/ConnectionParameters.cpp | 1 + src/Client/ConnectionParameters.h | 1 + src/Client/ConnectionPool.cpp | 4 +- src/Client/ConnectionPool.h | 7 +- src/Client/LocalConnection.cpp | 30 ++-- src/Client/LocalConnection.h | 2 +- src/Client/Suggest.h | 4 +- src/Core/ProtocolDefines.h | 9 +- .../ClickHouseDictionarySource.cpp | 6 +- src/Dictionaries/ClickHouseDictionarySource.h | 1 + .../Access/InterpreterCreateUserQuery.cpp | 5 +- .../Access/InterpreterGrantQuery.cpp | 36 +---- ...InterpreterShowCreateAccessEntityQuery.cpp | 3 +- .../Access/InterpreterShowGrantsQuery.cpp | 26 ++- src/Interpreters/Cluster.cpp | 7 +- src/Interpreters/Cluster.h | 1 + src/Interpreters/Context.cpp | 7 +- src/Interpreters/Context.h | 2 +- src/Interpreters/Session.cpp | 71 +++++---- src/Interpreters/Session.h | 4 +- src/Interpreters/SessionLog.cpp | 35 +++-- src/Interpreters/SessionLog.h | 12 +- src/Interpreters/SystemLog.cpp | 3 + src/Parsers/Access/ASTCreateUserQuery.cpp | 4 +- src/Parsers/Access/ASTCreateUserQuery.h | 2 +- src/Parsers/Access/ParserCreateUserQuery.cpp | 13 +- src/Parsers/Access/ParserUserNameWithHost.cpp | 3 +- src/QueryPipeline/RemoteInserter.cpp | 8 + src/Server/TCPHandler.cpp | 148 +++++++++++------- src/Server/TCPHandler.h | 9 +- src/Storages/Distributed/DirectoryMonitor.cpp | 1 + .../ExternalDataSourceConfiguration.cpp | 9 +- .../ExternalDataSourceConfiguration.h | 1 + src/Storages/StorageS3Cluster.cpp | 2 +- .../System/StorageSystemCurrentRoles.cpp | 2 - .../System/StorageSystemEnabledRoles.cpp | 2 - tests/config/config.d/clusters.xml | 15 ++ tests/config/config.d/session_log.xml | 7 + tests/config/install.sh | 1 + tests/integration/helpers/client.py | 78 +++++---- .../test_insert_profile_events.py | 42 +++++ .../test.py | 23 ++- .../integration/test_grant_and_revoke/test.py | 9 ++ tests/integration/test_role/test.py | 6 + ...t_INSERT_progress_profile_events.reference | 2 + ...e_client_INSERT_progress_profile_events.sh | 19 +++ ...l_INSERT_progress_profile_events.reference | 2 + ...se_local_INSERT_progress_profile_events.sh | 19 +++ 62 files changed, 708 insertions(+), 263 deletions(-) create mode 100644 src/Access/CachedAccessChecking.cpp create mode 100644 src/Access/CachedAccessChecking.h create mode 100644 tests/config/config.d/session_log.xml create mode 100644 tests/integration/test_backward_compatibility/test_insert_profile_events.py create mode 100755 tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.sh create mode 100755 tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.sh diff --git a/programs/benchmark/Benchmark.cpp b/programs/benchmark/Benchmark.cpp index a6bef7bbd23..710474fb1ac 100644 --- a/programs/benchmark/Benchmark.cpp +++ b/programs/benchmark/Benchmark.cpp @@ -61,7 +61,7 @@ class Benchmark : public Poco::Util::Application Benchmark(unsigned concurrency_, double delay_, Strings && hosts_, Ports && ports_, bool round_robin_, bool cumulative_, bool secure_, const String & default_database_, - const String & user_, const String & password_, const String & stage, + const String & user_, const String & password_, const String & quota_key_, const String & stage, bool randomize_, size_t max_iterations_, double max_time_, const String & json_path_, size_t confidence_, const String & query_id_, const String & query_to_execute_, bool continue_on_errors_, @@ -90,7 +90,7 @@ class Benchmark : public Poco::Util::Application connections.emplace_back(std::make_unique( concurrency, cur_host, cur_port, - default_database_, user_, password_, + default_database_, user_, password_, quota_key_, /* cluster_= */ "", /* cluster_secret_= */ "", /* client_name_= */ "benchmark", @@ -599,6 +599,24 @@ int mainBenchmark(int argc, char ** argv) { using boost::program_options::value; + /// Note: according to the standard, subsequent calls to getenv can mangle previous result. + /// So we copy the results to std::string. + std::optional env_user_str; + std::optional env_password_str; + std::optional env_quota_key_str; + + const char * env_user = getenv("TIMEPLUS_USER"); // NOLINT(concurrency-mt-unsafe) + if (env_user != nullptr) + env_user_str.emplace(std::string(env_user)); + + const char * env_password = getenv("TIMEPLUS_PASSWORD"); // NOLINT(concurrency-mt-unsafe) + if (env_password != nullptr) + env_password_str.emplace(std::string(env_password)); + + const char * env_quota_key = getenv("TIMEPLUS_QUOTA_KEY"); // NOLINT(concurrency-mt-unsafe) + if (env_quota_key != nullptr) + env_quota_key_str.emplace(std::string(env_quota_key)); + boost::program_options::options_description desc = createOptionsDescription("Allowed options", getTerminalWidth()); desc.add_options() ("help", "produce help message") @@ -615,8 +633,9 @@ int mainBenchmark(int argc, char ** argv) ("roundrobin", "Instead of comparing queries for different --host/--port just pick one random --host/--port for every query and send query to it.") ("cumulative", "prints cumulative data instead of data per interval") ("secure,s", "Use TLS connection") - ("user", value()->default_value("default"), "") - ("password", value()->default_value(""), "") + ("user,u", value()->default_value(env_user_str.value_or("default")), "") + ("password", value()->default_value(env_password_str.value_or("")), "") + ("quota_key", value()->default_value(env_quota_key_str.value_or("")), "") ("database", value()->default_value("default"), "") ("stacktrace", "print stack traces of exceptions") ("confidence", value()->default_value(5), "set the level of confidence for T-test [0=80%, 1=90%, 2=95%, 3=98%, 4=99%, 5=99.5%(default)") @@ -665,6 +684,7 @@ int mainBenchmark(int argc, char ** argv) options["database"].as(), options["user"].as(), options["password"].as(), + options["quota_key"].as(), options["stage"].as(), options["randomize"].as(), options["iterations"].as(), diff --git a/programs/client/Client.cpp b/programs/client/Client.cpp index aae909c022c..7358dec9285 100644 --- a/programs/client/Client.cpp +++ b/programs/client/Client.cpp @@ -309,10 +309,34 @@ bool Client::executeMultiQuery(const String & all_queries_text) } } +void Client::showWarnings() +{ + try + { + std::vector messages = loadWarningMessages(); + if (!messages.empty()) + { + std::cout << "Warnings:" << std::endl; + for (const auto & message : messages) + std::cout << " * " << message << std::endl; + std::cout << std::endl; + } + } + catch (...) + { + /// Ignore exception + } +} /// Make query to get all server warnings std::vector Client::loadWarningMessages() { + /// Older server versions cannot execute the query loading warnings. + constexpr UInt64 min_server_revision_to_load_warnings = DBMS_MIN_PROTOCOL_VERSION_WITH_VIEW_IF_PERMITTED; + + if (server_revision < min_server_revision_to_load_warnings) + return {}; + std::vector messages; connection->sendQuery(connection_parameters.timeouts, "SELECT message FROM system.warnings SETTINGS _tp_internal_system_open_sesame=true", {} /* query_parameters */, @@ -413,25 +437,9 @@ try connect(); - /// Load Warnings at the beginning of connection + /// Show warnings at the beginning of connection. if (is_interactive && !config().has("no-warnings")) - { - try - { - std::vector messages = loadWarningMessages(); - if (!messages.empty()) - { - std::cout << "Warnings:" << std::endl; - for (const auto & message : messages) - std::cout << " * " << message << std::endl; - std::cout << std::endl; - } - } - catch (...) - { - /// Ignore exception - } - } + showWarnings(); if (is_interactive && !delayed_interactive) { @@ -526,7 +534,7 @@ void Client::connect() } server_version = toString(server_version_major) + "." + toString(server_version_minor) + "." + toString(server_version_patch); - load_suggestions = is_interactive && (server_revision >= Suggest::MIN_SERVER_REVISION && !config().getBool("disable_suggestion", false)); + load_suggestions = is_interactive && (server_revision >= Suggest::MIN_SERVER_REVISION) && !config().getBool("disable_suggestion", false); if (server_display_name = connection->getServerDisplayName(connection_parameters.timeouts); server_display_name.empty()) server_display_name = config().getString("host", "localhost"); diff --git a/programs/client/Client.h b/programs/client/Client.h index 2def74ef3fc..8c710f623e6 100644 --- a/programs/client/Client.h +++ b/programs/client/Client.h @@ -31,6 +31,7 @@ class Client : public ClientBase private: void printChangedSettings() const; + void showWarnings(); std::vector loadWarningMessages(); }; } diff --git a/src/Access/AccessControl.cpp b/src/Access/AccessControl.cpp index 095da3f8cc8..445724d6964 100644 --- a/src/Access/AccessControl.cpp +++ b/src/Access/AccessControl.cpp @@ -71,7 +71,7 @@ class AccessControl::ContextAccessCache auto x = cache.get(params); if (x) { - if ((*x)->getUser()) + if ((*x)->tryGetUser()) return *x; /// No user, probably the user has been dropped while it was in the cache. cache.remove(params); diff --git a/src/Access/CachedAccessChecking.cpp b/src/Access/CachedAccessChecking.cpp new file mode 100644 index 00000000000..aa8ef6073d3 --- /dev/null +++ b/src/Access/CachedAccessChecking.cpp @@ -0,0 +1,44 @@ +#include +#include + + +namespace DB +{ +CachedAccessChecking::CachedAccessChecking(const std::shared_ptr & access_, AccessFlags access_flags_) + : CachedAccessChecking(access_, AccessRightsElement{access_flags_}) +{ +} + +CachedAccessChecking::CachedAccessChecking(const std::shared_ptr & access_, const AccessRightsElement & element_) + : access(access_), element(element_) +{ +} + +CachedAccessChecking::~CachedAccessChecking() = default; + +bool CachedAccessChecking::checkAccess(bool throw_if_denied) +{ + if (checked) + return result; + if (throw_if_denied) + { + try + { + access->checkAccess(element); + result = true; + } + catch (...) + { + result = false; + throw; + } + } + else + { + result = access->isGranted(element); + } + checked = true; + return result; +} + +} diff --git a/src/Access/CachedAccessChecking.h b/src/Access/CachedAccessChecking.h new file mode 100644 index 00000000000..e87c28dd823 --- /dev/null +++ b/src/Access/CachedAccessChecking.h @@ -0,0 +1,29 @@ +#pragma once + +#include +#include + + +namespace DB +{ +class ContextAccess; + +/// Checks if the current user has a specified access type granted, +/// and if it's checked another time later, it will just return the first result. +class CachedAccessChecking +{ +public: + CachedAccessChecking(const std::shared_ptr & access_, AccessFlags access_flags_); + CachedAccessChecking(const std::shared_ptr & access_, const AccessRightsElement & element_); + ~CachedAccessChecking(); + + bool checkAccess(bool throw_if_denied = true); + +private: + const std::shared_ptr access; + const AccessRightsElement element; + bool checked = false; + bool result = false; +}; + +} diff --git a/src/Access/ContextAccess.cpp b/src/Access/ContextAccess.cpp index 206d7d79fed..0ebb8f73cd4 100644 --- a/src/Access/ContextAccess.cpp +++ b/src/Access/ContextAccess.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,7 @@ namespace ErrorCodes extern const int QUERY_IS_PROHIBITED; extern const int FUNCTION_NOT_ALLOWED; extern const int UNKNOWN_USER; + extern const int LOGICAL_ERROR; } @@ -187,6 +189,7 @@ void ContextAccess::setUser(const UserPtr & user_) const if (!user) { /// User has been dropped. + user_was_dropped = true; subscription_for_user_change = {}; subscription_for_roles_changes = {}; access = nullptr; @@ -261,6 +264,20 @@ void ContextAccess::calculateAccessRights() const UserPtr ContextAccess::getUser() const +{ + auto res = tryGetUser(); + + if (likely(res)) + return res; + + if (user_was_dropped) + throw Exception(ErrorCodes::UNKNOWN_USER, "User has been dropped"); + + throw Exception(ErrorCodes::LOGICAL_ERROR, "No user in current context, it's a bug"); +} + + +UserPtr ContextAccess::tryGetUser() const { std::lock_guard lock{mutex}; return user; @@ -397,7 +414,7 @@ bool ContextAccess::checkAccessImplHelper(const AccessFlags & flags, const Args if (!flags || is_full_access) return access_granted(); - if (!getUser()) + if (!tryGetUser()) return access_denied("User has been dropped", ErrorCodes::UNKNOWN_USER); /// Access to temporary tables is controlled in an unusual way, not like normal tables. @@ -587,7 +604,7 @@ bool ContextAccess::checkAdminOptionImplHelper(const Container & role_ids, const throw Exception(getUserName() + ": " + msg, error_code); }; - if (!getUser()) + if (!tryGetUser()) { show_error("User has been dropped", ErrorCodes::UNKNOWN_USER); return false; @@ -677,4 +694,34 @@ void ContextAccess::checkAdminOption(const std::vector & role_ids) const { void ContextAccess::checkAdminOption(const std::vector & role_ids, const Strings & names_of_roles) const { checkAdminOptionImpl(role_ids, names_of_roles); } void ContextAccess::checkAdminOption(const std::vector & role_ids, const std::unordered_map & names_of_roles) const { checkAdminOptionImpl(role_ids, names_of_roles); } + +void ContextAccess::checkGranteeIsAllowed(const UUID & grantee_id, const IAccessEntity & grantee) const +{ + if (is_full_access) + return; + + auto current_user = getUser(); + if (!current_user->grantees.match(grantee_id)) + throw Exception(grantee.formatTypeWithName() + " is not allowed as grantee", ErrorCodes::ACCESS_DENIED); +} + +void ContextAccess::checkGranteesAreAllowed(const std::vector & grantee_ids) const +{ + if (is_full_access) + return; + + auto current_user = getUser(); + if (current_user->grantees == RolesOrUsersSet::AllTag{}) + return; + + for (const auto & id : grantee_ids) + { + auto entity = access_control->tryRead(id); + if (auto role_entity = typeid_cast(entity)) + checkGranteeIsAllowed(id, *role_entity); + else if (auto user_entity = typeid_cast(entity)) + checkGranteeIsAllowed(id, *user_entity); + } +} + } diff --git a/src/Access/ContextAccess.h b/src/Access/ContextAccess.h index 1ee8890ac13..3160cc619a5 100644 --- a/src/Access/ContextAccess.h +++ b/src/Access/ContextAccess.h @@ -29,6 +29,7 @@ struct SettingsProfilesInfo; class SettingsChanges; class AccessControl; class IAST; +struct IAccessEntity; using ASTPtr = std::shared_ptr; @@ -69,8 +70,10 @@ class ContextAccess : public std::enable_shared_from_this using Params = ContextAccessParams; const Params & getParams() const { return params; } - /// Returns the current user. The function can return nullptr. + /// Returns the current user. Throws if user is nullptr. UserPtr getUser() const; + /// Same as above, but can return nullptr. + UserPtr tryGetUser() const; String getUserName() const; std::optional getUserID() const { return getParams().user_id; } @@ -152,6 +155,11 @@ class ContextAccess : public std::enable_shared_from_this bool hasAdminOption(const std::vector & role_ids, const Strings & names_of_roles) const; bool hasAdminOption(const std::vector & role_ids, const std::unordered_map & names_of_roles) const; + /// Checks if a grantee is allowed for the current user, throws an exception if not. + void checkGranteeIsAllowed(const UUID & grantee_id, const IAccessEntity & grantee) const; + /// Checks if grantees are allowed for the current user, throws an exception if not. + void checkGranteesAreAllowed(const std::vector & grantee_ids) const; + /// Makes an instance of ContextAccess which provides full access to everything /// without any limitations. This is used for the global context. static std::shared_ptr getFullAccess(); @@ -214,6 +222,7 @@ class ContextAccess : public std::enable_shared_from_this mutable Poco::Logger * trace_log = nullptr; mutable UserPtr user; mutable String user_name; + mutable bool user_was_dropped = false; mutable scope_guard subscription_for_user_change; mutable std::shared_ptr enabled_roles; mutable scope_guard subscription_for_roles_changes; diff --git a/src/Access/User.cpp b/src/Access/User.cpp index d7c7f5c7ada..0476242c504 100644 --- a/src/Access/User.cpp +++ b/src/Access/User.cpp @@ -1,8 +1,13 @@ #include +#include namespace DB { +namespace ErrorCodes +{ + extern const int BAD_ARGUMENTS; +} bool User::equal(const IAccessEntity & other) const { @@ -14,4 +19,16 @@ bool User::equal(const IAccessEntity & other) const && (settings == other_user.settings) && (grantees == other_user.grantees) && (default_database == other_user.default_database); } +void User::setName(const String & name_) +{ + /// Unfortunately, there is not way to distinguish USER_INTERSERVER_MARKER from actual username in native protocol, + /// so we have to ensure that no such user will appear. + /// Also it was possible to create a user with empty name for some reason. + if (name_.empty()) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "User name is empty"); + if (name_ == USER_INTERSERVER_MARKER) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "User name '{}' is reserved", USER_INTERSERVER_MARKER); + name = name_; +} + } diff --git a/src/Access/User.h b/src/Access/User.h index b9167d68f15..57a3b178acf 100644 --- a/src/Access/User.h +++ b/src/Access/User.h @@ -28,6 +28,7 @@ struct User : public IAccessEntity std::shared_ptr clone() const override { return cloneImpl(); } static constexpr const auto TYPE = AccessEntityType::USER; AccessEntityType getType() const override { return TYPE; } + void setName(const String & name_) override; }; using UserPtr = std::shared_ptr; diff --git a/src/Client/ClientBase.cpp b/src/Client/ClientBase.cpp index c8b3eac54a8..d649a193685 100644 --- a/src/Client/ClientBase.cpp +++ b/src/Client/ClientBase.cpp @@ -1340,7 +1340,7 @@ void ClientBase::sendDataFromPipe(Pipe&& pipe, ASTPtr parsed_query, bool have_mo } /// Check if server send Log packet - receiveLogs(parsed_query); + receiveLogsAndProfileEvents(parsed_query); /// Check if server send Exception packet auto packet_type = connection->checkPacket(0); @@ -1387,11 +1387,11 @@ void ClientBase::sendDataFromStdin(Block & sample, const ColumnsDescription & co } /// Process Log packets, used when inserting data by blocks -void ClientBase::receiveLogs(ASTPtr parsed_query) +void ClientBase::receiveLogsAndProfileEvents(ASTPtr parsed_query) { auto packet_type = connection->checkPacket(0); - while (packet_type && *packet_type == Protocol::Server::Log) + while (packet_type && (*packet_type == Protocol::Server::Log || *packet_type == Protocol::Server::ProfileEvents)) { receiveAndProcessPacket(parsed_query, false); packet_type = connection->checkPacket(0); @@ -1424,9 +1424,13 @@ bool ClientBase::receiveEndOfQuery() onProgress(packet.progress); break; + case Protocol::Server::ProfileEvents: + onProfileEvents(packet.block); + break; + default: throw NetException( - "Unexpected packet from server (expected Exception, EndOfStream or Log, got " + "Unexpected packet from server (expected Exception, EndOfStream, Log, Progress or ProfileEvents. Got " + String(Protocol::Server::toString(packet.type)) + ")", ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER); } diff --git a/src/Client/ClientBase.h b/src/Client/ClientBase.h index 5557ab15963..60246d59f7f 100644 --- a/src/Client/ClientBase.h +++ b/src/Client/ClientBase.h @@ -119,7 +119,7 @@ class ClientBase : public Poco::Util::Application, public IHints<2, ClientBase> private: void receiveResult(ASTPtr parsed_query); bool receiveAndProcessPacket(ASTPtr parsed_query, bool cancelled_); - void receiveLogs(ASTPtr parsed_query); + void receiveLogsAndProfileEvents(ASTPtr parsed_query); bool receiveSampleBlock(Block & out, ColumnsDescription & columns_description, ASTPtr parsed_query); bool receiveEndOfQuery(); void cancelQuery(); diff --git a/src/Client/Connection.cpp b/src/Client/Connection.cpp index 0102c10cf44..c7d5d441203 100644 --- a/src/Client/Connection.cpp +++ b/src/Client/Connection.cpp @@ -62,6 +62,7 @@ Connection::~Connection() = default; Connection::Connection(const String & host_, UInt16 port_, const String & default_database_, const String & user_, const String & password_, + const String & quota_key_, const String & cluster_, const String & cluster_secret_, const String & client_name_, @@ -69,7 +70,7 @@ Connection::Connection(const String & host_, UInt16 port_, Protocol::Secure secure_, Poco::Timespan sync_request_timeout_) : host(host_), port(port_), default_database(default_database_) - , user(user_), password(password_) + , user(user_), password(password_), quota_key(quota_key_) , cluster(cluster_) , cluster_secret(cluster_secret_) , client_name(client_name_) @@ -146,6 +147,8 @@ void Connection::connect(const ConnectionTimeouts & timeouts) sendHello(); receiveHello(); + if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) + sendAddendum(); LOG_TRACE(log_wrapper.get(), "Connected to {} server version {}.{}.{}.", server_name, server_version_major, server_version_minor, server_version_patch); @@ -242,6 +245,14 @@ void Connection::sendHello() } +void Connection::sendAddendum() +{ + if (server_revision >= DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY) + writeStringBinary(quota_key, *out); + out->next(); +} + + void Connection::receiveHello() { /// Receive hello packet. @@ -373,11 +384,10 @@ void Connection::sendClusterNameAndSalt() bool Connection::ping() { - // LOG_TRACE(log_wrapper.get(), "Ping"); - - TimeoutSetter timeout_setter(*socket, sync_request_timeout, true); try { + TimeoutSetter timeout_setter(*socket, sync_request_timeout, true); + UInt64 pong = 0; writeVarUInt(Protocol::Client::Ping, *out); out->next(); @@ -822,7 +832,6 @@ std::optional Connection::checkPacket(size_t timeout_microseconds) if (hasReadPendingData() || poll(timeout_microseconds)) { - // LOG_TRACE(log_wrapper.get(), "Receiving packet type"); UInt64 packet_type; readVarUInt(packet_type, *in); @@ -1068,6 +1077,7 @@ ServerConnectionPtr Connection::createConnection(const ConnectionParameters & pa parameters.default_database, parameters.user, parameters.password, + parameters.quota_key, "", /* cluster */ "", /* cluster_secret */ "client", diff --git a/src/Client/Connection.h b/src/Client/Connection.h index 7949fcd4e04..243420b0593 100644 --- a/src/Client/Connection.h +++ b/src/Client/Connection.h @@ -51,6 +51,7 @@ class Connection : public IServerConnection Connection(const String & host_, UInt16 port_, const String & default_database_, const String & user_, const String & password_, + const String & quota_key_, const String & cluster_, const String & cluster_secret_, const String & client_name_, @@ -160,6 +161,7 @@ class Connection : public IServerConnection String default_database; String user; String password; + String quota_key; /// For inter-server authorization String cluster; @@ -246,6 +248,7 @@ class Connection : public IServerConnection void connect(const ConnectionTimeouts & timeouts); void sendHello(); + void sendAddendum(); void receiveHello(); #if USE_SSL diff --git a/src/Client/ConnectionParameters.cpp b/src/Client/ConnectionParameters.cpp index ec5fa407ebd..175270eb6d9 100644 --- a/src/Client/ConnectionParameters.cpp +++ b/src/Client/ConnectionParameters.cpp @@ -57,6 +57,7 @@ ConnectionParameters::ConnectionParameters(const Poco::Util::AbstractConfigurati if (auto * result = readpassphrase(prompt.c_str(), buf, sizeof(buf), 0)) password = result; } + quota_key = config.getString("quota_key", ""); /// By default compression is disabled if address looks like localhost. compression = config.getBool("compression", !isLocalAddress(DNSResolver::instance().resolveHost(host))) diff --git a/src/Client/ConnectionParameters.h b/src/Client/ConnectionParameters.h index a169df8390a..1d3bf41e6f6 100644 --- a/src/Client/ConnectionParameters.h +++ b/src/Client/ConnectionParameters.h @@ -18,6 +18,7 @@ struct ConnectionParameters std::string default_database; std::string user; std::string password; + std::string quota_key; Protocol::Secure security = Protocol::Secure::Disable; Protocol::Compression compression = Protocol::Compression::Enable; ConnectionTimeouts timeouts; diff --git a/src/Client/ConnectionPool.cpp b/src/Client/ConnectionPool.cpp index 4ec87127318..8433b0833fa 100644 --- a/src/Client/ConnectionPool.cpp +++ b/src/Client/ConnectionPool.cpp @@ -12,6 +12,7 @@ ConnectionPoolPtr ConnectionPoolFactory::get( String default_database, String user, String password, + String quota_key, String cluster, String cluster_secret, String client_name, @@ -20,7 +21,7 @@ ConnectionPoolPtr ConnectionPoolFactory::get( Int64 priority) { Key key{ - max_connections, host, port, default_database, user, password, cluster, cluster_secret, client_name, compression, secure, priority}; + max_connections, host, port, default_database, user, password, quota_key, cluster, cluster_secret, client_name, compression, secure, priority}; std::lock_guard lock(mutex); auto [it, inserted] = pools.emplace(key, ConnectionPoolPtr{}); @@ -37,6 +38,7 @@ ConnectionPoolPtr ConnectionPoolFactory::get( default_database, user, password, + quota_key, cluster, cluster_secret, client_name, diff --git a/src/Client/ConnectionPool.h b/src/Client/ConnectionPool.h index e6296a94619..a5c4d2448b3 100644 --- a/src/Client/ConnectionPool.h +++ b/src/Client/ConnectionPool.h @@ -55,6 +55,7 @@ class ConnectionPool : public IConnectionPool, private PoolBase const String & default_database_, const String & user_, const String & password_, + const String & quota_key_, const String & cluster_, const String & cluster_secret_, const String & client_name_, @@ -68,6 +69,7 @@ class ConnectionPool : public IConnectionPool, private PoolBase default_database(default_database_), user(user_), password(password_), + quota_key(quota_key_), cluster(cluster_), cluster_secret(cluster_secret_), client_name(client_name_), @@ -113,7 +115,7 @@ class ConnectionPool : public IConnectionPool, private PoolBase { return std::make_shared( host, port, - default_database, user, password, + default_database, user, password, quota_key, cluster, cluster_secret, client_name, compression, secure); } @@ -124,6 +126,7 @@ class ConnectionPool : public IConnectionPool, private PoolBase String default_database; String user; String password; + String quota_key; /// For inter-server authorization String cluster; @@ -150,6 +153,7 @@ class ConnectionPoolFactory final : private boost::noncopyable String default_database; String user; String password; + String quota_key; String cluster; String cluster_secret; String client_name; @@ -172,6 +176,7 @@ class ConnectionPoolFactory final : private boost::noncopyable String default_database, String user, String password, + String quota_key, String cluster, String cluster_secret, String client_name, diff --git a/src/Client/LocalConnection.cpp b/src/Client/LocalConnection.cpp index 70e660676d0..c1a61a59650 100644 --- a/src/Client/LocalConnection.cpp +++ b/src/Client/LocalConnection.cpp @@ -17,6 +17,7 @@ namespace ErrorCodes extern const int UNKNOWN_PACKET_FROM_SERVER; extern const int UNKNOWN_EXCEPTION; extern const int NOT_IMPLEMENTED; + extern const int LOGICAL_ERROR; } LocalConnection::LocalConnection(ContextPtr context_, bool send_progress_, bool send_profile_events_, const String & server_display_name_) @@ -58,9 +59,13 @@ void LocalConnection::updateProgress(const Progress & value) state->progress.incrementPiecewiseAtomically(value); } -void LocalConnection::getProfileEvents(Block & block) +void LocalConnection::sendProfileEvents() { - ProfileEvents::getProfileEvents(server_display_name, state->profile_queue, block, last_sent_snapshots); + Block profile_block; + state->after_send_profile_events.restart(); + next_packet_type = Protocol::Server::ProfileEvents; + ProfileEvents::getProfileEvents(server_display_name, state->profile_queue, profile_block, last_sent_snapshots); + state->block.emplace(std::move(profile_block)); } void LocalConnection::sendQuery( @@ -188,13 +193,14 @@ void LocalConnection::sendData(const Block & block, const String &, bool) return; if (state->pushing_async_executor) - { state->pushing_async_executor->push(block); - } else if (state->pushing_executor) - { state->pushing_executor->push(block); - } + else + throw Exception("Unknown executor", ErrorCodes::LOGICAL_ERROR); + + if (send_profile_events) + sendProfileEvents(); } void LocalConnection::sendCancel() @@ -260,11 +266,7 @@ bool LocalConnection::poll(size_t) if (send_profile_events && (state->after_send_profile_events.elapsedMicroseconds() >= query_context->getSettingsRef().interactive_delay)) { - Block block; - state->after_send_profile_events.restart(); - next_packet_type = Protocol::Server::ProfileEvents; - getProfileEvents(block); - state->block.emplace(std::move(block)); + sendProfileEvents(); return true; } @@ -345,11 +347,7 @@ bool LocalConnection::poll(size_t) if (send_profile_events && state->executor) { - Block block; - state->after_send_profile_events.restart(); - next_packet_type = Protocol::Server::ProfileEvents; - getProfileEvents(block); - state->block.emplace(std::move(block)); + sendProfileEvents(); return true; } } diff --git a/src/Client/LocalConnection.h b/src/Client/LocalConnection.h index 6282af5a139..eeafd1f882f 100644 --- a/src/Client/LocalConnection.h +++ b/src/Client/LocalConnection.h @@ -143,7 +143,7 @@ class LocalConnection : public IServerConnection, WithContext void updateProgress(const Progress & value); - void getProfileEvents(Block & block); + void sendProfileEvents(); bool pollImpl(); diff --git a/src/Client/Suggest.h b/src/Client/Suggest.h index 52fcc5f849e..a1a929180eb 100644 --- a/src/Client/Suggest.h +++ b/src/Client/Suggest.h @@ -28,8 +28,8 @@ class Suggest : public LineReader::Suggest, boost::noncopyable template void load(ContextPtr context, const ConnectionParameters & connection_parameters, Int32 suggestion_limit); - /// Older server versions cannot execute the query above. - static constexpr int MIN_SERVER_REVISION = 54406; + /// Older server versions cannot execute the query loading suggestions. + static constexpr int MIN_SERVER_REVISION = DBMS_MIN_PROTOCOL_VERSION_WITH_VIEW_IF_PERMITTED; private: void fetch(IServerConnection & connection, const ConnectionTimeouts & timeouts, const std::string & query); diff --git a/src/Core/ProtocolDefines.h b/src/Core/ProtocolDefines.h index 63c990e8c5d..63f2cabadd3 100644 --- a/src/Core/ProtocolDefines.h +++ b/src/Core/ProtocolDefines.h @@ -8,7 +8,6 @@ #define DBMS_MIN_REVISION_WITH_SERVER_DISPLAY_NAME 54372 #define DBMS_MIN_REVISION_WITH_VERSION_PATCH 54401 //#define DBMS_MIN_REVISION_WITH_SERVER_LOGS 54406 -//#define DBMS_MIN_REVISION_WITH_CLIENT_SUPPORT_EMBEDDED_DATA 54415 /// Minimum revision with exactly the same set of aggregation methods and rules to select them. /// Two-level (bucketed) aggregation is incompatible if servers are inconsistent in these rules /// (keys will be placed in different buckets and result will not be fully aggregated). @@ -60,3 +59,11 @@ //#define DBMS_MIN_PROTOCOL_VERSION_WITH_INITIAL_QUERY_START_TIME 54449 #define DBMS_MIN_PROTOCOL_VERSION_WITH_PARAMETERS 54459 + +#define DBMS_MIN_PROTOCOL_VERSION_WITH_PROFILE_EVENTS_IN_INSERT 54456 + +#define DBMS_MIN_PROTOCOL_VERSION_WITH_VIEW_IF_PERMITTED 54457 + +#define DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM 54458 + +#define DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY 54458 diff --git a/src/Dictionaries/ClickHouseDictionarySource.cpp b/src/Dictionaries/ClickHouseDictionarySource.cpp index 0d5605d6333..abc01b1b379 100644 --- a/src/Dictionaries/ClickHouseDictionarySource.cpp +++ b/src/Dictionaries/ClickHouseDictionarySource.cpp @@ -29,7 +29,7 @@ namespace ErrorCodes } static const std::unordered_set dictionary_allowed_keys = { - "host", "port", "user", "password", "db", "database", "table", + "host", "port", "user", "password", "quota_key", "db", "database", "table", "update_field", "update_lag", "invalidate_query", "query", "where", "name", "secure"}; namespace @@ -54,6 +54,7 @@ namespace configuration.db, configuration.user, configuration.password, + configuration.quota_key, "", /* cluster */ "", /* cluster_secret */ "ClickHouseDictionarySource", @@ -237,6 +238,7 @@ void registerDictionarySourceClickHouse(DictionarySourceFactory & factory) std::string host = config.getString(settings_config_prefix + ".host", "localhost"); std::string user = config.getString(settings_config_prefix + ".user", "default"); std::string password = config.getString(settings_config_prefix + ".password", ""); + std::string quota_key = config.getString(settings_config_prefix + ".quota_key", ""); std::string db = config.getString(settings_config_prefix + ".db", default_database); std::string table = config.getString(settings_config_prefix + ".table", ""); UInt16 port = static_cast(config.getUInt(settings_config_prefix + ".port", default_port)); @@ -252,6 +254,7 @@ void registerDictionarySourceClickHouse(DictionarySourceFactory & factory) host = configuration.host; user = configuration.username; password = configuration.password; + quota_key = configuration.quota_key; db = configuration.database; table = configuration.table; port = configuration.port; @@ -261,6 +264,7 @@ void registerDictionarySourceClickHouse(DictionarySourceFactory & factory) .host = host, .user = user, .password = password, + .quota_key = quota_key, .db = db, .table = table, .query = config.getString(settings_config_prefix + ".query", ""), diff --git a/src/Dictionaries/ClickHouseDictionarySource.h b/src/Dictionaries/ClickHouseDictionarySource.h index 6f972793cc3..0b131819cba 100644 --- a/src/Dictionaries/ClickHouseDictionarySource.h +++ b/src/Dictionaries/ClickHouseDictionarySource.h @@ -23,6 +23,7 @@ class ClickHouseDictionarySource final : public IDictionarySource const std::string host; const std::string user; const std::string password; + const std::string quota_key; const std::string db; const std::string table; const std::string query; diff --git a/src/Interpreters/Access/InterpreterCreateUserQuery.cpp b/src/Interpreters/Access/InterpreterCreateUserQuery.cpp index cd534989261..a47dfa943d1 100644 --- a/src/Interpreters/Access/InterpreterCreateUserQuery.cpp +++ b/src/Interpreters/Access/InterpreterCreateUserQuery.cpp @@ -14,7 +14,6 @@ namespace DB namespace ErrorCodes { extern const int BAD_ARGUMENTS; - } namespace { @@ -30,8 +29,8 @@ namespace { if (override_name) user.setName(override_name->toString()); - else if (!query.new_name.empty()) - user.setName(query.new_name); + else if (query.new_name) + user.setName(*query.new_name); else if (query.names->size() == 1) user.setName(query.names->front()->toString()); diff --git a/src/Interpreters/Access/InterpreterGrantQuery.cpp b/src/Interpreters/Access/InterpreterGrantQuery.cpp index a63b278566d..9d9fe9cf729 100644 --- a/src/Interpreters/Access/InterpreterGrantQuery.cpp +++ b/src/Interpreters/Access/InterpreterGrantQuery.cpp @@ -16,7 +16,6 @@ namespace DB { namespace ErrorCodes { - extern const int ACCESS_DENIED; extern const int LOGICAL_ERROR; } @@ -112,31 +111,6 @@ namespace } } - /// Checks if a grantee is allowed for the current user, throws an exception if not. - void checkGranteeIsAllowed(const ContextAccess & current_user_access, const UUID & grantee_id, const IAccessEntity & grantee) - { - auto current_user = current_user_access.getUser(); - if (current_user && !current_user->grantees.match(grantee_id)) - throw Exception(grantee.formatTypeWithName() + " is not allowed as grantee", ErrorCodes::ACCESS_DENIED); - } - - /// Checks if grantees are allowed for the current user, throws an exception if not. - void checkGranteesAreAllowed(const AccessControl & access_control, const ContextAccess & current_user_access, const std::vector & grantee_ids) - { - auto current_user = current_user_access.getUser(); - if (!current_user || (current_user->grantees == RolesOrUsersSet::AllTag{})) - return; - - for (const auto & id : grantee_ids) - { - auto entity = access_control.tryRead(id); - if (auto role = typeid_cast(entity)) - checkGranteeIsAllowed(current_user_access, id, *role); - else if (auto user = typeid_cast(entity)) - checkGranteeIsAllowed(current_user_access, id, *user); - } - } - /// Checks if the current user has enough access rights granted with grant option to grant or revoke specified access rights. void checkGrantOption( const AccessControl & access_control, @@ -171,13 +145,13 @@ namespace if (auto role = typeid_cast(entity)) { if (need_check_grantees_are_allowed) - checkGranteeIsAllowed(current_user_access, id, *role); + current_user_access.checkGranteeIsAllowed(id, *role); all_granted_access.makeUnion(role->access); } else if (auto user = typeid_cast(entity)) { if (need_check_grantees_are_allowed) - checkGranteeIsAllowed(current_user_access, id, *user); + current_user_access.checkGranteeIsAllowed(id, *user); all_granted_access.makeUnion(user->access); } } @@ -244,13 +218,13 @@ namespace if (auto role = typeid_cast(entity)) { if (need_check_grantees_are_allowed) - checkGranteeIsAllowed(current_user_access, id, *role); + current_user_access.checkGranteeIsAllowed(id, *role); all_granted_roles.makeUnion(role->granted_roles); } else if (auto user = typeid_cast(entity)) { if (need_check_grantees_are_allowed) - checkGranteeIsAllowed(current_user_access, id, *user); + current_user_access.checkGranteeIsAllowed(id, *user); all_granted_roles.makeUnion(user->granted_roles); } } @@ -365,7 +339,7 @@ BlockIO InterpreterGrantQuery::execute() checkAdminOption(access_control, *current_user_access, grantees, need_check_grantees_are_allowed, roles_to_grant, roles_to_revoke, query.admin_option); if (need_check_grantees_are_allowed) - checkGranteesAreAllowed(access_control, *current_user_access, grantees); + current_user_access->checkGranteesAreAllowed(grantees); /// Update roles and users listed in `grantees`. auto update_func = [&](const AccessEntityPtr & entity) -> AccessEntityPtr diff --git a/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp b/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp index 163cb57cab5..662d0595397 100644 --- a/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp +++ b/src/Interpreters/Access/InterpreterShowCreateAccessEntityQuery.cpp @@ -296,8 +296,7 @@ std::vector InterpreterShowCreateAccessEntityQuery::getEntities } else if (show_query.current_user) { - if (auto user = getContext()->getUser()) - entities.push_back(user); + entities.push_back(getContext()->getUser()); } else if (show_query.current_quota) { diff --git a/src/Interpreters/Access/InterpreterShowGrantsQuery.cpp b/src/Interpreters/Access/InterpreterShowGrantsQuery.cpp index 1c9d6b08a4c..056e599e40e 100644 --- a/src/Interpreters/Access/InterpreterShowGrantsQuery.cpp +++ b/src/Interpreters/Access/InterpreterShowGrantsQuery.cpp @@ -4,6 +4,9 @@ #include #include #include +#include +#include +#include #include #include #include @@ -135,15 +138,34 @@ QueryPipeline InterpreterShowGrantsQuery::executeImpl() std::vector InterpreterShowGrantsQuery::getEntities() const { - const auto & show_query = query_ptr->as(); + const auto & access = getContext()->getAccess(); const auto & access_control = getContext()->getAccessControl(); + + const auto & show_query = query_ptr->as(); auto ids = RolesOrUsersSet{*show_query.for_roles, access_control, getContext()->getUserID()}.getMatchingIDs(access_control); + CachedAccessChecking show_users(access, AccessType::SHOW_USERS); + CachedAccessChecking show_roles(access, AccessType::SHOW_ROLES); + bool throw_if_access_denied = !show_query.for_roles->all; + + auto current_user = access->getUser(); + auto roles_info = access->getRolesInfo(); + std::vector entities; for (const auto & id : ids) { auto entity = access_control.tryRead(id); - if (entity) + if (!entity) + continue; + + bool is_current_user = (id == access->getUserID()); + bool is_enabled_or_granted_role = entity->isTypeOf() + && (current_user->granted_roles.isGranted(id) || roles_info->enabled_roles.contains(id)); + + if ((is_current_user /* Any user can see his own grants */) + || (is_enabled_or_granted_role /* and grants from the granted roles */) + || (entity->isTypeOf() && show_users.checkAccess(throw_if_access_denied)) + || (entity->isTypeOf() && show_roles.checkAccess(throw_if_access_denied))) entities.push_back(entity); } diff --git a/src/Interpreters/Cluster.cpp b/src/Interpreters/Cluster.cpp index 37479d0a187..ea8b39b0004 100644 --- a/src/Interpreters/Cluster.cpp +++ b/src/Interpreters/Cluster.cpp @@ -408,7 +408,7 @@ Cluster::Cluster(const Poco::Util::AbstractConfiguration & config, auto pool = ConnectionPoolFactory::instance().get( static_cast(settings.distributed_connections_pool_size), address.host_name, address.port, - address.default_database, address.user, address.password, + address.default_database, address.user, address.password, address.quota_key, address.cluster, address.cluster_secret, "server", address.compression, address.secure, address.priority); @@ -481,7 +481,7 @@ Cluster::Cluster(const Poco::Util::AbstractConfiguration & config, auto replica_pool = ConnectionPoolFactory::instance().get( static_cast(settings.distributed_connections_pool_size), replica.host_name, replica.port, - replica.default_database, replica.user, replica.password, + replica.default_database, replica.user, replica.password, replica.quota_key, replica.cluster, replica.cluster_secret, "server", replica.compression, replica.secure, replica.priority); @@ -560,7 +560,7 @@ Cluster::Cluster( auto replica_pool = ConnectionPoolFactory::instance().get( static_cast(settings.distributed_connections_pool_size), replica.host_name, replica.port, - replica.default_database, replica.user, replica.password, + replica.default_database, replica.user, replica.password, replica.quota_key, replica.cluster, replica.cluster_secret, "server", replica.compression, replica.secure, replica.priority); all_replicas.emplace_back(replica_pool); @@ -668,6 +668,7 @@ Cluster::Cluster(Cluster::ReplicasAsShardsTag, const Cluster & from, const Setti address.default_database, address.user, address.password, + address.quota_key, address.cluster, address.cluster_secret, "server", diff --git a/src/Interpreters/Cluster.h b/src/Interpreters/Cluster.h index 3773dadaf13..435b7b56c0e 100644 --- a/src/Interpreters/Cluster.h +++ b/src/Interpreters/Cluster.h @@ -90,6 +90,7 @@ class Cluster UInt16 port{0}; String user; String password; + String quota_key; /// For inter-server authorization String cluster; diff --git a/src/Interpreters/Context.cpp b/src/Interpreters/Context.cpp index 547e46d267c..0dca7726771 100644 --- a/src/Interpreters/Context.cpp +++ b/src/Interpreters/Context.cpp @@ -841,7 +841,7 @@ void Context::setUserScriptsPath(const String & path) shared->user_scripts_path = path; } -void Context::addWarningMessage(const String & msg) +void Context::addWarningMessage(const String & msg) const { std::lock_guard lock(shared->mutex); shared->addWarningMessage(msg); @@ -905,6 +905,7 @@ void Context::setUser(const UUID & user_id_) user_id_, /* current_roles = */ {}, /* use_default_roles = */ true, settings, current_database, client_info); auto user = access->getUser(); + current_roles = std::make_shared>(user->granted_roles.findGranted(user->default_roles)); auto default_profile_info = access->getDefaultProfileInfo(); @@ -1021,7 +1022,7 @@ void Context::calculateAccessRightsWithLock(const std::lock_guardasynchronous_remote_fs_reader = std::make_unique(pool_size, queue_size); - + queue_size = config.getUInt(".threadpool_local_fs_reader_queue_size", 1000000); shared->asynchronous_local_fs_reader = std::make_unique(pool_size, queue_size); diff --git a/src/Interpreters/Context.h b/src/Interpreters/Context.h index 409cdd1b792..c39adc00c22 100644 --- a/src/Interpreters/Context.h +++ b/src/Interpreters/Context.h @@ -503,7 +503,7 @@ class Context: public ContextData, public std::enable_shared_from_this void setDictionariesLibPath(const String & path); void setUserScriptsPath(const String & path); - void addWarningMessage(const String & msg); + void addWarningMessage(const String & msg) const; VolumePtr setTemporaryStorage(const String & path, const String & policy_name = ""); diff --git a/src/Interpreters/Session.cpp b/src/Interpreters/Session.cpp index 2b06c3c79b6..d0fd3dfbcf5 100644 --- a/src/Interpreters/Session.cpp +++ b/src/Interpreters/Session.cpp @@ -10,6 +10,7 @@ #include #include #include +#include #include @@ -246,7 +247,6 @@ void Session::shutdownNamedSessions() Session::Session(const ContextPtr & global_context_, ClientInfo::Interface interface_) : auth_id(UUIDHelpers::generateV4()), global_context(global_context_), - interface(interface_), log(&Poco::Logger::get(String{magic_enum::enum_name(interface_)} + "-Session")) { prepared_client_info.emplace(); @@ -255,10 +255,9 @@ Session::Session(const ContextPtr & global_context_, ClientInfo::Interface inter Session::~Session() { - LOG_DEBUG(log, "{} Destroying {} of user {}", + LOG_DEBUG(log, "{} Destroying {}", toString(auth_id), - (named_session ? "named session '" + named_session->key.second + "'" : "unnamed session"), - (user_id ? toString(*user_id) : "") + (named_session ? "named session '" + named_session->key.second + "'" : "unnamed session") ); /// Early release a NamedSessionData. @@ -267,8 +266,8 @@ Session::~Session() if (notified_session_log_about_login) { - if (auto session_log = getSessionLog(); session_log && user) - session_log->addLogOut(auth_id, user->getName(), getClientInfo()); + if (auto session_log = getSessionLog()) + session_log->addLogOut(auth_id, user, getClientInfo()); } } @@ -313,13 +312,11 @@ void Session::authenticate(const Credentials & credentials_, const Poco::Net::So { user_id = global_context->getAccessControl().authenticate(credentials_, address.host()); LOG_DEBUG(log, "{} Authenticated with global context as user {}", - toString(auth_id), user_id ? toString(*user_id) : ""); + toString(auth_id), toString(*user_id)); } catch (const Exception & e) { - LOG_DEBUG(log, "{} Authentication failed with error: {}", toString(auth_id), e.what()); - if (auto session_log = getSessionLog()) - session_log->addLoginFailure(auth_id, *prepared_client_info, credentials_.getUserName(), e); + onAuthenticationFailure(credentials_.getUserName(), address, e); throw; } @@ -327,8 +324,21 @@ void Session::authenticate(const Credentials & credentials_, const Poco::Net::So prepared_client_info->current_address = address; } +void Session::onAuthenticationFailure(const std::optional & user_name, const Poco::Net::SocketAddress & address_, const Exception & e) +{ + LOG_DEBUG(log, "{} Authentication failed with error: {}", toString(auth_id), e.what()); + if (auto session_log = getSessionLog()) + { + /// Add source address to the log + auto info_for_log = *prepared_client_info; + info_for_log.current_address = address_; + session_log->addLoginFailure(auth_id, info_for_log, user_name, e); + } +} + ClientInfo & Session::getClientInfo() { + /// FIXME it may produce different info for LoginSuccess and the corresponding Logout entries in the session log return session_context ? session_context->getClientInfo() : *prepared_client_info; } @@ -343,9 +353,11 @@ ContextMutablePtr Session::makeSessionContext() throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR); if (query_context_created) throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR); + if (!user_id) + throw Exception("Session context must be created after authentication", ErrorCodes::LOGICAL_ERROR); LOG_DEBUG(log, "{} Creating session context with user_id: {}", - toString(auth_id), user_id ? toString(*user_id) : ""); + toString(auth_id), toString(*user_id)); /// Make a new session context. ContextMutablePtr new_session_context; new_session_context = Context::createCopy(global_context); @@ -373,9 +385,11 @@ ContextMutablePtr Session::makeSessionContext(const String & session_name_, std: throw Exception("Session context already exists", ErrorCodes::LOGICAL_ERROR); if (query_context_created) throw Exception("Session context must be created before any query context", ErrorCodes::LOGICAL_ERROR); + if (!user_id) + throw Exception("Session context must be created after authentication", ErrorCodes::LOGICAL_ERROR); LOG_DEBUG(log, "{} Creating named session context with name: {}, user_id: {}", - toString(auth_id), session_name_, user_id ? toString(*user_id) : ""); + toString(auth_id), session_name_, toString(*user_id)); /// Make a new session context OR /// if the `session_id` and `user_id` were used before then just get a previously created session context. @@ -395,7 +409,7 @@ ContextMutablePtr Session::makeSessionContext(const String & session_name_, std: prepared_client_info.reset(); /// Set user information for the new context: current profiles, roles, access rights. - if (user_id && !new_session_context->getUser()) + if (user_id && !new_session_context->getAccess()->tryGetUser()) new_session_context->setUser(*user_id); /// Session context is ready. @@ -419,11 +433,6 @@ ContextMutablePtr Session::makeQueryContext(ClientInfo && query_client_info) con std::shared_ptr Session::getSessionLog() const { - /// For the LOCAL interface we don't send events to the session log - /// because the LOCAL interface is internal, it does nothing with networking. - if (interface == ClientInfo::Interface::LOCAL) - return nullptr; - // take it from global context, since it outlives the Session and always available. // please note that server may have session_log disabled, hence this may return nullptr. return global_context->getSessionLog(); @@ -431,6 +440,9 @@ std::shared_ptr Session::getSessionLog() const ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_to_copy, ClientInfo * client_info_to_move) const { + if (!user_id && getClientInfo().interface != ClientInfo::Interface::TCP_INTERSERVER) + throw Exception("Session context must be created after authentication", ErrorCodes::LOGICAL_ERROR); + /// We can create a query context either from a session context or from a global context. bool from_session_context = static_cast(session_context); @@ -438,11 +450,14 @@ ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_t ContextMutablePtr query_context = Context::createCopy(from_session_context ? session_context : global_context); query_context->makeQueryContext(); - LOG_DEBUG(log, "{} Creating query context from {} context, user_id: {}, parent context user: {}", - toString(auth_id), - from_session_context ? "session" : "global", - user_id ? toString(*user_id) : "", - query_context->getUser() ? query_context->getUser()->getName() : ""); + if (auto query_context_user = query_context->getAccess()->tryGetUser()) + { + LOG_DEBUG(log, "{} Creating query context from {} context, user_id: {}, parent context user: {}", + toString(auth_id), + from_session_context ? "session" : "global", + toString(*user_id), + query_context_user->getName()); + } /// Copy the specified client info to the new query context. auto & res_client_info = query_context->getClientInfo(); @@ -472,21 +487,23 @@ ContextMutablePtr Session::makeQueryContextImpl(const ClientInfo * client_info_t query_context->enableRowPoliciesOfInitialUser(); /// Set user information for the new context: current profiles, roles, access rights. - if (user_id && !query_context->getUser()) + if (user_id && !query_context->getAccess()->tryGetUser()) query_context->setUser(*user_id); /// Query context is ready. query_context_created = true; - user = query_context->getUser(); + if (user_id) + user = query_context->getUser(); if (!notified_session_log_about_login) { - if (auto session_log = getSessionLog(); user && user_id && session_log) + if (auto session_log = getSessionLog()) { session_log->addLoginSuccess( auth_id, named_session ? std::optional(named_session->key.second) : std::nullopt, - *query_context); + *query_context, + user); notified_session_log_about_login = true; } diff --git a/src/Interpreters/Session.h b/src/Interpreters/Session.h index f937c73d1a8..e0cf7ed7cdf 100644 --- a/src/Interpreters/Session.h +++ b/src/Interpreters/Session.h @@ -51,6 +51,9 @@ class Session void authenticate(const String & user_name, const String & password, const Poco::Net::SocketAddress & address); void authenticate(const Credentials & credentials_, const Poco::Net::SocketAddress & address_); + /// Writes a row about login failure into session log (if enabled) + void onAuthenticationFailure(const std::optional & user_name, const Poco::Net::SocketAddress & address_, const Exception & e); + /// Returns a reference to session ClientInfo. ClientInfo & getClientInfo(); const ClientInfo & getClientInfo() const; @@ -79,7 +82,6 @@ class Session mutable bool notified_session_log_about_login = false; const UUID auth_id; const ContextPtr global_context; - const ClientInfo::Interface interface; /// ClientInfo that will be copied to a session context when it's created. std::optional prepared_client_info; diff --git a/src/Interpreters/SessionLog.cpp b/src/Interpreters/SessionLog.cpp index 9b852a6fa41..6bfb98b9b32 100644 --- a/src/Interpreters/SessionLog.cpp +++ b/src/Interpreters/SessionLog.cpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include #include @@ -92,7 +93,7 @@ NamesAndTypesList SessionLogElement::getNamesAndTypes() AUTH_TYPE_NAME_AND_VALUE(AuthType::SHA256_PASSWORD), AUTH_TYPE_NAME_AND_VALUE(AuthType::DOUBLE_SHA1_PASSWORD), AUTH_TYPE_NAME_AND_VALUE(AuthType::LDAP), - AUTH_TYPE_NAME_AND_VALUE(AuthType::KERBEROS) + AUTH_TYPE_NAME_AND_VALUE(AuthType::KERBEROS), }); #undef AUTH_TYPE_NAME_AND_VALUE @@ -126,8 +127,8 @@ NamesAndTypesList SessionLogElement::getNamesAndTypes() {"event_time", std::make_shared()}, {"event_time_microseconds", std::make_shared(6)}, - {"user", std::make_shared()}, - {"auth_type", std::move(identified_with_column)}, + {"user", std::make_shared(std::make_shared())}, + {"auth_type", std::make_shared(std::move(identified_with_column))}, {"profiles", std::make_shared(lc_string_datatype)}, {"roles", std::make_shared(lc_string_datatype)}, @@ -162,8 +163,9 @@ void SessionLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insert(event_time); columns[i++]->insert(event_time_microseconds); - columns[i++]->insert(user); - columns[i++]->insert(user_identified_with); + assert((user && user_identified_with) || client_info.interface == ClientInfo::Interface::TCP_INTERSERVER); + columns[i++]->insert(user ? Field(*user) : Field()); + columns[i++]->insert(user_identified_with ? Field(*user_identified_with) : Field()); fillColumnArray(profiles, *columns[i++]); fillColumnArray(roles, *columns[i++]); @@ -202,7 +204,7 @@ void SessionLogElement::appendToBlock(MutableColumns & columns) const columns[i++]->insertData(auth_failure_reason.data(), auth_failure_reason.length()); } -void SessionLog::addLoginSuccess(const UUID & auth_id, std::optional session_id, const Context & login_context) +void SessionLog::addLoginSuccess(const UUID & auth_id, std::optional session_id, const Context & login_context, const UserPtr & login_user) { const auto access = login_context.getAccess(); const auto & settings = login_context.getSettingsRef(); @@ -211,12 +213,12 @@ void SessionLog::addLoginSuccess(const UUID & auth_id, std::optional ses DB::SessionLogElement log_entry(auth_id, SESSION_LOGIN_SUCCESS); log_entry.client_info = client_info; + if (login_user) { - const auto user = access->getUser(); - log_entry.user = user->getName(); - log_entry.user_identified_with = user->auth_data.getType(); - log_entry.external_auth_server = user->auth_data.getLDAPServerName(); + log_entry.user = login_user->getName(); + log_entry.user_identified_with = login_user->auth_data.getType(); } + log_entry.external_auth_server = login_user ? login_user->auth_data.getLDAPServerName() : ""; if (session_id) log_entry.session_id = *session_id; @@ -225,7 +227,7 @@ void SessionLog::addLoginSuccess(const UUID & auth_id, std::optional ses log_entry.roles = roles_info->getCurrentRolesNames(); if (const auto profile_info = access->getDefaultProfileInfo()) - log_entry.profiles = profile_info->getProfileNames(); + log_entry.profiles = profile_info->getProfileNames(); for (const auto & s : settings.allChanged()) log_entry.settings.emplace_back(s.getName(), s.getValueString()); @@ -236,7 +238,7 @@ void SessionLog::addLoginSuccess(const UUID & auth_id, std::optional ses void SessionLog::addLoginFailure( const UUID & auth_id, const ClientInfo & info, - const String & user, + const std::optional & user, const Exception & reason) { SessionLogElement log_entry(auth_id, SESSION_LOGIN_FAILURE); @@ -249,10 +251,15 @@ void SessionLog::addLoginFailure( add(log_entry); } -void SessionLog::addLogOut(const UUID & auth_id, const String & user, const ClientInfo & client_info) +void SessionLog::addLogOut(const UUID & auth_id, const UserPtr & login_user, const ClientInfo & client_info) { auto log_entry = SessionLogElement(auth_id, SESSION_LOGOUT); - log_entry.user = user; + if (login_user) + { + log_entry.user = login_user->getName(); + log_entry.user_identified_with = login_user->auth_data.getType(); + } + log_entry.external_auth_server = login_user ? login_user->auth_data.getLDAPServerName() : ""; log_entry.client_info = client_info; add(log_entry); diff --git a/src/Interpreters/SessionLog.h b/src/Interpreters/SessionLog.h index 26f137565cb..9c671b2d04d 100644 --- a/src/Interpreters/SessionLog.h +++ b/src/Interpreters/SessionLog.h @@ -18,6 +18,8 @@ enum SessionLogElementType : int8_t }; class ContextAccess; +struct User; +using UserPtr = std::shared_ptr; /** A struct which will be inserted as row into session_log table. * @@ -44,8 +46,8 @@ struct SessionLogElement time_t event_time{}; Decimal64 event_time_microseconds{}; - String user; - AuthenticationType user_identified_with = AuthenticationType::NO_PASSWORD; + std::optional user; + std::optional user_identified_with; String external_auth_server; Strings roles; Strings profiles; @@ -69,9 +71,9 @@ class SessionLog : public SystemLog using SystemLog::SystemLog; public: - void addLoginSuccess(const UUID & auth_id, std::optional session_id, const Context & login_context); - void addLoginFailure(const UUID & auth_id, const ClientInfo & info, const String & user, const Exception & reason); - void addLogOut(const UUID & auth_id, const String & user, const ClientInfo & client_info); + void addLoginSuccess(const UUID & auth_id, std::optional session_id, const Context & login_context, const UserPtr & login_user); + void addLoginFailure(const UUID & auth_id, const ClientInfo & info, const std::optional & user, const Exception & reason); + void addLogOut(const UUID & auth_id, const UserPtr & login_user, const ClientInfo & client_info); }; } diff --git a/src/Interpreters/SystemLog.cpp b/src/Interpreters/SystemLog.cpp index 32784b86df1..2db50853345 100644 --- a/src/Interpreters/SystemLog.cpp +++ b/src/Interpreters/SystemLog.cpp @@ -201,7 +201,10 @@ SystemLogs::SystemLogs(ContextPtr global_context, const Poco::Util::AbstractConf if (zookeeper_log) logs.emplace_back(zookeeper_log.get()); if (session_log) + { logs.emplace_back(session_log.get()); + global_context->addWarningMessage("Table system.session_log is enabled. It's unreliable and may contain garbage. Do not use it for any kind of security monitoring."); + } if (processors_profile_log) logs.emplace_back(processors_profile_log.get()); if (cache_log) diff --git a/src/Parsers/Access/ASTCreateUserQuery.cpp b/src/Parsers/Access/ASTCreateUserQuery.cpp index 18030a5ed80..70122e50858 100644 --- a/src/Parsers/Access/ASTCreateUserQuery.cpp +++ b/src/Parsers/Access/ASTCreateUserQuery.cpp @@ -255,8 +255,8 @@ void ASTCreateUserQuery::formatImpl(const FormatSettings & format, FormatState & formatOnCluster(format); - if (!new_name.empty()) - formatRenameTo(new_name, format); + if (new_name) + formatRenameTo(*new_name, format); if (auth_data) formatAuthenticationData(*auth_data, show_password, format); diff --git a/src/Parsers/Access/ASTCreateUserQuery.h b/src/Parsers/Access/ASTCreateUserQuery.h index 92db71e8581..631d6c5eea9 100644 --- a/src/Parsers/Access/ASTCreateUserQuery.h +++ b/src/Parsers/Access/ASTCreateUserQuery.h @@ -42,7 +42,7 @@ class ASTCreateUserQuery : public IAST, public ASTQueryWithOnCluster bool or_replace = false; std::shared_ptr names; - String new_name; + std::optional new_name; std::optional auth_data; bool show_password = true; /// formatImpl() will show the password or hash. diff --git a/src/Parsers/Access/ParserCreateUserQuery.cpp b/src/Parsers/Access/ParserCreateUserQuery.cpp index 145364fb917..57e4ff06fab 100644 --- a/src/Parsers/Access/ParserCreateUserQuery.cpp +++ b/src/Parsers/Access/ParserCreateUserQuery.cpp @@ -22,14 +22,19 @@ namespace DB { namespace { - bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, String & new_name) + bool parseRenameTo(IParserBase::Pos & pos, Expected & expected, std::optional & new_name) { return IParserBase::wrapParseImpl(pos, [&] { if (!ParserKeyword{"RENAME TO"}.ignore(pos, expected)) return false; - return parseUserName(pos, expected, new_name); + String maybe_new_name; + if (!parseUserName(pos, expected, maybe_new_name)) + return false; + + new_name.emplace(std::move(maybe_new_name)); + return true; }); } @@ -359,7 +364,7 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec auto names = typeid_cast>(names_ast); auto names_ref = names->names; - String new_name; + std::optional new_name; std::optional auth_data; std::optional hosts; std::optional add_hosts; @@ -415,7 +420,7 @@ bool ParserCreateUserQuery::parseImpl(Pos & pos, ASTPtr & node, Expected & expec if (alter) { - if (new_name.empty() && (names->size() == 1) && parseRenameTo(pos, expected, new_name)) + if (!new_name && (names->size() == 1) && parseRenameTo(pos, expected, new_name)) continue; if (parseHosts(pos, expected, "ADD", new_hosts)) diff --git a/src/Parsers/Access/ParserUserNameWithHost.cpp b/src/Parsers/Access/ParserUserNameWithHost.cpp index 457629f0f76..ea5566becb9 100644 --- a/src/Parsers/Access/ParserUserNameWithHost.cpp +++ b/src/Parsers/Access/ParserUserNameWithHost.cpp @@ -8,6 +8,7 @@ namespace DB { + namespace { bool parseUserNameWithHost(IParserBase::Pos & pos, Expected & expected, std::shared_ptr & ast) @@ -18,8 +19,6 @@ namespace if (!parseIdentifierOrStringLiteral(pos, expected, base_name)) return false; - boost::algorithm::trim(base_name); - String host_pattern; if (ParserToken{TokenType::At}.ignore(pos, expected)) { diff --git a/src/QueryPipeline/RemoteInserter.cpp b/src/QueryPipeline/RemoteInserter.cpp index fee98beac06..1c00751b29a 100644 --- a/src/QueryPipeline/RemoteInserter.cpp +++ b/src/QueryPipeline/RemoteInserter.cpp @@ -60,6 +60,10 @@ RemoteInserter::RemoteInserter( if (auto log_queue = CurrentThread::getInternalTextLogsQueue()) log_queue->pushBlock(std::move(packet.block)); } + else if (Protocol::Server::ProfileEvents == packet.type) + { + // Do nothing + } else if (Protocol::Server::TableColumns == packet.type) { /// Server could attach ColumnsDescription in front of stream for column defaults. There's no need to pass it through cause @@ -120,6 +124,10 @@ void RemoteInserter::onFinish() { // Do nothing } + else if (Protocol::Server::ProfileEvents == packet.type) + { + // Do nothing + } else throw NetException( ErrorCodes::UNEXPECTED_PACKET_FROM_SERVER, diff --git a/src/Server/TCPHandler.cpp b/src/Server/TCPHandler.cpp index 07b1e53d245..1edb98d2788 100644 --- a/src/Server/TCPHandler.cpp +++ b/src/Server/TCPHandler.cpp @@ -91,9 +91,9 @@ namespace ErrorCodes extern const int POCO_EXCEPTION; extern const int SOCKET_TIMEOUT; extern const int UNEXPECTED_PACKET_FROM_CLIENT; - extern const int SUPPORT_IS_DISABLED; extern const int UNKNOWN_PROTOCOL; extern const int QUERY_WAS_CANCELLED; + extern const int AUTHENTICATION_FAILED; } TCPHandler::TCPHandler(IServer & server_, TCPServer & tcp_server_, const Poco::Net::StreamSocket & socket_, bool parse_proxy_protocol_, std::string server_display_name_, bool snapshot_mode_) @@ -126,7 +126,6 @@ void TCPHandler::runImpl() setThreadName("TCPHandler"); ThreadStatus thread_status; - session = std::make_unique(server.context(), ClientInfo::Interface::TCP); extractConnectionSettingsFromContext(server.context()); socket().setReceiveTimeout(receive_timeout); @@ -151,6 +150,8 @@ void TCPHandler::runImpl() { receiveHello(); sendHello(); + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_ADDENDUM) + receiveAddendum(); if (!is_interserver_mode) /// In interserver mode queries are executed without a session context. { @@ -377,7 +378,7 @@ void TCPHandler::runImpl() return true; sendProgress(); - sendProfileEvents(); + sendSelectProfileEvents(); sendLogs(); return false; @@ -540,6 +541,12 @@ void TCPHandler::runImpl() /// It is important to destroy query context here. We do not want it to live arbitrarily longer than the query. query_context.reset(); + if (is_interserver_mode) + { + /// We don't really have session in interserver mode, new one is created for each query. It's better to reset it now. + session.reset(); + } + if (network_error) break; } @@ -607,7 +614,10 @@ bool TCPHandler::readDataNext() } if (read_ok) + { sendLogs(); + sendInsertProfileEvents(); + } else state.read_all_data = true; @@ -680,6 +690,8 @@ void TCPHandler::processInsertQuery() PushingPipelineExecutor executor(state.io.pipeline); run_executor(executor); } + + sendInsertProfileEvents(); } @@ -728,7 +740,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors() /// Some time passed and there is a progress. after_send_progress.restart(); sendProgress(); - sendProfileEvents(); + sendSelectProfileEvents(); } sendLogs(); @@ -754,7 +766,7 @@ void TCPHandler::processOrdinaryQueryWithProcessors() sendProfileInfo(executor.getProfileInfo()); sendProgress(); sendLogs(); - sendProfileEvents(); + sendSelectProfileEvents(); } if (state.is_connection_closed) @@ -773,7 +785,7 @@ void TCPHandler::processTablesStatusRequest() TablesStatusRequest request; request.read(*in, client_tcp_protocol_version); - ContextPtr context_to_resolve_table_names = session->sessionContext() ? session->sessionContext() : server.context(); + ContextPtr context_to_resolve_table_names = (session && session->sessionContext()) ? session->sessionContext() : server.context(); TablesStatusResponse response; for (const QualifiedTableName & table_name: request.tables) @@ -897,9 +909,6 @@ void TCPHandler::sendExtremes(const Block & extremes) void TCPHandler::sendProfileEvents() { - if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS) - return; - Block block; ProfileEvents::getProfileEvents(server_display_name, state.profile_queue, block, last_sent_snapshots); if (block.rows() != 0) @@ -913,6 +922,21 @@ void TCPHandler::sendProfileEvents() } } +void TCPHandler::sendSelectProfileEvents() +{ + if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_INCREMENTAL_PROFILE_EVENTS) + return; + + sendProfileEvents(); +} + +void TCPHandler::sendInsertProfileEvents() +{ + if (client_tcp_protocol_version < DBMS_MIN_PROTOCOL_VERSION_WITH_PROFILE_EVENTS_IN_INSERT) + return; + + sendProfileEvents(); +} bool TCPHandler::receiveProxyHeader() { @@ -981,7 +1005,7 @@ bool TCPHandler::receiveProxyHeader() } LOG_TRACE(log, "Forwarded client address from PROXY header: {}", forwarded_address); - session->getClientInfo().forwarded_for = forwarded_address; + forwarded_for = std::move(forwarded_address); return true; } @@ -1008,6 +1032,30 @@ std::string formatHTTPErrorResponseWhenUserIsConnectedToWrongPort(const Poco::Ut } +std::unique_ptr TCPHandler::makeSession() +{ + auto interface = is_interserver_mode ? ClientInfo::Interface::TCP_INTERSERVER : ClientInfo::Interface::TCP; + + auto res = std::make_unique(server.context(), interface); + + auto & client_info = res->getClientInfo(); + client_info.forwarded_for = forwarded_for; + client_info.client_name = client_name; + client_info.client_version_major = client_version_major; + client_info.client_version_minor = client_version_minor; + client_info.client_version_patch = client_version_patch; + client_info.client_tcp_protocol_version = client_tcp_protocol_version; + + client_info.connection_client_version_major = client_version_major; + client_info.connection_client_version_minor = client_version_minor; + client_info.connection_client_version_patch = client_version_patch; + client_info.connection_tcp_protocol_version = client_tcp_protocol_version; + + client_info.quota_key = quota_key; + client_info.interface = interface; + + return res; +} void TCPHandler::receiveHello() { @@ -1051,29 +1099,27 @@ void TCPHandler::receiveHello() (!user.empty() ? ", user: " + user : "") ); - auto & client_info = session->getClientInfo(); - client_info.client_name = client_name; - client_info.client_version_major = client_version_major; - client_info.client_version_minor = client_version_minor; - client_info.client_version_patch = client_version_patch; - client_info.client_tcp_protocol_version = client_tcp_protocol_version; - - client_info.connection_client_version_major = client_version_major; - client_info.connection_client_version_minor = client_version_minor; - client_info.connection_client_version_patch = client_version_patch; - client_info.connection_tcp_protocol_version = client_tcp_protocol_version; - - is_interserver_mode = (user == USER_INTERSERVER_MARKER); + is_interserver_mode = (user == USER_INTERSERVER_MARKER) && password.empty(); if (is_interserver_mode) { - client_info.interface = ClientInfo::Interface::TCP_INTERSERVER; receiveClusterNameAndSalt(); return; } + session = makeSession(); session->authenticate(user, password, socket().peerAddress()); } +void TCPHandler::receiveAddendum() +{ + if (client_tcp_protocol_version >= DBMS_MIN_PROTOCOL_VERSION_WITH_QUOTA_KEY) + { + readStringBinary(quota_key, *in); + if (!is_interserver_mode) + session->getClientInfo().quota_key = quota_key; + } +} + void TCPHandler::receiveUnexpectedHello() { @@ -1253,25 +1299,6 @@ void TCPHandler::receiveClusterNameAndSalt() { readStringBinary(cluster, *in); readStringBinary(salt, *in, 32); - - try - { - if (salt.empty()) - throw NetException("Empty salt is not allowed", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); - - cluster_secret = server.context()->getCluster(cluster)->getSecret(); - } - catch (const Exception & e) - { - try - { - /// We try to send error information to the client. - sendException(e, send_exception_with_stack_trace); - } - catch (...) {} - - throw; - } } void TCPHandler::receiveQuery() @@ -1282,15 +1309,13 @@ void TCPHandler::receiveQuery() state.is_empty = false; readStringBinary(state.query_id, *in); - /// In interserer mode, + /// In interserver mode, /// initial_user can be empty in case of Distributed INSERT via Buffer/Kafka, /// (i.e. when the INSERT is done with the global context w/o user), /// so it is better to reset session to avoid using old user. if (is_interserver_mode) { - ClientInfo original_session_client_info = session->getClientInfo(); - session = std::make_unique(server.context(), ClientInfo::Interface::TCP_INTERSERVER); - session->getClientInfo() = original_session_client_info; + session = makeSession(); } /// Read client info. @@ -1329,22 +1354,35 @@ void TCPHandler::receiveQuery() if (is_interserver_mode) { + client_info.interface = ClientInfo::Interface::TCP_INTERSERVER; #if USE_SSL + String cluster_secret = server.context()->getCluster(cluster)->getSecret(); + if (salt.empty() || cluster_secret.empty()) + { + auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed"); + session->onAuthenticationFailure(/* user_name */ std::nullopt, socket().peerAddress(), exception); + throw exception; /// NOLINT + } + std::string data(salt); data += cluster_secret; data += state.query; data += state.query_id; data += client_info.initial_user; - if (received_hash.size() != 32) - throw NetException("Unexpected hash received from client", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); - std::string calculated_hash = encodeSHA256(data); + assert(calculated_hash.size() == 32); + /// TODO maybe also check that peer address actually belongs to the cluster? if (calculated_hash != received_hash) - throw NetException("Hash mismatch", ErrorCodes::UNEXPECTED_PACKET_FROM_CLIENT); - /// TODO: change error code? + { + auto exception = Exception(ErrorCodes::AUTHENTICATION_FAILED, "Interserver authentication failed"); + session->onAuthenticationFailure(/* user_name */ std::nullopt, socket().peerAddress(), exception); + throw exception; /// NOLINT + } + /// NOTE Usually we get some fields of client_info (including initial_address and initial_user) from user input, + /// so we should not rely on that. However, in this particular case we got client_info from other clickhouse-server, so it's ok. if (client_info.initial_user.empty()) { LOG_DEBUG(log, "User (no user, interserver mode)"); @@ -1355,9 +1393,11 @@ void TCPHandler::receiveQuery() session->authenticate(AlwaysAllowCredentials{client_info.initial_user}, client_info.initial_address); } #else - throw Exception( + auto exception = Exception( "Inter-server secret support is disabled, because proton was built without SSL library", - ErrorCodes::SUPPORT_IS_DISABLED); + ErrorCodes::AUTHENTICATION_FAILED); + session->onAuthenticationFailure(/* user_name */ std::nullopt, socket().peerAddress(), exception); + throw exception; /// NOLINT #endif } diff --git a/src/Server/TCPHandler.h b/src/Server/TCPHandler.h index 38c9fef4713..4fb3dfc0b8f 100644 --- a/src/Server/TCPHandler.h +++ b/src/Server/TCPHandler.h @@ -148,11 +148,14 @@ class TCPHandler : public Poco::Net::TCPServerConnection bool snapshot_mode = false; Poco::Logger * log; + String forwarded_for; + String client_name; UInt64 client_version_major = 0; UInt64 client_version_minor = 0; UInt64 client_version_patch = 0; UInt32 client_tcp_protocol_version = 0; + String quota_key; /// Connection settings, which are extracted from a context. bool send_exception_with_stack_trace = true; @@ -183,7 +186,6 @@ class TCPHandler : public Poco::Net::TCPServerConnection bool is_interserver_mode = false; String salt; String cluster; - String cluster_secret; std::mutex task_callback_mutex; std::mutex fatal_error_mutex; @@ -205,8 +207,11 @@ class TCPHandler : public Poco::Net::TCPServerConnection void extractConnectionSettingsFromContext(const ContextPtr & context); + std::unique_ptr makeSession(); + bool receiveProxyHeader(); void receiveHello(); + void receiveAddendum(); bool receivePacket(); void receiveQuery(); void receiveIgnoredPartUUIDs(); @@ -249,6 +254,8 @@ class TCPHandler : public Poco::Net::TCPServerConnection void sendTotals(const Block & totals); void sendExtremes(const Block & extremes); void sendProfileEvents(); + void sendSelectProfileEvents(); + void sendInsertProfileEvents(); /// Creates state.block_in/block_out for blocks read/write, depending on whether compression is enabled. void initBlockInput(); diff --git a/src/Storages/Distributed/DirectoryMonitor.cpp b/src/Storages/Distributed/DirectoryMonitor.cpp index 951fea13be6..522f9aae768 100644 --- a/src/Storages/Distributed/DirectoryMonitor.cpp +++ b/src/Storages/Distributed/DirectoryMonitor.cpp @@ -531,6 +531,7 @@ ConnectionPoolPtr StorageDistributedDirectoryMonitor::createPool(const std::stri address.default_database, address.user, address.password, + address.quota_key, address.cluster, address.cluster_secret, storage.getName() + '_' + address.user, /* client */ diff --git a/src/Storages/ExternalDataSourceConfiguration.cpp b/src/Storages/ExternalDataSourceConfiguration.cpp index 0fbd6225465..32196e1ceaf 100644 --- a/src/Storages/ExternalDataSourceConfiguration.cpp +++ b/src/Storages/ExternalDataSourceConfiguration.cpp @@ -21,7 +21,7 @@ namespace ErrorCodes IMPLEMENT_SETTINGS_TRAITS(EmptySettingsTraits, EMPTY_SETTINGS) static const std::unordered_set dictionary_allowed_keys = { - "host", "port", "user", "password", "db", + "host", "port", "user", "password", "quota_key", "db", "database", "table", "schema", "replica", "update_field", "update_tag", "invalidate_query", "query", "where", "name", "secure", "uri", "collection"}; @@ -70,6 +70,7 @@ void ExternalDataSourceConfiguration::set(const ExternalDataSourceConfiguration port = conf.port; username = conf.username; password = conf.password; + quota_key = conf.quota_key; database = conf.database; table = conf.table; schema = conf.schema; @@ -109,6 +110,7 @@ std::optional getExternalDataSourceConfiguration( configuration.port = config.getInt(collection_prefix + ".port", 0); configuration.username = config.getString(collection_prefix + ".user", ""); configuration.password = config.getString(collection_prefix + ".password", ""); + configuration.quota_key = config.getString(collection_prefix + ".quota_key", ""); configuration.database = config.getString(collection_prefix + ".database", ""); configuration.table = config.getString(collection_prefix + ".table", config.getString(collection_prefix + ".collection", "")); configuration.schema = config.getString(collection_prefix + ".schema", ""); @@ -155,6 +157,8 @@ std::optional getExternalDataSourceConfiguration( configuration.username = arg_value.safeGet(); else if (arg_name == "password") configuration.password = arg_value.safeGet(); + else if (arg_name == "quota_key") + configuration.quota_key = arg_value.safeGet(); else if (arg_name == "database") configuration.database = arg_value.safeGet(); else if (arg_name == "table") @@ -222,6 +226,7 @@ std::optional getExternalDataSourceConfiguration( configuration.port = dict_config.getInt(dict_config_prefix + ".port", config.getUInt(collection_prefix + ".port", 0)); configuration.username = dict_config.getString(dict_config_prefix + ".user", config.getString(collection_prefix + ".user", "")); configuration.password = dict_config.getString(dict_config_prefix + ".password", config.getString(collection_prefix + ".password", "")); + configuration.quota_key = dict_config.getString(dict_config_prefix + ".quota_key", config.getString(collection_prefix + ".quota_key", "")); configuration.database = dict_config.getString(dict_config_prefix + ".db", config.getString(dict_config_prefix + ".database", config.getString(collection_prefix + ".db", config.getString(collection_prefix + ".database", "")))); configuration.table = dict_config.getString(dict_config_prefix + ".table", config.getString(collection_prefix + ".table", "")); @@ -315,6 +320,7 @@ ExternalDataSourcesByPriority getExternalDataSourceConfigurationByPriority( common_configuration.port = dict_config.getUInt(dict_config_prefix + ".port", 0); common_configuration.username = dict_config.getString(dict_config_prefix + ".user", ""); common_configuration.password = dict_config.getString(dict_config_prefix + ".password", ""); + common_configuration.quota_key = dict_config.getString(dict_config_prefix + ".quota_key", ""); common_configuration.database = dict_config.getString(dict_config_prefix + ".db", dict_config.getString(dict_config_prefix + ".database", "")); common_configuration.table = dict_config.getString(fmt::format("{}.table", dict_config_prefix), ""); common_configuration.schema = dict_config.getString(fmt::format("{}.schema", dict_config_prefix), ""); @@ -346,6 +352,7 @@ ExternalDataSourcesByPriority getExternalDataSourceConfigurationByPriority( replica_configuration.port = dict_config.getUInt(replica_name + ".port", common_configuration.port); replica_configuration.username = dict_config.getString(replica_name + ".user", common_configuration.username); replica_configuration.password = dict_config.getString(replica_name + ".password", common_configuration.password); + replica_configuration.quota_key = dict_config.getString(replica_name + ".quota_key", common_configuration.quota_key); if (replica_configuration.host.empty() || replica_configuration.port == 0 || replica_configuration.username.empty() || replica_configuration.password.empty()) diff --git a/src/Storages/ExternalDataSourceConfiguration.h b/src/Storages/ExternalDataSourceConfiguration.h index 2d0258de77c..8c8405533eb 100644 --- a/src/Storages/ExternalDataSourceConfiguration.h +++ b/src/Storages/ExternalDataSourceConfiguration.h @@ -20,6 +20,7 @@ struct ExternalDataSourceConfiguration UInt16 port = 0; String username; String password; + String quota_key; String database; String table; String schema; diff --git a/src/Storages/StorageS3Cluster.cpp b/src/Storages/StorageS3Cluster.cpp index 11ce4168f70..67be8e84b86 100644 --- a/src/Storages/StorageS3Cluster.cpp +++ b/src/Storages/StorageS3Cluster.cpp @@ -125,7 +125,7 @@ Pipe StorageS3Cluster::read( { auto connection = std::make_shared( node.host_name, node.port, context->getGlobalContext()->getCurrentDatabase(), - node.user, node.password, node.cluster, node.cluster_secret, + node.user, node.password, node.quota_key, node.cluster, node.cluster_secret, "S3ClusterInititiator", node.compression, node.secure diff --git a/src/Storages/System/StorageSystemCurrentRoles.cpp b/src/Storages/System/StorageSystemCurrentRoles.cpp index a5b3566f5f7..cf7df0b8b99 100644 --- a/src/Storages/System/StorageSystemCurrentRoles.cpp +++ b/src/Storages/System/StorageSystemCurrentRoles.cpp @@ -26,8 +26,6 @@ void StorageSystemCurrentRoles::fillData(MutableColumns & res_columns, ContextPt { auto roles_info = context->getRolesInfo(); auto user = context->getUser(); - if (!roles_info || !user) - return; size_t column_index = 0; auto & column_role_name = assert_cast(*res_columns[column_index++]); diff --git a/src/Storages/System/StorageSystemEnabledRoles.cpp b/src/Storages/System/StorageSystemEnabledRoles.cpp index 99370dd647d..eec2f24c5b2 100644 --- a/src/Storages/System/StorageSystemEnabledRoles.cpp +++ b/src/Storages/System/StorageSystemEnabledRoles.cpp @@ -27,8 +27,6 @@ void StorageSystemEnabledRoles::fillData(MutableColumns & res_columns, ContextPt { auto roles_info = context->getRolesInfo(); auto user = context->getUser(); - if (!roles_info || !user) - return; size_t column_index = 0; auto & column_role_name = assert_cast(*res_columns[column_index++]); diff --git a/tests/config/config.d/clusters.xml b/tests/config/config.d/clusters.xml index 4941c5a00bd..ccefc9573c4 100644 --- a/tests/config/config.d/clusters.xml +++ b/tests/config/config.d/clusters.xml @@ -42,5 +42,20 @@ + + 123457 + + + 127.0.0.1 + 9000 + + + + + 127.0.0.2 + 9000 + + + diff --git a/tests/config/config.d/session_log.xml b/tests/config/config.d/session_log.xml new file mode 100644 index 00000000000..febad5b0c7c --- /dev/null +++ b/tests/config/config.d/session_log.xml @@ -0,0 +1,7 @@ + + + system + session_log
+ 100000 +
+
diff --git a/tests/config/install.sh b/tests/config/install.sh index 6335c951c31..1d6992a9999 100755 --- a/tests/config/install.sh +++ b/tests/config/install.sh @@ -43,6 +43,7 @@ ln -sf $SRC_PATH/config.d/zookeeper_log.xml $DEST_SERVER_PATH/config.d/ ln -sf $SRC_PATH/config.d/logger_test.xml $DEST_SERVER_PATH/config.d/ ln -sf $SRC_PATH/config.d/named_collection.xml $DEST_SERVER_PATH/config.d/ ln -sf $SRC_PATH/config.d/filesystem_cache_log.xml $DEST_SERVER_PATH/config.d/ +ln -sf $SRC_PATH/config.d/session_log.xml $DEST_SERVER_PATH/config.d/ ln -sf $SRC_PATH/users.d/log_queries.xml $DEST_SERVER_PATH/users.d/ ln -sf $SRC_PATH/users.d/readonly.xml $DEST_SERVER_PATH/users.d/ diff --git a/tests/integration/helpers/client.py b/tests/integration/helpers/client.py index b0e764bf174..e68f60389ea 100644 --- a/tests/integration/helpers/client.py +++ b/tests/integration/helpers/client.py @@ -16,34 +16,45 @@ def __init__(self, host, port=9000, command='/usr/bin/clickhouse-client'): self.command += ['--host', self.host, '--port', str(self.port), '--stacktrace'] - def query(self, sql, - stdin=None, - timeout=None, - settings=None, - user=None, - password=None, - database=None, - ignore_error=False, - query_id=None): - return self.get_query_request(sql, - stdin=stdin, - timeout=timeout, - settings=settings, - user=user, - password=password, - database=database, - ignore_error=ignore_error, - query_id=query_id).get_answer() - - def get_query_request(self, sql, - stdin=None, - timeout=None, - settings=None, - user=None, - password=None, - database=None, - ignore_error=False, - query_id=None): + def query( + self, + sql, + stdin=None, + timeout=None, + settings=None, + user=None, + password=None, + database=None, + host=None, + ignore_error=False, + query_id=None, + ): + return self.get_query_request( + sql, + stdin=stdin, + timeout=timeout, + settings=settings, + user=user, + password=password, + database=database, + host=host, + ignore_error=ignore_error, + query_id=query_id, + ).get_answer() + + def get_query_request( + self, + sql, + stdin=None, + timeout=None, + settings=None, + user=None, + password=None, + database=None, + host=None, + ignore_error=False, + query_id=None, + ): command = self.command[:] if stdin is None: @@ -57,14 +68,13 @@ def get_query_request(self, sql, command += ['--' + setting, str(value)] if user is not None: - command += ['--user', user] - + command += ["--user", user] if password is not None: - command += ['--password', password] - + command += ["--password", password] if database is not None: - command += ['--database', database] - + command += ["--database", database] + if host is not None: + command += ["--host", host] if query_id is not None: command += ['--query_id', query_id] diff --git a/tests/integration/test_backward_compatibility/test_insert_profile_events.py b/tests/integration/test_backward_compatibility/test_insert_profile_events.py new file mode 100644 index 00000000000..8047c088e4c --- /dev/null +++ b/tests/integration/test_backward_compatibility/test_insert_profile_events.py @@ -0,0 +1,42 @@ +# pylint: disable=line-too-long +# pylint: disable=unused-argument +# pylint: disable=redefined-outer-name + +import pytest + +from helpers.cluster import ClickHouseCluster + +cluster = ClickHouseCluster(__file__, name="insert_profile_events") +upstream_node = cluster.add_instance("upstream_node") +old_node = cluster.add_instance( + "old_node", + image="clickhouse/clickhouse-server", + tag="22.5.1.2079", + with_installed_binary=True, +) + + +@pytest.fixture(scope="module") +def start_cluster(): + try: + cluster.start() + yield cluster + + finally: + cluster.shutdown() + + +def test_old_client_compatible(start_cluster): + old_node.query("INSERT INTO FUNCTION null('foo String') VALUES ('foo')('bar')") + old_node.query( + "INSERT INTO FUNCTION null('foo String') VALUES ('foo')('bar')", + host=upstream_node.ip_address, + ) + + +def test_new_client_compatible(start_cluster): + upstream_node.query( + "INSERT INTO FUNCTION null('foo String') VALUES ('foo')('bar')", + host=old_node.ip_address, + ) + upstream_node.query("INSERT INTO FUNCTION null('foo String') VALUES ('foo')('bar')") diff --git a/tests/integration/test_distributed_inter_server_secret/test.py b/tests/integration/test_distributed_inter_server_secret/test.py index 73d338ba870..b4e3a373226 100644 --- a/tests/integration/test_distributed_inter_server_secret/test.py +++ b/tests/integration/test_distributed_inter_server_secret/test.py @@ -195,18 +195,27 @@ def test_secure_insert_buffer_async(): n1.query('TRUNCATE TABLE data_from_buffer ON CLUSTER secure') def test_secure_disagree(): - with pytest.raises(QueryRuntimeException, match='.*Hash mismatch.*'): - n1.query('SELECT * FROM dist_secure_disagree') + with pytest.raises( + QueryRuntimeException, match=".*Interserver authentication failed.*" + ): + n1.query("SELECT * FROM dist_secure_disagree") + def test_secure_disagree_insert(): n1.query("TRUNCATE TABLE data") - n1.query('INSERT INTO dist_secure_disagree SELECT * FROM numbers(2)') - with pytest.raises(QueryRuntimeException, match='.*Hash mismatch.*'): - n1.query('SYSTEM FLUSH DISTRIBUTED ON CLUSTER secure_disagree dist_secure_disagree') + n1.query("INSERT INTO dist_secure_disagree SELECT * FROM numbers(2)") + with pytest.raises( + QueryRuntimeException, match=".*Interserver authentication failed.*" + ): + n1.query( + "SYSTEM FLUSH DISTRIBUTED ON CLUSTER secure_disagree dist_secure_disagree" + ) # check the the connection will be re-established # IOW that we will not get "Unknown BlockInfo field" - with pytest.raises(QueryRuntimeException, match='.*Hash mismatch.*'): - assert int(n1.query('SELECT count() FROM dist_secure_disagree')) == 0 + with pytest.raises( + QueryRuntimeException, match=".*Interserver authentication failed.*" + ): + assert int(n1.query("SELECT count() FROM dist_secure_disagree")) == 0 @users def test_user_insecure_cluster(user, password): diff --git a/tests/integration/test_grant_and_revoke/test.py b/tests/integration/test_grant_and_revoke/test.py index b905e4df219..89e07fecb0a 100644 --- a/tests/integration/test_grant_and_revoke/test.py +++ b/tests/integration/test_grant_and_revoke/test.py @@ -250,6 +250,15 @@ def test_introspection(): assert instance.query("SHOW GRANTS", user='A') == TSV(["GRANT SELECT ON test.table TO A"]) assert instance.query("SHOW GRANTS", user='B') == TSV(["GRANT CREATE ON *.* TO B WITH GRANT OPTION"]) + assert instance.query("SHOW GRANTS FOR ALL", user='A') == TSV(["GRANT SELECT ON test.table TO A"]) + assert instance.query("SHOW GRANTS FOR ALL", user='B') == TSV(["GRANT CREATE ON *.* TO B WITH GRANT OPTION"]) + assert instance.query("SHOW GRANTS FOR ALL") == TSV(["GRANT SELECT ON test.table TO A", + "GRANT CREATE ON *.* TO B WITH GRANT OPTION", + "GRANT ALL ON *.* TO default WITH GRANT OPTION"]) + + expected_error = "necessary to have grant SHOW USERS" + assert expected_error in instance.query_and_get_error("SHOW GRANTS FOR B", user='A') + expected_access1 = "CREATE USER A\n" \ "CREATE USER B\n" \ "CREATE USER default IDENTIFIED WITH plaintext_password SETTINGS PROFILE default" diff --git a/tests/integration/test_role/test.py b/tests/integration/test_role/test.py index 7600bc73b16..54ed7e92ce8 100644 --- a/tests/integration/test_role/test.py +++ b/tests/integration/test_role/test.py @@ -1,4 +1,5 @@ import pytest +from helpers.client import QueryRuntimeException from helpers.cluster import ClickHouseCluster from helpers.test_tools import TSV @@ -205,6 +206,11 @@ def test_introspection(): ["GRANT SELECT ON test.table TO R2", "REVOKE SELECT(x) ON test.table FROM R2"]) assert instance.query("SHOW GRANTS", user='A') == TSV(["GRANT SELECT ON test.table TO A", "GRANT R1 TO A"]) + + assert instance.query("SHOW GRANTS FOR R1", user="A") == TSV([]) + with pytest.raises(QueryRuntimeException, match="Not enough privileges"): + assert instance.query("SHOW GRANTS FOR R2", user="A") + assert instance.query("SHOW GRANTS", user='B') == TSV( ["GRANT CREATE ON *.* TO B WITH GRANT OPTION", "GRANT R2 TO B WITH ADMIN OPTION"]) assert instance.query("SHOW CURRENT ROLES", user='A') == TSV([["R1", 0, 1]]) diff --git a/tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.reference b/tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.reference index e69de29bb2d..64ab61e6765 100644 --- a/tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.reference +++ b/tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.reference @@ -0,0 +1,2 @@ +0 +--progress produce some rows diff --git a/tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.sh b/tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.sh new file mode 100755 index 00000000000..6c37d870652 --- /dev/null +++ b/tests/queries/0_stateless/02310_clickhouse_client_INSERT_progress_profile_events.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# Tags: long + +# This is the regression for the concurrent access in ProgressIndication, +# so it is important to read enough rows here (10e6). +# +# Initially there was 100e6, but under thread fuzzer 10min may be not enough sometimes, +# but I believe that CI will catch possible issues even with less rows anyway. + +CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CUR_DIR"/../shell_config.sh + +tmp_file_progress="$(mktemp "$CUR_DIR/$CLICKHOUSE_TEST_UNIQUE_NAME.XXXXXX.progress")" +trap 'rm $tmp_file_progress' EXIT + +yes | head -n10000000 | $CLICKHOUSE_CLIENT -q "insert into function null('foo String') format TSV" --progress 2> "$tmp_file_progress" +echo $? +test -s "$tmp_file_progress" && echo "--progress produce some rows" || echo "FAIL: no rows with --progress" diff --git a/tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.reference b/tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.reference index e69de29bb2d..64ab61e6765 100644 --- a/tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.reference +++ b/tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.reference @@ -0,0 +1,2 @@ +0 +--progress produce some rows diff --git a/tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.sh b/tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.sh new file mode 100755 index 00000000000..00a8b7a2a90 --- /dev/null +++ b/tests/queries/0_stateless/02310_clickhouse_local_INSERT_progress_profile_events.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env bash +# Tags: long + +# This is the regression for the concurrent access in ProgressIndication, +# so it is important to read enough rows here (10e6). +# +# Initially there was 100e6, but under thread fuzzer 10min may be not enough sometimes, +# but I believe that CI will catch possible issues even with less rows anyway. + +CUR_DIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd) +# shellcheck source=../shell_config.sh +. "$CUR_DIR"/../shell_config.sh + +tmp_file_progress="$(mktemp "$CUR_DIR/$CLICKHOUSE_TEST_UNIQUE_NAME.XXXXXX.progress")" +trap 'rm $tmp_file_progress' EXIT + +yes | head -n10000000 | $CLICKHOUSE_LOCAL -q "insert into function null('foo String') format TSV" --progress 2> "$tmp_file_progress" +echo $? +test -s "$tmp_file_progress" && echo "--progress produce some rows" || echo "FAIL: no rows with --progress"