Skip to content

Commit

Permalink
Select ECH config from the ECHPolicy on connect
Browse files Browse the repository at this point in the history
Summary: Uses the ECHPolicy from the fizzContext to select an ECH config based on the given sni

Reviewed By: mingtaoy

Differential Revision: D51045365

fbshipit-source-id: 7ffd74c9d349747918cee3e6c23585eddc515e97
  • Loading branch information
Nick Richardson authored and facebook-github-bot committed Nov 28, 2023
1 parent 52c3be7 commit 63e3306
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 6 deletions.
6 changes: 6 additions & 0 deletions fizz/client/AsyncFizzClient-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ void AsyncFizzClientT<SM>::connect(
if (pskIdentity) {
cachedPsk = fizzContext_->getPsk(*pskIdentity);
}

auto echPolicy = fizzContext_->getECHPolicy();
if (!echConfigs && echPolicy && sni.hasValue()) {
echConfigs = echPolicy->getConfig(sni.value());
}

fizzClient_.connect(
fizzContext_,
std::move(verifier),
Expand Down
88 changes: 83 additions & 5 deletions fizz/client/test/AsyncFizzClientTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,7 @@ class AsyncFizzClientTest : public Test {
.WillOnce(InvokeWithoutArgs([]() { return Actions(); }));
const auto sni = std::string("www.example.com");
client_->connect(
&handshakeCallback_,
nullptr,
sni,
pskIdentity_,
folly::Optional<std::vector<ech::ECHConfig>>(folly::none));
&handshakeCallback_, nullptr, sni, pskIdentity_, echConfigs_);
}

enum class ECHMode { NotRequested, Accepted, Rejected };
Expand Down Expand Up @@ -217,6 +213,7 @@ class AsyncFizzClientTest : public Test {
EventBase evb_;
MockReplaySafetyCallback mockReplayCallback_;
folly::Optional<std::string> pskIdentity_{"pskIdentity"};
folly::Optional<std::vector<ech::ECHConfig>> echConfigs_;
};

MATCHER_P(BufMatches, expected, "") {
Expand Down Expand Up @@ -597,6 +594,87 @@ TEST_F(AsyncFizzClientTest, TestNoPskResumption) {
EXPECT_FALSE(client_->pskResumed());
}

TEST_F(AsyncFizzClientTest, TestNoECHPolicy) {
auto echPolicy = std::make_shared<MockECHPolicy>();
// Sanity check: ECHPolicy::getConfig should not be called if no ECH policy on
// context
EXPECT_CALL(*echPolicy, getConfig(_)).Times(0);
completeHandshake();
}

TEST_F(AsyncFizzClientTest, TestECHPolicyNoSNI) {
auto echPolicy = std::make_shared<MockECHPolicy>();
context_->setECHPolicy(echPolicy);
// Sanity check: ECHPolicy::getConfig should not be called if no ECH policy on
// context
EXPECT_CALL(*echPolicy, getConfig(_)).Times(0);
EXPECT_CALL(*machine_, _processConnect(_, _, _, _, _, _, _))
.WillOnce(InvokeWithoutArgs([]() { return Actions(); }));
client_->connect(
&handshakeCallback_, nullptr, folly::none, pskIdentity_, folly::none);
}

TEST_F(AsyncFizzClientTest, TestOverrideECHPolicy) {
auto echPolicy = std::make_shared<MockECHPolicy>();
context_->setECHPolicy(echPolicy);
// When an ECH config vector is passed to FizzClient::connect() the ECHPolicy
// lookup should be overridden.
echConfigs_ = std::vector<ech::ECHConfig>{};
EXPECT_CALL(*echPolicy, getConfig("www.example.com")).Times(0);
completeHandshake();
}

TEST_F(AsyncFizzClientTest, TestECHPolicyGet) {
auto echPolicy = std::make_shared<MockECHPolicy>();
context_->setECHPolicy(echPolicy);

std::vector<ech::ECHConfig> expectedEchConfigList;
ech::ECHConfig echConfig;
ech::ECHConfigContentDraft echConfigContent;
echConfigContent.key_config.kem_id = hpke::KEMId::x25519;
echConfigContent.key_config.config_id = 1;
echConfigContent.public_name =
folly::IOBuf::copyBuffer("www.super.secret.sni.com");
echConfigContent.maximum_name_length = 100;
echConfigContent.key_config.public_key = folly::IOBuf::copyBuffer(
"1d77eb1c522d08605b179d4214ee4a3635df7e17c336ea9006655a73fcaad63e");
auto kdfId = hpke::KDFId::Sha256;
auto aeadId = hpke::AeadId::TLS_AES_128_GCM_SHA256;
ech::HpkeSymmetricCipherSuite suite{kdfId, aeadId};
echConfigContent.key_config.cipher_suites.push_back(suite);
echConfig.version = ech::ECHVersion::Draft15;
echConfig.ech_config_content = encode(std::move(echConfigContent));
expectedEchConfigList.push_back(std::move(echConfig));

EXPECT_CALL(*echPolicy, getConfig("www.example.com"))
.WillOnce(Return(expectedEchConfigList));

// processConnect() should be called with the ECH config list returned from
// ECHPolicy::getConfig()
EXPECT_CALL(
*machine_,
_processConnect(
_,
_,
_,
_,
_,
_,
Truly([&expectedEchConfigList](
const folly::Optional<std::vector<ech::ECHConfig>>&
configList) {
return configList.hasValue() &&
configList->at(0).ech_config_content->coalesce() ==
expectedEchConfigList[0].ech_config_content->coalesce();
})))
.WillOnce(InvokeWithoutArgs([]() {
return detail::actions(ReportHandshakeSuccess(), WaitForData());
}));
const auto sni = std::string("www.example.com");
client_->connect(
&handshakeCallback_, nullptr, sni, pskIdentity_, folly::none);
}

TEST_F(AsyncFizzClientTest, TestECHAccepted) {
connect();
EXPECT_CALL(handshakeCallback_, _fizzHandshakeSuccess());
Expand Down
12 changes: 11 additions & 1 deletion fizz/client/test/Mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include <fizz/client/AsyncFizzClient.h>
#include <fizz/client/ClientExtensions.h>
#include <fizz/client/ECHPolicy.h>
#include <fizz/client/PskCache.h>
#include <folly/io/async/test/MockAsyncTransport.h>

Expand All @@ -28,7 +29,7 @@ class MockClientStateMachine : public ClientStateMachine {
folly::Optional<std::string> host,
folly::Optional<CachedPsk> cachedPsk,
const std::shared_ptr<ClientExtensions>& extensions,
const folly::Optional<std::vector<ech::ECHConfig>>& echConfigs));
folly::Optional<std::vector<ech::ECHConfig>> echConfigs));
Actions processConnect(
const State& state,
std::shared_ptr<const FizzClientContext> context,
Expand Down Expand Up @@ -132,6 +133,15 @@ class MockPskCache : public PskCache {
MOCK_METHOD(void, removePsk, (const std::string& identity));
};

class MockECHPolicy : public fizz::client::ECHPolicy {
public:
MOCK_METHOD(
folly::Optional<std::vector<fizz::ech::ECHConfig>>,
getConfig,
(const std::string& hostname),
(const));
};

class MockClientExtensions : public ClientExtensions {
public:
MOCK_METHOD(std::vector<Extension>, getClientHelloExtensions, (), (const));
Expand Down

0 comments on commit 63e3306

Please sign in to comment.