diff --git a/CHANGELOG.md b/CHANGELOG.md index 972731dd..bcc8018d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Moka — Change Log +## Version 0.5.0 (Not released yet) + +### Added + +- Add `get_or_insert_with` and `get_or_try_insert_with` methods to `sync` and + `future` caches. ([#20][gh-pull-0020]) + + ## Version 0.4.0 ### Fixed @@ -15,6 +23,7 @@ - Add `invalidate_entries_if` method to `sync`, `future` and `unsync` caches. ([#12][gh-pull-0012]) + ## Version 0.3.1 ### Changed @@ -65,6 +74,7 @@ [caffeine-git]: https://github.com/ben-manes/caffeine +[gh-pull-0020]: https://github.com/moka-rs/moka/pull/20/ [gh-pull-0019]: https://github.com/moka-rs/moka/pull/19/ [gh-pull-0016]: https://github.com/moka-rs/moka/pull/16/ [gh-pull-0012]: https://github.com/moka-rs/moka/pull/12/ diff --git a/Cargo.toml b/Cargo.toml index 63d46141..ebc73734 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "moka" -version = "0.4.0" +version = "0.5.0" authors = ["Tatsuya Kawano "] edition = "2018" @@ -21,7 +21,7 @@ features = ["future"] [features] default = [] -future = ["async-io"] +future = ["async-io", "async-lock"] [dependencies] cht = "0.4" @@ -35,7 +35,8 @@ thiserror = "1.0" uuid = { version = "0.8", features = ["v4"] } # Optional dependencies -async-io = { version = "1", optional = true } +async-io = { version = "1.4", optional = true } +async-lock = { version = "2.4", optional = true } [dev-dependencies] actix-rt2 = { package = "actix-rt", version = "2", default-features = false } diff --git a/README.md b/README.md index 08762aa4..7d9a0d7f 100644 --- a/README.md +++ b/README.md @@ -61,14 +61,14 @@ Add this to your `Cargo.toml`: ```toml [dependencies] -moka = "0.4" +moka = "0.5" ``` To use the asynchronous cache, enable a crate feature called "future". ```toml [dependencies] -moka = { version = "0.4", features = ["future"] } +moka = { version = "0.5", features = ["future"] } ``` @@ -164,7 +164,7 @@ Here is a similar program to the previous example, but using asynchronous cache // Cargo.toml // // [dependencies] -// moka = { version = "0.4", features = ["future"] } +// moka = { version = "0.5", features = ["future"] } // tokio = { version = "1", features = ["rt-multi-thread", "macros" ] } // futures = "0.3" diff --git a/src/future.rs b/src/future.rs index c65d6393..511a7ac2 100644 --- a/src/future.rs +++ b/src/future.rs @@ -4,6 +4,7 @@ mod builder; mod cache; +mod value_initializer; pub use builder::CacheBuilder; pub use cache::Cache; diff --git a/src/future/cache.rs b/src/future/cache.rs index 50ab1a1b..a2778a6f 100644 --- a/src/future/cache.rs +++ b/src/future/cache.rs @@ -1,4 +1,7 @@ -use super::ConcurrentCacheExt; +use super::{ + value_initializer::{InitResult, ValueInitializer}, + ConcurrentCacheExt, +}; use crate::{ sync::{ base_cache::{BaseCache, HouseKeeperArc, MAX_SYNC_REPEATS, WRITE_RETRY_INTERVAL_MICROS}, @@ -12,6 +15,8 @@ use crossbeam_channel::{Sender, TrySendError}; use std::{ borrow::Borrow, collections::hash_map::RandomState, + error::Error, + future::Future, hash::{BuildHasher, Hash}, sync::Arc, time::Duration, @@ -52,7 +57,7 @@ use std::{ /// // Cargo.toml /// // /// // [dependencies] -/// // moka = { version = "0.4", features = ["future"] } +/// // moka = { version = "0.5", features = ["future"] } /// // tokio = { version = "1", features = ["rt-multi-thread", "macros" ] } /// // futures = "0.3" /// @@ -187,6 +192,7 @@ use std::{ #[derive(Clone)] pub struct Cache { base: BaseCache, + value_initializer: Arc>, } unsafe impl Send for Cache @@ -240,11 +246,12 @@ where base: BaseCache::new( max_capacity, initial_capacity, - build_hasher, + build_hasher.clone(), time_to_live, time_to_idle, invalidator_enabled, ), + value_initializer: Arc::new(ValueInitializer::with_hasher(build_hasher)), } } @@ -266,11 +273,47 @@ where self.base.get_with_hash(key, self.base.hash(key)) } + /// Ensures the value of the key exists by inserting the output of the init + /// future if not exist, and returns a _clone_ of the value. + /// + /// This method prevents to resolve the init future multiple times on the same + /// key even if the method is concurrently called by many async tasks; only one + /// of the calls resolves its future, and other calls wait for that future to + /// complete. + pub async fn get_or_insert_with(&self, key: K, init: impl Future) -> V { + let hash = self.base.hash(&key); + let key = Arc::new(key); + self.get_or_insert_with_hash_and_fun(key, hash, init).await + } + + /// Try to ensure the value of the key exists by inserting an `Ok` output of the + /// init future if not exist, and returns a _clone_ of the value or the `Err` + /// produced by the future. + /// + /// This method prevents to resolve the init future multiple times on the same + /// key even if the method is concurrently called by many async tasks; only one + /// of the calls resolves its future, and other calls wait for that future to + /// complete. + pub async fn get_or_try_insert_with( + &self, + key: K, + init: F, + ) -> Result>> + where + F: Future>>, + { + let hash = self.base.hash(&key); + let key = Arc::new(key); + self.get_or_try_insert_with_hash_and_fun(key, hash, init) + .await + } + /// Inserts a key-value pair into the cache. /// /// If the cache has this key present, the value is updated. pub async fn insert(&self, key: K, value: V) { let hash = self.base.hash(&key); + let key = Arc::new(key); self.insert_with_hash(key, hash, value).await } @@ -280,6 +323,7 @@ where /// synchronous code. pub fn blocking_insert(&self, key: K, value: V) { let hash = self.base.hash(&key); + let key = Arc::new(key); let op = self.base.do_insert_with_hash(key, hash, value); let hk = self.base.housekeeper.as_ref(); if Self::blocking_schedule_write_op(&self.base.write_op_ch, op, hk).is_err() { @@ -414,7 +458,63 @@ where V: Clone + Send + Sync + 'static, S: BuildHasher + Clone + Send + Sync + 'static, { - async fn insert_with_hash(&self, key: K, hash: u64, value: V) { + async fn get_or_insert_with_hash_and_fun( + &self, + key: Arc, + hash: u64, + init: impl Future, + ) -> V { + if let Some(v) = self.base.get_with_hash(&key, hash) { + return v; + } + + match self + .value_initializer + .init_or_read(Arc::clone(&key), init) + .await + { + InitResult::Initialized(v) => { + self.insert_with_hash(Arc::clone(&key), hash, v.clone()) + .await; + self.value_initializer.remove_waiter(&key); + v + } + InitResult::ReadExisting(v) => v, + InitResult::InitErr(_) => unreachable!(), + } + } + + async fn get_or_try_insert_with_hash_and_fun( + &self, + key: Arc, + hash: u64, + init: F, + ) -> Result>> + where + F: Future>>, + { + if let Some(v) = self.base.get_with_hash(&key, hash) { + return Ok(v); + } + + match self + .value_initializer + .try_init_or_read(Arc::clone(&key), init) + .await + { + InitResult::Initialized(v) => { + let hash = self.base.hash(&key); + self.insert_with_hash(Arc::clone(&key), hash, v.clone()) + .await; + self.value_initializer.remove_waiter(&key); + Ok(v) + } + InitResult::ReadExisting(v) => Ok(v), + InitResult::InitErr(e) => Err(e), + } + } + + async fn insert_with_hash(&self, key: Arc, hash: u64, value: V) { let op = self.base.do_insert_with_hash(key, hash, value); let hk = self.base.housekeeper.as_ref(); if Self::schedule_write_op(&self.base.write_op_ch, op, hk) @@ -508,6 +608,7 @@ mod tests { use super::{Cache, ConcurrentCacheExt}; use crate::future::CacheBuilder; + use async_io::Timer; use quanta::Clock; use std::time::Duration; @@ -834,4 +935,221 @@ mod tests { assert_eq!(cache.get(&"b"), None); assert!(cache.is_table_empty()); } + + #[tokio::test] + async fn get_or_insert_with() { + let cache = Cache::new(100); + const KEY: u32 = 0; + + // This test will run five async tasks: + // + // Task1 will be the first task to call `get_or_insert_with` for a key, so + // its async block will be evaluated and then a &str value "task1" will be + // inserted to the cache. + let task1 = { + let cache1 = cache.clone(); + async move { + // Call `get_or_insert_with` immediately. + let v = cache1 + .get_or_insert_with(KEY, async { + // Wait for 300 ms and return a &str value. + Timer::after(Duration::from_millis(300)).await; + "task1" + }) + .await; + assert_eq!(v, "task1"); + } + }; + + // Task2 will be the second task to call `get_or_insert_with` for the same + // key, so its async block will not be evaluated. Once task1's async block + // finishes, it will get the value inserted by task1's async block. + let task2 = { + let cache2 = cache.clone(); + async move { + // Wait for 100 ms before calling `get_or_insert_with`. + Timer::after(Duration::from_millis(100)).await; + let v = cache2 + .get_or_insert_with(KEY, async { unreachable!() }) + .await; + assert_eq!(v, "task1"); + } + }; + + // Task3 will be the third task to call `get_or_insert_with` for the same + // key. By the time it calls, task1's async block should have finished + // already and the value should be already inserted to the cache. So its + // async block will not be evaluated and will get the value insert by task1's + // async block immediately. + let task3 = { + let cache3 = cache.clone(); + async move { + // Wait for 400 ms before calling `get_or_insert_with`. + Timer::after(Duration::from_millis(400)).await; + let v = cache3 + .get_or_insert_with(KEY, async { unreachable!() }) + .await; + assert_eq!(v, "task1"); + } + }; + + // Task4 will call `get` for the same key. It will call when task1's async + // block is still running, so it will get none for the key. + let task4 = { + let cache4 = cache.clone(); + async move { + // Wait for 200 ms before calling `get`. + Timer::after(Duration::from_millis(200)).await; + let maybe_v = cache4.get(&KEY); + assert!(maybe_v.is_none()); + } + }; + + // Task5 will call `get` for the same key. It will call after task1's async + // block finished, so it will get the value insert by task1's async block. + let task5 = { + let cache5 = cache.clone(); + async move { + // Wait for 400 ms before calling `get`. + Timer::after(Duration::from_millis(400)).await; + let maybe_v = cache5.get(&KEY); + assert_eq!(maybe_v, Some("task1")); + } + }; + + futures::join!(task1, task2, task3, task4, task5); + } + + #[tokio::test] + async fn get_or_try_insert_with() { + let cache = Cache::new(100); + const KEY: u32 = 0; + + // This test will run eight async tasks: + // + // Task1 will be the first task to call `get_or_insert_with` for a key, so + // its async block will be evaluated and then an error will be returned. + // Nothing will be inserted to the cache. + let task1 = { + let cache1 = cache.clone(); + async move { + // Call `get_or_try_insert_with` immediately. + let v = cache1 + .get_or_try_insert_with(KEY, async { + // Wait for 300 ms and return an error. + Timer::after(Duration::from_millis(300)).await; + Err("task1 error".into()) + }) + .await; + assert!(v.is_err()); + } + }; + + // Task2 will be the second task to call `get_or_insert_with` for the same + // key, so its async block will not be evaluated. Once task1's async block + // finishes, it will get the same error value returned by task1's async + // block. + let task2 = { + let cache2 = cache.clone(); + async move { + // Wait for 100 ms before calling `get_or_try_insert_with`. + Timer::after(Duration::from_millis(100)).await; + let v = cache2 + .get_or_try_insert_with(KEY, async { unreachable!() }) + .await; + assert!(v.is_err()); + } + }; + + // Task3 will be the third task to call `get_or_insert_with` for the same + // key. By the time it calls, task1's async block should have finished + // already, but the key still does not exist in the cache. So its async block + // will be evaluated and then an okay &str value will be returned. That value + // will be inserted to the cache. + let task3 = { + let cache3 = cache.clone(); + async move { + // Wait for 400 ms before calling `get_or_try_insert_with`. + Timer::after(Duration::from_millis(400)).await; + let v = cache3 + .get_or_try_insert_with(KEY, async { + // Wait for 300 ms and return an Ok(&str) value. + Timer::after(Duration::from_millis(300)).await; + Ok("task3") + }) + .await; + assert_eq!(v.unwrap(), "task3"); + } + }; + + // Task4 will be the fourth task to call `get_or_insert_with` for the same + // key. So its async block will not be evaluated. Once task3's async block + // finishes, it will get the same okay &str value. + let task4 = { + let cache4 = cache.clone(); + async move { + // Wait for 500 ms before calling `get_or_try_insert_with`. + Timer::after(Duration::from_millis(500)).await; + let v = cache4 + .get_or_try_insert_with(KEY, async { unreachable!() }) + .await; + assert_eq!(v.unwrap(), "task3"); + } + }; + + // Task5 will be the fifth task to call `get_or_insert_with` for the same + // key. So its async block will not be evaluated. By the time it calls, + // task3's async block should have finished already, so its async block will + // not be evaluated and will get the value insert by task3's async block + // immediately. + let task5 = { + let cache5 = cache.clone(); + async move { + // Wait for 800 ms before calling `get_or_try_insert_with`. + Timer::after(Duration::from_millis(800)).await; + let v = cache5 + .get_or_try_insert_with(KEY, async { unreachable!() }) + .await; + assert_eq!(v.unwrap(), "task3"); + } + }; + + // Task6 will call `get` for the same key. It will call when task1's async + // block is still running, so it will get none for the key. + let task6 = { + let cache6 = cache.clone(); + async move { + // Wait for 200 ms before calling `get`. + Timer::after(Duration::from_millis(200)).await; + let maybe_v = cache6.get(&KEY); + assert!(maybe_v.is_none()); + } + }; + + // Task7 will call `get` for the same key. It will call after task1's async + // block finished with an error. So it will get none for the key. + let task7 = { + let cache7 = cache.clone(); + async move { + // Wait for 400 ms before calling `get`. + Timer::after(Duration::from_millis(400)).await; + let maybe_v = cache7.get(&KEY); + assert!(maybe_v.is_none()); + } + }; + + // Task8 will call `get` for the same key. It will call after task3's async + // block finished, so it will get the value insert by task3's async block. + let task8 = { + let cache8 = cache.clone(); + async move { + // Wait for 800 ms before calling `get`. + Timer::after(Duration::from_millis(800)).await; + let maybe_v = cache8.get(&KEY); + assert_eq!(maybe_v, Some("task3")); + } + }; + + futures::join!(task1, task2, task3, task4, task5, task6, task7, task8); + } } diff --git a/src/future/value_initializer.rs b/src/future/value_initializer.rs new file mode 100644 index 00000000..e64fe485 --- /dev/null +++ b/src/future/value_initializer.rs @@ -0,0 +1,111 @@ +use async_lock::RwLock; +use std::{ + error::Error, + future::Future, + hash::{BuildHasher, Hash}, + sync::Arc, +}; + +type Waiter = Arc>>>>>; + +pub(crate) enum InitResult { + Initialized(V), + ReadExisting(V), + InitErr(Arc>), +} + +pub(crate) struct ValueInitializer { + waiters: cht::HashMap, Waiter, S>, +} + +impl ValueInitializer +where + Arc: Eq + Hash, + V: Clone, + S: BuildHasher, +{ + pub(crate) fn with_hasher(hasher: S) -> Self { + Self { + waiters: cht::HashMap::with_hasher(hasher), + } + } + + pub(crate) async fn init_or_read(&self, key: Arc, init: F) -> InitResult + where + F: Future, + { + use InitResult::*; + + let waiter = Arc::new(RwLock::new(None)); + let mut lock = waiter.write().await; + + match self.try_insert_waiter(&key, &waiter) { + None => { + // Inserted. Resolve the init future. + let value = init.await; + *lock = Some(Ok(value.clone())); + Initialized(value) + } + Some(res) => { + // Value already exists. Drop our write lock and wait for a read lock + // to become available. + std::mem::drop(lock); + match &*res.read().await { + Some(Ok(value)) => ReadExisting(value.clone()), + Some(Err(_)) | None => unreachable!(), + } + } + } + } + + pub(crate) async fn try_init_or_read(&self, key: Arc, init: F) -> InitResult + where + F: Future>>, + { + use InitResult::*; + + let waiter = Arc::new(RwLock::new(None)); + let mut lock = waiter.write().await; + + match self.try_insert_waiter(&key, &waiter) { + None => { + // Inserted. Resolve the init future. + match init.await { + Ok(value) => { + *lock = Some(Ok(value.clone())); + Initialized(value) + } + Err(e) => { + let err = Arc::new(e); + *lock = Some(Err(Arc::clone(&err))); + self.remove_waiter(&key); + InitErr(err) + } + } + } + Some(res) => { + // Value already exists. Drop our write lock and wait for a read lock + // to become available. + std::mem::drop(lock); + match &*res.read().await { + Some(Ok(value)) => ReadExisting(value.clone()), + Some(Err(e)) => InitErr(Arc::clone(e)), + None => unreachable!(), + } + } + } + } + + #[inline] + pub(crate) fn remove_waiter(&self, key: &Arc) { + self.waiters.remove(key); + } + + fn try_insert_waiter(&self, key: &Arc, waiter: &Waiter) -> Option> { + let key = Arc::clone(key); + let waiter = Arc::clone(waiter); + + self.waiters + .insert_with_or_modify(key, || waiter, |_, w| Arc::clone(w)) + } +} diff --git a/src/sync.rs b/src/sync.rs index c98d82d0..7fa63299 100644 --- a/src/sync.rs +++ b/src/sync.rs @@ -19,6 +19,7 @@ mod deques; pub(crate) mod housekeeper; mod invalidator; mod segment; +mod value_initializer; pub use builder::CacheBuilder; pub use cache::Cache; diff --git a/src/sync/base_cache.rs b/src/sync/base_cache.rs index 853190c9..afb125ff 100644 --- a/src/sync/base_cache.rs +++ b/src/sync/base_cache.rs @@ -230,9 +230,7 @@ where } #[inline] - pub(crate) fn do_insert_with_hash(&self, key: K, hash: u64, value: V) -> WriteOp { - let key = Arc::new(key); - + pub(crate) fn do_insert_with_hash(&self, key: Arc, hash: u64, value: V) -> WriteOp { let op_cnt1 = Rc::new(AtomicU8::new(0)); let op_cnt2 = Rc::clone(&op_cnt1); let mut op1 = None; diff --git a/src/sync/cache.rs b/src/sync/cache.rs index f07b1051..7f6e4b17 100644 --- a/src/sync/cache.rs +++ b/src/sync/cache.rs @@ -1,14 +1,16 @@ use super::{ base_cache::{BaseCache, HouseKeeperArc, MAX_SYNC_REPEATS, WRITE_RETRY_INTERVAL_MICROS}, housekeeper::InnerSync, + value_initializer::ValueInitializer, ConcurrentCacheExt, PredicateId, WriteOp, }; -use crate::PredicateError; +use crate::{sync::value_initializer::InitResult, PredicateError}; use crossbeam_channel::{Sender, TrySendError}; use std::{ borrow::Borrow, collections::hash_map::RandomState, + error::Error, hash::{BuildHasher, Hash}, sync::Arc, time::Duration, @@ -162,6 +164,7 @@ use std::{ #[derive(Clone)] pub struct Cache { base: BaseCache, + value_initializer: Arc>, } unsafe impl Send for Cache @@ -215,11 +218,12 @@ where base: BaseCache::new( max_capacity, initial_capacity, - build_hasher, + build_hasher.clone(), time_to_live, time_to_idle, invalidator_enabled, ), + value_initializer: Arc::new(ValueInitializer::with_hasher(build_hasher)), } } @@ -249,15 +253,98 @@ where self.base.get_with_hash(key, hash) } + /// Ensures the value of the key exists by inserting the result of the init + /// function if not exist, and returns a _clone_ of the value. + /// + /// This method prevents to evaluate the init function multiple times on the same + /// key even if the method is concurrently called by many threads; only one of + /// the calls evaluates its function, and other calls wait for that function to + /// complete. + pub fn get_or_insert_with(&self, key: K, init: impl FnOnce() -> V) -> V { + let hash = self.base.hash(&key); + let key = Arc::new(key); + self.get_or_insert_with_hash_and_fun(key, hash, init) + } + + pub(crate) fn get_or_insert_with_hash_and_fun( + &self, + key: Arc, + hash: u64, + init: impl FnOnce() -> V, + ) -> V { + if let Some(v) = self.get_with_hash(&key, hash) { + return v; + } + + match self.value_initializer.init_or_read(Arc::clone(&key), init) { + InitResult::Initialized(v) => { + self.insert_with_hash(Arc::clone(&key), hash, v.clone()); + self.value_initializer.remove_waiter(&key); + v + } + InitResult::ReadExisting(v) => v, + InitResult::InitErr(_) => unreachable!(), + } + } + + /// Try to ensure the value of the key exists by inserting an `Ok` result of the + /// init function if not exist, and returns a _clone_ of the value or the `Err` + /// returned by the function. + /// + /// This method prevents to evaluate the init function multiple times on the same + /// key even if the method is concurrently called by many threads; only one of + /// the calls evaluates its function, and other calls wait for that function to + /// complete. + pub fn get_or_try_insert_with( + &self, + key: K, + init: F, + ) -> Result>> + where + F: FnOnce() -> Result>, + { + let hash = self.base.hash(&key); + let key = Arc::new(key); + self.get_or_try_insert_with_hash_and_fun(key, hash, init) + } + + pub(crate) fn get_or_try_insert_with_hash_and_fun( + &self, + key: Arc, + hash: u64, + init: F, + ) -> Result>> + where + F: FnOnce() -> Result>, + { + if let Some(v) = self.get_with_hash(&key, hash) { + return Ok(v); + } + + match self + .value_initializer + .try_init_or_read(Arc::clone(&key), init) + { + InitResult::Initialized(v) => { + self.insert_with_hash(Arc::clone(&key), hash, v.clone()); + self.value_initializer.remove_waiter(&key); + Ok(v) + } + InitResult::ReadExisting(v) => Ok(v), + InitResult::InitErr(e) => Err(e), + } + } + /// Inserts a key-value pair into the cache. /// /// If the cache has this key present, the value is updated. pub fn insert(&self, key: K, value: V) { let hash = self.base.hash(&key); + let key = Arc::new(key); self.insert_with_hash(key, hash, value) } - pub(crate) fn insert_with_hash(&self, key: K, hash: u64, value: V) { + pub(crate) fn insert_with_hash(&self, key: Arc, hash: u64, value: V) { let op = self.base.do_insert_with_hash(key, hash, value); let hk = self.base.housekeeper.as_ref(); Self::schedule_write_op(&self.base.write_op_ch, op, hk).expect("Failed to insert"); @@ -708,4 +795,215 @@ mod tests { assert_eq!(cache.get(&"b"), None); assert!(cache.is_table_empty()); } + + #[test] + fn get_or_insert_with() { + use std::thread::{sleep, spawn}; + + let cache = Cache::new(100); + const KEY: u32 = 0; + + // This test will run five threads: + // + // Thread1 will be the first thread to call `get_or_insert_with` for a key, so + // its async block will be evaluated and then a &str value "thread1" will be + // inserted to the cache. + let thread1 = { + let cache1 = cache.clone(); + spawn(move || { + // Call `get_or_insert_with` immediately. + let v = cache1.get_or_insert_with(KEY, || { + // Wait for 300 ms and return a &str value. + sleep(Duration::from_millis(300)); + "thread1" + }); + assert_eq!(v, "thread1"); + }) + }; + + // Thread2 will be the second thread to call `get_or_insert_with` for the same + // key, so its async block will not be evaluated. Once thread1's async block + // finishes, it will get the value inserted by thread1's async block. + let thread2 = { + let cache2 = cache.clone(); + spawn(move || { + // Wait for 100 ms before calling `get_or_insert_with`. + sleep(Duration::from_millis(100)); + let v = cache2.get_or_insert_with(KEY, || unreachable!()); + assert_eq!(v, "thread1"); + }) + }; + + // Thread3 will be the third thread to call `get_or_insert_with` for the same + // key. By the time it calls, thread1's async block should have finished + // already and the value should be already inserted to the cache. So its + // async block will not be evaluated and will get the value insert by thread1's + // async block immediately. + let thread3 = { + let cache3 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get_or_insert_with`. + sleep(Duration::from_millis(400)); + let v = cache3.get_or_insert_with(KEY, || unreachable!()); + assert_eq!(v, "thread1"); + }) + }; + + // Thread4 will call `get` for the same key. It will call when thread1's async + // block is still running, so it will get none for the key. + let thread4 = { + let cache4 = cache.clone(); + spawn(move || { + // Wait for 200 ms before calling `get`. + sleep(Duration::from_millis(200)); + let maybe_v = cache4.get(&KEY); + assert!(maybe_v.is_none()); + }) + }; + + // Thread5 will call `get` for the same key. It will call after thread1's async + // block finished, so it will get the value insert by thread1's async block. + let thread5 = { + let cache5 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get`. + sleep(Duration::from_millis(400)); + let maybe_v = cache5.get(&KEY); + assert_eq!(maybe_v, Some("thread1")); + }) + }; + + for t in vec![thread1, thread2, thread3, thread4, thread5] { + t.join().expect("Failed to join"); + } + } + + #[test] + fn get_or_try_insert_with() { + use std::thread::{sleep, spawn}; + + let cache = Cache::new(100); + const KEY: u32 = 0; + + // This test will run eight async threads: + // + // Thread1 will be the first thread to call `get_or_insert_with` for a key, so + // its async block will be evaluated and then an error will be returned. + // Nothing will be inserted to the cache. + let thread1 = { + let cache1 = cache.clone(); + spawn(move || { + // Call `get_or_try_insert_with` immediately. + let v = cache1.get_or_try_insert_with(KEY, || { + // Wait for 300 ms and return an error. + sleep(Duration::from_millis(300)); + Err("thread1 error".into()) + }); + assert!(v.is_err()); + }) + }; + + // Thread2 will be the second thread to call `get_or_insert_with` for the same + // key, so its async block will not be evaluated. Once thread1's async block + // finishes, it will get the same error value returned by thread1's async + // block. + let thread2 = { + let cache2 = cache.clone(); + spawn(move || { + // Wait for 100 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(100)); + let v = cache2.get_or_try_insert_with(KEY, || unreachable!()); + assert!(v.is_err()); + }) + }; + + // Thread3 will be the third thread to call `get_or_insert_with` for the same + // key. By the time it calls, thread1's async block should have finished + // already, but the key still does not exist in the cache. So its async block + // will be evaluated and then an okay &str value will be returned. That value + // will be inserted to the cache. + let thread3 = { + let cache3 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(400)); + let v = cache3.get_or_try_insert_with(KEY, || { + // Wait for 300 ms and return an Ok(&str) value. + sleep(Duration::from_millis(300)); + Ok("thread3") + }); + assert_eq!(v.unwrap(), "thread3"); + }) + }; + + // thread4 will be the fourth thread to call `get_or_insert_with` for the same + // key. So its async block will not be evaluated. Once thread3's async block + // finishes, it will get the same okay &str value. + let thread4 = { + let cache4 = cache.clone(); + spawn(move || { + // Wait for 500 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(500)); + let v = cache4.get_or_try_insert_with(KEY, || unreachable!()); + assert_eq!(v.unwrap(), "thread3"); + }) + }; + + // Thread5 will be the fifth thread to call `get_or_insert_with` for the same + // key. So its async block will not be evaluated. By the time it calls, + // thread3's async block should have finished already, so its async block will + // not be evaluated and will get the value insert by thread3's async block + // immediately. + let thread5 = { + let cache5 = cache.clone(); + spawn(move || { + // Wait for 800 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(800)); + let v = cache5.get_or_try_insert_with(KEY, || unreachable!()); + assert_eq!(v.unwrap(), "thread3"); + }) + }; + + // Thread6 will call `get` for the same key. It will call when thread1's async + // block is still running, so it will get none for the key. + let thread6 = { + let cache6 = cache.clone(); + spawn(move || { + // Wait for 200 ms before calling `get`. + sleep(Duration::from_millis(200)); + let maybe_v = cache6.get(&KEY); + assert!(maybe_v.is_none()); + }) + }; + + // Thread7 will call `get` for the same key. It will call after thread1's async + // block finished with an error. So it will get none for the key. + let thread7 = { + let cache7 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get`. + sleep(Duration::from_millis(400)); + let maybe_v = cache7.get(&KEY); + assert!(maybe_v.is_none()); + }) + }; + + // Thread8 will call `get` for the same key. It will call after thread3's async + // block finished, so it will get the value insert by thread3's async block. + let thread8 = { + let cache8 = cache.clone(); + spawn(move || { + // Wait for 800 ms before calling `get`. + sleep(Duration::from_millis(800)); + let maybe_v = cache8.get(&KEY); + assert_eq!(maybe_v, Some("thread3")); + }) + }; + + for t in vec![ + thread1, thread2, thread3, thread4, thread5, thread6, thread7, thread8, + ] { + t.join().expect("Failed to join"); + } + } } diff --git a/src/sync/segment.rs b/src/sync/segment.rs index 271d8ada..9c02338d 100644 --- a/src/sync/segment.rs +++ b/src/sync/segment.rs @@ -4,6 +4,7 @@ use crate::PredicateError; use std::{ borrow::Borrow, collections::hash_map::RandomState, + error::Error, hash::{BuildHasher, Hash, Hasher}, sync::Arc, time::Duration, @@ -131,11 +132,50 @@ where self.inner.select(hash).get_with_hash(key, hash) } + /// Ensures the value of the key exists by inserting the result of the init + /// function if not exist, and returns a _clone_ of the value. + /// + /// This method prevents to evaluate the init function multiple times on the same + /// key even if the method is concurrently called by many threads; only one of + /// the calls evaluates its function, and other calls wait for that function to + /// complete. + pub fn get_or_insert_with(&self, key: K, init: impl FnOnce() -> V) -> V { + let hash = self.inner.hash(&key); + let key = Arc::new(key); + self.inner + .select(hash) + .get_or_insert_with_hash_and_fun(key, hash, init) + } + + /// Try to ensure the value of the key exists by inserting an `Ok` result of the + /// init function if not exist, and returns a _clone_ of the value or the `Err` + /// returned by the function. + /// + /// This method prevents to evaluate the init function multiple times on the same + /// key even if the method is concurrently called by many threads; only one of + /// the calls evaluates its function, and other calls wait for that function to + /// complete. + pub fn get_or_try_insert_with( + &self, + key: K, + init: F, + ) -> Result>> + where + F: FnOnce() -> Result>, + { + let hash = self.inner.hash(&key); + let key = Arc::new(key); + self.inner + .select(hash) + .get_or_try_insert_with_hash_and_fun(key, hash, init) + } + /// Inserts a key-value pair into the cache. /// /// If the cache has this key present, the value is updated. pub fn insert(&self, key: K, value: V) { let hash = self.inner.hash(&key); + let key = Arc::new(key); self.inner.select(hash).insert_with_hash(key, hash, value); } @@ -567,4 +607,215 @@ mod tests { Ok(()) } + + #[test] + fn get_or_insert_with() { + use std::thread::{sleep, spawn}; + + let cache = SegmentedCache::new(100, 4); + const KEY: u32 = 0; + + // This test will run five threads: + // + // Thread1 will be the first thread to call `get_or_insert_with` for a key, so + // its async block will be evaluated and then a &str value "thread1" will be + // inserted to the cache. + let thread1 = { + let cache1 = cache.clone(); + spawn(move || { + // Call `get_or_insert_with` immediately. + let v = cache1.get_or_insert_with(KEY, || { + // Wait for 300 ms and return a &str value. + sleep(Duration::from_millis(300)); + "thread1" + }); + assert_eq!(v, "thread1"); + }) + }; + + // Thread2 will be the second thread to call `get_or_insert_with` for the same + // key, so its async block will not be evaluated. Once thread1's async block + // finishes, it will get the value inserted by thread1's async block. + let thread2 = { + let cache2 = cache.clone(); + spawn(move || { + // Wait for 100 ms before calling `get_or_insert_with`. + sleep(Duration::from_millis(100)); + let v = cache2.get_or_insert_with(KEY, || unreachable!()); + assert_eq!(v, "thread1"); + }) + }; + + // Thread3 will be the third thread to call `get_or_insert_with` for the same + // key. By the time it calls, thread1's async block should have finished + // already and the value should be already inserted to the cache. So its + // async block will not be evaluated and will get the value insert by thread1's + // async block immediately. + let thread3 = { + let cache3 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get_or_insert_with`. + sleep(Duration::from_millis(400)); + let v = cache3.get_or_insert_with(KEY, || unreachable!()); + assert_eq!(v, "thread1"); + }) + }; + + // Thread4 will call `get` for the same key. It will call when thread1's async + // block is still running, so it will get none for the key. + let thread4 = { + let cache4 = cache.clone(); + spawn(move || { + // Wait for 200 ms before calling `get`. + sleep(Duration::from_millis(200)); + let maybe_v = cache4.get(&KEY); + assert!(maybe_v.is_none()); + }) + }; + + // Thread5 will call `get` for the same key. It will call after thread1's async + // block finished, so it will get the value insert by thread1's async block. + let thread5 = { + let cache5 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get`. + sleep(Duration::from_millis(400)); + let maybe_v = cache5.get(&KEY); + assert_eq!(maybe_v, Some("thread1")); + }) + }; + + for t in vec![thread1, thread2, thread3, thread4, thread5] { + t.join().expect("Failed to join"); + } + } + + #[test] + fn get_or_try_insert_with() { + use std::thread::{sleep, spawn}; + + let cache = SegmentedCache::new(100, 4); + const KEY: u32 = 0; + + // This test will run eight async threads: + // + // Thread1 will be the first thread to call `get_or_insert_with` for a key, so + // its async block will be evaluated and then an error will be returned. + // Nothing will be inserted to the cache. + let thread1 = { + let cache1 = cache.clone(); + spawn(move || { + // Call `get_or_try_insert_with` immediately. + let v = cache1.get_or_try_insert_with(KEY, || { + // Wait for 300 ms and return an error. + sleep(Duration::from_millis(300)); + Err("thread1 error".into()) + }); + assert!(v.is_err()); + }) + }; + + // Thread2 will be the second thread to call `get_or_insert_with` for the same + // key, so its async block will not be evaluated. Once thread1's async block + // finishes, it will get the same error value returned by thread1's async + // block. + let thread2 = { + let cache2 = cache.clone(); + spawn(move || { + // Wait for 100 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(100)); + let v = cache2.get_or_try_insert_with(KEY, || unreachable!()); + assert!(v.is_err()); + }) + }; + + // Thread3 will be the third thread to call `get_or_insert_with` for the same + // key. By the time it calls, thread1's async block should have finished + // already, but the key still does not exist in the cache. So its async block + // will be evaluated and then an okay &str value will be returned. That value + // will be inserted to the cache. + let thread3 = { + let cache3 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(400)); + let v = cache3.get_or_try_insert_with(KEY, || { + // Wait for 300 ms and return an Ok(&str) value. + sleep(Duration::from_millis(300)); + Ok("thread3") + }); + assert_eq!(v.unwrap(), "thread3"); + }) + }; + + // thread4 will be the fourth thread to call `get_or_insert_with` for the same + // key. So its async block will not be evaluated. Once thread3's async block + // finishes, it will get the same okay &str value. + let thread4 = { + let cache4 = cache.clone(); + spawn(move || { + // Wait for 500 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(500)); + let v = cache4.get_or_try_insert_with(KEY, || unreachable!()); + assert_eq!(v.unwrap(), "thread3"); + }) + }; + + // Thread5 will be the fifth thread to call `get_or_insert_with` for the same + // key. So its async block will not be evaluated. By the time it calls, + // thread3's async block should have finished already, so its async block will + // not be evaluated and will get the value insert by thread3's async block + // immediately. + let thread5 = { + let cache5 = cache.clone(); + spawn(move || { + // Wait for 800 ms before calling `get_or_try_insert_with`. + sleep(Duration::from_millis(800)); + let v = cache5.get_or_try_insert_with(KEY, || unreachable!()); + assert_eq!(v.unwrap(), "thread3"); + }) + }; + + // Thread6 will call `get` for the same key. It will call when thread1's async + // block is still running, so it will get none for the key. + let thread6 = { + let cache6 = cache.clone(); + spawn(move || { + // Wait for 200 ms before calling `get`. + sleep(Duration::from_millis(200)); + let maybe_v = cache6.get(&KEY); + assert!(maybe_v.is_none()); + }) + }; + + // Thread7 will call `get` for the same key. It will call after thread1's async + // block finished with an error. So it will get none for the key. + let thread7 = { + let cache7 = cache.clone(); + spawn(move || { + // Wait for 400 ms before calling `get`. + sleep(Duration::from_millis(400)); + let maybe_v = cache7.get(&KEY); + assert!(maybe_v.is_none()); + }) + }; + + // Thread8 will call `get` for the same key. It will call after thread3's async + // block finished, so it will get the value insert by thread3's async block. + let thread8 = { + let cache8 = cache.clone(); + spawn(move || { + // Wait for 800 ms before calling `get`. + sleep(Duration::from_millis(800)); + let maybe_v = cache8.get(&KEY); + assert_eq!(maybe_v, Some("thread3")); + }) + }; + + for t in vec![ + thread1, thread2, thread3, thread4, thread5, thread6, thread7, thread8, + ] { + t.join().expect("Failed to join"); + } + } } diff --git a/src/sync/value_initializer.rs b/src/sync/value_initializer.rs new file mode 100644 index 00000000..cad10b26 --- /dev/null +++ b/src/sync/value_initializer.rs @@ -0,0 +1,107 @@ +use parking_lot::RwLock; +use std::{ + error::Error, + hash::{BuildHasher, Hash}, + sync::Arc, +}; + +type Waiter = Arc>>>>>; + +pub(crate) enum InitResult { + Initialized(V), + ReadExisting(V), + InitErr(Arc>), +} + +pub(crate) struct ValueInitializer { + waiters: cht::HashMap, Waiter, S>, +} + +impl ValueInitializer +where + Arc: Eq + Hash, + V: Clone, + S: BuildHasher, +{ + pub(crate) fn with_hasher(hasher: S) -> Self { + Self { + waiters: cht::HashMap::with_hasher(hasher), + } + } + + pub(crate) fn init_or_read(&self, key: Arc, init: impl FnOnce() -> V) -> InitResult { + use InitResult::*; + + let waiter = Arc::new(RwLock::new(None)); + let mut lock = waiter.write(); + + match self.try_insert_waiter(&key, &waiter) { + None => { + // Inserted. Evaluate the init closure. + let value = init(); + *lock = Some(Ok(value.clone())); + Initialized(value) + } + Some(res) => { + // Value already exists. Drop our write lock and wait for a read lock + // to become available. + std::mem::drop(lock); + match &*res.read() { + Some(Ok(value)) => ReadExisting(value.clone()), + Some(Err(_)) | None => unreachable!(), + } + } + } + } + + pub(crate) fn try_init_or_read(&self, key: Arc, init: F) -> InitResult + where + F: FnOnce() -> Result>, + { + use InitResult::*; + + let waiter = Arc::new(RwLock::new(None)); + let mut lock = waiter.write(); + + match self.try_insert_waiter(&key, &waiter) { + None => { + // Inserted. Evaluate the init closure. + match init() { + Ok(value) => { + *lock = Some(Ok(value.clone())); + Initialized(value) + } + Err(e) => { + let err = Arc::new(e); + *lock = Some(Err(Arc::clone(&err))); + self.remove_waiter(&key); + InitErr(err) + } + } + } + Some(res) => { + // Value already exists. Drop our write lock and wait for a read lock + // to become available. + std::mem::drop(lock); + match &*res.read() { + Some(Ok(value)) => ReadExisting(value.clone()), + Some(Err(e)) => InitErr(Arc::clone(e)), + None => unreachable!(), + } + } + } + } + + #[inline] + pub(crate) fn remove_waiter(&self, key: &Arc) { + self.waiters.remove(key); + } + + fn try_insert_waiter(&self, key: &Arc, waiter: &Waiter) -> Option> { + let key = Arc::clone(key); + let waiter = Arc::clone(waiter); + + self.waiters + .insert_with_or_modify(key, || waiter, |_, w| Arc::clone(w)) + } +}