Skip to content

Commit

Permalink
chore: make websocket to optional support
Browse files Browse the repository at this point in the history
  • Loading branch information
driftluo committed Sep 14, 2020
1 parent 4883ee4 commit 3d38cc3
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 13 deletions.
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ tokio-util = { version = "0.3.0", features = ["codec"] }
log = "0.4"
bytes = "0.5.0"
thiserror = "1.0"
tokio-tungstenite = { version = "0.11" }
tokio-tungstenite = { version = "0.11", optional = true }

flatbuffers = { version = "0.6.0", optional = true }
flatbuffers-verifier = { version = "0.2.0", optional = true }
Expand Down Expand Up @@ -58,6 +58,7 @@ default = []
flatc = [ "flatbuffers", "flatbuffers-verifier", "secio/flatc" ]
# use molecule to handshake
molc = [ "molecule", "secio/molc" ]
ws = ["tokio-tungstenite"]

[workspace]
members = [
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ fmt:
cargo fmt --all -- --check

clippy:
RUSTFLAGS='-F warnings' cargo clippy --all --tests --features molc -- -D clippy::let_underscore_must_use
RUSTFLAGS='-F warnings' cargo clippy --all --tests --features molc,ws -- -D clippy::let_underscore_must_use
RUSTFLAGS='-F warnings' cargo clippy --all --tests --features flatc -- -D clippy::let_underscore_must_use

test:
RUSTFLAGS='-F warnings' RUST_BACKTRACE=full cargo test --all --features molc
RUSTFLAGS='-F warnings' RUST_BACKTRACE=full cargo test --all --features molc,ws
RUSTFLAGS='-F warnings' RUST_BACKTRACE=full cargo test --all --features flatc

examples:
Expand Down
4 changes: 2 additions & 2 deletions multiaddr/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ const IP6: u32 = 0x29;
const P2P: u32 = 0x01a5;
const TCP: u32 = 0x06;
const TLS: u32 = 0x01c0;
const WS: u32 = 477;
const WSS: u32 = 478;
const WS: u32 = 0x01dd;
const WSS: u32 = 0x01de;

const SHA256_CODE: u16 = 0x12;
const SHA256_SIZE: u8 = 32;
Expand Down
30 changes: 25 additions & 5 deletions src/transports/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use crate::{
utils::socketaddr_to_multiaddr,
};

use futures::{prelude::Stream, FutureExt, StreamExt};
use futures::{prelude::Stream, FutureExt};
use log::debug;
use std::{
fmt,
Expand All @@ -19,12 +19,14 @@ use tokio::{
prelude::{AsyncRead, AsyncWrite},
};

use self::{
tcp::{TcpDialFuture, TcpListenFuture, TcpTransport},
ws::{WebsocketListener, WsDialFuture, WsListenFuture, WsStream, WsTransport},
};
use self::tcp::{TcpDialFuture, TcpListenFuture, TcpTransport};
#[cfg(feature = "ws")]
use self::ws::{WebsocketListener, WsDialFuture, WsListenFuture, WsStream, WsTransport};
#[cfg(feature = "ws")]
use futures::StreamExt;

mod tcp;
#[cfg(feature = "ws")]
mod ws;

type Result<T> = std::result::Result<T, TransportErrorKind>;
Expand Down Expand Up @@ -61,10 +63,13 @@ impl Transport for MultiTransport {
Ok(future) => Ok(MultiListenFuture::Tcp(future)),
Err(e) => Err(e),
},
#[cfg(feature = "ws")]
TransportType::Ws => match WsTransport::new(self.timeout).listen(address) {
Ok(future) => Ok(MultiListenFuture::Ws(future)),
Err(e) => Err(e),
},
#[cfg(not(feature = "ws"))]
TransportType::Ws => Err(TransportErrorKind::NotSupported(address)),
TransportType::Wss => Err(TransportErrorKind::NotSupported(address)),
TransportType::TLS => Err(TransportErrorKind::NotSupported(address)),
}
Expand All @@ -76,10 +81,13 @@ impl Transport for MultiTransport {
Ok(res) => Ok(MultiDialFuture::Tcp(res)),
Err(e) => Err(e),
},
#[cfg(feature = "ws")]
TransportType::Ws => match WsTransport::new(self.timeout).dial(address) {
Ok(future) => Ok(MultiDialFuture::Ws(future)),
Err(e) => Err(e),
},
#[cfg(not(feature = "ws"))]
TransportType::Ws => Err(TransportErrorKind::NotSupported(address)),
TransportType::Wss => Err(TransportErrorKind::NotSupported(address)),
TransportType::TLS => Err(TransportErrorKind::NotSupported(address)),
}
Expand All @@ -88,6 +96,7 @@ impl Transport for MultiTransport {

pub enum MultiListenFuture {
Tcp(TcpListenFuture),
#[cfg(feature = "ws")]
Ws(WsListenFuture),
}

Expand All @@ -100,6 +109,7 @@ impl Future for MultiListenFuture {
Pin::new(&mut inner.map(|res| res.map(|res| (res.0, MultiIncoming::Tcp(res.1)))))
.poll(cx)
}
#[cfg(feature = "ws")]
MultiListenFuture::Ws(inner) => {
Pin::new(&mut inner.map(|res| res.map(|res| (res.0, MultiIncoming::Ws(res.1)))))
.poll(cx)
Expand All @@ -110,6 +120,7 @@ impl Future for MultiListenFuture {

pub enum MultiDialFuture {
Tcp(TcpDialFuture),
#[cfg(feature = "ws")]
Ws(WsDialFuture),
}

Expand All @@ -122,6 +133,7 @@ impl Future for MultiDialFuture {
Pin::new(&mut inner.map(|res| res.map(|res| (res.0, MultiStream::Tcp(res.1)))))
.poll(cx)
}
#[cfg(feature = "ws")]
MultiDialFuture::Ws(inner) => Pin::new(
&mut inner.map(|res| res.map(|res| (res.0, MultiStream::Ws(Box::new(res.1))))),
)
Expand All @@ -132,13 +144,15 @@ impl Future for MultiDialFuture {

pub enum MultiStream {
Tcp(TcpStream),
#[cfg(feature = "ws")]
Ws(Box<WsStream>),
}

impl fmt::Debug for MultiStream {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
MultiStream::Tcp(_) => write!(f, "Tcp stream"),
#[cfg(feature = "ws")]
MultiStream::Ws(_) => write!(f, "Websocket stream"),
}
}
Expand All @@ -152,6 +166,7 @@ impl AsyncRead for MultiStream {
) -> Poll<io::Result<usize>> {
match self.get_mut() {
MultiStream::Tcp(inner) => Pin::new(inner).poll_read(cx, buf),
#[cfg(feature = "ws")]
MultiStream::Ws(inner) => Pin::new(inner).poll_read(cx, buf),
}
}
Expand All @@ -161,13 +176,15 @@ impl AsyncWrite for MultiStream {
fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
match self.get_mut() {
MultiStream::Tcp(inner) => Pin::new(inner).poll_write(cx, buf),
#[cfg(feature = "ws")]
MultiStream::Ws(inner) => Pin::new(inner).poll_write(cx, buf),
}
}

fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
match self.get_mut() {
MultiStream::Tcp(inner) => Pin::new(inner).poll_flush(cx),
#[cfg(feature = "ws")]
MultiStream::Ws(inner) => Pin::new(inner).poll_flush(cx),
}
}
Expand All @@ -176,6 +193,7 @@ impl AsyncWrite for MultiStream {
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
match self.get_mut() {
MultiStream::Tcp(inner) => Pin::new(inner).poll_shutdown(cx),
#[cfg(feature = "ws")]
MultiStream::Ws(inner) => Pin::new(inner).poll_shutdown(cx),
}
}
Expand All @@ -184,6 +202,7 @@ impl AsyncWrite for MultiStream {
#[derive(Debug)]
pub enum MultiIncoming {
Tcp(TcpListener),
#[cfg(feature = "ws")]
Ws(WebsocketListener),
}

Expand All @@ -208,6 +227,7 @@ impl Stream for MultiIncoming {
},
Poll::Pending => Poll::Pending,
},
#[cfg(feature = "ws")]
MultiIncoming::Ws(inner) => match inner.poll_next_unpin(cx)? {
Poll::Ready(Some((addr, stream))) => {
Poll::Ready(Some(Ok((addr, MultiStream::Ws(Box::new(stream))))))
Expand Down
5 changes: 2 additions & 3 deletions src/transports/ws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,8 @@ impl Stream for WebsocketListener {
type Item = std::result::Result<(Multiaddr, WsStream), io::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
match self.poll_pending(cx) {
Poll::Ready(res) => return Poll::Ready(res),
Poll::Pending => (),
if let Poll::Ready(res) = self.poll_pending(cx) {
return Poll::Ready(res);
}

match self.inner.poll_accept(cx)? {
Expand Down

0 comments on commit 3d38cc3

Please sign in to comment.