Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reuse config deserialization logic across filters #182

Merged
merged 1 commit into from
Feb 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/extensions/filter_chain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ impl FilterChain {
for filter_config in config.source.get_filters() {
match filter_registry.get(
&filter_config.name,
CreateFilterArgs::new(filter_config.config.as_ref())
CreateFilterArgs::fixed(filter_config.config.as_ref())
.with_metrics_registry(metrics_registry.clone()),
) {
Ok(filter) => filters.push(filter),
Expand Down
61 changes: 49 additions & 12 deletions src/extensions/filter_registry.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

use bytes::Bytes;
use std::any::Any;
use std::collections::HashMap;
use std::fmt;
Expand Down Expand Up @@ -204,6 +205,7 @@ pub trait Filter: Send + Sync {
/// Error is an error when attempting to create a Filter from_config() from a FilterFactory
pub enum Error {
NotFound(String),
MissingConfig(String),
FieldInvalid { field: String, reason: String },
DeserializeFailed(String),
InitializeMetricsFailed(String),
Expand All @@ -213,6 +215,9 @@ impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Error::NotFound(key) => write!(f, "filter {} is not found", key),
Error::MissingConfig(filter_name) => {
write!(f, "filter `{}` requires a configuration", filter_name)
}
Error::FieldInvalid { field, reason } => {
write!(f, "field {} is invalid: {}", field, reason)
}
Expand All @@ -236,18 +241,47 @@ impl From<MetricsError> for Error {
}
}

pub enum ConfigType<'a> {
Static(&'a serde_yaml::Value),
Dynamic(prost_types::Any),
}

impl ConfigType<'_> {
/// Deserializes a config based on the input type.
pub fn deserialize<T, P>(self, filter_name: &str) -> Result<T, Error>
where
P: prost::Message + Default,
T: for<'de> serde::Deserialize<'de> + From<P>,
{
match self {
ConfigType::Static(config) => serde_yaml::to_string(config)
.and_then(|raw_config| serde_yaml::from_str(raw_config.as_str()))
.map_err(|err| Error::DeserializeFailed(err.to_string())),
ConfigType::Dynamic(config) => prost::Message::decode(Bytes::from(config.value))
.map(T::from)
.map_err(|err| {
Error::DeserializeFailed(format!(
"filter `{}`: config decode error: {}",
filter_name,
err.to_string()
))
}),
}
}
}

/// Arguments needed to create a new filter.
pub struct CreateFilterArgs<'a> {
/// Configuration for the filter.
pub config: Option<&'a serde_yaml::Value>,
pub config: Option<ConfigType<'a>>,
/// metrics_registry is used to register filter metrics collectors.
pub metrics_registry: Registry,
}

impl CreateFilterArgs<'_> {
pub fn new(config: Option<&serde_yaml::Value>) -> CreateFilterArgs {
pub fn fixed(config: Option<&serde_yaml::Value>) -> CreateFilterArgs {
Copy link
Contributor

@markmandel markmandel Feb 2, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shame we can't call it static - I even went to double check it, but it's a keyword!

CreateFilterArgs {
config,
config: config.map(|config| ConfigType::Static(config)),
metrics_registry: Registry::default(),
}
}
Expand All @@ -258,12 +292,6 @@ impl CreateFilterArgs<'_> {
..self
}
}

pub fn parse_config<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, Error> {
serde_yaml::to_string(&self.config)
.and_then(|raw_config| serde_yaml::from_str(raw_config.as_str()))
.map_err(|err| Error::DeserializeFailed(err.to_string()))
}
}

/// FilterFactory provides the name and creation function for a given Filter.
Expand All @@ -281,6 +309,15 @@ pub trait FilterFactory: Sync + Send {

/// Returns a filter based on the provided arguments.
fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error>;

/// Returns the [`ConfigType`] from the provided Option, otherwise it returns
/// Error::MissingConfig if the Option is None.
fn require_config<'a, 'b>(
&'a self,
config: Option<ConfigType<'b>>,
) -> Result<ConfigType<'b>, Error> {
config.ok_or_else(|| Error::MissingConfig(self.name()))
}
}

/// FilterRegistry is the registry of all Filters that can be applied in the system.
Expand Down Expand Up @@ -335,17 +372,17 @@ mod tests {
let mut reg = FilterRegistry::default();
reg.insert(TestFilterFactory {});

match reg.get(&String::from("not.found"), CreateFilterArgs::new(None)) {
match reg.get(&String::from("not.found"), CreateFilterArgs::fixed(None)) {
Ok(_) => unreachable!("should not be filter"),
Err(err) => assert_eq!(Error::NotFound("not.found".to_string()), err),
};

assert!(reg
.get(&String::from("TestFilter"), CreateFilterArgs::new(None))
.get(&String::from("TestFilter"), CreateFilterArgs::fixed(None))
.is_ok());

let filter = reg
.get(&String::from("TestFilter"), CreateFilterArgs::new(None))
.get(&String::from("TestFilter"), CreateFilterArgs::fixed(None))
.unwrap();

let addr = SocketAddr::new(IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 8080);
Expand Down
16 changes: 12 additions & 4 deletions src/extensions/filters/capture_bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,17 @@ impl FilterFactory for CaptureBytesFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
Ok(Box::new(CaptureBytes::new(
&self.log,
args.parse_config()?,
self.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?,
Metrics::new(&args.metrics_registry)?,
)))
}
Expand Down Expand Up @@ -213,7 +221,7 @@ mod tests {
map.insert(Value::String("remove".into()), Value::Bool(true));

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.unwrap();
assert_end_strategy(filter.as_ref(), TOKEN_KEY, true);
}
Expand All @@ -224,7 +232,7 @@ mod tests {
let mut map = Mapping::new();
map.insert(Value::String("size".into()), Value::Number(3.into()));
let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.unwrap();
assert_end_strategy(filter.as_ref(), CAPTURED_BYTES, false);
}
Expand All @@ -235,7 +243,7 @@ mod tests {
let mut map = Mapping::new();
map.insert(Value::String("size".into()), Value::String("WRONG".into()));

let result = factory.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))));
let result = factory.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))));
assert!(result.is_err(), "Should be an error");
}

Expand Down
14 changes: 11 additions & 3 deletions src/extensions/filters/compress/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,17 @@ impl FilterFactory for CompressFactory {
&self,
args: CreateFilterArgs,
) -> std::result::Result<Box<dyn Filter>, RegistryError> {
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
Ok(Box::new(Compress::new(
&self.log,
args.parse_config()?,
self.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?,
Metrics::new(&args.metrics_registry)?,
)))
}
Expand Down Expand Up @@ -249,7 +257,7 @@ mod tests {
Value::String("DOWNSTREAM".into()),
);
let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.expect("should create a filter");
assert_downstream_direction(filter.as_ref());
}
Expand All @@ -265,7 +273,7 @@ mod tests {
Value::String("DOWNSTREAM".into()),
);
let config = Value::Mapping(map);
let args = CreateFilterArgs::new(Some(&config));
let args = CreateFilterArgs::fixed(Some(&config));

let filter = factory.create_filter(args).expect("should create a filter");
assert_downstream_direction(filter.as_ref());
Expand Down
22 changes: 16 additions & 6 deletions src/extensions/filters/concatenate_bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,17 @@ impl FilterFactory for ConcatBytesFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
Ok(Box::new(ConcatenateBytes::new(args.parse_config()?)))
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
Ok(Box::new(ConcatenateBytes::new(
self.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?,
)))
}
}

Expand Down Expand Up @@ -118,7 +128,7 @@ mod tests {
);

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map.clone()))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map.clone()))))
.unwrap();
assert_with_filter(filter.as_ref(), "abchello");

Expand All @@ -129,7 +139,7 @@ mod tests {
);

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map.clone()))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map.clone()))))
.unwrap();
assert_with_filter(filter.as_ref(), "abchello");

Expand All @@ -140,7 +150,7 @@ mod tests {
);

let filter = factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.unwrap();

assert_with_filter(filter.as_ref(), "helloabc");
Expand All @@ -152,7 +162,7 @@ mod tests {
let mut map = Mapping::new();

let result =
factory.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map.clone()))));
factory.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map.clone()))));
assert!(result.is_err());

// broken strategy
Expand All @@ -161,7 +171,7 @@ mod tests {
Value::String("WRONG".into()),
);

let result = factory.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))));
let result = factory.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))));
assert!(result.is_err());
}

Expand Down
19 changes: 14 additions & 5 deletions src/extensions/filters/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ mod quilkin {
pub(crate) mod filters {
pub(crate) mod debug {
pub(crate) mod v1alpha1 {
#![cfg(not(doctest))]
#![doc(hidden)]
tonic::include_proto!("quilkin.extensions.filters.debug.v1alpha1");
}
}
}
}
}
use self::quilkin::extensions::filters::debug::v1alpha1::Debug as ProtoDebug;

/// Debug logs all incoming and outgoing packets
///
Expand Down Expand Up @@ -80,6 +80,12 @@ struct Config {
id: Option<String>,
}

impl From<ProtoDebug> for Config {
fn from(p: ProtoDebug) -> Self {
Config { id: p.id }
}
}

/// Factory for the Debug
pub struct DebugFactory {
log: Logger,
Expand All @@ -97,7 +103,10 @@ impl FilterFactory for DebugFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
let config: Option<Config> = args.parse_config()?;
let config: Option<Config> = args
.config
.map(|config| config.deserialize::<Config, ProtoDebug>(self.name().as_str()))
.transpose()?;
Ok(Box::new(Debug::new(
&self.log,
config.and_then(|cfg| cfg.id),
Expand Down Expand Up @@ -161,7 +170,7 @@ mod tests {

map.insert(Value::from("id"), Value::from("name"));
assert!(factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map)),))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map)),))
.is_ok());
}

Expand All @@ -173,7 +182,7 @@ mod tests {

map.insert(Value::from("id"), Value::from("name"));
assert!(factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map)),))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map)),))
.is_ok());
}

Expand All @@ -185,7 +194,7 @@ mod tests {

map.insert(Value::from("id"), Value::Sequence(vec![]));
assert!(factory
.create_filter(CreateFilterArgs::new(Some(&Value::Mapping(map))))
.create_filter(CreateFilterArgs::fixed(Some(&Value::Mapping(map))))
.is_err());
}
}
13 changes: 11 additions & 2 deletions src/extensions/filters/load_balancer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,16 @@ impl FilterFactory for LoadBalancerFilterFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
let config: Config = args.parse_config()?;
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
let config: Config = self
.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?;

let endpoint_chooser: Box<dyn EndpointChooser> = match config.policy {
Policy::RoundRobin => Box::new(RoundRobinEndpointChooser::new()),
Expand Down Expand Up @@ -137,7 +146,7 @@ mod tests {
fn create_filter(config: &str) -> Box<dyn Filter> {
let factory = LoadBalancerFilterFactory;
factory
.create_filter(CreateFilterArgs::new(Some(
.create_filter(CreateFilterArgs::fixed(Some(
&serde_yaml::from_str(config).unwrap(),
)))
.unwrap()
Expand Down
11 changes: 10 additions & 1 deletion src/extensions/filters/local_rate_limit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,16 @@ impl FilterFactory for RateLimitFilterFactory {
}

fn create_filter(&self, args: CreateFilterArgs) -> Result<Box<dyn Filter>, Error> {
let config: Config = args.parse_config()?;
#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TODO;
impl From<TODO> for Config {
fn from(_: TODO) -> Self {
unimplemented!()
}
}
let config: Config = self
.require_config(args.config)?
.deserialize::<Config, TODO>(self.name().as_str())?;

match config.period {
Some(period) if period.lt(&Duration::from_millis(100)) => Err(Error::FieldInvalid {
Expand Down
Loading