Skip to content

Commit

Permalink
Tokio 0.3 (hyperium#29)
Browse files Browse the repository at this point in the history
* Remove futures-core

* Upgrade Tokio 0.3

* clean code

* Fix ci

* Fix lint
  • Loading branch information
quininer authored Oct 16, 2020
1 parent c3bf063 commit e6ef546
Show file tree
Hide file tree
Showing 14 changed files with 100 additions and 303 deletions.
2 changes: 1 addition & 1 deletion tokio-native-tls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
rust_2018_idioms,
unreachable_pub
)]
#![deny(intra_doc_link_resolution_failure)]
#![deny(broken_intra_doc_links)]
#![doc(test(
no_crate_inject,
attr(deny(warnings, rust_2018_idioms), allow(dead_code, unused_variables))
Expand Down
10 changes: 3 additions & 7 deletions tokio-rustls/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "tokio-rustls"
version = "0.14.1"
version = "0.20.0"
authors = ["quininer kel <quininer@live.com>"]
license = "MIT/Apache-2.0"
repository = "https://github.com/tokio-rs/tls"
Expand All @@ -12,20 +12,16 @@ categories = ["asynchronous", "cryptography", "network-programming"]
edition = "2018"

[dependencies]
tokio = "0.2.0"
futures-core = "0.3.1"
tokio = "0.3"
rustls = "0.18"
webpki = "0.21"

bytes = { version = "0.5", optional = true }

[features]
early-data = []
dangerous_configuration = ["rustls/dangerous_configuration"]
unstable = ["bytes"]

[dev-dependencies]
tokio = { version = "0.2.0", features = ["macros", "net", "io-util", "rt-core", "time"] }
tokio = { version = "0.3", features = ["full"] }
futures-util = "0.3.1"
lazy_static = "1"
webpki-roots = "0.20"
2 changes: 1 addition & 1 deletion tokio-rustls/examples/client/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = ["quininer <quininer@live.com>"]
edition = "2018"

[dependencies]
tokio = { version = "0.2", features = [ "full" ] }
tokio = { version = "0.3", features = [ "full" ] }
argh = "0.1"
tokio-rustls = { path = "../.." }
webpki-roots = "0.20"
2 changes: 1 addition & 1 deletion tokio-rustls/examples/server/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@ authors = ["quininer <quininer@live.com>"]
edition = "2018"

[dependencies]
tokio = { version = "0.2", features = [ "full" ] }
tokio = { version = "0.3", features = [ "full" ] }
argh = "0.1"
tokio-rustls = { path = "../.." }
2 changes: 1 addition & 1 deletion tokio-rustls/examples/server/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ async fn main() -> io::Result<()> {
.map_err(|err| io::Error::new(io::ErrorKind::InvalidInput, err))?;
let acceptor = TlsAcceptor::from(Arc::new(config));

let mut listener = TcpListener::bind(&addr).await?;
let listener = TcpListener::bind(&addr).await?;

loop {
let (stream, peer_addr) = listener.accept().await?;
Expand Down
27 changes: 11 additions & 16 deletions tokio-rustls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,38 +52,36 @@ impl<IO> AsyncRead for TlsStream<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
#[cfg(feature = "unstable")]
unsafe fn prepare_uninitialized_buffer(&self, _buf: &mut [std::mem::MaybeUninit<u8>]) -> bool {
false
}

fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
match self.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(..) => Poll::Pending,
TlsState::Stream | TlsState::WriteShutdown => {
let this = self.get_mut();
let mut stream =
Stream::new(&mut this.io, &mut this.session).set_eof(!this.state.readable());
let prev = buf.remaining();

match stream.as_mut_pin().poll_read(cx, buf) {
Poll::Ready(Ok(0)) => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
Poll::Ready(Ok(())) => {
if prev == buf.remaining() {
this.state.shutdown_read();
}

Poll::Ready(Ok(()))
}
Poll::Ready(Ok(n)) => Poll::Ready(Ok(n)),
Poll::Ready(Err(ref e)) if e.kind() == io::ErrorKind::ConnectionAborted => {
this.state.shutdown_read();
Poll::Ready(Ok(0))
Poll::Ready(Ok(()))
}
output => output,
}
}
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(0)),
TlsState::ReadShutdown | TlsState::FullyShutdown => Poll::Ready(Ok(())),
}
}
}
Expand All @@ -107,7 +105,6 @@ where
match this.state {
#[cfg(feature = "early-data")]
TlsState::EarlyData(ref mut pos, ref mut data) => {
use futures_core::ready;
use std::io::Write;

// write early data
Expand Down Expand Up @@ -153,8 +150,6 @@ where

#[cfg(feature = "early-data")]
{
use futures_core::ready;

if let TlsState::EarlyData(ref mut pos, ref mut data) = this.state {
// complete handshake
while stream.session.is_handshaking() {
Expand Down
16 changes: 0 additions & 16 deletions tokio-rustls/src/common/handshake.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use crate::common::{Stream, TlsState};
use futures_core::future::FusedFuture;
use rustls::Session;
use std::future::Future;
use std::pin::Pin;
Expand All @@ -21,21 +20,6 @@ pub(crate) enum MidHandshake<IS> {
End,
}

impl<IS> FusedFuture for MidHandshake<IS>
where
IS: IoSession + Unpin,
IS::Io: AsyncRead + AsyncWrite + Unpin,
IS::Session: Session + Unpin,
{
fn is_terminated(&self) -> bool {
if let MidHandshake::End = self {
true
} else {
false
}
}
}

impl<IS> Future for MidHandshake<IS>
where
IS: IoSession + Unpin,
Expand Down
82 changes: 30 additions & 52 deletions tokio-rustls/src/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,11 @@
mod handshake;

#[cfg(feature = "unstable")]
mod vecbuf;

use futures_core as futures;
pub(crate) use handshake::{IoSession, MidHandshake};
use rustls::Session;
use std::io::{self, Read, Write};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

#[derive(Debug)]
pub enum TlsState {
Expand Down Expand Up @@ -40,27 +36,18 @@ impl TlsState {

#[inline]
pub fn writeable(&self) -> bool {
match *self {
TlsState::WriteShutdown | TlsState::FullyShutdown => false,
_ => true,
}
!matches!(*self, TlsState::WriteShutdown | TlsState::FullyShutdown)
}

#[inline]
pub fn readable(&self) -> bool {
match self {
TlsState::ReadShutdown | TlsState::FullyShutdown => false,
_ => true,
}
!matches!(*self, TlsState::ReadShutdown | TlsState::FullyShutdown)
}

#[inline]
#[cfg(feature = "early-data")]
pub fn is_early_data(&self) -> bool {
match self {
TlsState::EarlyData(..) => true,
_ => false,
}
matches!(self, TlsState::EarlyData(..))
}

#[inline]
Expand Down Expand Up @@ -105,8 +92,10 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
impl<'a, 'b, T: AsyncRead + Unpin> Read for Reader<'a, 'b, T> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
match Pin::new(&mut self.io).poll_read(self.cx, buf) {
Poll::Ready(result) => result,
let mut buf = ReadBuf::new(buf);
match Pin::new(&mut self.io).poll_read(self.cx, &mut buf) {
Poll::Ready(Ok(())) => Ok(buf.filled().len()),
Poll::Ready(Err(err)) => Err(err),
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}
Expand All @@ -133,9 +122,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}

pub fn write_io(&mut self, cx: &mut Context) -> Poll<io::Result<usize>> {
#[cfg(feature = "unstable")]
use std::io::IoSlice;

struct Writer<'a, 'b, T> {
io: &'a mut T,
cx: &'a mut Context<'b>,
Expand All @@ -150,19 +136,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
}
}

#[cfg(feature = "unstable")]
#[inline]
fn write_vectored(&mut self, bufs: &[IoSlice]) -> io::Result<usize> {
use vecbuf::VecBuf;

let mut vbuf = VecBuf::new(bufs);

match Pin::new(&mut self.io).poll_write_buf(self.cx, &mut vbuf) {
Poll::Ready(result) => result,
Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
}
}

fn flush(&mut self) -> io::Result<()> {
match Pin::new(&mut self.io).poll_flush(self.cx) {
Poll::Ready(result) => result,
Expand Down Expand Up @@ -232,12 +205,12 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> Stream<'a, IO, S> {
impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a, IO, S> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let mut pos = 0;
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let prev = buf.remaining();

while pos != buf.len() {
while buf.remaining() != 0 {
let mut would_block = false;

// read a packet
Expand All @@ -256,22 +229,28 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncRead for Stream<'a
}
}

return match self.session.read(&mut buf[pos..]) {
Ok(0) if pos == 0 && would_block => Poll::Pending,
Ok(n) if self.eof || would_block => Poll::Ready(Ok(pos + n)),
return match self.session.read(buf.initialize_unfilled()) {
Ok(0) if prev == buf.remaining() && would_block => Poll::Pending,
Ok(n) => {
pos += n;
continue;
buf.advance(n);

if self.eof || would_block {
break;
} else {
continue;
}
}
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
Err(ref err) if err.kind() == io::ErrorKind::ConnectionAborted && pos != 0 => {
Poll::Ready(Ok(pos))
Err(ref err)
if err.kind() == io::ErrorKind::ConnectionAborted
&& prev != buf.remaining() =>
{
break
}
Err(err) => Poll::Ready(Err(err)),
};
}

Poll::Ready(Ok(pos))
Poll::Ready(Ok(()))
}
}

Expand All @@ -288,7 +267,6 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'

match self.session.write(&buf[pos..]) {
Ok(n) => pos += n,
Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err)),
};

Expand Down Expand Up @@ -316,14 +294,14 @@ impl<'a, IO: AsyncRead + AsyncWrite + Unpin, S: Session> AsyncWrite for Stream<'
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
self.session.flush()?;
while self.session.wants_write() {
futures::ready!(self.write_io(cx))?;
ready!(self.write_io(cx))?;
}
Pin::new(&mut self.io).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
while self.session.wants_write() {
futures::ready!(self.write_io(cx))?;
ready!(self.write_io(cx))?;
}
Pin::new(&mut self.io).poll_shutdown(cx)
}
Expand Down
27 changes: 17 additions & 10 deletions tokio-rustls/src/common/test_stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::Stream;
use futures_core::ready;
use futures_util::future::poll_fn;
use futures_util::task::noop_waker_ref;
use rustls::internal::pemfile::{certs, rsa_private_keys};
Expand All @@ -8,7 +7,7 @@ use std::io::{self, BufReader, Cursor, Read, Write};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadBuf};
use webpki::DNSNameRef;

struct Good<'a>(&'a mut dyn Session);
Expand All @@ -17,9 +16,17 @@ impl<'a> AsyncRead for Good<'a> {
fn poll_read(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
mut buf: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(self.0.write_tls(buf.by_ref()))
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
let mut buf2 = buf.initialize_unfilled();

Poll::Ready(match self.0.write_tls(buf2.by_ref()) {
Ok(n) => {
buf.advance(n);
Ok(())
}
Err(err) => Err(err),
})
}
}

Expand Down Expand Up @@ -55,8 +62,8 @@ impl AsyncRead for Pending {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
_: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Poll::Pending
}
}
Expand Down Expand Up @@ -85,9 +92,9 @@ impl AsyncRead for Eof {
fn poll_read(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_: &mut [u8],
) -> Poll<io::Result<usize>> {
Poll::Ready(Ok(0))
_: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}

Expand Down
Loading

0 comments on commit e6ef546

Please sign in to comment.