Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(transport): Dynamic load balancing #341

Merged
10 changes: 9 additions & 1 deletion examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,14 @@ path = "src/load_balance/client.rs"
name = "load-balance-server"
path = "src/load_balance/server.rs"

[[bin]]
name = "load-balance-client-discovery"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do you think about naming this dynamic-load-balance? Might be a bit more clear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

path = "src/load_balance_with_discovery/client.rs"

[[bin]]
name = "load-balance-server-discovery"
path = "src/load_balance_with_discovery/server.rs"

[[bin]]
name = "tls-client"
path = "src/tls/client.rs"
Expand Down Expand Up @@ -123,7 +131,7 @@ serde_json = "1.0"
rand = "0.7"
# Tracing
tracing = "0.1"
tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] }
tracing-subscriber = { version = "0.2", features = ["tracing-log"] }
tracing-attributes = "0.1"
tracing-futures = "0.2"
# Required for wellknown types
Expand Down
81 changes: 81 additions & 0 deletions examples/src/load_balance_with_discovery/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
pub mod pb {
tonic::include_proto!("grpc.examples.echo");
}

use pb::{echo_client::EchoClient, EchoRequest};
use tonic::transport::Channel;

use tonic::transport::Endpoint;

use std::sync::Arc;

use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
use tokio::time::timeout;
use tower::discover::Change;

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let e1 = Endpoint::from_static("http://[::1]:50051").timeout(std::time::Duration::from_secs(1));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the timeout here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just me playing around, removed

let e2 = Endpoint::from_static("http://[::1]:50052").timeout(std::time::Duration::from_secs(1));

let (channel, mut rx) = Channel::balance_channel(10);
let mut client = EchoClient::new(channel);

let done = Arc::new(AtomicBool::new(false));
let demo_done = done.clone();
tokio::spawn(async move {
tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await;
println!("Added first endpoint");
let change = Change::Insert("1", e1);
let res = rx.send(change).await;
println!("{:?}", res);
tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await;
println!("Added second endpoint");
let change = Change::Insert("2", e2);
let res = rx.send(change).await;
println!("{:?}", res);
tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await;
println!("Removed first endpoint");
let change = Change::Remove("1");
let res = rx.send(change).await;
println!("{:?}", res);

tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await;
println!("Removed second endpoint");
let change = Change::Remove("2");
let res = rx.send(change).await;
println!("{:?}", res);

tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await;
println!("Added third endpoint");
let e3 = Endpoint::from_static("http://[::1]:50051");
let change = Change::Insert("3", e3);
let res = rx.send(change).await;
println!("{:?}", res);

tokio::time::delay_for(tokio::time::Duration::from_secs(5)).await;
println!("Removed third endpoint");
let change = Change::Remove("3");
let res = rx.send(change).await;
println!("{:?}", res);
demo_done.swap(true, SeqCst);
});

while !done.load(SeqCst) {
tokio::time::delay_for(tokio::time::Duration::from_millis(500)).await;
let request = tonic::Request::new(EchoRequest {
message: "hello".into(),
});

let rx = client.unary_echo(request);
if let Ok(resp) = timeout(tokio::time::Duration::from_secs(10), rx).await {
println!("RESPONSE={:?}", resp);
} else {
println!("did not receive value within 10 secs");
}
}

println!("... Bye");

Ok(())
}
82 changes: 82 additions & 0 deletions examples/src/load_balance_with_discovery/server.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
pub mod pb {
tonic::include_proto!("grpc.examples.echo");
}

use futures::Stream;
use std::net::SocketAddr;
use std::pin::Pin;
use tokio::sync::mpsc;
use tonic::{transport::Server, Request, Response, Status, Streaming};

use pb::{EchoRequest, EchoResponse};

type EchoResult<T> = Result<Response<T>, Status>;
type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send + Sync>>;

#[derive(Debug)]
pub struct EchoServer {
addr: SocketAddr,
}

#[tonic::async_trait]
impl pb::echo_server::Echo for EchoServer {
async fn unary_echo(&self, request: Request<EchoRequest>) -> EchoResult<EchoResponse> {
let message = format!("{} (from {})", request.into_inner().message, self.addr);

Ok(Response::new(EchoResponse { message }))
}

type ServerStreamingEchoStream = ResponseStream;

async fn server_streaming_echo(
&self,
_: Request<EchoRequest>,
) -> EchoResult<Self::ServerStreamingEchoStream> {
Err(Status::unimplemented("not implemented"))
}

async fn client_streaming_echo(
&self,
_: Request<Streaming<EchoRequest>>,
) -> EchoResult<EchoResponse> {
Err(Status::unimplemented("not implemented"))
}

type BidirectionalStreamingEchoStream = ResponseStream;

async fn bidirectional_streaming_echo(
&self,
_: Request<Streaming<EchoRequest>>,
) -> EchoResult<Self::BidirectionalStreamingEchoStream> {
Err(Status::unimplemented("not implemented"))
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addrs = ["[::1]:50051", "[::1]:50052"];

let (tx, mut rx) = mpsc::unbounded_channel();

for addr in &addrs {
let addr = addr.parse()?;
let tx = tx.clone();

let server = EchoServer { addr };
let serve = Server::builder()
.add_service(pb::echo_server::EchoServer::new(server))
.serve(addr);

tokio::spawn(async move {
if let Err(e) = serve.await {
eprintln!("Error = {:?}", e);
}

tx.send(()).unwrap();
});
}

rx.recv().await;

Ok(())
}
33 changes: 22 additions & 11 deletions tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@ pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
pub use tls::ClientTlsConfig;

use super::service::{Connection, ServiceList};
use super::service::{Connection, DynamicServiceStream};
use crate::{body::BoxBody, client::GrpcService};
use bytes::Bytes;
use http::{
uri::{InvalidUri, Uri},
Request, Response,
};
use hyper::client::connect::Connection as HyperConnection;
use std::hash::Hash;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor nit: can we move this import into the one below?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

use std::{
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite};

use tower::{
buffer::{self, Buffer},
discover::Discover,
discover::{Change, Discover},
util::{BoxService, Either},
Service,
};
Expand Down Expand Up @@ -104,17 +106,26 @@ impl Channel {
/// This creates a [`Channel`] that will load balance accross all the
/// provided endpoints.
pub fn balance_list(list: impl Iterator<Item = Endpoint>) -> Self {
let list = list.collect::<Vec<_>>();

let buffer_size = list
.iter()
.next()
.and_then(|e| e.buffer_size)
.unwrap_or(DEFAULT_BUFFER_SIZE);
let (channel, mut tx) = Self::balance_channel(DEFAULT_BUFFER_SIZE);
list.for_each(|endpoint| {
let _res = tx.try_send(Change::Insert(endpoint.uri.clone(), endpoint));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably should just unwrap here if there is an error we want to know because that would mean there is a bug in this code!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

});

let discover = ServiceList::new(list);
channel
}

Self::balance(discover, buffer_size)
/// Balance a list of [`Endpoint`]'s.
///
/// This creates a [`Channel`] that will listen to a stream of change events and will add or remove provided endpoints.
pub fn balance_channel<K>(
capacity: usize,
) -> (Self, tokio::sync::mpsc::Sender<Change<K, Endpoint>>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

might want to just import this type, its making this signature a bit long 😄

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

where
K: Hash + Eq + Send + Clone + 'static,
{
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
let list = DynamicServiceStream::new(rx);
(Self::balance(list, DEFAULT_BUFFER_SIZE), tx)
}

pub(crate) async fn connect<C>(connector: C, endpoint: Endpoint) -> Result<Self, super::Error>
Expand Down
83 changes: 40 additions & 43 deletions tonic/src/transport/service/discover.rs
Original file line number Diff line number Diff line change
@@ -1,34 +1,34 @@
use super::super::service;
use super::connection::Connection;
use crate::transport::Endpoint;
use std::hash::Hash;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

use std::{
collections::VecDeque,
fmt,
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::stream::Stream;
use tower::discover::{Change, Discover};

pub(crate) struct ServiceList {
list: VecDeque<Endpoint>,
connecting:
Option<Pin<Box<dyn Future<Output = Result<Connection, crate::Error>> + Send + 'static>>>,
i: usize,
pub(crate) struct DynamicServiceStream<K: Hash + Eq + Clone> {
changes: tokio::sync::mpsc::Receiver<Change<K, Endpoint>>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here can we import this type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

connecting: Option<(
K,
Pin<Box<dyn Future<Output = Result<Connection, crate::Error>> + Send + 'static>>,
)>,
}

impl ServiceList {
pub(crate) fn new(list: Vec<Endpoint>) -> Self {
impl<K: Hash + Eq + Clone> DynamicServiceStream<K> {
pub(crate) fn new(changes: tokio::sync::mpsc::Receiver<Change<K, Endpoint>>) -> Self {
Self {
list: list.into(),
changes,
connecting: None,
i: 0,
}
}
}

impl Discover for ServiceList {
type Key = usize;
impl<K: Hash + Eq + Clone> Discover for DynamicServiceStream<K> {
type Key = K;
type Service = Connection;
type Error = crate::Error;

Expand All @@ -37,43 +37,40 @@ impl Discover for ServiceList {
cx: &mut Context<'_>,
) -> Poll<Result<Change<Self::Key, Self::Service>, Self::Error>> {
loop {
if let Some(connecting) = &mut self.connecting {
if let Some((key, connecting)) = &mut self.connecting {
let svc = futures_core::ready!(Pin::new(connecting).poll(cx))?;
let key = key.to_owned();
self.connecting = None;

let i = self.i;
self.i += 1;

let change = Ok(Change::Insert(i, svc));

let change = Ok(Change::Insert(key, svc));
return Poll::Ready(change);
}
};

if let Some(endpoint) = self.list.pop_front() {
let mut http = hyper::client::connect::HttpConnector::new();
http.set_nodelay(endpoint.tcp_nodelay);
http.set_keepalive(endpoint.tcp_keepalive);
http.enforce_http(false);
let c = &mut self.changes;
match Pin::new(&mut *c).poll_next(cx) {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => {
return Poll::Pending;
}
Poll::Ready(Some(change)) => match change {
Change::Insert(k, endpoint) => {
let mut http = hyper::client::connect::HttpConnector::new();
http.set_nodelay(endpoint.tcp_nodelay);
http.set_keepalive(endpoint.tcp_keepalive);
http.enforce_http(false);
#[cfg(feature = "tls")]
let connector = service::connector(http, endpoint.tls.clone());

#[cfg(feature = "tls")]
let connector = service::connector(http, endpoint.tls.clone());

#[cfg(not(feature = "tls"))]
let connector = service::connector(http);

let fut = Connection::new(connector, endpoint);
self.connecting = Some(Box::pin(fut));
} else {
return Poll::Pending;
#[cfg(not(feature = "tls"))]
let connector = service::connector(http);
let fut = Connection::new(connector, endpoint);
self.connecting = Some((k, Box::pin(fut)));
continue;
}
Change::Remove(k) => return Poll::Ready(Ok(Change::Remove(k))),
},
}
}
}
}

impl fmt::Debug for ServiceList {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ServiceList")
.field("list", &self.list)
.finish()
}
}
impl<K: Hash + Eq + Clone> Unpin for DynamicServiceStream<K> {}
2 changes: 1 addition & 1 deletion tonic/src/transport/service/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ mod tls;
pub(crate) use self::add_origin::AddOrigin;
pub(crate) use self::connection::Connection;
pub(crate) use self::connector::connector;
pub(crate) use self::discover::ServiceList;
pub(crate) use self::discover::DynamicServiceStream;
pub(crate) use self::io::ServerIo;
pub(crate) use self::layer::ServiceBuilderExt;
pub(crate) use self::router::{Or, Routes};
Expand Down