From 4ecda60a0c9e811811827b56589e09e237e40ef2 Mon Sep 17 00:00:00 2001 From: alyssawilk Date: Wed, 4 Dec 2024 11:45:05 -0500 Subject: [PATCH] validator: allowing creation failure (#37456) Risk Level: low Testing: updated tests Docs Changes: n/a Release Notes: n/a https://github.com/envoyproxy/envoy-mobile/issues/176 Signed-off-by: Alyssa Wilk --- .../cert_validator/platform_bridge/config.cc | 4 +- .../cert_validator/platform_bridge/config.h | 2 +- .../platform_bridge_cert_validator.cc | 23 ++++-- .../platform_bridge_cert_validator.h | 13 ++-- .../platform_bridge_cert_validator_test.cc | 75 +++++++++++-------- .../tls/cert_validator/default_validator.cc | 2 +- source/common/tls/cert_validator/factory.h | 2 +- source/common/tls/context_impl.cc | 4 +- .../cert_validator/spiffe/spiffe_validator.cc | 2 +- .../tls/cert_validator/timed_cert_validator.h | 2 +- 10 files changed, 76 insertions(+), 53 deletions(-) diff --git a/mobile/library/common/extensions/cert_validator/platform_bridge/config.cc b/mobile/library/common/extensions/cert_validator/platform_bridge/config.cc index 807ec2dcfef7..03794524a051 100644 --- a/mobile/library/common/extensions/cert_validator/platform_bridge/config.cc +++ b/mobile/library/common/extensions/cert_validator/platform_bridge/config.cc @@ -7,10 +7,10 @@ namespace Extensions { namespace TransportSockets { namespace Tls { -CertValidatorPtr PlatformBridgeCertValidatorFactory::createCertValidator( +absl::StatusOr PlatformBridgeCertValidatorFactory::createCertValidator( const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, Server::Configuration::CommonFactoryContext& /*context*/) { - return std::make_unique(config, stats); + return PlatformBridgeCertValidator::create(config, stats); } REGISTER_FACTORY(PlatformBridgeCertValidatorFactory, CertValidatorFactory); diff --git a/mobile/library/common/extensions/cert_validator/platform_bridge/config.h b/mobile/library/common/extensions/cert_validator/platform_bridge/config.h index d070c3fc99af..40c075d5a5fd 100644 --- a/mobile/library/common/extensions/cert_validator/platform_bridge/config.h +++ b/mobile/library/common/extensions/cert_validator/platform_bridge/config.h @@ -15,7 +15,7 @@ namespace Tls { class PlatformBridgeCertValidatorFactory : public CertValidatorFactory, public Config::TypedFactory { public: - CertValidatorPtr + absl::StatusOr createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, Server::Configuration::CommonFactoryContext& context) override; diff --git a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc index e7295159644f..e3151f4a3564 100644 --- a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc +++ b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.cc @@ -15,9 +15,19 @@ namespace Extensions { namespace TransportSockets { namespace Tls { +absl::StatusOr> +PlatformBridgeCertValidator::create(const Envoy::Ssl::CertificateValidationContextConfig* config, + SslStats& stats) { + absl::Status creation_status = absl::OkStatus(); + auto ret = std::unique_ptr(new PlatformBridgeCertValidator( + config, stats, Thread::PosixThreadFactory::create(), creation_status)); + RETURN_IF_NOT_OK_REF(creation_status); + return ret; +} + PlatformBridgeCertValidator::PlatformBridgeCertValidator( const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, - Thread::PosixThreadFactoryPtr thread_factory) + Thread::PosixThreadFactoryPtr thread_factory, absl::Status& creation_status) : allow_untrusted_certificate_(config != nullptr && config->trustChainVerification() == envoy::extensions::transport_sockets::tls::v3:: @@ -28,19 +38,16 @@ PlatformBridgeCertValidator::PlatformBridgeCertValidator( "Invalid certificate validation context config."); if (config != nullptr && config->customValidatorConfig().has_value()) { envoy_mobile::extensions::cert_validator::platform_bridge::PlatformBridgeCertValidator cfg; - THROW_IF_NOT_OK(Envoy::Config::Utility::translateOpaqueConfig( - config->customValidatorConfig().value().typed_config(), - ProtobufMessage::getStrictValidationVisitor(), cfg)); + SET_AND_RETURN_IF_NOT_OK(Envoy::Config::Utility::translateOpaqueConfig( + config->customValidatorConfig().value().typed_config(), + ProtobufMessage::getStrictValidationVisitor(), cfg), + creation_status); if (cfg.has_thread_priority()) { thread_priority_ = cfg.thread_priority().value(); } } } -PlatformBridgeCertValidator::PlatformBridgeCertValidator( - const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats) - : PlatformBridgeCertValidator(config, stats, Thread::PosixThreadFactory::create()) {} - PlatformBridgeCertValidator::~PlatformBridgeCertValidator() { // Wait for validation threads to finish. for (auto& [id, job] : validation_jobs_) { diff --git a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h index 54f187b6fc7f..472ff0b6650e 100644 --- a/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h +++ b/mobile/library/common/extensions/cert_validator/platform_bridge/platform_bridge_cert_validator.h @@ -19,8 +19,8 @@ namespace Tls { // validation. class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable { public: - PlatformBridgeCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, - SslStats& stats); + static absl::StatusOr> + create(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats); ~PlatformBridgeCertValidator() override; @@ -58,6 +58,10 @@ class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable /* cert_chain */, std::string hostname, std::vector /* subject_alt_names */) override { - recorded_thread_priority_ = thread_factory_.currentThreadPriority(); + recorded_thread_priority_ = threadFactory()->currentThreadPriority(); postVerifyResultAndCleanUp(/* success = */ true, std::move(hostname), "", SSL_AD_CERTIFICATE_UNKNOWN, ValidationFailureType::Success, dispatcher, this); } private: - Thread::PosixThreadFactory& thread_factory_; int recorded_thread_priority_; }; @@ -171,7 +170,7 @@ INSTANTIATE_TEST_SUITE_P(TrustMode, PlatformBridgeCertValidatorTest, CertificateValidationContext::ACCEPT_UNTRUSTED})); TEST_P(PlatformBridgeCertValidatorTest, NoConfig) { - EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator validator(nullptr, stats_); }, + EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator::create(nullptr, stats_).IgnoreError(); }, "Invalid certificate validation context config."); } @@ -182,7 +181,7 @@ TEST_P(PlatformBridgeCertValidatorTest, NonEmptyCaCert) { EXPECT_CALL(config_, trustChainVerification()).WillRepeatedly(Return(GetParam())); EXPECT_CALL(config_, customValidatorConfig()).WillRepeatedly(ReturnRef(platform_bridge_config_)); - EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator validator(&config_, stats_); }, + EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator::create(&config_, stats_).IgnoreError(); }, "Invalid certificate validation context config."); } @@ -193,13 +192,14 @@ TEST_P(PlatformBridgeCertValidatorTest, NonEmptyRevocationList) { EXPECT_CALL(config_, trustChainVerification()).WillRepeatedly(Return(GetParam())); EXPECT_CALL(config_, customValidatorConfig()).WillRepeatedly(ReturnRef(platform_bridge_config_)); - EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator validator(&config_, stats_); }, + EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator::create(&config_, stats_).IgnoreError(); }, "Invalid certificate validation context config."); } TEST_P(PlatformBridgeCertValidatorTest, NoCallback) { initializeConfig(); - PlatformBridgeCertValidator validator(&config_, stats_); + std::unique_ptr validator = + *PlatformBridgeCertValidator::create(&config_, stats_); bssl::UniquePtr cert_chain = readCertChainFromFile( TestEnvironment::substitute("{{ test_rundir }}/test/common/tls/test_data/san_dns2_cert.pem")); @@ -207,23 +207,24 @@ TEST_P(PlatformBridgeCertValidatorTest, NoCallback) { EXPECT_ENVOY_BUG( { - validator.doVerifyCertChain(*cert_chain, Ssl::ValidateResultCallbackPtr(), - transport_socket_options_, *ssl_ctx_, validation_context_, - is_server_, hostname); + validator->doVerifyCertChain(*cert_chain, Ssl::ValidateResultCallbackPtr(), + transport_socket_options_, *ssl_ctx_, validation_context_, + is_server_, hostname); }, "No callback specified"); } TEST_P(PlatformBridgeCertValidatorTest, EmptyCertChain) { initializeConfig(); - PlatformBridgeCertValidator validator(&config_, stats_); + std::unique_ptr validator = + *PlatformBridgeCertValidator::create(&config_, stats_); bssl::UniquePtr cert_chain(sk_X509_new_null()); std::string hostname = "www.example.com"; ValidationResults results = - validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, - *ssl_ctx_, validation_context_, is_server_, hostname); + validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, + *ssl_ctx_, validation_context_, is_server_, hostname); EXPECT_EQ(ValidationResults::ValidationStatus::Failed, results.status); EXPECT_FALSE(results.tls_alert.has_value()); ASSERT_TRUE(results.error_details.has_value()); @@ -236,7 +237,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificate) { EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _)); initializeConfig(); - PlatformBridgeCertValidator validator(&config_, stats_); + std::unique_ptr validator = + *PlatformBridgeCertValidator::create(&config_, stats_); std::string hostname = "server1.example.com"; bssl::UniquePtr cert_chain = readCertChainFromFile( @@ -248,8 +250,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificate) { EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_)); ValidationResults results = - validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, - *ssl_ctx_, validation_context_, is_server_, hostname); + validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, + *ssl_ctx_, validation_context_, is_server_, hostname); EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); EXPECT_CALL(callback_ref, @@ -266,7 +268,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptySanOverrides) { EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _)); initializeConfig(); - PlatformBridgeCertValidator validator(&config_, stats_); + std::unique_ptr validator = + *PlatformBridgeCertValidator::create(&config_, stats_); std::string hostname = "server1.example.com"; bssl::UniquePtr cert_chain = readCertChainFromFile( @@ -283,8 +286,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptySanOverrides) { std::make_shared("", std::move(subject_alt_names)); ValidationResults results = - validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, - *ssl_ctx_, validation_context_, is_server_, hostname); + validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, + *ssl_ctx_, validation_context_, is_server_, hostname); EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); EXPECT_CALL(callback_ref, @@ -301,7 +304,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptyHostNoOverrides) { EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _)); initializeConfig(); - PlatformBridgeCertValidator validator(&config_, stats_); + std::unique_ptr validator = + *PlatformBridgeCertValidator::create(&config_, stats_); std::string hostname = ""; bssl::UniquePtr cert_chain = readCertChainFromFile( @@ -318,8 +322,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptyHostNoOverrides) { std::make_shared("", std::move(subject_alt_names)); ValidationResults results = - validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, - *ssl_ctx_, validation_context_, is_server_, hostname); + validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, + *ssl_ctx_, validation_context_, is_server_, hostname); EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); EXPECT_CALL(callback_ref, @@ -336,7 +340,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateButInvalidSni) { EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _)); initializeConfig(); - PlatformBridgeCertValidator validator(&config_, stats_); + std::unique_ptr validator = + *PlatformBridgeCertValidator::create(&config_, stats_); std::string hostname = "server2.example.com"; bssl::UniquePtr cert_chain = readCertChainFromFile( @@ -348,8 +353,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateButInvalidSni) { EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_)); ValidationResults results = - validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, - *ssl_ctx_, validation_context_, is_server_, hostname); + validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, + *ssl_ctx_, validation_context_, is_server_, hostname); EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); EXPECT_CALL(callback_ref, @@ -366,7 +371,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateSniOverride) { EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _)); initializeConfig(); - PlatformBridgeCertValidator validator(&config_, stats_); + std::unique_ptr validator = + *PlatformBridgeCertValidator::create(&config_, stats_); std::vector subject_alt_names = {"server1.example.com"}; @@ -383,8 +389,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateSniOverride) { std::make_shared("", std::move(subject_alt_names)); ValidationResults results = - validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, - *ssl_ctx_, validation_context_, is_server_, hostname); + validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_, + *ssl_ctx_, validation_context_, is_server_, hostname); EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status); // The cert will be validated against the overridden name not the invalid name "server2". @@ -399,7 +405,7 @@ TEST_P(PlatformBridgeCertValidatorTest, DeletedWithValidationPending) { EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _)); initializeConfig(); - auto validator = std::make_unique(&config_, stats_); + auto validator = *PlatformBridgeCertValidator::create(&config_, stats_); std::string hostname = "server1.example.com"; bssl::UniquePtr cert_chain = readCertChainFromFile( @@ -426,7 +432,9 @@ TEST_P(PlatformBridgeCertValidatorTest, ThreadCreationFailed) { initializeConfig(); auto thread_factory = std::make_unique(); EXPECT_CALL(*thread_factory, createThread(_, _, false)).WillOnce(Return(ByMove(nullptr))); - PlatformBridgeCertValidator validator(&config_, stats_, std::move(thread_factory)); + absl::Status creation_status = absl::OkStatus(); + PlatformBridgeCertValidatorCustomValidate validator(&config_, stats_, std::move(thread_factory), + creation_status); std::string hostname = "server1.example.com"; bssl::UniquePtr cert_chain = readCertChainFromFile( @@ -455,7 +463,10 @@ TEST_P(PlatformBridgeCertValidatorTest, ThreadPriority) { EXPECT_CALL(helper_handle_->mock_helper(), cleanupAfterCertificateValidation()); initializeConfig(); - PlatformBridgeCertValidatorCustomValidate validator(&config_, stats_, *thread_factory_); + absl::Status creation_status = absl::OkStatus(); + PlatformBridgeCertValidatorCustomValidate validator( + &config_, stats_, Thread::PosixThreadFactory::create(), creation_status); + ASSERT_TRUE(creation_status.ok()); std::string hostname = "server1.example.com"; bssl::UniquePtr cert_chain = readCertChainFromFile( diff --git a/source/common/tls/cert_validator/default_validator.cc b/source/common/tls/cert_validator/default_validator.cc index 1225a54c9f4e..56bc90946661 100644 --- a/source/common/tls/cert_validator/default_validator.cc +++ b/source/common/tls/cert_validator/default_validator.cc @@ -585,7 +585,7 @@ absl::optional DefaultCertValidator::daysUntilFirstCertExpires() const class DefaultCertValidatorFactory : public CertValidatorFactory { public: - CertValidatorPtr + absl::StatusOr createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, Server::Configuration::CommonFactoryContext& context) override { return std::make_unique(config, stats, context); diff --git a/source/common/tls/cert_validator/factory.h b/source/common/tls/cert_validator/factory.h index 8f6aebbd6b4c..535a95eabef5 100644 --- a/source/common/tls/cert_validator/factory.h +++ b/source/common/tls/cert_validator/factory.h @@ -19,7 +19,7 @@ std::string getCertValidatorName(const Envoy::Ssl::CertificateValidationContextC class CertValidatorFactory : public Config::UntypedFactory { public: - virtual CertValidatorPtr + virtual absl::StatusOr createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, Server::Configuration::CommonFactoryContext& context) PURE; diff --git a/source/common/tls/context_impl.cc b/source/common/tls/context_impl.cc index 1f9a2b5a3d9f..1986a2936916 100644 --- a/source/common/tls/context_impl.cc +++ b/source/common/tls/context_impl.cc @@ -92,8 +92,10 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c return; } - cert_validator_ = cert_validator_factory->createCertValidator( + auto validator_or_error = cert_validator_factory->createCertValidator( config.certificateValidationContext(), stats_, factory_context_); + SET_AND_RETURN_IF_NOT_OK(validator_or_error.status(), creation_status); + cert_validator_ = std::move(*validator_or_error); const auto tls_certificates = config.tlsCertificates(); tls_contexts_.resize(std::max(static_cast(1), tls_certificates.size())); diff --git a/source/extensions/transport_sockets/tls/cert_validator/spiffe/spiffe_validator.cc b/source/extensions/transport_sockets/tls/cert_validator/spiffe/spiffe_validator.cc index e53fa9f7e53b..0dc0b2941be3 100644 --- a/source/extensions/transport_sockets/tls/cert_validator/spiffe/spiffe_validator.cc +++ b/source/extensions/transport_sockets/tls/cert_validator/spiffe/spiffe_validator.cc @@ -313,7 +313,7 @@ Envoy::Ssl::CertificateDetailsPtr SPIFFEValidator::getCaCertInformation() const class SPIFFEValidatorFactory : public CertValidatorFactory { public: - CertValidatorPtr + absl::StatusOr createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, Server::Configuration::CommonFactoryContext& context) override { return std::make_unique(config, stats, context); diff --git a/test/common/tls/cert_validator/timed_cert_validator.h b/test/common/tls/cert_validator/timed_cert_validator.h index 2ea320d87af2..2ba0bb1dcba8 100644 --- a/test/common/tls/cert_validator/timed_cert_validator.h +++ b/test/common/tls/cert_validator/timed_cert_validator.h @@ -50,7 +50,7 @@ class TimedCertValidator : public DefaultCertValidator { class TimedCertValidatorFactory : public CertValidatorFactory { public: - CertValidatorPtr + absl::StatusOr createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats, Server::Configuration::CommonFactoryContext& context) override { auto validator = std::make_unique(validation_time_out_ms_, config, stats,