Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
prunes repair QUIC connections (#33775)
Browse files Browse the repository at this point in the history
The commit implements lazy eviction for repair QUIC connections.
The cache is allowed to grow to 2 x capacity at which point at least
half of the entries with lowest stake are evicted, resulting in an
amortized O(1) performance.
  • Loading branch information
behzadnouri authored Oct 20, 2023
1 parent 96052d2 commit dc3c827
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 15 deletions.
144 changes: 129 additions & 15 deletions core/src/repair/quic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,20 @@ use {
rustls::{Certificate, PrivateKey},
serde_bytes::ByteBuf,
solana_quic_client::nonblocking::quic_client::SkipServerVerification,
solana_runtime::bank_forks::BankForks,
solana_sdk::{packet::PACKET_DATA_SIZE, pubkey::Pubkey, signature::Keypair},
solana_streamer::{
quic::SkipClientVerification, tls_certificates::new_self_signed_tls_certificate,
},
std::{
cmp::Reverse,
collections::{hash_map::Entry, HashMap},
io::{Cursor, Error as IoError},
net::{IpAddr, SocketAddr, UdpSocket},
sync::Arc,
sync::{
atomic::{AtomicBool, Ordering},
Arc, RwLock,
},
time::Duration,
},
thiserror::Error,
Expand All @@ -40,18 +45,20 @@ const CONNECT_SERVER_NAME: &str = "solana-repair";

const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const CONNECTION_CACHE_CAPACITY: usize = 4096;
const CONNECTION_CACHE_CAPACITY: usize = 3072;
const MAX_CONCURRENT_BIDI_STREAMS: VarInt = VarInt::from_u32(512);

const CONNECTION_CLOSE_ERROR_CODE_SHUTDOWN: VarInt = VarInt::from_u32(1);
const CONNECTION_CLOSE_ERROR_CODE_DROPPED: VarInt = VarInt::from_u32(2);
const CONNECTION_CLOSE_ERROR_CODE_INVALID_IDENTITY: VarInt = VarInt::from_u32(3);
const CONNECTION_CLOSE_ERROR_CODE_REPLACED: VarInt = VarInt::from_u32(4);
const CONNECTION_CLOSE_ERROR_CODE_PRUNED: VarInt = VarInt::from_u32(5);

const CONNECTION_CLOSE_REASON_SHUTDOWN: &[u8] = b"SHUTDOWN";
const CONNECTION_CLOSE_REASON_DROPPED: &[u8] = b"DROPPED";
const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY";
const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED";
const CONNECTION_CLOSE_REASON_PRUNED: &[u8] = b"PRUNED";

pub(crate) type AsyncTryJoinHandle = TryJoin<JoinHandle<()>, JoinHandle<()>>;

Expand Down Expand Up @@ -108,6 +115,7 @@ pub(crate) fn new_quic_endpoint(
socket: UdpSocket,
address: IpAddr,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
) -> Result<(Endpoint, AsyncSender<LocalRequest>, AsyncTryJoinHandle), Error> {
let (cert, key) = new_self_signed_tls_certificate(keypair, address)?;
let server_config = new_server_config(cert.clone(), key.clone())?;
Expand All @@ -124,19 +132,24 @@ pub(crate) fn new_quic_endpoint(
)?
};
endpoint.set_default_client_config(client_config);
let prune_cache_pending = Arc::<AtomicBool>::default();
let cache = Arc::<Mutex<HashMap<Pubkey, Connection>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let router = Arc::<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>::default();
let server_task = runtime.spawn(run_server(
endpoint.clone(),
remote_request_sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_receiver,
remote_request_sender,
bank_forks,
prune_cache_pending,
router,
cache,
));
Expand Down Expand Up @@ -189,6 +202,8 @@ fn new_transport_config() -> TransportConfig {
async fn run_server(
endpoint: Endpoint,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -197,6 +212,8 @@ async fn run_server(
endpoint.clone(),
connecting,
remote_request_sender.clone(),
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand All @@ -207,6 +224,8 @@ async fn run_client(
endpoint: Endpoint,
mut receiver: AsyncReceiver<LocalRequest>,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -230,6 +249,8 @@ async fn run_client(
remote_address,
remote_request_sender.clone(),
receiver,
bank_forks.clone(),
prune_cache_pending.clone(),
router.clone(),
cache.clone(),
));
Expand Down Expand Up @@ -263,11 +284,21 @@ async fn handle_connecting_error(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
if let Err(err) =
handle_connecting(endpoint, connecting, remote_request_sender, router, cache).await
if let Err(err) = handle_connecting(
endpoint,
connecting,
remote_request_sender,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await
{
error!("handle_connecting: {err:?}");
}
Expand All @@ -277,6 +308,8 @@ async fn handle_connecting(
endpoint: Endpoint,
connecting: Connecting,
remote_request_sender: Sender<RemoteRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -295,24 +328,37 @@ async fn handle_connecting(
connection,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
.await;
Ok(())
}

#[allow(clippy::too_many_arguments)]
async fn handle_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
cache_connection(remote_pubkey, connection.clone(), &cache).await;
cache_connection(
remote_pubkey,
connection.clone(),
bank_forks,
prune_cache_pending,
router.clone(),
cache.clone(),
)
.await;
let send_requests_task = tokio::task::spawn(send_requests_task(
endpoint.clone(),
connection.clone(),
Expand Down Expand Up @@ -492,6 +538,8 @@ async fn make_connection_task(
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
Expand All @@ -500,6 +548,8 @@ async fn make_connection_task(
remote_address,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -514,6 +564,8 @@ async fn make_connection(
remote_address: SocketAddr,
remote_request_sender: Sender<RemoteRequest>,
receiver: AsyncReceiver<LocalRequest>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) -> Result<(), Error> {
Expand All @@ -527,6 +579,8 @@ async fn make_connection(
connection,
remote_request_sender,
receiver,
bank_forks,
prune_cache_pending,
router,
cache,
)
Expand All @@ -550,25 +604,32 @@ fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
async fn cache_connection(
remote_pubkey: Pubkey,
connection: Connection,
cache: &Mutex<HashMap<Pubkey, Connection>>,
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
let old = {
let (old, should_prune_cache) = {
let mut cache = cache.lock().await;
if cache.len() >= CONNECTION_CACHE_CAPACITY {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_DROPPED,
CONNECTION_CLOSE_REASON_DROPPED,
);
return;
}
cache.insert(remote_pubkey, connection)
(
cache.insert(remote_pubkey, connection),
cache.len() >= CONNECTION_CACHE_CAPACITY.saturating_mul(2),
)
};
if let Some(old) = old {
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}
if should_prune_cache && !prune_cache_pending.swap(true, Ordering::Relaxed) {
tokio::task::spawn(prune_connection_cache(
bank_forks,
prune_cache_pending,
router,
cache,
));
}
}

async fn drop_connection(
Expand All @@ -587,6 +648,50 @@ async fn drop_connection(
}
}

async fn prune_connection_cache(
bank_forks: Arc<RwLock<BankForks>>,
prune_cache_pending: Arc<AtomicBool>,
router: Arc<AsyncRwLock<HashMap<SocketAddr, AsyncSender<LocalRequest>>>>,
cache: Arc<Mutex<HashMap<Pubkey, Connection>>>,
) {
debug_assert!(prune_cache_pending.load(Ordering::Relaxed));
let staked_nodes = {
let root_bank = bank_forks.read().unwrap().root_bank();
root_bank.staked_nodes()
};
{
let mut cache = cache.lock().await;
if cache.len() < CONNECTION_CACHE_CAPACITY.saturating_mul(2) {
prune_cache_pending.store(false, Ordering::Relaxed);
return;
}
let mut connections: Vec<_> = cache
.drain()
.filter(|(_, connection)| connection.close_reason().is_none())
.map(|entry @ (pubkey, _)| {
let stake = staked_nodes.get(&pubkey).copied().unwrap_or_default();
(stake, entry)
})
.collect();
connections
.select_nth_unstable_by_key(CONNECTION_CACHE_CAPACITY, |&(stake, _)| Reverse(stake));
for (_, (_, connection)) in &connections[CONNECTION_CACHE_CAPACITY..] {
connection.close(
CONNECTION_CLOSE_ERROR_CODE_PRUNED,
CONNECTION_CLOSE_REASON_PRUNED,
);
}
cache.extend(
connections
.into_iter()
.take(CONNECTION_CACHE_CAPACITY)
.map(|(_, entry)| entry),
);
prune_cache_pending.store(false, Ordering::Relaxed);
}
router.write().await.retain(|_, sender| !sender.is_closed());
}

impl<T> From<crossbeam_channel::SendError<T>> for Error {
fn from(_: crossbeam_channel::SendError<T>) -> Self {
Error::ChannelSendError
Expand All @@ -598,6 +703,8 @@ mod tests {
use {
super::*,
itertools::{izip, multiunzip},
solana_ledger::genesis_utils::{create_genesis_config, GenesisConfigInfo},
solana_runtime::bank::Bank,
solana_sdk::signature::Signer,
std::{iter::repeat_with, net::Ipv4Addr, time::Duration},
};
Expand Down Expand Up @@ -625,6 +732,12 @@ mod tests {
repeat_with(crossbeam_channel::unbounded::<RemoteRequest>)
.take(NUM_ENDPOINTS)
.unzip();
let bank_forks = {
let GenesisConfigInfo { genesis_config, .. } =
create_genesis_config(/*mint_lamports:*/ 100_000);
let bank = Bank::new_for_tests(&genesis_config);
Arc::new(RwLock::new(BankForks::new(bank)))
};
let (endpoints, senders, tasks): (Vec<_>, Vec<_>, Vec<_>) = multiunzip(
keypairs
.iter()
Expand All @@ -637,6 +750,7 @@ mod tests {
socket,
IpAddr::V4(Ipv4Addr::LOCALHOST),
remote_request_sender,
bank_forks.clone(),
)
.unwrap()
}),
Expand Down
1 change: 1 addition & 0 deletions core/src/validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1208,6 +1208,7 @@ impl Validator {
.expect("Operator must spin up node with valid QUIC serve-repair address")
.ip(),
repair_quic_endpoint_sender,
bank_forks.clone(),
)
.unwrap();

Expand Down

0 comments on commit dc3c827

Please sign in to comment.