Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): Allow custom IO and UDS example #184

Merged
merged 5 commits into from
Dec 13, 2019
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add custom incoming support for server
  • Loading branch information
LucioFranco committed Dec 13, 2019
commit 4741b6cc3359b3767e48ce0c2c9bdb0c0101bec5
74 changes: 74 additions & 0 deletions tonic/src/transport/server/incoming.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
use super::Server;
use crate::transport::service::BoxedIo;
use futures_core::Stream;
use futures_util::stream::TryStreamExt;
use hyper::server::{
accept::Accept,
conn::{AddrIncoming, AddrStream},
};
use std::{
net::SocketAddr,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
#[cfg(feature = "tls")]
use tracing::error;

pub(crate) fn tcp_incoming<IO, IE>(
incoming: impl Stream<Item = Result<IO, IE>>,
server: Server,
) -> impl Stream<Item = Result<BoxedIo, crate::Error>>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
async_stream::try_stream! {
futures_util::pin_mut!(incoming);

while let Some(stream) = incoming.try_next().await? {
#[cfg(feature = "tls")]
{
if let Some(tls) = &server.tls {
let io = match tls.accept(stream).await {
Ok(io) => io,
Err(error) => {
error!(message = "Unable to accept incoming connection.", %error);
continue
},
};
yield BoxedIo::new(io);
continue;
}
}

yield BoxedIo::new(stream);
}
}
}

pub(crate) struct TcpIncoming {
inner: AddrIncoming,
}

impl TcpIncoming {
pub(crate) fn new(
addr: SocketAddr,
nodelay: bool,
keepalive: Option<Duration>,
) -> Result<Self, crate::Error> {
let mut inner = AddrIncoming::bind(&addr)?;
inner.set_nodelay(nodelay);
inner.set_keepalive(keepalive);
Ok(TcpIncoming { inner })
}
}

impl Stream for TcpIncoming {
type Item = Result<AddrStream, std::io::Error>;

fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_accept(cx)
}
}
93 changes: 37 additions & 56 deletions tonic/src/transport/server/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Server implementation and builder.

mod incoming;
#[cfg(feature = "tls")]
mod tls;

@@ -9,36 +10,34 @@ pub use tls::ServerTlsConfig;
#[cfg(feature = "tls")]
use super::service::TlsAcceptor;

use super::service::{layer_fn, BoxedIo, Or, Routes, ServiceBuilderExt};
use incoming::TcpIncoming;

use super::service::{layer_fn, Or, Routes, ServiceBuilderExt};
use crate::body::BoxBody;
use futures_core::Stream;
use futures_util::{
future::{self, poll_fn, MapErr},
future::{self, MapErr},
TryFutureExt,
};
use http::{HeaderMap, Request, Response};
use hyper::{
server::{accept::Accept, conn},
Body,
};
use std::time::Duration;
use hyper::{server::accept, Body};
use std::{
fmt,
future::Future,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
// time::Duration,
time::Duration,
};
use tokio::io::{AsyncRead, AsyncWrite};
use tower::{
layer::{Layer, Stack},
limit::concurrency::ConcurrencyLimitLayer,
// timeout::TimeoutLayer,
Service,
ServiceBuilder,
};
#[cfg(feature = "tls")]
use tracing::error;
use tracing_futures::{Instrument, Instrumented};

type BoxService = tower::util::BoxService<Request<Body>, Response<BoxBody>, crate::Error>;
@@ -242,16 +241,19 @@ impl Server {
Router::new(self.clone(), svc)
}

pub(crate) async fn serve_with_shutdown<S, F>(
pub(crate) async fn serve_with_shutdown<S, I, F, IO, IE>(
self,
addr: SocketAddr,
svc: S,
incoming: I,
signal: Option<F>,
) -> Result<(), super::Error>
where
S: Service<Request<Body>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
S::Error: Into<crate::Error> + Send,
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::Error>,
F: Future<Output = ()>,
{
let interceptor = self.interceptor.clone();
@@ -262,35 +264,8 @@ impl Server {
let max_concurrent_streams = self.max_concurrent_streams;
// let timeout = self.timeout.clone();

let incoming = hyper::server::accept::from_stream::<_, _, crate::Error>(
async_stream::try_stream! {
let mut incoming = conn::AddrIncoming::bind(&addr)?;

incoming.set_nodelay(self.tcp_nodelay);
incoming.set_keepalive(self.tcp_keepalive);



while let Some(stream) = next_accept(&mut incoming).await? {
#[cfg(feature = "tls")]
{
if let Some(tls) = &self.tls {
let io = match tls.connect(stream.into_inner()).await {
Ok(io) => io,
Err(error) => {
error!(message = "Unable to accept incoming connection.", %error);
continue
},
};
yield BoxedIo::new(io);
continue;
}
}

yield BoxedIo::new(stream);
}
},
);
let tcp = incoming::tcp_incoming(incoming, self);
let incoming = accept::from_stream::<_, _, crate::Error>(tcp);

let svc = MakeSvc {
inner: svc,
@@ -384,8 +359,10 @@ where
///
/// [`Server`]: struct.Server.html
pub async fn serve(self, addr: SocketAddr) -> Result<(), super::Error> {
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(map_err)?;
self.server
.serve_with_shutdown::<_, future::Ready<()>>(addr, self.routes, None)
.serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None)
.await
}

@@ -399,8 +376,25 @@ where
addr: SocketAddr,
f: F,
) -> Result<(), super::Error> {
let incoming = TcpIncoming::new(addr, self.server.tcp_nodelay, self.server.tcp_keepalive)
.map_err(map_err)?;
self.server
.serve_with_shutdown(addr, self.routes, Some(f))
.serve_with_shutdown(self.routes, incoming, Some(f))
.await
}

/// Consume this [`Server`] creating a future that will execute the server on
/// the provided incoming stream of `AsyncRead + AsyncWrite`.
///
/// [`Server`]: struct.Server.html
pub async fn serve_with_incoming<I, IO, IE>(self, incoming: I) -> Result<(), super::Error>
where
I: Stream<Item = Result<IO, IE>>,
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
IE: Into<crate::Error>,
{
self.server
.serve_with_shutdown::<_, _, future::Ready<()>, _, _>(self.routes, incoming, None)
.await
}
}
@@ -523,16 +517,3 @@ impl Service<Request<Body>> for Unimplemented {
)
}
}

// Implement try_next for `Accept::poll_accept`.
async fn next_accept(
incoming: &mut conn::AddrIncoming,
) -> Result<Option<conn::AddrStream>, crate::Error> {
let res = poll_fn(|cx| Pin::new(&mut *incoming).poll_accept(cx)).await;

if let Some(res) = res {
Ok(Some(res?))
} else {
return Ok(None);
}
}
4 changes: 2 additions & 2 deletions tonic/src/transport/service/io.rs
Original file line number Diff line number Diff line change
@@ -5,11 +5,11 @@ use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};

pub(in crate::transport) trait Io:
AsyncRead + AsyncWrite + Send + Unpin + 'static
AsyncRead + AsyncWrite + Send + 'static
{
}

impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + Unpin + 'static {}
impl<T> Io for T where T: AsyncRead + AsyncWrite + Send + 'static {}

pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);

6 changes: 4 additions & 2 deletions tonic/src/transport/service/tls.rs
Original file line number Diff line number Diff line change
@@ -4,7 +4,6 @@ use crate::transport::{Certificate, Identity};
use rustls_native_certs;
use std::{fmt, sync::Arc};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::net::TcpStream;
#[cfg(feature = "tls")]
use tokio_rustls::{
rustls::{ClientConfig, NoClientAuth, ServerConfig, Session},
@@ -158,7 +157,10 @@ impl TlsAcceptor {
})
}

pub(crate) async fn connect(&self, io: TcpStream) -> Result<BoxedIo, crate::Error> {
pub(crate) async fn accept<IO>(&self, io: IO) -> Result<BoxedIo, crate::Error>
where
IO: AsyncRead + AsyncWrite + Unpin + Send + 'static,
{
let io = {
let acceptor = RustlsAcceptor::from(self.inner.clone());
let tls = acceptor.accept(io).await?;