diff --git a/Cargo.lock b/Cargo.lock index bbbd89c1..301842c1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1486,6 +1486,7 @@ dependencies = [ "r2d2", "rand", "redis", + "redis-test", "reqwest", "rmp-serde", "rocksdb", @@ -2388,6 +2389,16 @@ dependencies = [ "url", ] +[[package]] +name = "redis-test" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a948b3cec9e4b1fedbb0f0788e79029fb1f641b6cfefb7a15d044f803854427" +dependencies = [ + "futures", + "redis", +] + [[package]] name = "redox_syscall" version = "0.4.1" diff --git a/limitador/Cargo.toml b/limitador/Cargo.toml index 103043df..2d6646c6 100644 --- a/limitador/Cargo.toml +++ b/limitador/Cargo.toml @@ -54,6 +54,7 @@ base64 = { version = "0.22", optional = true } [dev-dependencies] serial_test = "3.0" criterion = { version = "0.5.1", features = ["html_reports"] } +redis-test = { version = "0.4.0", features = ["aio"] } paste = "1" rand = "0.8" tempfile = "3.5.0" diff --git a/limitador/src/storage/keys.rs b/limitador/src/storage/keys.rs index 51b257c2..86110521 100644 --- a/limitador/src/storage/keys.rs +++ b/limitador/src/storage/keys.rs @@ -36,7 +36,7 @@ pub fn prefix_for_namespace(namespace: &str) -> String { } pub fn counter_from_counter_key(key: &str, limit: &Limit) -> Counter { - let mut counter = partial_counter_from_counter_key(key, limit.namespace().as_ref()); + let mut counter = partial_counter_from_counter_key(key); if !counter.update_to_limit(limit) { // this means some kind of data corruption _or_ most probably // an out of sync `impl PartialEq for Limit` vs `pub fn key_for_counter(counter: &Counter) -> String` @@ -49,11 +49,24 @@ pub fn counter_from_counter_key(key: &str, limit: &Limit) -> Counter { counter } -pub fn partial_counter_from_counter_key(key: &str, namespace: &str) -> Counter { - let offset = ",counter:".len(); - let start_pos_counter = prefix_for_namespace(namespace).len() + offset; - - let counter: Counter = serde_json::from_str(&key[start_pos_counter..]).unwrap(); +pub fn partial_counter_from_counter_key(key: &str) -> Counter { + let namespace_prefix = "namespace:"; + let counter_prefix = ",counter:"; + + // Find the start position of the counter portion + let start_pos_namespace = key + .find(namespace_prefix) + .expect("Namespace not found in the key"); + let start_pos_counter = key[start_pos_namespace..] + .find(counter_prefix) + .expect("Counter not found in the key") + + start_pos_namespace + + counter_prefix.len(); + + // Extract counter JSON substring and deserialize it + let counter_str = &key[start_pos_counter..]; + let counter: Counter = + serde_json::from_str(counter_str).expect("Failed to deserialize counter JSON"); counter } @@ -87,7 +100,7 @@ mod tests { let limit = Limit::new(namespace, 1, 1, vec!["req.method == 'GET'"], vec!["app_id"]); let counter = Counter::new(limit.clone(), HashMap::default()); let raw = key_for_counter(&counter); - assert_eq!(counter, partial_counter_from_counter_key(&raw, namespace)); + assert_eq!(counter, partial_counter_from_counter_key(&raw)); let prefix = prefix_for_namespace(namespace); assert_eq!(&raw[0..prefix.len()], &prefix); } diff --git a/limitador/src/storage/redis/counters_cache.rs b/limitador/src/storage/redis/counters_cache.rs index 603fc3ac..2f2b22ca 100644 --- a/limitador/src/storage/redis/counters_cache.rs +++ b/limitador/src/storage/redis/counters_cache.rs @@ -121,9 +121,10 @@ impl CountersCache { ); if let Some(ttl) = cache_ttl.checked_sub(ttl_margin) { if ttl > Duration::ZERO { - let value = CachedCounterValue::from(&counter, counter_val, cache_ttl); - let previous = self.cache.get_with(counter.clone(), || Arc::new(value)); - if previous.expired_at(now) { + let previous = self.cache.get_with(counter.clone(), || { + Arc::new(CachedCounterValue::from(&counter, counter_val, cache_ttl)) + }); + if previous.expired_at(now) || previous.value.value() < counter_val { previous.set_from_authority(&counter, counter_val, cache_ttl); } return previous; diff --git a/limitador/src/storage/redis/redis_cached.rs b/limitador/src/storage/redis/redis_cached.rs index c08989b5..de98ab06 100644 --- a/limitador/src/storage/redis/redis_cached.rs +++ b/limitador/src/storage/redis/redis_cached.rs @@ -4,21 +4,22 @@ use crate::storage::atomic_expiring_value::AtomicExpiringValue; use crate::storage::keys::*; use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; use crate::storage::redis::redis_async::AsyncRedisStorage; -use crate::storage::redis::scripts::VALUES_AND_TTLS; +use crate::storage::redis::scripts::{BATCH_UPDATE_COUNTERS, VALUES_AND_TTLS}; use crate::storage::redis::{ DEFAULT_FLUSHING_PERIOD_SEC, DEFAULT_MAX_CACHED_COUNTERS, DEFAULT_MAX_TTL_CACHED_COUNTERS_SEC, DEFAULT_RESPONSE_TIMEOUT_MS, DEFAULT_TTL_RATIO_CACHED_COUNTERS, }; use crate::storage::{AsyncCounterStorage, Authorization, StorageErr}; use async_trait::async_trait; -use redis::aio::ConnectionManager; +use redis::aio::{ConnectionLike, ConnectionManager}; use redis::{ConnectionInfo, RedisError}; use std::collections::{HashMap, HashSet}; +use std::future::Future; use std::str::FromStr; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant, SystemTime}; -use tracing::{error, warn}; +use tracing::{debug_span, error, warn, Instrument}; // This is just a first version. // @@ -39,7 +40,7 @@ use tracing::{error, warn}; // multiple times when it is not cached. pub struct CachedRedisStorage { - cached_counters: CountersCache, + cached_counters: Arc, batcher_counter_updates: Arc>>, async_redis_storage: AsyncRedisStorage, redis_conn_manager: ConnectionManager, @@ -226,52 +227,43 @@ impl CachedRedisStorage { ) .await?; + let cached_counters = CountersCacheBuilder::new() + .max_cached_counters(max_cached_counters) + .max_ttl_cached_counter(ttl_cached_counters) + .ttl_ratio_cached_counter(ttl_ratio_cached_counters) + .build(); + + let counters_cache = Arc::new(cached_counters); let partitioned = Arc::new(AtomicBool::new(false)); let async_redis_storage = AsyncRedisStorage::new_with_conn_manager(redis_conn_manager.clone()); - - let storage = async_redis_storage.clone(); let batcher: Arc>> = Arc::new(Mutex::new(Default::default())); - let p = Arc::clone(&partitioned); - let batcher_flusher = batcher.clone(); - let mut interval = tokio::time::interval(flushing_period); - tokio::spawn(async move { - loop { - if p.load(Ordering::Acquire) { - if storage.is_alive().await { - warn!("Partition to Redis resolved!"); - p.store(false, Ordering::Release); - } - } else { - let counters = { - let mut batch = batcher_flusher.lock().unwrap(); - std::mem::take(&mut *batch) - }; - let now = SystemTime::now(); - for (counter, delta) in counters { - let delta = delta.value_at(now); - if delta > 0 { - storage - .update_counter(&counter, delta) - .await - .or_else(|err| if err.is_transient() { Ok(()) } else { Err(err) }) - .expect("Unrecoverable Redis error!"); - } - } - } - interval.tick().await; - } - }); - let cached_counters = CountersCacheBuilder::new() - .max_cached_counters(max_cached_counters) - .max_ttl_cached_counter(ttl_cached_counters) - .ttl_ratio_cached_counter(ttl_ratio_cached_counters) - .build(); + { + let storage = async_redis_storage.clone(); + let counters_cache_clone = counters_cache.clone(); + let conn = redis_conn_manager.clone(); + let p = Arc::clone(&partitioned); + let batcher_flusher = batcher.clone(); + let mut interval = tokio::time::interval(flushing_period); + tokio::spawn(async move { + loop { + flush_batcher_and_update_counters( + conn.clone(), + batcher_flusher.clone(), + storage.is_alive(), + counters_cache_clone.clone(), + p.clone(), + ) + .await; + interval.tick().await; + } + }); + } Ok(Self { - cached_counters, + cached_counters: counters_cache, batcher_counter_updates: batcher, redis_conn_manager, async_redis_storage, @@ -395,10 +387,108 @@ impl CachedRedisStorageBuilder { } } +async fn update_counters( + redis_conn: &mut C, + counters_and_deltas: HashMap, +) -> Result, StorageErr> { + let redis_script = redis::Script::new(BATCH_UPDATE_COUNTERS); + let mut script_invocation = redis_script.prepare_invoke(); + + let mut res: Vec<(Counter, i64, i64)> = Vec::new(); + let now = SystemTime::now(); + for (counter, delta) in counters_and_deltas { + let delta = delta.value_at(now); + if delta > 0 { + script_invocation.key(key_for_counter(&counter)); + script_invocation.key(key_for_counters_of_limit(counter.limit())); + script_invocation.arg(counter.seconds()); + script_invocation.arg(delta); + // We need to store the counter in the actual order we are sending it to the script + res.push((counter, 0, 0)); + } + } + + let span = debug_span!("datastore"); + // The redis crate is not working with tables, thus the response will be a Vec of counter values + let script_res: Vec = script_invocation + .invoke_async(redis_conn) + .instrument(span) + .await?; + + // We need to update the values and ttls returned by redis + let counters_range = 0..res.len(); + let script_res_range = (0..script_res.len()).step_by(2); + + for (i, j) in counters_range.zip(script_res_range) { + let (_, val, ttl) = &mut res[i]; + *val = script_res[j]; + *ttl = script_res[j + 1]; + } + + Ok(res) +} + +async fn flush_batcher_and_update_counters( + mut redis_conn: C, + batcher: Arc>>, + storage_is_alive: impl Future, + cached_counters: Arc, + partitioned: Arc, +) { + if partitioned.load(Ordering::Acquire) { + if storage_is_alive.await { + warn!("Partition to Redis resolved!"); + partitioned.store(false, Ordering::Release); + } + } else { + let counters = { + let mut batch = batcher.lock().unwrap(); + std::mem::take(&mut *batch) + }; + + let time_start_update_counters = Instant::now(); + + let updated_counters = update_counters(&mut redis_conn, counters) + .await + .or_else(|err| { + if err.is_transient() { + partitioned.store(true, Ordering::Release); + Ok(Vec::new()) + } else { + Err(err) + } + }) + .expect("Unrecoverable Redis error!"); + + for (counter, value, ttl) in updated_counters { + cached_counters.insert( + counter, + Option::from(value), + ttl, + Duration::from_millis( + (Instant::now() - time_start_update_counters).as_millis() as u64 + ), + SystemTime::now(), + ); + } + } +} + #[cfg(test)] mod tests { + use crate::counter::Counter; + use crate::limit::Limit; + use crate::storage::atomic_expiring_value::AtomicExpiringValue; + use crate::storage::keys::{key_for_counter, key_for_counters_of_limit}; + use crate::storage::redis::counters_cache::{CountersCache, CountersCacheBuilder}; + use crate::storage::redis::redis_cached::{flush_batcher_and_update_counters, update_counters}; use crate::storage::redis::CachedRedisStorage; - use redis::ErrorKind; + use redis::{ErrorKind, Value}; + use redis_test::{MockCmd, MockRedisConnection}; + use std::collections::HashMap; + use std::sync::atomic::AtomicBool; + use std::sync::{Arc, Mutex}; + use std::time::{Duration, SystemTime}; #[tokio::test] async fn errs_on_bad_url() { @@ -415,4 +505,116 @@ mod tests { assert_eq!(error.kind(), ErrorKind::IoError); assert!(error.is_connection_refusal()) } + + #[tokio::test] + async fn batch_update_counters() { + let mut counters_and_deltas = HashMap::new(); + let counter = Counter::new( + Limit::new( + "test_namespace", + 10, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ), + Default::default(), + ); + + let expiring_value = + AtomicExpiringValue::new(1, SystemTime::now() + Duration::from_secs(60)); + + counters_and_deltas.insert(counter.clone(), expiring_value); + + let mock_response = Value::Bulk(vec![Value::Int(10), Value::Int(60)]); + + let mut mock_client = MockRedisConnection::new(vec![MockCmd::new( + redis::cmd("EVALSHA") + .arg("1e87383cf7dba2bd0f9972ed73671274e6cbd5da") + .arg("2") + .arg(key_for_counter(&counter)) + .arg(key_for_counters_of_limit(counter.limit())) + .arg(60) + .arg(1), + Ok(mock_response.clone()), + )]); + + let result = update_counters(&mut mock_client, counters_and_deltas).await; + + assert!(result.is_ok()); + + let (c, v, t) = result.unwrap()[0].clone(); + assert_eq!( + "req.method == \"GET\"", + c.limit().conditions().iter().collect::>()[0] + ); + assert_eq!(10, v); + assert_eq!(60, t); + } + + #[tokio::test] + async fn flush_batcher_and_update_counters_test() { + let counter = Counter::new( + Limit::new( + "test_namespace", + 10, + 60, + vec!["req.method == 'POST'"], + vec!["app_id"], + ), + Default::default(), + ); + + let mock_response = Value::Bulk(vec![Value::Int(8), Value::Int(60)]); + + let mock_client = MockRedisConnection::new(vec![MockCmd::new( + redis::cmd("EVALSHA") + .arg("1e87383cf7dba2bd0f9972ed73671274e6cbd5da") + .arg("2") + .arg(key_for_counter(&counter)) + .arg(key_for_counters_of_limit(counter.limit())) + .arg(60) + .arg(2), + Ok(mock_response.clone()), + )]); + + let mut batched_counters = HashMap::new(); + batched_counters.insert( + counter.clone(), + AtomicExpiringValue::new(2, SystemTime::now() + Duration::from_secs(60)), + ); + + let batcher: Arc>> = + Arc::new(Mutex::new(batched_counters)); + let cache = CountersCacheBuilder::new().build(); + cache.insert( + counter.clone(), + Some(1), + 10, + Duration::from_secs(0), + SystemTime::now(), + ); + let cached_counters: Arc = Arc::new(cache); + let partitioned = Arc::new(AtomicBool::new(false)); + + async fn future_true() -> bool { + true + } + + if let Some(c) = cached_counters.get(&counter) { + assert_eq!(c.hits(&counter), 1); + } + + flush_batcher_and_update_counters( + mock_client, + batcher, + future_true(), + cached_counters.clone(), + partitioned, + ) + .await; + + if let Some(c) = cached_counters.get(&counter) { + assert_eq!(c.hits(&counter), 8); + } + } } diff --git a/limitador/src/storage/redis/scripts.rs b/limitador/src/storage/redis/scripts.rs index 6c2432c9..b241d88d 100644 --- a/limitador/src/storage/redis/scripts.rs +++ b/limitador/src/storage/redis/scripts.rs @@ -19,6 +19,33 @@ pub const SCRIPT_UPDATE_COUNTER: &str = " end return c"; +// KEY[i]: Counter key +// KEY[i+1]: Limit key +// ARGV[i]: TTLs +// ARGV[i+1]: Deltas +// This function returns a list with the values and TTLs for the updated counter_keys, +// the first position the counter value and the second the TTL +pub const BATCH_UPDATE_COUNTERS: &str = " + local res = {} + for i = 1, #KEYS, 2 do + local counter_key = KEYS[i] + local limit_key = KEYS[i+1] + local ttl = ARGV[i] + local delta = ARGV[i+1] + + local c = redis.call('incrby', counter_key, delta) + table.insert(res, c) + if c == tonumber(delta) then + redis.call('expire', counter_key, ttl) + redis.call('sadd', limit_key, counter_key) + table.insert(res, ttl*1000) + else + table.insert(res, redis.call('pttl', counter_key)) + end + end + return res +"; + // KEYS: the function returns the value and TTL (in ms) for these keys // The first position of the list returned contains the value of KEYS[1], the // second position contains its TTL. The third position contains the value of diff --git a/limitador/tests/integration_tests.rs b/limitador/tests/integration_tests.rs index 5cdb6b88..844a90ed 100644 --- a/limitador/tests/integration_tests.rs +++ b/limitador/tests/integration_tests.rs @@ -59,7 +59,7 @@ macro_rules! test_with_all_storage_impls { #[serial] async fn [<$function _with_async_redis_and_local_cache>]() { let storage_builder = CachedRedisStorageBuilder::new("redis://127.0.0.1:6379"). - flushing_period(Duration::from_millis(200)). + flushing_period(Duration::from_millis(2)). max_ttl_cached_counters(Duration::from_secs(3600)). ttl_ratio_cached_counters(1). max_cached_counters(10000); @@ -149,6 +149,7 @@ mod test { test_with_all_storage_impls!(delete_limits_of_a_namespace_also_deletes_counters); test_with_all_storage_impls!(delete_limits_of_an_empty_namespace_does_nothing); test_with_all_storage_impls!(rate_limited); + test_with_all_storage_impls!(multiple_limits_rate_limited); test_with_all_storage_impls!(rate_limited_with_delta_higher_than_one); test_with_all_storage_impls!(rate_limited_with_delta_higher_than_max); test_with_all_storage_impls!(takes_into_account_only_vars_of_the_limits); @@ -478,6 +479,76 @@ mod test { .unwrap()); } + async fn multiple_limits_rate_limited(rate_limiter: &mut TestsLimiter) { + let namespace = "test_namespace"; + let max_hits = 3; + let limits = vec![ + Limit::new( + namespace, + max_hits, + 60, + vec!["req.method == 'GET'"], + vec!["app_id"], + ), + Limit::new( + namespace, + max_hits + 1, + 60, + vec!["req.method == 'POST'"], + vec!["app_id"], + ), + ]; + + for limit in limits { + rate_limiter.add_limit(&limit).await; + } + + let mut get_values: HashMap = HashMap::new(); + get_values.insert("req.method".to_string(), "GET".to_string()); + get_values.insert("app_id".to_string(), "test_app_id".to_string()); + + let mut post_values: HashMap = HashMap::new(); + post_values.insert("req.method".to_string(), "POST".to_string()); + post_values.insert("app_id".to_string(), "test_app_id".to_string()); + + for i in 0..max_hits { + assert!( + !rate_limiter + .is_rate_limited(namespace, &get_values, 1) + .await + .unwrap(), + "Must not be limited after {i}" + ); + assert!( + !rate_limiter + .is_rate_limited(namespace, &post_values, 1) + .await + .unwrap(), + "Must not be limited after {i}" + ); + rate_limiter + .check_rate_limited_and_update(namespace, &get_values, 1, false) + .await + .unwrap(); + rate_limiter + .check_rate_limited_and_update(namespace, &post_values, 1, false) + .await + .unwrap(); + } + + // We wait for the flushing period to pass so the counters are flushed in the cached storage + tokio::time::sleep(Duration::from_millis(3)).await; + + assert!(rate_limiter + .is_rate_limited(namespace, &get_values, 1) + .await + .unwrap()); + assert!(!rate_limiter + .is_rate_limited(namespace, &post_values, 1) + .await + .unwrap()); + } + async fn rate_limited_with_delta_higher_than_one(rate_limiter: &mut TestsLimiter) { let namespace = "test_namespace"; let limit = Limit::new(