Skip to content

Commit

Permalink
validator: allowing creation failure (#37456)
Browse files Browse the repository at this point in the history
Risk Level: low
Testing: updated tests
Docs Changes: n/a
Release Notes: n/a
envoyproxy/envoy-mobile#176

Signed-off-by: Alyssa Wilk <alyssar@chromium.org>
  • Loading branch information
alyssawilk authored Dec 4, 2024
1 parent 42f17fd commit 4ecda60
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ namespace Extensions {
namespace TransportSockets {
namespace Tls {

CertValidatorPtr PlatformBridgeCertValidatorFactory::createCertValidator(
absl::StatusOr<CertValidatorPtr> PlatformBridgeCertValidatorFactory::createCertValidator(
const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats,
Server::Configuration::CommonFactoryContext& /*context*/) {
return std::make_unique<PlatformBridgeCertValidator>(config, stats);
return PlatformBridgeCertValidator::create(config, stats);
}

REGISTER_FACTORY(PlatformBridgeCertValidatorFactory, CertValidatorFactory);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace Tls {
class PlatformBridgeCertValidatorFactory : public CertValidatorFactory,
public Config::TypedFactory {
public:
CertValidatorPtr
absl::StatusOr<CertValidatorPtr>
createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats,
Server::Configuration::CommonFactoryContext& context) override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,19 @@ namespace Extensions {
namespace TransportSockets {
namespace Tls {

absl::StatusOr<std::unique_ptr<PlatformBridgeCertValidator>>
PlatformBridgeCertValidator::create(const Envoy::Ssl::CertificateValidationContextConfig* config,
SslStats& stats) {
absl::Status creation_status = absl::OkStatus();
auto ret = std::unique_ptr<PlatformBridgeCertValidator>(new PlatformBridgeCertValidator(
config, stats, Thread::PosixThreadFactory::create(), creation_status));
RETURN_IF_NOT_OK_REF(creation_status);
return ret;
}

PlatformBridgeCertValidator::PlatformBridgeCertValidator(
const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats,
Thread::PosixThreadFactoryPtr thread_factory)
Thread::PosixThreadFactoryPtr thread_factory, absl::Status& creation_status)
: allow_untrusted_certificate_(config != nullptr &&
config->trustChainVerification() ==
envoy::extensions::transport_sockets::tls::v3::
Expand All @@ -28,19 +38,16 @@ PlatformBridgeCertValidator::PlatformBridgeCertValidator(
"Invalid certificate validation context config.");
if (config != nullptr && config->customValidatorConfig().has_value()) {
envoy_mobile::extensions::cert_validator::platform_bridge::PlatformBridgeCertValidator cfg;
THROW_IF_NOT_OK(Envoy::Config::Utility::translateOpaqueConfig(
config->customValidatorConfig().value().typed_config(),
ProtobufMessage::getStrictValidationVisitor(), cfg));
SET_AND_RETURN_IF_NOT_OK(Envoy::Config::Utility::translateOpaqueConfig(
config->customValidatorConfig().value().typed_config(),
ProtobufMessage::getStrictValidationVisitor(), cfg),
creation_status);
if (cfg.has_thread_priority()) {
thread_priority_ = cfg.thread_priority().value();
}
}
}

PlatformBridgeCertValidator::PlatformBridgeCertValidator(
const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats)
: PlatformBridgeCertValidator(config, stats, Thread::PosixThreadFactory::create()) {}

PlatformBridgeCertValidator::~PlatformBridgeCertValidator() {
// Wait for validation threads to finish.
for (auto& [id, job] : validation_jobs_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ namespace Tls {
// validation.
class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable<Logger::Id::connection> {
public:
PlatformBridgeCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config,
SslStats& stats);
static absl::StatusOr<std::unique_ptr<PlatformBridgeCertValidator>>
create(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats);

~PlatformBridgeCertValidator() override;

Expand Down Expand Up @@ -58,6 +58,10 @@ class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable<Logge
}

protected:
PlatformBridgeCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config,
SslStats& stats, Thread::PosixThreadFactoryPtr thread_factory,
absl::Status& creation_status);

enum class ValidationFailureType {
Success,
FailVerifyError,
Expand All @@ -82,12 +86,11 @@ class PlatformBridgeCertValidator : public CertValidator, Logger::Loggable<Logge
Event::Dispatcher* dispatcher,
PlatformBridgeCertValidator* parent);

Thread::PosixThreadFactoryPtr& threadFactory() { return thread_factory_; }

private:
GTEST_FRIEND_CLASS(PlatformBridgeCertValidatorTest, ThreadCreationFailed);

PlatformBridgeCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config,
SslStats& stats, Thread::PosixThreadFactoryPtr thread_factory);

// Called when a pending verification completes. Must be invoked on the main thread.
void onVerificationComplete(const Thread::ThreadId& thread_id, const std::string& hostname,
bool success, const std::string& error_details, uint8_t tls_alert,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,22 @@ class PlatformBridgeCertValidatorCustomValidate : public PlatformBridgeCertValid
public:
PlatformBridgeCertValidatorCustomValidate(
const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats,
Thread::PosixThreadFactory& thread_factory)
: PlatformBridgeCertValidator(config, stats), thread_factory_(thread_factory) {}
Thread::PosixThreadFactoryPtr thread_factory, absl::Status& creation_status)
: PlatformBridgeCertValidator(config, stats, std::move(thread_factory), creation_status) {}

int recordedThreadPriority() const { return recorded_thread_priority_; }

protected:
void verifyCertChainByPlatform(Event::Dispatcher* dispatcher,
std::vector<std::string> /* cert_chain */, std::string hostname,
std::vector<std::string> /* subject_alt_names */) override {
recorded_thread_priority_ = thread_factory_.currentThreadPriority();
recorded_thread_priority_ = threadFactory()->currentThreadPriority();
postVerifyResultAndCleanUp(/* success = */ true, std::move(hostname), "",
SSL_AD_CERTIFICATE_UNKNOWN, ValidationFailureType::Success,
dispatcher, this);
}

private:
Thread::PosixThreadFactory& thread_factory_;
int recorded_thread_priority_;
};

Expand Down Expand Up @@ -171,7 +170,7 @@ INSTANTIATE_TEST_SUITE_P(TrustMode, PlatformBridgeCertValidatorTest,
CertificateValidationContext::ACCEPT_UNTRUSTED}));

TEST_P(PlatformBridgeCertValidatorTest, NoConfig) {
EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator validator(nullptr, stats_); },
EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator::create(nullptr, stats_).IgnoreError(); },
"Invalid certificate validation context config.");
}

Expand All @@ -182,7 +181,7 @@ TEST_P(PlatformBridgeCertValidatorTest, NonEmptyCaCert) {
EXPECT_CALL(config_, trustChainVerification()).WillRepeatedly(Return(GetParam()));
EXPECT_CALL(config_, customValidatorConfig()).WillRepeatedly(ReturnRef(platform_bridge_config_));

EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator validator(&config_, stats_); },
EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator::create(&config_, stats_).IgnoreError(); },
"Invalid certificate validation context config.");
}

Expand All @@ -193,37 +192,39 @@ TEST_P(PlatformBridgeCertValidatorTest, NonEmptyRevocationList) {
EXPECT_CALL(config_, trustChainVerification()).WillRepeatedly(Return(GetParam()));
EXPECT_CALL(config_, customValidatorConfig()).WillRepeatedly(ReturnRef(platform_bridge_config_));

EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator validator(&config_, stats_); },
EXPECT_ENVOY_BUG({ PlatformBridgeCertValidator::create(&config_, stats_).IgnoreError(); },
"Invalid certificate validation context config.");
}

TEST_P(PlatformBridgeCertValidatorTest, NoCallback) {
initializeConfig();
PlatformBridgeCertValidator validator(&config_, stats_);
std::unique_ptr<PlatformBridgeCertValidator> validator =
*PlatformBridgeCertValidator::create(&config_, stats_);

bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
TestEnvironment::substitute("{{ test_rundir }}/test/common/tls/test_data/san_dns2_cert.pem"));
std::string hostname = "www.example.com";

EXPECT_ENVOY_BUG(
{
validator.doVerifyCertChain(*cert_chain, Ssl::ValidateResultCallbackPtr(),
transport_socket_options_, *ssl_ctx_, validation_context_,
is_server_, hostname);
validator->doVerifyCertChain(*cert_chain, Ssl::ValidateResultCallbackPtr(),
transport_socket_options_, *ssl_ctx_, validation_context_,
is_server_, hostname);
},
"No callback specified");
}

TEST_P(PlatformBridgeCertValidatorTest, EmptyCertChain) {
initializeConfig();
PlatformBridgeCertValidator validator(&config_, stats_);
std::unique_ptr<PlatformBridgeCertValidator> validator =
*PlatformBridgeCertValidator::create(&config_, stats_);

bssl::UniquePtr<STACK_OF(X509)> cert_chain(sk_X509_new_null());
std::string hostname = "www.example.com";

ValidationResults results =
validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
EXPECT_EQ(ValidationResults::ValidationStatus::Failed, results.status);
EXPECT_FALSE(results.tls_alert.has_value());
ASSERT_TRUE(results.error_details.has_value());
Expand All @@ -236,7 +237,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificate) {
EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _));

initializeConfig();
PlatformBridgeCertValidator validator(&config_, stats_);
std::unique_ptr<PlatformBridgeCertValidator> validator =
*PlatformBridgeCertValidator::create(&config_, stats_);

std::string hostname = "server1.example.com";
bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
Expand All @@ -248,8 +250,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificate) {
EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_));

ValidationResults results =
validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status);

EXPECT_CALL(callback_ref,
Expand All @@ -266,7 +268,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptySanOverrides) {
EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _));

initializeConfig();
PlatformBridgeCertValidator validator(&config_, stats_);
std::unique_ptr<PlatformBridgeCertValidator> validator =
*PlatformBridgeCertValidator::create(&config_, stats_);

std::string hostname = "server1.example.com";
bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
Expand All @@ -283,8 +286,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptySanOverrides) {
std::make_shared<Network::TransportSocketOptionsImpl>("", std::move(subject_alt_names));

ValidationResults results =
validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status);

EXPECT_CALL(callback_ref,
Expand All @@ -301,7 +304,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptyHostNoOverrides) {
EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _));

initializeConfig();
PlatformBridgeCertValidator validator(&config_, stats_);
std::unique_ptr<PlatformBridgeCertValidator> validator =
*PlatformBridgeCertValidator::create(&config_, stats_);

std::string hostname = "";
bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
Expand All @@ -318,8 +322,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateEmptyHostNoOverrides) {
std::make_shared<Network::TransportSocketOptionsImpl>("", std::move(subject_alt_names));

ValidationResults results =
validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status);

EXPECT_CALL(callback_ref,
Expand All @@ -336,7 +340,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateButInvalidSni) {
EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _));

initializeConfig();
PlatformBridgeCertValidator validator(&config_, stats_);
std::unique_ptr<PlatformBridgeCertValidator> validator =
*PlatformBridgeCertValidator::create(&config_, stats_);

std::string hostname = "server2.example.com";
bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
Expand All @@ -348,8 +353,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateButInvalidSni) {
EXPECT_CALL(callback_ref, dispatcher()).WillRepeatedly(ReturnRef(*dispatcher_));

ValidationResults results =
validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status);

EXPECT_CALL(callback_ref,
Expand All @@ -366,7 +371,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateSniOverride) {
EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _));

initializeConfig();
PlatformBridgeCertValidator validator(&config_, stats_);
std::unique_ptr<PlatformBridgeCertValidator> validator =
*PlatformBridgeCertValidator::create(&config_, stats_);

std::vector<std::string> subject_alt_names = {"server1.example.com"};

Expand All @@ -383,8 +389,8 @@ TEST_P(PlatformBridgeCertValidatorTest, ValidCertificateSniOverride) {
std::make_shared<Network::TransportSocketOptionsImpl>("", std::move(subject_alt_names));

ValidationResults results =
validator.doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
validator->doVerifyCertChain(*cert_chain, std::move(callback_), transport_socket_options_,
*ssl_ctx_, validation_context_, is_server_, hostname);
EXPECT_EQ(ValidationResults::ValidationStatus::Pending, results.status);

// The cert will be validated against the overridden name not the invalid name "server2".
Expand All @@ -399,7 +405,7 @@ TEST_P(PlatformBridgeCertValidatorTest, DeletedWithValidationPending) {
EXPECT_CALL(helper_handle_->mock_helper(), validateCertificateChain(_, _));

initializeConfig();
auto validator = std::make_unique<PlatformBridgeCertValidator>(&config_, stats_);
auto validator = *PlatformBridgeCertValidator::create(&config_, stats_);

std::string hostname = "server1.example.com";
bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
Expand All @@ -426,7 +432,9 @@ TEST_P(PlatformBridgeCertValidatorTest, ThreadCreationFailed) {
initializeConfig();
auto thread_factory = std::make_unique<Thread::MockPosixThreadFactory>();
EXPECT_CALL(*thread_factory, createThread(_, _, false)).WillOnce(Return(ByMove(nullptr)));
PlatformBridgeCertValidator validator(&config_, stats_, std::move(thread_factory));
absl::Status creation_status = absl::OkStatus();
PlatformBridgeCertValidatorCustomValidate validator(&config_, stats_, std::move(thread_factory),
creation_status);

std::string hostname = "server1.example.com";
bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
Expand Down Expand Up @@ -455,7 +463,10 @@ TEST_P(PlatformBridgeCertValidatorTest, ThreadPriority) {
EXPECT_CALL(helper_handle_->mock_helper(), cleanupAfterCertificateValidation());

initializeConfig();
PlatformBridgeCertValidatorCustomValidate validator(&config_, stats_, *thread_factory_);
absl::Status creation_status = absl::OkStatus();
PlatformBridgeCertValidatorCustomValidate validator(
&config_, stats_, Thread::PosixThreadFactory::create(), creation_status);
ASSERT_TRUE(creation_status.ok());

std::string hostname = "server1.example.com";
bssl::UniquePtr<STACK_OF(X509)> cert_chain = readCertChainFromFile(
Expand Down
2 changes: 1 addition & 1 deletion source/common/tls/cert_validator/default_validator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ absl::optional<uint32_t> DefaultCertValidator::daysUntilFirstCertExpires() const

class DefaultCertValidatorFactory : public CertValidatorFactory {
public:
CertValidatorPtr
absl::StatusOr<CertValidatorPtr>
createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats,
Server::Configuration::CommonFactoryContext& context) override {
return std::make_unique<DefaultCertValidator>(config, stats, context);
Expand Down
2 changes: 1 addition & 1 deletion source/common/tls/cert_validator/factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ std::string getCertValidatorName(const Envoy::Ssl::CertificateValidationContextC

class CertValidatorFactory : public Config::UntypedFactory {
public:
virtual CertValidatorPtr
virtual absl::StatusOr<CertValidatorPtr>
createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats,
Server::Configuration::CommonFactoryContext& context) PURE;

Expand Down
4 changes: 3 additions & 1 deletion source/common/tls/context_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ ContextImpl::ContextImpl(Stats::Scope& scope, const Envoy::Ssl::ContextConfig& c
return;
}

cert_validator_ = cert_validator_factory->createCertValidator(
auto validator_or_error = cert_validator_factory->createCertValidator(
config.certificateValidationContext(), stats_, factory_context_);
SET_AND_RETURN_IF_NOT_OK(validator_or_error.status(), creation_status);
cert_validator_ = std::move(*validator_or_error);

const auto tls_certificates = config.tlsCertificates();
tls_contexts_.resize(std::max(static_cast<size_t>(1), tls_certificates.size()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ Envoy::Ssl::CertificateDetailsPtr SPIFFEValidator::getCaCertInformation() const

class SPIFFEValidatorFactory : public CertValidatorFactory {
public:
CertValidatorPtr
absl::StatusOr<CertValidatorPtr>
createCertValidator(const Envoy::Ssl::CertificateValidationContextConfig* config, SslStats& stats,
Server::Configuration::CommonFactoryContext& context) override {
return std::make_unique<SPIFFEValidator>(config, stats, context);
Expand Down
Loading

0 comments on commit 4ecda60

Please sign in to comment.