diff --git a/consensus/src/dag/dag_fetcher.rs b/consensus/src/dag/dag_fetcher.rs index db3da1f9630cfe..8216b6ae14c0cf 100644 --- a/consensus/src/dag/dag_fetcher.rs +++ b/consensus/src/dag/dag_fetcher.rs @@ -13,33 +13,60 @@ use aptos_infallible::RwLock; use aptos_logger::error; use aptos_time_service::TimeService; use aptos_types::epoch_state::EpochState; -use futures::{stream::FuturesUnordered, StreamExt}; -use tokio::sync::{oneshot, mpsc::{Sender, Receiver}}; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use futures::{stream::FuturesUnordered, Stream, StreamExt}; +use std::{ + collections::HashMap, + pin::Pin, + sync::Arc, + task::{Context, Poll}, + time::Duration, +}; use thiserror::Error as ThisError; +use tokio::sync::{ + mpsc::{Receiver, Sender}, + oneshot, +}; -pub struct FetchRequester { - request_tx: Sender, - node_rx_futures: FuturesUnordered>, - certified_node_rx_futures: FuturesUnordered>, +pub struct FetchWaiter { + rx: Receiver>, + futures: Pin>>>, } -impl FetchRequester { - pub fn new(request_tx: Sender) -> Self { +impl FetchWaiter { + fn new(rx: Receiver>) -> Self { Self { - request_tx, - node_rx_futures: FuturesUnordered::new(), - certified_node_rx_futures: FuturesUnordered::new(), + rx, + futures: Box::pin(FuturesUnordered::new()), } } +} +impl Stream for FetchWaiter { + type Item = Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + if let Poll::Ready(Some(rx)) = self.rx.poll_recv(cx) { + self.futures.push(rx); + } + + self.futures.as_mut().poll_next(cx) + } +} + +pub struct FetchRequester { + request_tx: Sender, + node_tx: Sender>, + certified_node_tx: Sender>, +} + +impl FetchRequester { pub fn request_for_node(&self, node: Node) -> anyhow::Result<()> { let (res_tx, res_rx) = oneshot::channel(); let fetch_req = LocalFetchRequest::Node(node, res_tx); self.request_tx .try_send(fetch_req) .map_err(|e| anyhow::anyhow!("unable to send fetch request to channel: {}", e))?; - self.node_rx_futures.push(res_rx); + self.node_tx.try_send(res_rx)?; Ok(()) } @@ -49,19 +76,9 @@ impl FetchRequester { self.request_tx .try_send(fetch_req) .map_err(|e| anyhow::anyhow!("unable to send fetch request to channel: {}", e))?; - self.certified_node_rx_futures.push(res_rx); + self.certified_node_tx.try_send(res_rx)?; Ok(()) } - - pub async fn next_ready_node(&mut self) -> Option> { - self.node_rx_futures.next().await - } - - pub async fn next_ready_certified_node( - &mut self, - ) -> Option> { - self.certified_node_rx_futures.next().await - } } #[derive(Debug)] @@ -113,8 +130,15 @@ impl DagFetcher { network: Arc, dag: Arc>, time_service: TimeService, - ) -> (Self, FetchRequester) { + ) -> ( + Self, + FetchRequester, + FetchWaiter, + FetchWaiter, + ) { let (request_tx, request_rx) = tokio::sync::mpsc::channel(16); + let (node_tx, node_rx) = tokio::sync::mpsc::channel(100); + let (certified_node_tx, certified_node_rx) = tokio::sync::mpsc::channel(100); ( Self { epoch_state, @@ -125,9 +149,11 @@ impl DagFetcher { }, FetchRequester { request_tx, - node_rx_futures: FuturesUnordered::new(), - certified_node_rx_futures: FuturesUnordered::new(), + node_tx, + certified_node_tx, }, + FetchWaiter::new(node_rx), + FetchWaiter::new(certified_node_rx), ) } diff --git a/consensus/src/dag/dag_handler.rs b/consensus/src/dag/dag_handler.rs index ff047eef60abf4..55c8b6cfa4065e 100644 --- a/consensus/src/dag/dag_handler.rs +++ b/consensus/src/dag/dag_handler.rs @@ -2,11 +2,12 @@ use super::{ dag_driver::DagDriver, - dag_fetcher::{DagFetcher, FetchRequestHandler}, + dag_fetcher::{DagFetcher, FetchRequestHandler, FetchWaiter}, dag_network::DAGNetworkSender, order_rule::OrderRule, storage::DAGStorage, types::TDAGMessage, + CertifiedNode, Node, }; use crate::{ dag::{ @@ -28,6 +29,7 @@ use aptos_types::{epoch_state::EpochState, validator_signer::ValidatorSigner}; use bytes::Bytes; use futures::StreamExt; use std::sync::Arc; +use tokio::select; use tokio_retry::strategy::ExponentialBackoff; struct NetworkHandler { @@ -36,6 +38,8 @@ struct NetworkHandler { dag_driver: DagDriver, fetch_receiver: FetchRequestHandler, epoch_state: Arc, + node_fetch_waiter: FetchWaiter, + certified_node_fetch_waiter: FetchWaiter, } impl NetworkHandler { @@ -57,13 +61,14 @@ impl NetworkHandler { ExponentialBackoff::from_millis(10), time_service.clone(), )); - // TODO: wire dag fetcher - let (_dag_fetcher, fetch_requester) = DagFetcher::new( - epoch_state.clone(), - dag_network_sender, - dag.clone(), - time_service.clone(), - ); + let (_dag_fetcher, fetch_requester, node_fetch_waiter, certified_node_fetch_waiter) = + DagFetcher::new( + epoch_state.clone(), + dag_network_sender, + dag.clone(), + time_service.clone(), + ); + let fetch_requester = Arc::new(fetch_requester); Self { dag_rpc_rx, node_receiver: NodeBroadcastHandler::new( @@ -82,10 +87,12 @@ impl NetworkHandler { time_service, storage, order_rule, - Arc::new(fetch_requester), + fetch_requester, ), epoch_state: epoch_state.clone(), fetch_receiver: FetchRequestHandler::new(dag, epoch_state), + node_fetch_waiter, + certified_node_fetch_waiter, } } @@ -93,9 +100,23 @@ impl NetworkHandler { self.dag_driver.try_enter_new_round(); // TODO(ibalajiarun): clean up Reliable Broadcast storage periodically. - while let Some(msg) = self.dag_rpc_rx.next().await { - if let Err(e) = self.process_rpc(msg).await { - warn!(error = ?e, "error processing rpc"); + loop { + select! { + Some(msg) = self.dag_rpc_rx.next() => { + if let Err(e) = self.process_rpc(msg).await { + warn!(error = ?e, "error processing rpc"); + } + }, + Some(res) = self.node_fetch_waiter.next() => { + if let Err(e) = res.map_err(|e| anyhow::anyhow!("recv error: {}", e)).and_then(|node| self.node_receiver.process(node)) { + warn!(error = ?e, "error processing node fetch notification"); + } + }, + Some(res) = self.certified_node_fetch_waiter.next() => { + if let Err(e) = res.map_err(|e| anyhow::anyhow!("recv error: {}", e)).and_then(|certified_node| self.dag_driver.process(certified_node)) { + warn!(error = ?e, "error processing certified node fetch notification"); + } + } } } } diff --git a/consensus/src/dag/tests/dag_driver_tests.rs b/consensus/src/dag/tests/dag_driver_tests.rs index 4c7e11259cbbfd..018710402f9434 100644 --- a/consensus/src/dag/tests/dag_driver_tests.rs +++ b/consensus/src/dag/tests/dag_driver_tests.rs @@ -4,7 +4,7 @@ use crate::{ dag::{ anchor_election::RoundRobinAnchorElection, dag_driver::{DagDriver, DagDriverError}, - dag_fetcher::FetchRequester, + dag_fetcher::DagFetcher, dag_network::{DAGNetworkSender, RpcWithFallback}, dag_store::Dag, order_rule::OrderRule, @@ -76,9 +76,10 @@ fn test_certified_node_handler() { let zeroth_round_node = new_certified_node(0, signers[0].author(), vec![]); + let network_sender = Arc::new(MockNetworkSender {}); let rb = Arc::new(ReliableBroadcast::new( signers.iter().map(|s| s.author()).collect(), - Arc::new(MockNetworkSender {}), + network_sender.clone(), ExponentialBackoff::from_millis(10), aptos_time_service::TimeService::mock(), )); @@ -93,8 +94,13 @@ fn test_certified_node_handler() { ordered_nodes_sender, ); - let (request_tx, _) = tokio::sync::mpsc::channel(10); - let fetch_requester = Arc::new(FetchRequester::new(request_tx)); + let (_, fetch_requester, _, _) = DagFetcher::new( + epoch_state.clone(), + network_sender, + dag.clone(), + aptos_time_service::TimeService::mock(), + ); + let fetch_requester = Arc::new(fetch_requester); let mut driver = DagDriver::new( signers[0].author(),