diff --git a/fuzz/fuzz_targets/common.rs b/fuzz/fuzz_targets/common.rs index 3c0e7eae..1aebcc09 100644 --- a/fuzz/fuzz_targets/common.rs +++ b/fuzz/fuzz_targets/common.rs @@ -144,6 +144,14 @@ pub(crate) enum FuzzOperation { modulus: U64Between<1, 8>, reversed: bool, }, + Retain { + modulus: U64Between<1, 8>, + }, + RetainIn { + start_key: BoundedU64, + len: BoundedU64, + modulus: U64Between<1, 8>, + }, Range { start_key: BoundedU64, len: BoundedU64, diff --git a/fuzz/fuzz_targets/fuzz_redb.rs b/fuzz/fuzz_targets/fuzz_redb.rs index 145df851..7df2ae90 100644 --- a/fuzz/fuzz_targets/fuzz_redb.rs +++ b/fuzz/fuzz_targets/fuzz_redb.rs @@ -274,6 +274,12 @@ fn handle_multimap_table_op(op: &FuzzOperation, reference: &mut BTreeMap { // no-op. Multimap tables don't support this } + FuzzOperation::Retain { .. } => { + // no-op. Multimap tables don't support this + } + FuzzOperation::RetainIn { .. } => { + // no-op. Multimap tables don't support this + } FuzzOperation::Range { start_key, len, @@ -428,6 +434,18 @@ fn handle_table_op(op: &FuzzOperation, reference: &mut BTreeMap, tab panic!(); } } + FuzzOperation::RetainIn { start_key, len, modulus } => { + let start = start_key.value; + let end = start + len.value; + let modulus = modulus.value; + table.retain_in(|x, _| x % modulus == 0, start..end)?; + reference.retain(|x, _| (*x < start || *x >= end) || *x % modulus == 0); + } + FuzzOperation::Retain { modulus } => { + let modulus = modulus.value; + table.retain(|x, _| x % modulus == 0)?; + reference.retain(|x, _| *x % modulus == 0); + } FuzzOperation::Range { start_key, len, diff --git a/src/table.rs b/src/table.rs index 8d363ca8..43feb98c 100644 --- a/src/table.rs +++ b/src/table.rs @@ -160,6 +160,30 @@ impl<'txn, K: Key + 'static, V: Value + 'static> Table<'txn, K, V> { .map(DrainFilter::new) } + /// Applies `predicate` to all key-value pairs. All entries for which + /// `predicate` evaluates to `false` are removed. + /// + pub fn retain Fn(K::SelfType<'f>, V::SelfType<'f>) -> bool>( + &mut self, + predicate: F, + ) -> Result { + self.tree.retain_in::, F>(predicate, ..) + } + + /// Applies `predicate` to all key-value pairs in the range `start..end`. All entries for which + /// `predicate` evaluates to `false` are removed. + /// + pub fn retain_in<'a, KR, F: for<'f> Fn(K::SelfType<'f>, V::SelfType<'f>) -> bool>( + &mut self, + predicate: F, + range: impl RangeBounds + 'a, + ) -> Result + where + KR: Borrow> + 'a, + { + self.tree.retain_in(predicate, range) + } + /// Insert mapping of the given key to the given value /// /// Returns the old value, if the key was present in the table diff --git a/src/tree_store/btree.rs b/src/tree_store/btree.rs index a16f8df4..4386d5ec 100644 --- a/src/tree_store/btree.rs +++ b/src/tree_store/btree.rs @@ -436,6 +436,29 @@ impl BtreeMut<'_, K, V> { Ok(result) } + pub(crate) fn retain_in<'a, KR, F: for<'f> Fn(K::SelfType<'f>, V::SelfType<'f>) -> bool>( + &mut self, + predicate: F, + range: impl RangeBounds + 'a, + ) -> Result + where + KR: Borrow> + 'a, + { + let iter = self.range(&range)?; + let mut freed = vec![]; + let mut operation: MutateHelper<'_, '_, K, V> = + MutateHelper::new_do_not_modify(&mut self.root, self.mem.clone(), &mut freed); + for entry in iter { + let entry = entry?; + if !predicate(entry.key(), entry.value()) { + assert!(operation.delete(&entry.key())?.is_some()); + } + } + self.freed_pages.lock().unwrap().extend_from_slice(&freed); + + Ok(()) + } + pub(crate) fn len(&self) -> Result { self.read_tree()?.len() } diff --git a/tests/basic_tests.rs b/tests/basic_tests.rs index b9b67f45..278553f9 100644 --- a/tests/basic_tests.rs +++ b/tests/basic_tests.rs @@ -192,6 +192,75 @@ fn drain() { write_txn.abort().unwrap(); } +#[test] +fn retain() { + let tmpfile = create_tempfile(); + let db = Database::create(tmpfile.path()).unwrap(); + let write_txn = db.begin_write().unwrap(); + { + let mut table = write_txn.open_table(U64_TABLE).unwrap(); + for i in 0..10 { + table.insert(&i, &i).unwrap(); + } + // Test retain uncommitted data + table.retain(|k, _| k >= 5).unwrap(); + for i in 0..5 { + assert!(table.insert(&i, &i).unwrap().is_none()); + } + assert_eq!(table.len().unwrap(), 10); + + // Test matching on the value + table.retain(|_, v| v >= 5).unwrap(); + for i in 0..5 { + assert!(table.insert(&i, &i).unwrap().is_none()); + } + assert_eq!(table.len().unwrap(), 10); + + // Test retain_in + table.retain_in(|_, _| false, ..5).unwrap(); + for i in 0..5 { + assert!(table.insert(&i, &i).unwrap().is_none()); + } + assert_eq!(table.len().unwrap(), 10); + } + write_txn.commit().unwrap(); + + let write_txn = db.begin_write().unwrap(); + { + let mut table = write_txn.open_table(U64_TABLE).unwrap(); + assert_eq!(table.len().unwrap(), 10); + table.retain(|x, _| x >= 5).unwrap(); + assert_eq!(table.len().unwrap(), 5); + + let mut i = 5u64; + for item in table.range(0..10).unwrap() { + let (k, v) = item.unwrap(); + assert_eq!(i, k.value()); + assert_eq!(i, v.value()); + i += 1; + } + } + write_txn.abort().unwrap(); + + let write_txn = db.begin_write().unwrap(); + { + let mut table = write_txn.open_table(U64_TABLE).unwrap(); + table.retain(|x, _| x % 2 == 0).unwrap(); + } + write_txn.commit().unwrap(); + + let read_txn = db.begin_write().unwrap(); + { + let table = read_txn.open_table(U64_TABLE).unwrap(); + assert_eq!(table.len().unwrap(), 5); + for entry in table.iter().unwrap() { + let (k, v) = entry.unwrap(); + assert_eq!(k.value() % 2, 0); + assert_eq!(k.value(), v.value()); + } + } +} + #[test] fn drain_filter() { let tmpfile = create_tempfile();