Skip to content

Commit

Permalink
chore(socket): simplify handle_engineio_packet
Browse files Browse the repository at this point in the history
  • Loading branch information
SSebo committed Aug 12, 2022
1 parent 31996cf commit f3dc134
Show file tree
Hide file tree
Showing 6 changed files with 64 additions and 52 deletions.
4 changes: 4 additions & 0 deletions engineio/src/asynchronous/async_transports/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ impl WebsocketTransport {
pub(crate) async fn upgrade(&self) -> Result<()> {
self.inner.upgrade().await
}

pub(crate) async fn poll_next(&self) -> Result<Option<Bytes>> {
self.inner.poll_next().await
}
}

#[async_trait]
Expand Down
28 changes: 26 additions & 2 deletions engineio/src/asynchronous/async_transports/websocket_general.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,17 @@ use tokio::{net::TcpStream, sync::Mutex};
use tokio_tungstenite::{MaybeTlsStream, WebSocketStream};
use tungstenite::Message;

type AsyncWebsocketSender = SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>;
type AsyncWebsocketReceiver = SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>;

/// A general purpose asynchronous websocket transport type. Holds
/// the sender and receiver stream of a websocket connection
/// and implements the common methods `update` and `emit`. This also
/// implements `Stream`.
#[derive(Clone)]
pub(crate) struct AsyncWebsocketGeneralTransport {
sender: Arc<Mutex<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
receiver: Arc<Mutex<SplitStream<WebSocketStream<MaybeTlsStream<TcpStream>>>>>,
sender: Arc<Mutex<AsyncWebsocketSender>>,
receiver: Arc<Mutex<AsyncWebsocketReceiver>>,
}

impl AsyncWebsocketGeneralTransport {
Expand Down Expand Up @@ -75,6 +78,27 @@ impl AsyncWebsocketGeneralTransport {

Ok(())
}

pub(crate) async fn poll_next(&self) -> Result<Option<Bytes>> {
loop {
let mut receiver = self.receiver.lock().await;
let next = receiver.next().await;
match next {
Some(Ok(Message::Text(str))) => return Ok(Some(Bytes::from(str))),
Some(Ok(Message::Binary(data))) => {
let mut msg = BytesMut::with_capacity(data.len() + 1);
msg.put_u8(PacketId::Message as u8);
msg.put(data.as_ref());

return Ok(Some(msg.freeze()));
}
// ignore packets other than text and binary
Some(Ok(_)) => (),
Some(Err(err)) => return Err(err.into()),
None => return Ok(None),
}
}
}
}

impl Stream for AsyncWebsocketGeneralTransport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ impl WebsocketSecureTransport {
pub(crate) async fn upgrade(&self) -> Result<()> {
self.inner.upgrade().await
}

pub(crate) async fn poll_next(&self) -> Result<Option<Bytes>> {
self.inner.poll_next().await
}
}

impl Stream for WebsocketSecureTransport {
Expand Down
36 changes: 14 additions & 22 deletions engineio/src/transports/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,15 @@ use crate::{
Error,
};
use bytes::Bytes;
use futures_util::StreamExt;
use http::HeaderMap;
use std::sync::Arc;
use tokio::{runtime::Runtime, sync::Mutex};
use tokio::runtime::Runtime;
use url::Url;

#[derive(Clone)]
pub struct WebsocketTransport {
runtime: Arc<Runtime>,
inner: Arc<Mutex<AsyncWebsocketTransport>>,
inner: Arc<AsyncWebsocketTransport>,
}

impl WebsocketTransport {
Expand All @@ -30,47 +29,40 @@ impl WebsocketTransport {

Ok(WebsocketTransport {
runtime: Arc::new(runtime),
inner: Arc::new(Mutex::new(inner)),
inner: Arc::new(inner),
})
}

/// Sends probe packet to ensure connection is valid, then sends upgrade
/// request
pub(crate) fn upgrade(&self) -> Result<()> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.upgrade().await
})
self.runtime.block_on(async { self.inner.upgrade().await })
}
}

impl Transport for WebsocketTransport {
fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.emit(data, is_binary_att).await
})
self.runtime
.block_on(async { self.inner.emit(data, is_binary_att).await })
}

fn poll(&self) -> Result<Bytes> {
self.runtime.block_on(async {
let mut lock = self.inner.lock().await;
lock.next().await.ok_or(Error::IncompletePacket())?
let r = self.inner.poll_next().await;
match r {
Ok(b) => b.ok_or(Error::IncompletePacket()),
Err(_) => Err(Error::IncompletePacket()),
}
})
}

fn base_url(&self) -> Result<url::Url> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.base_url().await
})
self.runtime.block_on(async { self.inner.base_url().await })
}

fn set_base_url(&self, url: url::Url) -> Result<()> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.set_base_url(url).await
})
self.runtime
.block_on(async { self.inner.set_base_url(url).await })
}
}

Expand Down
36 changes: 14 additions & 22 deletions engineio/src/transports/websocket_secure.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,16 @@ use crate::{
Error,
};
use bytes::Bytes;
use futures_util::StreamExt;
use http::HeaderMap;
use native_tls::TlsConnector;
use std::sync::Arc;
use tokio::{runtime::Runtime, sync::Mutex};
use tokio::runtime::Runtime;
use url::Url;

#[derive(Clone)]
pub struct WebsocketSecureTransport {
runtime: Arc<Runtime>,
inner: Arc<Mutex<AsyncWebsocketSecureTransport>>,
inner: Arc<AsyncWebsocketSecureTransport>,
}

impl WebsocketSecureTransport {
Expand All @@ -38,47 +37,40 @@ impl WebsocketSecureTransport {

Ok(WebsocketSecureTransport {
runtime: Arc::new(runtime),
inner: Arc::new(Mutex::new(inner)),
inner: Arc::new(inner),
})
}

/// Sends probe packet to ensure connection is valid, then sends upgrade
/// request
pub(crate) fn upgrade(&self) -> Result<()> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.upgrade().await
})
self.runtime.block_on(async { self.inner.upgrade().await })
}
}

impl Transport for WebsocketSecureTransport {
fn emit(&self, data: Bytes, is_binary_att: bool) -> Result<()> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.emit(data, is_binary_att).await
})
self.runtime
.block_on(async { self.inner.emit(data, is_binary_att).await })
}

fn poll(&self) -> Result<Bytes> {
self.runtime.block_on(async {
let mut lock = self.inner.lock().await;
lock.next().await.ok_or(Error::IncompletePacket())?
let r = self.inner.poll_next().await;
match r {
Ok(b) => b.ok_or(Error::IncompletePacket()),
Err(_) => Err(Error::IncompletePacket()),
}
})
}

fn base_url(&self) -> Result<url::Url> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.base_url().await
})
self.runtime.block_on(async { self.inner.base_url().await })
}

fn set_base_url(&self, url: url::Url) -> Result<()> {
self.runtime.block_on(async {
let lock = self.inner.lock().await;
lock.set_base_url(url).await
})
self.runtime
.block_on(async { self.inner.set_base_url(url).await })
}
}

Expand Down
8 changes: 2 additions & 6 deletions socketio/src/socket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,12 +159,8 @@ impl Socket {

/// Handles new incoming engineio packets
fn handle_engineio_packet(&self, packet: EnginePacket) -> Result<Packet> {
let socket_packet = Packet::try_from(&packet.data);
if let Err(err) = socket_packet {
return Err(err);
}
// SAFETY: checked above to see if it was Err
let mut socket_packet = socket_packet.unwrap();
let mut socket_packet = Packet::try_from(&packet.data)?;

// Only handle attachments if there are any
if socket_packet.attachment_count > 0 {
let mut attachments_left = socket_packet.attachment_count;
Expand Down

0 comments on commit f3dc134

Please sign in to comment.