diff --git a/turbine/src/quic_endpoint.rs b/turbine/src/quic_endpoint.rs index 0f93391e042b47..df8d56437084c4 100644 --- a/turbine/src/quic_endpoint.rs +++ b/turbine/src/quic_endpoint.rs @@ -18,20 +18,20 @@ use { collections::{hash_map::Entry, HashMap}, io::Error as IoError, net::{IpAddr, SocketAddr, UdpSocket}, - ops::Deref, sync::Arc, }, thiserror::Error, tokio::{ sync::{ - mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender}, - RwLock, + mpsc::{error::TrySendError, Receiver as AsyncReceiver, Sender as AsyncSender}, + Mutex, RwLock as AsyncRwLock, }, task::JoinHandle, }, }; -const CLIENT_CHANNEL_CAPACITY: usize = 1 << 20; +const CLIENT_CHANNEL_BUFFER: usize = 1 << 14; +const ROUTER_CHANNEL_BUFFER: usize = 64; const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280; const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine"; const CONNECT_SERVER_NAME: &str = "solana-turbine"; @@ -47,7 +47,6 @@ const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY"; const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED"; pub type AsyncTryJoinHandle = TryJoin, JoinHandle<()>>; -type ConnectionCache = HashMap<(SocketAddr, Option), Arc>>>; #[derive(Error, Debug)] pub enum Error { @@ -99,10 +98,22 @@ pub fn new_quic_endpoint( )? }; endpoint.set_default_client_config(client_config); - let cache = Arc::>::default(); - let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_CAPACITY); - let server_task = runtime.spawn(run_server(endpoint.clone(), sender.clone(), cache.clone())); - let client_task = runtime.spawn(run_client(endpoint.clone(), client_receiver, sender, cache)); + let cache = Arc::>>::default(); + let router = Arc::>>>::default(); + let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER); + let server_task = runtime.spawn(run_server( + endpoint.clone(), + sender.clone(), + router.clone(), + cache.clone(), + )); + let client_task = runtime.spawn(run_client( + endpoint.clone(), + client_receiver, + sender, + router, + cache, + )); let task = futures::future::try_join(server_task, client_task); Ok((endpoint, client_sender, task)) } @@ -152,13 +163,15 @@ fn new_transport_config() -> TransportConfig { async fn run_server( endpoint: Endpoint, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { while let Some(connecting) = endpoint.accept().await { tokio::task::spawn(handle_connecting_error( endpoint.clone(), connecting, sender.clone(), + router.clone(), cache.clone(), )); } @@ -168,27 +181,61 @@ async fn run_client( endpoint: Endpoint, mut receiver: AsyncReceiver<(SocketAddr, Bytes)>, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { while let Some((remote_address, bytes)) = receiver.recv().await { - tokio::task::spawn(send_datagram_task( + let bytes = match router.read().await.get(&remote_address) { + None => bytes, + Some(sender) => match sender.try_send(bytes) { + Ok(()) => continue, + Err(TrySendError::Full(_)) => { + error!("TrySendError::Full {remote_address}"); + continue; + } + Err(TrySendError::Closed(bytes)) => bytes, + }, + }; + let receiver = { + let mut router = router.write().await; + let bytes = match router.get(&remote_address) { + None => bytes, + Some(sender) => match sender.try_send(bytes) { + Ok(()) => continue, + Err(TrySendError::Full(_)) => { + error!("TrySendError::Full {remote_address}"); + continue; + } + Err(TrySendError::Closed(bytes)) => bytes, + }, + }; + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + sender.try_send(bytes).unwrap(); + router.insert(remote_address, sender); + receiver + }; + tokio::task::spawn(make_connection_task( endpoint.clone(), remote_address, - bytes, sender.clone(), + receiver, + router.clone(), cache.clone(), )); } close_quic_endpoint(&endpoint); + // Drop sender channels to unblock threads waiting on the receiving end. + router.write().await.clear(); } async fn handle_connecting_error( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) { - if let Err(err) = handle_connecting(endpoint, connecting, sender, cache).await { + if let Err(err) = handle_connecting(endpoint, connecting, sender, router, cache).await { error!("handle_connecting: {err:?}"); } } @@ -197,52 +244,75 @@ async fn handle_connecting( endpoint: Endpoint, connecting: Connecting, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + router: Arc>>>, + cache: Arc>>, ) -> Result<(), Error> { let connection = connecting.await?; let remote_address = connection.remote_address(); let remote_pubkey = get_remote_pubkey(&connection)?; - handle_connection_error( + let receiver = { + let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER); + router.write().await.insert(remote_address, sender); + receiver + }; + handle_connection( endpoint, remote_address, remote_pubkey, connection, sender, + receiver, + router, cache, ) .await; Ok(()) } -async fn handle_connection_error( +async fn handle_connection( endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, ) { - cache_connection(remote_address, remote_pubkey, connection.clone(), &cache).await; - if let Err(err) = handle_connection( - &endpoint, + cache_connection(remote_pubkey, connection.clone(), &cache).await; + let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver)); + let read_datagram_task = tokio::task::spawn(read_datagram_task( + endpoint, remote_address, remote_pubkey, - &connection, - &sender, - ) - .await - { - drop_connection(remote_address, remote_pubkey, &connection, &cache).await; - error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"); + connection.clone(), + sender, + )); + match futures::future::try_join(send_datagram_task, read_datagram_task).await { + Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"), + Ok(out) => { + if let (Err(ref err), _) = out { + error!("send_datagram_task: {remote_pubkey}, {remote_address}, {err:?}"); + } + if let (_, Err(ref err)) = out { + error!("read_datagram_task: {remote_pubkey}, {remote_address}, {err:?}"); + } + } + } + drop_connection(remote_pubkey, &connection, &cache).await; + if let Entry::Occupied(entry) = router.write().await.entry(remote_address) { + if entry.get().is_closed() { + entry.remove(); + } } } -async fn handle_connection( - endpoint: &Endpoint, +async fn read_datagram_task( + endpoint: Endpoint, remote_address: SocketAddr, remote_pubkey: Pubkey, - connection: &Connection, - sender: &Sender<(Pubkey, SocketAddr, Bytes)>, + connection: Connection, + sender: Sender<(Pubkey, SocketAddr, Bytes)>, ) -> Result<(), Error> { // Assert that send won't block. debug_assert_eq!(sender.capacity(), None); @@ -250,7 +320,7 @@ async fn handle_connection( match connection.read_datagram().await { Ok(bytes) => { if let Err(err) = sender.send((remote_pubkey, remote_address, bytes)) { - close_quic_endpoint(endpoint); + close_quic_endpoint(&endpoint); return Err(Error::from(err)); } } @@ -265,67 +335,53 @@ async fn handle_connection( } async fn send_datagram_task( + connection: Connection, + mut receiver: AsyncReceiver, +) -> Result<(), Error> { + while let Some(bytes) = receiver.recv().await { + connection.send_datagram(bytes)?; + } + Ok(()) +} + +async fn make_connection_task( endpoint: Endpoint, remote_address: SocketAddr, - bytes: Bytes, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, ) { - if let Err(err) = send_datagram(&endpoint, remote_address, bytes, sender, cache).await { - error!("send_datagram: {remote_address}, {err:?}"); + if let Err(err) = + make_connection(endpoint, remote_address, sender, receiver, router, cache).await + { + error!("make_connection: {remote_address}, {err:?}"); } } -async fn send_datagram( - endpoint: &Endpoint, +async fn make_connection( + endpoint: Endpoint, remote_address: SocketAddr, - bytes: Bytes, sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, + receiver: AsyncReceiver, + router: Arc>>>, + cache: Arc>>, ) -> Result<(), Error> { - let connection = get_connection(endpoint, remote_address, sender, cache).await?; - connection.send_datagram(bytes)?; - Ok(()) -} - -async fn get_connection( - endpoint: &Endpoint, - remote_address: SocketAddr, - sender: Sender<(Pubkey, SocketAddr, Bytes)>, - cache: Arc>, -) -> Result { - let entry = get_cache_entry(remote_address, &cache).await; - { - let connection: Option = entry.read().await.clone(); - if let Some(connection) = connection { - if connection.close_reason().is_none() { - return Ok(connection); - } - } - } - let connection = { - // Need to write lock here so that only one task initiates - // a new connection to the same remote_address. - let mut entry = entry.write().await; - if let Some(connection) = entry.deref() { - if connection.close_reason().is_none() { - return Ok(connection.clone()); - } - } - let connection = endpoint - .connect(remote_address, CONNECT_SERVER_NAME)? - .await?; - entry.insert(connection).clone() - }; - tokio::task::spawn(handle_connection_error( - endpoint.clone(), + let connection = endpoint + .connect(remote_address, CONNECT_SERVER_NAME)? + .await?; + handle_connection( + endpoint, connection.remote_address(), get_remote_pubkey(&connection)?, - connection.clone(), + connection, sender, + receiver, + router, cache, - )); - Ok(connection) + ) + .await; + Ok(()) } fn get_remote_pubkey(connection: &Connection) -> Result { @@ -341,62 +397,34 @@ fn get_remote_pubkey(connection: &Connection) -> Result { } } -async fn get_cache_entry( - remote_address: SocketAddr, - cache: &RwLock, -) -> Arc>> { - let key = (remote_address, /*remote_pubkey:*/ None); - if let Some(entry) = cache.read().await.get(&key) { - return entry.clone(); - } - cache.write().await.entry(key).or_default().clone() -} - async fn cache_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: Connection, - cache: &RwLock, + cache: &Mutex>, ) { - let entries: [Arc>>; 2] = { - let mut cache = cache.write().await; - [Some(remote_pubkey), None].map(|remote_pubkey| { - let key = (remote_address, remote_pubkey); - cache.entry(key).or_default().clone() - }) + let Some(old) = cache.lock().await.insert(remote_pubkey, connection) else { + return; }; - let mut entry = entries[0].write().await; - *entries[1].write().await = Some(connection.clone()); - if let Some(old) = entry.replace(connection) { - drop(entry); - old.close( - CONNECTION_CLOSE_ERROR_CODE_REPLACED, - CONNECTION_CLOSE_REASON_REPLACED, - ); - } + old.close( + CONNECTION_CLOSE_ERROR_CODE_REPLACED, + CONNECTION_CLOSE_REASON_REPLACED, + ); } async fn drop_connection( - remote_address: SocketAddr, remote_pubkey: Pubkey, connection: &Connection, - cache: &RwLock, + cache: &Mutex>, ) { - if connection.close_reason().is_none() { - connection.close( - CONNECTION_CLOSE_ERROR_CODE_DROPPED, - CONNECTION_CLOSE_REASON_DROPPED, - ); - } - let key = (remote_address, Some(remote_pubkey)); - if let Entry::Occupied(entry) = cache.write().await.entry(key) { - if matches!(entry.get().read().await.deref(), - Some(entry) if entry.stable_id() == connection.stable_id()) - { + connection.close( + CONNECTION_CLOSE_ERROR_CODE_DROPPED, + CONNECTION_CLOSE_REASON_DROPPED, + ); + if let Entry::Occupied(entry) = cache.lock().await.entry(remote_pubkey) { + if entry.get().stable_id() == connection.stable_id() { entry.remove(); } } - // Cache entry for (remote_address, None) will be lazily evicted. } impl From> for Error {