diff --git a/lib/internal/quic/core.js b/lib/internal/quic/core.js index cb9a0c6eb9..5dfbf810f4 100644 --- a/lib/internal/quic/core.js +++ b/lib/internal/quic/core.js @@ -1063,6 +1063,9 @@ class QuicSocket extends EventEmitter { if (typeof callback === 'function') session.on('ready', callback); + if (this.bound) + session[kReady](); + this[kMaybeBind](connectAfterBind.bind( this, session, diff --git a/src/node_quic_session.cc b/src/node_quic_session.cc index 8fe1451a67..d8bdd89fe9 100644 --- a/src/node_quic_session.cc +++ b/src/node_quic_session.cc @@ -2050,7 +2050,7 @@ void QuicSession::RemoveFromSocket() { socket_->DisassociateCID(QuicCID(&cid)); Debug(this, "Removed from the QuicSocket."); - socket_->RemoveSession(QuicCID(scid_), **GetRemoteAddress()); + socket_->RemoveSession(QuicCID(scid_), GetRemoteAddress()->GetSockaddrStorage()); socket_.reset(); } diff --git a/src/node_quic_socket.cc b/src/node_quic_socket.cc index c6de5f33f6..8773eb13a6 100644 --- a/src/node_quic_socket.cc +++ b/src/node_quic_socket.cc @@ -247,7 +247,7 @@ void QuicSocket::AddSession( const QuicCID& cid, BaseObjectPtr session) { sessions_[cid.ToStr()] = session; - IncrementSocketAddressCounter(**session->GetRemoteAddress()); + IncrementSocketAddressCounter(session->GetRemoteAddress()->GetSockaddrStorage()); IncrementSocketStat( 1, &socket_stats_, session->IsServer() ? @@ -485,7 +485,7 @@ int QuicSocket::ReceiveStop() { return udp_->RecvStop(); } -void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr* addr) { +void QuicSocket::RemoveSession(const QuicCID& cid, const sockaddr_storage* addr) { sessions_.erase(cid.ToStr()); DecrementSocketAddressCounter(addr); } @@ -659,7 +659,7 @@ namespace { void QuicSocket::SetValidatedAddress(const sockaddr* addr) { if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) { // Remove the oldest item if we've hit the LRU limit - validated_addrs_.push_back(addr_hash(addr)); + validated_addrs_.push_back(addr_hash(*addr)); if (validated_addrs_.size() > MAX_VALIDATE_ADDRESS_LRU) validated_addrs_.pop_front(); } @@ -669,7 +669,7 @@ bool QuicSocket::IsValidatedAddress(const sockaddr* addr) const { if (IsOptionSet(QUICSOCKET_OPTIONS_VALIDATE_ADDRESS_LRU)) { auto res = std::find(std::begin(validated_addrs_), std::end(validated_addrs_), - addr_hash(addr)); + addr_hash(*addr)); return res != std::end(validated_addrs_); } return false; @@ -721,9 +721,13 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( // Check to see if the number of connections for this peer has been exceeded. // If the count has been exceeded, shutdown the connection immediately // after the initial keys are installed. - if (GetCurrentSocketAddressCounter(addr) >= max_connections_per_host_) { - Debug(this, "Connection count for address exceeded"); - initial_connection_close = NGTCP2_SERVER_BUSY; + { + sockaddr_storage storage; + memcpy(&storage, addr, SocketAddress::GetLength(addr)); + if (GetCurrentSocketAddressCounter(&storage) >= max_connections_per_host_) { + Debug(this, "Connection count for address exceeded"); + initial_connection_close = NGTCP2_SERVER_BUSY; + } } // QUIC has address validation built in to the handshake but allows for @@ -782,22 +786,22 @@ BaseObjectPtr QuicSocket::AcceptInitialPacket( return session; } -void QuicSocket::IncrementSocketAddressCounter(const sockaddr* addr) { - addr_counts_[addr]++; +void QuicSocket::IncrementSocketAddressCounter(const sockaddr_storage* addr) { + addr_counts_[*addr]++; } -void QuicSocket::DecrementSocketAddressCounter(const sockaddr* addr) { - auto it = addr_counts_.find(addr); +void QuicSocket::DecrementSocketAddressCounter(const sockaddr_storage* addr) { + auto it = addr_counts_.find(*addr); if (it == std::end(addr_counts_)) return; it->second--; // Remove the address if the counter reaches zero again. if (it->second == 0) - addr_counts_.erase(addr); + addr_counts_.erase(*addr); } -size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr* addr) { - auto it = addr_counts_.find(addr); +size_t QuicSocket::GetCurrentSocketAddressCounter(const sockaddr_storage* addr) { + auto it = addr_counts_.find(*addr); if (it == std::end(addr_counts_)) return 0; return it->second; diff --git a/src/node_quic_socket.h b/src/node_quic_socket.h index 4550490101..cd24c36d4d 100644 --- a/src/node_quic_socket.h +++ b/src/node_quic_socket.h @@ -133,7 +133,7 @@ class QuicSocket : public AsyncWrap, int ReceiveStop(); void RemoveSession( const QuicCID& cid, - const sockaddr* addr); + const sockaddr_storage* addr); void ReportSendError( int error); int SendPacket( @@ -233,9 +233,9 @@ class QuicSocket : public AsyncWrap, const struct sockaddr* addr, unsigned int flags); - void IncrementSocketAddressCounter(const sockaddr* addr); - void DecrementSocketAddressCounter(const sockaddr* addr); - size_t GetCurrentSocketAddressCounter(const sockaddr* addr); + void IncrementSocketAddressCounter(const sockaddr_storage* addr); + void DecrementSocketAddressCounter(const sockaddr_storage* addr); + size_t GetCurrentSocketAddressCounter(const sockaddr_storage* addr); void IncrementPendingCallbacks() { pending_callbacks_++; } void DecrementPendingCallbacks() { pending_callbacks_--; } @@ -315,7 +315,7 @@ class QuicSocket : public AsyncWrap, // value reaches the value of max_connections_per_host_, // attempts to create new connections will be ignored // until the value falls back below the limit. - std::unordered_map addr_counts_; // The validated_addrs_ vector is used as an LRU cache for diff --git a/src/node_sockaddr-inl.h b/src/node_sockaddr-inl.h index 24dd9a9e9a..e680c0cf49 100644 --- a/src/node_sockaddr-inl.h +++ b/src/node_sockaddr-inl.h @@ -25,7 +25,7 @@ inline void hash_combine(size_t* seed, const T& value, Args... rest) { } } // namespace -size_t SocketAddress::Hash::operator()(const sockaddr* addr) const { +static size_t GetHash(const sockaddr* addr) { size_t hash = 0; switch (addr->sa_family) { case AF_INET: { @@ -48,11 +48,20 @@ size_t SocketAddress::Hash::operator()(const sockaddr* addr) const { return hash; } +size_t SocketAddress::Hash::operator()(const sockaddr& addr) const { + return GetHash(&addr); +} + +size_t SocketAddress::Hash::operator()(const sockaddr_storage& addr_storage) const { + const sockaddr* addr = reinterpret_cast(&addr_storage); + return GetHash(addr); +} + bool SocketAddress::Compare::operator()( - const sockaddr* laddr, - const sockaddr* raddr) const { - CHECK(laddr->sa_family == AF_INET || laddr->sa_family == AF_INET6); - return memcmp(laddr, raddr, GetLength(laddr)) == 0; + const sockaddr_storage& laddr, + const sockaddr_storage& raddr) const { + CHECK(laddr.ss_family == AF_INET || laddr.ss_family == AF_INET6); + return memcmp(&laddr, &raddr, GetLength(&laddr)) == 0; } bool SocketAddress::is_numeric_host(const char* hostname) { @@ -146,6 +155,10 @@ const sockaddr* SocketAddress::operator*() const { return reinterpret_cast(&address_); } +const sockaddr_storage* SocketAddress::GetSockaddrStorage() const { + return &address_; +} + size_t SocketAddress::GetLength() const { return GetLength(&address_); } diff --git a/src/node_sockaddr.h b/src/node_sockaddr.h index 1ac5cbbdd4..86e13f5eb0 100644 --- a/src/node_sockaddr.h +++ b/src/node_sockaddr.h @@ -15,11 +15,12 @@ namespace node { class SocketAddress { public: struct Hash { - inline size_t operator()(const sockaddr* addr) const; + inline size_t operator()(const sockaddr& addr) const; + inline size_t operator()(const sockaddr_storage& addr_storage) const; }; struct Compare { - inline bool operator()(const sockaddr* laddr, const sockaddr* raddr) const; + inline bool operator()(const sockaddr_storage& laddr, const sockaddr_storage& raddr) const; }; inline static bool is_numeric_host(const char* hostname); @@ -56,6 +57,8 @@ class SocketAddress { inline const sockaddr* operator*() const; + inline const sockaddr_storage* GetSockaddrStorage() const; + inline size_t GetLength() const; inline int GetFamily() const; diff --git a/test/parallel/test-quic-client-connect-callback.js b/test/parallel/test-quic-client-connect-callback.js new file mode 100644 index 0000000000..ed2f7e877c --- /dev/null +++ b/test/parallel/test-quic-client-connect-callback.js @@ -0,0 +1,72 @@ +'use strict'; + +const common = require('../common'); +if (!common.hasQuic) + common.skip('missing quic'); + +const { createSocket } = require('quic'); +const fixtures = require('../common/fixtures'); +const Countdown = require('../common/countdown'); +const key = fixtures.readKey('agent1-key.pem', 'binary'); +const cert = fixtures.readKey('agent1-cert.pem', 'binary'); +const ca = fixtures.readKey('ca1-cert.pem', 'binary'); + +const kServerName = 'agent2'; +const kALPN = 'zzz'; +const kIdleTimeout = 0; +const kConnections = 5; + +// After QuicSocket bound, the callback of QuicSocket.connect() +// should still get called. +{ + let client; + const server = createSocket({ + port: 0, + }); + + server.listen({ + key, + cert, + ca, + alpn: kALPN, + idleTimeout: kIdleTimeout, + }); + + const countdown = new Countdown(kConnections, () => { + client.close(); + server.close(); + }); + + server.on('ready', common.mustCall(() => { + const options = { + key, + cert, + ca, + address: common.localhostIPv4, + port: server.address.port, + servername: kServerName, + alpn: kALPN, + idleTimeout: kIdleTimeout, + }; + + client = createSocket({ + port: 0, + }); + + const session = client.connect(options, common.mustCall(() => { + session.close(common.mustCall(() => { + // After a session being ready, the socket should have bound + // and we could start the test. + testConnections(); + })); + })); + + const testConnections = common.mustCall(() => { + for (let i = 0; i < kConnections; i += 1) { + client.connect(options, common.mustCall(() => { + countdown.dec(); + })); + } + }); + })); +}