-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Optimize and simplify
BitReaderReversed
(#81)
- Loading branch information
Showing
1 changed file
with
93 additions
and
167 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,223 +1,149 @@ | ||
use crate::io::Read; | ||
use core::convert::TryInto; | ||
|
||
/// Zstandard encodes some types of data in a way that the data must be read | ||
/// back to front to decode it properly. `BitReaderReversed` provides a | ||
/// convenient interface to do that. | ||
pub struct BitReaderReversed<'s> { | ||
idx: isize, //index counts bits already read | ||
/// The index of the last read byte in the source. | ||
index: usize, | ||
|
||
/// How many bits have been consumed from `bit_container`. | ||
bits_consumed: u8, | ||
|
||
/// How many bits have been consumed past the end of the input. Will be zero until all the input | ||
/// has been read. | ||
extra_bits: usize, | ||
|
||
/// The source data to read from. | ||
source: &'s [u8], | ||
/// The reader doesn't read directly from the source, | ||
/// it reads bits from here, and the container is | ||
/// "refilled" as it's emptied. | ||
|
||
/// The reader doesn't read directly from the source, it reads bits from here, and the container | ||
/// is "refilled" as it's emptied. | ||
bit_container: u64, | ||
bits_in_container: u8, | ||
} | ||
|
||
impl<'s> BitReaderReversed<'s> { | ||
/// How many bits are left to read by the reader. | ||
pub fn bits_remaining(&self) -> isize { | ||
self.idx + self.bits_in_container as isize | ||
self.index as isize * 8 + (64 - self.bits_consumed as isize) - self.extra_bits as isize | ||
} | ||
|
||
pub fn new(source: &'s [u8]) -> BitReaderReversed<'s> { | ||
BitReaderReversed { | ||
idx: source.len() as isize * 8, | ||
index: source.len(), | ||
bits_consumed: 64, | ||
source, | ||
bit_container: 0, | ||
bits_in_container: 0, | ||
extra_bits: 0, | ||
} | ||
} | ||
|
||
/// We refill the container in full bytes, shifting the still unread portion to the left, and filling the lower bits with new data | ||
#[inline(always)] | ||
fn refill_container(&mut self) { | ||
let byte_idx = self.byte_idx() as usize; | ||
|
||
let retain_bytes = (self.bits_in_container + 7) / 8; | ||
let want_to_read_bits = 64 - (retain_bytes * 8); | ||
|
||
// if there are >= 8 byte left to read we go a fast path: | ||
// The slice is looking something like this |U..UCCCCCCCCR..R| Where U are some unread bytes, C are the bytes in the container, and R are already read bytes | ||
// What we do is, we shift the container by a few bytes to the left by just reading a u64 from the correct position, rereading the portion we did not yet return from the conainer. | ||
// Technically this would still work for positions lower than 8 but this guarantees that enough bytes are in the source and generally makes for less edge cases | ||
if byte_idx >= 8 { | ||
self.refill_fast(byte_idx, retain_bytes, want_to_read_bits) | ||
} else { | ||
// In the slow path we just read however many bytes we can | ||
self.refill_slow(byte_idx, want_to_read_bits) | ||
#[cold] | ||
fn refill(&mut self) { | ||
let bytes_consumed = self.bits_consumed as usize / 8; | ||
if bytes_consumed == 0 { | ||
return; | ||
} | ||
} | ||
|
||
#[inline(always)] | ||
fn refill_fast(&mut self, byte_idx: usize, retain_bytes: u8, want_to_read_bits: u8) { | ||
let load_from_byte_idx = byte_idx - 7 + retain_bytes as usize; | ||
let tmp_bytes: [u8; 8] = (&self.source[load_from_byte_idx..][..8]) | ||
.try_into() | ||
.unwrap(); | ||
let refill = u64::from_le_bytes(tmp_bytes); | ||
self.bit_container = refill; | ||
self.bits_in_container += want_to_read_bits; | ||
self.idx -= want_to_read_bits as isize; | ||
} | ||
|
||
#[cold] | ||
fn refill_slow(&mut self, byte_idx: usize, want_to_read_bits: u8) { | ||
let can_read_bits = isize::min(want_to_read_bits as isize, self.idx); | ||
let can_read_bytes = can_read_bits / 8; | ||
let mut tmp_bytes = [0u8; 8]; | ||
let offset @ 1..=8 = can_read_bytes as usize else { | ||
unreachable!() | ||
}; | ||
let bits_read = offset * 8; | ||
|
||
let _ = (&self.source[byte_idx - (offset - 1)..]).read_exact(&mut tmp_bytes[0..offset]); | ||
self.bits_in_container += bits_read as u8; | ||
self.idx -= bits_read as isize; | ||
if offset < 8 { | ||
self.bit_container <<= bits_read; | ||
self.bit_container |= u64::from_le_bytes(tmp_bytes); | ||
if self.index >= bytes_consumed { | ||
self.index -= bytes_consumed; | ||
self.bits_consumed &= 7; | ||
self.bit_container = | ||
u64::from_le_bytes((&self.source[self.index..][..8]).try_into().unwrap()); | ||
} else if self.index > 0 { | ||
if self.source.len() >= 8 { | ||
self.bit_container = u64::from_le_bytes((&self.source[..8]).try_into().unwrap()); | ||
} else { | ||
let mut value = [0; 8]; | ||
value[..self.source.len()].copy_from_slice(self.source); | ||
self.bit_container = u64::from_le_bytes(value); | ||
} | ||
|
||
self.bits_consumed -= 8 * self.index as u8; | ||
self.index = 0; | ||
|
||
self.bit_container <<= self.bits_consumed; | ||
self.extra_bits += self.bits_consumed as usize; | ||
self.bits_consumed = 0; | ||
} else if self.bits_consumed < 64 { | ||
self.bit_container <<= self.bits_consumed; | ||
self.extra_bits += self.bits_consumed as usize; | ||
self.bits_consumed = 0; | ||
} else { | ||
self.bit_container = u64::from_le_bytes(tmp_bytes); | ||
self.extra_bits += self.bits_consumed as usize; | ||
self.bits_consumed = 0; | ||
self.bit_container = 0; | ||
} | ||
} | ||
|
||
/// Next byte that should be read into the container | ||
/// Negative values mean that the source buffer as been read into the container completetly. | ||
fn byte_idx(&self) -> isize { | ||
(self.idx - 1) / 8 | ||
// Assert that at least `56 = 64 - 8` bits are available to read. | ||
debug_assert!(self.bits_consumed < 8); | ||
} | ||
|
||
/// Read `n` number of bits from the source. Will read at most 56 bits. | ||
/// If there are no more bits to be read from the source zero bits will be returned instead. | ||
#[inline(always)] | ||
pub fn get_bits(&mut self, n: u8) -> u64 { | ||
if n == 0 { | ||
return 0; | ||
} | ||
if self.bits_in_container >= n { | ||
return self.get_bits_unchecked(n); | ||
if self.bits_consumed + n > 64 { | ||
self.refill(); | ||
} | ||
|
||
self.get_bits_cold(n) | ||
let value = self.peek_bits(n); | ||
self.consume(n); | ||
value | ||
} | ||
|
||
#[cold] | ||
fn get_bits_cold(&mut self, n: u8) -> u64 { | ||
let n = u8::min(n, 56); | ||
let signed_n = n as isize; | ||
|
||
if self.bits_remaining() <= 0 { | ||
self.idx -= signed_n; | ||
/// Get the next `n` bits from the source without consuming them. | ||
#[inline(always)] | ||
pub fn peek_bits(&mut self, n: u8) -> u64 { | ||
if n == 0 { | ||
return 0; | ||
} | ||
|
||
if self.bits_remaining() < signed_n { | ||
let emulated_read_shift = signed_n - self.bits_remaining(); | ||
let v = self.get_bits(self.bits_remaining() as u8); | ||
debug_assert!(self.idx == 0); | ||
let value = v.wrapping_shl(emulated_read_shift as u32); | ||
self.idx -= emulated_read_shift; | ||
return value; | ||
} | ||
|
||
while (self.bits_in_container < n) && self.idx > 0 { | ||
self.refill_container(); | ||
} | ||
|
||
debug_assert!(self.bits_in_container >= n); | ||
|
||
//if we reach this point there are enough bits in the container | ||
let mask = (1u64 << n) - 1u64; | ||
let shift_by = 64 - self.bits_consumed - n; | ||
(self.bit_container >> shift_by) & mask | ||
} | ||
|
||
self.get_bits_unchecked(n) | ||
/// Consume `n` bits from the source. | ||
#[inline(always)] | ||
pub fn consume(&mut self, n: u8) { | ||
self.bits_consumed += n; | ||
debug_assert!(self.bits_consumed <= 64); | ||
} | ||
|
||
/// Same as calling get_bits three times but slightly more performant | ||
#[inline(always)] | ||
pub fn get_bits_triple(&mut self, n1: u8, n2: u8, n3: u8) -> (u64, u64, u64) { | ||
let sum = n1 as usize + n2 as usize + n3 as usize; | ||
if sum == 0 { | ||
return (0, 0, 0); | ||
} | ||
if sum > 56 { | ||
// try and get the values separately | ||
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)); | ||
} | ||
let sum = sum as u8; | ||
let sum = n1 + n2 + n3; | ||
if sum <= 56 { | ||
self.refill(); | ||
|
||
if self.bits_in_container >= sum { | ||
let v1 = if n1 == 0 { | ||
0 | ||
} else { | ||
self.get_bits_unchecked(n1) | ||
}; | ||
let v2 = if n2 == 0 { | ||
0 | ||
} else { | ||
self.get_bits_unchecked(n2) | ||
}; | ||
let v3 = if n3 == 0 { | ||
0 | ||
} else { | ||
self.get_bits_unchecked(n3) | ||
}; | ||
let v1 = self.peek_bits(n1); | ||
self.consume(n1); | ||
let v2 = self.peek_bits(n2); | ||
self.consume(n2); | ||
let v3 = self.peek_bits(n3); | ||
self.consume(n3); | ||
|
||
return (v1, v2, v3); | ||
} | ||
|
||
self.get_bits_triple_cold(n1, n2, n3, sum) | ||
(self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)) | ||
} | ||
} | ||
|
||
#[cold] | ||
fn get_bits_triple_cold(&mut self, n1: u8, n2: u8, n3: u8, sum: u8) -> (u64, u64, u64) { | ||
let sum_signed = sum as isize; | ||
|
||
if self.bits_remaining() <= 0 { | ||
self.idx -= sum_signed; | ||
return (0, 0, 0); | ||
} | ||
|
||
if self.bits_remaining() < sum_signed { | ||
return (self.get_bits(n1), self.get_bits(n2), self.get_bits(n3)); | ||
} | ||
|
||
while (self.bits_in_container < sum) && self.idx > 0 { | ||
self.refill_container(); | ||
} | ||
|
||
debug_assert!(self.bits_in_container >= sum); | ||
|
||
//if we reach this point there are enough bits in the container | ||
|
||
let v1 = if n1 == 0 { | ||
0 | ||
} else { | ||
self.get_bits_unchecked(n1) | ||
}; | ||
let v2 = if n2 == 0 { | ||
0 | ||
} else { | ||
self.get_bits_unchecked(n2) | ||
}; | ||
let v3 = if n3 == 0 { | ||
0 | ||
} else { | ||
self.get_bits_unchecked(n3) | ||
}; | ||
|
||
(v1, v2, v3) | ||
} | ||
|
||
#[inline(always)] | ||
fn get_bits_unchecked(&mut self, n: u8) -> u64 { | ||
let shift_by = self.bits_in_container - n; | ||
let mask = (1u64 << n) - 1u64; | ||
|
||
let value = self.bit_container >> shift_by; | ||
self.bits_in_container -= n; | ||
let value_masked = value & mask; | ||
debug_assert!(value_masked < (1 << n)); | ||
|
||
value_masked | ||
#[cfg(test)] | ||
mod test { | ||
|
||
#[test] | ||
fn it_works() { | ||
let data = [0b10101010, 0b01010101]; | ||
let mut br = super::BitReaderReversed::new(&data); | ||
assert_eq!(br.get_bits(1), 0); | ||
assert_eq!(br.get_bits(1), 1); | ||
assert_eq!(br.get_bits(1), 0); | ||
assert_eq!(br.get_bits(4), 0b1010); | ||
assert_eq!(br.get_bits(4), 0b1101); | ||
} | ||
} |