diff --git a/src/cli/proxy.rs b/src/cli/proxy.rs index 12e7bbefff..e29d01df65 100644 --- a/src/cli/proxy.rs +++ b/src/cli/proxy.rs @@ -118,10 +118,10 @@ impl Proxy { }); } - if config.clusters.read().endpoints().count() == 0 && self.management_server.is_empty() { + if !config.clusters.read().has_endpoints() && self.management_server.is_empty() { return Err(eyre::eyre!( - "`quilkin proxy` requires at least one `to` address or `management_server` endpoint." - )); + "`quilkin proxy` requires at least one `to` address or `management_server` endpoint." + )); } let id = config.id.load(); @@ -263,7 +263,7 @@ impl RuntimeConfig { .read() .as_ref() .map_or(true, |health| health.load(Ordering::SeqCst)) - && config.clusters.read().endpoints().count() != 0 + && config.clusters.read().has_endpoints() } } @@ -393,17 +393,20 @@ impl DownstreamReceiveWorkerConfig { config: &Arc, sessions: &Arc, ) -> Result { - let endpoints: Vec<_> = config.clusters.read().endpoints().collect(); - if endpoints.is_empty() { + if !config.clusters.read().has_endpoints() { return Err(PipelineError::NoUpstreamEndpoints); } let filters = config.filters.load(); - let mut context = ReadContext::new(endpoints, packet.source.into(), packet.contents); + let mut context = ReadContext::new( + config.clusters.clone_value(), + packet.source.into(), + packet.contents, + ); filters.read(&mut context).await?; let mut bytes_written = 0; - for endpoint in context.endpoints.iter() { + for endpoint in context.destinations.iter() { let session_key = SessionKey { source: packet.source, dest: endpoint.address.to_socket_addr().await?, diff --git a/src/config.rs b/src/config.rs index 8da5296ee7..cc2f689f08 100644 --- a/src/config.rs +++ b/src/config.rs @@ -190,7 +190,7 @@ impl Config { pub fn apply_metrics(&self) { let clusters = self.clusters.read(); crate::net::cluster::active_clusters().set(clusters.len() as i64); - crate::net::cluster::active_endpoints().set(clusters.endpoints().count() as i64); + crate::net::cluster::active_endpoints().set(clusters.num_of_endpoints() as i64); } } diff --git a/src/config/watch.rs b/src/config/watch.rs index bb0a00ab46..4c38b1952e 100644 --- a/src/config/watch.rs +++ b/src/config/watch.rs @@ -38,6 +38,10 @@ impl Watch { pub fn watch(&self) -> watch::Receiver { self.watchers.subscribe() } + + pub fn clone_value(&self) -> std::sync::Arc { + self.value.clone() + } } impl Watch { diff --git a/src/filters/capture.rs b/src/filters/capture.rs index cd7e30211f..0f90c1ddb3 100644 --- a/src/filters/capture.rs +++ b/src/filters/capture.rs @@ -159,10 +159,12 @@ mod tests { }), }; let filter = Capture::from_config(config.into()); - let endpoints = vec![Endpoint::new("127.0.0.1:81".parse().unwrap())]; + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), + ); assert!(filter .read(&mut ReadContext::new( - endpoints, + endpoints.into(), (std::net::Ipv4Addr::LOCALHOST, 80).into(), "abc".to_string().into_bytes(), )) @@ -235,9 +237,11 @@ mod tests { where F: Filter + ?Sized, { - let endpoints = vec![Endpoint::new("127.0.0.1:81".parse().unwrap())]; + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), + ); let mut context = ReadContext::new( - endpoints, + endpoints.into(), "127.0.0.1:80".parse().unwrap(), "helloabc".to_string().into_bytes(), ); diff --git a/src/filters/chain.rs b/src/filters/chain.rs index 865e587396..3435f0cd91 100644 --- a/src/filters/chain.rs +++ b/src/filters/chain.rs @@ -278,6 +278,17 @@ impl Filter for FilterChain { } } + // Special case to handle to allow for pass-through, if no filter + // has rejected, and the destinations is empty, we passthrough to all. + // Which mimics the old behaviour while avoid clones in most cases. + if ctx.destinations.is_empty() { + ctx.destinations = ctx + .endpoints + .iter() + .flat_map(|e| e.value().iter().cloned().collect::>()) + .collect(); + } + Ok(()) } @@ -340,11 +351,15 @@ mod tests { assert!(result.is_err()); } - fn endpoints() -> Vec { - vec![ - Endpoint::new("127.0.0.1:80".parse().unwrap()), - Endpoint::new("127.0.0.1:90".parse().unwrap()), - ] + fn endpoints() -> std::sync::Arc { + crate::net::cluster::ClusterMap::new_default( + [ + Endpoint::new("127.0.0.1:80".parse().unwrap()), + Endpoint::new("127.0.0.1:90".parse().unwrap()), + ] + .into(), + ) + .into() } #[tokio::test] @@ -361,7 +376,10 @@ mod tests { config.filters.read(&mut context).await.unwrap(); let expected = endpoints_fixture.clone(); - assert_eq!(expected, &*context.endpoints); + assert_eq!( + &*expected.endpoints().collect::>(), + &*context.destinations + ); assert_eq!(b"hello:odr:127.0.0.1:70", &*context.contents); assert_eq!( "receive", @@ -369,7 +387,12 @@ mod tests { ); let mut context = WriteContext::new( - endpoints_fixture[0].address.clone(), + endpoints_fixture + .endpoints() + .next() + .unwrap() + .address + .clone(), "127.0.0.1:70".parse().unwrap(), b"hello".to_vec(), ); @@ -405,7 +428,10 @@ mod tests { chain.read(&mut context).await.unwrap(); let expected = endpoints_fixture.clone(); - assert_eq!(expected, context.endpoints.to_vec()); + assert_eq!( + expected.endpoints().collect::>(), + context.destinations + ); assert_eq!( b"hello:odr:127.0.0.1:70:odr:127.0.0.1:70", &*context.contents @@ -416,7 +442,12 @@ mod tests { ); let mut context = WriteContext::new( - endpoints_fixture[0].address.clone(), + endpoints_fixture + .endpoints() + .next() + .unwrap() + .address + .clone(), "127.0.0.1:70".parse().unwrap(), b"hello".to_vec(), ); diff --git a/src/filters/compress.rs b/src/filters/compress.rs index d9fc863877..0d6dcdcfda 100644 --- a/src/filters/compress.rs +++ b/src/filters/compress.rs @@ -176,8 +176,11 @@ mod tests { let expected = contents_fixture(); // read compress + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), + ); let mut read_context = ReadContext::new( - vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], + endpoints.into(), "127.0.0.1:8080".parse().unwrap(), expected.clone(), ); @@ -238,9 +241,12 @@ mod tests { Metrics::new(), ); + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), + ); assert!(compression .read(&mut ReadContext::new( - vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], + endpoints.into(), "127.0.0.1:8080".parse().unwrap(), b"hello".to_vec(), )) @@ -259,8 +265,11 @@ mod tests { Metrics::new(), ); + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), + ); let mut read_context = ReadContext::new( - vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], + endpoints.into(), "127.0.0.1:8080".parse().unwrap(), b"hello".to_vec(), ); @@ -345,8 +354,11 @@ mod tests { ); // read decompress + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), + ); let mut read_context = ReadContext::new( - vec![Endpoint::new("127.0.0.1:80".parse().unwrap())], + endpoints.into(), "127.0.0.1:8080".parse().unwrap(), write_context.contents.clone(), ); diff --git a/src/filters/firewall.rs b/src/filters/firewall.rs index f8d9aa5bae..597de7de9e 100644 --- a/src/filters/firewall.rs +++ b/src/filters/firewall.rs @@ -137,18 +137,16 @@ mod tests { }; let local_ip = [192, 168, 75, 20]; - let mut ctx = ReadContext::new( - vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())], - (local_ip, 80).into(), - vec![], + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())].into(), ); + let mut ctx = ReadContext::new(endpoints.into(), (local_ip, 80).into(), vec![]); assert!(firewall.read(&mut ctx).await.is_ok()); - let mut ctx = ReadContext::new( - vec![Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())], - (local_ip, 2000).into(), - vec![], + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new((Ipv4Addr::LOCALHOST, 8080).into())].into(), ); + let mut ctx = ReadContext::new(endpoints.into(), (local_ip, 2000).into(), vec![]); assert!(logs_contain("quilkin::filters::firewall")); // the given name to the the logger by tracing assert!(logs_contain("Allow")); diff --git a/src/filters/load_balancer.rs b/src/filters/load_balancer.rs index 728fe4cd1f..02e8696076 100644 --- a/src/filters/load_balancer.rs +++ b/src/filters/load_balancer.rs @@ -68,16 +68,18 @@ mod tests { input_addresses: &[EndpointAddress], source: EndpointAddress, ) -> Vec { - let mut context = ReadContext::new( - Vec::from_iter(input_addresses.iter().cloned().map(Endpoint::new)), - source, - vec![], - ); + let endpoints = input_addresses + .iter() + .cloned() + .map(Endpoint::new) + .collect::>(); + let endpoints = crate::net::cluster::ClusterMap::new_default(endpoints); + let mut context = ReadContext::new(endpoints.into(), source, vec![]); filter.read(&mut context).await.unwrap(); context - .endpoints + .destinations .iter() .map(|ep| ep.address.clone()) .collect::>() diff --git a/src/filters/load_balancer/endpoint_chooser.rs b/src/filters/load_balancer/endpoint_chooser.rs index 1f2321af9d..028304ae55 100644 --- a/src/filters/load_balancer/endpoint_chooser.rs +++ b/src/filters/load_balancer/endpoint_chooser.rs @@ -48,7 +48,11 @@ impl EndpointChooser for RoundRobinEndpointChooser { fn choose_endpoints(&self, ctx: &mut ReadContext) { let count = self.next_endpoint.fetch_add(1, Ordering::Relaxed); // Note: The index is guaranteed to be in range. - ctx.endpoints = vec![ctx.endpoints[count % ctx.endpoints.len()].clone()]; + ctx.destinations = vec![ctx + .endpoints + .nth_endpoint(count % ctx.endpoints.num_of_endpoints()) + .unwrap() + .clone()]; } } @@ -58,8 +62,8 @@ pub struct RandomEndpointChooser; impl EndpointChooser for RandomEndpointChooser { fn choose_endpoints(&self, ctx: &mut ReadContext) { // The index is guaranteed to be in range. - let index = thread_rng().gen_range(0..ctx.endpoints.len()); - ctx.endpoints = vec![ctx.endpoints[index].clone()]; + let index = thread_rng().gen_range(0..ctx.endpoints.num_of_endpoints()); + ctx.destinations = vec![ctx.endpoints.nth_endpoint(index).unwrap().clone()]; } } @@ -70,6 +74,10 @@ impl EndpointChooser for HashEndpointChooser { fn choose_endpoints(&self, ctx: &mut ReadContext) { let mut hasher = DefaultHasher::new(); ctx.source.hash(&mut hasher); - ctx.endpoints = vec![ctx.endpoints[hasher.finish() as usize % ctx.endpoints.len()].clone()]; + ctx.destinations = vec![ctx + .endpoints + .nth_endpoint(hasher.finish() as usize % ctx.endpoints.num_of_endpoints()) + .unwrap() + .clone()]; } } diff --git a/src/filters/local_rate_limit.rs b/src/filters/local_rate_limit.rs index 2ec14b4d5c..cc6448e137 100644 --- a/src/filters/local_rate_limit.rs +++ b/src/filters/local_rate_limit.rs @@ -222,11 +222,14 @@ mod tests { /// Send a packet to the filter and assert whether or not it was processed. async fn read(r: &LocalRateLimit, address: &EndpointAddress, should_succeed: bool) { - let endpoints = vec![crate::net::endpoint::Endpoint::new( - (Ipv4Addr::LOCALHOST, 8089).into(), - )]; - - let mut context = ReadContext::new(endpoints, address.clone(), vec![9]); + let endpoints = crate::net::cluster::ClusterMap::new_default( + [crate::net::endpoint::Endpoint::new( + (Ipv4Addr::LOCALHOST, 8089).into(), + )] + .into(), + ); + + let mut context = ReadContext::new(endpoints.into(), address.clone(), vec![9]); let result = r.read(&mut context).await; if should_succeed { diff --git a/src/filters/match.rs b/src/filters/match.rs index 2a87961b67..4b0c929bc5 100644 --- a/src/filters/match.rs +++ b/src/filters/match.rs @@ -205,8 +205,11 @@ mod tests { assert_eq!(0, filter.metrics.packets_matched_total.get()); // config so we can test match and fallthrough. + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), + ); let mut ctx = ReadContext::new( - vec![Default::default()], + endpoints.into(), ([127, 0, 0, 1], 7000).into(), contents.clone(), ); @@ -216,11 +219,10 @@ mod tests { assert_eq!(1, filter.metrics.packets_matched_total.get()); assert_eq!(0, filter.metrics.packets_fallthrough_total.get()); - let mut ctx = ReadContext::new( - vec![Default::default()], - ([127, 0, 0, 1], 7000).into(), - contents, + let endpoints = crate::net::cluster::ClusterMap::new_default( + [Endpoint::new("127.0.0.1:81".parse().unwrap())].into(), ); + let mut ctx = ReadContext::new(endpoints.into(), ([127, 0, 0, 1], 7000).into(), contents); ctx.metadata.insert(key, "xyz".into()); let result = filter.read(&mut ctx).await; diff --git a/src/filters/read.rs b/src/filters/read.rs index e2f025a185..00880c259b 100644 --- a/src/filters/read.rs +++ b/src/filters/read.rs @@ -14,15 +14,22 @@ * limitations under the License. */ +use std::sync::Arc; + #[cfg(doc)] use crate::filters::Filter; -use crate::net::endpoint::{metadata::DynamicMetadata, Endpoint, EndpointAddress}; +use crate::net::{ + endpoint::{metadata::DynamicMetadata, Endpoint, EndpointAddress}, + ClusterMap, +}; /// The input arguments to [`Filter::read`]. #[non_exhaustive] pub struct ReadContext { /// The upstream endpoints that the packet will be forwarded to. - pub endpoints: Vec, + pub endpoints: Arc, + /// The upstream endpoints that the packet will be forwarded to. + pub destinations: Vec, /// The source of the received packet. pub source: EndpointAddress, /// Contents of the received packet. @@ -33,9 +40,10 @@ pub struct ReadContext { impl ReadContext { /// Creates a new [`ReadContext`]. - pub fn new(endpoints: Vec, source: EndpointAddress, contents: Vec) -> Self { + pub fn new(endpoints: Arc, source: EndpointAddress, contents: Vec) -> Self { Self { endpoints, + destinations: Vec::new(), source, contents, metadata: DynamicMetadata::new(), diff --git a/src/filters/registry.rs b/src/filters/registry.rs index 0228838ccf..d5be9cb1a4 100644 --- a/src/filters/registry.rs +++ b/src/filters/registry.rs @@ -104,9 +104,10 @@ mod tests { let addr: EndpointAddress = (Ipv4Addr::LOCALHOST, 8080).into(); let endpoint = Endpoint::new(addr.clone()); + let endpoints = crate::net::cluster::ClusterMap::new_default([endpoint.clone()].into()); assert!(filter .read(&mut ReadContext::new( - vec![endpoint.clone()], + endpoints.into(), addr.clone(), vec![] )) diff --git a/src/filters/timestamp.rs b/src/filters/timestamp.rs index b1253c09c7..c14b588cf9 100644 --- a/src/filters/timestamp.rs +++ b/src/filters/timestamp.rs @@ -169,7 +169,7 @@ mod tests { const TIMESTAMP_KEY: &str = "BASIC"; let filter = Timestamp::from_config(Config::new(TIMESTAMP_KEY).into()); let mut ctx = ReadContext::new( - vec![], + <_>::default(), (std::net::Ipv4Addr::UNSPECIFIED, 0).into(), b"hello".to_vec(), ); @@ -200,7 +200,7 @@ mod tests { let timestamp = Timestamp::from_config(Config::new(TIMESTAMP_KEY).into()); let source = (std::net::Ipv4Addr::UNSPECIFIED, 0); let mut ctx = ReadContext::new( - vec![], + <_>::default(), source.into(), [0, 0, 0, 0, 99, 81, 55, 181].to_vec(), ); diff --git a/src/filters/token_router.rs b/src/filters/token_router.rs index a4a1959f3f..b07a5ecff6 100644 --- a/src/filters/token_router.rs +++ b/src/filters/token_router.rs @@ -54,7 +54,7 @@ impl Filter for TokenRouter { async fn read(&self, ctx: &mut ReadContext) -> Result<(), FilterError> { match ctx.metadata.get(&self.config.metadata_key) { Some(metadata::Value::Bytes(token)) => { - ctx.endpoints.retain(|endpoint| { + ctx.destinations = ctx.endpoints.filter_endpoints(|endpoint| { if endpoint.metadata.known.tokens.contains(&**token) { tracing::trace!(%endpoint.address, token = &*crate::codec::base64::encode(token), "Endpoint matched"); true @@ -63,7 +63,7 @@ impl Filter for TokenRouter { } }); - if ctx.endpoints.is_empty() { + if ctx.destinations.is_empty() { Err(FilterError::new(Error::NoEndpointMatch( self.config.metadata_key, crate::codec::base64::encode(token), @@ -257,8 +257,10 @@ mod tests { }, ); + let endpoints = crate::net::cluster::ClusterMap::default(); + endpoints.insert_default([endpoint1, endpoint2].into()); ReadContext::new( - vec![endpoint1, endpoint2], + endpoints.into(), "127.0.0.1:100".parse().unwrap(), b"hello".to_vec(), ) diff --git a/src/net/cluster.rs b/src/net/cluster.rs index f5d90ec4cb..47e3e4f47a 100644 --- a/src/net/cluster.rs +++ b/src/net/cluster.rs @@ -127,6 +127,35 @@ impl ClusterMap { self.0.iter() } + #[cfg(test)] + pub fn endpoints(&self) -> impl Iterator + '_ { + self.0 + .iter() + .flat_map(|entry| entry.value().iter().cloned().collect::>()) + } + + pub fn nth_endpoint(&self, mut index: usize) -> Option { + for set in self.iter() { + if index < set.len() { + return set.value().iter().nth(index).cloned(); + } else { + index -= set.len(); + } + } + + None + } + + pub fn filter_endpoints(&self, f: impl Fn(&Endpoint) -> bool) -> Vec { + let mut endpoints = Vec::new(); + + for set in self.iter() { + endpoints.extend(set.iter().filter(|e| (f)(e)).cloned()); + } + + endpoints + } + pub fn entry( &self, key: Option, @@ -138,10 +167,12 @@ impl ClusterMap { self.entry(None).or_default() } - pub fn endpoints(&self) -> impl Iterator + '_ { - self.0 - .iter() - .flat_map(|entry| entry.value().iter().cloned().collect::>()) + pub fn num_of_endpoints(&self) -> usize { + self.0.iter().map(|entry| entry.value().len()).sum() + } + + pub fn has_endpoints(&self) -> bool { + self.num_of_endpoints() != 0 } pub fn update_unlocated_endpoints(&self, locality: Locality) { diff --git a/src/test.rs b/src/test.rs index e9bfa5da84..1c7337592e 100644 --- a/src/test.rs +++ b/src/test.rs @@ -307,13 +307,17 @@ pub async fn assert_filter_read_no_change(filter: &F) where F: Filter, { - let endpoints = vec!["127.0.0.1:80".parse::().unwrap()]; + let endpoints = std::sync::Arc::new(crate::net::cluster::ClusterMap::default()); + endpoints.insert_default(std::collections::BTreeSet::from(["127.0.0.1:80" + .parse::() + .unwrap()])); let source = "127.0.0.1:90".parse().unwrap(); let contents = "hello".to_string().into_bytes(); let mut context = ReadContext::new(endpoints.clone(), source, contents.clone()); filter.read(&mut context).await.unwrap(); - assert_eq!(endpoints, &*context.endpoints); + assert!(context.destinations.is_empty()); + assert_eq!(endpoints, context.endpoints); assert_eq!(contents, &*context.contents); }