diff --git a/api/envoy/extensions/transport_sockets/tls/v3/tls.proto b/api/envoy/extensions/transport_sockets/tls/v3/tls.proto index 7ee7920c724d..86e7bf4639c9 100644 --- a/api/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/api/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -238,4 +238,7 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaker. + config.core.v3.TypedExtensionConfig custom_listener_handshaker = 13; } diff --git a/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto b/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto index a73ba6e002ba..1ee9bc1772be 100644 --- a/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto +++ b/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto @@ -243,4 +243,7 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaker. + config.core.v4alpha.TypedExtensionConfig custom_listener_handshaker = 13; } diff --git a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto index 7ee7920c724d..86e7bf4639c9 100644 --- a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -238,4 +238,7 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaker. + config.core.v3.TypedExtensionConfig custom_listener_handshaker = 13; } diff --git a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto index a73ba6e002ba..13e6b740db8f 100644 --- a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto +++ b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto @@ -97,6 +97,14 @@ message DownstreamTlsContext { }]; } +// Custom config for customizing listener handshaker behavior. +message CustomListenerHandshaker { + option (udpa.annotations.versioning).previous_message_type = + "envoy.extensions.transport_sockets.tls.v3.CustomListenerHandshaker"; + + config.core.v4alpha.TypedExtensionConfig typed_config = 1; +} + // TLS context shared by both client and server TLS contexts. // [#next-free-field: 13] message CommonTlsContext { @@ -243,4 +251,7 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaker. + config.core.v4alpha.TypedExtensionConfig custom_listener_handshaker = 13; } diff --git a/include/envoy/ssl/BUILD b/include/envoy/ssl/BUILD index b8e7d530174f..7005cb70006d 100644 --- a/include/envoy/ssl/BUILD +++ b/include/envoy/ssl/BUILD @@ -28,6 +28,7 @@ envoy_cc_library( hdrs = ["context_config.h"], deps = [ ":certificate_validation_context_config_interface", + ":handshaker_interface", ":tls_certificate_config_interface", ], ) @@ -68,3 +69,21 @@ envoy_cc_library( deps = [ ], ) + +envoy_cc_library( + name = "handshaker_interface", + hdrs = ["handshaker.h"], + external_deps = ["ssl"], + deps = [ + ":socket_state", + "//include/envoy/config:typed_config_interface", + "//include/envoy/network:transport_socket_interface", + "//include/envoy/protobuf:message_validator_interface", + ], +) + +envoy_cc_library( + name = "socket_state", + hdrs = ["socket_state.h"], + deps = [], +) diff --git a/include/envoy/ssl/context_config.h b/include/envoy/ssl/context_config.h index 9196a5a294a9..0db6c366487b 100644 --- a/include/envoy/ssl/context_config.h +++ b/include/envoy/ssl/context_config.h @@ -8,6 +8,7 @@ #include "envoy/common/pure.h" #include "envoy/ssl/certificate_validation_context_config.h" +#include "envoy/ssl/handshaker.h" #include "envoy/ssl/tls_certificate_config.h" #include "absl/types/optional.h" @@ -73,6 +74,17 @@ class ContextConfig { * @param callback callback that is executed by context config. */ virtual void setSecretUpdateCallback(std::function callback) PURE; + + /** + * @return the handshaker to use for TLS handshakes. + */ + virtual Ssl::HandshakerSharedPtr createHandshaker(bssl::UniquePtr ssl) const PURE; + + /** + * @return whether or not this context requires certificates for TLS + * handshakes. + */ + virtual bool requireCertificates() const PURE; }; class ClientContextConfig : public virtual ContextConfig { diff --git a/include/envoy/ssl/handshaker.h b/include/envoy/ssl/handshaker.h new file mode 100644 index 000000000000..583406851577 --- /dev/null +++ b/include/envoy/ssl/handshaker.h @@ -0,0 +1,101 @@ +#pragma once + +#include "envoy/api/api.h" +#include "envoy/common/pure.h" +#include "envoy/config/typed_config.h" +#include "envoy/network/transport_socket.h" +#include "envoy/protobuf/message_validator.h" +#include "envoy/ssl/socket_state.h" + +#include "openssl/ssl.h" + +namespace Envoy { +namespace Ssl { + +class HandshakerCallbacks { +public: + virtual ~HandshakerCallbacks() = default; + + /** + * Called when a handshake is successfully performed. + */ + virtual void onSuccessCb(SSL* ssl) PURE; + /** + * Called when a handshake fails. + */ + virtual void onFailureCb() PURE; +}; + +/* + * Interface for a Handshaker which is responsible for owning the + * `bssl::UniquePtr` and performing handshakes. + */ +class Handshaker { +public: + virtual ~Handshaker() = default; + + /** + * Do the handshake. + * + * NB: |state| is a mutable reference. + */ + virtual Network::PostIoAction doHandshake(SocketState& state) PURE; + + /** + * Set internal pointers to Network::TransportSocketCallbacks and + * Ssl::HandshakerCallbacks. + * Depending on impl, these callbacks can be invoked to access connection + * state, raise connection events, etc. + */ + virtual void setCallbacks(Network::TransportSocketCallbacks& callbacks, + Ssl::HandshakerCallbacks& handshaker_callbacks) PURE; + + /* + * Access the held SSL object as a ptr. Callsites should handle nullptr + * gracefully. + */ + virtual SSL* ssl() PURE; +}; + +using HandshakerSharedPtr = std::shared_ptr; + +class HandshakerFactoryContext { +public: + virtual ~HandshakerFactoryContext() = default; + + /** + * @return reference to the Api object + */ + virtual Api::Api& api() PURE; + + /** + * The list of supported protocols exposed via ALPN, from ContextConfig. + */ + virtual absl::string_view alpnProtocols() const PURE; +}; + +using HandshakerFactoryCb = std::function)>; + +class HandshakerFactory : public Config::TypedFactory { +public: + /** + * @returns a callback (of type HandshakerFactoryCb). Accepts the |config| and + * |validation_visitor| for early config validation. This virtual base doesn't + * perform MessageUtil::downcastAndValidate, but an implementation should. + */ + virtual HandshakerFactoryCb + createHandshakerCb(const Protobuf::Message& message, + HandshakerFactoryContext& handshaker_factory_context, + ProtobufMessage::ValidationVisitor& validation_visitor) PURE; + + std::string category() const override { return "envoy.tls_handshakers"; } + + /** + * Implementations should return true if the tls context accompanying this + * handshaker expects certificates. + */ + virtual bool requireCertificates() const PURE; +}; + +} // namespace Ssl +} // namespace Envoy diff --git a/include/envoy/ssl/socket_state.h b/include/envoy/ssl/socket_state.h new file mode 100644 index 000000000000..aa60fbc178ab --- /dev/null +++ b/include/envoy/ssl/socket_state.h @@ -0,0 +1,9 @@ +#pragma once + +namespace Envoy { +namespace Ssl { + +enum class SocketState { PreHandshake, HandshakeInProgress, HandshakeComplete, ShutdownSent }; + +} // namespace Ssl +} // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/BUILD b/source/extensions/transport_sockets/tls/BUILD index b26ce0cc4d14..cad938faeee8 100644 --- a/source/extensions/transport_sockets/tls/BUILD +++ b/source/extensions/transport_sockets/tls/BUILD @@ -40,9 +40,12 @@ envoy_cc_library( deps = [ ":context_config_lib", ":context_lib", + ":handshaker_lib", ":utility_lib", "//include/envoy/network:connection_interface", "//include/envoy/network:transport_socket_interface", + "//include/envoy/ssl:handshaker_interface", + "//include/envoy/ssl:socket_state", "//include/envoy/ssl:ssl_socket_extended_info_interface", "//include/envoy/ssl/private_key:private_key_callbacks_interface", "//include/envoy/ssl/private_key:private_key_interface", @@ -63,6 +66,7 @@ envoy_cc_library( "ssl", ], deps = [ + ":handshaker_lib", "//include/envoy/secret:secret_callbacks_interface", "//include/envoy/secret:secret_provider_interface", "//include/envoy/server:transport_socket_config_interface", @@ -118,6 +122,35 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "handshaker_lib", + srcs = ["handshaker_impl.cc"], + hdrs = ["handshaker_impl.h"], + external_deps = [ + "ssl", + ], + deps = [ + "//include/envoy/ssl:context_config_interface", + "//include/envoy/ssl:context_interface", + "//include/envoy/ssl:context_manager_interface", + "//include/envoy/ssl:handshaker_interface", + "//include/envoy/ssl:socket_state", + "//include/envoy/ssl:ssl_socket_extended_info_interface", + "//include/envoy/ssl/private_key:private_key_interface", + "//include/envoy/stats:stats_interface", + "//include/envoy/stats:stats_macros", + "//source/common/common:assert_lib", + "//source/common/common:base64_lib", + "//source/common/common:hex_lib", + "//source/common/common:utility_lib", + "//source/common/network:address_lib", + "//source/common/protobuf:utility_lib", + "//source/common/stats:symbol_table_lib", + "//source/common/stats:utility_lib", + "//source/extensions/transport_sockets/tls/private_key:private_key_manager_lib", + ], +) + envoy_cc_library( name = "utility_lib", srcs = ["utility.cc"], diff --git a/source/extensions/transport_sockets/tls/context_config_impl.cc b/source/extensions/transport_sockets/tls/context_config_impl.cc index 6f20081eed80..6c2772684b60 100644 --- a/source/extensions/transport_sockets/tls/context_config_impl.cc +++ b/source/extensions/transport_sockets/tls/context_config_impl.cc @@ -213,6 +213,25 @@ ContextConfigImpl::ContextConfigImpl( } } } + + HandshakerFactoryContextImpl handshaker_factory_context(api_, alpnProtocols()); + Ssl::HandshakerFactory* handshaker_factory; + if (config.has_custom_listener_handshaker()) { + // If a custom handshaker is configured, derive the factory from the config. + const auto& handshaker_config = config.custom_listener_handshaker(); + handshaker_factory = + &Config::Utility::getAndCheckFactory(handshaker_config); + handshaker_factory_cb_ = handshaker_factory->createHandshakerCb( + handshaker_config.typed_config(), handshaker_factory_context, + factory_context.messageValidationVisitor()); + } else { + // Otherwise, derive the config from the (default) factory). + handshaker_factory = HandshakerFactoryImpl::getDefaultHandshakerFactory(); + handshaker_factory_cb_ = handshaker_factory->createHandshakerCb( + *handshaker_factory->createEmptyConfigProto(), handshaker_factory_context, + factory_context.messageValidationVisitor()); + } + require_certificates_ = handshaker_factory->requireCertificates(); } Ssl::CertificateValidationContextConfigPtr ContextConfigImpl::getCombinedValidationContextConfig( @@ -224,6 +243,10 @@ Ssl::CertificateValidationContextConfigPtr ContextConfigImpl::getCombinedValidat return std::make_unique(combined_cvc, api_); } +Ssl::HandshakerSharedPtr ContextConfigImpl::createHandshaker(bssl::UniquePtr ssl) const { + return handshaker_factory_cb_(std::move(ssl)); +} + void ContextConfigImpl::setSecretUpdateCallback(std::function callback) { if (!tls_certificate_providers_.empty()) { if (tc_update_callback_handle_) { @@ -404,7 +427,9 @@ ServerContextConfigImpl::ServerContextConfigImpl( if ((config.common_tls_context().tls_certificates().size() + config.common_tls_context().tls_certificate_sds_secret_configs().size()) == 0) { - throw EnvoyException("No TLS certificates found for server context"); + if (requireCertificates()) { + throw EnvoyException("No TLS certificates found for server context"); + } } else if (!config.common_tls_context().tls_certificates().empty() && !config.common_tls_context().tls_certificate_sds_secret_configs().empty()) { throw EnvoyException("SDS and non-SDS TLS certificates may not be mixed in server contexts"); diff --git a/source/extensions/transport_sockets/tls/context_config_impl.h b/source/extensions/transport_sockets/tls/context_config_impl.h index 9cfaff0482fb..11ce55a5005b 100644 --- a/source/extensions/transport_sockets/tls/context_config_impl.h +++ b/source/extensions/transport_sockets/tls/context_config_impl.h @@ -13,6 +13,8 @@ #include "common/json/json_loader.h" #include "common/ssl/tls_certificate_config_impl.h" +#include "extensions/transport_sockets/tls/handshaker_impl.h" + namespace Envoy { namespace Extensions { namespace TransportSockets { @@ -60,6 +62,10 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { const envoy::extensions::transport_sockets::tls::v3::CertificateValidationContext& dynamic_cvc); + Ssl::HandshakerSharedPtr createHandshaker(bssl::UniquePtr ssl) const override; + + bool requireCertificates() const override { return require_certificates_; } + protected: ContextConfigImpl(const envoy::extensions::transport_sockets::tls::v3::CommonTlsContext& config, const unsigned default_min_protocol_version, @@ -94,6 +100,9 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { Envoy::Common::CallbackHandle* cvc_validation_callback_handle_{}; const unsigned min_protocol_version_; const unsigned max_protocol_version_; + + Ssl::HandshakerFactoryCb handshaker_factory_cb_; + bool require_certificates_; }; class ClientContextConfigImpl : public ContextConfigImpl, public Envoy::Ssl::ClientContextConfig { diff --git a/source/extensions/transport_sockets/tls/context_impl.cc b/source/extensions/transport_sockets/tls/context_impl.cc index f42f9077fc42..08e1ab41d29e 100644 --- a/source/extensions/transport_sockets/tls/context_impl.cc +++ b/source/extensions/transport_sockets/tls/context_impl.cc @@ -73,7 +73,8 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c ssl_ciphers_(stat_name_set_->add("ssl.ciphers")), ssl_versions_(stat_name_set_->add("ssl.versions")), ssl_curves_(stat_name_set_->add("ssl.curves")), - ssl_sigalgs_(stat_name_set_->add("ssl.sigalgs")) { + ssl_sigalgs_(stat_name_set_->add("ssl.sigalgs")), + require_certificates_(config.requireCertificates()) { const auto tls_certificates = config.tlsCertificates(); tls_contexts_.resize(std::max(static_cast(1), tls_certificates.size())); @@ -984,7 +985,7 @@ ServerContextImpl::ServerContextImpl(Stats::Scope& scope, const std::vector& server_names, TimeSource& time_source) : ContextImpl(scope, config, time_source), session_ticket_keys_(config.sessionTicketKeys()) { - if (config.tlsCertificates().empty()) { + if (config.tlsCertificates().empty() && config.requireCertificates()) { throw EnvoyException("Server TlsCertificates must have a certificate specified"); } @@ -1063,67 +1064,69 @@ ServerContextImpl::generateHashForSessionContextId(const std::vector= 0) { - X509_NAME_ENTRY* cn_entry = X509_NAME_get_entry(cert_subject, cn_index); - RELEASE_ASSERT(cn_entry != nullptr, "certificate subject CN should be present"); - - ASN1_STRING* cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry); - if (ASN1_STRING_length(cn_asn1) <= 0) { - throw EnvoyException("Invalid TLS context has an empty subject CN"); - } + if (require_certificates_) { + for (const auto& ctx : tls_contexts_) { + X509* cert = SSL_CTX_get0_certificate(ctx.ssl_ctx_.get()); + RELEASE_ASSERT(cert != nullptr, "TLS context should have an active certificate"); + X509_NAME* cert_subject = X509_get_subject_name(cert); + RELEASE_ASSERT(cert_subject != nullptr, "TLS certificate should have a subject"); + + const int cn_index = X509_NAME_get_index_by_NID(cert_subject, NID_commonName, -1); + if (cn_index >= 0) { + X509_NAME_ENTRY* cn_entry = X509_NAME_get_entry(cert_subject, cn_index); + RELEASE_ASSERT(cn_entry != nullptr, "certificate subject CN should be present"); + + ASN1_STRING* cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry); + if (ASN1_STRING_length(cn_asn1) <= 0) { + throw EnvoyException("Invalid TLS context has an empty subject CN"); + } - rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(cn_asn1), ASN1_STRING_length(cn_asn1)); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - } + rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(cn_asn1), ASN1_STRING_length(cn_asn1)); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + } - unsigned san_count = 0; - bssl::UniquePtr san_names(static_cast( - X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr))); - - if (san_names != nullptr) { - for (const GENERAL_NAME* san : san_names.get()) { - switch (san->type) { - case GEN_IPADD: - rc = EVP_DigestUpdate(md.get(), san->d.iPAddress->data, san->d.iPAddress->length); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - ++san_count; - break; - case GEN_DNS: - rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.dNSName), - ASN1_STRING_length(san->d.dNSName)); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - ++san_count; - break; - case GEN_URI: - rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.uniformResourceIdentifier), - ASN1_STRING_length(san->d.uniformResourceIdentifier)); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - ++san_count; - break; + unsigned san_count = 0; + bssl::UniquePtr san_names(static_cast( + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr))); + + if (san_names != nullptr) { + for (const GENERAL_NAME* san : san_names.get()) { + switch (san->type) { + case GEN_IPADD: + rc = EVP_DigestUpdate(md.get(), san->d.iPAddress->data, san->d.iPAddress->length); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + ++san_count; + break; + case GEN_DNS: + rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.dNSName), + ASN1_STRING_length(san->d.dNSName)); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + ++san_count; + break; + case GEN_URI: + rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.uniformResourceIdentifier), + ASN1_STRING_length(san->d.uniformResourceIdentifier)); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + ++san_count; + break; + } } } - } - // It's possible that the certificate doesn't have a subject, but - // does have SANs. Make sure that we have one or the other. - if (cn_index < 0 && san_count == 0) { - throw EnvoyException("Invalid TLS context has neither subject CN nor SAN names"); - } + // It's possible that the certificate doesn't have a subject, but + // does have SANs. Make sure that we have one or the other. + if (cn_index < 0 && san_count == 0) { + throw EnvoyException("Invalid TLS context has neither subject CN nor SAN names"); + } - rc = X509_NAME_digest(X509_get_issuer_name(cert), EVP_sha256(), hash_buffer, &hash_length); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - RELEASE_ASSERT(hash_length == SHA256_DIGEST_LENGTH, - fmt::format("invalid SHA256 hash length {}", hash_length)); + rc = X509_NAME_digest(X509_get_issuer_name(cert), EVP_sha256(), hash_buffer, &hash_length); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + RELEASE_ASSERT(hash_length == SHA256_DIGEST_LENGTH, + fmt::format("invalid SHA256 hash length {}", hash_length)); - rc = EVP_DigestUpdate(md.get(), hash_buffer, hash_length); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + rc = EVP_DigestUpdate(md.get(), hash_buffer, hash_length); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + } } // Hash all the settings that affect whether the server will allow/accept diff --git a/source/extensions/transport_sockets/tls/context_impl.h b/source/extensions/transport_sockets/tls/context_impl.h index 407dd45f86f8..6b32900b2d3f 100644 --- a/source/extensions/transport_sockets/tls/context_impl.h +++ b/source/extensions/transport_sockets/tls/context_impl.h @@ -202,6 +202,7 @@ class ContextImpl : public virtual Envoy::Ssl::Context { const Stats::StatName ssl_versions_; const Stats::StatName ssl_curves_; const Stats::StatName ssl_sigalgs_; + const bool require_certificates_; }; using ContextImplSharedPtr = std::shared_ptr; diff --git a/source/extensions/transport_sockets/tls/handshaker_impl.cc b/source/extensions/transport_sockets/tls/handshaker_impl.cc new file mode 100644 index 000000000000..1ad370008fe9 --- /dev/null +++ b/source/extensions/transport_sockets/tls/handshaker_impl.cc @@ -0,0 +1,46 @@ +#include "extensions/transport_sockets/tls/handshaker_impl.h" + +#include "envoy/network/connection.h" +#include "envoy/ssl/handshaker.h" +#include "envoy/ssl/socket_state.h" + +#include "common/common/assert.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +using Envoy::Ssl::SocketState; +using Network::PostIoAction; + +PostIoAction HandshakerImpl::doHandshake(SocketState& state) { + ASSERT(state != SocketState::HandshakeComplete && state != SocketState::ShutdownSent); + int rc = SSL_do_handshake(ssl_.get()); + if (rc == 1) { + state = SocketState::HandshakeComplete; + handshaker_callbacks_->onSuccessCb(ssl_.get()); + + // It's possible that we closed during the handshake callback. + return transport_socket_callbacks_->connection().state() == Network::Connection::State::Open + ? PostIoAction::KeepOpen + : PostIoAction::Close; + } else { + switch (SSL_get_error(ssl_.get(), rc)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return PostIoAction::KeepOpen; + case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: + state = SocketState::HandshakeInProgress; + return PostIoAction::KeepOpen; + default: + handshaker_callbacks_->onFailureCb(); + return PostIoAction::Close; + } + } +} + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/handshaker_impl.h b/source/extensions/transport_sockets/tls/handshaker_impl.h new file mode 100644 index 000000000000..3be89fefdbc8 --- /dev/null +++ b/source/extensions/transport_sockets/tls/handshaker_impl.h @@ -0,0 +1,80 @@ +#pragma once + +#include "envoy/ssl/handshaker.h" +#include "envoy/ssl/socket_state.h" + +#include "openssl/ssl.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +// Class to abstract handshaking behavior. +// Manages translation between SSL error codes and Network::PostIoAction +// response enums, among other things. +class HandshakerImpl : public Envoy::Ssl::Handshaker { +public: + HandshakerImpl(bssl::UniquePtr ssl) : ssl_(std::move(ssl)) {} + + Network::PostIoAction doHandshake(Envoy::Ssl::SocketState& state) override; + + void setCallbacks(Network::TransportSocketCallbacks& transport_socket_callbacks, + Ssl::HandshakerCallbacks& handshaker_callbacks) override { + transport_socket_callbacks_ = &transport_socket_callbacks; + handshaker_callbacks_ = &handshaker_callbacks; + } + + SSL* ssl() override { return ssl_.get(); } + +private: + bssl::UniquePtr ssl_; + Network::TransportSocketCallbacks* transport_socket_callbacks_{}; + Ssl::HandshakerCallbacks* handshaker_callbacks_{}; +}; + +class HandshakerFactoryContextImpl : public Ssl::HandshakerFactoryContext { +public: + HandshakerFactoryContextImpl(Api::Api& api, absl::string_view alpn_protocols) + : api_(api), alpn_protocols_(alpn_protocols) {} + + // HandshakerFactoryContext + Api::Api& api() override { return api_; } + absl::string_view alpnProtocols() const override { return alpn_protocols_; } + +private: + Api::Api& api_; + const std::string alpn_protocols_; +}; + +class HandshakerFactoryImpl : public Ssl::HandshakerFactory { +public: + std::string name() const override { return "envoy.default_tls_handshaker"; } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return ProtobufTypes::MessagePtr{new Envoy::ProtobufWkt::Struct()}; + } + + Ssl::HandshakerFactoryCb createHandshakerCb(const Protobuf::Message&, + Ssl::HandshakerFactoryContext&, + ProtobufMessage::ValidationVisitor&) override { + // The default HandshakerImpl doesn't take a config or use the HandshakerFactoryContext. + return + [](bssl::UniquePtr ssl) { return std::make_shared(std::move(ssl)); }; + } + + bool requireCertificates() const override { + // The default HandshakerImpl does require certificates. + return true; + } + + static HandshakerFactory* getDefaultHandshakerFactory() { + static HandshakerFactoryImpl default_handshaker_factory; + return &default_handshaker_factory; + } +}; + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index ab2644ccc808..6ce29f4a5df7 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -1,5 +1,6 @@ #include "extensions/transport_sockets/tls/ssl_socket.h" +#include "envoy/network/transport_socket.h" #include "envoy/stats/scope.h" #include "common/common/assert.h" @@ -14,6 +15,7 @@ #include "openssl/x509v3.h" using Envoy::Network::PostIoAction; +using Envoy::Ssl::SocketState; namespace Envoy { namespace Extensions { @@ -43,11 +45,13 @@ class NotReadySslSocket : public Network::TransportSocket { } // namespace SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, - const Network::TransportSocketOptionsSharedPtr& transport_socket_options) + const Network::TransportSocketOptionsSharedPtr& transport_socket_options, + Ssl::HandshakerFactoryCb handshaker_factory_cb) : transport_socket_options_(transport_socket_options), - ctx_(std::dynamic_pointer_cast(ctx)), state_(SocketState::PreHandshake) { - bssl::UniquePtr ssl = ctx_->newSsl(transport_socket_options_.get()); - info_ = std::make_shared(std::move(ssl), ctx_); + ctx_(std::dynamic_pointer_cast(ctx)), + handshaker_(handshaker_factory_cb(ctx_->newSsl(transport_socket_options_.get()))), + state_(SocketState::PreHandshake) { + info_ = std::make_shared(handshaker_, ctx_); if (state == InitialState::Client) { SSL_set_connect_state(rawSsl()); @@ -60,6 +64,7 @@ SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, void SslSocket::setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) { ASSERT(!callbacks_); callbacks_ = &callbacks; + handshaker_->setCallbacks(callbacks, *this); // Associate this SSL connection with all the certificates (with their potentially different // private key methods). @@ -168,38 +173,13 @@ void SslSocket::onPrivateKeyMethodComplete() { } } -PostIoAction SslSocket::doHandshake() { - ASSERT(state_ != SocketState::HandshakeComplete && state_ != SocketState::ShutdownSent); - int rc = SSL_do_handshake(rawSsl()); - if (rc == 1) { - ENVOY_CONN_LOG(debug, "handshake complete", callbacks_->connection()); - state_ = SocketState::HandshakeComplete; - ctx_->logHandshake(rawSsl()); - callbacks_->raiseEvent(Network::ConnectionEvent::Connected); - - // It's possible that we closed during the handshake callback. - return callbacks_->connection().state() == Network::Connection::State::Open - ? PostIoAction::KeepOpen - : PostIoAction::Close; - } else { - int err = SSL_get_error(rawSsl(), rc); - switch (err) { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - ENVOY_CONN_LOG(debug, "handshake expecting {}", callbacks_->connection(), - err == SSL_ERROR_WANT_READ ? "read" : "write"); - return PostIoAction::KeepOpen; - case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: - ENVOY_CONN_LOG(debug, "handshake continued asynchronously", callbacks_->connection()); - state_ = SocketState::HandshakeInProgress; - return PostIoAction::KeepOpen; - default: - ENVOY_CONN_LOG(debug, "handshake error: {}", callbacks_->connection(), err); - drainErrorQueue(); - return PostIoAction::Close; - } - } +void SslSocket::onSuccessCb(SSL* ssl) { + ctx_->logHandshake(ssl); + callbacks_->raiseEvent(Network::ConnectionEvent::Connected); } +void SslSocket::onFailureCb() { drainErrorQueue(); } + +PostIoAction SslSocket::doHandshake() { return handshaker_->doHandshake(state_); } void SslSocket::drainErrorQueue() { bool saw_error = false; @@ -309,9 +289,9 @@ Envoy::Ssl::ClientValidationStatus SslExtendedSocketInfoImpl::certificateValidat return certificate_validation_status_; } -SslSocketInfo::SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx) - : ssl_(std::move(ssl)) { - SSL_set_ex_data(ssl_.get(), ctx->sslExtendedSocketInfoIndex(), &(this->extended_socket_info_)); +SslSocketInfo::SslSocketInfo(Ssl::HandshakerSharedPtr handshaker, ContextImplSharedPtr ctx) + : handshaker_(handshaker) { + SSL_set_ex_data(ssl(), ctx->sslExtendedSocketInfoIndex(), &(this->extended_socket_info_)); } bool SslSocketInfo::peerCertificatePresented() const { @@ -640,8 +620,9 @@ Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket( ssl_ctx = ssl_ctx_; } if (ssl_ctx) { - return std::make_unique(std::move(ssl_ctx), InitialState::Client, - transport_socket_options); + return std::make_unique( + std::move(ssl_ctx), InitialState::Client, transport_socket_options, + [this](bssl::UniquePtr ssl) { return config_->createHandshaker(std::move(ssl)); }); } else { ENVOY_LOG(debug, "Create NotReadySslSocket"); stats_.upstream_context_secrets_not_ready_.inc(); @@ -681,7 +662,9 @@ ServerSslSocketFactory::createTransportSocket(Network::TransportSocketOptionsSha ssl_ctx = ssl_ctx_; } if (ssl_ctx) { - return std::make_unique(std::move(ssl_ctx), InitialState::Server, nullptr); + return std::make_unique( + std::move(ssl_ctx), InitialState::Server, nullptr, + [this](bssl::UniquePtr ssl) { return config_->createHandshaker(std::move(ssl)); }); } else { ENVOY_LOG(debug, "Create NotReadySslSocket"); stats_.downstream_context_secrets_not_ready_.inc(); diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 27416ce7f635..478f337cf17d 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -6,7 +6,9 @@ #include "envoy/network/connection.h" #include "envoy/network/transport_socket.h" #include "envoy/secret/secret_callbacks.h" +#include "envoy/ssl/handshaker.h" #include "envoy/ssl/private_key/private_key_callbacks.h" +#include "envoy/ssl/socket_state.h" #include "envoy/ssl/ssl_socket_extended_info.h" #include "envoy/stats/scope.h" #include "envoy/stats/stats_macros.h" @@ -14,6 +16,7 @@ #include "common/common/logger.h" #include "extensions/transport_sockets/tls/context_impl.h" +#include "extensions/transport_sockets/tls/handshaker_impl.h" #include "extensions/transport_sockets/tls/utility.h" #include "absl/container/node_hash_map.h" @@ -39,7 +42,6 @@ struct SslSocketFactoryStats { }; enum class InitialState { Client, Server }; -enum class SocketState { PreHandshake, HandshakeInProgress, HandshakeComplete, ShutdownSent }; class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { public: @@ -53,7 +55,7 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { public: - SslSocketInfo(bssl::UniquePtr ssl, ContextImplSharedPtr ctx); + SslSocketInfo(Ssl::HandshakerSharedPtr handshaker, ContextImplSharedPtr ctx); // Ssl::ConnectionInfo bool peerCertificatePresented() const override; @@ -77,9 +79,11 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { std::string ciphersuiteString() const override; const std::string& tlsVersion() const override; absl::optional x509Extension(absl::string_view extension_name) const override; - SSL* ssl() const { return ssl_.get(); } + SSL* ssl() const { return handshaker_->ssl(); } - bssl::UniquePtr ssl_; + // Owns a shared ptr to the Handshaker for access to the SSL*, even after + // SslSocket is destroyed. + Ssl::HandshakerSharedPtr handshaker_; private: mutable std::vector cached_uri_san_local_certificate_; @@ -99,20 +103,22 @@ class SslSocketInfo : public Envoy::Ssl::ConnectionInfo { mutable SslExtendedSocketInfoImpl extended_socket_info_; }; -using SslSocketInfoConstSharedPtr = std::shared_ptr; +using SslSocketInfoSharedPtr = std::shared_ptr; class SslSocket : public Network::TransportSocket, public Envoy::Ssl::PrivateKeyConnectionCallbacks, + public Envoy::Ssl::HandshakerCallbacks, protected Logger::Loggable { public: SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, - const Network::TransportSocketOptionsSharedPtr& transport_socket_options); + const Network::TransportSocketOptionsSharedPtr& transport_socket_options, + Ssl::HandshakerFactoryCb handshaker_factory_cb); // Network::TransportSocket void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override; std::string protocol() const override; absl::string_view failureReason() const override; - bool canFlushClose() override { return state_ == SocketState::HandshakeComplete; } + bool canFlushClose() override { return state_ == Envoy::Ssl::SocketState::HandshakeComplete; } void closeSocket(Network::ConnectionEvent close_type) override; Network::IoResult doRead(Buffer::Instance& read_buffer) override; Network::IoResult doWrite(Buffer::Instance& write_buffer, bool end_stream) override; @@ -120,11 +126,14 @@ class SslSocket : public Network::TransportSocket, Ssl::ConnectionInfoConstSharedPtr ssl() const override; // Ssl::PrivateKeyConnectionCallbacks void onPrivateKeyMethodComplete() override; + // Ssl::HandshakerCallbacks + void onSuccessCb(SSL* ssl) override; + void onFailureCb() override; SSL* rawSslForTest() const { return rawSsl(); } protected: - SSL* rawSsl() const { return info_->ssl_.get(); } + SSL* rawSsl() const { return handshaker_->ssl(); } private: struct ReadResult { @@ -143,11 +152,12 @@ class SslSocket : public Network::TransportSocket, const Network::TransportSocketOptionsSharedPtr transport_socket_options_; Network::TransportSocketCallbacks* callbacks_{}; ContextImplSharedPtr ctx_; + Ssl::HandshakerSharedPtr handshaker_; uint64_t bytes_to_retry_{}; std::string failure_reason_; - SocketState state_; + Envoy::Ssl::SocketState state_; - SslSocketInfoConstSharedPtr info_; + SslSocketInfoSharedPtr info_; }; class ClientSslSocketFactory : public Network::TransportSocketFactory, diff --git a/test/extensions/transport_sockets/tls/BUILD b/test/extensions/transport_sockets/tls/BUILD index 4a7a2cd9481e..3f302585f586 100644 --- a/test/extensions/transport_sockets/tls/BUILD +++ b/test/extensions/transport_sockets/tls/BUILD @@ -96,6 +96,28 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "handshaker_test", + srcs = ["handshaker_test.cc"], + data = [ + "gen_unittest_certs.sh", + "//test/config/integration/certs", + "//test/extensions/transport_sockets/tls/test_data:certs", + ], + external_deps = ["ssl"], + deps = [ + ":ssl_socket_test", + ":ssl_test_utils", + "//source/extensions/transport_sockets/tls:handshaker_lib", + "//test/mocks/buffer:buffer_mocks", + "//test/mocks/network:network_mocks", + "//test/mocks/runtime:runtime_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/ssl:ssl_mocks", + "//test/mocks/stats:stats_mocks", + ], +) + envoy_cc_test( name = "utility_test", srcs = [ diff --git a/test/extensions/transport_sockets/tls/handshaker_test.cc b/test/extensions/transport_sockets/tls/handshaker_test.cc new file mode 100644 index 000000000000..d01b7ae43857 --- /dev/null +++ b/test/extensions/transport_sockets/tls/handshaker_test.cc @@ -0,0 +1,256 @@ +#include + +#include "envoy/network/transport_socket.h" +#include "envoy/ssl/handshaker.h" +#include "envoy/ssl/socket_state.h" + +#include "extensions/transport_sockets/tls/handshaker_impl.h" + +#include "test/extensions/transport_sockets/tls/ssl_certs_test.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "openssl/evp.h" +#include "openssl/hmac.h" +#include "openssl/ssl.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { +namespace { + +using ::testing::AtLeast; +using ::testing::Invoke; +using ::testing::StrictMock; + +// A callback shaped like pem_password_cb. +// See https://www.openssl.org/docs/man1.1.0/man3/pem_password_cb.html. +int PemPasswordCallback(char* buf, int buf_size, int, void* u) { + if (u == nullptr) { + return 0; + } + std::string passphrase = *reinterpret_cast(u); + ASSERT(buf_size >= static_cast(passphrase.size())); + memcpy(buf, passphrase.data(), passphrase.size()); + return passphrase.size(); +} + +class MockHandshakerCallbacks : public Ssl::HandshakerCallbacks { +public: + ~MockHandshakerCallbacks() override{}; + MOCK_METHOD(void, onSuccessCb, (SSL*), (override)); + MOCK_METHOD(void, onFailureCb, (), (override)); +}; + +class HandshakerTest : public SslCertsTest { +protected: + HandshakerTest() + : dispatcher_(api_->allocateDispatcher("test_thread")), stream_info_(api_->timeSource()), + client_ctx_(SSL_CTX_new(TLS_method())), server_ctx_(SSL_CTX_new(TLS_method())) { + // Set up key and cert, initialize two SSL objects and a pair of BIOs for + // handshaking. + auto key = MakeKey(); + auto cert = MakeCert(); + auto chain = std::vector{cert.get()}; + + server_ssl_ = bssl::UniquePtr(SSL_new(server_ctx_.get())); + SSL_set_accept_state(server_ssl_.get()); + ASSERT( + SSL_set_chain_and_key(server_ssl_.get(), chain.data(), chain.size(), key.get(), nullptr)); + + client_ssl_ = bssl::UniquePtr(SSL_new(client_ctx_.get())); + SSL_set_connect_state(client_ssl_.get()); + + ASSERT(BIO_new_bio_pair(&client_bio_, kBufferLength, &server_bio_, kBufferLength)); + + BIO_up_ref(client_bio_); + BIO_up_ref(server_bio_); + SSL_set0_rbio(client_ssl_.get(), client_bio_); + SSL_set0_wbio(client_ssl_.get(), client_bio_); + SSL_set0_rbio(server_ssl_.get(), server_bio_); + SSL_set0_wbio(server_ssl_.get(), server_bio_); + } + + // Read in key.pem and return a new private key. + bssl::UniquePtr MakeKey() { + std::string file = TestEnvironment::readFileToStringForTest( + TestEnvironment::substitute("{{ test_tmpdir }}/unittestkey.pem")); + std::string passphrase = ""; + bssl::UniquePtr bio(BIO_new_mem_buf(file.data(), file.size())); + + bssl::UniquePtr key(EVP_PKEY_new()); + + RSA* rsa = PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, &PemPasswordCallback, &passphrase); + ASSERT(rsa && EVP_PKEY_assign_RSA(key.get(), rsa)); + return key; + } + + // Read in cert.pem and return a certificate. + bssl::UniquePtr MakeCert() { + std::string file = TestEnvironment::readFileToStringForTest( + TestEnvironment::substitute("{{ test_tmpdir }}/unittestcert.pem")); + bssl::UniquePtr bio(BIO_new_mem_buf(file.data(), file.size())); + + uint8_t* data = nullptr; + long len; // NOLINT (runtime/int) + ASSERT(PEM_bytes_read_bio(&data, &len, nullptr, PEM_STRING_X509, bio.get(), nullptr, nullptr)); + bssl::UniquePtr tmp(data); // Prevents memory leak. + return bssl::UniquePtr(CRYPTO_BUFFER_new(data, len, nullptr)); + } + + const size_t kBufferLength{100}; + + Event::DispatcherPtr dispatcher_; + StreamInfo::StreamInfoImpl stream_info_; + + BIO *client_bio_, *server_bio_; + bssl::UniquePtr client_ctx_, server_ctx_; + bssl::UniquePtr client_ssl_, server_ssl_; +}; + +TEST_F(HandshakerTest, NormalOperation) { + Network::MockTransportSocketCallbacks transport_socket_callbacks; + transport_socket_callbacks.connection_.state_ = Network::Connection::State::Closed; + EXPECT_CALL(transport_socket_callbacks, connection).Times(1); + + StrictMock handshaker_callbacks; + EXPECT_CALL(handshaker_callbacks, onSuccessCb).Times(1); + + HandshakerImpl handshaker(std::move(server_ssl_)); + + handshaker.setCallbacks(transport_socket_callbacks, handshaker_callbacks); + + auto socket_state = Ssl::SocketState::PreHandshake; + auto post_io_action = Network::PostIoAction::KeepOpen; // default enum + + // Run the handshakes from the client and server until HandshakerImpl decides + // we're done and returns PostIoAction::Close. + while (post_io_action != Network::PostIoAction::Close) { + SSL_do_handshake(client_ssl_.get()); + post_io_action = handshaker.doHandshake(socket_state); + } + + EXPECT_EQ(post_io_action, Network::PostIoAction::Close); + // HandshakerImpl should have set |socket_state| accordingly. + EXPECT_EQ(socket_state, Ssl::SocketState::HandshakeComplete); +} + +// We induce some kind of BIO mismatch and force the SSL_do_handshake to +// return an error code without error handling, i.e. not SSL_ERROR_WANT_READ +// or _WRITE or _PRIVATE_KEY_OPERATION. +TEST_F(HandshakerTest, ErrorCbOnAbnormalOperation) { + // We make a new BIO, set it as the `rbio`/`wbio` for the client SSL object, + // and break the BIO pair connecting the two SSL objects. Now handshaking will + // fail, likely with SSL_ERROR_SSL. + BIO* bio = BIO_new(BIO_s_socket()); + SSL_set_bio(client_ssl_.get(), bio, bio); + + HandshakerImpl handshaker(std::move(server_ssl_)); + + StrictMock transport_socket_callbacks; + + StrictMock handshaker_callbacks; + EXPECT_CALL(handshaker_callbacks, onFailureCb).Times(1); + + handshaker.setCallbacks(transport_socket_callbacks, handshaker_callbacks); + + auto socket_state = Ssl::SocketState::PreHandshake; + auto post_io_action = Network::PostIoAction::KeepOpen; // default enum + + while (post_io_action != Network::PostIoAction::Close) { + SSL_do_handshake(client_ssl_.get()); + post_io_action = handshaker.doHandshake(socket_state); + } + + // In the error case, HandshakerImpl also closes the connection. + EXPECT_EQ(post_io_action, Network::PostIoAction::Close); +} + +// Example HandshakerImpl demonstrating special-case behavior which necessitates +// extra SSL_ERROR case handling. Here, we induce an SSL_ERROR_WANT_X509_LOOKUP, +// check for it in the handshaker, faux-trigger the lookup, and then proceed as +// normal. +class HandshakerImplForTest : public Ssl::Handshaker { +public: + HandshakerImplForTest(bssl::UniquePtr ssl, std::function requested_cert_cb) + : ssl_(std::move(ssl)), requested_cert_cb_(requested_cert_cb) { + SSL_set_cert_cb( + ssl_.get(), [](SSL*, void* arg) -> int { return *static_cast(arg) ? 1 : -1; }, + &cert_cb_ok_); + } + + Network::PostIoAction doHandshake(Ssl::SocketState& state) override { + ASSERT(state != Ssl::SocketState::HandshakeComplete && state != Ssl::SocketState::ShutdownSent); + + int rc = SSL_do_handshake(ssl()); + if (rc == 1) { + state = Ssl::SocketState::HandshakeComplete; + handshaker_callbacks_->onSuccessCb(ssl()); + return Network::PostIoAction::Close; + } else { + switch (SSL_get_error(ssl(), rc)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return Network::PostIoAction::KeepOpen; + case SSL_ERROR_WANT_X509_LOOKUP: + // Special case. Once this lookup is requested, we flip the bit and allow + // the handshake to proceed. + requested_cert_cb_(); + return Network::PostIoAction::KeepOpen; + default: + handshaker_callbacks_->onFailureCb(); + return Network::PostIoAction::Close; + } + } + } + + void setCallbacks(Network::TransportSocketCallbacks& transport_socket_callbacks, + Ssl::HandshakerCallbacks& handshaker_callbacks) override { + transport_socket_callbacks_ = &transport_socket_callbacks; + handshaker_callbacks_ = &handshaker_callbacks; + } + + SSL* ssl() override { return ssl_.get(); } + + void setCertCbOk() { cert_cb_ok_ = true; } + +private: + bssl::UniquePtr ssl_; + std::function requested_cert_cb_; + bool cert_cb_ok_{false}; + Network::TransportSocketCallbacks* transport_socket_callbacks_{}; + Ssl::HandshakerCallbacks* handshaker_callbacks_{}; +}; + +TEST_F(HandshakerTest, NormalOperationWithHandshakerImplForTest) { + ::testing::MockFunction requested_cert_cb; + + HandshakerImplForTest handshaker(std::move(server_ssl_), requested_cert_cb.AsStdFunction()); + + EXPECT_CALL(requested_cert_cb, Call).WillOnce([&]() { handshaker.setCertCbOk(); }); + + Network::MockTransportSocketCallbacks transport_socket_callbacks; + + StrictMock handshaker_callbacks; + EXPECT_CALL(handshaker_callbacks, onSuccessCb).Times(1); + + handshaker.setCallbacks(transport_socket_callbacks, handshaker_callbacks); + + auto socket_state = Ssl::SocketState::PreHandshake; + auto post_io_action = Network::PostIoAction::KeepOpen; // default enum + + while (post_io_action != Network::PostIoAction::Close) { + SSL_do_handshake(client_ssl_.get()); + post_io_action = handshaker.doHandshake(socket_state); + } + + EXPECT_EQ(post_io_action, Network::PostIoAction::Close); + EXPECT_EQ(socket_state, Ssl::SocketState::HandshakeComplete); +} + +} // namespace +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index c3bc9b2f8ecd..f9a5e16b2863 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -87,6 +87,8 @@ class MockClientContextConfig : public ClientContextConfig { MOCK_METHOD(bool, isReady, (), (const)); MOCK_METHOD(void, setSecretUpdateCallback, (std::function callback)); + MOCK_METHOD(Ssl::HandshakerSharedPtr, createHandshaker, (bssl::UniquePtr), (const)); + MOCK_METHOD(const std::string&, serverNameIndication, (), (const)); MOCK_METHOD(bool, allowRenegotiation, (), (const)); MOCK_METHOD(size_t, maxSessionKeys, (), (const)); @@ -109,6 +111,7 @@ class MockServerContextConfig : public ServerContextConfig { MOCK_METHOD(bool, isReady, (), (const)); MOCK_METHOD(absl::optional, sessionTimeout, (), (const)); MOCK_METHOD(void, setSecretUpdateCallback, (std::function callback)); + MOCK_METHOD(Ssl::HandshakerSharedPtr, createHandshaker, (bssl::UniquePtr), (const)); MOCK_METHOD(bool, requireClientCertificate, (), (const)); MOCK_METHOD(const std::vector&, sessionTicketKeys, (), (const));