Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(network): implement background keep-alive loop #427

Merged
merged 3 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 2 additions & 25 deletions examples/n2n-miniprotocols/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use pallas::{
miniprotocols::{blockfetch, chainsync, keepalive, Point, MAINNET_MAGIC},
},
};
use std::time::Duration;

use thiserror::Error;
use tokio::time::Instant;

Expand Down Expand Up @@ -117,14 +117,6 @@ async fn do_chainsync(
}
}

async fn do_keepalive(mut keepalive_client: keepalive::Client) -> Result<(), Error> {
loop {
tokio::time::sleep(Duration::from_secs(20)).await;
keepalive_client.send_keepalive().await?;
tracing::info!("keepalive sent");
}
}

#[tokio::main]
async fn main() {
tracing::subscriber::set_global_default(
Expand All @@ -145,25 +137,10 @@ async fn main() {
plexer,
chainsync,
blockfetch,
keepalive,
..
} = peer;

let chainsync_handle = tokio::spawn(do_chainsync(chainsync, blockfetch));
let keepalive_handle = tokio::spawn(do_keepalive(keepalive));

// If any of these concurrent tasks exit or fail, the others are canceled.
let (chainsync_result, keepalive_result) =
tokio::try_join!(chainsync_handle, keepalive_handle)
.expect("error joining tokio threads");

if let Err(err) = chainsync_result {
tracing::error!("chainsync error: {:?}", err);
}

if let Err(err) = keepalive_result {
tracing::error!("keepalive error: {:?}", err);
}
do_chainsync(chainsync, blockfetch).await.unwrap();

plexer.abort().await;

Expand Down
124 changes: 91 additions & 33 deletions pallas-network/src/facades.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
use std::net::SocketAddr;
use std::path::Path;

Check warning on line 2 in pallas-network/src/facades.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

unused import: `std::path::Path`
use std::time::Duration;
use thiserror::Error;
use tracing::error;
use tracing::{debug, error, warn};

use tokio::net::{TcpListener, ToSocketAddrs};

#[cfg(unix)]
use tokio::net::{unix::SocketAddr as UnixSocketAddr, UnixListener};

use crate::miniprotocols::handshake::{n2c, n2n, Confirmation, VersionNumber};

Check warning on line 12 in pallas-network/src/facades.rs

View workflow job for this annotation

GitHub Actions / Check (windows-latest, stable)

unused imports: `Confirmation`, `VersionNumber`, `n2c`

use crate::miniprotocols::{
blockfetch, chainsync, handshake, keepalive, localstate, txsubmission, PROTOCOL_N2C_CHAIN_SYNC,
Expand All @@ -30,53 +31,104 @@
#[error("handshake protocol error")]
HandshakeProtocol(handshake::Error),

#[error("keepalive client loop error")]
KeepAliveClientLoop(keepalive::ClientError),

#[error("keepalive server loop error")]
KeepAliveServerLoop(keepalive::ServerError),

#[error("handshake version not accepted")]
IncompatibleVersion,
}

pub const DEFAULT_KEEP_ALIVE_INTERVAL_SEC: u64 = 20;

pub type KeepAliveHandle = tokio::task::JoinHandle<Result<(), Error>>;

pub enum KeepAliveLoop {
Client(keepalive::Client, Duration),
Server(keepalive::Server),
}

impl KeepAliveLoop {
pub fn client(client: keepalive::Client, interval: Duration) -> Self {
Self::Client(client, interval)
}

pub fn server(server: keepalive::Server) -> Self {
Self::Server(server)
}

pub async fn run_client(
mut client: keepalive::Client,
interval: Duration,
) -> Result<(), Error> {
let mut interval = tokio::time::interval(interval);

loop {
interval.tick().await;
warn!("sending keepalive request");

client
.keepalive_roundtrip()
.await
.map_err(Error::KeepAliveClientLoop)?;
}
}

pub async fn run_server(mut server: keepalive::Server) -> Result<(), Error> {
loop {
debug!("waiting keepalive request");

server
.keepalive_roundtrip()
.await
.map_err(Error::KeepAliveServerLoop)?;
}
}

pub fn spawn(self) -> KeepAliveHandle {
match self {
KeepAliveLoop::Client(client, interval) => {
tokio::spawn(Self::run_client(client, interval))
}
KeepAliveLoop::Server(server) => tokio::spawn(Self::run_server(server)),
}
}
}

/// Client of N2N Ouroboros
pub struct PeerClient {
pub plexer: RunningPlexer,
pub handshake: handshake::N2NClient,
pub keepalive: KeepAliveHandle,
pub chainsync: chainsync::N2NClient,
pub blockfetch: blockfetch::Client,
pub txsubmission: txsubmission::Client,
pub keepalive: keepalive::Client,
}

impl PeerClient {
pub fn new(bearer: Bearer) -> Self {
pub async fn connect(addr: impl ToSocketAddrs, magic: u64) -> Result<Self, Error> {
let bearer = Bearer::connect_tcp(addr)
.await
.map_err(Error::ConnectFailure)?;

let mut plexer = multiplexer::Plexer::new(bearer);

let hs_channel = plexer.subscribe_client(PROTOCOL_N2N_HANDSHAKE);
let channel = plexer.subscribe_client(PROTOCOL_N2N_HANDSHAKE);
let mut handshake = handshake::Client::new(channel);

let cs_channel = plexer.subscribe_client(PROTOCOL_N2N_CHAIN_SYNC);
let bf_channel = plexer.subscribe_client(PROTOCOL_N2N_BLOCK_FETCH);
let txsub_channel = plexer.subscribe_client(PROTOCOL_N2N_TX_SUBMISSION);
let keepalive_channel = plexer.subscribe_client(PROTOCOL_N2N_KEEP_ALIVE);

let plexer = plexer.spawn();

Self {
plexer,
handshake: handshake::Client::new(hs_channel),
chainsync: chainsync::Client::new(cs_channel),
blockfetch: blockfetch::Client::new(bf_channel),
txsubmission: txsubmission::Client::new(txsub_channel),
keepalive: keepalive::Client::new(keepalive_channel),
}
}

pub async fn connect(addr: impl ToSocketAddrs, magic: u64) -> Result<Self, Error> {
let bearer = Bearer::connect_tcp(addr)
.await
.map_err(Error::ConnectFailure)?;
let channel = plexer.subscribe_client(PROTOCOL_N2N_KEEP_ALIVE);
let keepalive = keepalive::Client::new(channel);

let mut client = Self::new(bearer);
let plexer = plexer.spawn();

let versions = handshake::n2n::VersionTable::v7_and_above(magic);

let handshake = client
.handshake()
let handshake = handshake
.handshake(versions)
.await
.map_err(Error::HandshakeProtocol)?;
Expand All @@ -86,11 +138,21 @@
return Err(Error::IncompatibleVersion);
}

Ok(client)
}
let keepalive = KeepAliveLoop::client(
keepalive,
Duration::from_secs(DEFAULT_KEEP_ALIVE_INTERVAL_SEC),
)
.spawn();

pub fn handshake(&mut self) -> &mut handshake::N2NClient {
&mut self.handshake
let client = Self {
plexer,
keepalive,
chainsync: chainsync::Client::new(cs_channel),
blockfetch: blockfetch::Client::new(bf_channel),
txsubmission: txsubmission::Client::new(txsub_channel),
};

Ok(client)
}

pub fn chainsync(&mut self) -> &mut chainsync::N2NClient {
Expand All @@ -114,10 +176,6 @@
&mut self.txsubmission
}

pub fn keepalive(&mut self) -> &mut keepalive::Client {
&mut self.keepalive
}

pub async fn abort(self) {
self.plexer.abort().await
}
Expand Down
46 changes: 22 additions & 24 deletions pallas-network/src/miniprotocols/keepalive/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,11 @@ pub enum ClientError {
Plexer(multiplexer::Error),
}

pub struct KeepAliveSharedState {
saved_cookie: u16,
}

pub struct Client(State, multiplexer::ChannelBuffer, KeepAliveSharedState);
pub struct Client(State, multiplexer::ChannelBuffer);

impl Client {
pub fn new(channel: multiplexer::AgentChannel) -> Self {
Self(
State::Client,
multiplexer::ChannelBuffer::new(channel),
KeepAliveSharedState { saved_cookie: 0 },
)
Self(State::Client, multiplexer::ChannelBuffer::new(channel))
}

pub fn state(&self) -> &State {
Expand All @@ -53,7 +45,7 @@ impl Client {
fn has_agency(&self) -> bool {
match &self.0 {
State::Client => true,
State::Server => false,
State::Server(..) => false,
State::Done => false,
}
}
Expand Down Expand Up @@ -84,7 +76,7 @@ impl Client {

fn assert_inbound_state(&self, msg: &Message) -> Result<(), ClientError> {
match (&self.0, msg) {
(State::Server, Message::ResponseKeepAlive(..)) => Ok(()),
(State::Server(..), Message::ResponseKeepAlive(..)) => Ok(()),
_ => Err(ClientError::InvalidInbound),
}
}
Expand All @@ -108,32 +100,38 @@ impl Client {
Ok(msg)
}

pub async fn send_keepalive(&mut self) -> Result<(), ClientError> {
pub async fn send_keepalive_request(&mut self) -> Result<(), ClientError> {
// generate random cookie value
let cookie = rand::thread_rng().gen::<KeepAliveCookie>();
let cookie = rand::thread_rng().gen::<Cookie>();
let msg = Message::KeepAlive(cookie);
self.send_message(&msg).await?;
self.2.saved_cookie = cookie;
self.0 = State::Server;
self.0 = State::Server(cookie);
debug!("sent keepalive message with cookie {}", cookie);

self.recv_while_sending_keepalive().await?;

Ok(())
}

async fn recv_while_sending_keepalive(&mut self) -> Result<(), ClientError> {
pub async fn recv_keepalive_response(&mut self) -> Result<(), ClientError> {
match self.recv_message().await? {
Message::ResponseKeepAlive(cookie) => {
debug!("received keepalive response with cookie {}", cookie);
if cookie == self.2.saved_cookie {
self.0 = State::Client;
Ok(())
} else {
Err(ClientError::KeepAliveCookieMismatch)
match self.state() {
State::Server(expected) if *expected == cookie => {
self.0 = State::Client;
Ok(())
}
State::Server(..) => Err(ClientError::KeepAliveCookieMismatch),
_ => unreachable!(),
}
}
_ => Err(ClientError::InvalidInbound),
}
}

pub async fn keepalive_roundtrip(&mut self) -> Result<(), ClientError> {
self.send_keepalive_request().await?;
self.recv_keepalive_response().await?;

Ok(())
}
}
8 changes: 4 additions & 4 deletions pallas-network/src/miniprotocols/keepalive/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
pub type KeepAliveCookie = u16;
pub type Cookie = u16;

#[derive(Debug, PartialEq, Eq, Clone)]
pub enum State {
Client,
Server,
Server(Cookie),
Done,
}

#[derive(Debug, Clone)]
pub enum Message {
KeepAlive(KeepAliveCookie),
ResponseKeepAlive(KeepAliveCookie),
KeepAlive(Cookie),
ResponseKeepAlive(Cookie),
Done,
}
Loading
Loading