From f3308f733fe3c18bb3df019a5797c98fdfa13c2c Mon Sep 17 00:00:00 2001 From: "mergify[bot]" <37929162+mergify[bot]@users.noreply.github.com> Date: Fri, 20 Oct 2023 21:19:22 +0000 Subject: [PATCH] v1.17: prunes repair QUIC connections (backport of #33775) (#33792) prunes repair QUIC connections (#33775) The commit implements lazy eviction for repair QUIC connections. The cache is allowed to grow to 2 x capacity at which point at least half of the entries with lowest stake are evicted, resulting in an amortized O(1) performance. (cherry picked from commit dc3c827299f0139b681dc79fa0a71622bce2feeb) Co-authored-by: behzad nouri --- core/src/repair/quic_endpoint.rs | 144 +++++++++++++++++++++++++++---- core/src/validator.rs | 1 + 2 files changed, 130 insertions(+), 15 deletions(-) diff --git a/core/src/repair/quic_endpoint.rs b/core/src/repair/quic_endpoint.rs index bf3a1802144a42..7d1cd29a32589f 100644 --- a/core/src/repair/quic_endpoint.rs +++ b/core/src/repair/quic_endpoint.rs @@ -13,15 +13,20 @@ use { rustls::{Certificate, PrivateKey}, serde_bytes::ByteBuf, solana_quic_client::nonblocking::quic_client::SkipServerVerification, + solana_runtime::bank_forks::BankForks, solana_sdk::{packet::PACKET_DATA_SIZE, pubkey::Pubkey, signature::Keypair}, solana_streamer::{ quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate, }, std::{ + cmp::Reverse, collections::{hash_map::Entry, HashMap}, io::{Cursor, Error as IoError}, net::{IpAddr, SocketAddr, UdpSocket}, - sync::Arc, + sync::{ + atomic::{AtomicBool, Ordering}, + Arc, RwLock, + }, time::Duration, }, thiserror::Error, @@ -40,18 +45,20 @@ const CONNECT_SERVER_NAME: &str = "solana-repair"; const CLIENT_CHANNEL_BUFFER: usize = 1 << 14; const ROUTER_CHANNEL_BUFFER: usize = 64; -const CONNECTION_CACHE_CAPACITY: usize = 4096; +const CONNECTION_CACHE_CAPACITY: usize = 3072; const MAX_CONCURRENT_BIDI_STREAMS: VarInt = VarInt::from_u32(512); const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1); const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2); const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3); const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4); +const CONNECTION_CLOSE_ERROR_CODE_PRUNED: VarInt = VarInt::from_u32(5); const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN"; const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED"; const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY"; const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED"; +const CONNECTION_CLOSE_REASON_PRUNED: &[u8] = b"PRUNED"; pub(crate) type AsyncTryJoinHandle = TryJoin, JoinHandle<()>>; @@ -108,6 +115,7 @@ pub(crate) fn new_quic_endpoint( socket: UdpSocket, address: IpAddr, remote_request_sender: Sender, + bank_forks: Arc>, ) -> Result<(Endpoint, AsyncSender, AsyncTryJoinHandle), Error> { let (cert, key) = new_self_signed_tls_certificate(keypair, address)?; let server_config = new_server_config(cert.clone(), key.clone())?; @@ -124,12 +132,15 @@ pub(crate) fn new_quic_endpoint( )? }; endpoint.set_default_client_config(client_config); + let prune_cache_pending = Arc::::default(); let cache = Arc::>>::default(); let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER); let router = Arc::>>>::default(); let server_task = runtime.spawn(run_server( endpoint.clone(), remote_request_sender.clone(), + bank_forks.clone(), + prune_cache_pending.clone(), router.clone(), cache.clone(), )); @@ -137,6 +148,8 @@ pub(crate) fn new_quic_endpoint( endpoint.clone(), client_receiver, remote_request_sender, + bank_forks, + prune_cache_pending, router, cache, )); @@ -189,6 +202,8 @@ fn new_transport_config() -> TransportConfig { async fn run_server( endpoint: Endpoint, remote_request_sender: Sender, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { @@ -197,6 +212,8 @@ async fn run_server( endpoint.clone(), connecting, remote_request_sender.clone(), + bank_forks.clone(), + prune_cache_pending.clone(), router.clone(), cache.clone(), )); @@ -207,6 +224,8 @@ async fn run_client( endpoint: Endpoint, mut receiver: AsyncReceiver, remote_request_sender: Sender, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { @@ -230,6 +249,8 @@ async fn run_client( remote_address, remote_request_sender.clone(), receiver, + bank_forks.clone(), + prune_cache_pending.clone(), router.clone(), cache.clone(), )); @@ -263,11 +284,21 @@ async fn handle_connecting_error( endpoint: Endpoint, connecting: Connecting, remote_request_sender: Sender, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { - if let Err(err) = - handle_connecting(endpoint, connecting, remote_request_sender, router, cache).await + if let Err(err) = handle_connecting( + endpoint, + connecting, + remote_request_sender, + bank_forks, + prune_cache_pending, + router, + cache, + ) + .await { error!("handle_connecting: {err:?}"); } @@ -277,6 +308,8 @@ async fn handle_connecting( endpoint: Endpoint, connecting: Connecting, remote_request_sender: Sender, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) -> Result<(), Error> { @@ -295,6 +328,8 @@ async fn handle_connecting( connection, remote_request_sender, receiver, + bank_forks, + prune_cache_pending, router, cache, ) @@ -302,6 +337,7 @@ async fn handle_connecting( Ok(()) } +#[allow(clippy::too_many_arguments)] async fn handle_connection( endpoint: Endpoint, remote_address: SocketAddr, @@ -309,10 +345,20 @@ async fn handle_connection( connection: Connection, remote_request_sender: Sender, receiver: AsyncReceiver, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { - cache_connection(remote_pubkey, connection.clone(), &cache).await; + cache_connection( + remote_pubkey, + connection.clone(), + bank_forks, + prune_cache_pending, + router.clone(), + cache.clone(), + ) + .await; let send_requests_task = tokio::task::spawn(send_requests_task( endpoint.clone(), connection.clone(), @@ -492,6 +538,8 @@ async fn make_connection_task( remote_address: SocketAddr, remote_request_sender: Sender, receiver: AsyncReceiver, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) { @@ -500,6 +548,8 @@ async fn make_connection_task( remote_address, remote_request_sender, receiver, + bank_forks, + prune_cache_pending, router, cache, ) @@ -514,6 +564,8 @@ async fn make_connection( remote_address: SocketAddr, remote_request_sender: Sender, receiver: AsyncReceiver, + bank_forks: Arc>, + prune_cache_pending: Arc, router: Arc>>>, cache: Arc>>, ) -> Result<(), Error> { @@ -527,6 +579,8 @@ async fn make_connection( connection, remote_request_sender, receiver, + bank_forks, + prune_cache_pending, router, cache, ) @@ -550,18 +604,17 @@ fn get_remote_pubkey(connection: &Connection) -> Result { async fn cache_connection( remote_pubkey: Pubkey, connection: Connection, - cache: &Mutex>, + bank_forks: Arc>, + prune_cache_pending: Arc, + router: Arc>>>, + cache: Arc>>, ) { - let old = { + let (old, should_prune_cache) = { let mut cache = cache.lock().await; - if cache.len() >= CONNECTION_CACHE_CAPACITY { - connection.close( - CONNECTION_CLOSE_ERROR_CODE_DROPPED, - CONNECTION_CLOSE_REASON_DROPPED, - ); - return; - } - cache.insert(remote_pubkey, connection) + ( + cache.insert(remote_pubkey, connection), + cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2), + ) }; if let Some(old) = old { old.close( @@ -569,6 +622,14 @@ async fn cache_connection( CONNECTION_CLOSE_REASON_REPLACED, ); } + if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) { + tokio::task::spawn(prune_connection_cache( + bank_forks, + prune_cache_pending, + router, + cache, + )); + } } async fn drop_connection( @@ -587,6 +648,50 @@ async fn drop_connection( } } +async fn prune_connection_cache( + bank_forks: Arc>, + prune_cache_pending: Arc, + router: Arc>>>, + cache: Arc>>, +) { + debug_assert!(prune_cache_pending.load(Ordering::Relaxed)); + let staked_nodes = { + let root_bank = bank_forks.read().unwrap().root_bank(); + root_bank.staked_nodes() + }; + { + let mut cache = cache.lock().await; + if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) { + prune_cache_pending.store(false, Ordering::Relaxed); + return; + } + let mut connections: Vec<_> = cache + .drain() + .filter(|(_, connection)| connection.close_reason().is_none()) + .map(|entry @ (pubkey, _)| { + let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default(); + (stake, entry) + }) + .collect(); + connections + .select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake)); + for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] { + connection.close( + CONNECTION_CLOSE_ERROR_CODE_PRUNED, + CONNECTION_CLOSE_REASON_PRUNED, + ); + } + cache.extend( + connections + .into_iter() + .take(CONNECTION_CACHE_CAPACITY) + .map(|(_, entry)| entry), + ); + prune_cache_pending.store(false, Ordering::Relaxed); + } + router.write().await.retain(|_, sender| !sender.is_closed()); +} + impl From> for Error { fn from(_: crossbeam_channel::SendError) -> Self { Error::ChannelSendError @@ -598,6 +703,8 @@ mod tests { use { super::*, itertools::{izip, multiunzip}, + solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo}, + solana_runtime::bank::Bank, solana_sdk::signature::Signer, std::{iter::repeat_with, net::Ipv4Addr, time::Duration}, }; @@ -625,6 +732,12 @@ mod tests { repeat_with(crossbeam_channel::unbounded::) .take(NUM_ENDPOINTS) .unzip(); + let bank_forks = { + let GenesisConfigInfo { genesis_config, .. } = + create_genesis_config(/*mint_lamports:*/ 100_000); + let bank = Bank::new_for_tests(&genesis_config); + Arc::new(RwLock::new(BankForks::new(bank))) + }; let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) = multiunzip( keypairs .iter() @@ -637,6 +750,7 @@ mod tests { socket, IpAddr::V4(Ipv4Addr::LOCALHOST), remote_request_sender, + bank_forks.clone(), ) .unwrap() }), diff --git a/core/src/validator.rs b/core/src/validator.rs index 3c19756f47d054..c23c724d5ec81f 100644 --- a/core/src/validator.rs +++ b/core/src/validator.rs @@ -1201,6 +1201,7 @@ impl Validator { .expect("Operator must spin up node with valid QUIC serve-repair address") .ip(), repair_quic_endpoint_sender, + bank_forks.clone(), ) .unwrap();