diff --git a/api/envoy/extensions/transport_sockets/tls/v3/tls.proto b/api/envoy/extensions/transport_sockets/tls/v3/tls.proto index 7ee7920c724d..f746f3d2f1cf 100644 --- a/api/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/api/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -99,7 +99,7 @@ message DownstreamTlsContext { } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 13] +// [#next-free-field: 14] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; @@ -238,4 +238,8 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaking + // behavior. + config.core.v3.TypedExtensionConfig custom_handshaker = 13; } diff --git a/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto b/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto index a73ba6e002ba..44963f687073 100644 --- a/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto +++ b/api/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto @@ -98,7 +98,7 @@ message DownstreamTlsContext { } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 13] +// [#next-free-field: 14] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.extensions.transport_sockets.tls.v3.CommonTlsContext"; @@ -243,4 +243,8 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaking + // behavior. + config.core.v4alpha.TypedExtensionConfig custom_handshaker = 13; } diff --git a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto index 7ee7920c724d..f746f3d2f1cf 100644 --- a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto +++ b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v3/tls.proto @@ -99,7 +99,7 @@ message DownstreamTlsContext { } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 13] +// [#next-free-field: 14] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.api.v2.auth.CommonTlsContext"; @@ -238,4 +238,8 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaking + // behavior. + config.core.v3.TypedExtensionConfig custom_handshaker = 13; } diff --git a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto index a73ba6e002ba..44963f687073 100644 --- a/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto +++ b/generated_api_shadow/envoy/extensions/transport_sockets/tls/v4alpha/tls.proto @@ -98,7 +98,7 @@ message DownstreamTlsContext { } // TLS context shared by both client and server TLS contexts. -// [#next-free-field: 13] +// [#next-free-field: 14] message CommonTlsContext { option (udpa.annotations.versioning).previous_message_type = "envoy.extensions.transport_sockets.tls.v3.CommonTlsContext"; @@ -243,4 +243,8 @@ message CommonTlsContext { // // There is no default for this parameter. If empty, Envoy will not expose ALPN. repeated string alpn_protocols = 4; + + // Custom TLS handshaker. If empty, defaults to native TLS handshaking + // behavior. + config.core.v4alpha.TypedExtensionConfig custom_handshaker = 13; } diff --git a/include/envoy/ssl/BUILD b/include/envoy/ssl/BUILD index b295a20e2a1a..08266cf1376f 100644 --- a/include/envoy/ssl/BUILD +++ b/include/envoy/ssl/BUILD @@ -29,6 +29,7 @@ envoy_cc_library( hdrs = ["context_config.h"], deps = [ ":certificate_validation_context_config_interface", + ":handshaker_interface", ":tls_certificate_config_interface", ], ) @@ -81,7 +82,10 @@ envoy_cc_library( hdrs = ["handshaker.h"], external_deps = ["ssl"], deps = [ + "//include/envoy/api:api_interface", + "//include/envoy/config:typed_config_interface", "//include/envoy/network:connection_interface", "//include/envoy/network:post_io_action_interface", + "//include/envoy/protobuf:message_validator_interface", ], ) diff --git a/include/envoy/ssl/context_config.h b/include/envoy/ssl/context_config.h index 9196a5a294a9..487e0b619f92 100644 --- a/include/envoy/ssl/context_config.h +++ b/include/envoy/ssl/context_config.h @@ -8,6 +8,7 @@ #include "envoy/common/pure.h" #include "envoy/ssl/certificate_validation_context_config.h" +#include "envoy/ssl/handshaker.h" #include "envoy/ssl/tls_certificate_config.h" #include "absl/types/optional.h" @@ -73,6 +74,16 @@ class ContextConfig { * @param callback callback that is executed by context config. */ virtual void setSecretUpdateCallback(std::function callback) PURE; + + /** + * @return a callback which can be used to create Handshaker instances. + */ + virtual HandshakerFactoryCb createHandshaker() const PURE; + + /** + * @return the set of capabilities for handshaker instances created by this context. + */ + virtual HandshakerCapabilities capabilities() const PURE; }; class ClientContextConfig : public virtual ContextConfig { diff --git a/include/envoy/ssl/handshaker.h b/include/envoy/ssl/handshaker.h index de11fc85f41f..42d20601071b 100644 --- a/include/envoy/ssl/handshaker.h +++ b/include/envoy/ssl/handshaker.h @@ -1,7 +1,10 @@ #pragma once +#include "envoy/api/api.h" +#include "envoy/config/typed_config.h" #include "envoy/network/connection.h" #include "envoy/network/post_io_action.h" +#include "envoy/protobuf/message_validator.h" #include "openssl/ssl.h" @@ -13,9 +16,9 @@ class HandshakeCallbacks { virtual ~HandshakeCallbacks() = default; /** - * @return the connection state. + * @return the connection. */ - virtual Network::Connection::State connectionState() const PURE; + virtual Network::Connection& connection() const PURE; /** * A callback which will be executed at most once upon successful completion @@ -43,5 +46,70 @@ class Handshaker { virtual Network::PostIoAction doHandshake() PURE; }; +using HandshakerSharedPtr = std::shared_ptr; +using HandshakerFactoryCb = + std::function, int, HandshakeCallbacks*)>; + +class HandshakerFactoryContext { +public: + virtual ~HandshakerFactoryContext() = default; + + /** + * @return reference to the Api object + */ + virtual Api::Api& api() PURE; + + /** + * The list of supported protocols exposed via ALPN, from ContextConfig. + */ + virtual absl::string_view alpnProtocols() const PURE; +}; + +struct HandshakerCapabilities { + // Whether or not a handshaker implementation provides certificates itself. + bool provides_certificates = false; + + // Whether or not a handshaker implementation verifies certificates itself. + bool verifies_peer_certificates = false; + + // Whether or not a handshaker implementation handles session resumption + // itself. + bool handles_session_resumption = false; + + // Whether or not a handshaker implementation provides its own list of ciphers + // and curves. + bool provides_ciphers_and_curves = false; + + // Whether or not a handshaker implementation handles ALPN selection. + bool handles_alpn_selection = false; + + // Should return true if this handshaker is FIPS-compliant. + // Envoy will fail to compile if this returns true and `--define=boringssl=fips`. + bool is_fips_compliant = true; +}; + +class HandshakerFactory : public Config::TypedFactory { +public: + /** + * @returns a callback to create a Handshaker. Accepts the |config| and + * |validation_visitor| for early validation. This virtual base doesn't + * perform MessageUtil::downcastAndValidate, but an implementation should. + */ + virtual HandshakerFactoryCb + createHandshakerCb(const Protobuf::Message& message, + HandshakerFactoryContext& handshaker_factory_context, + ProtobufMessage::ValidationVisitor& validation_visitor) PURE; + + std::string category() const override { return "envoy.tls_handshakers"; } + + /** + * Implementations should return a struct with their capabilities. See + * HandshakerCapabilities above. For any capability a Handshaker + * implementation explicitly declares, Envoy will not also configure that SSL + * capability. + */ + virtual HandshakerCapabilities capabilities() const PURE; +}; + } // namespace Ssl } // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/BUILD b/source/extensions/transport_sockets/tls/BUILD index 8ce5828b43ed..aabca8fd0581 100644 --- a/source/extensions/transport_sockets/tls/BUILD +++ b/source/extensions/transport_sockets/tls/BUILD @@ -28,6 +28,30 @@ envoy_cc_extension( ], ) +envoy_cc_library( + name = "ssl_handshaker_lib", + srcs = ["ssl_handshaker.cc"], + hdrs = ["ssl_handshaker.h"], + external_deps = ["ssl"], + deps = [ + ":context_lib", + ":utility_lib", + "//include/envoy/network:connection_interface", + "//include/envoy/network:transport_socket_interface", + "//include/envoy/ssl:handshaker_interface", + "//include/envoy/ssl:ssl_socket_extended_info_interface", + "//include/envoy/ssl:ssl_socket_state", + "//include/envoy/ssl/private_key:private_key_callbacks_interface", + "//include/envoy/ssl/private_key:private_key_interface", + "//include/envoy/stats:stats_macros", + "//source/common/common:assert_lib", + "//source/common/common:empty_string", + "//source/common/common:minimal_logger_lib", + "//source/common/common:thread_annotations", + "//source/common/http:headers_lib", + ], +) + envoy_cc_library( name = "io_handle_bio_lib", srcs = ["io_handle_bio.cc"], @@ -56,6 +80,7 @@ envoy_cc_library( ":context_config_lib", ":context_lib", ":io_handle_bio_lib", + ":ssl_handshaker_lib", ":utility_lib", "//include/envoy/network:connection_interface", "//include/envoy/network:transport_socket_interface", @@ -83,6 +108,7 @@ envoy_cc_library( # TLS is core functionality. visibility = ["//visibility:public"], deps = [ + ":ssl_handshaker_lib", "//include/envoy/secret:secret_callbacks_interface", "//include/envoy/secret:secret_provider_interface", "//include/envoy/server:transport_socket_config_interface", diff --git a/source/extensions/transport_sockets/tls/context_config_impl.cc b/source/extensions/transport_sockets/tls/context_config_impl.cc index 8c873564020b..56fbece90c73 100644 --- a/source/extensions/transport_sockets/tls/context_config_impl.cc +++ b/source/extensions/transport_sockets/tls/context_config_impl.cc @@ -12,6 +12,8 @@ #include "common/secret/sds_api.h" #include "common/ssl/certificate_validation_context_config_impl.h" +#include "extensions/transport_sockets/tls/ssl_handshaker.h" + #include "openssl/ssl.h" namespace Envoy { @@ -213,6 +215,25 @@ ContextConfigImpl::ContextConfigImpl( } } } + + HandshakerFactoryContextImpl handshaker_factory_context(api_, alpn_protocols_); + Ssl::HandshakerFactory* handshaker_factory; + if (config.has_custom_handshaker()) { + // If a custom handshaker is configured, derive the factory from the config. + const auto& handshaker_config = config.custom_handshaker(); + handshaker_factory = + &Config::Utility::getAndCheckFactory(handshaker_config); + handshaker_factory_cb_ = handshaker_factory->createHandshakerCb( + handshaker_config.typed_config(), handshaker_factory_context, + factory_context.messageValidationVisitor()); + } else { + // Otherwise, derive the config from the default factory. + handshaker_factory = HandshakerFactoryImpl::getDefaultHandshakerFactory(); + handshaker_factory_cb_ = handshaker_factory->createHandshakerCb( + *handshaker_factory->createEmptyConfigProto(), handshaker_factory_context, + factory_context.messageValidationVisitor()); + } + capabilities_ = handshaker_factory->capabilities(); } Ssl::CertificateValidationContextConfigPtr ContextConfigImpl::getCombinedValidationContextConfig( @@ -270,6 +291,10 @@ void ContextConfigImpl::setSecretUpdateCallback(std::function callback) } } +Ssl::HandshakerFactoryCb ContextConfigImpl::createHandshaker() const { + return handshaker_factory_cb_; +} + ContextConfigImpl::~ContextConfigImpl() { if (tc_update_callback_handle_) { tc_update_callback_handle_->remove(); @@ -400,12 +425,14 @@ ServerContextConfigImpl::ServerContextConfigImpl( } } - if ((config.common_tls_context().tls_certificates().size() + - config.common_tls_context().tls_certificate_sds_secret_configs().size()) == 0) { - throw EnvoyException("No TLS certificates found for server context"); - } else if (!config.common_tls_context().tls_certificates().empty() && - !config.common_tls_context().tls_certificate_sds_secret_configs().empty()) { - throw EnvoyException("SDS and non-SDS TLS certificates may not be mixed in server contexts"); + if (!capabilities().provides_certificates) { + if ((config.common_tls_context().tls_certificates().size() + + config.common_tls_context().tls_certificate_sds_secret_configs().size()) == 0) { + throw EnvoyException("No TLS certificates found for server context"); + } else if (!config.common_tls_context().tls_certificates().empty() && + !config.common_tls_context().tls_certificate_sds_secret_configs().empty()) { + throw EnvoyException("SDS and non-SDS TLS certificates may not be mixed in server contexts"); + } } if (config.has_session_timeout()) { diff --git a/source/extensions/transport_sockets/tls/context_config_impl.h b/source/extensions/transport_sockets/tls/context_config_impl.h index ad2d927d8231..9d8048c3b32a 100644 --- a/source/extensions/transport_sockets/tls/context_config_impl.h +++ b/source/extensions/transport_sockets/tls/context_config_impl.h @@ -55,6 +55,8 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { } void setSecretUpdateCallback(std::function callback) override; + Ssl::HandshakerFactoryCb createHandshaker() const override; + Ssl::HandshakerCapabilities capabilities() const override { return capabilities_; } Ssl::CertificateValidationContextConfigPtr getCombinedValidationContextConfig( const envoy::extensions::transport_sockets::tls::v3::CertificateValidationContext& @@ -94,6 +96,9 @@ class ContextConfigImpl : public virtual Ssl::ContextConfig { Envoy::Common::CallbackHandle* cvc_validation_callback_handle_{}; const unsigned min_protocol_version_; const unsigned max_protocol_version_; + + Ssl::HandshakerFactoryCb handshaker_factory_cb_; + Ssl::HandshakerCapabilities capabilities_; }; class ClientContextConfigImpl : public ContextConfigImpl, public Envoy::Ssl::ClientContextConfig { diff --git a/source/extensions/transport_sockets/tls/context_impl.cc b/source/extensions/transport_sockets/tls/context_impl.cc index 502739958e50..f461bb3c5d7f 100644 --- a/source/extensions/transport_sockets/tls/context_impl.cc +++ b/source/extensions/transport_sockets/tls/context_impl.cc @@ -74,7 +74,7 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c ssl_ciphers_(stat_name_set_->add("ssl.ciphers")), ssl_versions_(stat_name_set_->add("ssl.versions")), ssl_curves_(stat_name_set_->add("ssl.curves")), - ssl_sigalgs_(stat_name_set_->add("ssl.sigalgs")) { + ssl_sigalgs_(stat_name_set_->add("ssl.sigalgs")), capabilities_(config.capabilities()) { const auto tls_certificates = config.tlsCertificates(); tls_contexts_.resize(std::max(static_cast(1), tls_certificates.size())); @@ -90,7 +90,8 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c rc = SSL_CTX_set_max_proto_version(ctx.ssl_ctx_.get(), config.maxProtocolVersion()); RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - if (!SSL_CTX_set_strict_cipher_list(ctx.ssl_ctx_.get(), config.cipherSuites().c_str())) { + if (!capabilities_.provides_ciphers_and_curves && + !SSL_CTX_set_strict_cipher_list(ctx.ssl_ctx_.get(), config.cipherSuites().c_str())) { // Break up a set of ciphers into each individual cipher and try them each individually in // order to attempt to log which specific one failed. Example of config.cipherSuites(): // "-ALL:[ECDHE-ECDSA-AES128-GCM-SHA256|ECDHE-ECDSA-CHACHA20-POLY1305]:ECDHE-ECDSA-AES128-SHA". @@ -118,7 +119,8 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c config.cipherSuites(), absl::StrJoin(bad_ciphers, ", "))); } - if (!SSL_CTX_set1_curves_list(ctx.ssl_ctx_.get(), config.ecdhCurves().c_str())) { + if (!capabilities_.provides_ciphers_and_curves && + !SSL_CTX_set1_curves_list(ctx.ssl_ctx_.get(), config.ecdhCurves().c_str())) { throw EnvoyException(absl::StrCat("Failed to initialize ECDH curves ", config.ecdhCurves())); } } @@ -138,8 +140,16 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c } } +#ifdef BORINGSSL_FIPS + if (!capabilities_.is_fips_compliant) { + throw EnvoyException( + "Can't load a FIPS noncompliant custom handshaker while running in FIPS compliant mode."); + } +#endif + if (config.certificateValidationContext() != nullptr && - !config.certificateValidationContext()->caCert().empty()) { + !config.certificateValidationContext()->caCert().empty() && + !config.capabilities().provides_certificates) { ca_file_path_ = config.certificateValidationContext()->caCertPath(); bssl::UniquePtr bio( BIO_new_mem_buf(const_cast(config.certificateValidationContext()->caCert().data()), @@ -262,162 +272,168 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c } } - for (auto& ctx : tls_contexts_) { - if (verify_mode != SSL_VERIFY_NONE) { - SSL_CTX_set_verify(ctx.ssl_ctx_.get(), verify_mode, nullptr); - SSL_CTX_set_cert_verify_callback(ctx.ssl_ctx_.get(), ContextImpl::verifyCallback, this); + if (!capabilities_.verifies_peer_certificates) { + for (auto& ctx : tls_contexts_) { + if (verify_mode != SSL_VERIFY_NONE) { + SSL_CTX_set_verify(ctx.ssl_ctx_.get(), verify_mode, nullptr); + SSL_CTX_set_cert_verify_callback(ctx.ssl_ctx_.get(), ContextImpl::verifyCallback, this); + } } } absl::node_hash_set cert_pkey_ids; - for (uint32_t i = 0; i < tls_certificates.size(); ++i) { - auto& ctx = tls_contexts_[i]; - // Load certificate chain. - const auto& tls_certificate = tls_certificates[i].get(); - ctx.cert_chain_file_path_ = tls_certificate.certificateChainPath(); - bssl::UniquePtr bio( - BIO_new_mem_buf(const_cast(tls_certificate.certificateChain().data()), - tls_certificate.certificateChain().size())); - RELEASE_ASSERT(bio != nullptr, ""); - ctx.cert_chain_.reset(PEM_read_bio_X509_AUX(bio.get(), nullptr, nullptr, nullptr)); - if (ctx.cert_chain_ == nullptr || - !SSL_CTX_use_certificate(ctx.ssl_ctx_.get(), ctx.cert_chain_.get())) { - while (uint64_t err = ERR_get_error()) { - ENVOY_LOG_MISC(debug, "SSL error: {}:{}:{}:{}", err, ERR_lib_error_string(err), - ERR_func_error_string(err), ERR_GET_REASON(err), - ERR_reason_error_string(err)); + if (!capabilities_.provides_certificates) { + for (uint32_t i = 0; i < tls_certificates.size(); ++i) { + auto& ctx = tls_contexts_[i]; + // Load certificate chain. + const auto& tls_certificate = tls_certificates[i].get(); + ctx.cert_chain_file_path_ = tls_certificate.certificateChainPath(); + bssl::UniquePtr bio( + BIO_new_mem_buf(const_cast(tls_certificate.certificateChain().data()), + tls_certificate.certificateChain().size())); + RELEASE_ASSERT(bio != nullptr, ""); + ctx.cert_chain_.reset(PEM_read_bio_X509_AUX(bio.get(), nullptr, nullptr, nullptr)); + if (ctx.cert_chain_ == nullptr || + !SSL_CTX_use_certificate(ctx.ssl_ctx_.get(), ctx.cert_chain_.get())) { + while (uint64_t err = ERR_get_error()) { + ENVOY_LOG_MISC(debug, "SSL error: {}:{}:{}:{}", err, ERR_lib_error_string(err), + ERR_func_error_string(err), ERR_GET_REASON(err), + ERR_reason_error_string(err)); + } + throw EnvoyException( + absl::StrCat("Failed to load certificate chain from ", ctx.cert_chain_file_path_)); } - throw EnvoyException( - absl::StrCat("Failed to load certificate chain from ", ctx.cert_chain_file_path_)); - } - // Read rest of the certificate chain. - while (true) { - bssl::UniquePtr cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); - if (cert == nullptr) { - break; + // Read rest of the certificate chain. + while (true) { + bssl::UniquePtr cert(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr)); + if (cert == nullptr) { + break; + } + if (!SSL_CTX_add_extra_chain_cert(ctx.ssl_ctx_.get(), cert.get())) { + throw EnvoyException( + absl::StrCat("Failed to load certificate chain from ", ctx.cert_chain_file_path_)); + } + // SSL_CTX_add_extra_chain_cert() takes ownership. + cert.release(); } - if (!SSL_CTX_add_extra_chain_cert(ctx.ssl_ctx_.get(), cert.get())) { + // Check for EOF. + const uint32_t err = ERR_peek_last_error(); + if (ERR_GET_LIB(err) == ERR_LIB_PEM && ERR_GET_REASON(err) == PEM_R_NO_START_LINE) { + ERR_clear_error(); + } else { throw EnvoyException( absl::StrCat("Failed to load certificate chain from ", ctx.cert_chain_file_path_)); } - // SSL_CTX_add_extra_chain_cert() takes ownership. - cert.release(); - } - // Check for EOF. - const uint32_t err = ERR_peek_last_error(); - if (ERR_GET_LIB(err) == ERR_LIB_PEM && ERR_GET_REASON(err) == PEM_R_NO_START_LINE) { - ERR_clear_error(); - } else { - throw EnvoyException( - absl::StrCat("Failed to load certificate chain from ", ctx.cert_chain_file_path_)); - } - bssl::UniquePtr public_key(X509_get_pubkey(ctx.cert_chain_.get())); - const int pkey_id = EVP_PKEY_id(public_key.get()); - if (!cert_pkey_ids.insert(pkey_id).second) { - throw EnvoyException(fmt::format("Failed to load certificate chain from {}, at most one " - "certificate of a given type may be specified", - ctx.cert_chain_file_path_)); - } - ctx.is_ecdsa_ = pkey_id == EVP_PKEY_EC; - switch (pkey_id) { - case EVP_PKEY_EC: { - // We only support P-256 ECDSA today. - const EC_KEY* ecdsa_public_key = EVP_PKEY_get0_EC_KEY(public_key.get()); - // Since we checked the key type above, this should be valid. - ASSERT(ecdsa_public_key != nullptr); - const EC_GROUP* ecdsa_group = EC_KEY_get0_group(ecdsa_public_key); - if (ecdsa_group == nullptr || EC_GROUP_get_curve_name(ecdsa_group) != NID_X9_62_prime256v1) { - throw EnvoyException(fmt::format("Failed to load certificate chain from {}, only P-256 " - "ECDSA certificates are supported", + bssl::UniquePtr public_key(X509_get_pubkey(ctx.cert_chain_.get())); + const int pkey_id = EVP_PKEY_id(public_key.get()); + if (!cert_pkey_ids.insert(pkey_id).second) { + throw EnvoyException(fmt::format("Failed to load certificate chain from {}, at most one " + "certificate of a given type may be specified", ctx.cert_chain_file_path_)); } - ctx.is_ecdsa_ = true; - } break; - case EVP_PKEY_RSA: { - // We require RSA certificates with 2048-bit or larger keys. - const RSA* rsa_public_key = EVP_PKEY_get0_RSA(public_key.get()); - // Since we checked the key type above, this should be valid. - ASSERT(rsa_public_key != nullptr); - const unsigned rsa_key_length = RSA_size(rsa_public_key); + ctx.is_ecdsa_ = pkey_id == EVP_PKEY_EC; + switch (pkey_id) { + case EVP_PKEY_EC: { + // We only support P-256 ECDSA today. + const EC_KEY* ecdsa_public_key = EVP_PKEY_get0_EC_KEY(public_key.get()); + // Since we checked the key type above, this should be valid. + ASSERT(ecdsa_public_key != nullptr); + const EC_GROUP* ecdsa_group = EC_KEY_get0_group(ecdsa_public_key); + if (ecdsa_group == nullptr || + EC_GROUP_get_curve_name(ecdsa_group) != NID_X9_62_prime256v1) { + throw EnvoyException(fmt::format("Failed to load certificate chain from {}, only P-256 " + "ECDSA certificates are supported", + ctx.cert_chain_file_path_)); + } + ctx.is_ecdsa_ = true; + } break; + case EVP_PKEY_RSA: { + // We require RSA certificates with 2048-bit or larger keys. + const RSA* rsa_public_key = EVP_PKEY_get0_RSA(public_key.get()); + // Since we checked the key type above, this should be valid. + ASSERT(rsa_public_key != nullptr); + const unsigned rsa_key_length = RSA_size(rsa_public_key); #ifdef BORINGSSL_FIPS - if (rsa_key_length != 2048 / 8 && rsa_key_length != 3072 / 8) { - throw EnvoyException( - fmt::format("Failed to load certificate chain from {}, only RSA certificates with " - "2048-bit or 3072-bit keys are supported in FIPS mode", - ctx.cert_chain_file_path_)); - } + if (rsa_key_length != 2048 / 8 && rsa_key_length != 3072 / 8) { + throw EnvoyException( + fmt::format("Failed to load certificate chain from {}, only RSA certificates with " + "2048-bit or 3072-bit keys are supported in FIPS mode", + ctx.cert_chain_file_path_)); + } #else - if (rsa_key_length < 2048 / 8) { - throw EnvoyException(fmt::format("Failed to load certificate chain from {}, only RSA " - "certificates with 2048-bit or larger keys are supported", - ctx.cert_chain_file_path_)); - } + if (rsa_key_length < 2048 / 8) { + throw EnvoyException( + fmt::format("Failed to load certificate chain from {}, only RSA " + "certificates with 2048-bit or larger keys are supported", + ctx.cert_chain_file_path_)); + } #endif - } break; + } break; #ifdef BORINGSSL_FIPS - default: - throw EnvoyException(fmt::format("Failed to load certificate chain from {}, only RSA and " - "ECDSA certificates are supported in FIPS mode", - ctx.cert_chain_file_path_)); + default: + throw EnvoyException(fmt::format("Failed to load certificate chain from {}, only RSA and " + "ECDSA certificates are supported in FIPS mode", + ctx.cert_chain_file_path_)); #endif - } - - Envoy::Ssl::PrivateKeyMethodProviderSharedPtr private_key_method_provider = - tls_certificate.privateKeyMethod(); - // We either have a private key or a BoringSSL private key method provider. - if (private_key_method_provider) { - ctx.private_key_method_provider_ = private_key_method_provider; - // The provider has a reference to the private key method for the context lifetime. - Ssl::BoringSslPrivateKeyMethodSharedPtr private_key_method = - private_key_method_provider->getBoringSslPrivateKeyMethod(); - if (private_key_method == nullptr) { - throw EnvoyException( - fmt::format("Failed to get BoringSSL private key method from provider")); } + + Envoy::Ssl::PrivateKeyMethodProviderSharedPtr private_key_method_provider = + tls_certificate.privateKeyMethod(); + // We either have a private key or a BoringSSL private key method provider. + if (private_key_method_provider) { + ctx.private_key_method_provider_ = private_key_method_provider; + // The provider has a reference to the private key method for the context lifetime. + Ssl::BoringSslPrivateKeyMethodSharedPtr private_key_method = + private_key_method_provider->getBoringSslPrivateKeyMethod(); + if (private_key_method == nullptr) { + throw EnvoyException( + fmt::format("Failed to get BoringSSL private key method from provider")); + } #ifdef BORINGSSL_FIPS - if (!ctx.private_key_method_provider_->checkFips()) { - throw EnvoyException( - fmt::format("Private key method doesn't support FIPS mode with current parameters")); - } + if (!ctx.private_key_method_provider_->checkFips()) { + throw EnvoyException( + fmt::format("Private key method doesn't support FIPS mode with current parameters")); + } #endif - SSL_CTX_set_private_key_method(ctx.ssl_ctx_.get(), private_key_method.get()); - } else { - // Load private key. - bio.reset(BIO_new_mem_buf(const_cast(tls_certificate.privateKey().data()), - tls_certificate.privateKey().size())); - RELEASE_ASSERT(bio != nullptr, ""); - bssl::UniquePtr pkey( - PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, - !tls_certificate.password().empty() - ? const_cast(tls_certificate.password().c_str()) - : nullptr)); - if (pkey == nullptr || !SSL_CTX_use_PrivateKey(ctx.ssl_ctx_.get(), pkey.get())) { - throw EnvoyException( - absl::StrCat("Failed to load private key from ", tls_certificate.privateKeyPath())); - } + SSL_CTX_set_private_key_method(ctx.ssl_ctx_.get(), private_key_method.get()); + } else { + // Load private key. + bio.reset(BIO_new_mem_buf(const_cast(tls_certificate.privateKey().data()), + tls_certificate.privateKey().size())); + RELEASE_ASSERT(bio != nullptr, ""); + bssl::UniquePtr pkey( + PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, + !tls_certificate.password().empty() + ? const_cast(tls_certificate.password().c_str()) + : nullptr)); + if (pkey == nullptr || !SSL_CTX_use_PrivateKey(ctx.ssl_ctx_.get(), pkey.get())) { + throw EnvoyException( + absl::StrCat("Failed to load private key from ", tls_certificate.privateKeyPath())); + } #ifdef BORINGSSL_FIPS - // Verify that private keys are passing FIPS pairwise consistency tests. - switch (pkey_id) { - case EVP_PKEY_EC: { - const EC_KEY* ecdsa_private_key = EVP_PKEY_get0_EC_KEY(pkey.get()); - if (!EC_KEY_check_fips(ecdsa_private_key)) { - throw EnvoyException(fmt::format("Failed to load private key from {}, ECDSA key failed " - "pairwise consistency test required in FIPS mode", - tls_certificate.privateKeyPath())); - } - } break; - case EVP_PKEY_RSA: { - RSA* rsa_private_key = EVP_PKEY_get0_RSA(pkey.get()); - if (!RSA_check_fips(rsa_private_key)) { - throw EnvoyException(fmt::format("Failed to load private key from {}, RSA key failed " - "pairwise consistency test required in FIPS mode", - tls_certificate.privateKeyPath())); + // Verify that private keys are passing FIPS pairwise consistency tests. + switch (pkey_id) { + case EVP_PKEY_EC: { + const EC_KEY* ecdsa_private_key = EVP_PKEY_get0_EC_KEY(pkey.get()); + if (!EC_KEY_check_fips(ecdsa_private_key)) { + throw EnvoyException(fmt::format("Failed to load private key from {}, ECDSA key failed " + "pairwise consistency test required in FIPS mode", + tls_certificate.privateKeyPath())); + } + } break; + case EVP_PKEY_RSA: { + RSA* rsa_private_key = EVP_PKEY_get0_RSA(pkey.get()); + if (!RSA_check_fips(rsa_private_key)) { + throw EnvoyException(fmt::format("Failed to load private key from {}, RSA key failed " + "pairwise consistency test required in FIPS mode", + tls_certificate.privateKeyPath())); + } + } break; } - } break; - } #endif + } } } @@ -986,7 +1002,7 @@ ServerContextImpl::ServerContextImpl(Stats::Scope& scope, const std::vector& server_names, TimeSource& time_source) : ContextImpl(scope, config, time_source), session_ticket_keys_(config.sessionTicketKeys()) { - if (config.tlsCertificates().empty()) { + if (config.tlsCertificates().empty() && !config.capabilities().provides_certificates) { throw EnvoyException("Server TlsCertificates must have a certificate specified"); } @@ -998,22 +1014,25 @@ ServerContextImpl::ServerContextImpl(Stats::Scope& scope, // First, configure the base context for ClientHello interception. // TODO(htuch): replace with SSL_IDENTITY when we have this as a means to do multi-cert in // BoringSSL. - SSL_CTX_set_select_certificate_cb( - tls_contexts_[0].ssl_ctx_.get(), - [](const SSL_CLIENT_HELLO* client_hello) -> ssl_select_cert_result_t { - return static_cast( - SSL_CTX_get_app_data(SSL_get_SSL_CTX(client_hello->ssl))) - ->selectTlsContext(client_hello); - }); + if (!config.capabilities().provides_certificates) { + SSL_CTX_set_select_certificate_cb( + tls_contexts_[0].ssl_ctx_.get(), + [](const SSL_CLIENT_HELLO* client_hello) -> ssl_select_cert_result_t { + return static_cast( + SSL_CTX_get_app_data(SSL_get_SSL_CTX(client_hello->ssl))) + ->selectTlsContext(client_hello); + }); + } for (auto& ctx : tls_contexts_) { - if (config.certificateValidationContext() != nullptr && + if (!config.capabilities().verifies_peer_certificates && + config.certificateValidationContext() != nullptr && !config.certificateValidationContext()->caCert().empty()) { ctx.addClientValidationContext(*config.certificateValidationContext(), config.requireClientCertificate()); } - if (!parsed_alpn_protocols_.empty()) { + if (!parsed_alpn_protocols_.empty() && !config.capabilities().handles_alpn_selection) { SSL_CTX_set_alpn_select_cb( ctx.ssl_ctx_.get(), [](SSL*, const unsigned char** out, unsigned char* outlen, const unsigned char* in, @@ -1023,9 +1042,11 @@ ServerContextImpl::ServerContextImpl(Stats::Scope& scope, this); } + // If the handshaker handles session tickets natively, don't call + // `SSL_CTX_set_tlsext_ticket_key_cb`. if (config.disableStatelessSessionResumption()) { SSL_CTX_set_options(ctx.ssl_ctx_.get(), SSL_OP_NO_TICKET); - } else if (!session_ticket_keys_.empty()) { + } else if (!session_ticket_keys_.empty() && !config.capabilities().handles_session_resumption) { SSL_CTX_set_tlsext_ticket_key_cb( ctx.ssl_ctx_.get(), [](SSL* ssl, uint8_t* key_name, uint8_t* iv, EVP_CIPHER_CTX* ctx, HMAC_CTX* hmac_ctx, @@ -1039,7 +1060,7 @@ ServerContextImpl::ServerContextImpl(Stats::Scope& scope, }); } - if (config.sessionTimeout()) { + if (config.sessionTimeout() && !config.capabilities().handles_session_resumption) { auto timeout = config.sessionTimeout().value().count(); SSL_CTX_set_timeout(ctx.ssl_ctx_.get(), uint32_t(timeout)); } @@ -1065,67 +1086,69 @@ ServerContextImpl::generateHashForSessionContextId(const std::vector= 0) { - X509_NAME_ENTRY* cn_entry = X509_NAME_get_entry(cert_subject, cn_index); - RELEASE_ASSERT(cn_entry != nullptr, "certificate subject CN should be present"); - - ASN1_STRING* cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry); - if (ASN1_STRING_length(cn_asn1) <= 0) { - throw EnvoyException("Invalid TLS context has an empty subject CN"); - } + if (!capabilities_.provides_certificates) { + for (const auto& ctx : tls_contexts_) { + X509* cert = SSL_CTX_get0_certificate(ctx.ssl_ctx_.get()); + RELEASE_ASSERT(cert != nullptr, "TLS context should have an active certificate"); + X509_NAME* cert_subject = X509_get_subject_name(cert); + RELEASE_ASSERT(cert_subject != nullptr, "TLS certificate should have a subject"); + + const int cn_index = X509_NAME_get_index_by_NID(cert_subject, NID_commonName, -1); + if (cn_index >= 0) { + X509_NAME_ENTRY* cn_entry = X509_NAME_get_entry(cert_subject, cn_index); + RELEASE_ASSERT(cn_entry != nullptr, "certificate subject CN should be present"); + + ASN1_STRING* cn_asn1 = X509_NAME_ENTRY_get_data(cn_entry); + if (ASN1_STRING_length(cn_asn1) <= 0) { + throw EnvoyException("Invalid TLS context has an empty subject CN"); + } - rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(cn_asn1), ASN1_STRING_length(cn_asn1)); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - } + rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(cn_asn1), ASN1_STRING_length(cn_asn1)); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + } - unsigned san_count = 0; - bssl::UniquePtr san_names(static_cast( - X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr))); - - if (san_names != nullptr) { - for (const GENERAL_NAME* san : san_names.get()) { - switch (san->type) { - case GEN_IPADD: - rc = EVP_DigestUpdate(md.get(), san->d.iPAddress->data, san->d.iPAddress->length); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - ++san_count; - break; - case GEN_DNS: - rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.dNSName), - ASN1_STRING_length(san->d.dNSName)); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - ++san_count; - break; - case GEN_URI: - rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.uniformResourceIdentifier), - ASN1_STRING_length(san->d.uniformResourceIdentifier)); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - ++san_count; - break; + unsigned san_count = 0; + bssl::UniquePtr san_names(static_cast( + X509_get_ext_d2i(cert, NID_subject_alt_name, nullptr, nullptr))); + + if (san_names != nullptr) { + for (const GENERAL_NAME* san : san_names.get()) { + switch (san->type) { + case GEN_IPADD: + rc = EVP_DigestUpdate(md.get(), san->d.iPAddress->data, san->d.iPAddress->length); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + ++san_count; + break; + case GEN_DNS: + rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.dNSName), + ASN1_STRING_length(san->d.dNSName)); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + ++san_count; + break; + case GEN_URI: + rc = EVP_DigestUpdate(md.get(), ASN1_STRING_data(san->d.uniformResourceIdentifier), + ASN1_STRING_length(san->d.uniformResourceIdentifier)); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + ++san_count; + break; + } } } - } - // It's possible that the certificate doesn't have a subject, but - // does have SANs. Make sure that we have one or the other. - if (cn_index < 0 && san_count == 0) { - throw EnvoyException("Invalid TLS context has neither subject CN nor SAN names"); - } + // It's possible that the certificate doesn't have a subject, but + // does have SANs. Make sure that we have one or the other. + if (cn_index < 0 && san_count == 0) { + throw EnvoyException("Invalid TLS context has neither subject CN nor SAN names"); + } - rc = X509_NAME_digest(X509_get_issuer_name(cert), EVP_sha256(), hash_buffer, &hash_length); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); - RELEASE_ASSERT(hash_length == SHA256_DIGEST_LENGTH, - fmt::format("invalid SHA256 hash length {}", hash_length)); + rc = X509_NAME_digest(X509_get_issuer_name(cert), EVP_sha256(), hash_buffer, &hash_length); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + RELEASE_ASSERT(hash_length == SHA256_DIGEST_LENGTH, + fmt::format("invalid SHA256 hash length {}", hash_length)); - rc = EVP_DigestUpdate(md.get(), hash_buffer, hash_length); - RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + rc = EVP_DigestUpdate(md.get(), hash_buffer, hash_length); + RELEASE_ASSERT(rc == 1, Utility::getLastCryptoError().value_or("")); + } } // Hash all the settings that affect whether the server will allow/accept diff --git a/source/extensions/transport_sockets/tls/context_impl.h b/source/extensions/transport_sockets/tls/context_impl.h index 5ea35a48228e..d40c4c88881c 100644 --- a/source/extensions/transport_sockets/tls/context_impl.h +++ b/source/extensions/transport_sockets/tls/context_impl.h @@ -209,6 +209,7 @@ class ContextImpl : public virtual Envoy::Ssl::Context { const Stats::StatName ssl_versions_; const Stats::StatName ssl_curves_; const Stats::StatName ssl_sigalgs_; + const Ssl::HandshakerCapabilities capabilities_; }; using ContextImplSharedPtr = std::shared_ptr; diff --git a/source/extensions/transport_sockets/tls/ssl_handshaker.cc b/source/extensions/transport_sockets/tls/ssl_handshaker.cc new file mode 100644 index 000000000000..7d8def3ecfd1 --- /dev/null +++ b/source/extensions/transport_sockets/tls/ssl_handshaker.cc @@ -0,0 +1,337 @@ +#include "extensions/transport_sockets/tls/ssl_handshaker.h" + +#include "envoy/stats/scope.h" + +#include "common/common/assert.h" +#include "common/common/empty_string.h" +#include "common/common/hex.h" +#include "common/http/headers.h" + +#include "extensions/transport_sockets/tls/utility.h" + +#include "absl/strings/str_replace.h" +#include "openssl/err.h" +#include "openssl/x509v3.h" + +using Envoy::Network::PostIoAction; + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +void SslExtendedSocketInfoImpl::setCertificateValidationStatus( + Envoy::Ssl::ClientValidationStatus validated) { + certificate_validation_status_ = validated; +} + +Envoy::Ssl::ClientValidationStatus SslExtendedSocketInfoImpl::certificateValidationStatus() const { + return certificate_validation_status_; +} + +SslHandshakerImpl::SslHandshakerImpl(bssl::UniquePtr ssl, int ssl_extended_socket_info_index, + Ssl::HandshakeCallbacks* handshake_callbacks) + : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks), + state_(Ssl::SocketState::PreHandshake) { + SSL_set_ex_data(ssl_.get(), ssl_extended_socket_info_index, &(this->extended_socket_info_)); +} + +bool SslHandshakerImpl::peerCertificatePresented() const { + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + return cert != nullptr; +} + +bool SslHandshakerImpl::peerCertificateValidated() const { + return extended_socket_info_.certificateValidationStatus() == + Envoy::Ssl::ClientValidationStatus::Validated; +} + +absl::Span SslHandshakerImpl::uriSanLocalCertificate() const { + if (!cached_uri_san_local_certificate_.empty()) { + return cached_uri_san_local_certificate_; + } + + // The cert object is not owned. + X509* cert = SSL_get_certificate(ssl()); + if (!cert) { + ASSERT(cached_uri_san_local_certificate_.empty()); + return cached_uri_san_local_certificate_; + } + cached_uri_san_local_certificate_ = Utility::getSubjectAltNames(*cert, GEN_URI); + return cached_uri_san_local_certificate_; +} + +absl::Span SslHandshakerImpl::dnsSansLocalCertificate() const { + if (!cached_dns_san_local_certificate_.empty()) { + return cached_dns_san_local_certificate_; + } + + X509* cert = SSL_get_certificate(ssl()); + if (!cert) { + ASSERT(cached_dns_san_local_certificate_.empty()); + return cached_dns_san_local_certificate_; + } + cached_dns_san_local_certificate_ = Utility::getSubjectAltNames(*cert, GEN_DNS); + return cached_dns_san_local_certificate_; +} + +const std::string& SslHandshakerImpl::sha256PeerCertificateDigest() const { + if (!cached_sha_256_peer_certificate_digest_.empty()) { + return cached_sha_256_peer_certificate_digest_; + } + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_sha_256_peer_certificate_digest_.empty()); + return cached_sha_256_peer_certificate_digest_; + } + + std::vector computed_hash(SHA256_DIGEST_LENGTH); + unsigned int n; + X509_digest(cert.get(), EVP_sha256(), computed_hash.data(), &n); + RELEASE_ASSERT(n == computed_hash.size(), ""); + cached_sha_256_peer_certificate_digest_ = Hex::encode(computed_hash); + return cached_sha_256_peer_certificate_digest_; +} + +const std::string& SslHandshakerImpl::sha1PeerCertificateDigest() const { + if (!cached_sha_1_peer_certificate_digest_.empty()) { + return cached_sha_1_peer_certificate_digest_; + } + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_sha_1_peer_certificate_digest_.empty()); + return cached_sha_1_peer_certificate_digest_; + } + + std::vector computed_hash(SHA_DIGEST_LENGTH); + unsigned int n; + X509_digest(cert.get(), EVP_sha1(), computed_hash.data(), &n); + RELEASE_ASSERT(n == computed_hash.size(), ""); + cached_sha_1_peer_certificate_digest_ = Hex::encode(computed_hash); + return cached_sha_1_peer_certificate_digest_; +} + +const std::string& SslHandshakerImpl::urlEncodedPemEncodedPeerCertificate() const { + if (!cached_url_encoded_pem_encoded_peer_certificate_.empty()) { + return cached_url_encoded_pem_encoded_peer_certificate_; + } + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_url_encoded_pem_encoded_peer_certificate_.empty()); + return cached_url_encoded_pem_encoded_peer_certificate_; + } + + bssl::UniquePtr buf(BIO_new(BIO_s_mem())); + RELEASE_ASSERT(buf != nullptr, ""); + RELEASE_ASSERT(PEM_write_bio_X509(buf.get(), cert.get()) == 1, ""); + const uint8_t* output; + size_t length; + RELEASE_ASSERT(BIO_mem_contents(buf.get(), &output, &length) == 1, ""); + absl::string_view pem(reinterpret_cast(output), length); + cached_url_encoded_pem_encoded_peer_certificate_ = absl::StrReplaceAll( + pem, {{"\n", "%0A"}, {" ", "%20"}, {"+", "%2B"}, {"/", "%2F"}, {"=", "%3D"}}); + return cached_url_encoded_pem_encoded_peer_certificate_; +} + +const std::string& SslHandshakerImpl::urlEncodedPemEncodedPeerCertificateChain() const { + if (!cached_url_encoded_pem_encoded_peer_cert_chain_.empty()) { + return cached_url_encoded_pem_encoded_peer_cert_chain_; + } + + STACK_OF(X509)* cert_chain = SSL_get_peer_full_cert_chain(ssl()); + if (cert_chain == nullptr) { + ASSERT(cached_url_encoded_pem_encoded_peer_cert_chain_.empty()); + return cached_url_encoded_pem_encoded_peer_cert_chain_; + } + + for (uint64_t i = 0; i < sk_X509_num(cert_chain); i++) { + X509* cert = sk_X509_value(cert_chain, i); + + bssl::UniquePtr buf(BIO_new(BIO_s_mem())); + RELEASE_ASSERT(buf != nullptr, ""); + RELEASE_ASSERT(PEM_write_bio_X509(buf.get(), cert) == 1, ""); + const uint8_t* output; + size_t length; + RELEASE_ASSERT(BIO_mem_contents(buf.get(), &output, &length) == 1, ""); + + absl::string_view pem(reinterpret_cast(output), length); + cached_url_encoded_pem_encoded_peer_cert_chain_ = absl::StrCat( + cached_url_encoded_pem_encoded_peer_cert_chain_, + absl::StrReplaceAll( + pem, {{"\n", "%0A"}, {" ", "%20"}, {"+", "%2B"}, {"/", "%2F"}, {"=", "%3D"}})); + } + return cached_url_encoded_pem_encoded_peer_cert_chain_; +} + +absl::Span SslHandshakerImpl::uriSanPeerCertificate() const { + if (!cached_uri_san_peer_certificate_.empty()) { + return cached_uri_san_peer_certificate_; + } + + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_uri_san_peer_certificate_.empty()); + return cached_uri_san_peer_certificate_; + } + cached_uri_san_peer_certificate_ = Utility::getSubjectAltNames(*cert, GEN_URI); + return cached_uri_san_peer_certificate_; +} + +absl::Span SslHandshakerImpl::dnsSansPeerCertificate() const { + if (!cached_dns_san_peer_certificate_.empty()) { + return cached_dns_san_peer_certificate_; + } + + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_dns_san_peer_certificate_.empty()); + return cached_dns_san_peer_certificate_; + } + cached_dns_san_peer_certificate_ = Utility::getSubjectAltNames(*cert, GEN_DNS); + return cached_dns_san_peer_certificate_; +} + +uint16_t SslHandshakerImpl::ciphersuiteId() const { + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + if (cipher == nullptr) { + return 0xffff; + } + + // From the OpenSSL docs: + // SSL_CIPHER_get_id returns |cipher|'s id. It may be cast to a |uint16_t| to + // get the cipher suite value. + return static_cast(SSL_CIPHER_get_id(cipher)); +} + +std::string SslHandshakerImpl::ciphersuiteString() const { + const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); + if (cipher == nullptr) { + return {}; + } + + return SSL_CIPHER_get_name(cipher); +} + +const std::string& SslHandshakerImpl::tlsVersion() const { + if (!cached_tls_version_.empty()) { + return cached_tls_version_; + } + cached_tls_version_ = SSL_get_version(ssl()); + return cached_tls_version_; +} + +Network::PostIoAction SslHandshakerImpl::doHandshake() { + ASSERT(state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent); + int rc = SSL_do_handshake(ssl()); + if (rc == 1) { + state_ = Ssl::SocketState::HandshakeComplete; + handshake_callbacks_->onSuccess(ssl()); + + // It's possible that we closed during the handshake callback. + return handshake_callbacks_->connection().state() == Network::Connection::State::Open + ? PostIoAction::KeepOpen + : PostIoAction::Close; + } else { + int err = SSL_get_error(ssl(), rc); + switch (err) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return PostIoAction::KeepOpen; + case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: + state_ = Ssl::SocketState::HandshakeInProgress; + return PostIoAction::KeepOpen; + default: + handshake_callbacks_->onFailure(); + return PostIoAction::Close; + } + } +} + +const std::string& SslHandshakerImpl::serialNumberPeerCertificate() const { + if (!cached_serial_number_peer_certificate_.empty()) { + return cached_serial_number_peer_certificate_; + } + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_serial_number_peer_certificate_.empty()); + return cached_serial_number_peer_certificate_; + } + cached_serial_number_peer_certificate_ = Utility::getSerialNumberFromCertificate(*cert.get()); + return cached_serial_number_peer_certificate_; +} + +const std::string& SslHandshakerImpl::issuerPeerCertificate() const { + if (!cached_issuer_peer_certificate_.empty()) { + return cached_issuer_peer_certificate_; + } + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_issuer_peer_certificate_.empty()); + return cached_issuer_peer_certificate_; + } + cached_issuer_peer_certificate_ = Utility::getIssuerFromCertificate(*cert); + return cached_issuer_peer_certificate_; +} + +const std::string& SslHandshakerImpl::subjectPeerCertificate() const { + if (!cached_subject_peer_certificate_.empty()) { + return cached_subject_peer_certificate_; + } + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + ASSERT(cached_subject_peer_certificate_.empty()); + return cached_subject_peer_certificate_; + } + cached_subject_peer_certificate_ = Utility::getSubjectFromCertificate(*cert); + return cached_subject_peer_certificate_; +} + +const std::string& SslHandshakerImpl::subjectLocalCertificate() const { + if (!cached_subject_local_certificate_.empty()) { + return cached_subject_local_certificate_; + } + X509* cert = SSL_get_certificate(ssl()); + if (!cert) { + ASSERT(cached_subject_local_certificate_.empty()); + return cached_subject_local_certificate_; + } + cached_subject_local_certificate_ = Utility::getSubjectFromCertificate(*cert); + return cached_subject_local_certificate_; +} + +absl::optional SslHandshakerImpl::validFromPeerCertificate() const { + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + return absl::nullopt; + } + return Utility::getValidFrom(*cert); +} + +absl::optional SslHandshakerImpl::expirationPeerCertificate() const { + bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); + if (!cert) { + return absl::nullopt; + } + return Utility::getExpirationTime(*cert); +} + +const std::string& SslHandshakerImpl::sessionId() const { + if (!cached_session_id_.empty()) { + return cached_session_id_; + } + SSL_SESSION* session = SSL_get_session(ssl()); + if (session == nullptr) { + ASSERT(cached_session_id_.empty()); + return cached_session_id_; + } + + unsigned int session_id_length = 0; + const uint8_t* session_id = SSL_SESSION_get_id(session, &session_id_length); + cached_session_id_ = Hex::encode(session_id, session_id_length); + return cached_session_id_; +} + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/ssl_handshaker.h b/source/extensions/transport_sockets/tls/ssl_handshaker.h new file mode 100644 index 000000000000..8eaec861a8f1 --- /dev/null +++ b/source/extensions/transport_sockets/tls/ssl_handshaker.h @@ -0,0 +1,147 @@ +#pragma once + +#include +#include + +#include "envoy/network/connection.h" +#include "envoy/network/transport_socket.h" +#include "envoy/secret/secret_callbacks.h" +#include "envoy/ssl/handshaker.h" +#include "envoy/ssl/private_key/private_key_callbacks.h" +#include "envoy/ssl/ssl_socket_extended_info.h" +#include "envoy/ssl/ssl_socket_state.h" +#include "envoy/stats/scope.h" +#include "envoy/stats/stats_macros.h" + +#include "common/common/logger.h" + +#include "extensions/transport_sockets/tls/utility.h" + +#include "absl/container/node_hash_map.h" +#include "absl/synchronization/mutex.h" +#include "absl/types/optional.h" +#include "openssl/ssl.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { + +class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { +public: + void setCertificateValidationStatus(Envoy::Ssl::ClientValidationStatus validated) override; + Envoy::Ssl::ClientValidationStatus certificateValidationStatus() const override; + +private: + Envoy::Ssl::ClientValidationStatus certificate_validation_status_{ + Envoy::Ssl::ClientValidationStatus::NotValidated}; +}; + +class SslHandshakerImpl : public Ssl::ConnectionInfo, public Ssl::Handshaker { +public: + SslHandshakerImpl(bssl::UniquePtr ssl, int ssl_extended_socket_info_index, + Ssl::HandshakeCallbacks* handshake_callbacks); + + // Ssl::ConnectionInfo + bool peerCertificatePresented() const override; + bool peerCertificateValidated() const override; + absl::Span uriSanLocalCertificate() const override; + const std::string& sha256PeerCertificateDigest() const override; + const std::string& sha1PeerCertificateDigest() const override; + const std::string& serialNumberPeerCertificate() const override; + const std::string& issuerPeerCertificate() const override; + const std::string& subjectPeerCertificate() const override; + const std::string& subjectLocalCertificate() const override; + absl::Span uriSanPeerCertificate() const override; + const std::string& urlEncodedPemEncodedPeerCertificate() const override; + const std::string& urlEncodedPemEncodedPeerCertificateChain() const override; + absl::Span dnsSansPeerCertificate() const override; + absl::Span dnsSansLocalCertificate() const override; + absl::optional validFromPeerCertificate() const override; + absl::optional expirationPeerCertificate() const override; + const std::string& sessionId() const override; + uint16_t ciphersuiteId() const override; + std::string ciphersuiteString() const override; + const std::string& tlsVersion() const override; + + // Ssl::Handshaker + Network::PostIoAction doHandshake() override; + + Ssl::SocketState state() { return state_; } + void setState(Ssl::SocketState state) { state_ = state; } + SSL* ssl() const { return ssl_.get(); } + Ssl::HandshakeCallbacks* handshakeCallbacks() { return handshake_callbacks_; } + + bssl::UniquePtr ssl_; + +private: + Ssl::HandshakeCallbacks* handshake_callbacks_; + + Ssl::SocketState state_; + mutable std::vector cached_uri_san_local_certificate_; + mutable std::string cached_sha_256_peer_certificate_digest_; + mutable std::string cached_sha_1_peer_certificate_digest_; + mutable std::string cached_serial_number_peer_certificate_; + mutable std::string cached_issuer_peer_certificate_; + mutable std::string cached_subject_peer_certificate_; + mutable std::string cached_subject_local_certificate_; + mutable std::vector cached_uri_san_peer_certificate_; + mutable std::string cached_url_encoded_pem_encoded_peer_certificate_; + mutable std::string cached_url_encoded_pem_encoded_peer_cert_chain_; + mutable std::vector cached_dns_san_peer_certificate_; + mutable std::vector cached_dns_san_local_certificate_; + mutable std::string cached_session_id_; + mutable std::string cached_tls_version_; + mutable SslExtendedSocketInfoImpl extended_socket_info_; +}; + +using SslHandshakerImplSharedPtr = std::shared_ptr; + +class HandshakerFactoryContextImpl : public Ssl::HandshakerFactoryContext { +public: + HandshakerFactoryContextImpl(Api::Api& api, absl::string_view alpn_protocols) + : api_(api), alpn_protocols_(alpn_protocols) {} + + // HandshakerFactoryContext + Api::Api& api() override { return api_; } + absl::string_view alpnProtocols() const override { return alpn_protocols_; } + +private: + Api::Api& api_; + const std::string alpn_protocols_; +}; + +class HandshakerFactoryImpl : public Ssl::HandshakerFactory { +public: + std::string name() const override { return "envoy.default_tls_handshaker"; } + + ProtobufTypes::MessagePtr createEmptyConfigProto() override { + return ProtobufTypes::MessagePtr{new Envoy::ProtobufWkt::Struct()}; + } + + Ssl::HandshakerFactoryCb createHandshakerCb(const Protobuf::Message&, + Ssl::HandshakerFactoryContext&, + ProtobufMessage::ValidationVisitor&) override { + // The default HandshakerImpl doesn't take a config or use the HandshakerFactoryContext. + return [](bssl::UniquePtr ssl, int ssl_extended_socket_info_index, + Ssl::HandshakeCallbacks* handshake_callbacks) { + return std::make_shared(std::move(ssl), ssl_extended_socket_info_index, + handshake_callbacks); + }; + } + + Ssl::HandshakerCapabilities capabilities() const override { + // The default handshaker impl requires Envoy to handle all enumerated behaviors. + return Ssl::HandshakerCapabilities{}; + } + + static HandshakerFactory* getDefaultHandshakerFactory() { + static HandshakerFactoryImpl default_handshaker_factory; + return &default_handshaker_factory; + } +}; + +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index 1b11683b7eee..485468443096 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -9,6 +9,7 @@ #include "common/runtime/runtime_features.h" #include "extensions/transport_sockets/tls/io_handle_bio.h" +#include "extensions/transport_sockets/tls/ssl_handshaker.h" #include "extensions/transport_sockets/tls/utility.h" #include "absl/strings/str_replace.h" @@ -45,12 +46,13 @@ class NotReadySslSocket : public Network::TransportSocket { } // namespace SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, - const Network::TransportSocketOptionsSharedPtr& transport_socket_options) + const Network::TransportSocketOptionsSharedPtr& transport_socket_options, + Ssl::HandshakerFactoryCb handshaker_factory_cb) : transport_socket_options_(transport_socket_options), - ctx_(std::dynamic_pointer_cast(ctx)) { - bssl::UniquePtr ssl = ctx_->newSsl(transport_socket_options_.get()); - info_ = std::make_shared(std::move(ssl), ctx_, this); - + ctx_(std::dynamic_pointer_cast(ctx)), + info_(std::dynamic_pointer_cast( + handshaker_factory_cb(ctx_->newSsl(transport_socket_options_.get()), + ctx_->sslExtendedSocketInfoIndex(), this))) { if (state == InitialState::Client) { SSL_set_connect_state(rawSsl()); } else { @@ -178,9 +180,7 @@ void SslSocket::onPrivateKeyMethodComplete() { } } -Network::Connection::State SslSocket::connectionState() const { - return callbacks_->connection().state(); -} +Network::Connection& SslSocket::connection() const { return callbacks_->connection(); } void SslSocket::onSuccess(SSL* ssl) { ctx_->logHandshake(ssl); @@ -291,177 +291,6 @@ void SslSocket::shutdownSsl() { } } -void SslExtendedSocketInfoImpl::setCertificateValidationStatus( - Envoy::Ssl::ClientValidationStatus validated) { - certificate_validation_status_ = validated; -} - -Envoy::Ssl::ClientValidationStatus SslExtendedSocketInfoImpl::certificateValidationStatus() const { - return certificate_validation_status_; -} - -SslHandshakerImpl::SslHandshakerImpl(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, - Ssl::HandshakeCallbacks* handshake_callbacks) - : ssl_(std::move(ssl)), handshake_callbacks_(handshake_callbacks), - state_(Ssl::SocketState::PreHandshake) { - SSL_set_ex_data(ssl_.get(), ctx->sslExtendedSocketInfoIndex(), &(this->extended_socket_info_)); -} - -bool SslHandshakerImpl::peerCertificatePresented() const { - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - return cert != nullptr; -} - -bool SslHandshakerImpl::peerCertificateValidated() const { - return extended_socket_info_.certificateValidationStatus() == - Envoy::Ssl::ClientValidationStatus::Validated; -} - -absl::Span SslHandshakerImpl::uriSanLocalCertificate() const { - if (!cached_uri_san_local_certificate_.empty()) { - return cached_uri_san_local_certificate_; - } - - // The cert object is not owned. - X509* cert = SSL_get_certificate(ssl()); - if (!cert) { - ASSERT(cached_uri_san_local_certificate_.empty()); - return cached_uri_san_local_certificate_; - } - cached_uri_san_local_certificate_ = Utility::getSubjectAltNames(*cert, GEN_URI); - return cached_uri_san_local_certificate_; -} - -absl::Span SslHandshakerImpl::dnsSansLocalCertificate() const { - if (!cached_dns_san_local_certificate_.empty()) { - return cached_dns_san_local_certificate_; - } - - X509* cert = SSL_get_certificate(ssl()); - if (!cert) { - ASSERT(cached_dns_san_local_certificate_.empty()); - return cached_dns_san_local_certificate_; - } - cached_dns_san_local_certificate_ = Utility::getSubjectAltNames(*cert, GEN_DNS); - return cached_dns_san_local_certificate_; -} - -const std::string& SslHandshakerImpl::sha256PeerCertificateDigest() const { - if (!cached_sha_256_peer_certificate_digest_.empty()) { - return cached_sha_256_peer_certificate_digest_; - } - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_sha_256_peer_certificate_digest_.empty()); - return cached_sha_256_peer_certificate_digest_; - } - - std::vector computed_hash(SHA256_DIGEST_LENGTH); - unsigned int n; - X509_digest(cert.get(), EVP_sha256(), computed_hash.data(), &n); - RELEASE_ASSERT(n == computed_hash.size(), ""); - cached_sha_256_peer_certificate_digest_ = Hex::encode(computed_hash); - return cached_sha_256_peer_certificate_digest_; -} - -const std::string& SslHandshakerImpl::sha1PeerCertificateDigest() const { - if (!cached_sha_1_peer_certificate_digest_.empty()) { - return cached_sha_1_peer_certificate_digest_; - } - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_sha_1_peer_certificate_digest_.empty()); - return cached_sha_1_peer_certificate_digest_; - } - - std::vector computed_hash(SHA_DIGEST_LENGTH); - unsigned int n; - X509_digest(cert.get(), EVP_sha1(), computed_hash.data(), &n); - RELEASE_ASSERT(n == computed_hash.size(), ""); - cached_sha_1_peer_certificate_digest_ = Hex::encode(computed_hash); - return cached_sha_1_peer_certificate_digest_; -} - -const std::string& SslHandshakerImpl::urlEncodedPemEncodedPeerCertificate() const { - if (!cached_url_encoded_pem_encoded_peer_certificate_.empty()) { - return cached_url_encoded_pem_encoded_peer_certificate_; - } - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_url_encoded_pem_encoded_peer_certificate_.empty()); - return cached_url_encoded_pem_encoded_peer_certificate_; - } - - bssl::UniquePtr buf(BIO_new(BIO_s_mem())); - RELEASE_ASSERT(buf != nullptr, ""); - RELEASE_ASSERT(PEM_write_bio_X509(buf.get(), cert.get()) == 1, ""); - const uint8_t* output; - size_t length; - RELEASE_ASSERT(BIO_mem_contents(buf.get(), &output, &length) == 1, ""); - absl::string_view pem(reinterpret_cast(output), length); - cached_url_encoded_pem_encoded_peer_certificate_ = absl::StrReplaceAll( - pem, {{"\n", "%0A"}, {" ", "%20"}, {"+", "%2B"}, {"/", "%2F"}, {"=", "%3D"}}); - return cached_url_encoded_pem_encoded_peer_certificate_; -} - -const std::string& SslHandshakerImpl::urlEncodedPemEncodedPeerCertificateChain() const { - if (!cached_url_encoded_pem_encoded_peer_cert_chain_.empty()) { - return cached_url_encoded_pem_encoded_peer_cert_chain_; - } - - STACK_OF(X509)* cert_chain = SSL_get_peer_full_cert_chain(ssl()); - if (cert_chain == nullptr) { - ASSERT(cached_url_encoded_pem_encoded_peer_cert_chain_.empty()); - return cached_url_encoded_pem_encoded_peer_cert_chain_; - } - - for (uint64_t i = 0; i < sk_X509_num(cert_chain); i++) { - X509* cert = sk_X509_value(cert_chain, i); - - bssl::UniquePtr buf(BIO_new(BIO_s_mem())); - RELEASE_ASSERT(buf != nullptr, ""); - RELEASE_ASSERT(PEM_write_bio_X509(buf.get(), cert) == 1, ""); - const uint8_t* output; - size_t length; - RELEASE_ASSERT(BIO_mem_contents(buf.get(), &output, &length) == 1, ""); - - absl::string_view pem(reinterpret_cast(output), length); - cached_url_encoded_pem_encoded_peer_cert_chain_ = absl::StrCat( - cached_url_encoded_pem_encoded_peer_cert_chain_, - absl::StrReplaceAll( - pem, {{"\n", "%0A"}, {" ", "%20"}, {"+", "%2B"}, {"/", "%2F"}, {"=", "%3D"}})); - } - return cached_url_encoded_pem_encoded_peer_cert_chain_; -} - -absl::Span SslHandshakerImpl::uriSanPeerCertificate() const { - if (!cached_uri_san_peer_certificate_.empty()) { - return cached_uri_san_peer_certificate_; - } - - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_uri_san_peer_certificate_.empty()); - return cached_uri_san_peer_certificate_; - } - cached_uri_san_peer_certificate_ = Utility::getSubjectAltNames(*cert, GEN_URI); - return cached_uri_san_peer_certificate_; -} - -absl::Span SslHandshakerImpl::dnsSansPeerCertificate() const { - if (!cached_dns_san_peer_certificate_.empty()) { - return cached_dns_san_peer_certificate_; - } - - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_dns_san_peer_certificate_.empty()); - return cached_dns_san_peer_certificate_; - } - cached_dns_san_peer_certificate_ = Utility::getSubjectAltNames(*cert, GEN_DNS); - return cached_dns_san_peer_certificate_; -} - void SslSocket::closeSocket(Network::ConnectionEvent) { // Unregister the SSL connection object from private key method providers. for (auto const& provider : ctx_->getPrivateKeyMethodProviders()) { @@ -484,148 +313,8 @@ std::string SslSocket::protocol() const { return std::string(reinterpret_cast(proto), proto_len); } -uint16_t SslHandshakerImpl::ciphersuiteId() const { - const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); - if (cipher == nullptr) { - return 0xffff; - } - - // From the OpenSSL docs: - // SSL_CIPHER_get_id returns |cipher|'s id. It may be cast to a |uint16_t| to - // get the cipher suite value. - return static_cast(SSL_CIPHER_get_id(cipher)); -} - -std::string SslHandshakerImpl::ciphersuiteString() const { - const SSL_CIPHER* cipher = SSL_get_current_cipher(ssl()); - if (cipher == nullptr) { - return {}; - } - - return SSL_CIPHER_get_name(cipher); -} - -const std::string& SslHandshakerImpl::tlsVersion() const { - if (!cached_tls_version_.empty()) { - return cached_tls_version_; - } - cached_tls_version_ = SSL_get_version(ssl()); - return cached_tls_version_; -} - -Network::PostIoAction SslHandshakerImpl::doHandshake() { - ASSERT(state_ != Ssl::SocketState::HandshakeComplete && state_ != Ssl::SocketState::ShutdownSent); - int rc = SSL_do_handshake(ssl()); - if (rc == 1) { - state_ = Ssl::SocketState::HandshakeComplete; - handshake_callbacks_->onSuccess(ssl()); - - // It's possible that we closed during the handshake callback. - return handshake_callbacks_->connectionState() == Network::Connection::State::Open - ? PostIoAction::KeepOpen - : PostIoAction::Close; - } else { - int err = SSL_get_error(ssl(), rc); - switch (err) { - case SSL_ERROR_WANT_READ: - case SSL_ERROR_WANT_WRITE: - return PostIoAction::KeepOpen; - case SSL_ERROR_WANT_PRIVATE_KEY_OPERATION: - state_ = Ssl::SocketState::HandshakeInProgress; - return PostIoAction::KeepOpen; - default: - handshake_callbacks_->onFailure(); - return PostIoAction::Close; - } - } -} - absl::string_view SslSocket::failureReason() const { return failure_reason_; } -const std::string& SslHandshakerImpl::serialNumberPeerCertificate() const { - if (!cached_serial_number_peer_certificate_.empty()) { - return cached_serial_number_peer_certificate_; - } - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_serial_number_peer_certificate_.empty()); - return cached_serial_number_peer_certificate_; - } - cached_serial_number_peer_certificate_ = Utility::getSerialNumberFromCertificate(*cert.get()); - return cached_serial_number_peer_certificate_; -} - -const std::string& SslHandshakerImpl::issuerPeerCertificate() const { - if (!cached_issuer_peer_certificate_.empty()) { - return cached_issuer_peer_certificate_; - } - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_issuer_peer_certificate_.empty()); - return cached_issuer_peer_certificate_; - } - cached_issuer_peer_certificate_ = Utility::getIssuerFromCertificate(*cert); - return cached_issuer_peer_certificate_; -} - -const std::string& SslHandshakerImpl::subjectPeerCertificate() const { - if (!cached_subject_peer_certificate_.empty()) { - return cached_subject_peer_certificate_; - } - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - ASSERT(cached_subject_peer_certificate_.empty()); - return cached_subject_peer_certificate_; - } - cached_subject_peer_certificate_ = Utility::getSubjectFromCertificate(*cert); - return cached_subject_peer_certificate_; -} - -const std::string& SslHandshakerImpl::subjectLocalCertificate() const { - if (!cached_subject_local_certificate_.empty()) { - return cached_subject_local_certificate_; - } - X509* cert = SSL_get_certificate(ssl()); - if (!cert) { - ASSERT(cached_subject_local_certificate_.empty()); - return cached_subject_local_certificate_; - } - cached_subject_local_certificate_ = Utility::getSubjectFromCertificate(*cert); - return cached_subject_local_certificate_; -} - -absl::optional SslHandshakerImpl::validFromPeerCertificate() const { - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - return absl::nullopt; - } - return Utility::getValidFrom(*cert); -} - -absl::optional SslHandshakerImpl::expirationPeerCertificate() const { - bssl::UniquePtr cert(SSL_get_peer_certificate(ssl())); - if (!cert) { - return absl::nullopt; - } - return Utility::getExpirationTime(*cert); -} - -const std::string& SslHandshakerImpl::sessionId() const { - if (!cached_session_id_.empty()) { - return cached_session_id_; - } - SSL_SESSION* session = SSL_get_session(ssl()); - if (session == nullptr) { - ASSERT(cached_session_id_.empty()); - return cached_session_id_; - } - - unsigned int session_id_length = 0; - const uint8_t* session_id = SSL_SESSION_get_id(session, &session_id_length); - cached_session_id_ = Hex::encode(session_id, session_id_length); - return cached_session_id_; -} - namespace { SslSocketFactoryStats generateStats(const std::string& prefix, Stats::Scope& store) { return { @@ -654,7 +343,7 @@ Network::TransportSocketPtr ClientSslSocketFactory::createTransportSocket( } if (ssl_ctx) { return std::make_unique(std::move(ssl_ctx), InitialState::Client, - transport_socket_options); + transport_socket_options, config_->createHandshaker()); } else { ENVOY_LOG(debug, "Create NotReadySslSocket"); stats_.upstream_context_secrets_not_ready_.inc(); @@ -694,7 +383,8 @@ ServerSslSocketFactory::createTransportSocket(Network::TransportSocketOptionsSha ssl_ctx = ssl_ctx_; } if (ssl_ctx) { - return std::make_unique(std::move(ssl_ctx), InitialState::Server, nullptr); + return std::make_unique(std::move(ssl_ctx), InitialState::Server, nullptr, + config_->createHandshaker()); } else { ENVOY_LOG(debug, "Create NotReadySslSocket"); stats_.downstream_context_secrets_not_ready_.inc(); diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index 2fcc78445cd2..ba73cc5d6ac6 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -16,6 +16,7 @@ #include "common/common/logger.h" #include "extensions/transport_sockets/tls/context_impl.h" +#include "extensions/transport_sockets/tls/ssl_handshaker.h" #include "extensions/transport_sockets/tls/utility.h" #include "absl/container/node_hash_map.h" @@ -42,82 +43,14 @@ struct SslSocketFactoryStats { enum class InitialState { Client, Server }; -class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { -public: - void setCertificateValidationStatus(Envoy::Ssl::ClientValidationStatus validated) override; - Envoy::Ssl::ClientValidationStatus certificateValidationStatus() const override; - -private: - Envoy::Ssl::ClientValidationStatus certificate_validation_status_{ - Envoy::Ssl::ClientValidationStatus::NotValidated}; -}; - -class SslHandshakerImpl : public Envoy::Ssl::ConnectionInfo, public Envoy::Ssl::Handshaker { -public: - SslHandshakerImpl(bssl::UniquePtr ssl, ContextImplSharedPtr ctx, - Ssl::HandshakeCallbacks* handshake_callbacks); - - // Ssl::ConnectionInfo - bool peerCertificatePresented() const override; - bool peerCertificateValidated() const override; - absl::Span uriSanLocalCertificate() const override; - const std::string& sha256PeerCertificateDigest() const override; - const std::string& sha1PeerCertificateDigest() const override; - const std::string& serialNumberPeerCertificate() const override; - const std::string& issuerPeerCertificate() const override; - const std::string& subjectPeerCertificate() const override; - const std::string& subjectLocalCertificate() const override; - absl::Span uriSanPeerCertificate() const override; - const std::string& urlEncodedPemEncodedPeerCertificate() const override; - const std::string& urlEncodedPemEncodedPeerCertificateChain() const override; - absl::Span dnsSansPeerCertificate() const override; - absl::Span dnsSansLocalCertificate() const override; - absl::optional validFromPeerCertificate() const override; - absl::optional expirationPeerCertificate() const override; - const std::string& sessionId() const override; - uint16_t ciphersuiteId() const override; - std::string ciphersuiteString() const override; - const std::string& tlsVersion() const override; - - // Ssl::Handshaker - Network::PostIoAction doHandshake() override; - - Ssl::SocketState state() { return state_; } - void setState(Ssl::SocketState state) { state_ = state; } - SSL* ssl() const { return ssl_.get(); } - - bssl::UniquePtr ssl_; - -private: - Ssl::HandshakeCallbacks* handshake_callbacks_; - - Ssl::SocketState state_; - mutable std::vector cached_uri_san_local_certificate_; - mutable std::string cached_sha_256_peer_certificate_digest_; - mutable std::string cached_sha_1_peer_certificate_digest_; - mutable std::string cached_serial_number_peer_certificate_; - mutable std::string cached_issuer_peer_certificate_; - mutable std::string cached_subject_peer_certificate_; - mutable std::string cached_subject_local_certificate_; - mutable std::vector cached_uri_san_peer_certificate_; - mutable std::string cached_url_encoded_pem_encoded_peer_certificate_; - mutable std::string cached_url_encoded_pem_encoded_peer_cert_chain_; - mutable std::vector cached_dns_san_peer_certificate_; - mutable std::vector cached_dns_san_local_certificate_; - mutable std::string cached_session_id_; - mutable std::string cached_tls_version_; - mutable SslExtendedSocketInfoImpl extended_socket_info_; -}; - -using SslHandshakerImplSharedPtr = std::shared_ptr; - class SslSocket : public Network::TransportSocket, public Envoy::Ssl::PrivateKeyConnectionCallbacks, public Ssl::HandshakeCallbacks, protected Logger::Loggable { public: SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, - const Network::TransportSocketOptionsSharedPtr& transport_socket_options); + const Network::TransportSocketOptionsSharedPtr& transport_socket_options, + Ssl::HandshakerFactoryCb handshaker_factory_cb); // Network::TransportSocket void setTransportSocketCallbacks(Network::TransportSocketCallbacks& callbacks) override; @@ -132,7 +65,7 @@ class SslSocket : public Network::TransportSocket, // Ssl::PrivateKeyConnectionCallbacks void onPrivateKeyMethodComplete() override; // Ssl::HandshakeCallbacks - Network::Connection::State connectionState() const override; + Network::Connection& connection() const override; void onSuccess(SSL* ssl) override; void onFailure() override; diff --git a/test/extensions/transport_sockets/tls/BUILD b/test/extensions/transport_sockets/tls/BUILD index 58b897c06f15..cdd5963d7ee4 100644 --- a/test/extensions/transport_sockets/tls/BUILD +++ b/test/extensions/transport_sockets/tls/BUILD @@ -160,3 +160,26 @@ envoy_cc_test_library( "@envoy_api//envoy/extensions/transport_sockets/tls/v3:pkg_cc_proto", ], ) + +envoy_cc_test( + name = "handshaker_test", + srcs = ["handshaker_test.cc"], + data = [ + "gen_unittest_certs.sh", + "//test/config/integration/certs", + "//test/extensions/transport_sockets/tls/test_data:certs", + ], + external_deps = ["ssl"], + deps = [ + ":ssl_socket_test", + ":ssl_test_utils", + "//source/common/stream_info:stream_info_lib", + "//source/extensions/transport_sockets/tls:ssl_handshaker_lib", + "//test/mocks/buffer:buffer_mocks", + "//test/mocks/network:network_mocks", + "//test/mocks/runtime:runtime_mocks", + "//test/mocks/server:server_mocks", + "//test/mocks/ssl:ssl_mocks", + "//test/mocks/stats:stats_mocks", + ], +) diff --git a/test/extensions/transport_sockets/tls/handshaker_test.cc b/test/extensions/transport_sockets/tls/handshaker_test.cc new file mode 100644 index 000000000000..77623f9e13ae --- /dev/null +++ b/test/extensions/transport_sockets/tls/handshaker_test.cc @@ -0,0 +1,245 @@ +#include + +#include "envoy/network/transport_socket.h" +#include "envoy/ssl/handshaker.h" + +#include "common/stream_info/stream_info_impl.h" + +#include "extensions/transport_sockets/tls/ssl_handshaker.h" + +#include "test/extensions/transport_sockets/tls/ssl_certs_test.h" +#include "test/mocks/network/connection.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" +#include "openssl/evp.h" +#include "openssl/hmac.h" +#include "openssl/ssl.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Tls { +namespace { + +using ::testing::NiceMock; +using ::testing::Return; +using ::testing::StrictMock; + +// A callback shaped like pem_password_cb. +// See https://www.openssl.org/docs/man1.1.0/man3/pem_password_cb.html. +int pemPasswordCallback(char* buf, int buf_size, int, void* u) { + if (u == nullptr) { + return 0; + } + std::string passphrase = *reinterpret_cast(u); + RELEASE_ASSERT(buf_size >= static_cast(passphrase.size()), + "Passphrase was larger than buffer."); + memcpy(buf, passphrase.data(), passphrase.size()); + return passphrase.size(); +} + +class MockHandshakeCallbacks : public Ssl::HandshakeCallbacks { +public: + ~MockHandshakeCallbacks() override = default; + MOCK_METHOD(Network::Connection&, connection, (), (const, override)); + MOCK_METHOD(void, onSuccess, (SSL*), (override)); + MOCK_METHOD(void, onFailure, (), (override)); +}; + +class HandshakerTest : public SslCertsTest { +protected: + HandshakerTest() + : dispatcher_(api_->allocateDispatcher("test_thread")), stream_info_(api_->timeSource()), + client_ctx_(SSL_CTX_new(TLS_method())), server_ctx_(SSL_CTX_new(TLS_method())) {} + + void SetUp() override { + // Set up key and cert, initialize two SSL objects and a pair of BIOs for + // handshaking. + auto key = makeKey(); + auto cert = makeCert(); + auto chain = std::vector{cert.get()}; + + server_ssl_ = bssl::UniquePtr(SSL_new(server_ctx_.get())); + SSL_set_accept_state(server_ssl_.get()); + ASSERT_NE(key, nullptr); + ASSERT_EQ(1, SSL_set_chain_and_key(server_ssl_.get(), chain.data(), chain.size(), key.get(), + nullptr)); + + client_ssl_ = bssl::UniquePtr(SSL_new(client_ctx_.get())); + SSL_set_connect_state(client_ssl_.get()); + + ASSERT_EQ(1, BIO_new_bio_pair(&client_bio_, kBufferLength, &server_bio_, kBufferLength)); + + BIO_up_ref(client_bio_); + BIO_up_ref(server_bio_); + SSL_set0_rbio(client_ssl_.get(), client_bio_); + SSL_set0_wbio(client_ssl_.get(), client_bio_); + SSL_set0_rbio(server_ssl_.get(), server_bio_); + SSL_set0_wbio(server_ssl_.get(), server_bio_); + } + + // Read in key.pem and return a new private key. + bssl::UniquePtr makeKey() { + std::string file = TestEnvironment::readFileToStringForTest( + TestEnvironment::substitute("{{ test_tmpdir }}/unittestkey.pem")); + std::string passphrase = ""; + bssl::UniquePtr bio(BIO_new_mem_buf(file.data(), file.size())); + + bssl::UniquePtr key(EVP_PKEY_new()); + + RSA* rsa = PEM_read_bio_RSAPrivateKey(bio.get(), nullptr, &pemPasswordCallback, &passphrase); + RELEASE_ASSERT(rsa != nullptr, "PEM_read_bio_RSAPrivateKey failed."); + RELEASE_ASSERT(1 == EVP_PKEY_assign_RSA(key.get(), rsa), "EVP_PKEY_assign_RSA failed."); + return key; + } + + // Read in cert.pem and return a certificate. + bssl::UniquePtr makeCert() { + std::string file = TestEnvironment::readFileToStringForTest( + TestEnvironment::substitute("{{ test_tmpdir }}/unittestcert.pem")); + bssl::UniquePtr bio(BIO_new_mem_buf(file.data(), file.size())); + + uint8_t* data = nullptr; + long len = 0; + RELEASE_ASSERT( + PEM_bytes_read_bio(&data, &len, nullptr, PEM_STRING_X509, bio.get(), nullptr, nullptr), + "PEM_bytes_read_bio failed"); + bssl::UniquePtr tmp(data); // Prevents memory leak. + return bssl::UniquePtr(CRYPTO_BUFFER_new(data, len, nullptr)); + } + + const size_t kBufferLength{100}; + + Event::DispatcherPtr dispatcher_; + StreamInfo::StreamInfoImpl stream_info_; + + BIO *client_bio_, *server_bio_; + bssl::UniquePtr client_ctx_, server_ctx_; + bssl::UniquePtr client_ssl_, server_ssl_; +}; + +TEST_F(HandshakerTest, NormalOperation) { + NiceMock mock_connection; + ON_CALL(mock_connection, state).WillByDefault(Return(Network::Connection::State::Closed)); + + NiceMock handshake_callbacks; + EXPECT_CALL(handshake_callbacks, onSuccess).Times(1); + ON_CALL(handshake_callbacks, connection()).WillByDefault(ReturnRef(mock_connection)); + + SslHandshakerImpl handshaker(std::move(server_ssl_), 0, &handshake_callbacks); + + auto post_io_action = Network::PostIoAction::KeepOpen; // default enum + + // Run the handshakes from the client and server until SslHandshakerImpl decides + // we're done and returns PostIoAction::Close. + while (post_io_action != Network::PostIoAction::Close) { + SSL_do_handshake(client_ssl_.get()); + post_io_action = handshaker.doHandshake(); + } + + EXPECT_EQ(post_io_action, Network::PostIoAction::Close); +} + +// We induce some kind of BIO mismatch and force the SSL_do_handshake to +// return an error code without error handling, i.e. not SSL_ERROR_WANT_READ +// or _WRITE or _PRIVATE_KEY_OPERATION. +TEST_F(HandshakerTest, ErrorCbOnAbnormalOperation) { + // We make a new BIO, set it as the `rbio`/`wbio` for the client SSL object, + // and break the BIO pair connecting the two SSL objects. Now handshaking will + // fail, likely with SSL_ERROR_SSL. + BIO* bio = BIO_new(BIO_s_socket()); + SSL_set_bio(client_ssl_.get(), bio, bio); + + StrictMock handshake_callbacks; + EXPECT_CALL(handshake_callbacks, onFailure).Times(1); + + SslHandshakerImpl handshaker(std::move(server_ssl_), 0, &handshake_callbacks); + + auto post_io_action = Network::PostIoAction::KeepOpen; // default enum + + while (post_io_action != Network::PostIoAction::Close) { + SSL_do_handshake(client_ssl_.get()); + post_io_action = handshaker.doHandshake(); + } + + // In the error case, SslHandshakerImpl also closes the connection. + EXPECT_EQ(post_io_action, Network::PostIoAction::Close); +} + +// Example SslHandshakerImpl demonstrating special-case behavior which necessitates +// extra SSL_ERROR case handling. Here, we induce an SSL_ERROR_WANT_X509_LOOKUP, +// check for it in the handshaker, faux-trigger the lookup, and then proceed as +// normal. +class SslHandshakerImplForTest : public SslHandshakerImpl { +public: + SslHandshakerImplForTest(bssl::UniquePtr ssl_ptr, + Ssl::HandshakeCallbacks* handshake_callbacks, + std::function requested_cert_cb) + : SslHandshakerImpl(std::move(ssl_ptr), 0, handshake_callbacks), + requested_cert_cb_(requested_cert_cb) { + SSL_set_cert_cb( + ssl(), [](SSL*, void* arg) -> int { return *static_cast(arg) ? 1 : -1; }, + &cert_cb_ok_); + } + + Network::PostIoAction doHandshake() override { + RELEASE_ASSERT(state() != Ssl::SocketState::HandshakeComplete && + state() != Ssl::SocketState::ShutdownSent, + "Handshaker state was either complete or sent."); + + int rc = SSL_do_handshake(ssl()); + if (rc == 1) { + setState(Ssl::SocketState::HandshakeComplete); + handshakeCallbacks()->onSuccess(ssl()); + return Network::PostIoAction::Close; + } else { + switch (SSL_get_error(ssl(), rc)) { + case SSL_ERROR_WANT_READ: + case SSL_ERROR_WANT_WRITE: + return Network::PostIoAction::KeepOpen; + case SSL_ERROR_WANT_X509_LOOKUP: + // Special case. Once this lookup is requested, we flip the bit and allow + // the handshake to proceed. + requested_cert_cb_(); + return Network::PostIoAction::KeepOpen; + default: + handshakeCallbacks()->onFailure(); + return Network::PostIoAction::Close; + } + } + } + + void setCertCbOk() { cert_cb_ok_ = true; } + +private: + std::function requested_cert_cb_; + bool cert_cb_ok_{false}; +}; + +TEST_F(HandshakerTest, NormalOperationWithSslHandshakerImplForTest) { + ::testing::MockFunction requested_cert_cb; + + StrictMock handshake_callbacks; + EXPECT_CALL(handshake_callbacks, onSuccess).Times(1); + + SslHandshakerImplForTest handshaker(std::move(server_ssl_), &handshake_callbacks, + requested_cert_cb.AsStdFunction()); + + EXPECT_CALL(requested_cert_cb, Call).WillOnce([&]() { handshaker.setCertCbOk(); }); + + auto post_io_action = Network::PostIoAction::KeepOpen; // default enum + + while (post_io_action != Network::PostIoAction::Close) { + SSL_do_handshake(client_ssl_.get()); + post_io_action = handshaker.doHandshake(); + } + + EXPECT_EQ(post_io_action, Network::PostIoAction::Close); +} + +} // namespace +} // namespace Tls +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/mocks/ssl/mocks.h b/test/mocks/ssl/mocks.h index 5f25bd9a18cd..ea2a0edd6000 100644 --- a/test/mocks/ssl/mocks.h +++ b/test/mocks/ssl/mocks.h @@ -86,6 +86,9 @@ class MockClientContextConfig : public ClientContextConfig { MOCK_METHOD(bool, isReady, (), (const)); MOCK_METHOD(void, setSecretUpdateCallback, (std::function callback)); + MOCK_METHOD(Ssl::HandshakerFactoryCb, createHandshaker, (), (const, override)); + MOCK_METHOD(Ssl::HandshakerCapabilities, capabilities, (), (const, override)); + MOCK_METHOD(const std::string&, serverNameIndication, (), (const)); MOCK_METHOD(bool, allowRenegotiation, (), (const)); MOCK_METHOD(size_t, maxSessionKeys, (), (const)); @@ -109,6 +112,9 @@ class MockServerContextConfig : public ServerContextConfig { MOCK_METHOD(absl::optional, sessionTimeout, (), (const)); MOCK_METHOD(void, setSecretUpdateCallback, (std::function callback)); + MOCK_METHOD(Ssl::HandshakerFactoryCb, createHandshaker, (), (const, override)); + MOCK_METHOD(Ssl::HandshakerCapabilities, capabilities, (), (const, override)); + MOCK_METHOD(bool, requireClientCertificate, (), (const)); MOCK_METHOD(const std::vector&, sessionTicketKeys, (), (const)); MOCK_METHOD(bool, disableStatelessSessionResumption, (), (const));