Skip to content

Commit

Permalink
manager: Avoid overflow on stream implementation for `TransportContex…
Browse files Browse the repository at this point in the history
…t` (#283)

This PR ensures that the stream implementation of `TransportContext`
does not overflow.
Instead, this PR ensures a round-robin polling strategy with the index
capped at the number of elements registered to the `TransportContext`.

While at it, added a test to ensure polling functionality works with
round-robin expectations.

Discovered during: #282

cc @paritytech/networking

---------

Signed-off-by: Alexandru Vasile <alexandru.vasile@parity.io>
  • Loading branch information
lexnv authored Nov 12, 2024
1 parent 65eb2c2 commit b1866ae
Showing 1 changed file with 164 additions and 16 deletions.
180 changes: 164 additions & 16 deletions src/transport/manager/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,29 @@ impl Stream for TransportContext {
type Item = (SupportedTransport, TransportEvent);

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let len = match self.transports.len() {
0 => return Poll::Ready(None),
len => len,
};
let start_index = self.index;

loop {
let index = self.index % len;
self.index += 1;
if self.transports.is_empty() {
// Terminate if we don't have any transports installed.
return Poll::Ready(None);
}

let (key, stream) = self.transports.get_index_mut(index).expect("transport to exist");
let len = self.transports.len();
self.index = (self.index + 1) % len;
for index in 0..len {
let current = (self.index + index) % len;
let (key, stream) = self.transports.get_index_mut(current).expect("transport to exist");
match stream.poll_next_unpin(cx) {
Poll::Pending => {}
Poll::Ready(None) => return Poll::Ready(None),
Poll::Ready(Some(event)) => return Poll::Ready(Some((*key, event))),
}

if self.index == start_index + len {
break Poll::Pending;
Poll::Ready(None) => {
return Poll::Ready(None);
}
Poll::Ready(Some(event)) => {
let event = Some((*key, event));
return Poll::Ready(event);
}
}
}

Poll::Pending
}
}

Expand Down Expand Up @@ -1349,6 +1351,152 @@ mod tests {
(dial_address, connection_id)
}

struct MockTransport {
rx: tokio::sync::mpsc::Receiver<TransportEvent>,
}

impl MockTransport {
fn new(rx: tokio::sync::mpsc::Receiver<TransportEvent>) -> Self {
Self { rx }
}
}

impl Transport for MockTransport {
fn dial(&mut self, _connection_id: ConnectionId, _address: Multiaddr) -> crate::Result<()> {
Ok(())
}

fn accept(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {
Ok(())
}

fn accept_pending(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {
Ok(())
}

fn reject_pending(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {
Ok(())
}

fn reject(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {
Ok(())
}

fn open(
&mut self,
_connection_id: ConnectionId,
_addresses: Vec<Multiaddr>,
) -> crate::Result<()> {
Ok(())
}

fn negotiate(&mut self, _connection_id: ConnectionId) -> crate::Result<()> {
Ok(())
}

fn cancel(&mut self, _connection_id: ConnectionId) {}
}
impl Stream for MockTransport {
type Item = TransportEvent;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
self.rx.poll_recv(cx)
}
}

#[tokio::test]
#[cfg(feature = "websocket")]
async fn transport_events() {
let mut transports = TransportContext::new();

let (tx_tcp, rx) = tokio::sync::mpsc::channel(8);
let transport = MockTransport::new(rx);
transports.register_transport(SupportedTransport::Tcp, Box::new(transport));

let (tx_ws, rx) = tokio::sync::mpsc::channel(8);
let transport = MockTransport::new(rx);
transports.register_transport(SupportedTransport::WebSocket, Box::new(transport));

assert_eq!(transports.index, 0);
assert_eq!(transports.transports.len(), 2);
// No items.
futures::future::poll_fn(|cx| match transports.poll_next_unpin(cx) {
std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"),
std::task::Poll::Pending => std::task::Poll::Ready(()),
})
.await;
assert_eq!(transports.index, 1);

// Websocket events.
tx_ws
.send(TransportEvent::PendingInboundConnection {
connection_id: ConnectionId::from(1),
})
.await
.expect("chanel to be open");

let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx))
.await
.expect("expected event");
assert_eq!(event.0, SupportedTransport::WebSocket);
assert!(std::matches!(
event.1,
TransportEvent::PendingInboundConnection { .. }
));
assert_eq!(transports.index, 0);

// TCP events.
tx_tcp
.send(TransportEvent::PendingInboundConnection {
connection_id: ConnectionId::from(2),
})
.await
.expect("chanel to be open");

let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx))
.await
.expect("expected event");
assert_eq!(event.0, SupportedTransport::Tcp);
assert!(std::matches!(
event.1,
TransportEvent::PendingInboundConnection { .. }
));
assert_eq!(transports.index, 1);

// Both transports produce events.
tx_ws
.send(TransportEvent::PendingInboundConnection {
connection_id: ConnectionId::from(3),
})
.await
.expect("chanel to be open");
tx_tcp
.send(TransportEvent::PendingInboundConnection {
connection_id: ConnectionId::from(4),
})
.await
.expect("chanel to be open");

let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx))
.await
.expect("expected event");
assert_eq!(event.0, SupportedTransport::Tcp);
assert!(std::matches!(
event.1,
TransportEvent::PendingInboundConnection { .. }
));
assert_eq!(transports.index, 0);

let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx))
.await
.expect("expected event");
assert_eq!(event.0, SupportedTransport::WebSocket);
assert!(std::matches!(
event.1,
TransportEvent::PendingInboundConnection { .. }
));
assert_eq!(transports.index, 1);
}

#[test]
#[should_panic]
#[cfg(debug_assertions)]
Expand Down

0 comments on commit b1866ae

Please sign in to comment.