Skip to content

Commit

Permalink
fix(websocket): lock conflict between emit and poll
Browse files Browse the repository at this point in the history
  • Loading branch information
SSebo committed Aug 11, 2022
1 parent 2a47343 commit 7a9af9d
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 7 deletions.
1 change: 1 addition & 0 deletions engineio/src/asynchronous/async_transports/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ mod websocket_secure;

pub use self::polling::PollingTransport;
pub use self::websocket::WebsocketTransport;
pub use self::websocket_general::AsyncWebsocketSender;
pub use self::websocket_secure::WebsocketSecureTransport;
10 changes: 9 additions & 1 deletion engineio/src/asynchronous/async_transports/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::fmt::Debug;
use std::pin::Pin;
use std::sync::Arc;
use tokio::sync::Mutex;

use crate::asynchronous::transport::AsyncTransport;
use crate::error::Result;
Expand All @@ -14,14 +15,15 @@ use tokio_tungstenite::connect_async;
use tungstenite::client::IntoClientRequest;
use url::Url;

use super::websocket_general::AsyncWebsocketGeneralTransport;
use super::websocket_general::{AsyncWebsocketGeneralTransport, AsyncWebsocketSender};

/// An asynchronous websocket transport type.
/// This type only allows for plain websocket
/// connections ("ws://").
#[derive(Clone)]
pub struct WebsocketTransport {
inner: AsyncWebsocketGeneralTransport,
sender: Arc<Mutex<AsyncWebsocketSender>>,
base_url: Arc<RwLock<Url>>,
}

Expand All @@ -42,12 +44,18 @@ impl WebsocketTransport {
let (sen, rec) = ws_stream.split();

let inner = AsyncWebsocketGeneralTransport::new(sen, rec).await;
let sender = inner.sender();
Ok(WebsocketTransport {
inner,
sender,
base_url: Arc::new(RwLock::new(url)),
})
}

pub fn sender(&self) -> Arc<Mutex<AsyncWebsocketSender>> {
Arc::clone(&self.sender)
}

/// Sends probe packet to ensure connection is valid, then sends upgrade
/// request
pub(crate) async fn upgrade(&self) -> Result<()> {
Expand Down
34 changes: 30 additions & 4 deletions engineio/src/asynchronous/async_transports/websocket_general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,34 @@ use tokio::{net::TcpStream, sync::Mutex};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tungstenite::Message;

pub struct AsyncWebsocketSender {
inner: SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>,
}

type AsyncWebsocketReceiver = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;

impl AsyncWebsocketSender {
pub async fn emit(&mut self, data: Bytes, is_binary_att: bool) -> Result<()> {
let message = if is_binary_att {
Message::binary(Cow::Borrowed(data.as_ref()))
} else {
Message::text(Cow::Borrowed(std::str::from_utf8(data.as_ref())?))
};

self.inner.send(message).await?;

Ok(())
}
}

/// A general purpose asynchronous websocket transport type. Holds
/// the sender and receiver stream of a websocket connection
/// and implements the common methods `update` and `emit`. This also
/// implements `Stream`.
#[derive(Clone)]
pub(crate) struct AsyncWebsocketGeneralTransport {
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
receiver: Arc<Mutex<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
sender: Arc<Mutex<AsyncWebsocketSender>>,
receiver: Arc<Mutex<AsyncWebsocketReceiver>>,
}

impl AsyncWebsocketGeneralTransport {
Expand All @@ -27,7 +47,7 @@ impl AsyncWebsocketGeneralTransport {
receiver: SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>,
) -> Self {
AsyncWebsocketGeneralTransport {
sender: Arc::new(Mutex::new(sender)),
sender: Arc::new(Mutex::new(AsyncWebsocketSender { inner: sender })),
receiver: Arc::new(Mutex::new(receiver)),
}
}
Expand All @@ -39,6 +59,7 @@ impl AsyncWebsocketGeneralTransport {
let mut sender = self.sender.lock().await;

sender
.inner
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
Packet::new(PacketId::Ping, Bytes::from("probe")),
))?)))
Expand All @@ -54,6 +75,7 @@ impl AsyncWebsocketGeneralTransport {
}

sender
.inner
.send(Message::text(Cow::Borrowed(from_utf8(&Bytes::from(
Packet::new(PacketId::Upgrade, Bytes::from("")),
))?)))
Expand All @@ -71,10 +93,14 @@ impl AsyncWebsocketGeneralTransport {
Message::text(Cow::Borrowed(std::str::from_utf8(data.as_ref())?))
};

sender.send(message).await?;
sender.inner.send(message).await?;

Ok(())
}

pub(crate) fn sender(&self) -> Arc<Mutex<AsyncWebsocketSender>> {
Arc::clone(&self.sender)
}
}

impl Stream for AsyncWebsocketGeneralTransport {
Expand Down
8 changes: 6 additions & 2 deletions engineio/src/transports/websocket.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::{
asynchronous::{
async_transports::WebsocketTransport as AsyncWebsocketTransport, transport::AsyncTransport,
async_transports::{AsyncWebsocketSender, WebsocketTransport as AsyncWebsocketTransport},
transport::AsyncTransport,
},
error::Result,
transport::Transport,
Expand All @@ -17,6 +18,7 @@ use url::Url;
pub struct WebsocketTransport {
runtime: Arc<Runtime>,
inner: Arc<Mutex<AsyncWebsocketTransport>>,
sender: Arc<Mutex<AsyncWebsocketSender>>,
}

impl WebsocketTransport {
Expand All @@ -27,10 +29,12 @@ impl WebsocketTransport {
.build()?;

let inner = runtime.block_on(AsyncWebsocketTransport::new(base_url, headers))?;
let sender = inner.sender();

Ok(WebsocketTransport {
runtime: Arc::new(runtime),
inner: Arc::new(Mutex::new(inner)),
sender,
})
}

Expand All @@ -47,7 +51,7 @@ impl WebsocketTransport {
impl Transport for WebsocketTransport {
fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
let mut lock = self.sender.lock().await;
lock.emit(data, is_binary_att).await
})
}
Expand Down

0 comments on commit 7a9af9d

Please sign in to comment.