From b1866aef25f38bc022b4996f386e9b09852d6a29 Mon Sep 17 00:00:00 2001 From: Alexandru Vasile <60601340+lexnv@users.noreply.github.com> Date: Tue, 12 Nov 2024 13:42:50 +0200 Subject: [PATCH] manager: Avoid overflow on stream implementation for `TransportContext` (#283) This PR ensures that the stream implementation of `TransportContext` does not overflow. Instead, this PR ensures a round-robin polling strategy with the index capped at the number of elements registered to the `TransportContext`. While at it, added a test to ensure polling functionality works with round-robin expectations. Discovered during: https://github.com/paritytech/litep2p/issues/282 cc @paritytech/networking --------- Signed-off-by: Alexandru Vasile --- src/transport/manager/mod.rs | 180 +++++++++++++++++++++++++++++++---- 1 file changed, 164 insertions(+), 16 deletions(-) diff --git a/src/transport/manager/mod.rs b/src/transport/manager/mod.rs index d7fce91e..12c963d7 100644 --- a/src/transport/manager/mod.rs +++ b/src/transport/manager/mod.rs @@ -170,27 +170,29 @@ impl Stream for TransportContext { type Item = (SupportedTransport, TransportEvent); fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - let len = match self.transports.len() { - 0 => return Poll::Ready(None), - len => len, - }; - let start_index = self.index; - - loop { - let index = self.index % len; - self.index += 1; + if self.transports.is_empty() { + // Terminate if we don't have any transports installed. + return Poll::Ready(None); + } - let (key, stream) = self.transports.get_index_mut(index).expect("transport to exist"); + let len = self.transports.len(); + self.index = (self.index + 1) % len; + for index in 0..len { + let current = (self.index + index) % len; + let (key, stream) = self.transports.get_index_mut(current).expect("transport to exist"); match stream.poll_next_unpin(cx) { Poll::Pending => {} - Poll::Ready(None) => return Poll::Ready(None), - Poll::Ready(Some(event)) => return Poll::Ready(Some((*key, event))), - } - - if self.index == start_index + len { - break Poll::Pending; + Poll::Ready(None) => { + return Poll::Ready(None); + } + Poll::Ready(Some(event)) => { + let event = Some((*key, event)); + return Poll::Ready(event); + } } } + + Poll::Pending } } @@ -1349,6 +1351,152 @@ mod tests { (dial_address, connection_id) } + struct MockTransport { + rx: tokio::sync::mpsc::Receiver, + } + + impl MockTransport { + fn new(rx: tokio::sync::mpsc::Receiver) -> Self { + Self { rx } + } + } + + impl Transport for MockTransport { + fn dial(&mut self, _connection_id: ConnectionId, _address: Multiaddr) -> crate::Result<()> { + Ok(()) + } + + fn accept(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + fn accept_pending(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + fn reject_pending(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + fn reject(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + fn open( + &mut self, + _connection_id: ConnectionId, + _addresses: Vec, + ) -> crate::Result<()> { + Ok(()) + } + + fn negotiate(&mut self, _connection_id: ConnectionId) -> crate::Result<()> { + Ok(()) + } + + fn cancel(&mut self, _connection_id: ConnectionId) {} + } + impl Stream for MockTransport { + type Item = TransportEvent; + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.rx.poll_recv(cx) + } + } + + #[tokio::test] + #[cfg(feature = "websocket")] + async fn transport_events() { + let mut transports = TransportContext::new(); + + let (tx_tcp, rx) = tokio::sync::mpsc::channel(8); + let transport = MockTransport::new(rx); + transports.register_transport(SupportedTransport::Tcp, Box::new(transport)); + + let (tx_ws, rx) = tokio::sync::mpsc::channel(8); + let transport = MockTransport::new(rx); + transports.register_transport(SupportedTransport::WebSocket, Box::new(transport)); + + assert_eq!(transports.index, 0); + assert_eq!(transports.transports.len(), 2); + // No items. + futures::future::poll_fn(|cx| match transports.poll_next_unpin(cx) { + std::task::Poll::Ready(_) => panic!("didn't expect event from `TransportService`"), + std::task::Poll::Pending => std::task::Poll::Ready(()), + }) + .await; + assert_eq!(transports.index, 1); + + // Websocket events. + tx_ws + .send(TransportEvent::PendingInboundConnection { + connection_id: ConnectionId::from(1), + }) + .await + .expect("chanel to be open"); + + let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx)) + .await + .expect("expected event"); + assert_eq!(event.0, SupportedTransport::WebSocket); + assert!(std::matches!( + event.1, + TransportEvent::PendingInboundConnection { .. } + )); + assert_eq!(transports.index, 0); + + // TCP events. + tx_tcp + .send(TransportEvent::PendingInboundConnection { + connection_id: ConnectionId::from(2), + }) + .await + .expect("chanel to be open"); + + let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx)) + .await + .expect("expected event"); + assert_eq!(event.0, SupportedTransport::Tcp); + assert!(std::matches!( + event.1, + TransportEvent::PendingInboundConnection { .. } + )); + assert_eq!(transports.index, 1); + + // Both transports produce events. + tx_ws + .send(TransportEvent::PendingInboundConnection { + connection_id: ConnectionId::from(3), + }) + .await + .expect("chanel to be open"); + tx_tcp + .send(TransportEvent::PendingInboundConnection { + connection_id: ConnectionId::from(4), + }) + .await + .expect("chanel to be open"); + + let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx)) + .await + .expect("expected event"); + assert_eq!(event.0, SupportedTransport::Tcp); + assert!(std::matches!( + event.1, + TransportEvent::PendingInboundConnection { .. } + )); + assert_eq!(transports.index, 0); + + let event = futures::future::poll_fn(|cx| transports.poll_next_unpin(cx)) + .await + .expect("expected event"); + assert_eq!(event.0, SupportedTransport::WebSocket); + assert!(std::matches!( + event.1, + TransportEvent::PendingInboundConnection { .. } + )); + assert_eq!(transports.index, 1); + } + #[test] #[should_panic] #[cfg(debug_assertions)]