Skip to content

Commit

Permalink
removes intermediate vector allocations in ClusterNodes::get_retransm…
Browse files Browse the repository at this point in the history
…it_addrs (#4146)

ClusterNodes::get_retransmit_addrs does 2 intermediate collects:
https://github.com/anza-xyz/agave/blob/3890ce5bc/turbine/src/cluster_nodes.rs#L222
https://github.com/anza-xyz/agave/blob/3890ce5bc/turbine/src/cluster_nodes.rs#L239

The commit avoids both by chaining iterator operations.
  • Loading branch information
behzadnouri authored Dec 18, 2024
1 parent dd63bae commit 11b2e32
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 90 deletions.
1 change: 1 addition & 0 deletions streamer/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ impl SocketAddrSpace {
}

/// Returns true if the IP address is valid.
#[inline]
#[must_use]
pub fn check(&self, addr: &SocketAddr) -> bool {
if matches!(self, SocketAddrSpace::Unspecified) {
Expand Down
9 changes: 7 additions & 2 deletions turbine/benches/cluster_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ use {
solana_gossip::contact_info::ContactInfo,
solana_ledger::shred::{Shred, ShredFlags},
solana_sdk::{clock::Slot, genesis_config::ClusterType, pubkey::Pubkey},
solana_streamer::socket::SocketAddrSpace,
solana_turbine::{
cluster_nodes::{make_test_cluster, new_cluster_nodes, ClusterNodes},
retransmit_stage::RetransmitStage,
Expand Down Expand Up @@ -45,8 +46,12 @@ fn get_retransmit_peers_deterministic(
0,
0,
);
let _retransmit_peers =
cluster_nodes.get_retransmit_peers(slot_leader, &shred.id(), /*fanout:*/ 200);
let _retransmit_peers = cluster_nodes.get_retransmit_addrs(
slot_leader,
&shred.id(),
200, // fanout
&SocketAddrSpace::Unspecified,
);
}
}

Expand Down
154 changes: 72 additions & 82 deletions turbine/src/cluster_nodes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use {
std::{
any::TypeId,
cmp::Reverse,
collections::HashMap,
collections::{HashMap, HashSet},
iter::repeat_with,
marker::PhantomData,
net::{IpAddr, SocketAddr},
Expand Down Expand Up @@ -83,14 +83,6 @@ pub struct ClusterNodesCache<T> {
ttl: Duration, // Time to live.
}

pub struct RetransmitPeers<'a> {
root_distance: usize, // distance from the root node
children: Vec<&'a Node>,
// Maps tvu addresses to the first node
// in the shuffle with the same address.
addrs: HashMap<SocketAddr, Pubkey>, // tvu addresses
}

impl Node {
#[inline]
fn pubkey(&self) -> Pubkey {
Expand Down Expand Up @@ -168,33 +160,13 @@ impl ClusterNodes<BroadcastStage> {
}

impl ClusterNodes<RetransmitStage> {
pub(crate) fn get_retransmit_addrs(
pub fn get_retransmit_addrs(
&self,
slot_leader: &Pubkey,
shred: &ShredId,
fanout: usize,
socket_addr_space: &SocketAddrSpace,
) -> Result<(/*root_distance:*/ usize, Vec<SocketAddr>), Error> {
let RetransmitPeers {
root_distance,
children,
addrs,
} = self.get_retransmit_peers(slot_leader, shred, fanout)?;
let protocol = get_broadcast_protocol(shred);
let peers = children.into_iter().filter_map(|node| {
node.contact_info()?
.tvu(protocol)
.ok()
.filter(|addr| addrs.get(addr) == Some(&node.pubkey()))
});
Ok((root_distance, peers.collect()))
}

pub fn get_retransmit_peers(
&self,
slot_leader: &Pubkey,
shred: &ShredId,
fanout: usize,
) -> Result<RetransmitPeers, Error> {
let mut weighted_shuffle = self.weighted_shuffle.clone();
// Exclude slot leader from list of nodes.
if slot_leader == &self.pubkey {
Expand All @@ -206,39 +178,30 @@ impl ClusterNodes<RetransmitStage> {
if let Some(index) = self.index.get(slot_leader) {
weighted_shuffle.remove_index(*index);
}
let mut addrs = HashMap::<SocketAddr, Pubkey>::with_capacity(self.nodes.len());
let mut rng = get_seeded_rng(slot_leader, shred);
let protocol = get_broadcast_protocol(shred);
let nodes: Vec<_> = weighted_shuffle
.shuffle(&mut rng)
.map(|index| &self.nodes[index])
.inspect(|node| {
if let Some(node) = node.contact_info() {
if let Ok(addr) = node.tvu(protocol) {
addrs.entry(addr).or_insert(*node.pubkey());
}
}
let nodes = {
let protocol = get_broadcast_protocol(shred);
// If there are 2 nodes in the shuffle with the same socket-addr,
// we only send shreds to the first one. The hash-set below allows
// to track if a socket-addr was observed earlier in the shuffle.
let mut addrs = HashSet::<SocketAddr>::with_capacity(self.nodes.len());
weighted_shuffle.shuffle(&mut rng).map(move |index| {
let node = &self.nodes[index];
let addr: Option<SocketAddr> = node
.contact_info()
.and_then(|node| node.tvu(protocol).ok())
.filter(|&addr| addrs.insert(addr));
(node, addr)
})
.collect();
let self_index = nodes
.iter()
.position(|node| node.pubkey() == self.pubkey)
.unwrap();
let root_distance = if self_index == 0 {
0
} else if self_index <= fanout {
1
} else if self_index <= fanout.saturating_add(1).saturating_mul(fanout) {
2
} else {
3 // If changed, update MAX_NUM_TURBINE_HOPS.
};
let peers = get_retransmit_peers(fanout, self_index, &nodes);
Ok(RetransmitPeers {
root_distance,
children: peers.collect(),
addrs,
})
let (index, peers) =
get_retransmit_peers(fanout, |(node, _)| node.pubkey() == self.pubkey, nodes);
let peers = peers
.filter_map(|(_, addr)| addr)
.filter(|addr| socket_addr_space.check(addr))
.collect();
let root_distance = get_root_distance(index, fanout);
Ok((root_distance, peers))
}

// Returns the parent node in the turbine broadcast tree.
Expand Down Expand Up @@ -393,22 +356,29 @@ fn get_seeded_rng(leader: &Pubkey, shred: &ShredId) -> ChaChaRng {
// Each other node retransmits shreds to fanout many nodes in the next layer.
// For example the node k in the 1st layer will retransmit to nodes:
// fanout + k, 2*fanout + k, ..., fanout*fanout + k
fn get_retransmit_peers<T: Copy>(
fn get_retransmit_peers<T>(
fanout: usize,
index: usize, // Local node's index within the nodes slice.
nodes: &[T],
) -> impl Iterator<Item = T> + '_ {
// Predicate fn which identifies this node in the shuffle.
pred: impl Fn(T) -> bool,
nodes: impl IntoIterator<Item = T>,
) -> (/*this node's index:*/ usize, impl Iterator<Item = T>) {
let mut nodes = nodes.into_iter();
// This node's index within shuffled nodes.
let index = nodes.by_ref().position(pred).unwrap();
// Node's index within its neighborhood.
let offset = index.saturating_sub(1) % fanout;
// First node in the neighborhood.
let anchor = index - offset;
let step = if index == 0 { 1 } else { fanout };
(anchor * fanout + offset + 1..)
let peers = (anchor * fanout + offset + 1..)
.step_by(step)
.take(fanout)
.map(|i| nodes.get(i))
.while_some()
.copied()
.scan(index, move |state, k| -> Option<T> {
let peer = nodes.by_ref().nth(k - *state - 1)?;
*state = k;
Some(peer)
});
(index, peers)
}

// Returns the parent node in the turbine broadcast tree.
Expand Down Expand Up @@ -519,6 +489,19 @@ pub(crate) fn get_broadcast_protocol(_: &ShredId) -> Protocol {
Protocol::UDP
}

#[inline]
fn get_root_distance(index: usize, fanout: usize) -> usize {
if index == 0 {
0
} else if index <= fanout {
1
} else if index <= fanout.saturating_add(1).saturating_mul(fanout) {
2
} else {
3 // If changed, update MAX_NUM_TURBINE_HOPS.
}
}

pub fn make_test_cluster<R: Rng>(
rng: &mut R,
num_nodes: usize,
Expand Down Expand Up @@ -710,7 +693,7 @@ mod tests {
T: Copy + Eq + PartialEq + Debug + Hash,
{
// Map node identities to their index within the shuffled tree.
let index: HashMap<_, _> = nodes
let cache: HashMap<_, _> = nodes
.iter()
.copied()
.enumerate()
Expand All @@ -720,18 +703,22 @@ mod tests {
// Root node's parent is None.
assert_eq!(get_retransmit_parent(fanout, /*index:*/ 0, nodes), None);
for (k, peers) in peers.into_iter().enumerate() {
assert_eq!(
get_retransmit_peers(fanout, k, nodes).collect::<Vec<_>>(),
peers
);
{
let (index, retransmit_peers) =
get_retransmit_peers(fanout, |node| node == &nodes[k], nodes);
assert_eq!(peers, retransmit_peers.copied().collect::<Vec<_>>());
assert_eq!(index, k);
}
let parent = Some(nodes[k]);
for peer in peers {
assert_eq!(get_retransmit_parent(fanout, index[&peer], nodes), parent);
assert_eq!(get_retransmit_parent(fanout, cache[&peer], nodes), parent);
}
}
// Remaining nodes have no children.
for k in offset..=nodes.len() {
assert_eq!(get_retransmit_peers(fanout, k, nodes).next(), None);
for k in offset..nodes.len() {
let (index, mut peers) = get_retransmit_peers(fanout, |node| node == &nodes[k], nodes);
assert_eq!(peers.next(), None);
assert_eq!(index, k);
}
}

Expand Down Expand Up @@ -860,7 +847,7 @@ mod tests {
let mut nodes: Vec<_> = (0..size).collect();
nodes.shuffle(&mut rng);
// Map node identities to their index within the shuffled tree.
let index: HashMap<_, _> = nodes
let cache: HashMap<_, _> = nodes
.iter()
.copied()
.enumerate()
Expand All @@ -870,13 +857,16 @@ mod tests {
assert_eq!(get_retransmit_parent(fanout, /*index:*/ 0, &nodes), None);
for k in 1..size {
let parent = get_retransmit_parent(fanout, k, &nodes).unwrap();
let mut peers = get_retransmit_peers(fanout, index[&parent], &nodes);
assert_eq!(peers.find(|&peer| peer == nodes[k]), Some(nodes[k]));
let (index, mut peers) = get_retransmit_peers(fanout, |node| node == &parent, &nodes);
assert_eq!(index, cache[&parent]);
assert_eq!(peers.find(|&&peer| peer == nodes[k]), Some(&nodes[k]));
}
for k in 0..size {
let parent = Some(nodes[k]);
for peer in get_retransmit_peers(fanout, k, &nodes) {
assert_eq!(get_retransmit_parent(fanout, index[&peer], &nodes), parent);
let (index, peers) = get_retransmit_peers(fanout, |node| node == &nodes[k], &nodes);
assert_eq!(index, k);
for peer in peers {
assert_eq!(get_retransmit_parent(fanout, cache[peer], &nodes), parent);
}
}
}
Expand Down
12 changes: 6 additions & 6 deletions turbine/src/retransmit_stage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,12 @@ fn retransmit_shred(
) -> Result<(/*root_distance:*/ usize, /*num_nodes:*/ usize), Error> {
let mut compute_turbine_peers = Measure::start("turbine_start");
let data_plane_fanout = cluster_nodes::get_data_plane_fanout(key.slot(), root_bank);
let (root_distance, addrs) =
cluster_nodes.get_retransmit_addrs(slot_leader, key, data_plane_fanout)?;
let addrs: Vec<_> = addrs
.into_iter()
.filter(|addr| socket_addr_space.check(addr))
.collect();
let (root_distance, addrs) = cluster_nodes.get_retransmit_addrs(
slot_leader,
key,
data_plane_fanout,
socket_addr_space,
)?;
compute_turbine_peers.stop();
stats
.compute_turbine_peers_total
Expand Down

0 comments on commit 11b2e32

Please sign in to comment.