From 16369342fe1b252ec1d5b46c26c213a7e097cbef Mon Sep 17 00:00:00 2001 From: Mark Mandel Date: Wed, 11 Nov 2020 13:25:47 -0800 Subject: [PATCH] EndpointAuthentication filter Implementation of the filter that will route packets to Endpoints that have a matching connection_id to the auth token found in the dynamic metadata. Closes #8 --- docs/extensions/filters/capture_bytes.md | 2 +- .../filters/endpoint_authentication.md | 70 +++++ docs/extensions/filters/filters.md | 3 +- src/config/mod.rs | 6 + .../endpoint_authentication/metrics.rs | 38 +++ .../filters/endpoint_authentication/mod.rs | 275 ++++++++++++++++++ src/extensions/filters/mod.rs | 2 + src/extensions/mod.rs | 1 + src/lib.rs | 1 + tests/endpoint_authentication.rs | 95 ++++++ 10 files changed, 491 insertions(+), 2 deletions(-) create mode 100644 docs/extensions/filters/endpoint_authentication.md create mode 100644 src/extensions/filters/endpoint_authentication/metrics.rs create mode 100644 src/extensions/filters/endpoint_authentication/mod.rs create mode 100644 tests/endpoint_authentication.rs diff --git a/docs/extensions/filters/capture_bytes.md b/docs/extensions/filters/capture_bytes.md index fff32f2f9d..2483f5be23 100644 --- a/docs/extensions/filters/capture_bytes.md +++ b/docs/extensions/filters/capture_bytes.md @@ -6,7 +6,7 @@ down the chain. This is often used as a way of retrieving authentication tokens from a packet, and used in combination with [ConcatenateBytes](./concatenate_bytes.md) and -`[[TODO: add router filter name when ready]]` filter to provide common packet routing utilities. +[EndpointAuthentication](endpoint_authentication.md) filter to provide common packet routing utilities. #### Filter name ```text diff --git a/docs/extensions/filters/endpoint_authentication.md b/docs/extensions/filters/endpoint_authentication.md new file mode 100644 index 0000000000..2377f5023f --- /dev/null +++ b/docs/extensions/filters/endpoint_authentication.md @@ -0,0 +1,70 @@ +# EndpointAuthentication + +The `EndpointAuthentication` filter's job is to ensure only authorised clients are able to send packets to Endpoints that +they have access to. + +It does this via matching an authentication token found in the +[Filter dynamic metadata]`(TODO: add link to dynamic metadata docs)`, and comparing it to Endpoint's connection_id +values, and only letting packets through to those Endpoints if there is a match. + +Capturing the authentication token from an incoming packet can be implemented via the [CaptureByte](./capture_bytes.md) +filter, with an example outlined below, or any other filter that populates the configured dynamic metadata key for the +authentication token to reside. + +On the game client side the [ConcatenateBytes](./concatenate_bytes.md) filter can be used to add authentication tokens +to outgoing packets. + +#### Filter name +```text +quilkin.extensions.filters.endpoint_authentication.v1alpha1.EndpointAuthentication +``` + +### Configuration Examples +```rust +# let yaml = " +local: + port: 7000 +filters: + - name: quilkin.extensions.filters.capture_bytes.v1alpha1.CaptureBytes # This filter is often used in conjunction to capture the authentication token + config: + metadataKey: myapp.com/myownkey + size: 3 + remove: true + - name: quilkin.extensions.filters.endpoint_authentication.v1alpha1.EndpointAuthentication + config: + metadataKey: myapp.com/myownkey +server: + endpoints: + - name: Game Server No. 1 + address: 127.0.0.1:26000 + connection_ids: + - MXg3aWp5Ng== # Authentication is provided by these ids, and matched against + - OGdqM3YyaQ== # the value stored in Filter dynamic metadata + - name: Game Server No. 2 + address: 127.0.0.1:26001 + connection_ids: + - bmt1eTcweA== +# "; +# let config = quilkin::config::Config::from_reader(yaml.as_bytes()).unwrap(); +# assert_eq!(config.filters.len(), 2); +# quilkin::proxy::Builder::from(std::sync::Arc::new(config)).validate().unwrap(); +``` + +View the [CaptureBytes](./capture_bytes.md) filter documentation for more details. + +### Configuration Options + +```yaml +properties: + metadataKey: + type: string + default: quilkin.dev/captured_bytes + description: | + The key under which the captured bytes are stored in the Filter invocation values. +``` + +### Metrics + +* `quilkin_filter_EndpointAuthentication_packets_dropped` + A counter of the total number of packets that have been dropped as they could not be authenticated against an + Endpoint. diff --git a/docs/extensions/filters/filters.md b/docs/extensions/filters/filters.md index 872f1fa68d..620934495b 100644 --- a/docs/extensions/filters/filters.md +++ b/docs/extensions/filters/filters.md @@ -66,7 +66,8 @@ Quilkin includes several filters out of the box. | [Debug](debug.md) | Logs every packet | | [LocalRateLimiter](./local_rate_limit.md) | Limit the frequency of packets. | | [ConcatenateBytes](./concatenate_bytes.md) | Add authentication tokens to packets. | -| [CaptureBytes](capture_bytes.md) | Capture bytes from a packet into the Filter Context. | +| [CaptureBytes](capture_bytes.md) | Capture specific bytes from a packet and store them in filter dynamic metadata. | +| [EndpointAuthentication](endpoint_authentication.md) | Only sends packets to Endpoints that they are authenticated to access. | ### FilterConfig Represents configuration for a filter instance. diff --git a/src/config/mod.rs b/src/config/mod.rs index 0bfd39f973..e34470808a 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -105,6 +105,12 @@ impl From<&str> for ConnectionId { } } +impl PartialEq> for ConnectionId { + fn eq(&self, other: &Vec) -> bool { + self.0.eq(other) + } +} + /// ConnectionConfig is the configuration for either a Client or Server proxy #[derive(Debug, Deserialize, Serialize)] pub enum ConnectionConfig { diff --git a/src/extensions/filters/endpoint_authentication/metrics.rs b/src/extensions/filters/endpoint_authentication/metrics.rs new file mode 100644 index 0000000000..ce2d863411 --- /dev/null +++ b/src/extensions/filters/endpoint_authentication/metrics.rs @@ -0,0 +1,38 @@ +/* + * Copyright 2020 Google LLC All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +use prometheus::core::{AtomicI64, GenericCounter}; +use prometheus::Result as MetricsResult; +use prometheus::{IntCounter, Registry}; + +use crate::metrics::{filter_opts, CollectorExt}; + +/// Register and manage metrics for this filter +pub(super) struct Metrics { + pub(super) packets_dropped_total: GenericCounter, +} + +impl Metrics { + pub(super) fn new(registry: &Registry) -> MetricsResult { + Ok(Metrics { + packets_dropped_total: IntCounter::with_opts(filter_opts( + "packets_dropped", + "EndpointAuthentication", + "Total number of packets dropped due to invalid connection_id values.", + ))? + .register(registry)?, + }) + } +} diff --git a/src/extensions/filters/endpoint_authentication/mod.rs b/src/extensions/filters/endpoint_authentication/mod.rs new file mode 100644 index 0000000000..79891bfb3a --- /dev/null +++ b/src/extensions/filters/endpoint_authentication/mod.rs @@ -0,0 +1,275 @@ +/* + * Copyright 2020 Google LLC All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +use serde::{Deserialize, Serialize}; +use slog::{error, o, Logger}; + +use crate::extensions::filters::endpoint_authentication::metrics::Metrics; +use crate::extensions::filters::CAPTURED_BYTES; +use crate::extensions::{ + CreateFilterArgs, DownstreamContext, DownstreamResponse, Error, Filter, FilterFactory, + UpstreamContext, UpstreamResponse, +}; + +mod metrics; + +#[derive(Serialize, Deserialize, Debug)] +struct Config { + /// the key to use when retrieving the captured bytes in the filter context + #[serde(rename = "metadataKey")] + #[serde(default = "default_metadata_key")] + metadata_key: String, +} + +/// default value for the context key in the Config +fn default_metadata_key() -> String { + CAPTURED_BYTES.into() +} + +impl Default for Config { + fn default() -> Self { + Self { + metadata_key: default_metadata_key(), + } + } +} + +struct EndpointAuthentication { + log: Logger, + values_key: String, + metrics: Metrics, +} + +/// Factory for the EndpointAuthentication filter that only allows packets to be passed to Endpoints that have a matching +/// connection_id to the token stored in the Filter's dynamic metadata. +pub struct EndpointAuthenticationFactory { + log: Logger, +} + +impl EndpointAuthenticationFactory { + pub fn new(base: &Logger) -> Self { + EndpointAuthenticationFactory { log: base.clone() } + } +} + +impl FilterFactory for EndpointAuthenticationFactory { + fn name(&self) -> String { + "quilkin.extensions.filters.endpoint_authentication.v1alpha1.EndpointAuthentication".into() + } + + fn create_filter(&self, args: CreateFilterArgs) -> Result, Error> { + let config: Option = serde_yaml::to_string(&args.config) + .and_then(|raw_config| serde_yaml::from_str(raw_config.as_str())) + .map_err(|err| Error::DeserializeFailed(err.to_string()))?; + + Ok(Box::new(EndpointAuthentication::new( + &self.log, + config.unwrap_or_default(), + Metrics::new(&args.metrics_registry)?, + ))) + } +} + +impl EndpointAuthentication { + fn new(base: &Logger, config: Config, metrics: Metrics) -> Self { + Self { + log: base.new(o!("source" => "extensions::EndpointAuthentication")), + values_key: config.metadata_key, + metrics, + } + } +} + +impl Filter for EndpointAuthentication { + fn on_downstream_receive(&self, mut ctx: DownstreamContext) -> Option { + match ctx.metadata.get(self.values_key.as_str()) { + None => { + error!(self.log, "Value key not found in DownstreamContext"; "key" => self.values_key.clone()); + self.metrics.packets_dropped_total.inc(); + None + } + Some(value) => match value.downcast_ref::>() { + Some(connection_id) => { + ctx.endpoints + .retain(|e| e.connection_ids.iter().any(|id| id == connection_id)); + if ctx.endpoints.is_empty() { + self.metrics.packets_dropped_total.inc(); + return None; + } + Some(ctx.into()) + } + None => { + error!(self.log, "Type of value stored in DownstreamContext.values is not Vec"; + "key" => self.values_key.clone()); + self.metrics.packets_dropped_total.inc(); + None + } + }, + } + } + fn on_upstream_receive(&self, ctx: UpstreamContext) -> Option { + Some(ctx.into()) + } +} + +#[cfg(test)] +mod tests { + use std::ops::Deref; + + use prometheus::Registry; + use serde_yaml::{Mapping, Value}; + + use crate::config::{ConnectionConfig, ConnectionId, EndPoint}; + use crate::test_utils::{assert_filter_on_upstream_receive_no_change, logger}; + + use super::*; + + const TOKEN_KEY: &str = "TOKEN"; + + fn router(config: Config) -> EndpointAuthentication { + EndpointAuthentication::new( + &logger(), + config, + Metrics::new(&Registry::default()).unwrap(), + ) + } + + #[test] + fn factory_custom_tokens() { + let factory = EndpointAuthenticationFactory::new(&logger()); + let connection = ConnectionConfig::Server { endpoints: vec![] }; + let mut map = Mapping::new(); + map.insert( + Value::String("metadataKey".into()), + Value::String(TOKEN_KEY.into()), + ); + + let filter = factory + .create_filter(CreateFilterArgs::new( + &connection, + Some(&Value::Mapping(map)), + )) + .unwrap(); + let mut ctx = new_ctx(); + ctx.metadata + .insert(TOKEN_KEY.into(), Box::new(b"123".to_vec())); + assert_on_downstream_receive(filter.deref(), ctx); + } + + #[test] + fn factory_empty_config() { + let factory = EndpointAuthenticationFactory::new(&logger()); + let connection = ConnectionConfig::Server { endpoints: vec![] }; + let map = Mapping::new(); + + let filter = factory + .create_filter(CreateFilterArgs::new( + &connection, + Some(&Value::Mapping(map)), + )) + .unwrap(); + let mut ctx = new_ctx(); + ctx.metadata + .insert(CAPTURED_BYTES.into(), Box::new(b"123".to_vec())); + assert_on_downstream_receive(filter.deref(), ctx); + } + + #[test] + fn factory_no_config() { + let factory = EndpointAuthenticationFactory::new(&logger()); + let connection = ConnectionConfig::Server { endpoints: vec![] }; + + let filter = factory + .create_filter(CreateFilterArgs::new(&connection, None)) + .unwrap(); + let mut ctx = new_ctx(); + ctx.metadata + .insert(CAPTURED_BYTES.into(), Box::new(b"123".to_vec())); + assert_on_downstream_receive(filter.deref(), ctx); + } + + #[test] + fn downstream_receive() { + // valid key + let config = Config { + metadata_key: CAPTURED_BYTES.into(), + }; + let filter = router(config); + + let mut ctx = new_ctx(); + ctx.metadata + .insert(CAPTURED_BYTES.into(), Box::new(b"123".to_vec())); + assert_on_downstream_receive(&filter, ctx); + + // invalid key + let mut ctx = new_ctx(); + ctx.metadata + .insert(CAPTURED_BYTES.into(), Box::new(b"567".to_vec())); + assert!(filter.on_downstream_receive(ctx).is_none()); + assert_eq!(1, filter.metrics.packets_dropped_total.get()); + + // no key + let ctx = new_ctx(); + assert!(filter.on_downstream_receive(ctx).is_none()); + assert_eq!(2, filter.metrics.packets_dropped_total.get()); + + // wrong type key + let mut ctx = new_ctx(); + ctx.metadata + .insert(CAPTURED_BYTES.into(), Box::new(String::from("wrong"))); + assert!(filter.on_downstream_receive(ctx).is_none()); + assert_eq!(3, filter.metrics.packets_dropped_total.get()); + } + + #[test] + fn on_upstream_receive() { + let config = Config { + metadata_key: CAPTURED_BYTES.into(), + }; + let filter = router(config); + assert_filter_on_upstream_receive_no_change(&filter); + } + + fn new_ctx() -> DownstreamContext { + let endpoint1 = EndPoint::new( + "one".into(), + "127.0.0.1:80".parse().unwrap(), + vec![ConnectionId::from("123")], + ); + let endpoint2 = EndPoint::new( + "two".into(), + "127.0.0.1:90".parse().unwrap(), + vec![ConnectionId::from("456")], + ); + + DownstreamContext::new( + vec![endpoint1, endpoint2], + "127.0.0.1:100".parse().unwrap(), + b"hello".to_vec(), + ) + } + + fn assert_on_downstream_receive(filter: &F, ctx: DownstreamContext) + where + F: Filter + ?Sized, + { + let result = filter.on_downstream_receive(ctx).unwrap(); + + assert_eq!(b"hello".to_vec(), result.contents); + assert_eq!(1, result.endpoints.len()); + assert_eq!("one", result.endpoints[0].name); + } +} diff --git a/src/extensions/filters/mod.rs b/src/extensions/filters/mod.rs index dc0ca864c9..4d19051c80 100644 --- a/src/extensions/filters/mod.rs +++ b/src/extensions/filters/mod.rs @@ -17,12 +17,14 @@ pub use capture_bytes::CaptureBytesFactory; pub use concatenate_bytes::ConcatBytesFactory; pub use debug::DebugFactory; +pub use endpoint_authentication::EndpointAuthenticationFactory; pub use load_balancer::LoadBalancerFilterFactory; pub use local_rate_limit::RateLimitFilterFactory; mod capture_bytes; mod concatenate_bytes; mod debug; +mod endpoint_authentication; mod load_balancer; mod local_rate_limit; diff --git a/src/extensions/mod.rs b/src/extensions/mod.rs index 619a7993cb..854c3fde7c 100644 --- a/src/extensions/mod.rs +++ b/src/extensions/mod.rs @@ -37,5 +37,6 @@ pub fn default_registry(base: &Logger) -> FilterRegistry { fr.insert(filters::ConcatBytesFactory::default()); fr.insert(filters::LoadBalancerFilterFactory::default()); fr.insert(filters::CaptureBytesFactory::new(base)); + fr.insert(filters::EndpointAuthenticationFactory::new(base)); fr } diff --git a/src/lib.rs b/src/lib.rs index 427eed6f69..bd15093e54 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -40,4 +40,5 @@ pub mod external_doc_tests { #![doc(include = "../docs/extensions/filters/debug.md")] #![doc(include = "../docs/extensions/filters/concatenate_bytes.md")] #![doc(include = "../docs/extensions/filters/capture_bytes.md")] + #![doc(include = "../docs/extensions/filters/endpoint_authentication.md")] } diff --git a/tests/endpoint_authentication.rs b/tests/endpoint_authentication.rs new file mode 100644 index 0000000000..f47d6c05e6 --- /dev/null +++ b/tests/endpoint_authentication.rs @@ -0,0 +1,95 @@ +/* + * Copyright 2020 Google LLC All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#[cfg(test)] +mod tests { + use std::net::{IpAddr, Ipv4Addr, SocketAddr}; + + use slog::debug; + use tokio::select; + use tokio::time::{delay_for, Duration}; + + use quilkin::config::{Builder, ConnectionConfig, ConnectionId, EndPoint, Filter, Local}; + use quilkin::extensions::filters::{CaptureBytesFactory, EndpointAuthenticationFactory}; + use quilkin::extensions::FilterFactory; + use quilkin::test_utils::{logger, TestHelper}; + + /// This test covers both endpoint_authentication and capture_bytes filters, + /// since they work in concert together. + #[tokio::test] + async fn endpoint_authentication() { + let log = logger(); + let mut t = TestHelper::default(); + let echo = t.run_echo_server().await; + + let capture_yaml = " +size: 3 +remove: true +"; + let server_port = 12348; + let server_config = Builder::empty() + .with_local(Local { port: server_port }) + .with_filters(vec![ + Filter { + name: CaptureBytesFactory::new(&log).name(), + config: serde_yaml::from_str(capture_yaml).unwrap(), + }, + Filter { + name: EndpointAuthenticationFactory::new(&log).name(), + config: None, + }, + ]) + .with_connections(ConnectionConfig::Server { + endpoints: vec![EndPoint { + name: "server".to_string(), + address: echo, + connection_ids: vec![ConnectionId::from("abc")], + }], + }) + .build(); + server_config.validate().unwrap(); + t.run_server(server_config); + + // valid packet + let (mut recv_chan, mut send) = t.open_socket_and_recv_multiple_packets().await; + + let local_addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), server_port); + let msg = b"helloabc"; + debug!(log, "sending message"; "content" => format!("{:?}", msg)); + send.send_to(msg, &local_addr).await.unwrap(); + + select! { + res = recv_chan.recv() => { + assert_eq!("hello", res.unwrap()); + } + _ = delay_for(Duration::from_secs(5)) => { + unreachable!("should have received a packet"); + } + }; + + // send an invalid packet + let msg = b"helloxyz"; + debug!(log, "sending message"; "content" => format!("{:?}", msg)); + send.send_to(msg, &local_addr).await.unwrap(); + + select! { + _ = recv_chan.recv() => { + unreachable!("should not have received a packet") + } + _ = delay_for(Duration::from_secs(3)) => {} + }; + } +}