diff --git a/Cargo.lock b/Cargo.lock index f463a3951909..0fb7468b1d87 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -263,6 +263,7 @@ dependencies = [ name = "actor" version = "0.1.0" dependencies = [ + "ahash 0.3.8", "bitfield", "byteorder 1.3.4", "clock", @@ -761,8 +762,8 @@ checksum = "5f0dc55f2d8a1a85650ac47858bb001b4c0dd73d79e3c455a842925e68d29cd3" name = "bitfield" version = "0.1.0" dependencies = [ + "ahash 0.3.8", "bitvec", - "fnv", "forest_encoding", "rand 0.7.3", "rand_xorshift", diff --git a/blockchain/state_manager/src/utils.rs b/blockchain/state_manager/src/utils.rs index 24aab85c156e..e17f781c402e 100644 --- a/blockchain/state_manager/src/utils.rs +++ b/blockchain/state_manager/src/utils.rs @@ -94,13 +94,9 @@ fn get_proving_set_raw( where DB: BlockStore, { - let mut not_proving = actor_state - .faults - .clone() - .merge(&actor_state.recoveries) - .map_err(|_| Error::Other("Could not merge bitfield".to_string()))?; + let not_proving = &actor_state.faults | &actor_state.recoveries; actor_state - .load_sector_infos(&*state_manager.get_block_store(), &mut not_proving) + .load_sector_infos(&*state_manager.get_block_store(), ¬_proving) .map_err(|err| Error::Other(format!("failed to get proving set :{:}", err))) } diff --git a/utils/bitfield/Cargo.toml b/utils/bitfield/Cargo.toml index c06a8a70342a..1e87f215a7f2 100644 --- a/utils/bitfield/Cargo.toml +++ b/utils/bitfield/Cargo.toml @@ -12,7 +12,7 @@ bitvec = "0.17.3" unsigned-varint = "0.4" serde = { version = "1.0", features = ["derive"] } serde_bytes = "0.11.3" -fnv = "1.0.6" +ahash = "0.3" [dev-dependencies] rand_xorshift = "0.2.0" diff --git a/utils/bitfield/src/bitvec_serde.rs b/utils/bitfield/src/bitvec_serde.rs deleted file mode 100644 index 7ac28e1d8ca2..000000000000 --- a/utils/bitfield/src/bitvec_serde.rs +++ /dev/null @@ -1,80 +0,0 @@ -// Copyright 2020 ChainSafe Systems -// SPDX-License-Identifier: Apache-2.0, MIT - -use super::{decode_and_apply_cache, rleplus::encode, BitField}; -use bitvec::prelude::BitVec; -use serde::{ser, Deserialize, Deserializer, Serialize, Serializer}; - -impl Serialize for BitField { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - match self { - BitField::Encoded { bv, set, unset } => { - if set.is_empty() && unset.is_empty() { - serde_bytes::serialize(bv.as_slice(), serializer) - } else { - let decoded = - decode_and_apply_cache(bv, set, unset).map_err(ser::Error::custom)?; - serde_bytes::serialize(encode(&decoded).as_slice(), serializer) - } - } - BitField::Decoded(bv) => serde_bytes::serialize(encode(bv).as_slice(), serializer), - } - } -} - -impl<'de> Deserialize<'de> for BitField { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - let bz: Vec = serde_bytes::deserialize(deserializer)?; - Ok(BitField::Encoded { - bv: BitVec::from_vec(bz), - set: Default::default(), - unset: Default::default(), - }) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use bitvec::bitvec; - use encoding::{from_slice, to_vec}; - - #[test] - fn serialize_node_symmetric() { - let bit_field: BitField = bitvec![Lsb0, u8; 0, 1, 0, 1, 1, 1, 1, 1, 1].into(); - let cbor_bz = to_vec(&bit_field).unwrap(); - let mut deserialized: BitField = from_slice(&cbor_bz).unwrap(); - assert_eq!(deserialized.count().unwrap(), 7); - // assert_eq!(deserialized, bit_field); - } - - #[test] - // ported test from specs-actors `bitfield_test.go` with added vector - fn bit_vec_unset_vector() { - let mut bf = BitField::default(); - bf.set(1); - bf.set(2); - bf.set(3); - bf.set(4); - bf.set(5); - - bf.unset(3); - - assert_ne!(bf.get(3).unwrap(), true); - assert_eq!(bf.count().unwrap(), 4); - - // Test cbor marshal and unmarshal - let cbor_bz = to_vec(&bf).unwrap(); - assert_eq!(&cbor_bz, &[0x43, 0xa8, 0x54, 0x0]); - let mut deserialized: BitField = from_slice(&cbor_bz).unwrap(); - - assert_eq!(deserialized.count().unwrap(), 4); - assert_ne!(bf.get(3).unwrap(), true); - } -} diff --git a/utils/bitfield/src/iter.rs b/utils/bitfield/src/iter.rs new file mode 100644 index 000000000000..869af9b2abbb --- /dev/null +++ b/utils/bitfield/src/iter.rs @@ -0,0 +1,596 @@ +// Copyright 2020 ChainSafe Systems +// SPDX-License-Identifier: Apache-2.0, MIT + +use std::{ + iter::{self, FusedIterator}, + ops::Range, +}; + +/// A trait for iterators over `Range`. +/// +/// Requirements: +/// - all ranges are non-empty +/// - the ranges are in ascending order +/// - no two ranges overlap or touch +/// - the iterator must be fused, i.e. once it has returned `None`, it must keep returning `None` +pub trait RangeIterator: FusedIterator> + Sized { + /// Returns a new `RangeIterator` over the bits that are in `self`, in `other`, or in both. + fn merge(self, other: R) -> Union { + Union { + a_iter: self, + b_iter: other, + a_range: None, + b_range: None, + } + } + + /// Returns a new `RangeIterator` over the bits that are in both `self` and `other`. + fn intersection(self, other: R) -> Intersection { + Intersection { + a_iter: self, + b_iter: other, + a_range: None, + b_range: None, + } + } + + /// Returns a new `RangeIterator` over the bits that are in `self` but not in `other`. + fn difference(self, other: R) -> Difference { + Difference { + a_iter: self, + b_iter: other, + a_range: None, + b_range: None, + } + } + + /// Returns a new `RangeIterator` over the bits in `self` after skipping the first `n` bits. + fn skip_bits(self, n: usize) -> Skip { + Skip { + iter: self, + skip: n, + } + } + + /// Returns a new `RangeIterator` over the first `n` bits in `self`. + fn take_bits(self, n: usize) -> Take { + Take { + iter: self, + take: n, + } + } +} + +/// A `RangeIterator` over the bits that represent the union of two other `RangeIterator`s. +pub struct Union { + a_iter: A, + b_iter: B, + a_range: Option>, + b_range: Option>, +} + +impl Iterator for Union { + type Item = Range; + + fn next(&mut self) -> Option { + let (mut a, mut b) = match ( + self.a_range.take().or_else(|| self.a_iter.next()), + self.b_range.take().or_else(|| self.b_iter.next()), + ) { + (Some(a), Some(b)) => (a, b), + (a, b) => return a.or(b), + }; + + loop { + if a.start <= b.start { + if a.end < b.start { + // a.start < a.end < b.start < b.end + // + // a: -xxx----- + // b: -----xxx- + + self.b_range = Some(b); + return Some(a); + } else if a.end < b.end { + // a.start <= b.start <= a.end < b.end + // + // a: -?xxx--- + // b: --xxxxx- + + // we resize `b` to be the union of `a` and `b`, but don't + // return it yet because it might overlap with another range + // in `a_iter` + b.start = a.start; + match self.a_iter.next() { + Some(range) => a = range, + None => return Some(b), + } + } else { + // a.start <= b.start < b.end <= a.end + // + // a: -?xxx?- + // b: --xxx-- + + match self.b_iter.next() { + Some(range) => b = range, + None => return Some(a), + } + } + } else { + // the union operator is symmetric, so this is exactly + // the same as above but with `a` and `b` swapped + + if b.end < a.start { + self.a_range = Some(a); + return Some(b); + } else if b.end < a.end { + a.start = b.start; + match self.b_iter.next() { + Some(range) => b = range, + None => return Some(a), + } + } else { + match self.a_iter.next() { + Some(range) => a = range, + None => return Some(b), + } + } + } + } + } +} + +impl FusedIterator for Union {} +impl RangeIterator for Union {} + +/// A `RangeIterator` over the bits that represent the intersection of two other `RangeIterator`s. +pub struct Intersection { + a_iter: A, + b_iter: B, + a_range: Option>, + b_range: Option>, +} + +impl Iterator for Intersection { + type Item = Range; + + fn next(&mut self) -> Option { + let (mut a, mut b) = match ( + self.a_range.take().or_else(|| self.a_iter.next()), + self.b_range.take().or_else(|| self.b_iter.next()), + ) { + (Some(a), Some(b)) => (a, b), + _ => return None, + }; + + loop { + if a.start <= b.start { + if a.end <= b.start { + // a.start < a.end <= b.start < b.end + // + // a: -xxx----- + // b: -----xxx- + + a = self.a_iter.next()?; + } else if a.end < b.end { + // a.start <= b.start < a.end < b.end + // + // a: -?xxx--- + // b: --xxxxx- + + let intersection = b.start..a.end; + self.b_range = Some(b); + return Some(intersection); + } else { + // a.start <= b.start < b.end <= a.end + // + // a: -?xxx?- + // b: --xxx-- + + self.a_range = Some(a); + return Some(b); + } + } else { + // the intersection operator is symmetric, so this is exactly + // the same as above but with `a` and `b` swapped + + if b.end <= a.start { + b = self.b_iter.next()?; + } else if b.end < a.end { + let intersection = a.start..b.end; + self.a_range = Some(a); + return Some(intersection); + } else { + self.b_range = Some(b); + return Some(a); + } + } + } + } +} + +impl FusedIterator for Intersection {} +impl RangeIterator for Intersection {} + +/// A `RangeIterator` over the bits that represent the difference between two other `RangeIterator`s. +pub struct Difference { + a_iter: A, + b_iter: B, + a_range: Option>, + b_range: Option>, +} + +impl Iterator for Difference { + type Item = Range; + + fn next(&mut self) -> Option { + let (mut a, mut b) = match ( + self.a_range.take().or_else(|| self.a_iter.next()), + self.b_range.take().or_else(|| self.b_iter.next()), + ) { + (Some(a), Some(b)) => (a, b), + (a, _) => return a, + }; + + loop { + if a.start < b.start { + if a.end <= b.start { + // a.start < a.end <= b.start < b.end + // + // a: -xxx---- + // b: ----xxx- + + self.b_range = Some(b); + return Some(a); + } else if b.end < a.end { + // a.start < b.start < b.end < a.end + // + // a: -xxxxxxx- + // b: ---xxx--- + + self.a_range = Some(b.end..a.end); + return Some(a.start..b.start); + } else { + // a.start < b.start < a.end <= b.end + // + // a: -xxxx--- + // b: ---xx?-- + + let difference = a.start..b.start; + self.b_range = Some(b); + return Some(difference); + } + } else { + // b.start <= a.start + + if b.end <= a.start { + // b.start < b.end <= a.start < a.end + // + // a: ----xxx- + // b: -xxx---- + + match self.b_iter.next() { + Some(range) => b = range, + None => return Some(a), + } + } else if a.end <= b.end { + // b.start <= a.start < a.end <= b.end + // + // a: --xxx-- + // b: -?xxx?- + + a = self.a_iter.next()?; + } else { + // b.start <= a.start < b.end < a.end + // + // a: --xxxxx- + // b: -?xxx--- + + a.start = b.end; + match self.b_iter.next() { + Some(range) => b = range, + None => return Some(a), + } + } + } + } + } +} + +impl FusedIterator for Difference {} +impl RangeIterator for Difference {} + +/// A `RangeIterator` that skips over `n` bits of antoher `RangeIterator`. +pub struct Skip { + iter: I, + skip: usize, +} + +impl Iterator for Skip { + type Item = Range; + + fn next(&mut self) -> Option { + loop { + let mut range = self.iter.next()?; + + if range.len() > self.skip { + range.start += self.skip; + self.skip = 0; + return Some(range); + } else { + self.skip -= range.len(); + } + } + } +} + +impl FusedIterator for Skip {} +impl RangeIterator for Skip {} + +/// A `RangeIterator` that iterates over the first `n` bits of antoher `RangeIterator`. +pub struct Take { + iter: I, + take: usize, +} + +impl Iterator for Take { + type Item = Range; + + fn next(&mut self) -> Option { + if self.take == 0 { + return None; + } + + let mut range = self.iter.next()?; + + if range.len() > self.take { + range.end = range.start + self.take; + } + + self.take -= range.len(); + Some(range) + } +} + +impl FusedIterator for Take {} +impl RangeIterator for Take {} + +/// A `RangeIterator` that wraps a regular iterator over `Range` as a way to explicitly +/// indicate that this iterator satisfies the requirements of the `RangeIterator` trait. +pub struct Ranges(I); + +impl Ranges +where + I: Iterator>, +{ + /// Creates a new `Ranges` instance. + pub fn new(iter: II) -> Self + where + II: IntoIterator>, + { + Self(iter.into_iter()) + } +} + +impl Iterator for Ranges +where + I: Iterator>, +{ + type Item = Range; + + fn next(&mut self) -> Option { + self.0.next() + } +} + +impl FusedIterator for Ranges where I: Iterator> {} +impl RangeIterator for Ranges where I: Iterator> {} + +/// Returns a `RangeIterator` which ranges contain the values from the provided iterator. +/// The values need to be in ascending order. +pub fn ranges_from_bits(bits: impl IntoIterator) -> impl RangeIterator { + let mut iter = bits.into_iter().peekable(); + + Ranges::new(iter::from_fn(move || { + let start = iter.next()?; + let mut end = start + 1; + while iter.peek() == Some(&end) { + end += 1; + iter.next(); + } + Some(start..end) + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + fn ranges(slice: &[Range]) -> impl RangeIterator + '_ { + Ranges::new(slice.iter().cloned()) + } + + #[test] + fn test_combinators() { + struct Case<'a> { + lhs: &'a [Range], + rhs: &'a [Range], + union: &'a [Range], + intersection: &'a [Range], + difference: &'a [Range], + } + + for &Case { + lhs, + rhs, + union, + intersection, + difference, + } in &[ + Case { + lhs: &[2..5], + rhs: &[], + union: &[2..5], + intersection: &[], + difference: &[2..5], + }, + Case { + lhs: &[0..3, 10..13], + rhs: &[5..8], + union: &[0..3, 5..8, 10..13], + intersection: &[], + difference: &[0..3, 10..13], + }, + Case { + lhs: &[0..3, 8..11], + rhs: &[2..5], + union: &[0..5, 8..11], + intersection: &[2..3], + difference: &[0..2, 8..11], + }, + Case { + lhs: &[0..3, 4..7, 8..11], + rhs: &[2..5, 6..9, 10..13], + union: &[0..13], + intersection: &[2..3, 4..5, 6..7, 8..9, 10..11], + difference: &[0..2, 5..6, 9..10], + }, + Case { + lhs: &[0..6], + rhs: &[1..3], + union: &[0..6], + intersection: &[1..3], + difference: &[0..1, 3..6], + }, + Case { + lhs: &[0..6], + rhs: &[1..3, 5..7, 9..11], + union: &[0..7, 9..11], + intersection: &[1..3, 5..6], + difference: &[0..1, 3..5], + }, + Case { + lhs: &[3..6], + rhs: &[0..2, 4..5, 8..10], + union: &[0..2, 3..6, 8..10], + intersection: &[4..5], + difference: &[3..4, 5..6], + }, + Case { + lhs: &[3..6, 8..10], + rhs: &[2..7, 8..11], + union: &[2..7, 8..11], + intersection: &[3..6, 8..10], + difference: &[], + }, + Case { + lhs: &[3..6, 8..10], + rhs: &[2..4], + union: &[2..6, 8..10], + intersection: &[3..4], + difference: &[4..6, 8..10], + }, + ] { + assert_eq!(ranges(lhs).merge(ranges(rhs)).collect::>(), union); + assert_eq!(ranges(rhs).merge(ranges(lhs)).collect::>(), union); + + assert_eq!( + ranges(lhs).intersection(ranges(rhs)).collect::>(), + intersection + ); + assert_eq!( + ranges(rhs).intersection(ranges(lhs)).collect::>(), + intersection + ); + + assert_eq!( + ranges(lhs).difference(ranges(rhs)).collect::>(), + difference + ); + } + } + + #[test] + fn test_ranges_from_bits() { + struct Case<'a> { + input: &'a [usize], + output: &'a [Range], + } + for &Case { input, output } in &[ + Case { + input: &[], + output: &[], + }, + Case { + input: &[10], + output: &[10..11], + }, + Case { + input: &[2, 3, 4, 7, 9, 11, 12], + output: &[2..5, 7..8, 9..10, 11..13], + }, + ] { + assert_eq!( + ranges_from_bits(input.iter().copied()).collect::>(), + output + ); + } + } + + #[test] + fn test_skip_take() { + struct Case<'a> { + input: &'a [Range], + n: usize, + skip: &'a [Range], + take: &'a [Range], + } + + for &Case { + input, + n, + skip, + take, + } in &[ + Case { + input: &[], + n: 0, + skip: &[], + take: &[], + }, + Case { + input: &[], + n: 3, + skip: &[], + take: &[], + }, + Case { + input: &[1..3, 4..6], + n: 0, + skip: &[1..3, 4..6], + take: &[], + }, + Case { + input: &[1..3, 4..6], + n: 1, + skip: &[2..3, 4..6], + take: &[1..2], + }, + Case { + input: &[1..3, 4..6], + n: 2, + skip: &[4..6], + take: &[1..3], + }, + Case { + input: &[1..3, 4..6], + n: 3, + skip: &[5..6], + take: &[1..3, 4..5], + }, + ] { + assert_eq!(ranges(input).skip_bits(n).collect::>(), skip); + assert_eq!(ranges(input).take_bits(n).collect::>(), take); + } + } +} diff --git a/utils/bitfield/src/lib.rs b/utils/bitfield/src/lib.rs index 3454a8ff95d8..e4942574f0f1 100644 --- a/utils/bitfield/src/lib.rs +++ b/utils/bitfield/src/lib.rs @@ -1,462 +1,270 @@ // Copyright 2020 ChainSafe Systems // SPDX-License-Identifier: Apache-2.0, MIT -pub mod bitvec_serde; +mod iter; + pub mod rleplus; -pub use bitvec; -use bitvec::prelude::*; -use core::ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, Not}; -use fnv::FnvHashSet; -use std::iter::FromIterator; +use ahash::AHashSet; +use iter::{ranges_from_bits, RangeIterator}; +use rleplus::RlePlus; +use serde::{Deserialize, Serialize}; +use std::{ + iter::FromIterator, + ops::{BitAnd, BitAndAssign, BitOr, BitOrAssign, Sub, SubAssign}, +}; -type BitVec = bitvec::prelude::BitVec; +type BitVec = bitvec::prelude::BitVec; type Result = std::result::Result; -/// Represents a bitfield to track bits set at indexes in the range of `u64`. -#[derive(Debug, Clone)] -pub enum BitField { - Encoded { - bv: BitVec, - set: FnvHashSet, - unset: FnvHashSet, - }, - // TODO would be beneficial in future to only keep encoded bitvec in memory, but comes at a cost - Decoded(BitVec), +/// An RLE+ encoded bit field with buffered insertion/removal. Similar to `HashSet`, +/// but more memory-efficient when long runs of 1s and 0s are present. +/// +/// When deserializing a bit field, in order to distinguish between an invalid RLE+ encoding +/// and any other deserialization errors, deserialize into an `UnverifiedBitField` and +/// call `verify` on it. +#[derive(Debug, Default, Clone, Serialize, Deserialize)] +#[serde(from = "RlePlus", into = "RlePlus")] +pub struct BitField { + /// The underlying RLE+ encoded bitvec. + bitvec: RlePlus, + /// Bits set to 1. Never overlaps with `unset`. + set: AHashSet, + /// Bits set to 0. Never overlaps with `set`. + unset: AHashSet, } -impl Default for BitField { - fn default() -> Self { - Self::Decoded(BitVec::new()) +impl PartialEq for BitField { + fn eq(&self, other: &Self) -> bool { + Iterator::eq(self.ranges(), other.ranges()) } } -impl BitField { - pub fn new() -> Self { - Self::default() - } - - /// Generates a new bitfield with a slice of all indexes to set. - pub fn new_from_set(set_bits: &[u64]) -> Self { - let mut vec = match set_bits.iter().max() { - Some(&max) => bitvec![_, u8; 0; max as usize + 1], - None => return Self::new(), - }; - - // Set all bits in bitfield - for b in set_bits { - vec.set(*b as usize, true); - } - - Self::Decoded(vec) - } - - /// Sets bit at bit index provided - pub fn set(&mut self, bit: u64) { - match self { - BitField::Encoded { set, unset, .. } => { - unset.remove(&bit); - set.insert(bit); - } - BitField::Decoded(bv) => { - let index = bit as usize; - if bv.len() <= index { - bv.resize(index + 1, false); - } - bv.set(index, true); - } - } - } - - /// Removes bit at bit index provided - pub fn unset(&mut self, bit: u64) { - match self { - BitField::Encoded { set, unset, .. } => { - set.remove(&bit); - unset.insert(bit); - } - BitField::Decoded(bv) => { - let index = bit as usize; - if bv.len() <= index { - return; - } - bv.set(index, false); - } - } +impl FromIterator for BitField { + fn from_iter>(iter: I) -> Self { + let mut vec: Vec<_> = iter.into_iter().collect(); + vec.sort_unstable(); + Self::from_ranges(ranges_from_bits(vec)) } +} - /// Gets the bit at the given index. - // TODO this probably should not require mut self and RLE decode bits - pub fn get(&mut self, index: u64) -> Result { - match self { - BitField::Encoded { set, unset, .. } => { - if set.contains(&index) { - return Ok(true); - } - - if unset.contains(&index) { - return Ok(false); - } - - // Check in encoded for the given bit - // This can be changed to not flush changes - if let Some(true) = self.as_mut_flushed()?.get(index as usize) { - Ok(true) - } else { - Ok(false) - } - } - BitField::Decoded(bv) => { - if let Some(true) = bv.get(index as usize) { - Ok(true) - } else { - Ok(false) - } - } +impl From for BitField { + fn from(bitvec: RlePlus) -> Self { + Self { + bitvec, + ..Default::default() } } +} - /// Retrieves the index of the first set bit, and error if invalid encoding or no bits set. - pub fn first(&mut self) -> Result { - for (i, b) in (0..).zip(self.as_mut_flushed()?.iter()) { - if b == &true { - return Ok(i); - } +impl From for RlePlus { + fn from(bitfield: BitField) -> Self { + if bitfield.set.is_empty() && bitfield.unset.is_empty() { + bitfield.bitvec + } else { + Self::from_ranges(bitfield.ranges()) } - // Return error if none found, not ideal but no reason not to match - Err("Bitfield has no set bits") } +} - fn retrieve_set_indices>(&mut self, max: usize) -> Result { - let flushed = self.as_mut_flushed()?; - if flushed.count_ones() > max { - return Err("Bits set exceeds max in retrieval"); - } - - Ok((0..) - .zip(flushed.iter()) - .filter_map(|(i, b)| if b == &true { Some(i) } else { None }) - .collect()) +impl BitField { + /// Creates an empty bit field. + pub fn new() -> Self { + Self::default() } - /// Returns a vector of indexes of all set bits - pub fn all(&mut self, max: usize) -> Result> { - self.retrieve_set_indices(max) + /// Creates a new bit field from a `RangeIterator`. + pub fn from_ranges(iter: impl RangeIterator) -> Self { + RlePlus::from_ranges(iter).into() } - /// Returns a Hash set of indexes of all set bits - pub fn all_set(&mut self, max: usize) -> Result> { - self.retrieve_set_indices(max) + /// Adds the bit at a given index to the bit field. + pub fn set(&mut self, bit: usize) { + self.unset.remove(&bit); + self.set.insert(bit); } - pub fn for_each(&mut self, mut callback: F) -> std::result::Result<(), String> - where - F: FnMut(u64) -> std::result::Result<(), String>, - { - let flushed = self.as_mut_flushed()?; - - for (i, &b) in (0..).zip(flushed.iter()) { - if b { - callback(i)?; - } - } - Ok(()) + /// Removes the bit at a given index from the bit field. + pub fn unset(&mut self, bit: usize) { + self.set.remove(&bit); + self.unset.insert(bit); } - /// Returns true if there are no bits set, false if the bitfield is empty. - pub fn is_empty(&mut self) -> Result { - for b in self.as_mut_flushed()?.iter() { - if b == &true { - return Ok(false); - } + /// Returns `true` if the bit field contains the bit at a given index. + pub fn get(&self, index: usize) -> bool { + if self.set.contains(&index) { + true + } else if self.unset.contains(&index) { + false + } else { + self.bitvec.get(index) } - - Ok(true) } - /// Returns a slice of the bitfield with the start index of set bits - /// and number of bits to include in slice. - pub fn slice(&mut self, start: u64, count: u64) -> Result { - if count == 0 { - return Ok(BitField::default()); - } - - // These conversions aren't ideal, but we aren't supporting 32 bit targets - let mut start = start as usize; - let mut count = count as usize; - - let bitvec = self.as_mut_flushed()?; - let mut start_idx: usize = 0; - let mut range: usize = 0; - if start != 0 { - for (i, v) in bitvec.iter().enumerate() { - if v == &true { - start -= 1; - if start == 0 { - start_idx = i + 1; - break; - } - } - } + /// Returns the index of the lowest bit present in the bit field. + pub fn first(&self) -> Option { + self.iter().next() + } + + /// Returns an iterator over the indices of the bit field's set bits. + pub fn iter(&self) -> impl Iterator + '_ { + // this code results in the same values as `self.ranges().flatten()`, but there's + // a key difference: + // + // `ranges()` needs to traverse both `self.set` and `self.unset` up front (so before + // iteration starts) in order to not have to visit each individual bit of `self.bitvec` + // during iteration, while here we can get away with only traversing `self.set` up + // front and checking `self.unset` containment for the candidate bits on the fly + // because we're visiting all bits either way + // + // consequently, `self.first()` is only linear in the length of `self.set`, not + // in the length of `self.unset` (as opposed to getting the first range with + // `self.ranges().next()` which is linear in both) + + let mut set_bits: Vec<_> = self.set.iter().copied().collect(); + set_bits.sort_unstable(); + + self.bitvec + .ranges() + .merge(ranges_from_bits(set_bits)) + .flatten() + .filter(move |i| !self.unset.contains(i)) + } + + /// Returns an iterator over the indices of the bit field's set bits if the number + /// of set bits in the bit field does not exceed `max`. Returns an error otherwise. + pub fn bounded_iter(&self, max: usize) -> Result + '_> { + if max <= self.len() { + Ok(self.iter()) + } else { + Err("Bits set exceeds max in retrieval") } - - for (i, v) in bitvec[start_idx..].iter().enumerate() { - if v == &true { - count -= 1; - if count == 0 { - range = i + 1; - break; - } - } - } - - if count > 0 { - return Err("Not enough bits to index the slice"); - } - - let mut slice = BitVec::with_capacity(start_idx + range); - slice.resize(start_idx, false); - slice.extend_from_slice(&bitvec[start_idx..start_idx + range]); - Ok(BitField::Decoded(slice)) } - /// Retrieves number of set bits in the bitfield - /// - /// This function requires a mutable reference for now to be able to handle the cached - /// changes in the case of an RLE encoded bitfield. - pub fn count(&mut self) -> Result { - Ok(self.as_mut_flushed()?.count_ones()) - } - - fn flush(&mut self) -> Result<()> { - if let BitField::Encoded { bv, set, unset } = self { - *self = BitField::Decoded(decode_and_apply_cache(bv, set, unset)?); - } - - Ok(()) - } + /// Returns an iterator over the ranges of set bits that make up the bit field. The + /// ranges are in ascending order, are non-empty, and don't overlap. + pub fn ranges(&self) -> impl RangeIterator + '_ { + let ranges = |set: &AHashSet| { + let mut vec: Vec<_> = set.iter().copied().collect(); + vec.sort_unstable(); + ranges_from_bits(vec) + }; - fn into_flushed(mut self) -> Result { - self.flush()?; - match self { - BitField::Decoded(bv) => Ok(bv), - // Unreachable because flushed before this. - _ => unreachable!(), - } + self.bitvec + .ranges() + .merge(ranges(&self.set)) + .difference(ranges(&self.unset)) } - fn as_mut_flushed(&mut self) -> Result<&mut BitVec> { - self.flush()?; - match self { - BitField::Decoded(bv) => Ok(bv), - // Unreachable because flushed before this. - _ => unreachable!(), - } + /// Returns `true` if the bit field is empty. + pub fn is_empty(&self) -> bool { + self.set.is_empty() + && self + .bitvec + .ranges() + .flatten() + .all(|bit| self.unset.contains(&bit)) } - /// Merges to bitfields together (equivalent of bitwise OR `|` operator) - pub fn merge(mut self, other: &Self) -> Result { - self.merge_assign(other)?; - Ok(self) - } + /// Returns a slice of the bit field with the start index of set bits + /// and number of bits to include in the slice. Returns an error if the + /// bit field contains fewer than `start + len` set bits. + pub fn slice(&self, start: usize, len: usize) -> Result { + let slice = BitField::from_ranges(self.ranges().skip_bits(start).take_bits(len)); - /// Merges to bitfields into `self` (equivalent of bitwise OR `|` operator) - pub fn merge_assign(&mut self, other: &Self) -> Result<()> { - let a = self.as_mut_flushed()?; - match other { - BitField::Encoded { bv, set, unset } => { - let v = decode_and_apply_cache(bv, set, unset)?; - bit_or(a, v.into_iter()) - } - BitField::Decoded(bv) => bit_or(a, bv.iter().copied()), + if slice.len() == len { + Ok(slice) + } else { + Err("Not enough bits") } - - Ok(()) } - /// Intersection of two bitfields (equivalent of bit AND `&`) - pub fn intersect(mut self, other: &Self) -> Result { - self.intersect_assign(other)?; - Ok(self) - } - - /// Intersection of two bitfields and assigns to self (equivalent of bit AND `&`) - pub fn intersect_assign(&mut self, other: &Self) -> Result<()> { - match other { - BitField::Encoded { bv, set, unset } => { - *self.as_mut_flushed()? &= decode_and_apply_cache(bv, set, unset)? - } - BitField::Decoded(bv) => *self.as_mut_flushed()? &= bv.iter().copied(), - } - Ok(()) + /// Returns the number of set bits in the bit field. + pub fn len(&self) -> usize { + self.ranges().map(|range| range.len()).sum() } - /// Subtract other bitfield from self (equivalent of `a & !b`) - pub fn subtract(mut self, other: &Self) -> Result { - self.subtract_assign(other)?; - Ok(self) + /// Returns a new `RangeIterator` over the bits that are in `self`, in `other`, or in both. + /// + /// The `|` operator is the eager version of this. + pub fn merge<'a>(&'a self, other: &'a Self) -> impl RangeIterator + 'a { + self.ranges().merge(other.ranges()) } - /// Subtract other bitfield from self (equivalent of `a & !b`) - pub fn subtract_assign(&mut self, other: &Self) -> Result<()> { - match other { - BitField::Encoded { bv, set, unset } => { - *self.as_mut_flushed()? &= !decode_and_apply_cache(bv, set, unset)? - } - BitField::Decoded(bv) => *self.as_mut_flushed()? &= bv.iter().copied().map(|b| !b), - } - Ok(()) + /// Returns a new `RangeIterator` over the bits that are in both `self` and `other`. + /// + /// The `&` operator is the eager version of this. + pub fn intersection<'a>(&'a self, other: &'a Self) -> impl RangeIterator + 'a { + self.ranges().intersection(other.ranges()) } - /// Creates a bitfield which is a union of a vector of bitfields. - pub fn union<'a>(bit_fields: impl IntoIterator) -> Result { - let mut ret = Self::default(); - for bf in bit_fields.into_iter() { - ret.merge_assign(bf)?; - } - Ok(ret) + /// Returns a new `RangeIterator` over the bits that are in `self` but not in `other`. + /// + /// The `-` operator is the eager version of this. + pub fn difference<'a>(&'a self, other: &'a Self) -> impl RangeIterator + 'a { + self.ranges().difference(other.ranges()) } - /// Returns true if BitFields have any overlapping bits. - pub fn contains_any(&mut self, other: &mut BitField) -> Result { - for (&a, &b) in self - .as_mut_flushed()? - .iter() - .zip(other.as_mut_flushed()?.iter()) - { - if a && b { - return Ok(true); - } - } - Ok(false) + /// Returns the union of the given bit fields as a new bit field. + pub fn union<'a>(bitfields: impl IntoIterator) -> Self { + bitfields.into_iter().fold(Self::new(), |a, b| &a | b) } - /// Returns true if the self `BitField` has all the bits set in the other `BitField`. - pub fn contains_all(&mut self, other: &mut BitField) -> Result { - let a_bf = self.as_mut_flushed()?; - let b_bf = other.as_mut_flushed()?; - - // Checking lengths should be sufficient in most cases, but does not take into account - // decoded bitfields with extra 0 bits. This makes sure there are no extra bits in the - // extension. - if b_bf.len() > a_bf.len() && b_bf[a_bf.len()..].count_ones() > 0 { - return Ok(false); - } - - for (a, b) in a_bf.iter().zip(b_bf.iter()) { - if *b && !a { - return Ok(false); - } - } - - Ok(true) + /// Returns true if `self` overlaps with `other`. + pub fn contains_any(&self, other: &BitField) -> bool { + self.intersection(other).next().is_some() } -} -fn bit_or(a: &mut BitVec, mut b: I) -where - I: Iterator, -{ - for mut a_i in a.iter_mut() { - match b.next() { - Some(true) => *a_i = true, - Some(false) => (), - None => return, - } + /// Returns true if the `self` is a superset of `other`. + pub fn contains_all(&self, other: &BitField) -> bool { + other.difference(self).next().is_none() } - - a.extend(b); } -fn decode_and_apply_cache( - bit_vec: &BitVec, - set: &FnvHashSet, - unset: &FnvHashSet, -) -> Result { - let mut decoded = rleplus::decode(bit_vec)?; - - // Resize before setting any values - if let Some(&max) = set.iter().max() { - let max = max as usize; - if max >= bit_vec.len() { - decoded.resize(max + 1, false); - } - }; - - // Set all values in the cache - for &b in set.iter() { - decoded.set(b as usize, true); - } - - // Unset all values from the encoded cache - for &b in unset.iter() { - decoded.set(b as usize, false); - } +impl BitOr<&BitField> for &BitField { + type Output = BitField; - Ok(decoded) -} - -impl AsRef for BitField { - fn as_ref(&self) -> &Self { - self + #[inline] + fn bitor(self, rhs: &BitField) -> Self::Output { + BitField::from_ranges(self.merge(rhs)) } } -impl From for BitField { - fn from(b: BitVec) -> Self { - Self::Decoded(b) +impl BitOrAssign<&BitField> for BitField { + #[inline] + fn bitor_assign(&mut self, rhs: &BitField) { + *self = &*self | rhs; } } -impl BitOr for BitField -where - B: AsRef, -{ - type Output = Self; +impl BitAnd<&BitField> for &BitField { + type Output = BitField; #[inline] - fn bitor(self, rhs: B) -> Self { - self.merge(rhs.as_ref()).unwrap() + fn bitand(self, rhs: &BitField) -> Self::Output { + BitField::from_ranges(self.intersection(rhs)) } } -impl BitOrAssign for BitField -where - B: AsRef, -{ +impl BitAndAssign<&BitField> for BitField { #[inline] - fn bitor_assign(&mut self, rhs: B) { - self.merge_assign(rhs.as_ref()).unwrap() + fn bitand_assign(&mut self, rhs: &BitField) { + *self = &*self & rhs; } } -impl BitAnd for BitField -where - B: AsRef, -{ - type Output = Self; +impl Sub<&BitField> for &BitField { + type Output = BitField; #[inline] - fn bitand(self, rhs: B) -> Self::Output { - self.intersect(rhs.as_ref()).unwrap() + fn sub(self, rhs: &BitField) -> Self::Output { + BitField::from_ranges(self.difference(rhs)) } } -impl BitAndAssign for BitField -where - B: AsRef, -{ - #[inline] - fn bitand_assign(&mut self, rhs: B) { - self.intersect_assign(rhs.as_ref()).unwrap() - } -} - -impl Not for BitField { - type Output = Self; - +impl SubAssign<&BitField> for BitField { #[inline] - fn not(self) -> Self::Output { - Self::Decoded(!self.into_flushed().unwrap()) + fn sub_assign(&mut self, rhs: &BitField) { + *self = &*self - rhs; } } diff --git a/utils/bitfield/src/rleplus.rs b/utils/bitfield/src/rleplus.rs deleted file mode 100644 index 4d4c4c027e51..000000000000 --- a/utils/bitfield/src/rleplus.rs +++ /dev/null @@ -1,282 +0,0 @@ -// Copyright 2020 ChainSafe Systems -// SPDX-License-Identifier: Apache-2.0, MIT - -//! # RLE+ Bitset Encoding -//! -//! RLE+ is a lossless compression format based on [RLE](https://en.wikipedia.org/wiki/Run-length_encoding). -//! It's primary goal is to reduce the size in the case of many individual bits, where RLE breaks down quickly, -//! while keeping the same level of compression for large sets of contigous bits. -//! -//! In tests it has shown to be more compact than RLE iteself, as well as [Concise](https://arxiv.org/pdf/1004.0403.pdf) and [Roaring](https://roaringbitmap.org/). -//! -//! ## Format -//! -//! The format consists of a header, followed by a series of blocks, of which there are three different types. -//! -//! The format can be expressed as the following [BNF](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form) grammar. -//! -//! ```text -//! ::=
-//!
::= -//! ::= "00" -//! ::= | "" -//! ::= | | -//! ::= "1" -//! ::= "01" -//! ::= "00" -//! ::= "0" | "1" -//! ``` -//! -//! An `` is defined as specified [here](https://github.com/multiformats/unsigned-varint). -//! -//! ### Header -//! -//! The header indiciates the very first bit of the bit vector to encode. This means the first bit is always -//! the same for the encoded and non encoded form. -//! -//! ### Blocks -//! -//! The blocks represent how many bits, of the current bit type there are. As `0` and `1` alternate in a bit vector -//! the inital bit, which is stored in the header, is enough to determine if a length is currently referencing -//! a set of `0`s, or `1`s. -//! -//! #### Block Single -//! -//! If the running length of the current bit is only `1`, it is encoded as a single set bit. -//! -//! #### Block Short -//! -//! If the running length is less than `16`, it can be encoded into up to four bits, which a short block -//! represents. The length is encoded into a 4 bits, and prefixed with `01`, to indicate a short block. -//! -//! #### Block Long -//! -//! If the running length is `16` or larger, it is encoded into a varint, and then prefixed with `00` to indicate -//! a long block. -//! -//! -//! > **Note:** The encoding is unique, so no matter which algorithm for encoding is used, it should produce -//! > the same encoding, given the same input. -//! - -use super::BitVec; - -/// Encode the given bitset into their RLE+ encoded representation. -pub fn encode(raw: &BitVec) -> BitVec { - let mut encoding = BitVec::new(); - - if raw.is_empty() { - return encoding; - } - - // Header - // encode version "00" and push to start of encoding - encoding.insert(0, false); - encoding.insert(0, false); - - // encode the very first bit (the first block contains this, then alternating) - encoding.push(*raw.get(0).unwrap()); - - // the running length - let mut count = 1; - - // the current bit type - let mut current = raw.get(0); - - let last = raw.len(); - - for i in 1..=raw.len() { - if raw.get(i) != current || i == last { - if i == last && raw.get(i) == current { - count += 1; - } - - if count == 1 { - // Block Single - encoding.push(true); - } else if count < 16 { - // Block Short - // 4 bits - let s_vec: BitVec = BitVec::from(&[count as u8][..]); - - // prefix: 01 - encoding.push(false); - encoding.push(true); - encoding.extend(s_vec.into_iter().take(4)); - count = 1; - } else { - // Block Long - let mut v = [0u8; 10]; - let s = unsigned_varint::encode::u64(count, &mut v); - let s_vec: BitVec = BitVec::from(s); - - // prefix: 00 - encoding.push(false); - encoding.push(false); - - encoding.extend(s_vec.into_iter()); - count = 1; - } - current = raw.get(i); - } else { - count += 1; - } - } - - encoding -} - -/// Decode an RLE+ encoded bitset into its original form. -pub fn decode(enc: &BitVec) -> Result { - let mut decoded = BitVec::new(); - - if enc.is_empty() { - return Ok(decoded); - } - - // Header - if enc.len() < 3 { - return Err("Failed to decode, bytes must be at least 3 bits long"); - } - - // read version (expects "00") - if *enc.get(0).unwrap() || *enc.get(1).unwrap() { - return Err("Invalid version, expected '00'"); - } - - // read the inital bit - let mut cur = *enc.get(2).unwrap(); - - // pointer into the encoded bitvec - let mut i = 3; - - let len = enc.len(); - - while i < len { - // read the next prefix - if *enc.get(i).unwrap_or(&false) { - decoded.push(cur); - i += 1; - } else { - let enc_iter = enc.iter().skip(i + 2); - if *enc.get(i + 1).ok_or_else(|| "premature end to bits")? { - // Block Short - // prefix: 01 - let buf = enc_iter.take(4).copied().collect::(); - let res: Vec = buf.into(); - - if res.len() != 1 { - return Err("Invalid short block encoding"); - } - - let len = res[0] as usize; - - // prefix - i += 2; - // length of the encoded number - i += 4; - - decoded.extend((0..len).map(|_| cur)); - } else { - let buf = enc_iter.take(10 * 8).copied().collect::(); - let buf_ref: &[u8] = buf.as_ref(); - let (len, rest) = unsigned_varint::decode::u64(buf_ref) - .map_err(|_| "Failed to decode uvarint")?; - - // prefix - i += 2; - // this is how much space the varint took in bits - i += (buf_ref.len() * 8) - (rest.len() * 8); - - // insert this many bits - decoded.extend((0..len).map(|_| cur)); - } - } - - // swith the cur value - cur = !cur; - } - - Ok(decoded) -} - -#[cfg(test)] -mod tests { - use super::*; - - use bitvec::prelude::Lsb0; - use bitvec::*; - use rand::{Rng, RngCore, SeedableRng}; - use rand_xorshift::XorShiftRng; - - #[test] - fn test_rle_plus_basics() { - let cases: Vec<(BitVec, BitVec)> = vec![ - ( - bitvec![Lsb0, u8; 0; 8], - bitvec![Lsb0, u8; - 0, 0, // version - 0, // starts with 0 - 0, 1, // fits into 4 bits - 0, 0, 0, 1, // 8 - ], - ), - ( - bitvec![Lsb0, u8; 0, 0, 0, 0, 1, 0, 0, 0], - bitvec![Lsb0, u8; - 0, 0, // version - 0, // starts with 0 - 0, 1, // fits into 4 bits - 0, 0, 1, 0, // 4 - 0 - 1, // 1 - 1 - 0, 1, // fits into 4 bits - 1, 1, 0, 0 // 3 - 0 - ], - ), - ]; - - for (i, case) in cases.into_iter().enumerate() { - assert_eq!(encode(&case.0), case.1, "case: {}", i); - } - } - - #[test] - #[ignore] - fn test_rle_plus_roundtrip_small() { - let mut rng = XorShiftRng::from_seed([1u8; 16]); - - for _i in 0..10000 { - let len: usize = rng.gen_range(0, 1000); - - let mut src = vec![0u8; len]; - rng.fill_bytes(&mut src); - - let original: BitVec = src.into(); - - let encoded = encode(&original); - let decoded = decode(&encoded).unwrap(); - - assert_eq!(original, decoded); - } - } - - #[test] - #[ignore] - fn test_rle_plus_roundtrip_large() { - let mut rng = XorShiftRng::from_seed([2u8; 16]); - - for _i in 0..100 { - let len: usize = rng.gen_range(0, 100000); - - let mut src = vec![0u8; len]; - rng.fill_bytes(&mut src); - - let original: BitVec = src.into(); - - let encoded = encode(&original); - let decoded = decode(&encoded).unwrap(); - - assert_eq!(original, decoded); - } - } -} diff --git a/utils/bitfield/src/rleplus/iter.rs b/utils/bitfield/src/rleplus/iter.rs new file mode 100644 index 000000000000..e8964ad4ee88 --- /dev/null +++ b/utils/bitfield/src/rleplus/iter.rs @@ -0,0 +1,86 @@ +// Copyright 2020 ChainSafe Systems +// SPDX-License-Identifier: Apache-2.0, MIT + +use super::{BitReader, RangeIterator, Result, RlePlus}; +use std::{iter::FusedIterator, ops::Range}; + +/// An iterator over the runs of 1s and 0s of RLE+ encoded data. +pub struct Runs<'a> { + /// The `BitReader` that is read from. + reader: BitReader<'a>, + /// The value of the next bit. + next_value: bool, +} + +impl<'a> Runs<'a> { + /// Creates a new `Runs` instance given data that may or may + /// not be correctly RLE+ encoded. Immediately returns an + /// error if the version number is incorrect. + pub fn new(bytes: &'a [u8]) -> Result { + let mut reader = BitReader::new(bytes); + + let version = reader.read(2); + if version != 0 { + return Err("incorrect version"); + } + + let next_value = reader.read(1) == 1; + Ok(Self { reader, next_value }) + } +} + +impl Iterator for Runs<'_> { + type Item = Result<(bool, usize)>; + + fn next(&mut self) -> Option { + let len = match self.reader.read_len() { + Ok(len) => len?, + Err(e) => return Some(Err(e)), + }; + + let run = (self.next_value, len); + self.next_value = !self.next_value; + Some(Ok(run)) + } +} + +/// An iterator over the ranges of 1s of RLE+ encoded data that has already been verified. +pub struct Ranges<'a> { + /// The underlying runs of 1s and 0s. + runs: Runs<'a>, + /// The current position, i.e. the end of the last range that was read, + /// or 0 if no ranges have been read yet. + offset: usize, +} + +impl<'a> Ranges<'a> { + pub(super) fn new(encoded: &'a RlePlus) -> Self { + Self { + // the data has already been verified, so this cannot fail + runs: Runs::new(encoded.as_bytes()).unwrap(), + offset: 0, + } + } +} + +impl Iterator for Ranges<'_> { + type Item = Range; + + fn next(&mut self) -> Option { + // this loop will run either 1 or 2 times because runs alternate + loop { + // the data has already been verified, so this cannot fail + let (value, len) = self.runs.next()?.unwrap(); + + let start = self.offset; + self.offset += len; + + if value { + return Some(start..self.offset); + } + } + } +} + +impl FusedIterator for Ranges<'_> {} +impl RangeIterator for Ranges<'_> {} diff --git a/utils/bitfield/src/rleplus/mod.rs b/utils/bitfield/src/rleplus/mod.rs new file mode 100644 index 000000000000..1bdea1bb5671 --- /dev/null +++ b/utils/bitfield/src/rleplus/mod.rs @@ -0,0 +1,312 @@ +// Copyright 2020 ChainSafe Systems +// SPDX-License-Identifier: Apache-2.0, MIT + +//! # RLE+ Bitset Encoding +//! +//! (from https://github.com/filecoin-project/specs/blob/master/src/listings/data_structures.md) +//! +//! RLE+ is a lossless compression format based on [RLE](https://en.wikipedia.org/wiki/Run-length_encoding). +//! Its primary goal is to reduce the size in the case of many individual bits, where RLE breaks down quickly, +//! while keeping the same level of compression for large sets of contiguous bits. +//! +//! In tests it has shown to be more compact than RLE iteself, as well as [Concise](https://arxiv.org/pdf/1004.0403.pdf) and [Roaring](https://roaringbitmap.org/). +//! +//! ## Format +//! +//! The format consists of a header, followed by a series of blocks, of which there are three different types. +//! +//! The format can be expressed as the following [BNF](https://en.wikipedia.org/wiki/Backus%E2%80%93Naur_form) grammar. +//! +//! ```text +//! ::=
+//!
::= +//! ::= "00" +//! ::= | "" +//! ::= | | +//! ::= "1" +//! ::= "01" +//! ::= "00" +//! ::= "0" | "1" +//! ``` +//! +//! An `` is defined as specified [here](https://github.com/multiformats/unsigned-varint). +//! +//! ### Header +//! +//! The header indiciates the very first bit of the bit vector to encode. This means the first bit is always +//! the same for the encoded and non encoded form. +//! +//! ### Blocks +//! +//! The blocks represent how many bits, of the current bit type there are. As `0` and `1` alternate in a bit vector +//! the inital bit, which is stored in the header, is enough to determine if a length is currently referencing +//! a set of `0`s, or `1`s. +//! +//! #### Block Single +//! +//! If the running length of the current bit is only `1`, it is encoded as a single set bit. +//! +//! #### Block Short +//! +//! If the running length is less than `16`, it can be encoded into up to four bits, which a short block +//! represents. The length is encoded into a 4 bits, and prefixed with `01`, to indicate a short block. +//! +//! #### Block Long +//! +//! If the running length is `16` or larger, it is encoded into a varint, and then prefixed with `00` to indicate +//! a long block. +//! +//! +//! > **Note:** The encoding is unique, so no matter which algorithm for encoding is used, it should produce +//! > the same encoding, given the same input. +//! + +mod iter; +mod reader; +mod writer; + +pub use iter::{Ranges, Runs}; +use reader::BitReader; +use writer::BitWriter; + +use super::{ranges_from_bits, BitVec, RangeIterator, Result}; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; + +// https://github.com/multiformats/unsigned-varint#practical-maximum-of-9-bytes-for-security +const VARINT_MAX_BYTES: usize = 9; + +/// An RLE+ encoded bit field. +#[derive(Debug, Default, Clone, PartialEq)] +pub struct RlePlus(BitVec); + +impl Serialize for RlePlus { + fn serialize(&self, serializer: S) -> std::result::Result + where + S: Serializer, + { + serde_bytes::serialize(&self.0.as_slice(), serializer) + } +} + +impl<'de> Deserialize<'de> for RlePlus { + fn deserialize(deserializer: D) -> std::result::Result + where + D: Deserializer<'de>, + { + let bytes: Vec = serde_bytes::deserialize(deserializer)?; + Self::new(bytes.into()).map_err(serde::de::Error::custom) + } +} + +impl RlePlus { + /// Creates a new `RlePlus` instance with an already encoded bitvec. Returns an + /// error if the given bitvec is not RLE+ encoded correctly. + pub fn new(encoded: BitVec) -> Result { + // iterating the runs of the encoded bitvec ensures that it's encoded correctly, + // and adding the lengths of the runs together ensures that the total length of + // 1s and 0s fits in a `usize` + Runs::new(encoded.as_slice())?.try_fold(0_usize, |total_len, run| { + let (_value, len) = run?; + total_len.checked_add(len).ok_or("RLE+ overflow") + })?; + Ok(Self(encoded)) + } + + /// Encodes the given bitset into its RLE+ encoded representation. + pub fn encode(raw: &BitVec) -> Self { + let bits = raw + .iter() + .enumerate() + .filter(|(_, &bit)| bit) + .map(|(i, _)| i); + Self::from_ranges(ranges_from_bits(bits)) + } + + /// Decodes an RLE+ encoded bitset into its original form. + pub fn decode(&self) -> BitVec { + // the underlying bitvec has already been validated, so nothing here can fail + let mut bitvec = BitVec::new(); + for run in Runs::new(self.as_bytes()).unwrap() { + let (value, len) = run.unwrap(); + bitvec.extend(std::iter::repeat(value).take(len)); + } + bitvec + } + + /// Returns an iterator over the ranges of 1s of the RLE+ encoded data. + pub fn ranges(&self) -> Ranges<'_> { + Ranges::new(self) + } + + /// Returns `true` if the RLE+ encoded data contains the bit at a given index. + pub fn get(&self, index: usize) -> bool { + self.ranges() + .take_while(|range| range.start < index) + .any(|range| range.contains(&index)) + } + + /// RLE+ encodes the ranges of 1s from a given `RangeIterator`. + pub fn from_ranges(mut iter: impl RangeIterator) -> Self { + let first_range = match iter.next() { + Some(range) => range, + None => return Default::default(), + }; + + let mut writer = BitWriter::new(); + writer.write(0, 2); // version 00 + + if first_range.start == 0 { + writer.write(1, 1); // the first bit is a 1 + } else { + writer.write(0, 1); // the first bit is a 0 + writer.write_len(first_range.start); // the number of leading 0s + } + + writer.write_len(first_range.len()); + let mut index = first_range.end; + + // for each range of 1s we first encode the number of 0s that came prior + // before encoding the number of 1s + for range in iter { + writer.write_len(range.start - index); // zeros + writer.write_len(range.len()); // ones + index = range.end; + } + + let (bytes, padding_zeros) = writer.finish(); + let mut bitvec = BitVec::from(bytes); + + // `bitvec` now may also contains padding zeros if the number of written + // bits is not a multiple of 8 + for _ in 0..padding_zeros { + bitvec.pop(); + } + + // no need to verify, this is valid RLE+ by construction + Self(bitvec) + } + + // Returns a byte slice of the bit field's contents. + pub fn as_bytes(&self) -> &[u8] { + self.0.as_slice() + } + + // Converts a bit field into a byte vector. + pub fn into_bytes(self) -> Vec { + self.0.into() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use bitvec::prelude::Lsb0; + use bitvec::*; + use rand::{Rng, RngCore, SeedableRng}; + use rand_xorshift::XorShiftRng; + + #[test] + fn test_rle_plus_basics() { + let cases: Vec<(BitVec, BitVec)> = vec![ + ( + bitvec![Lsb0, u8; 1; 8], + bitvec![Lsb0, u8; + 0, 0, // version + 1, // starts with 1 + 0, 1, // fits into 4 bits + 0, 0, 0, 1, // 8 - 1 + ], + ), + ( + bitvec![Lsb0, u8; 1, 1, 1, 1, 0, 1, 1, 1], + bitvec![Lsb0, u8; + 0, 0, // version + 1, // starts with 1 + 0, 1, // fits into 4 bits + 0, 0, 1, 0, // 4 - 1 + 1, // 1 - 0 + 0, 1, // fits into 4 bits + 1, 1, 0, 0 // 3 - 1 + ], + ), + ( + bitvec![Lsb0, u8; 1; 25], + bitvec![Lsb0, u8; + 0, 0, // version + 1, // starts with 1 + 0, 0, // does not fit into 4 bits + 1, 0, 0, 1, 1, 0, 0, 0 // 25 - 1 + ], + ), + ]; + + for (i, case) in cases.into_iter().enumerate() { + assert_eq!( + RlePlus::encode(&case.0), + RlePlus::new(case.1.clone()).unwrap(), + "encoding case {}", + i + ); + assert_eq!( + RlePlus::new(case.1).unwrap().decode(), + case.0, + "decoding case: {}", + i + ); + } + } + + #[test] + fn test_zero_short_block() { + // decoding should end whenever a length of 0 is encountered + + let encoded = bitvec![Lsb0, u8; + 0, 0, // version + 1, // starts with 1 + 1, // 1 - 1 + 0, 1, // fits into 4 bits + 0, 0, 0, 0, // 0 - 0 + 1, // 1 - 1 + ]; + + let decoded = RlePlus::new(encoded).unwrap().decode(); + assert_eq!(decoded, bitvec![Lsb0, u8; 1]); + } + + fn roundtrip(rng: &mut XorShiftRng, range: usize) { + let len: usize = rng.gen_range(0, range); + + let mut src = vec![0u8; len]; + rng.fill_bytes(&mut src); + + let mut bitvec = BitVec::from(src); + while bitvec.last() == Some(&false) { + bitvec.pop(); + } + + let encoded = RlePlus::encode(&bitvec); + let decoded = encoded.decode(); + assert_eq!(&bitvec, &decoded); + } + + #[test] + #[ignore] + fn test_rle_plus_roundtrip_small() { + let mut rng = XorShiftRng::seed_from_u64(1); + + for _i in 0..10_000 { + roundtrip(&mut rng, 1000); + } + } + + #[test] + #[ignore] + fn test_rle_plus_roundtrip_large() { + let mut rng = XorShiftRng::seed_from_u64(2); + + for _i in 0..10_000 { + roundtrip(&mut rng, 100_000); + } + } +} diff --git a/utils/bitfield/src/rleplus/reader.rs b/utils/bitfield/src/rleplus/reader.rs new file mode 100644 index 000000000000..e2ff59d5e320 --- /dev/null +++ b/utils/bitfield/src/rleplus/reader.rs @@ -0,0 +1,119 @@ +// Copyright 2020 ChainSafe Systems +// SPDX-License-Identifier: Apache-2.0, MIT + +use super::{Result, VARINT_MAX_BYTES}; + +/// A `BitReader` allows for efficiently reading bits to a byte buffer, up to a byte at a time. +/// +/// It works by always storing at least the next 8 bits in `bits`, which lets us conveniently +/// and efficiently read bits that cross a byte boundary. It's filled with the bits from `next_byte` +/// after every read operation, which is in turn replaced by the next byte from `bytes` as soon +/// as the next read might read bits from `next_byte`. +pub struct BitReader<'a> { + /// The bytes that have not been read from yet. + bytes: &'a [u8], + /// The next byte from `bytes` to be added to `bits`. + next_byte: u8, + /// The next bits to be read. + bits: u16, + /// The number of bits in `bits` from bytes that came before `next_byte` (at least 8, at most 15). + num_bits: u32, +} + +impl<'a> BitReader<'a> { + /// Creates a new `BitReader`. + pub fn new(bytes: &'a [u8]) -> Self { + let &byte1 = bytes.get(0).unwrap_or(&0); + let &byte2 = bytes.get(1).unwrap_or(&0); + let bytes = if bytes.len() > 2 { &bytes[2..] } else { &[] }; + + Self { + bytes, + bits: byte1 as u16, + next_byte: byte2, + num_bits: 8, + } + } + + /// Reads a given number of bits from the buffer. Will keep returning 0 once + /// the buffer has been exhausted. + pub fn read(&mut self, num_bits: u32) -> u8 { + debug_assert!(num_bits <= 8); + + // creates a mask with a `num_bits` number of 1s in order + // to get only the bits we need from `self.bits` + let mask = (1 << num_bits) - 1; + let res = (self.bits & mask) as u8; + + // removes the bits we've just read from local storage + // because we don't need them anymore + self.bits >>= num_bits; + self.num_bits -= num_bits; + + // this unconditionally adds the next byte to `bits`, + // regardless of whether there's enough space or not. the + // point is to make sure that `bits` always contains + // at least the next 8 bits to be read + self.bits |= (self.next_byte as u16) << self.num_bits; + + // if fewer than 8 bits remain, we increment `self.num_bits` + // to include the bits from `next_byte` (which is already + // contained in `bits`) and we update `next_byte` with the + // data to be read after that + if self.num_bits < 8 { + self.num_bits += 8; + + let (&next_byte, bytes) = self.bytes.split_first().unwrap_or((&0, &[])); + self.next_byte = next_byte; + self.bytes = bytes; + } + + res + } + + /// Reads a varint from the buffer. Returns an error if the + /// current position on the buffer contains no valid varint. + fn read_varint(&mut self) -> Result { + let mut len = 0; + + for i in 0..VARINT_MAX_BYTES { + let byte = self.read(8); + + // strip off the most significant bit and add + // it to the output + len |= (byte as usize & 0x7f) << (i * 7); + + // if the most significant bit is a 0, we've + // reached the end of the varint + if byte & 0x80 == 0 { + return Ok(len); + } + } + + Err("Invalid varint") + } + + /// Reads a length from the buffer according to RLE+ encoding. + pub fn read_len(&mut self) -> Result> { + let prefix_0 = self.read(1); + + let len = if prefix_0 == 1 { + // Block Single (prefix 1) + 1 + } else { + let prefix_1 = self.read(1); + + if prefix_1 == 1 { + // Block Short (prefix 01) + self.read(4) as usize + } else { + // Block Long (prefix 00) + self.read_varint()? + } + }; + + // decoding ends when a length of 0 is encountered, regardless of + // whether it is a short block or a long block + Ok(if len > 0 { Some(len) } else { None }) + } +} diff --git a/utils/bitfield/src/rleplus/writer.rs b/utils/bitfield/src/rleplus/writer.rs new file mode 100644 index 000000000000..de9025d0fead --- /dev/null +++ b/utils/bitfield/src/rleplus/writer.rs @@ -0,0 +1,69 @@ +// Copyright 2020 ChainSafe Systems +// SPDX-License-Identifier: Apache-2.0, MIT + +#[derive(Default)] +/// A `BitWriter` allows for efficiently writing bits to a byte buffer, up to a byte at a time. +pub struct BitWriter { + /// The buffer that is written to. + bytes: Vec, + /// The most recently written bits. Whenever this exceeds 8 bits, one byte is written to `bytes`. + bits: u16, + /// The number of bits currently stored in `bits`. + num_bits: u32, +} + +impl BitWriter { + /// Creates a new `BitWriter`. + pub fn new() -> Self { + Default::default() + } + + /// Writes a given number of bits from `byte` to the buffer. + pub fn write(&mut self, byte: u8, num_bits: u32) { + debug_assert!(num_bits <= 8); + + self.bits |= (byte as u16) << self.num_bits; + self.num_bits += num_bits; + + // when we have a full byte in `self.bits`, we write it to `self.bytes` + if self.num_bits >= 8 { + self.bytes.push(self.bits as u8); + self.bits >>= 8; + self.num_bits -= 8; + } + } + + /// Writes a given length to the buffer according to RLE+ encoding. + pub fn write_len(&mut self, len: usize) { + debug_assert!(len > 0); + + if len == 1 { + // Block Single (prefix 1) + self.write(1, 1); + } else if len < 16 { + // Block Short (prefix 01) + self.write(2, 2); // 2 == 01 with the least significant bit first + self.write(len as u8, 4); + } else { + // Block Long (prefix 00) + self.write(0, 2); + + let mut buffer = unsigned_varint::encode::usize_buffer(); + for &byte in unsigned_varint::encode::usize(len, &mut buffer) { + self.write(byte, 8); + } + } + } + + /// Writes any remaining bits to the buffer and returns it, as well as the number of + /// padding zeros that were (possibly) added to fill the last byte. + pub fn finish(mut self) -> (Vec, u32) { + let padding = if self.num_bits > 0 { + self.bytes.push(self.bits as u8); + 8 - self.num_bits + } else { + 0 + }; + (self.bytes, padding) + } +} diff --git a/utils/bitfield/tests/bitfield_tests.rs b/utils/bitfield/tests/bitfield_tests.rs index 72df471942db..50c6a94f7131 100644 --- a/utils/bitfield/tests/bitfield_tests.rs +++ b/utils/bitfield/tests/bitfield_tests.rs @@ -1,26 +1,25 @@ // Copyright 2020 ChainSafe Systems // SPDX-License-Identifier: Apache-2.0, MIT -use bitfield::*; +use ahash::AHashSet; +use bitfield::{rleplus::RlePlus, *}; use bitvec::*; -use fnv::FnvHashSet; use rand::{Rng, SeedableRng}; use rand_xorshift::XorShiftRng; +use std::iter::FromIterator; -fn gen_random_index_set(range: u64, seed: u8) -> Vec { - let mut rng = XorShiftRng::from_seed([seed; 16]); - +fn random_indices(range: usize, seed: u64) -> Vec { + let mut rng = XorShiftRng::seed_from_u64(seed); (0..range).filter(|_| rng.gen::()).collect() } #[test] fn bitfield_slice() { - let vals = gen_random_index_set(10000, 2); - - let mut bf = BitField::new_from_set(&vals); + let vals = random_indices(10000, 2); + let bf: BitField = vals.iter().copied().collect(); - let mut slice = bf.slice(600, 500).unwrap(); - let out_vals = slice.all(10000).unwrap(); + let slice = bf.slice(600, 500).unwrap(); + let out_vals: Vec<_> = slice.iter().collect(); let expected_slice = &vals[600..1100]; assert_eq!(out_vals[..500], expected_slice[..500]); @@ -28,37 +27,38 @@ fn bitfield_slice() { #[test] fn bitfield_slice_small() { - let mut bf = BitField::from(bitvec![Lsb0, u8; 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0]); - let mut slice = bf.slice(1, 3).unwrap(); + let bf = BitField::from(RlePlus::encode( + &bitvec![Lsb0, u8; 0, 1, 0, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0], + )); + let slice = bf.slice(1, 3).unwrap(); - assert_eq!(slice.count().unwrap(), 3); - assert_eq!(slice.all(10).unwrap(), &[4, 7, 9]); + assert_eq!(slice.len(), 3); + assert_eq!(slice.iter().collect::>(), &[4, 7, 9]); // Test all combinations let vals = [1, 5, 6, 7, 10, 11, 12, 15]; let test_permutations = |start, count: usize| { - let mut bf = BitField::new_from_set(&vals); - let mut sl = bf.slice(start as u64, count as u64).unwrap(); + let bf: BitField = vals.iter().copied().collect(); + let sl = bf.slice(start, count).unwrap(); let exp = &vals[start..start + count]; - let out = sl.all(10000).unwrap(); + let out: Vec<_> = sl.iter().collect(); assert_eq!(out, exp); }; for i in 0..vals.len() { for j in 0..vals.len() - i { - println!("{}, {}", i, j); test_permutations(i, j); } } } -fn set_up_test_bitfields() -> (Vec, Vec, BitField, BitField) { - let a = gen_random_index_set(100, 1); - let b = gen_random_index_set(100, 2); +fn set_up_test_bitfields() -> (Vec, Vec, BitField, BitField) { + let a = random_indices(100, 1); + let b = random_indices(100, 2); - let bf_a = BitField::new_from_set(&a); - let bf_b = BitField::new_from_set(&b); + let bf_a: BitField = a.iter().copied().collect(); + let bf_b: BitField = b.iter().copied().collect(); (a, b, bf_a, bf_b) } @@ -67,66 +67,60 @@ fn set_up_test_bitfields() -> (Vec, Vec, BitField, BitField) { fn bitfield_union() { let (a, b, bf_a, bf_b) = set_up_test_bitfields(); - let mut expected: FnvHashSet = a.iter().copied().collect(); + let mut expected: AHashSet<_> = a.iter().copied().collect(); expected.extend(b); - let mut merged = bf_a.merge(&bf_b).unwrap(); - - assert_eq!(expected, merged.all_set(100).unwrap()); + let merged = &bf_a | &bf_b; + assert_eq!(expected, merged.iter().collect()); } #[test] fn bitfield_intersection() { let (a, b, bf_a, bf_b) = set_up_test_bitfields(); - let hs_a: FnvHashSet = a.into_iter().collect(); - let hs_b: FnvHashSet = b.into_iter().collect(); - let expected: FnvHashSet = hs_a.intersection(&hs_b).copied().collect(); + let hs_a: AHashSet<_> = a.into_iter().collect(); + let hs_b: AHashSet<_> = b.into_iter().collect(); + let expected: AHashSet<_> = hs_a.intersection(&hs_b).copied().collect(); - let mut merged = bf_a.intersect(&bf_b).unwrap(); - - assert_eq!(expected, merged.all_set(100).unwrap()); + let merged = &bf_a & &bf_b; + assert_eq!(expected, merged.iter().collect()); } #[test] -fn bitfield_subtraction() { +fn bitfield_difference() { let (a, b, bf_a, bf_b) = set_up_test_bitfields(); - let mut expected: FnvHashSet = a.into_iter().collect(); + let mut expected: AHashSet<_> = a.into_iter().collect(); for i in b.iter() { expected.remove(i); } - let mut merged = bf_a.subtract(&bf_b).unwrap(); - assert_eq!(expected, merged.all_set(100).unwrap()); + let merged = &bf_a - &bf_b; + assert_eq!(expected, merged.iter().collect()); } // Ported test from go impl (specs-actors) #[test] fn subtract_more() { - let have = BitField::new_from_set(&[5, 6, 8, 10, 11, 13, 14, 17]); - let s1 = BitField::new_from_set(&[5, 6]).subtract(&have).unwrap(); - let s2 = BitField::new_from_set(&[8, 10]).subtract(&have).unwrap(); - let s3 = BitField::new_from_set(&[11, 13]).subtract(&have).unwrap(); - let s4 = BitField::new_from_set(&[14, 17]).subtract(&have).unwrap(); - - let mut u = BitField::union(&[s1, s2, s3, s4]).unwrap(); - assert_eq!(u.count().unwrap(), 0); + let have = BitField::from_iter(vec![5, 6, 8, 10, 11, 13, 14, 17]); + let s1 = &BitField::from_iter(vec![5, 6]) - &have; + let s2 = &BitField::from_iter(vec![8, 10]) - &have; + let s3 = &BitField::from_iter(vec![11, 13]) - &have; + let s4 = &BitField::from_iter(vec![14, 17]) - &have; + + let u = BitField::union(&[s1, s2, s3, s4]); + assert_eq!(u.len(), 0); } #[test] fn contains_any() { assert_eq!( - BitField::new_from_set(&[0, 4]) - .contains_any(&mut BitField::new_from_set(&[1, 3, 5])) - .unwrap(), + BitField::from_iter(vec![0, 4]).contains_any(&BitField::from_iter(vec![1, 3, 5])), false ); assert_eq!( - BitField::new_from_set(&[0, 2, 5, 6]) - .contains_any(&mut BitField::new_from_set(&[1, 3, 5])) - .unwrap(), + BitField::from_iter(vec![0, 2, 5, 6]).contains_any(&BitField::from_iter(vec![1, 3, 5])), true ); } @@ -134,47 +128,98 @@ fn contains_any() { #[test] fn contains_all() { assert_eq!( - BitField::new_from_set(&[0, 2, 4]) - .contains_all(&mut BitField::new_from_set(&[0, 2, 4, 5])) - .unwrap(), + BitField::from_iter(vec![0, 2, 4]).contains_all(&BitField::from_iter(vec![0, 2, 4, 5])), false ); assert_eq!( - BitField::new_from_set(&[0, 2, 4, 5]) - .contains_all(&mut BitField::new_from_set(&[0, 2, 4])) - .unwrap(), + BitField::from_iter(vec![0, 2, 4, 5]).contains_all(&BitField::from_iter(vec![0, 2, 4])), true ); assert_eq!( - BitField::new_from_set(&[1, 2, 3]) - .contains_any(&mut BitField::new_from_set(&[1, 2, 3])) - .unwrap(), + BitField::from_iter(vec![1, 2, 3]).contains_any(&BitField::from_iter(vec![1, 2, 3])), true ); } #[test] fn bit_ops() { - let mut a = BitField::new_from_set(&[1, 2, 3]) & BitField::new_from_set(&[1, 3, 4]); - assert_eq!(a.all(5).unwrap(), &[1, 3]); + let a = &BitField::from_iter(vec![1, 2, 3]) & &BitField::from_iter(vec![1, 3, 4]); + assert_eq!(a.iter().collect::>(), &[1, 3]); - let mut a = BitField::new_from_set(&[1, 2, 3]); - a &= BitField::new_from_set(&[1, 3, 4]); - assert_eq!(a.all(5).unwrap(), &[1, 3]); + let mut a = BitField::from_iter(vec![1, 2, 3]); + a &= &BitField::from_iter(vec![1, 3, 4]); + assert_eq!(a.iter().collect::>(), &[1, 3]); - let mut a = BitField::new_from_set(&[1, 2, 3]) | BitField::new_from_set(&[1, 3, 4]); - assert_eq!(a.all(5).unwrap(), &[1, 2, 3, 4]); + let a = &BitField::from_iter(vec![1, 2, 3]) | &BitField::from_iter(vec![1, 3, 4]); + assert_eq!(a.iter().collect::>(), &[1, 2, 3, 4]); - let mut a = BitField::new_from_set(&[1, 2, 3]); - a |= BitField::new_from_set(&[1, 3, 4]); - assert_eq!(a.all(5).unwrap(), &[1, 2, 3, 4]); + let mut a = BitField::from_iter(vec![1, 2, 3]); + a |= &BitField::from_iter(vec![1, 3, 4]); + assert_eq!(a.iter().collect::>(), &[1, 2, 3, 4]); +} - assert_eq!( - (!BitField::from(bitvec![Lsb0, u8; 1, 0, 1, 0])) - .all(5) - .unwrap(), - &[1, 3] - ); +#[test] +fn ranges() { + let bitvec = bitvec![Lsb0, u8; 0, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 0, 0]; + let mut bit_field = BitField::from(RlePlus::encode(&bitvec)); + + assert_eq!(bit_field.ranges().count(), 4); + bit_field.set(5); + assert_eq!(bit_field.ranges().count(), 3); + bit_field.unset(4); + assert_eq!(bit_field.ranges().count(), 4); + bit_field.unset(2); + assert_eq!(bit_field.ranges().count(), 4); +} + +#[test] +fn serialize_node_symmetric() { + let bitvec = bitvec![Lsb0, u8; 0, 1, 0, 1, 1, 1, 1, 1, 1]; + let bit_field = BitField::from(RlePlus::encode(&bitvec)); + let cbor_bz = encoding::to_vec(&bit_field).unwrap(); + let deserialized: BitField = encoding::from_slice(&cbor_bz).unwrap(); + assert_eq!(deserialized.len(), 7); + assert_eq!(deserialized, bit_field); +} + +#[test] +// ported test from specs-actors `bitfield_test.go` with added vector +fn bit_vec_unset_vector() { + let mut bf = BitField::new(); + bf.set(1); + bf.set(2); + bf.set(3); + bf.set(4); + bf.set(5); + + bf.unset(3); + + assert_eq!(bf.get(3), false); + assert_eq!(bf.len(), 4); + + // Test cbor marshal and unmarshal + let cbor_bz = encoding::to_vec(&bf).unwrap(); + assert_eq!(&cbor_bz, &[0x43, 0xa8, 0x54, 0x00]); + + let deserialized: BitField = encoding::from_slice(&cbor_bz).unwrap(); + assert_eq!(deserialized.len(), 4); + assert_eq!(bf.get(3), false); +} + +#[test] +fn padding() { + // bits: 0 1 0 1 + // rle+: 0 0 0 1 1 1 1 + // when decoded it will have an extra 0 at the end for padding, + // which is not part of a block prefix + + let mut bf = BitField::new(); + bf.set(1); + bf.set(3); + + let cbor = encoding::to_vec(&bf).unwrap(); + let deserialized: BitField = encoding::from_slice(&cbor).unwrap(); + assert_eq!(deserialized, bf); } diff --git a/vm/actor/Cargo.toml b/vm/actor/Cargo.toml index 4ab95afdb0bd..c1911556b381 100644 --- a/vm/actor/Cargo.toml +++ b/vm/actor/Cargo.toml @@ -26,6 +26,7 @@ crypto = { package = "forest_crypto", path = "../../crypto" } bitfield = { path = "../../utils/bitfield" } fil_types = { path = "../../types" } byteorder = "1.3.4" +ahash = "0.3" [dev-dependencies] db = { path = "../../node/db" } diff --git a/vm/actor/src/builtin/miner/deadlines.rs b/vm/actor/src/builtin/miner/deadlines.rs index 66782b2af3d1..45352eb82198 100644 --- a/vm/actor/src/builtin/miner/deadlines.rs +++ b/vm/actor/src/builtin/miner/deadlines.rs @@ -135,11 +135,8 @@ pub fn deadline_count( )); } - let sector_count = d.due.get_mut(deadline_idx).unwrap().count()?; - let mut partition_count = sector_count / partition_size; - if sector_count % partition_size != 0 { - partition_count += 1; - }; + let sector_count = d.due[deadline_idx].len(); + let partition_count = (sector_count + partition_size - 1) / partition_size; Ok((partition_count, sector_count)) } /// Computes a bitfield of the sector numbers included in a sequence of partitions due at some deadline. @@ -156,7 +153,7 @@ pub fn compute_partitions_sector( // Work out which sector numbers the partitions correspond to. let deadline_sectors = d .due - .get_mut(deadline_idx) + .get(deadline_idx) .ok_or(format!("unable to find deadline: {}", deadline_idx))?; let partitions_sectors = partitions .iter() @@ -172,7 +169,8 @@ pub fn compute_partitions_sector( // Slice out the sectors corresponding to this partition from the deadline's sector bitfield. let sector_offset = (p_idx - deadline_first_partition) * partition_size; let sector_count = std::cmp::min(partition_size, deadline_sector_count - sector_offset); - let partition_sectors = deadline_sectors.slice(sector_offset, sector_count)?; + let partition_sectors = + deadline_sectors.slice(sector_offset as usize, sector_count as usize)?; Ok(partition_sectors) }) .collect::>()?; @@ -185,7 +183,7 @@ pub fn compute_partitions_sector( pub fn assign_new_sectors( deadlines: &mut Deadlines, partition_size: usize, - new_sectors: &[u64], + new_sectors: &[usize], _seed: Randomness, ) -> Result<(), String> { let mut next_new_sector: usize = 0; diff --git a/vm/actor/src/builtin/miner/mod.rs b/vm/actor/src/builtin/miner/mod.rs index e4b94a789b2a..05485d51c2b4 100644 --- a/vm/actor/src/builtin/miner/mod.rs +++ b/vm/actor/src/builtin/miner/mod.rs @@ -29,6 +29,7 @@ use crate::{ STORAGE_POWER_ACTOR_ADDR, }; use address::{Address, Payload, Protocol}; +use ahash::AHashSet; use bitfield::BitField; use byteorder::{BigEndian, ByteOrder}; use cid::{multihash::Blake2b256, Cid}; @@ -334,14 +335,9 @@ impl Actor { ) })?; - let proven_sectors = BitField::union(&partitions_sectors).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to union partitions of sectors: {}", e), - ) - })?; + let proven_sectors = BitField::union(&partitions_sectors); - let (sector_infos, mut declared_recoveries) = st + let (sector_infos, declared_recoveries) = st .load_sector_infos_for_proof(rt.store(), proven_sectors) .map_err(|e| { ActorError::new( @@ -355,16 +351,9 @@ impl Actor { verify_windowed_post(rt, deadline.challenge, §or_infos, params.proofs.clone())?; // Record the successful submission - let mut posted_partitions = BitField::new_from_set(¶ms.partitions); - let contains = st - .post_submissions - .contains_any(&mut posted_partitions) - .map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to intersect post partitions: {}", e), - ) - })?; + let posted_partitions: BitField = + params.partitions.iter().map(|&i| i as usize).collect(); + let contains = st.post_submissions.contains_any(&posted_partitions); if contains { return Err(ActorError::new( ExitCode::ErrIllegalArgument, @@ -382,7 +371,7 @@ impl Actor { })?; // If the PoSt was successful, the declared recoveries should be restored - st.remove_faults(rt.store(), &mut declared_recoveries) + st.remove_faults(rt.store(), &declared_recoveries) .map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, @@ -390,35 +379,28 @@ impl Actor { ) })?; - st.remove_recoveries(&mut declared_recoveries) - .map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to remove recoveries: {}", e), - ) - })?; + st.remove_recoveries(&declared_recoveries).map_err(|e| { + ActorError::new( + ExitCode::ErrIllegalState, + format!("failed to remove recoveries: {}", e), + ) + })?; // Load info for recovered sectors for recovery of power outside this state transaction. - match declared_recoveries.is_empty() { - Ok(true) => Ok(st.info.sector_size), - Ok(false) => { - let mut sectors_by_number: HashMap = - HashMap::new(); - for sec in sector_infos { - sectors_by_number.insert(sec.info.sector_number, sec); - } - let _ = declared_recoveries.for_each(|i| { - let key: SectorNumber = i as u64; - let s = sectors_by_number.get(&key).cloned().unwrap(); - recovered_sectors.push(s); - Ok(()) - }); - Ok(st.info.sector_size) + if declared_recoveries.is_empty() { + Ok(st.info.sector_size) + } else { + let mut sectors_by_number: HashMap = + HashMap::new(); + for sec in sector_infos { + sectors_by_number.insert(sec.info.sector_number, sec); } - Err(e) => Err(ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to check if bitfield was empty: {}", e), - )), + declared_recoveries.iter().for_each(|i| { + let key = i as u64; + let s = sectors_by_number.get(&key).cloned().unwrap(); + recovered_sectors.push(s); + }); + Ok(st.info.sector_size) } })??; // Remove power for new faults, and burn penalties. @@ -571,7 +553,7 @@ impl Actor { notify_pledge_change(rt, &BigInt::from(newly_vested_amount).neg())?; let mut bf = BitField::new(); - bf.set(params.sector_number); + bf.set(params.sector_number as usize); // Request deferred Cron check for PreCommit expiry check. let cron_payload = CronEventPayload { @@ -914,7 +896,7 @@ impl Actor { fn terminate_sectors( rt: &mut RT, - mut params: TerminateSectorsParams, + params: TerminateSectorsParams, ) -> Result<(), ActorError> where BS: BlockStore, @@ -925,7 +907,7 @@ impl Actor { // Note: this cannot terminate pre-committed but un-proven sectors. // They must be allowed to expire (and deposit burnt). - terminate_sectors(rt, &mut params.sectors, SECTOR_TERMINATION_MANUAL)?; + terminate_sectors(rt, ¶ms.sectors, SECTOR_TERMINATION_MANUAL)?; Ok(()) } @@ -969,7 +951,7 @@ impl Actor { let declared_sectors = params .faults .into_iter() - .map(|mut decl| { + .map(|decl| { let target_deadline: DeadlineInfo = declaration_deadline_info( st.proving_period_start, decl.deadline as usize, @@ -981,7 +963,7 @@ impl Actor { format!("invalid fault declaration deadline: {}", e), ) })?; - validate_fr_declaration(&mut deadlines, &target_deadline, &mut decl.sectors) + validate_fr_declaration(&mut deadlines, &target_deadline, &decl.sectors) .map_err(|e| { ActorError::new( ExitCode::ErrIllegalArgument, @@ -992,124 +974,81 @@ impl Actor { }) .collect::, ActorError>>()?; - let all_declared = BitField::union(&declared_sectors).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to union faults: {}", e), - ) - })?; + let all_declared = BitField::union(&declared_sectors); // Split declarations into declarations of new faults, and retraction of declared recoveries. - let mut recoveries = st - .recoveries - .clone() - .intersect(&all_declared) - .map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to intersect sectors with recoveries: {}", e), - ) - })?; - - let mut new_faults = all_declared.subtract(&recoveries).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to subtract recoveries from sectors: {}", e), - ) - })?; + let recoveries = &st.recoveries & &all_declared; + let new_faults = &all_declared - &recoveries; + + if !new_faults.is_empty() { + // check new fault are really new + if st.faults.contains_any(&new_faults) { + // This could happen if attempting to declare a fault for a deadline that's already passed, + // detected and added to Faults above. + // The miner must for the fault detection at proving period end, or submit again omitting + // sectors in deadlines that have passed. + // Alternatively, we could subtract the just-detected faults from new faults. + return Err(ActorError::new( + ExitCode::ErrIllegalArgument, + "attempted to re-declare fault".to_string(), + )); + } - match new_faults.is_empty() { - Ok(true) => { - // check new fault are really new - let contains = st.faults.contains_any(&mut new_faults).map_err(|e| { + // Add new faults to state and charge fee. + // Note: this sets the fault epoch for all declarations to be the beginning of this proving period, + // even if some sectors have already been proven in this period. + // It would better to use the target deadline's proving period start (which may be the one subsequent + // to the current). + st.add_faults(rt.store(), &new_faults, st.proving_period_start) + .map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, - format!("failed to intersect existing faults: {}", e), + format!("failed to add faults: {}", e), ) })?; - if contains { - // This could happen if attempting to declare a fault for a deadline that's already passed, - // detected and added to Faults above. - // The miner must for the fault detection at proving period end, or submit again omitting - // sectors in deadlines that have passed. - // Alternatively, we could subtract the just-detected faults from new faults. - return Err(ActorError::new( - ExitCode::ErrIllegalArgument, - "attempted to re-declare fault".to_string(), - )); - } - - // Add new faults to state and charge fee. - // Note: this sets the fault epoch for all declarations to be the beginning of this proving period, - // even if some sectors have already been proven in this period. - // It would better to use the target deadline's proving period start (which may be the one subsequent - // to the current). - st.add_faults(rt.store(), &mut new_faults, st.proving_period_start) - .map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to add faults: {}", e), - ) - })?; - // Note: this charges a fee for all declarations, even if the sectors have already been proven - // in this proving period. This discourages early declaration compared with waiting for - // the proving period to roll over. - // It would be better to charge a fee for this proving period only if the target deadline has - // not already passed. If it _has_ already passed then either: - // - the miner submitted PoSt successfully and should not be penalised more relative to - // submitting this declaration after the proving period rolls over, or - // - the miner failed to submit PoSt and will be penalised at the proving period end - // In either case, the miner will pay a fee for the subsequent proving period at the start - // of that period, unless faults are recovered sooner. - - // Load info for sectors. - let declared_fault_sectors = st - .load_sector_infos(rt.store(), &mut new_faults) - .map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to load fault sectors: {}", e), - ) - })?; - - // Unlock penalty for declared faults. - let declared_penalty = unlock_penalty( - st, - rt.store(), - current_epoch, - &declared_fault_sectors, - &pledge_penalty_for_sector_declared_fault, - ) - .map_err(|e| { + // Note: this charges a fee for all declarations, even if the sectors have already been proven + // in this proving period. This discourages early declaration compared with waiting for + // the proving period to roll over. + // It would be better to charge a fee for this proving period only if the target deadline has + // not already passed. If it _has_ already passed then either: + // - the miner submitted PoSt successfully and should not be penalised more relative to + // submitting this declaration after the proving period rolls over, or + // - the miner failed to submit PoSt and will be penalised at the proving period end + // In either case, the miner will pay a fee for the subsequent proving period at the start + // of that period, unless faults are recovered sooner. + + // Load info for sectors. + let declared_fault_sectors = + st.load_sector_infos(rt.store(), &new_faults).map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, - format!("failed to charge fault fee: {}", e), + format!("failed to load fault sectors: {}", e), ) })?; - penalty += declared_penalty; - let empty = recoveries.is_empty().map_err(|e| { + // Unlock penalty for declared faults. + let declared_penalty = unlock_penalty( + st, + rt.store(), + current_epoch, + &declared_fault_sectors, + &pledge_penalty_for_sector_declared_fault, + ) + .map_err(|e| { + ActorError::new( + ExitCode::ErrIllegalState, + format!("failed to charge fault fee: {}", e), + ) + })?; + penalty += declared_penalty; + + if !recoveries.is_empty() { + st.remove_recoveries(&recoveries).map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, - format!("failed to check if bitfield was empty: {}", e), + format!("failed to remove recoveries: {}", e), ) })?; - - if !empty { - st.remove_recoveries(&mut recoveries).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to remove recoveries: {}", e), - ) - })?; - } - } - Ok(false) => {} - Err(e) => { - return Err(ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to check if bitfield was empty: {}", e), - )); } } @@ -1166,7 +1105,7 @@ impl Actor { let declared_sectors = params .recoveries .into_iter() - .map(|mut decl| { + .map(|decl| { let target_deadline = declaration_deadline_info( st.proving_period_start, decl.deadline as usize, @@ -1179,7 +1118,7 @@ impl Actor { ) })?; - validate_fr_declaration(&mut deadlines, &target_deadline, &mut decl.sectors) + validate_fr_declaration(&mut deadlines, &target_deadline, &decl.sectors) .map_err(|e| { ActorError::new( ExitCode::ErrIllegalArgument, @@ -1190,42 +1129,23 @@ impl Actor { }) .collect::, ActorError>>()?; - let mut all_recoveries = BitField::union(&declared_sectors).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to union recoveries: {}", e), - ) - })?; + let all_recoveries = BitField::union(&declared_sectors); - let mut contains = st.faults.contains_all(&mut all_recoveries).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to check recoveries are faulty: {}", e), - ) - })?; - if !contains { + if !st.faults.contains_all(&all_recoveries) { return Err(ActorError::new( ExitCode::ErrIllegalArgument, "declared recoveries not currently faulty".to_string(), )); } - contains = st - .recoveries - .contains_any(&mut all_recoveries) - .map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to intersect new recoveries: {}", e), - ) - })?; - if contains { + + if st.recoveries.contains_any(&all_recoveries) { return Err(ActorError::new( ExitCode::ErrIllegalArgument, "sector already declared recovered".to_string(), )); } - st.add_recoveries(&mut all_recoveries).map_err(|e| { + st.add_recoveries(&all_recoveries).map_err(|e| { ActorError::new( ExitCode::ErrIllegalArgument, format!("invalid recoveries: {}", e), @@ -1397,7 +1317,7 @@ impl Actor { fn on_deferred_cron_event( rt: &mut RT, - mut payload: CronEventPayload, + payload: CronEventPayload, ) -> Result<(), ActorError> where BS: BlockStore, @@ -1405,7 +1325,7 @@ impl Actor { { match payload.event_type { CRON_EVENT_PROVING_PERIOD => handle_proving_period(rt)?, - CRON_EVENT_PRE_COMMIT_EXPIRY => check_precommit_expiry(rt, &mut payload.sectors)?, + CRON_EVENT_PRE_COMMIT_EXPIRY => check_precommit_expiry(rt, &payload.sectors)?, CRON_EVENT_WORKER_KEY_CHANGE => commit_worker_key_change(rt)?, _ => (), }; @@ -1483,27 +1403,26 @@ where { // Expire sectors that are due. - let mut expired_sectors = - rt.transaction::, _>(|st, rt| { - Ok( - pop_sector_expirations(st, rt.store(), deadline.period_end()).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to load expired sectors {:}", e), - ) - })?, - ) - })??; + let expired_sectors = rt.transaction::, _>(|st, rt| { + Ok( + pop_sector_expirations(st, rt.store(), deadline.period_end()).map_err(|e| { + ActorError::new( + ExitCode::ErrIllegalState, + format!("failed to load expired sectors {:}", e), + ) + })?, + ) + })??; // Terminate expired sectors (sends messages to power and market actors). - terminate_sectors(rt, &mut expired_sectors, SECTOR_TERMINATION_EXPIRED)?; + terminate_sectors(rt, &expired_sectors, SECTOR_TERMINATION_EXPIRED)?; } { // Terminate sectors with faults that are too old, and pay fees for ongoing faults. - let (mut expired_faults, ongoing_fault_penalty) = rt + let (expired_faults, ongoing_fault_penalty) = rt .transaction::, _>(|st, rt| { - let (expired_faults, mut ongoing_faults) = + let (expired_faults, ongoing_faults) = pop_expired_faults(st, rt.store(), deadline.period_end() - FAULT_MAX_AGE) .map_err(|e| { ActorError::new( @@ -1515,7 +1434,7 @@ where // Load info for ongoing faults. // TODO: this is potentially super expensive for a large miner with ongoing faults let ongoing_fault_info = st - .load_sector_infos(rt.store(), &mut ongoing_faults) + .load_sector_infos(rt.store(), &ongoing_faults) .map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, @@ -1540,7 +1459,7 @@ where Ok((expired_faults, ongoing_fault_penalty)) })??; - terminate_sectors(rt, &mut expired_faults, SECTOR_TERMINATION_FAULTY)?; + terminate_sectors(rt, &expired_faults, SECTOR_TERMINATION_FAULTY)?; burn_funds_and_notify_pledge_change(rt, &ongoing_fault_penalty)?; } @@ -1555,15 +1474,16 @@ where })?; // assign new sectors to deadlines - let new_sectors = st + let new_sectors: Vec<_> = st .new_sectors - .all(NEW_SECTORS_PER_PERIOD_MAX) + .bounded_iter(NEW_SECTORS_PER_PERIOD_MAX) .map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, format!("failed to expand new sectors {:}", e), ) - })?; + })? + .collect(); if !new_sectors.is_empty() { let randomness_epoch = std::cmp::min( @@ -1691,7 +1611,7 @@ where ) ); - let (mut detected_faults, mut failed_recoveries) = compute_faults_from_missing_posts( + let (detected_faults, failed_recoveries) = compute_faults_from_missing_posts( st, deadlines, st.next_deadline_to_process_faults, @@ -1705,7 +1625,7 @@ where })?; st.next_deadline_to_process_faults = before_deadline % WPOST_PERIOD_DEADLINES; - st.add_faults(store, &mut detected_faults, period_start) + st.add_faults(store, &detected_faults, period_start) .map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, @@ -1713,7 +1633,7 @@ where ) })?; - st.remove_recoveries(&mut failed_recoveries).map_err(|e| { + st.remove_recoveries(&failed_recoveries).map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, format!("failed to record failed recoveries: {}", e), @@ -1722,21 +1642,20 @@ where // Load info for sectors. let mut detected_fault_sectors = - st.load_sector_infos(store, &mut detected_faults) + st.load_sector_infos(store, &detected_faults).map_err(|e| { + ActorError::new( + ExitCode::ErrIllegalState, + format!("failed to load fault sectors: {}", e), + ) + })?; + let mut failed_recovery_sectors = + st.load_sector_infos(store, &failed_recoveries) .map_err(|e| { ActorError::new( ExitCode::ErrIllegalState, - format!("failed to load fault sectors: {}", e), + format!("failed to load failed recovery sectors: {}", e), ) })?; - let mut failed_recovery_sectors = st - .load_sector_infos(store, &mut failed_recoveries) - .map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to load failed recovery sectors: {}", e), - ) - })?; // unlock sector penalty for all undeclared faults detected_fault_sectors.append(&mut failed_recovery_sectors); @@ -1769,9 +1688,10 @@ fn compute_faults_from_missing_posts( // TODO: Iterating this bitfield and keeping track of what partitions we're expecting could remove the // need to expand this into a potentially-giant map. But it's tricksy. let partition_size = st.info.window_post_partition_sectors; - let submissions = st + let submissions: AHashSet<_> = st .post_submissions - .all_set(active_partitions_max(partition_size))?; + .bounded_iter(active_partitions_max(partition_size))? + .collect(); let mut f_groups: Vec = Vec::new(); let mut r_groups: Vec = Vec::new(); @@ -1789,10 +1709,10 @@ fn compute_faults_from_missing_posts( let deadline_sectors = deadlines .due - .get_mut(dl_idx) + .get(dl_idx) .expect("Should be able to index due deadlines"); for dl_part_idx in 0..dl_part_count { - if !submissions.contains(&(deadline_first_partition + dl_part_idx as u64)) { + if !submissions.contains(&(deadline_first_partition as usize + dl_part_idx)) { // no PoSt received in prior period let part_first_sector_idx = dl_part_idx * partition_size as usize; let part_sector_count = std::cmp::min( @@ -1800,22 +1720,22 @@ fn compute_faults_from_missing_posts( dl_sector_count - part_first_sector_idx, ); - let partition_sectors = deadline_sectors - .slice(part_first_sector_idx as u64, part_sector_count as u64)?; + let partition_sectors = + deadline_sectors.slice(part_first_sector_idx, part_sector_count)?; // record newly-faulty sectors - let new_faults = st.faults.clone().subtract(&partition_sectors)?; + let new_faults = &st.faults - &partition_sectors; f_groups.push(new_faults); // record failed recoveries - let failed_recovery = st.recoveries.clone().intersect(&partition_sectors)?; + let failed_recovery = &st.recoveries & &partition_sectors; r_groups.push(failed_recovery); } } deadline_first_partition += dl_part_count as u64; } - let detected_faults = BitField::union(&f_groups)?; - let failed_recoveries = BitField::union(&r_groups)?; + let detected_faults = BitField::union(&f_groups); + let failed_recoveries = BitField::union(&r_groups); Ok((detected_faults, failed_recoveries)) } @@ -1856,7 +1776,7 @@ where st.clear_sector_expirations(store, &expired_epochs)?; - let all_expiries = BitField::union(&expired_sectors)?; + let all_expiries = BitField::union(&expired_sectors); Ok(all_expiries) } @@ -1877,10 +1797,10 @@ where st.for_each_fault_epoch(store, |fault_start: ChainEpoch, faults: &BitField| { if fault_start <= latest_termination { - all_expiries.merge_assign(faults)?; + all_expiries |= faults; expired_epochs.push(fault_start); } else { - all_ongoing_faults.merge_assign(faults)?; + all_ongoing_faults |= faults; } Ok(()) })?; @@ -1892,7 +1812,7 @@ where fn check_precommit_expiry( rt: &mut RT, - optional_sectors: &mut Option, + optional_sectors: &Option, ) -> Result<(), ActorError> where BS: BlockStore, @@ -1903,7 +1823,9 @@ where rt.transaction::, _>(|st, rt| { if let Some(sectors) = optional_sectors { sectors - .for_each(|sec_num| { + .iter() + .try_for_each(|i| { + let sec_num = i as u64; let sector = match st.get_precommitted_sector(rt.store(), sec_num)? { Some(sec) => sec, // Already committed/deleted @@ -1918,7 +1840,7 @@ where Ok(()) }) - .map_err(|e| { + .map_err(|e: String| { ActorError::new( ExitCode::ErrIllegalState, format!("failed to check precommit expires: {}", e), @@ -1936,21 +1858,14 @@ where fn terminate_sectors( rt: &mut RT, - sector_nos: &mut BitField, + sector_nos: &BitField, termination_type: SectorTermination, ) -> Result<(), ActorError> where BS: BlockStore, RT: Runtime, { - let empty = sector_nos.is_empty().map_err(|_| { - ActorError::new( - ExitCode::ErrIllegalState, - "failed to count sectors".to_string(), - ) - })?; - - if empty { + if sector_nos.is_empty() { return Ok(()); } @@ -1970,22 +1885,23 @@ where })?; // narrow faults to just the set that are expiring, before expanding to a map - let mut faults = st.faults.clone().intersect(§or_nos).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to load faults: {}", e), - ) - })?; + let faults = &st.faults & sector_nos; - let faults_map = faults.all_set(max_allowed_faults as usize).map_err(|e| { - ActorError::new( - ExitCode::ErrIllegalState, - format!("failed to expand faults: {}", e), - ) - })?; + let faults_map: AHashSet<_> = faults + .bounded_iter(max_allowed_faults as usize) + .map_err(|e| { + ActorError::new( + ExitCode::ErrIllegalState, + format!("failed to expand faults: {}", e), + ) + })? + .map(|i| i as u64) + .collect(); sector_nos - .for_each(|i| { + .iter() + .try_for_each(|i| { + let i = i as u64; let sector = st .get_sector(rt.store(), i)? .ok_or_else(|| format!("no sector found: {}", i))?; @@ -1999,7 +1915,7 @@ where all_sectors.push(sector); Ok(()) }) - .map_err(|e| { + .map_err(|e: String| { ActorError::new( ExitCode::ErrIllegalState, format!("failed to load sector metadata: {}", e), @@ -2048,13 +1964,13 @@ fn remove_terminated_sectors( st: &mut State, store: &BS, deadlines: &mut Deadlines, - sectors: &mut BitField, + sectors: &BitField, ) -> Result<(), String> where BS: BlockStore, { st.delete_sector(store, sectors)?; - st.remove_new_sectors(sectors)?; + st.remove_new_sectors(sectors); deadlines.remove_from_all_deadlines(sectors)?; st.remove_faults(store, sectors)?; st.remove_recoveries(sectors)?; @@ -2547,7 +2463,7 @@ fn declaration_deadline_info( fn validate_fr_declaration( deadlines: &mut Deadlines, deadline: &DeadlineInfo, - mut declared_sectors: &mut BitField, + declared_sectors: &BitField, ) -> Result<(), String> { if deadline.fault_cutoff_passed() { return Err("late fault or recovery declaration".to_string()); @@ -2556,10 +2472,9 @@ fn validate_fr_declaration( // check that the declared sectors are actually due at the deadline let deadline_sectors = deadlines .due - .get_mut(deadline.index as usize) + .get(deadline.index) .ok_or("deadline not found")?; - let contains = deadline_sectors.contains_all(&mut declared_sectors)?; - if !contains { + if !deadline_sectors.contains_all(&declared_sectors) { return Err(format!( "sectors not all due at deadline {}", deadline.index diff --git a/vm/actor/src/builtin/miner/state.rs b/vm/actor/src/builtin/miner/state.rs index 91952a8d699a..d4f6b504975c 100644 --- a/vm/actor/src/builtin/miner/state.rs +++ b/vm/actor/src/builtin/miner/state.rs @@ -6,6 +6,7 @@ use super::policy::*; use super::types::*; use crate::{power, u64_key, BytesKey, HAMT_BIT_WIDTH}; use address::Address; +use ahash::AHashSet; use bitfield::BitField; use cid::{multihash::Blake2b256, Cid}; use clock::ChainEpoch; @@ -230,16 +231,15 @@ impl State { pub fn delete_sector( &mut self, store: &BS, - sector_nos: &mut BitField, + sector_nos: &BitField, ) -> Result<(), AmtError> { let mut sectors = Amt::::load(&self.sectors, store)?; - sector_nos - .for_each(|sector_num| { - sectors.delete(sector_num)?; - Ok(()) - }) - .map_err(|e| AmtError::Other(format!("could not delete sector number: {}", e)))?; + for sector_num in sector_nos.iter() { + sectors + .delete(sector_num as u64) + .map_err(|e| AmtError::Other(format!("could not delete sector number: {}", e)))?; + } self.sectors = sectors.flush()?; Ok(()) @@ -255,24 +255,23 @@ impl State { pub fn add_new_sectors(&mut self, sector_nos: &[SectorNumber]) -> Result<(), String> { let mut ns = BitField::new(); for §or in sector_nos { - ns.set(sector) + ns.set(sector as usize) } - self.new_sectors.merge_assign(&ns)?; + self.new_sectors |= &ns; - let count = self.new_sectors.count()?; - if count > NEW_SECTORS_PER_PERIOD_MAX { + let len = self.new_sectors.len(); + if len > NEW_SECTORS_PER_PERIOD_MAX { return Err(format!( "too many new sectors {}, max {}", - count, NEW_SECTORS_PER_PERIOD_MAX + len, NEW_SECTORS_PER_PERIOD_MAX )); } Ok(()) } /// Removes some sector numbers from the new sectors bitfield, if present. - pub fn remove_new_sectors(&mut self, sector_nos: &BitField) -> Result<(), String> { - self.new_sectors.subtract_assign(§or_nos)?; - Ok(()) + pub fn remove_new_sectors(&mut self, sector_nos: &BitField) { + self.new_sectors -= §or_nos; } /// Gets the sector numbers expiring at some epoch. pub fn get_sector_expirations( @@ -306,12 +305,14 @@ impl State { ) -> Result<(), String> { let mut sector_arr = Amt::::load(&self.sector_expirations, store)?; let mut bf: BitField = sector_arr.get(expiry)?.ok_or("unable to find sector")?; - bf.merge_assign(&BitField::new_from_set(sectors))?; - let count = bf.count()?; - if count > SECTORS_MAX { + for §or in sectors { + bf.set(sector as usize); + } + let len = bf.len(); + if len > SECTORS_MAX { return Err(format!( "too many sectors at expiration {}, {}, max {}", - expiry, count, SECTORS_MAX + expiry, len, SECTORS_MAX )); } @@ -329,9 +330,10 @@ impl State { ) -> Result<(), String> { let mut sector_arr = Amt::::load(&self.sector_expirations, store)?; - let bf: BitField = sector_arr.get(expiry)?.ok_or("unable to find sector")?; - bf.clone() - .subtract_assign(&BitField::new_from_set(sectors))?; + let mut bf = sector_arr.get(expiry)?.ok_or("unable to find sector")?; + for §or in sectors { + bf.unset(sector as usize); + } sector_arr.set(expiry, bf)?; @@ -359,18 +361,18 @@ impl State { pub fn add_faults( &mut self, store: &BS, - sector_nos: &mut BitField, + sector_nos: &BitField, fault_epoch: ChainEpoch, ) -> Result<(), String> { - if sector_nos.is_empty()? { + if sector_nos.is_empty() { return Err(format!("sectors are empty: {:?}", sector_nos)); } - self.faults.merge_assign(sector_nos)?; + self.faults |= sector_nos; - let count = self.faults.count()?; - if count > SECTORS_MAX { - return Err(format!("too many faults {}, max {}", count, SECTORS_MAX)); + let len = self.faults.len(); + if len > SECTORS_MAX { + return Err(format!("too many faults {}, max {}", len, SECTORS_MAX)); } let mut epoch_fault_arr = Amt::::load(&self.fault_epochs, store)?; @@ -378,7 +380,7 @@ impl State { .get(fault_epoch)? .ok_or("unable to find sector")?; - bf.merge_assign(sector_nos)?; + bf |= sector_nos; epoch_fault_arr.set(fault_epoch, bf)?; @@ -390,24 +392,22 @@ impl State { pub fn remove_faults( &mut self, store: &BS, - sector_nos: &mut BitField, + sector_nos: &BitField, ) -> Result<(), String> { - if sector_nos.is_empty()? { + if sector_nos.is_empty() { return Err(format!("sectors are empty: {:?}", sector_nos)); } - self.faults.subtract_assign(sector_nos)?; + self.faults -= sector_nos; let mut sector_arr = Amt::::load(&self.fault_epochs, store)?; let mut changed: Vec<(u64, BitField)> = Vec::new(); sector_arr.for_each(|i, bf1: &BitField| { - let c1 = bf1.clone().count()?; - - let mut bf2 = bf1.clone().subtract(sector_nos)?; - - let c2 = bf2.count()?; + let c1 = bf1.clone().len(); + let bf2 = bf1 - sector_nos; + let c2 = bf2.len(); if c1 != c2 { changed.push((i, bf2)); @@ -453,29 +453,26 @@ impl State { Ok(()) } /// Adds sectors to recoveries. - pub fn add_recoveries(&mut self, sector_nos: &mut BitField) -> Result<(), String> { - if sector_nos.is_empty()? { + pub fn add_recoveries(&mut self, sector_nos: &BitField) -> Result<(), String> { + if sector_nos.is_empty() { return Err(format!("sectors are empty: {:?}", sector_nos)); } - self.recoveries.clone().merge_assign(sector_nos)?; + self.recoveries |= sector_nos; - let count = self.recoveries.count()?; - if count > SECTORS_MAX { - return Err(format!( - "too many recoveries {}, max {}", - count, SECTORS_MAX - )); + let len = self.recoveries.len(); + if len > SECTORS_MAX { + return Err(format!("too many recoveries {}, max {}", len, SECTORS_MAX)); } Ok(()) } /// Removes sectors from recoveries, if present. - pub fn remove_recoveries(&mut self, sector_nos: &mut BitField) -> Result<(), String> { - if sector_nos.is_empty()? { + pub fn remove_recoveries(&mut self, sector_nos: &BitField) -> Result<(), String> { + if sector_nos.is_empty() { return Err(format!("sectors are empty: {:?}", sector_nos)); } - self.recoveries.subtract_assign(sector_nos)?; + self.recoveries -= sector_nos; Ok(()) } @@ -483,18 +480,16 @@ impl State { pub fn load_sector_infos( &self, store: &BS, - sectors: &mut BitField, + sectors: &BitField, ) -> Result, String> { let mut sector_infos: Vec = Vec::new(); - sectors.for_each(|i| { - let key: SectorNumber = i; + for i in sectors.iter() { + let key = i as u64; let sector_on_chain = self .get_sector(store, key)? .ok_or(format!("sector not found: {}", i))?; sector_infos.push(sector_on_chain); - Ok(()) - })?; - + } Ok(sector_infos) } @@ -504,33 +499,25 @@ impl State { pub fn load_sector_infos_for_proof( &mut self, store: &BS, - mut proven_sectors: BitField, + proven_sectors: BitField, ) -> Result<(Vec, BitField), String> { // Extract a fault set relevant to the sectors being submitted, for expansion into a map. - let declared_faults = self.faults.clone().intersect(&proven_sectors)?; - - let recoveries = self.recoveries.clone().intersect(&declared_faults)?; - - let mut expected_faults = declared_faults.subtract(&recoveries)?; - - let mut non_faults = expected_faults.clone().subtract(&proven_sectors)?; - - if non_faults.is_empty()? { - return Err(format!( - "failed to check if bitfield was empty: {:?}", - non_faults - )); - } + let declared_faults = &self.faults & &proven_sectors; + let recoveries = &self.recoveries & &declared_faults; + let expected_faults = &declared_faults - &recoveries; + let non_faults = &expected_faults - &proven_sectors; // Select a non-faulty sector as a substitute for faulty ones. - let good_sector_no = non_faults.first()?; + let good_sector_no = non_faults + .first() + .ok_or("no non-faulty sectors in partitions")?; // load sector infos let sector_infos = self.load_sector_infos_with_fault_mask( store, - &mut proven_sectors, - &mut expected_faults, - good_sector_no, + &proven_sectors, + &expected_faults, + good_sector_no as u64, )?; Ok((sector_infos, recoveries)) @@ -539,8 +526,8 @@ impl State { fn load_sector_infos_with_fault_mask( &self, store: &BS, - sectors: &mut BitField, - faults: &mut BitField, + sectors: &BitField, + faults: &BitField, fault_stand_in: SectorNumber, ) -> Result, String> { let sector_on_chain = self @@ -549,30 +536,28 @@ impl State { // Expand faults into a map for quick lookups. // The faults bitfield should already be a subset of the sectors bitfield. - let fault_max = sectors.count()?; - let fault_set = faults.all_set(fault_max)?; + let fault_max = sectors.len(); + let fault_set: AHashSet<_> = faults.bounded_iter(fault_max)?.collect(); // Load the sector infos, masking out fault sectors with a good one. let mut sector_infos: Vec = Vec::new(); - sectors.for_each(|i| { - let mut sector = sector_on_chain.clone(); - let _faulty = fault_set.get(&i).ok_or_else(|| { - let new_sector_on_chain = self - .get_sector(store, fault_stand_in) + for i in sectors.iter() { + let sector = if fault_set.contains(&i) { + sector_on_chain.clone() + } else { + self.get_sector(store, fault_stand_in) .unwrap() .ok_or(format!("unable to find sector: {}", i)) - .unwrap(); - sector = new_sector_on_chain; - }); + .unwrap() + }; sector_infos.push(sector); - Ok(()) - })?; + } Ok(sector_infos) } /// Adds partition numbers to the set of PoSt submissions pub fn add_post_submissions(&mut self, partition_nos: BitField) -> Result<(), String> { - self.post_submissions.merge_assign(&partition_nos)?; + self.post_submissions |= &partition_nos; Ok(()) } /// Removes all PoSt submissions @@ -868,22 +853,24 @@ impl Deadlines { } /// Adds sector numbers to a deadline. - /// The sector numbers are given as uint64 to avoid pointless conversions for bitfield use. - pub fn add_to_deadline(&mut self, deadline: usize, new_sectors: &[u64]) -> Result<(), String> { - let ns = BitField::new_from_set(new_sectors); + pub fn add_to_deadline( + &mut self, + deadline: usize, + new_sectors: &[usize], + ) -> Result<(), String> { + let ns: BitField = new_sectors.iter().copied().collect(); let sec = self .due .get_mut(deadline) .ok_or(format!("unable to find deadline: {}", deadline))?; - sec.merge_assign(&ns)?; + *sec |= &ns; Ok(()) } /// Removes sector numbers from all deadlines. - pub fn remove_from_all_deadlines(&mut self, sector_nos: &mut BitField) -> Result<(), String> { + pub fn remove_from_all_deadlines(&mut self, sector_nos: &BitField) -> Result<(), String> { for d in self.due.iter_mut() { - d.subtract_assign(§or_nos)?; + *d -= sector_nos; } - Ok(()) } }