diff --git a/include/nebula/client/Config.h b/include/nebula/client/Config.h index f504c826..9118726c 100644 --- a/include/nebula/client/Config.h +++ b/include/nebula/client/Config.h @@ -15,8 +15,19 @@ struct Config { std::uint32_t idleTime_{0}; // in ms std::uint32_t maxConnectionPoolSize_{10}; std::uint32_t minConnectionPoolSize_{0}; - std::string CAPath_; + // Whether to enable SSL encryption bool enableSSL_{false}; + // Whether to enable mTLS + bool enableMTLS_{false}; + // Whether to check peer CN or SAN + bool checkPeerName_{false}; + std::string peerName_; + // Path to cert of CA + std::string CAPath_; + // Path to cert of client + std::string certPath_; + // path to private key of client + std::string keyPath_; }; } // namespace nebula diff --git a/include/nebula/client/Connection.h b/include/nebula/client/Connection.h index a11822bf..1eb4e688 100644 --- a/include/nebula/client/Connection.h +++ b/include/nebula/client/Connection.h @@ -11,6 +11,7 @@ #include "common/datatypes/Value.h" #include "common/graph/Response.h" +#include "nebula/client/Config.h" namespace folly { class ScopedEventBaseThread; @@ -50,8 +51,7 @@ class Connection { bool open(const std::string &address, int32_t port, uint32_t timeout, - bool enableSSL, - const std::string &CAPath); + const Config &cfg = Config{}); AuthResponse authenticate(const std::string &user, const std::string &password); diff --git a/include/nebula/mclient/MConfig.h b/include/nebula/mclient/MConfig.h index 66c981ae..3c461343 100644 --- a/include/nebula/mclient/MConfig.h +++ b/include/nebula/mclient/MConfig.h @@ -15,8 +15,19 @@ struct MConfig { int32_t connTimeoutInMs_{1000}; // It's as same as FLAG_meta_client_timeout_ms in nebula int32_t clientTimeoutInMs_{60 * 1000}; + // Whether to enable SSL encryption bool enableSSL_{false}; + // Whether to enable mTLS + bool enableMTLS_{false}; + // Whether to check peer CN or SAN + bool checkPeerName_{false}; + std::string peerName_; + // Path to cert of CA std::string CAPath_; + // Path to cert of client + std::string certPath_; + // path to private key of client + std::string keyPath_; }; } // namespace nebula diff --git a/include/nebula/sclient/SConfig.h b/include/nebula/sclient/SConfig.h index 669d2a1f..cb9cb40e 100644 --- a/include/nebula/sclient/SConfig.h +++ b/include/nebula/sclient/SConfig.h @@ -15,8 +15,19 @@ struct SConfig { int32_t connTimeoutInMs_{1000}; // It's as same as FLAG_meta_client_timeout_ms in nebula int32_t clientTimeoutInMs_{60 * 1000}; + // Whether to enable SSL encryption bool enableSSL_{false}; + // Whether to enable mTLS + bool enableMTLS_{false}; + // Whether to check peer CN or SAN + bool checkPeerName_{false}; + std::string peerName_; + // Path to cert of CA std::string CAPath_; + // Path to cert of client + std::string certPath_; + // path to private key of client + std::string keyPath_; }; } // namespace nebula diff --git a/src/SSLConfig.cpp b/src/SSLConfig.cpp index 76f3cb39..df62756d 100644 --- a/src/SSLConfig.cpp +++ b/src/SSLConfig.cpp @@ -9,17 +9,34 @@ namespace nebula { -std::shared_ptr createSSLContext(const std::string &CAPath) { - auto context = std::make_shared(); - if (!CAPath.empty()) { - context->loadTrustedCertificates(CAPath.c_str()); - // don't do peer name validation - context->authenticate(true, false); - // verify the server cert - context->setVerificationOption(folly::SSLContext::SSLVerifyPeerEnum::VERIFY); - } - folly::ssl::setSignatureAlgorithms(*context); - return context; +std::shared_ptr createSSLContext(const SSLConfig &cfg) { + if (cfg.check_peer_name && cfg.peer_name.empty()) { + throw std::runtime_error("peer name checking enabled but not provied"); + } + + if (cfg.enable_mtls && (cfg.cert_path.empty() || cfg.key_path.empty())) { + throw std::runtime_error("mTLS enabled but cert/key not provided"); + } + + auto context = std::make_shared(); + + if (!cfg.ca_path.empty()) { + context->loadTrustedCertificates(cfg.ca_path.c_str()); + if (cfg.check_peer_name) { + context->authenticate(true, true, cfg.peer_name); + } else { + context->authenticate(true, false); + } + context->setVerificationOption(folly::SSLContext::SSLVerifyPeerEnum::VERIFY); + } + + if (cfg.enable_mtls) { + context->loadCertKeyPairFromFiles(cfg.cert_path.c_str(), cfg.key_path.c_str()); + } + + folly::ssl::setSignatureAlgorithms(*context); + + return context; } } // namespace nebula diff --git a/src/SSLConfig.h b/src/SSLConfig.h index 18c8834e..e8201fcb 100644 --- a/src/SSLConfig.h +++ b/src/SSLConfig.h @@ -9,6 +9,20 @@ namespace nebula { -std::shared_ptr createSSLContext(const std::string &CAPath); +struct SSLConfig final { + // Whether to enable mTLS(mutual TLS authentication) + bool enable_mtls{false}; + // Check whether the given peername matches the CN or SAN in the certificate + bool check_peer_name{false}; + std::string peer_name; + // Path to certificate(s) of the CA used to authenticate the cert of server + std::string ca_path; + // Path to the client cert, must be present if mTLS enabled + std::string cert_path; + // Path to the client private key, must be present if mTLS enabled + std::string key_path; +}; + +std::shared_ptr createSSLContext(const SSLConfig &cfg); } // namespace nebula diff --git a/src/client/Connection.cpp b/src/client/Connection.cpp index d62f0f65..5a7b54d2 100644 --- a/src/client/Connection.cpp +++ b/src/client/Connection.cpp @@ -76,8 +76,7 @@ Connection &Connection::operator=(Connection &&c) { bool Connection::open(const std::string &address, int32_t port, uint32_t timeout, - bool enableSSL, - const std::string &CAPath) { + const Config &cfg) { if (address.empty()) { return false; } @@ -91,10 +90,17 @@ bool Connection::open(const std::string &address, return false; } clientLoopThread_->getEventBase()->runImmediatelyOrRunInEventBaseThreadAndWait( - [this, &complete, &socket, timeout, &socketAddr, enableSSL, &CAPath]() { + [this, &complete, &socket, timeout, &socketAddr, &cfg]() { try { - if (enableSSL) { - socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(CAPath), + if (cfg.enableSSL_) { + SSLConfig sslcfg; + sslcfg.enable_mtls = cfg.enableMTLS_; + sslcfg.check_peer_name = cfg.checkPeerName_; + sslcfg.peer_name = cfg.peerName_; + sslcfg.ca_path = cfg.CAPath_; + sslcfg.cert_path = cfg.certPath_; + sslcfg.key_path = cfg.keyPath_; + socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(sslcfg), clientLoopThread_->getEventBase()); socket->connect(nullptr, std::move(socketAddr), timeout); } else { diff --git a/src/client/ConnectionPool.cpp b/src/client/ConnectionPool.cpp index ed3b44e1..dcb009ad 100644 --- a/src/client/ConnectionPool.cpp +++ b/src/client/ConnectionPool.cpp @@ -96,8 +96,7 @@ void ConnectionPool::newConnection(std::size_t cursor, std::size_t count) { if (conn.open(address_[addrCursor].first, address_[addrCursor].second, config_.timeout_, - config_.enableSSL_, - config_.CAPath_)) { + config_)) { ++connectionCount; conns_.emplace_back(std::move(conn)); } diff --git a/src/client/tests/ConnectionSSLTest.cpp b/src/client/tests/ConnectionSSLTest.cpp index c8e1ce3c..44cb3236 100644 --- a/src/client/tests/ConnectionSSLTest.cpp +++ b/src/client/tests/ConnectionSSLTest.cpp @@ -20,8 +20,10 @@ class ConnectionTest : public ClientTest {}; TEST_F(ConnectionTest, SSL) { nebula::Connection c; + nebula::Config cfg; + cfg.enableSSL_ = true; - ASSERT_TRUE(c.open(kServerHost, 9669, 10, true, "")); + ASSERT_TRUE(c.open(kServerHost, 9669, 10, cfg)); // auth auto authResp = c.authenticate("root", "nebula"); @@ -38,7 +40,10 @@ TEST_F(ConnectionTest, SSL) { TEST_F(ConnectionTest, SSCA) { { nebula::Connection c; - ASSERT_TRUE(c.open(kServerHost, 9669, 10, true, "./test.ca.pem")); + nebula::Config cfg; + cfg.enableSSL_ = true; + cfg.CAPath_ = "./test.ca.pem"; + ASSERT_TRUE(c.open(kServerHost, 9669, 10, cfg)); // auth auto authResp = c.authenticate("root", "nebula"); @@ -55,7 +60,10 @@ TEST_F(ConnectionTest, SSCA) { { // mismatch nebula::Connection c; - ASSERT_FALSE(c.open(kServerHost, 9669, 10, true, "./test.2.crt")); + nebula::Config cfg; + cfg.enableSSL_ = true; + cfg.CAPath_ = "./test.2.crt"; + ASSERT_FALSE(c.open(kServerHost, 9669, 10, cfg)); } } diff --git a/src/client/tests/ConnectionTest.cpp b/src/client/tests/ConnectionTest.cpp index ea89fb83..b243afdd 100644 --- a/src/client/tests/ConnectionTest.cpp +++ b/src/client/tests/ConnectionTest.cpp @@ -38,7 +38,7 @@ class ConnectionTest : public ClientTest { }); // open - ASSERT_TRUE(c.open(kServerHost, 9669, 0, false, "")); + ASSERT_TRUE(c.open(kServerHost, 9669, 0)); // ping EXPECT_TRUE(c.ping()); @@ -128,7 +128,7 @@ TEST_F(ConnectionTest, Basic) { TEST_F(ConnectionTest, Timeout) { nebula::Connection c; - ASSERT_TRUE(c.open(kServerHost, 9669, 100, false, "")); + ASSERT_TRUE(c.open(kServerHost, 9669, 100)); // auth auto authResp = c.authenticate("root", "nebula"); @@ -167,7 +167,7 @@ TEST_F(ConnectionTest, Timeout) { TEST_F(ConnectionTest, JsonResult) { nebula::Connection c; - ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, "")); + ASSERT_TRUE(c.open(kServerHost, 9669, 10)); // auth auto authResp = c.authenticate("root", "nebula"); @@ -187,7 +187,7 @@ TEST_F(ConnectionTest, JsonResult) { TEST_F(ConnectionTest, DurationResult) { nebula::Connection c; - ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, "")); + ASSERT_TRUE(c.open(kServerHost, 9669, 10)); // auth auto authResp = c.authenticate("root", "nebula"); @@ -204,7 +204,7 @@ TEST_F(ConnectionTest, DurationResult) { TEST_F(ConnectionTest, ExecuteParameter) { nebula::Connection c; - ASSERT_TRUE(c.open(kServerHost, 9669, 10, false, "")); + ASSERT_TRUE(c.open(kServerHost, 9669, 10)); // auth auto authResp = c.authenticate("root", "nebula"); @@ -232,13 +232,13 @@ TEST_F(ConnectionTest, ExecuteParameter) { TEST_F(ConnectionTest, InvalidPort) { nebula::Connection c; - ASSERT_FALSE(c.open(kServerHost, 2333, 10, false, "")); + ASSERT_FALSE(c.open(kServerHost, 2333, 10)); } TEST_F(ConnectionTest, InvalidHost) { nebula::Connection c; - ASSERT_FALSE(c.open("Invalid Host", 9669, 10, false, "")); + ASSERT_FALSE(c.open("Invalid Host", 9669, 10)); } int main(int argc, char **argv) { diff --git a/src/client/tests/RegistHost.cpp b/src/client/tests/RegistHost.cpp index d72fa294..d616dd5f 100644 --- a/src/client/tests/RegistHost.cpp +++ b/src/client/tests/RegistHost.cpp @@ -25,7 +25,7 @@ int main(int argc, char** argv) { google::SetStderrLogging(google::INFO); nebula::ConnectionPool pool; - nebula::Config c{10, 0, 300, 0, "", FLAGS_enable_ssl}; + nebula::Config c{10, 0, 300, 0, FLAGS_enable_ssl, false, false, "", "", "", ""}; pool.init({FLAGS_server}, c); auto session = pool.getSession("root", "nebula"); CHECK(session.valid()); diff --git a/src/client/tests/SessionPoolTest.cpp b/src/client/tests/SessionPoolTest.cpp index 39c76fb6..fe2105ce 100644 --- a/src/client/tests/SessionPoolTest.cpp +++ b/src/client/tests/SessionPoolTest.cpp @@ -25,7 +25,8 @@ class SessionPoolTest : public ClientTest { protected: void SetUp() { nebula::ConnectionPool pool; - pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 1, 0, "", false}); + nebula::Config cfg{0, 0, 1, 0, false, false, false, "", "", "", ""}; + pool.init({kServerHost ":9669"}, cfg); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); @@ -41,7 +42,8 @@ class SessionPoolTest : public ClientTest { void TearDown() { nebula::ConnectionPool pool; - pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 1, 0, "", false}); + nebula::Config cfg{0, 0, 1, 0, false, false, false, "", "", "", ""}; + pool.init({kServerHost ":9669"}, cfg); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); diff --git a/src/client/tests/SessionSSLTest.cpp b/src/client/tests/SessionSSLTest.cpp index ef036cff..73cd3f35 100644 --- a/src/client/tests/SessionSSLTest.cpp +++ b/src/client/tests/SessionSSLTest.cpp @@ -28,7 +28,7 @@ class SessionTest : public ClientTest {}; TEST_F(SessionTest, SSL) { nebula::ConnectionPool pool; - nebula::Config c{10, 0, 10, 0, "", true}; + nebula::Config c{10, 0, 10, 0, true, false, false, "", "", "", ""}; pool.init({kServerHost ":9669"}, c); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); diff --git a/src/client/tests/SessionTest.cpp b/src/client/tests/SessionTest.cpp index 1b350f99..27edb055 100644 --- a/src/client/tests/SessionTest.cpp +++ b/src/client/tests/SessionTest.cpp @@ -144,7 +144,7 @@ TEST_F(SessionTest, InvalidAddress) { TEST_F(SessionTest, Data) { nebula::ConnectionPool pool; - nebula::Config c{10, 0, 300, 0, "", false}; + nebula::Config c{10, 0, 300, 0, false, false, false, "", "", "", ""}; pool.init({kServerHost ":9669"}, c); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); @@ -192,7 +192,7 @@ TEST_F(SessionTest, Data) { TEST_F(SessionTest, Timeout) { nebula::ConnectionPool pool; - nebula::Config c{10, 0, 100, 0, "", false}; + nebula::Config c{10, 0, 100, 0, false, false, false, "", "", "", ""}; pool.init({kServerHost ":9669"}, c); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); @@ -228,7 +228,7 @@ TEST_F(SessionTest, Timeout) { TEST_F(SessionTest, JsonResult) { nebula::ConnectionPool pool; - nebula::Config c{10, 0, 10, 0, "", false}; + nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""}; pool.init({kServerHost ":9669"}, c); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); @@ -246,7 +246,7 @@ TEST_F(SessionTest, JsonResult) { TEST_F(SessionTest, DurationResult) { nebula::ConnectionPool pool; - nebula::Config c{10, 0, 10, 0, "", false}; + nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""}; pool.init({kServerHost ":9669"}, c); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); @@ -261,7 +261,7 @@ TEST_F(SessionTest, DurationResult) { TEST_F(SessionTest, ExecuteParameter) { nebula::ConnectionPool pool; - nebula::Config c{10, 0, 10, 0, "", false}; + nebula::Config c{10, 0, 10, 0, false, false, false, "", "", "", ""}; pool.init({kServerHost ":9669"}, c); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); diff --git a/src/mclient/MetaClient.cpp b/src/mclient/MetaClient.cpp index 6b748d6f..7715519c 100644 --- a/src/mclient/MetaClient.cpp +++ b/src/mclient/MetaClient.cpp @@ -24,10 +24,17 @@ MetaClient::MetaClient(const std::vector& metaAddrs, const MConfig& } CHECK(!metaAddrs_.empty()) << "metaAddrs_ is empty"; mConfig_ = mConfig; + SSLConfig sslcfg; + sslcfg.enable_mtls = mConfig_.enableMTLS_; + sslcfg.check_peer_name = mConfig_.checkPeerName_; + sslcfg.peer_name = mConfig_.peerName_; + sslcfg.ca_path = mConfig_.CAPath_; + sslcfg.cert_path = mConfig_.certPath_; + sslcfg.key_path = mConfig_.keyPath_; ioExecutor_ = std::make_shared(std::thread::hardware_concurrency()); clientsMan_ = std::make_shared>( - mConfig_.connTimeoutInMs_, mConfig_.enableSSL_, mConfig_.CAPath_); + mConfig_.connTimeoutInMs_, mConfig_.enableSSL_, sslcfg); bool b = loadData(); // load data into cache if (!b) { LOG(ERROR) << "load data failed"; diff --git a/src/mclient/tests/MetaClientSSLTest.cpp b/src/mclient/tests/MetaClientSSLTest.cpp index 85357cec..9f2c4bff 100644 --- a/src/mclient/tests/MetaClientSSLTest.cpp +++ b/src/mclient/tests/MetaClientSSLTest.cpp @@ -24,7 +24,9 @@ class MetaClientTest : public MClientTest { protected: static void prepare() { nebula::ConnectionPool pool; - pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 10, 0, "", true}); + nebula::Config cfg; + cfg.enableSSL_ = true; + pool.init({kServerHost ":9669"}, cfg); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); EXPECT_TRUE(session.ping()); @@ -74,7 +76,7 @@ TEST_F(MetaClientTest, SSL) { prepare(); LOG(INFO) << "Run once."; - nebula::MConfig mConfig{1000, 60 * 1000, true, ""}; + nebula::MConfig mConfig{1000, 60 * 1000, true, false, false, "", "", "", ""}; nebula::MetaClient c({kServerHost ":9559"}, mConfig); runOnce(c); } diff --git a/src/sclient/StorageClient.cpp b/src/sclient/StorageClient.cpp index 5a70e577..869c6fa6 100644 --- a/src/sclient/StorageClient.cpp +++ b/src/sclient/StorageClient.cpp @@ -21,10 +21,17 @@ StorageClient::StorageClient(const std::vector& metaAddrs, : user_(user), password_(password) { mClient_ = std::make_unique(metaAddrs, mConfig); sConfig_ = sConfig; + SSLConfig sslcfg; + sslcfg.enable_mtls = sConfig_.enableMTLS_; + sslcfg.check_peer_name = sConfig_.checkPeerName_; + sslcfg.peer_name = sConfig_.peerName_; + sslcfg.ca_path = sConfig_.CAPath_; + sslcfg.cert_path = sConfig_.certPath_; + sslcfg.key_path = sConfig_.keyPath_; ioExecutor_ = std::make_shared(std::thread::hardware_concurrency()); clientsMan_ = std::make_shared>( - sConfig.connTimeoutInMs_, sConfig.enableSSL_, sConfig.CAPath_); + sConfig.connTimeoutInMs_, sConfig.enableSSL_, sslcfg); } StorageClient::~StorageClient() = default; diff --git a/src/sclient/tests/StorageClientSSLTest.cpp b/src/sclient/tests/StorageClientSSLTest.cpp index 88631124..c09b6933 100644 --- a/src/sclient/tests/StorageClientSSLTest.cpp +++ b/src/sclient/tests/StorageClientSSLTest.cpp @@ -25,7 +25,9 @@ class StorageClientTest : public SClientTest { protected: static void prepare() { nebula::ConnectionPool pool; - pool.init({kServerHost ":9669"}, nebula::Config{0, 0, 10, 0, "", true}); + nebula::Config c; + c.enableSSL_ = true; + pool.init({kServerHost ":9669"}, c); auto session = pool.getSession("root", "nebula"); ASSERT_TRUE(session.valid()); EXPECT_TRUE(session.ping()); @@ -198,8 +200,8 @@ class StorageClientTest : public SClientTest { TEST_F(StorageClientTest, SSL) { LOG(INFO) << "Prepare data."; prepare(); - nebula::MConfig mConfig{1000, 60 * 1000, true, ""}; - nebula::SConfig sConfig{1000, 60 * 1000, true, ""}; + nebula::MConfig mConfig{1000, 60 * 1000, true, false, false, "", "", "", ""}; + nebula::SConfig sConfig{1000, 60 * 1000, true, false, false, "", "", "", ""}; nebula::StorageClient c({kServerHost ":9559"}, "root", "nebula", mConfig, sConfig); auto *m = c.getMetaClient(); LOG(INFO) << "Testing run once of meta client"; diff --git a/src/thrift/ThriftClientManager-inl.h b/src/thrift/ThriftClientManager-inl.h index 0860ea53..b07a0256 100644 --- a/src/thrift/ThriftClientManager-inl.h +++ b/src/thrift/ThriftClientManager-inl.h @@ -71,7 +71,7 @@ std::shared_ptr ThriftClientManager::client(const HostAd std::shared_ptr socket; evb->runImmediatelyOrRunInEventBaseThreadAndWait([this, &socket, evb, resolved]() { if (enableSSL_) { - socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(CAPath_), evb); + socket = folly::AsyncSSLSocket::newSocket(nebula::createSSLContext(sslcfg_), evb); socket->connect(nullptr, resolved.host, resolved.port, connTimeoutInMs_); } else { socket = folly::AsyncSocket::newSocket(evb, resolved.host, resolved.port, connTimeoutInMs_); diff --git a/src/thrift/ThriftClientManager.h b/src/thrift/ThriftClientManager.h index 3f2f20dc..5c45b28d 100644 --- a/src/thrift/ThriftClientManager.h +++ b/src/thrift/ThriftClientManager.h @@ -9,6 +9,7 @@ #include #include "common/datatypes/HostAddr.h" +#include "../SSLConfig.h" namespace nebula { namespace thrift { @@ -25,8 +26,8 @@ class ThriftClientManager final { VLOG(3) << "~ThriftClientManager"; } - explicit ThriftClientManager(int32_t connTimeoutInMs, bool enableSSL, const std::string& CAPath) - : connTimeoutInMs_(connTimeoutInMs), enableSSL_(enableSSL), CAPath_(CAPath) { + explicit ThriftClientManager(int32_t connTimeoutInMs, bool enableSSL, SSLConfig cfg = SSLConfig()) + : connTimeoutInMs_(connTimeoutInMs), enableSSL_(enableSSL), sslcfg_(std::move(cfg)) { VLOG(3) << "ThriftClientManager"; } @@ -38,7 +39,7 @@ class ThriftClientManager final { int32_t connTimeoutInMs_; // whether enable ssl bool enableSSL_; - std::string CAPath_; + SSLConfig sslcfg_; }; } // namespace thrift