Skip to content

Commit

Permalink
Reuse config deserialization logic across filters (#182)
Browse files Browse the repository at this point in the history
Updates CreateFilterArgs to contain either static or proto filter
config using an enum.
Add a function to deserialize the config enum, returning the static
version regardless of input type. This function requires that the
static config can be created from the proto type. Update all filters
with proto config type TODO as they do not yet have proto configs.
  • Loading branch information
iffyio authored Feb 2, 2021
1 parent 1bd9996 commit dbe56ea
Show file tree
Hide file tree
Showing 9 changed files with 142 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/extensions/filter_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl FilterChain {
for filter_config in config.source.get_filters() {
match filter_registry.get(
&filter_config.name,
CreateFilterArgs::new(filter_config.config.as_ref())
CreateFilterArgs::fixed(filter_config.config.as_ref())
.with_metrics_registry(metrics_registry.clone()),
) {
Ok(filter) => filters.push(filter),
Expand Down
61 changes: 49 additions & 12 deletions src/extensions/filter_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

use bytes::Bytes;
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
Expand Down Expand Up @@ -204,6 +205,7 @@ pub trait Filter: Send + Sync {
/// Error is an error when attempting to create a Filter from_config() from a FilterFactory
pub enum Error {
NotFound(String),
MissingConfig(String),
FieldInvalid { field: String, reason: String },
DeserializeFailed(String),
InitializeMetricsFailed(String),
Expand All @@ -213,6 +215,9 @@ impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::NotFound(key) => write!(f, "filter {} is not found", key),
Error::MissingConfig(filter_name) => {
write!(f, "filter `{}` requires a configuration", filter_name)
}
Error::FieldInvalid { field, reason } => {
write!(f, "field {} is invalid: {}", field, reason)
}
Expand All @@ -236,18 +241,47 @@ impl From<MetricsError> for Error {
}
}

pub enum ConfigType<'a> {
Static(&'a serde_yaml::Value),
Dynamic(prost_types::Any),
}

impl ConfigType<'_> {
/// Deserializes a config based on the input type.
pub fn deserialize<T, P>(self, filter_name: &str) -> Result<T, Error>
where
P: prost::Message + Default,
T: for<'de> serde::Deserialize<'de> + From<P>,
{
match self {
ConfigType::Static(config) => serde_yaml::to_string(config)
.and_then(|raw_config| serde_yaml::from_str(raw_config.as_str()))
.map_err(|err| Error::DeserializeFailed(err.to_string())),
ConfigType::Dynamic(config) => prost::Message::decode(Bytes::from(config.value))
.map(T::from)
.map_err(|err| {
Error::DeserializeFailed(format!(
"filter `{}`: config decode error: {}",
filter_name,
err.to_string()
))
}),
}
}
}

/// Arguments needed to create a new filter.
pub struct CreateFilterArgs<'a> {
/// Configuration for the filter.
pub config: Option<&'a serde_yaml::Value>,
pub config: Option<ConfigType<'a>>,
/// metrics_registry is used to register filter metrics collectors.
pub metrics_registry: Registry,
}

impl CreateFilterArgs<'_> {
pub fn new(config: Option<&serde_yaml::Value>) -> CreateFilterArgs {
pub fn fixed(config: Option<&serde_yaml::Value>) -> CreateFilterArgs {
CreateFilterArgs {
config,
config: config.map(|config| ConfigType::Static(config)),
metrics_registry: Registry::default(),
}
}
Expand All @@ -258,12 +292,6 @@ impl CreateFilterArgs<'_> {
..self
}
}

pub fn parse_config<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, Error> {
serde_yaml::to_string(&self.config)
.and_then(|raw_config| serde_yaml::from_str(raw_config.as_str()))
.map_err(|err| Error::DeserializeFailed(err.to_string()))
}
}

/// FilterFactory provides the name and creation function for a given Filter.
Expand All @@ -281,6 +309,15 @@ pub trait FilterFactory: Sync + Send {

/// Returns a filter based on the provided arguments.
fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error>;

/// Returns the [`ConfigType`] from the provided Option, otherwise it returns
/// Error::MissingConfig if the Option is None.
fn require_config<'a, 'b>(
&'a self,
config: Option<ConfigType<'b>>,
) -> Result<ConfigType<'b>, Error> {
config.ok_or_else(|| Error::MissingConfig(self.name()))
}
}

/// FilterRegistry is the registry of all Filters that can be applied in the system.
Expand Down Expand Up @@ -335,17 +372,17 @@ mod tests {
let mut reg = FilterRegistry::default();
reg.insert(TestFilterFactory {});

match reg.get(&String::from("not.found"), CreateFilterArgs::new(None)) {
match reg.get(&String::from("not.found"), CreateFilterArgs::fixed(None)) {
Ok(_) => unreachable!("should not be filter"),
Err(err) => assert_eq!(Error::NotFound("not.found".to_string()), err),
};

assert!(reg
.get(&String::from("TestFilter"), CreateFilterArgs::new(None))
.get(&String::from("TestFilter"), CreateFilterArgs::fixed(None))
.is_ok());

let filter = reg
.get(&String::from("TestFilter"), CreateFilterArgs::new(None))
.get(&String::from("TestFilter"), CreateFilterArgs::fixed(None))
.unwrap();

let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
Expand Down
16 changes: 12 additions & 4 deletions src/extensions/filters/capture_bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,17 @@ impl FilterFactory for CaptureBytesFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
Ok(Box::new(CaptureBytes::new(
&self.log,
args.parse_config()?,
self.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?,
Metrics::new(&args.metrics_registry)?,
)))
}
Expand Down Expand Up @@ -213,7 +221,7 @@ mod tests {
map.insert(Value::String("remove".into()), Value::Bool(true));

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.unwrap();
assert_end_strategy(filter.as_ref(), TOKEN_KEY, true);
}
Expand All @@ -224,7 +232,7 @@ mod tests {
let mut map = Mapping::new();
map.insert(Value::String("size".into()), Value::Number(3.into()));
let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.unwrap();
assert_end_strategy(filter.as_ref(), CAPTURED_BYTES, false);
}
Expand All @@ -235,7 +243,7 @@ mod tests {
let mut map = Mapping::new();
map.insert(Value::String("size".into()), Value::String("WRONG".into()));

let result = factory.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))));
let result = factory.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))));
assert!(result.is_err(), "Should be an error");
}

Expand Down
14 changes: 11 additions & 3 deletions src/extensions/filters/compress/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,17 @@ impl FilterFactory for CompressFactory {
&self,
args: CreateFilterArgs,
) -> std::result::Result<Box<dyn Filter>, RegistryError> {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
Ok(Box::new(Compress::new(
&self.log,
args.parse_config()?,
self.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?,
Metrics::new(&args.metrics_registry)?,
)))
}
Expand Down Expand Up @@ -249,7 +257,7 @@ mod tests {
Value::String("DOWNSTREAM".into()),
);
let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.expect("should create a filter");
assert_downstream_direction(filter.as_ref());
}
Expand All @@ -265,7 +273,7 @@ mod tests {
Value::String("DOWNSTREAM".into()),
);
let config = Value::Mapping(map);
let args = CreateFilterArgs::new(Some(&config));
let args = CreateFilterArgs::fixed(Some(&config));

let filter = factory.create_filter(args).expect("should create a filter");
assert_downstream_direction(filter.as_ref());
Expand Down
22 changes: 16 additions & 6 deletions src/extensions/filters/concatenate_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,17 @@ impl FilterFactory for ConcatBytesFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
Ok(Box::new(ConcatenateBytes::new(args.parse_config()?)))
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
Ok(Box::new(ConcatenateBytes::new(
self.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?,
)))
}
}

Expand Down Expand Up @@ -118,7 +128,7 @@ mod tests {
);

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map.clone()))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map.clone()))))
.unwrap();
assert_with_filter(filter.as_ref(), "abchello");

Expand All @@ -129,7 +139,7 @@ mod tests {
);

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map.clone()))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map.clone()))))
.unwrap();
assert_with_filter(filter.as_ref(), "abchello");

Expand All @@ -140,7 +150,7 @@ mod tests {
);

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.unwrap();

assert_with_filter(filter.as_ref(), "helloabc");
Expand All @@ -152,7 +162,7 @@ mod tests {
let mut map = Mapping::new();

let result =
factory.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map.clone()))));
factory.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map.clone()))));
assert!(result.is_err());

// broken strategy
Expand All @@ -161,7 +171,7 @@ mod tests {
Value::String("WRONG".into()),
);

let result = factory.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))));
let result = factory.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))));
assert!(result.is_err());
}

Expand Down
19 changes: 14 additions & 5 deletions src/extensions/filters/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ mod quilkin {
pub(crate) mod filters {
pub(crate) mod debug {
pub(crate) mod v1alpha1 {
#![cfg(not(doctest))]
#![doc(hidden)]
tonic::include_proto!("quilkin.extensions.filters.debug.v1alpha1");
}
}
}
}
}
use self::quilkin::extensions::filters::debug::v1alpha1::Debug as ProtoDebug;

/// Debug logs all incoming and outgoing packets
///
Expand Down Expand Up @@ -80,6 +80,12 @@ struct Config {
id: Option<String>,
}

impl From<ProtoDebug> for Config {
fn from(p: ProtoDebug) -> Self {
Config { id: p.id }
}
}

/// Factory for the Debug
pub struct DebugFactory {
log: Logger,
Expand All @@ -97,7 +103,10 @@ impl FilterFactory for DebugFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
let config: Option<Config> = args.parse_config()?;
let config: Option<Config> = args
.config
.map(|config| config.deserialize::<Config, ProtoDebug>(self.name().as_str()))
.transpose()?;
Ok(Box::new(Debug::new(
&self.log,
config.and_then(|cfg| cfg.id),
Expand Down Expand Up @@ -161,7 +170,7 @@ mod tests {

map.insert(Value::from("id"), Value::from("name"));
assert!(factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map)),))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map)),))
.is_ok());
}

Expand All @@ -173,7 +182,7 @@ mod tests {

map.insert(Value::from("id"), Value::from("name"));
assert!(factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map)),))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map)),))
.is_ok());
}

Expand All @@ -185,7 +194,7 @@ mod tests {

map.insert(Value::from("id"), Value::Sequence(vec![]));
assert!(factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.is_err());
}
}
13 changes: 11 additions & 2 deletions src/extensions/filters/load_balancer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ impl FilterFactory for LoadBalancerFilterFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
let config: Config = args.parse_config()?;
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
let config: Config = self
.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?;

let endpoint_chooser: Box<dyn EndpointChooser> = match config.policy {
Policy::RoundRobin => Box::new(RoundRobinEndpointChooser::new()),
Expand Down Expand Up @@ -137,7 +146,7 @@ mod tests {
fn create_filter(config: &str) -> Box<dyn Filter> {
let factory = LoadBalancerFilterFactory;
factory
.create_filter(CreateFilterArgs::new(Some(
.create_filter(CreateFilterArgs::fixed(Some(
&serde_yaml::from_str(config).unwrap(),
)))
.unwrap()
Expand Down
11 changes: 10 additions & 1 deletion src/extensions/filters/local_rate_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,16 @@ impl FilterFactory for RateLimitFilterFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
let config: Config = args.parse_config()?;
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
let config: Config = self
.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?;

match config.period {
Some(period) if period.lt(&Duration::from_millis(100)) => Err(Error::FieldInvalid {
Expand Down
Loading

0 comments on commit dbe56ea

Please sign in to comment.