From f3572dc9ce2bc7a391a33f5bc8464f8e42c019e7 Mon Sep 17 00:00:00 2001 From: behzad nouri Date: Fri, 6 Oct 2023 10:31:04 -0400 Subject: [PATCH] separates out routing shreds from establishing connections Currently each outgoing shred will attempt to establish a connection if one does not already exist. This is very wasteful and consumes many tokio tasks if the remote node is down or unresponsive. The commit decouples routing packets from establishing connections by adding a buffering channel for each remote address. Outgoing packets are always sent down this channel to be processed once the connection is established. If connecting attempt fails, all packets already pushed to the channel are dropped at once, reducing the number of attempts to make a connection if the remote node is down or unresponsive. --- turbine/src/cluster_nodes.rs | 2 +- turbine/src/quic_endpoint.rs | 227 ++++++++++++++++++++--------------- 2 files changed, 128 insertions(+), 101 deletions(-) diff --git a/turbine/src/cluster_nodes.rs b/turbine/src/cluster_nodes.rs index 57676a34b75eff..ddbd13cc1dcb74 100644 --- a/turbine/src/cluster_nodes.rs +++ b/turbine/src/cluster_nodes.rs @@ -423,7 +423,7 @@ impl From for NodeId { #[inline] pub(crate) fn get_broadcast_protocol(_: &ShredId) -> Protocol { - Protocol::UDP + Protocol::QUIC } pub fn make_test_cluster( diff --git a/turbine/src/quic_endpoint.rs b/turbine/src/quic_endpoint.rs index 0f93391e042b47..4a7bc711b5856b 100644 --- a/turbine/src/quic_endpoint.rs +++ b/turbine/src/quic_endpoint.rs @@ -24,14 +24,15 @@ use { thiserror::Error, tokio::{ sync::{ - mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender}, + mpsc::{error::TrySendError, Receiver as AsyncReceiver, Sender as AsyncSender}, RwLock, }, task::JoinHandle, }, }; -const CLIENT_CHANNEL_CAPACITY: usize = 1 << 20; +const CLIENT_CHANNEL_BUFFER: usize = 1 << 14; +const ROUTER_CHANNEL_BUFFER: usize = 64; const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280; const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine"; const CONNECT_SERVER_NAME: &str = "solana-turbine"; @@ -47,7 +48,7 @@ const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY"; const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED"; pub type AsyncTryJoinHandle = TryJoin, JoinHandle<()>>; -type ConnectionCache = HashMap<(SocketAddr, Option), Arc>>>; +type ConnectionCache = HashMap>>>; #[derive(Error, Debug)] pub enum Error { @@ -100,9 +101,21 @@ pub fn new_quic_endpoint( }; endpoint.set_default_client_config(client_config); let cache = Arc::>::default(); - let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_CAPACITY); - let server_task = runtime.spawn(run_server(endpoint.clone(), sender.clone(), cache.clone())); - let client_task = runtime.spawn(run_client(endpoint.clone(), client_receiver, sender, cache)); + let router = Arc::>>>::default(); + let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER); + let server_task = runtime.spawn(run_server( + endpoint.clone(), + sender.clone(), + router.clone(), + cache.clone(), + )); + let client_task = runtime.spawn(run_client( + endpoint.clone(), + client_receiver, + sender, + router, + cache, + )); let task = futures::future::try_join(server_task, client_task); Ok((endpoint, client_sender, task)) } @@ -152,6 +165,7 @@ fn new_transport_config() -> TransportConfig { async fn run_server( endpoint: Endpoint, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + router: Arc>>>, cache: Arc>, ) { while let Some(connecting) = endpoint.accept().await { @@ -159,6 +173,7 @@ async fn run_server( endpoint.clone(), connecting, sender.clone(), + router.clone(), cache.clone(), )); } @@ -168,27 +183,60 @@ async fn run_client( endpoint: Endpoint, mut receiver: AsyncReceiver<(SocketAddr, Bytes)>, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + router: Arc>>>, cache: Arc>, ) { while let Some((remote_address, bytes)) = receiver.recv().await { - tokio::task::spawn(send_datagram_task( + let bytes = match router.read().await.get(&remote_address) { + None => bytes, + Some(sender) => match sender.try_send(bytes) { + Ok(()) => continue, + Err(TrySendError::Full(_)) => { + error!("TrySendError::Full {remote_address}"); + continue; + } + Err(TrySendError::Closed(bytes)) => bytes, + }, + }; + let receiver = { + let mut router = router.write().await; + let bytes = match router.get(&remote_address) { + None => bytes, + Some(sender) => match sender.try_send(bytes) { + Ok(()) => continue, + Err(TrySendError::Full(_)) => { + error!("TrySendError::Full {remote_address}"); + continue; + } + Err(TrySendError::Closed(bytes)) => bytes, + }, + }; + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + sender.try_send(bytes).unwrap(); + router.insert(remote_address, sender); + receiver + }; + tokio::task::spawn(make_connection_task( endpoint.clone(), remote_address, - bytes, sender.clone(), + receiver, cache.clone(), )); } close_quic_endpoint(&endpoint); + // Drop sender channels to unblock threads waiting on the receiving end. + router.write().await.clear(); } async fn handle_connecting_error( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + router: Arc>>>, cache: Arc>, ) { - if let Err(err) = handle_connecting(endpoint, connecting, sender, cache).await { + if let Err(err) = handle_connecting(endpoint, connecting, sender, router, cache).await { error!("handle_connecting: {err:?}"); } } @@ -197,52 +245,68 @@ async fn handle_connecting( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + router: Arc>>>, cache: Arc>, ) -> Result<(), Error> { let connection = connecting.await?; let remote_address = connection.remote_address(); let remote_pubkey = get_remote_pubkey(&connection)?; - handle_connection_error( + let receiver = { + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + router.write().await.insert(remote_address, sender); + receiver + }; + handle_connection( endpoint, remote_address, remote_pubkey, connection, sender, + receiver, cache, ) .await; Ok(()) } -async fn handle_connection_error( +async fn handle_connection( endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + receiver: AsyncReceiver, cache: Arc>, ) { - cache_connection(remote_address, remote_pubkey, connection.clone(), &cache).await; - if let Err(err) = handle_connection( - &endpoint, + cache_connection(remote_pubkey, connection.clone(), &cache).await; + let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver)); + let read_datagram_task = tokio::task::spawn(read_datagram_task( + endpoint, remote_address, remote_pubkey, - &connection, - &sender, - ) - .await - { - drop_connection(remote_address, remote_pubkey, &connection, &cache).await; - error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"); + connection.clone(), + sender, + )); + match futures::future::try_join(send_datagram_task, read_datagram_task).await { + Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"), + Ok(out) => { + if let (Err(ref err), _) = out { + error!("send_datagram_task: {remote_pubkey}, {remote_address}, {err:?}"); + } + if let (_, Err(ref err)) = out { + error!("read_datagram_task: {remote_pubkey}, {remote_address}, {err:?}"); + } + } } + drop_connection(remote_pubkey, &connection, &cache).await; } -async fn handle_connection( - endpoint: &Endpoint, +async fn read_datagram_task( + endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, - connection: &Connection, - sender: &Sender<(Pubkey, SocketAddr, Bytes)>, + connection: Connection, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, ) -> Result<(), Error> { // Assert that send won't block. debug_assert_eq!(sender.capacity(), None); @@ -250,7 +314,7 @@ async fn handle_connection( match connection.read_datagram().await { Ok(bytes) => { if let Err(err) = sender.send((remote_pubkey, remote_address, bytes)) { - close_quic_endpoint(endpoint); + close_quic_endpoint(&endpoint); return Err(Error::from(err)); } } @@ -265,67 +329,48 @@ async fn handle_connection( } async fn send_datagram_task( + connection: Connection, + mut receiver: AsyncReceiver, +) -> Result<(), Error> { + while let Some(bytes) = receiver.recv().await { + connection.send_datagram(bytes)?; + } + Ok(()) +} + +async fn make_connection_task( endpoint: Endpoint, remote_address: SocketAddr, - bytes: Bytes, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + receiver: AsyncReceiver, cache: Arc>, ) { - if let Err(err) = send_datagram(&endpoint, remote_address, bytes, sender, cache).await { - error!("send_datagram: {remote_address}, {err:?}"); + if let Err(err) = make_connection(endpoint, remote_address, sender, receiver, cache).await { + error!("make_connection: {remote_address}, {err:?}"); } } -async fn send_datagram( - endpoint: &Endpoint, +async fn make_connection( + endpoint: Endpoint, remote_address: SocketAddr, - bytes: Bytes, sender: Sender<(Pubkey, SocketAddr, Bytes)>, + receiver: AsyncReceiver, cache: Arc>, ) -> Result<(), Error> { - let connection = get_connection(endpoint, remote_address, sender, cache).await?; - connection.send_datagram(bytes)?; - Ok(()) -} - -async fn get_connection( - endpoint: &Endpoint, - remote_address: SocketAddr, - sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, -) -> Result { - let entry = get_cache_entry(remote_address, &cache).await; - { - let connection: Option = entry.read().await.clone(); - if let Some(connection) = connection { - if connection.close_reason().is_none() { - return Ok(connection); - } - } - } - let connection = { - // Need to write lock here so that only one task initiates - // a new connection to the same remote_address. - let mut entry = entry.write().await; - if let Some(connection) = entry.deref() { - if connection.close_reason().is_none() { - return Ok(connection.clone()); - } - } - let connection = endpoint - .connect(remote_address, CONNECT_SERVER_NAME)? - .await?; - entry.insert(connection).clone() - }; - tokio::task::spawn(handle_connection_error( - endpoint.clone(), + let connection = endpoint + .connect(remote_address, CONNECT_SERVER_NAME)? + .await?; + handle_connection( + endpoint, connection.remote_address(), get_remote_pubkey(&connection)?, - connection.clone(), + connection, sender, + receiver, cache, - )); - Ok(connection) + ) + .await; + Ok(()) } fn get_remote_pubkey(connection: &Connection) -> Result { @@ -341,43 +386,25 @@ fn get_remote_pubkey(connection: &Connection) -> Result { } } -async fn get_cache_entry( - remote_address: SocketAddr, - cache: &RwLock, -) -> Arc>> { - let key = (remote_address, /*remote_pubkey:*/ None); - if let Some(entry) = cache.read().await.get(&key) { - return entry.clone(); - } - cache.write().await.entry(key).or_default().clone() -} - async fn cache_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, cache: &RwLock, ) { - let entries: [Arc>>; 2] = { + let entry = { let mut cache = cache.write().await; - [Some(remote_pubkey), None].map(|remote_pubkey| { - let key = (remote_address, remote_pubkey); - cache.entry(key).or_default().clone() - }) + cache.entry(remote_pubkey).or_default().clone() }; - let mut entry = entries[0].write().await; - *entries[1].write().await = Some(connection.clone()); - if let Some(old) = entry.replace(connection) { - drop(entry); - old.close( - CONNECTION_CLOSE_ERROR_CODE_REPLACED, - CONNECTION_CLOSE_REASON_REPLACED, - ); - } + let Some(old) = entry.write().await.replace(connection) else { + return; + }; + old.close( + CONNECTION_CLOSE_ERROR_CODE_REPLACED, + CONNECTION_CLOSE_REASON_REPLACED, + ); } async fn drop_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: &Connection, cache: &RwLock, @@ -388,8 +415,7 @@ async fn drop_connection( CONNECTION_CLOSE_REASON_DROPPED, ); } - let key = (remote_address, Some(remote_pubkey)); - if let Entry::Occupied(entry) = cache.write().await.entry(key) { + if let Entry::Occupied(entry) = cache.write().await.entry(remote_pubkey) { if matches!(entry.get().read().await.deref(), Some(entry) if entry.stable_id() == connection.stable_id()) { @@ -416,6 +442,7 @@ mod tests { #[test] fn test_quic_endpoint() { + solana_logger::setup(); const NUM_ENDPOINTS: usize = 3; const RECV_TIMEOUT: Duration = Duration::from_secs(60); let runtime = tokio::runtime::Builder::new_multi_thread()