Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
separates out routing shreds from establishing connections
Browse files Browse the repository at this point in the history
Currently each outgoing shred will attempt to establish a connection if
one does not already exist. This is very wasteful and consumes many
tokio tasks if the remote node is down or unresponsive.

The commit decouples routing packets from establishing connections by
adding a buffering channel for each remote address. Outgoing packets are
always sent down this channel to be processed once the connection is
established. If connecting attempt fails, all packets already pushed to
the channel are dropped at once, reducing the number of attempts to make
a connection if the remote node is down or unresponsive.
  • Loading branch information
behzadnouri committed Oct 9, 2023
1 parent c924719 commit f3572dc
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 101 deletions.
2 changes: 1 addition & 1 deletion turbine/src/cluster_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ impl From<Pubkey> for NodeId {

#[inline]
pub(crate) fn get_broadcast_protocol(_: &ShredId) -> Protocol {
Protocol::UDP
Protocol::QUIC
}

pub fn make_test_cluster<R: Rng>(
Expand Down
227 changes: 127 additions & 100 deletions turbine/src/quic_endpoint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ use {
thiserror::Error,
tokio::{
sync::{
mpsc::{Receiver as AsyncReceiver, Sender as AsyncSender},
mpsc::{error::TrySendError, Receiver as AsyncReceiver, Sender as AsyncSender},
RwLock,
},
task::JoinHandle,
},
};

const CLIENT_CHANNEL_CAPACITY: usize = 1 << 20;
const CLIENT_CHANNEL_BUFFER: usize = 1 << 14;
const ROUTER_CHANNEL_BUFFER: usize = 64;
const INITIAL_MAXIMUM_TRANSMISSION_UNIT: u16 = 1280;
const ALPN_TURBINE_PROTOCOL_ID: &[u8] = b"solana-turbine";
const CONNECT_SERVER_NAME: &str = "solana-turbine";
Expand All @@ -47,7 +48,7 @@ const CONNECTION_CLOSE_REASON_INVALID_IDENTITY: &[u8] = b"INVALID_IDENTITY";
const CONNECTION_CLOSE_REASON_REPLACED: &[u8] = b"REPLACED";

pub type AsyncTryJoinHandle = TryJoin<JoinHandle<()>, JoinHandle<()>>;
type ConnectionCache = HashMap<(SocketAddr, Option<Pubkey>), Arc<RwLock<Option<Connection>>>>;
type ConnectionCache = HashMap<Pubkey, Arc<RwLock<Option<Connection>>>>;

#[derive(Error, Debug)]
pub enum Error {
Expand Down Expand Up @@ -100,9 +101,21 @@ pub fn new_quic_endpoint(
};
endpoint.set_default_client_config(client_config);
let cache = Arc::<RwLock<ConnectionCache>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_CAPACITY);
let server_task = runtime.spawn(run_server(endpoint.clone(), sender.clone(), cache.clone()));
let client_task = runtime.spawn(run_client(endpoint.clone(), client_receiver, sender, cache));
let router = Arc::<RwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>::default();
let (client_sender, client_receiver) = tokio::sync::mpsc::channel(CLIENT_CHANNEL_BUFFER);
let server_task = runtime.spawn(run_server(
endpoint.clone(),
sender.clone(),
router.clone(),
cache.clone(),
));
let client_task = runtime.spawn(run_client(
endpoint.clone(),
client_receiver,
sender,
router,
cache,
));
let task = futures::future::try_join(server_task, client_task);
Ok((endpoint, client_sender, task))
}
Expand Down Expand Up @@ -152,13 +165,15 @@ fn new_transport_config() -> TransportConfig {
async fn run_server(
endpoint: Endpoint,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
router: Arc<RwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<RwLock<ConnectionCache>>,
) {
while let Some(connecting) = endpoint.accept().await {
tokio::task::spawn(handle_connecting_error(
endpoint.clone(),
connecting,
sender.clone(),
router.clone(),
cache.clone(),
));
}
Expand All @@ -168,27 +183,60 @@ async fn run_client(
endpoint: Endpoint,
mut receiver: AsyncReceiver<(SocketAddr, Bytes)>,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
router: Arc<RwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<RwLock<ConnectionCache>>,
) {
while let Some((remote_address, bytes)) = receiver.recv().await {
tokio::task::spawn(send_datagram_task(
let bytes = match router.read().await.get(&remote_address) {
None => bytes,
Some(sender) => match sender.try_send(bytes) {
Ok(()) => continue,
Err(TrySendError::Full(_)) => {
error!("TrySendError::Full {remote_address}");
continue;
}
Err(TrySendError::Closed(bytes)) => bytes,
},
};
let receiver = {
let mut router = router.write().await;
let bytes = match router.get(&remote_address) {
None => bytes,
Some(sender) => match sender.try_send(bytes) {
Ok(()) => continue,
Err(TrySendError::Full(_)) => {
error!("TrySendError::Full {remote_address}");
continue;
}
Err(TrySendError::Closed(bytes)) => bytes,
},
};
let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER);
sender.try_send(bytes).unwrap();
router.insert(remote_address, sender);
receiver
};
tokio::task::spawn(make_connection_task(
endpoint.clone(),
remote_address,
bytes,
sender.clone(),
receiver,
cache.clone(),
));
}
close_quic_endpoint(&endpoint);
// Drop sender channels to unblock threads waiting on the receiving end.
router.write().await.clear();
}

async fn handle_connecting_error(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
router: Arc<RwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<RwLock<ConnectionCache>>,
) {
if let Err(err) = handle_connecting(endpoint, connecting, sender, cache).await {
if let Err(err) = handle_connecting(endpoint, connecting, sender, router, cache).await {
error!("handle_connecting: {err:?}");
}
}
Expand All @@ -197,60 +245,76 @@ async fn handle_connecting(
endpoint: Endpoint,
connecting: Connecting,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
router: Arc<RwLock<HashMap<SocketAddr, AsyncSender<Bytes>>>>,
cache: Arc<RwLock<ConnectionCache>>,
) -> Result<(), Error> {
let connection = connecting.await?;
let remote_address = connection.remote_address();
let remote_pubkey = get_remote_pubkey(&connection)?;
handle_connection_error(
let receiver = {
let (sender, receiver) = tokio::sync::mpsc::channel(ROUTER_CHANNEL_BUFFER);
router.write().await.insert(remote_address, sender);
receiver
};
handle_connection(
endpoint,
remote_address,
remote_pubkey,
connection,
sender,
receiver,
cache,
)
.await;
Ok(())
}

async fn handle_connection_error(
async fn handle_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
cache: Arc<RwLock<ConnectionCache>>,
) {
cache_connection(remote_address, remote_pubkey, connection.clone(), &cache).await;
if let Err(err) = handle_connection(
&endpoint,
cache_connection(remote_pubkey, connection.clone(), &cache).await;
let send_datagram_task = tokio::task::spawn(send_datagram_task(connection.clone(), receiver));
let read_datagram_task = tokio::task::spawn(read_datagram_task(
endpoint,
remote_address,
remote_pubkey,
&connection,
&sender,
)
.await
{
drop_connection(remote_address, remote_pubkey, &connection, &cache).await;
error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}");
connection.clone(),
sender,
));
match futures::future::try_join(send_datagram_task, read_datagram_task).await {
Err(err) => error!("handle_connection: {remote_pubkey}, {remote_address}, {err:?}"),
Ok(out) => {
if let (Err(ref err), _) = out {
error!("send_datagram_task: {remote_pubkey}, {remote_address}, {err:?}");
}
if let (_, Err(ref err)) = out {
error!("read_datagram_task: {remote_pubkey}, {remote_address}, {err:?}");
}
}
}
drop_connection(remote_pubkey, &connection, &cache).await;
}

async fn handle_connection(
endpoint: &Endpoint,
async fn read_datagram_task(
endpoint: Endpoint,
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: &Connection,
sender: &Sender<(Pubkey, SocketAddr, Bytes)>,
connection: Connection,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
) -> Result<(), Error> {
// Assert that send won't block.
debug_assert_eq!(sender.capacity(), None);
loop {
match connection.read_datagram().await {
Ok(bytes) => {
if let Err(err) = sender.send((remote_pubkey, remote_address, bytes)) {
close_quic_endpoint(endpoint);
close_quic_endpoint(&endpoint);
return Err(Error::from(err));
}
}
Expand All @@ -265,67 +329,48 @@ async fn handle_connection(
}

async fn send_datagram_task(
connection: Connection,
mut receiver: AsyncReceiver<Bytes>,
) -> Result<(), Error> {
while let Some(bytes) = receiver.recv().await {
connection.send_datagram(bytes)?;
}
Ok(())
}

async fn make_connection_task(
endpoint: Endpoint,
remote_address: SocketAddr,
bytes: Bytes,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
cache: Arc<RwLock<ConnectionCache>>,
) {
if let Err(err) = send_datagram(&endpoint, remote_address, bytes, sender, cache).await {
error!("send_datagram: {remote_address}, {err:?}");
if let Err(err) = make_connection(endpoint, remote_address, sender, receiver, cache).await {
error!("make_connection: {remote_address}, {err:?}");
}
}

async fn send_datagram(
endpoint: &Endpoint,
async fn make_connection(
endpoint: Endpoint,
remote_address: SocketAddr,
bytes: Bytes,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
receiver: AsyncReceiver<Bytes>,
cache: Arc<RwLock<ConnectionCache>>,
) -> Result<(), Error> {
let connection = get_connection(endpoint, remote_address, sender, cache).await?;
connection.send_datagram(bytes)?;
Ok(())
}

async fn get_connection(
endpoint: &Endpoint,
remote_address: SocketAddr,
sender: Sender<(Pubkey, SocketAddr, Bytes)>,
cache: Arc<RwLock<ConnectionCache>>,
) -> Result<Connection, Error> {
let entry = get_cache_entry(remote_address, &cache).await;
{
let connection: Option<Connection> = entry.read().await.clone();
if let Some(connection) = connection {
if connection.close_reason().is_none() {
return Ok(connection);
}
}
}
let connection = {
// Need to write lock here so that only one task initiates
// a new connection to the same remote_address.
let mut entry = entry.write().await;
if let Some(connection) = entry.deref() {
if connection.close_reason().is_none() {
return Ok(connection.clone());
}
}
let connection = endpoint
.connect(remote_address, CONNECT_SERVER_NAME)?
.await?;
entry.insert(connection).clone()
};
tokio::task::spawn(handle_connection_error(
endpoint.clone(),
let connection = endpoint
.connect(remote_address, CONNECT_SERVER_NAME)?
.await?;
handle_connection(
endpoint,
connection.remote_address(),
get_remote_pubkey(&connection)?,
connection.clone(),
connection,
sender,
receiver,
cache,
));
Ok(connection)
)
.await;
Ok(())
}

fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
Expand All @@ -341,43 +386,25 @@ fn get_remote_pubkey(connection: &Connection) -> Result<Pubkey, Error> {
}
}

async fn get_cache_entry(
remote_address: SocketAddr,
cache: &RwLock<ConnectionCache>,
) -> Arc<RwLock<Option<Connection>>> {
let key = (remote_address, /*remote_pubkey:*/ None);
if let Some(entry) = cache.read().await.get(&key) {
return entry.clone();
}
cache.write().await.entry(key).or_default().clone()
}

async fn cache_connection(
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: Connection,
cache: &RwLock<ConnectionCache>,
) {
let entries: [Arc<RwLock<Option<Connection>>>; 2] = {
let entry = {
let mut cache = cache.write().await;
[Some(remote_pubkey), None].map(|remote_pubkey| {
let key = (remote_address, remote_pubkey);
cache.entry(key).or_default().clone()
})
cache.entry(remote_pubkey).or_default().clone()
};
let mut entry = entries[0].write().await;
*entries[1].write().await = Some(connection.clone());
if let Some(old) = entry.replace(connection) {
drop(entry);
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}
let Some(old) = entry.write().await.replace(connection) else {
return;
};
old.close(
CONNECTION_CLOSE_ERROR_CODE_REPLACED,
CONNECTION_CLOSE_REASON_REPLACED,
);
}

async fn drop_connection(
remote_address: SocketAddr,
remote_pubkey: Pubkey,
connection: &Connection,
cache: &RwLock<ConnectionCache>,
Expand All @@ -388,8 +415,7 @@ async fn drop_connection(
CONNECTION_CLOSE_REASON_DROPPED,
);
}
let key = (remote_address, Some(remote_pubkey));
if let Entry::Occupied(entry) = cache.write().await.entry(key) {
if let Entry::Occupied(entry) = cache.write().await.entry(remote_pubkey) {
if matches!(entry.get().read().await.deref(),
Some(entry) if entry.stable_id() == connection.stable_id())
{
Expand All @@ -416,6 +442,7 @@ mod tests {

#[test]
fn test_quic_endpoint() {
solana_logger::setup();
const NUM_ENDPOINTS: usize = 3;
const RECV_TIMEOUT: Duration = Duration::from_secs(60);
let runtime = tokio::runtime::Builder::new_multi_thread()
Expand Down

0 comments on commit f3572dc

Please sign in to comment.