diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 7f29b990a..0cd42cbae 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -42,6 +42,14 @@ path = "src/load_balance/client.rs" name = "load-balance-server" path = "src/load_balance/server.rs" +[[bin]] +name = "dynamic-load-balance-client" +path = "src/dynamic_load_balance/client.rs" + +[[bin]] +name = "dynamic-load-balance-server" +path = "src/dynamic_load_balance/server.rs" + [[bin]] name = "tls-client" path = "src/tls/client.rs" @@ -123,7 +131,7 @@ serde_json = "1.0" rand = "0.7" # Tracing tracing = "0.1" -tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] } +tracing-subscriber = { version = "0.2", features = ["tracing-log"] } tracing-attributes = "0.1" tracing-futures = "0.2" # Required for wellknown types diff --git a/examples/README.md b/examples/README.md index 1783b4bd3..784af6279 100644 --- a/examples/README.md +++ b/examples/README.md @@ -44,7 +44,7 @@ $ cargo run --bin authentication-client $ cargo run --bin authentication-server ``` -## Load balance +## Load Balance ### Client @@ -58,6 +58,20 @@ $ cargo run --bin load-balance-client $ cargo run --bin load-balance-server ``` +## Dyanmic Load Balance + +### Client + +```bash +$ cargo run --bin dynamic-load-balance-client +``` + +### Server + +```bash +$ cargo run --bin dynamic-load-balance-server +``` + ## TLS (rustls) ### Client diff --git a/examples/src/dynamic_load_balance/client.rs b/examples/src/dynamic_load_balance/client.rs new file mode 100644 index 000000000..4ffa6a66d --- /dev/null +++ b/examples/src/dynamic_load_balance/client.rs @@ -0,0 +1,81 @@ +pub mod pb { + tonic::include_proto!("grpc.examples.echo"); +} + +use pb::{echo_client::EchoClient, EchoRequest}; +use tonic::transport::Channel; + +use tonic::transport::Endpoint; + +use std::sync::Arc; + +use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; +use tokio::time::timeout; +use tower::discover::Change; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let e1 = Endpoint::from_static("http://[::1]:50051"); + let e2 = Endpoint::from_static("http://[::1]:50052"); + + let (channel, mut rx) = Channel::balance_channel(10); + let mut client = EchoClient::new(channel); + + let done = Arc::new(AtomicBool::new(false)); + let demo_done = done.clone(); + tokio::spawn(async move { + tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await; + println!("Added first endpoint"); + let change = Change::Insert("1", e1); + let res = rx.send(change).await; + println!("{:?}", res); + tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await; + println!("Added second endpoint"); + let change = Change::Insert("2", e2); + let res = rx.send(change).await; + println!("{:?}", res); + tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await; + println!("Removed first endpoint"); + let change = Change::Remove("1"); + let res = rx.send(change).await; + println!("{:?}", res); + + tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await; + println!("Removed second endpoint"); + let change = Change::Remove("2"); + let res = rx.send(change).await; + println!("{:?}", res); + + tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await; + println!("Added third endpoint"); + let e3 = Endpoint::from_static("http://[::1]:50051"); + let change = Change::Insert("3", e3); + let res = rx.send(change).await; + println!("{:?}", res); + + tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await; + println!("Removed third endpoint"); + let change = Change::Remove("3"); + let res = rx.send(change).await; + println!("{:?}", res); + demo_done.swap(true, SeqCst); + }); + + while !done.load(SeqCst) { + tokio::time::delay_for(tokio::time::Duration::from_millis(500)).await; + let request = tonic::Request::new(EchoRequest { + message: "hello".into(), + }); + + let rx = client.unary_echo(request); + if let Ok(resp) = timeout(tokio::time::Duration::from_secs(10), rx).await { + println!("RESPONSE={:?}", resp); + } else { + println!("did not receive value within 10 secs"); + } + } + + println!("... Bye"); + + Ok(()) +} diff --git a/examples/src/dynamic_load_balance/server.rs b/examples/src/dynamic_load_balance/server.rs new file mode 100644 index 000000000..8623d8f7a --- /dev/null +++ b/examples/src/dynamic_load_balance/server.rs @@ -0,0 +1,82 @@ +pub mod pb { + tonic::include_proto!("grpc.examples.echo"); +} + +use futures::Stream; +use std::net::SocketAddr; +use std::pin::Pin; +use tokio::sync::mpsc; +use tonic::{transport::Server, Request, Response, Status, Streaming}; + +use pb::{EchoRequest, EchoResponse}; + +type EchoResult = Result, Status>; +type ResponseStream = Pin> + Send + Sync>>; + +#[derive(Debug)] +pub struct EchoServer { + addr: SocketAddr, +} + +#[tonic::async_trait] +impl pb::echo_server::Echo for EchoServer { + async fn unary_echo(&self, request: Request) -> EchoResult { + let message = format!("{} (from {})", request.into_inner().message, self.addr); + + Ok(Response::new(EchoResponse { message })) + } + + type ServerStreamingEchoStream = ResponseStream; + + async fn server_streaming_echo( + &self, + _: Request, + ) -> EchoResult { + Err(Status::unimplemented("not implemented")) + } + + async fn client_streaming_echo( + &self, + _: Request>, + ) -> EchoResult { + Err(Status::unimplemented("not implemented")) + } + + type BidirectionalStreamingEchoStream = ResponseStream; + + async fn bidirectional_streaming_echo( + &self, + _: Request>, + ) -> EchoResult { + Err(Status::unimplemented("not implemented")) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let addrs = ["[::1]:50051", "[::1]:50052"]; + + let (tx, mut rx) = mpsc::unbounded_channel(); + + for addr in &addrs { + let addr = addr.parse()?; + let tx = tx.clone(); + + let server = EchoServer { addr }; + let serve = Server::builder() + .add_service(pb::echo_server::EchoServer::new(server)) + .serve(addr); + + tokio::spawn(async move { + if let Err(e) = serve.await { + eprintln!("Error = {:?}", e); + } + + tx.send(()).unwrap(); + }); + } + + rx.recv().await; + + Ok(()) +} diff --git a/tonic/src/transport/channel/mod.rs b/tonic/src/transport/channel/mod.rs index acbafb2cf..773f8f20f 100644 --- a/tonic/src/transport/channel/mod.rs +++ b/tonic/src/transport/channel/mod.rs @@ -9,7 +9,7 @@ pub use endpoint::Endpoint; #[cfg(feature = "tls")] pub use tls::ClientTlsConfig; -use super::service::{Connection, ServiceList}; +use super::service::{Connection, DynamicServiceStream}; use crate::{body::BoxBody, client::GrpcService}; use bytes::Bytes; use http::{ @@ -20,13 +20,18 @@ use hyper::client::connect::Connection as HyperConnection; use std::{ fmt, future::Future, + hash::Hash, pin::Pin, task::{Context, Poll}, }; -use tokio::io::{AsyncRead, AsyncWrite}; +use tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::mpsc::{channel, Sender}, +}; + use tower::{ buffer::{self, Buffer}, - discover::Discover, + discover::{Change, Discover}, util::{BoxService, Either}, Service, }; @@ -104,17 +109,25 @@ impl Channel { /// This creates a [`Channel`] that will load balance accross all the /// provided endpoints. pub fn balance_list(list: impl Iterator) -> Self { - let list = list.collect::>(); + let (channel, mut tx) = Self::balance_channel(DEFAULT_BUFFER_SIZE); + list.for_each(|endpoint| { + tx.try_send(Change::Insert(endpoint.uri.clone(), endpoint)) + .unwrap(); + }); - let buffer_size = list - .iter() - .next() - .and_then(|e| e.buffer_size) - .unwrap_or(DEFAULT_BUFFER_SIZE); - - let discover = ServiceList::new(list); + channel + } - Self::balance(discover, buffer_size) + /// Balance a list of [`Endpoint`]'s. + /// + /// This creates a [`Channel`] that will listen to a stream of change events and will add or remove provided endpoints. + pub fn balance_channel(capacity: usize) -> (Self, Sender>) + where + K: Hash + Eq + Send + Clone + 'static, + { + let (tx, rx) = channel(capacity); + let list = DynamicServiceStream::new(rx); + (Self::balance(list, DEFAULT_BUFFER_SIZE), tx) } pub(crate) async fn connect(connector: C, endpoint: Endpoint) -> Result diff --git a/tonic/src/transport/service/discover.rs b/tonic/src/transport/service/discover.rs index 072cde5f1..bed7f1310 100644 --- a/tonic/src/transport/service/discover.rs +++ b/tonic/src/transport/service/discover.rs @@ -1,34 +1,36 @@ use super::super::service; use super::connection::Connection; use crate::transport::Endpoint; + use std::{ - collections::VecDeque, - fmt, future::Future, + hash::Hash, pin::Pin, task::{Context, Poll}, }; +use tokio::{stream::Stream, sync::mpsc::Receiver}; + use tower::discover::{Change, Discover}; -pub(crate) struct ServiceList { - list: VecDeque, - connecting: - Option> + Send + 'static>>>, - i: usize, +pub(crate) struct DynamicServiceStream { + changes: Receiver>, + connecting: Option<( + K, + Pin> + Send + 'static>>, + )>, } -impl ServiceList { - pub(crate) fn new(list: Vec) -> Self { +impl DynamicServiceStream { + pub(crate) fn new(changes: Receiver>) -> Self { Self { - list: list.into(), + changes, connecting: None, - i: 0, } } } -impl Discover for ServiceList { - type Key = usize; +impl Discover for DynamicServiceStream { + type Key = K; type Service = Connection; type Error = crate::Error; @@ -37,43 +39,40 @@ impl Discover for ServiceList { cx: &mut Context<'_>, ) -> Poll, Self::Error>> { loop { - if let Some(connecting) = &mut self.connecting { + if let Some((key, connecting)) = &mut self.connecting { let svc = futures_core::ready!(Pin::new(connecting).poll(cx))?; + let key = key.to_owned(); self.connecting = None; - - let i = self.i; - self.i += 1; - - let change = Ok(Change::Insert(i, svc)); - + let change = Ok(Change::Insert(key, svc)); return Poll::Ready(change); - } - - if let Some(endpoint) = self.list.pop_front() { - let mut http = hyper::client::connect::HttpConnector::new(); - http.set_nodelay(endpoint.tcp_nodelay); - http.set_keepalive(endpoint.tcp_keepalive); - http.enforce_http(false); - - #[cfg(feature = "tls")] - let connector = service::connector(http, endpoint.tls.clone()); + }; - #[cfg(not(feature = "tls"))] - let connector = service::connector(http); + let c = &mut self.changes; + match Pin::new(&mut *c).poll_next(cx) { + Poll::Pending => return Poll::Pending, + Poll::Ready(None) => { + return Poll::Pending; + } + Poll::Ready(Some(change)) => match change { + Change::Insert(k, endpoint) => { + let mut http = hyper::client::connect::HttpConnector::new(); + http.set_nodelay(endpoint.tcp_nodelay); + http.set_keepalive(endpoint.tcp_keepalive); + http.enforce_http(false); + #[cfg(feature = "tls")] + let connector = service::connector(http, endpoint.tls.clone()); - let fut = Connection::new(connector, endpoint); - self.connecting = Some(Box::pin(fut)); - } else { - return Poll::Pending; + #[cfg(not(feature = "tls"))] + let connector = service::connector(http); + let fut = Connection::new(connector, endpoint); + self.connecting = Some((k, Box::pin(fut))); + continue; + } + Change::Remove(k) => return Poll::Ready(Ok(Change::Remove(k))), + }, } } } } -impl fmt::Debug for ServiceList { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("ServiceList") - .field("list", &self.list) - .finish() - } -} +impl Unpin for DynamicServiceStream {} diff --git a/tonic/src/transport/service/mod.rs b/tonic/src/transport/service/mod.rs index 3bc691871..92453cdbf 100644 --- a/tonic/src/transport/service/mod.rs +++ b/tonic/src/transport/service/mod.rs @@ -12,7 +12,7 @@ mod tls; pub(crate) use self::add_origin::AddOrigin; pub(crate) use self::connection::Connection; pub(crate) use self::connector::connector; -pub(crate) use self::discover::ServiceList; +pub(crate) use self::discover::DynamicServiceStream; pub(crate) use self::io::ServerIo; pub(crate) use self::layer::ServiceBuilderExt; pub(crate) use self::router::{Or, Routes};