Skip to content

Commit

Permalink
make code compile
Browse files Browse the repository at this point in the history
a bit of a hack to for now not require send bounds
for service async call fns

this until we find a fix for
plabayo/tower-async#9
  • Loading branch information
glendc committed Jul 27, 2023
1 parent e30ccc9 commit e423ffa
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 43 deletions.
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ bytes = "1.4"
pin-project-lite = "0.2"
tokio = { version = "1", features = ["full"] }
tokio-util = "^0.7.8"
tower-async = { version = "0.1" }
tower-async = { version = "0.1", features = ["full"]}
tracing = "0.1"

[dev-dependencies]
anyhow = "1.0"
tokio-test = "0.4"
tracing-subscriber = "0.3"
tracing-subscriber = { version = "0.3", features = ["env-filter"] }
16 changes: 16 additions & 0 deletions examples/tcp_echo_server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use rama::transport::{bytes::service::EchoService, tcp::server::TcpListener};

use tower_async::{make::Shared, BoxError};
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

#[tokio::main]
async fn main() -> Result<(), BoxError> {
tracing_subscriber::registry()
.with(fmt::layer())
.with(EnvFilter::from_default_env())
.init();

let service = Shared::new(EchoService::new());
TcpListener::new()?.serve(service).await?;
Ok(())
}
6 changes: 3 additions & 3 deletions src/transport/bytes/service/echo.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::io::{Error, ErrorKind};

use tower_async::Service;

use crate::transport::{bytes::ByteStream, connection::Connection};
use crate::transport::{bytes::ByteStream, Connection};

/// An async service which echoes the incoming bytes back on the same connection.
///
Expand All @@ -15,7 +15,7 @@ use crate::transport::{bytes::ByteStream, connection::Connection};
/// # #[tokio::main]
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let stream = tokio_test::io::Builder::new().read(b"hello world").write(b"hello world").build();
/// # let conn = rama::transport::connection::Connection::new(stream, rama::transport::graceful::Token::pending(), ());
/// # let conn = rama::transport::Connection::new(stream, rama::transport::graceful::Token::pending(), ());
/// let mut service = EchoService::new()
/// .respect_shutdown(Some(std::time::Duration::from_secs(5)));
///
Expand All @@ -24,7 +24,7 @@ use crate::transport::{bytes::ByteStream, connection::Connection};
/// # Ok(())
/// # }
/// ```
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct EchoService {
respect_shutdown: bool,
shutdown_delay: Option<std::time::Duration>,
Expand Down
4 changes: 2 additions & 2 deletions src/transport/bytes/service/forward.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{

use tower_async::Service;

use crate::transport::{bytes::ByteStream, connection::Connection};
use crate::transport::{bytes::ByteStream, Connection};

/// Async service which forwards the incoming connection bytes to the given destination,
/// and forwards the response back from the destination to the incoming connection.
Expand All @@ -20,7 +20,7 @@ use crate::transport::{bytes::ByteStream, connection::Connection};
/// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// # let destination = tokio_test::io::Builder::new().write(b"hello world").read(b"hello world").build();
/// # let stream = tokio_test::io::Builder::new().read(b"hello world").write(b"hello world").build();
/// # let conn = rama::transport::connection::Connection::new(stream, rama::transport::graceful::Token::pending(), ());
/// # let conn = rama::transport::Connection::new(stream, rama::transport::graceful::Token::pending(), ());
/// let mut service = ForwardService::new(destination)
/// .respect_shutdown(Some(std::time::Duration::from_secs(5)));
///
Expand Down
8 changes: 4 additions & 4 deletions src/transport/connection/service_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ use super::Connection;
///
/// ```
/// use rama::transport::connection::service_fn;
/// # use rama::transport::connection::Connection;
/// # use rama::transport::Connection;
/// # use rama::transport::graceful::Token;
/// use tower_async::Service;
/// use std::convert::Infallible;
Expand Down Expand Up @@ -54,7 +54,7 @@ use super::Connection;
///
/// ```
/// use rama::transport::connection::service_fn;
/// # use rama::transport::connection::Connection;
/// # use rama::transport::Connection;
/// # use rama::transport::graceful::Token;
/// use tower_async::Service;
/// use std::convert::Infallible;
Expand Down Expand Up @@ -99,7 +99,7 @@ use super::Connection;
///
/// ```
/// use rama::transport::connection::service_fn;
/// # use rama::transport::connection::Connection;
/// # use rama::transport::Connection;
/// # use rama::transport::graceful::Token;
/// use tower_async::Service;
/// use std::convert::Infallible;
Expand Down Expand Up @@ -143,7 +143,7 @@ use super::Connection;
///
/// ```
/// use rama::transport::connection::service_fn;
/// # use rama::transport::connection::Connection;
/// # use rama::transport::Connection;
/// # use rama::transport::graceful::Token;
/// use tower_async::Service;
/// use std::convert::Infallible;
Expand Down
3 changes: 3 additions & 0 deletions src/transport/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,6 @@ pub mod bytes;
pub mod connection;
pub mod graceful;
pub mod tcp;

pub use connection::Connection;
pub use graceful::GracefulService;
106 changes: 75 additions & 31 deletions src/transport/tcp/server/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,11 @@ use std::{
time::Duration,
};
use tokio::net::TcpStream;
use tower_async::{BoxError, Service};
use tower_async::{BoxError, MakeService, Service};
use tracing::info;

use super::error::{Error, ErrorHandler, ErrorKind};
use crate::transport::{connection::Connection, graceful};
use crate::transport::{graceful, Connection, GracefulService};

/// Listens to incoming TCP connections and serves them with a [`tower_async::Service`].
///
Expand All @@ -21,7 +21,7 @@ use crate::transport::{connection::Connection, graceful};
pub struct TcpListener<S, H> {
listener: tokio::net::TcpListener,
shutdown_timeout: Option<Duration>,
graceful: graceful::GracefulService,
graceful: GracefulService,
err_handler: H,
state: S,
}
Expand Down Expand Up @@ -76,7 +76,7 @@ impl<H> TcpListener<private::NoState, H> {
/// which will be passed to the [`tower_async::Service`] for each incoming connection.
///
/// [`tower_async::Service`]: https://docs.rs/tower-async/*/tower_async/trait.Service.html
pub fn state<S>(self, state: S) -> TcpListener<S, H>
pub fn state<S>(self, state: S) -> TcpListener<private::SomeState<S>, H>
where
S: Clone + Send + 'static,
{
Expand All @@ -85,7 +85,7 @@ impl<H> TcpListener<private::NoState, H> {
shutdown_timeout: self.shutdown_timeout,
graceful: self.graceful,
err_handler: self.err_handler,
state,
state: private::SomeState(state),
}
}
}
Expand Down Expand Up @@ -125,27 +125,35 @@ impl<S, H> TcpListener<S, H> {
impl<S, H> TcpListener<S, H>
where
H: ErrorHandler,
S: Clone + Send + 'static,
S: private::IntoState,
S::State: Clone + Send + 'static,
{
/// Serves incoming connections with a [`tower_async::Service`] that acts as a factory,
/// creating a new [`Service`] for each incoming connection.
pub async fn serve<F>(mut self, mut service_factory: F) -> Result<(), BoxError>
pub async fn serve<Factory>(mut self, mut service_factory: Factory) -> Result<(), BoxError>
where
F: Service<SocketAddr>,
F::Response: Service<Connection<TcpStream, S>, call(): Send> + Send + 'static,
F::Error: Into<BoxError>,
<F::Response as Service<Connection<TcpStream, S>>>::Error: Into<BoxError> + Send + 'static,
Factory: MakeService<SocketAddr, Connection<TcpStream, S::State>>,
// Factory::Service: Service<Connection<TcpStream, S::State>, call(): Send> + Send + 'static,
Factory::Service: Service<Connection<TcpStream, S::State>> + Send + 'static,
Factory::MakeError: Into<BoxError>,
Factory::Error: Into<BoxError> + Send + 'static,
<Factory as MakeService<
std::net::SocketAddr,
Connection<tokio::net::TcpStream, <S as private::IntoState>::State>,
>>::Service: Send,
{
let (service_err_tx, mut service_err_rx) = tokio::sync::mpsc::unbounded_channel();
let state = self.state.into_state();

// let (service_err_tx, mut service_err_rx) = tokio::sync::mpsc::unbounded_channel();
loop {
let (socket, peer_addr) = tokio::select! {
maybe_err = service_err_rx.recv() => {
if let Some(err) = maybe_err {
let error = Error::new(ErrorKind::Accept, err);
self.err_handler.handle(error).await?;
}
continue;
},
// maybe_err = service_err_rx.recv() => {
// if let Some(err) = maybe_err {
// let error = Error::new(ErrorKind::Accept, err);
// self.err_handler.handle(error).await?;
// }
// continue;
// },
result = self.listener.accept() => {
match result{
Ok((socket, peer_addr)) => (socket, peer_addr),
Expand All @@ -159,7 +167,7 @@ where
_ = self.graceful.shutdown_req() => break,
};

let mut service = match service_factory.call(peer_addr).await {
let mut service = match service_factory.make_service(peer_addr).await {
Ok(service) => service,
Err(err) => {
let error = Error::new(ErrorKind::Factory, err);
Expand All @@ -169,14 +177,22 @@ where
};

let token = self.graceful.token();
let state = self.state.clone();
let service_err_tx = service_err_tx.clone();
tokio::spawn(async move {
let conn: Connection<_, _> = Connection::new(socket, token, state);
if let Err(err) = service.call(conn).await {
let _ = service_err_tx.send(err);
}
});
let state = state.clone();
// let service_err_tx = service_err_tx.clone();
let conn: Connection<_, _> = Connection::new(socket, token, state);

// TODO: enable this kind of features once again when
// this bug is fixed: https://github.com/plabayo/tower-async/issues/9
// tokio::spawn(async move {
// if let Err(err) = service.call(conn).await {
// let _ = service_err_tx.send(err);
// }
// });

if let Err(err) = service.call(conn).await {
let error = Error::new(ErrorKind::Accept, err);
self.err_handler.handle(error).await?;
}
}

// wait for all services to finish
Expand All @@ -199,11 +215,39 @@ mod private {

use crate::transport::tcp::server::error::{Error, ErrorHandler, ErrorKind};

#[derive(Debug, Clone, Copy, Default)]
pub(super) struct NoState;
/// Marker trait for the [`super::TcpListener`] to indicate
/// no state is defined, meaning we'll fallback to the empty type `()`.
#[derive(Debug)]
pub struct NoState;

/// Marker trait for the [`super::TcpListener`] to indicate
/// some state is defined, meaning we'll use the type `T` in the end for the
/// passed down [`crate::transport::Connection`].
#[derive(Debug)]
pub struct SomeState<T>(pub T);

pub trait IntoState {
type State;

fn into_state(self) -> Self::State;
}

impl IntoState for NoState {
type State = ();

fn into_state(self) -> Self::State {}
}

impl<T> IntoState for SomeState<T> {
type State = T;

fn into_state(self) -> Self::State {
self.0
}
}

#[derive(Debug, Clone, Copy, Default)]
pub(super) struct DefaultErrorHandler;
pub struct DefaultErrorHandler;

impl ErrorHandler for DefaultErrorHandler {
async fn handle(&mut self, error: Error) -> std::result::Result<(), BoxError> {
Expand Down
1 change: 0 additions & 1 deletion src/transport/tcp/server/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
//! as the entrypoint of pretty much any Rama server.
pub mod error;
pub mod factory;

mod listener;
pub use listener::*;

0 comments on commit e423ffa

Please sign in to comment.