Skip to content

Commit

Permalink
Added ConnectionConfig to CreateFilterArgs
Browse files Browse the repository at this point in the history
Some filters will need to be aware of whether they are client/server,
and also authentication connection details.

To provide this functionality, this PR added ConnectionConfig to the
CreateFilterArgs, so it can be passed into FilterFactory's when needed.

Work on #1
  • Loading branch information
markmandel committed Aug 26, 2020
1 parent b8e6b7f commit dc8eadf
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 26 deletions.
4 changes: 2 additions & 2 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ pub struct Local {

/// LoadBalancerPolicy represents how a proxy load-balances
/// traffic between endpoints.
#[derive(Debug, Deserialize, Serialize, Eq, PartialEq)]
#[derive(Debug, Deserialize, Serialize, Eq, PartialEq, Clone)]
pub enum LoadBalancerPolicy {
/// Send all traffic to all endpoints.
#[serde(rename = "BROADCAST")]
Expand Down Expand Up @@ -88,7 +88,7 @@ impl From<&str> for ConnectionId {
}

/// ConnectionConfig is the configuration for either a Client or Server proxy
#[derive(Debug, Deserialize, Serialize)]
#[derive(Debug, Deserialize, Serialize, Clone)]
pub enum ConnectionConfig {
/// Client is the configuration for a client proxy, for sitting behind a game client.
#[serde(rename = "client")]
Expand Down
2 changes: 1 addition & 1 deletion src/extensions/filter_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ impl FilterChain {
for filter_config in &config.filters {
match filter_registry.get(
&filter_config.name,
CreateFilterArgs::from_config(&filter_config.config)
CreateFilterArgs::new(&config.connections, &filter_config.config)
.with_metrics_registry(metrics_registry.clone()),
) {
Ok(filter) => filters.push(filter),
Expand Down
25 changes: 15 additions & 10 deletions src/extensions/filter_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ use std::collections::HashMap;
use std::fmt;
use std::net::SocketAddr;

use prometheus::{Error as MetricsError, Registry};
use serde::export::Formatter;

use crate::config::EndPoint;
use prometheus::{Error as MetricsError, Registry};
use crate::config::{ConnectionConfig, EndPoint};

/// Filter is a trait for routing and manipulating packets.
pub trait Filter: Send + Sync {
Expand Down Expand Up @@ -81,18 +81,21 @@ impl From<MetricsError> for Error {
}

/// Arguments needed to create a new filter.
pub struct CreateFilterArgs<'a> {
pub struct CreateFilterArgs {
/// Configuration for the filter.
pub config: &'a serde_yaml::Value,
pub config: serde_yaml::Value,
/// metrics_registry is used to register filter metrics collectors.
pub metrics_registry: Registry,
/// connection is used to pass the connection configuration
pub connection: ConnectionConfig,
}

impl CreateFilterArgs<'_> {
pub fn from_config(config: &serde_yaml::Value) -> CreateFilterArgs {
impl CreateFilterArgs {
pub fn new(connection: &ConnectionConfig, config: &serde_yaml::Value) -> CreateFilterArgs {
CreateFilterArgs {
config,
config: config.clone(),
metrics_registry: Registry::default(),
connection: connection.clone(),
}
}

Expand Down Expand Up @@ -181,10 +184,12 @@ mod tests {
fn insert_and_get() {
let mut reg = FilterRegistry::default();
reg.insert(TestFilterFactory {});
let config = serde_yaml::Value::Null;
let connection = ConnectionConfig::Server { endpoints: vec![] };

match reg.get(
&String::from("not.found"),
CreateFilterArgs::from_config(&serde_yaml::Value::Null),
CreateFilterArgs::new(&connection, &config),
) {
Ok(_) => assert!(false, "should not be filter"),
Err(err) => assert_eq!(Error::NotFound("not.found".to_string()), err),
Expand All @@ -193,14 +198,14 @@ mod tests {
assert!(reg
.get(
&String::from("TestFilter"),
CreateFilterArgs::from_config(&serde_yaml::Value::Null)
CreateFilterArgs::new(&connection, &config)
)
.is_ok());

let filter = reg
.get(
&String::from("TestFilter"),
CreateFilterArgs::from_config(&serde_yaml::Value::Null),
CreateFilterArgs::new(&connection, &config),
)
.unwrap();

Expand Down
29 changes: 17 additions & 12 deletions src/extensions/filters/debug_filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ impl FilterFactory for DebugFilterFactory {

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
// pull out the Option<&Value>
let prefix = match args.config {
let prefix = match &args.config {
serde_yaml::Value::Mapping(map) => map.get(&serde_yaml::Value::from("id")),
_ => None,
};
Expand Down Expand Up @@ -137,14 +137,16 @@ fn packet_to_string(contents: Vec<u8>) -> String {

#[cfg(test)]
mod tests {
use serde_yaml::Mapping;
use serde_yaml::Value;

use crate::config::ConnectionConfig::Server;
use crate::test_utils::{
assert_filter_on_downstream_receive_no_change, assert_filter_on_upstream_receive_no_change,
logger,
};

use super::*;
use serde_yaml::Mapping;
use serde_yaml::Value;

#[test]
fn on_downstream_receive() {
Expand All @@ -162,34 +164,37 @@ mod tests {
fn from_config_with_id() {
let log = logger();
let mut map = Mapping::new();
let provider = DebugFilterFactory::new(&log);
let connection = Server { endpoints: vec![] };
let factory = DebugFilterFactory::new(&log);

map.insert(Value::from("id"), Value::from("name"));
assert!(provider
.create_filter(CreateFilterArgs::from_config(&Value::Mapping(map),))
assert!(factory
.create_filter(CreateFilterArgs::new(&connection, &Value::Mapping(map),))
.is_ok());
}

#[test]
fn from_config_without_id() {
let log = logger();
let mut map = Mapping::new();
let provider = DebugFilterFactory::new(&log);
let connection = Server { endpoints: vec![] };
let factory = DebugFilterFactory::new(&log);

map.insert(Value::from("id"), Value::from("name"));
assert!(provider
.create_filter(CreateFilterArgs::from_config(&Value::Mapping(map),))
assert!(factory
.create_filter(CreateFilterArgs::new(&connection, &Value::Mapping(map),))
.is_ok());
}

#[test]
fn from_config_should_panic() {
fn from_config_should_error() {
let log = logger();
let mut map = Mapping::new();
let provider = DebugFilterFactory::new(&log);
let connection = Server { endpoints: vec![] };
let factory = DebugFilterFactory::new(&log);

map.insert(Value::from("id"), Value::from(false));
match provider.create_filter(CreateFilterArgs::from_config(&Value::Mapping(map))) {
match factory.create_filter(CreateFilterArgs::new(&connection, &Value::Mapping(map))) {
Ok(_) => assert!(false, "should be an error"),
Err(err) => {
assert_eq!(
Expand Down
2 changes: 1 addition & 1 deletion src/extensions/filters/local_rate_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ impl FilterFactory for RateLimitFilterFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
let config: Config = serde_yaml::to_string(args.config)
let config: Config = serde_yaml::to_string(&args.config)
.and_then(|raw_config| serde_yaml::from_str(raw_config.as_str()))
.map_err(|err| Error::DeserializeFailed(err.to_string()))?;

Expand Down

0 comments on commit dc8eadf

Please sign in to comment.