Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix decompressing long inputs #30

Merged
merged 18 commits into from
Oct 20, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 6 additions & 17 deletions src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,14 @@
use crate::constants::BASE64_KEY;
use crate::constants::CLOSE_CODE;
use crate::constants::START_CODE_BITS;
use crate::constants::U16_CODE;
use crate::constants::U8_CODE;
use crate::constants::URI_KEY;
use crate::IntoWideIter;
use std::collections::HashMap;
use std::collections::HashSet;
use std::convert::TryInto;

/// The starting size of a codepoint.
///
/// Compression starts with the following codes:
/// 0: u8
/// 1: u16
/// 2: close stream
const START_NUM_BITS: u8 = 2;

/// The stream code for a `u8`.
const U8_CODE: u32 = 0;

/// The stream code for a `u16`.
const U16_CODE: u32 = 1;

/// The number of "base codes",
/// the default codes of all streams.
///
Expand Down Expand Up @@ -96,7 +85,7 @@ where

bit_buffer: 0,

num_bits: START_NUM_BITS,
num_bits: START_CODE_BITS,

bit_position: 0,
bits_per_char,
Expand All @@ -114,10 +103,10 @@ where
{
Some(Some(first_w_char)) => {
if first_w_char < 256 {
self.write_bits(self.num_bits, U8_CODE);
self.write_bits(self.num_bits, U8_CODE.into());
self.write_bits(8, first_w_char.into());
} else {
self.write_bits(self.num_bits, U16_CODE);
self.write_bits(self.num_bits, U16_CODE.into());
self.write_bits(16, first_w_char.into());
}
self.decrement_enlarge_in();
Expand Down
16 changes: 15 additions & 1 deletion src/constants.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
pub const URI_KEY: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-$";
pub const BASE64_KEY: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/=";

/// The stream code for a `u8`.
pub const U8_CODE: u8 = 0;

/// The stream code for a `u16`.
pub const U16_CODE: u8 = 1;

/// End of stream signal
pub const CLOSE_CODE: u16 = 2;
pub const CLOSE_CODE: u8 = 2;

/// The starting size of a code.
///
/// Compression starts with the following codes:
/// 0: u8
/// 1: u16
/// 2: close stream
pub const START_CODE_BITS: u8 = 2;
89 changes: 49 additions & 40 deletions src/decompress.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// TODO: Disable this
#![allow(clippy::cast_possible_truncation)]

use crate::constants::BASE64_KEY;
use crate::constants::CLOSE_CODE;
use crate::constants::START_CODE_BITS;
use crate::constants::U16_CODE;
use crate::constants::U8_CODE;
use crate::constants::URI_KEY;
use crate::IntoWideIter;
use std::convert::TryFrom;
Expand All @@ -12,8 +12,8 @@ use std::convert::TryInto;
pub struct DecompressContext<I> {
val: u16,
compressed_data: I,
position: usize,
reset_val: usize,
position: u16,
reset_val: u16,
}

impl<I> DecompressContext<I>
Expand All @@ -24,8 +24,17 @@ where
///
/// # Errors
/// Returns `None` if the iterator is empty.
///
/// # Panics
/// Panics if `bits_per_char` is greater than the number of bits in a `u16`.
#[inline]
pub fn new(mut compressed_data: I, reset_val: usize) -> Option<Self> {
pub fn new(mut compressed_data: I, bits_per_char: u8) -> Option<Self> {
assert!(usize::from(bits_per_char) <= std::mem::size_of::<u16>() * 8);

let reset_val_pow = bits_per_char - 1;
// (1 << 15) <= u16::MAX
let reset_val: u16 = 1 << reset_val_pow;

Some(DecompressContext {
val: compressed_data.next()?,
compressed_data,
Expand All @@ -36,7 +45,7 @@ where

#[inline]
pub fn read_bit(&mut self) -> Option<bool> {
let res = self.val & (self.position as u16);
let res = self.val & self.position;
self.position >>= 1;

if self.position == 0 {
Expand All @@ -47,11 +56,14 @@ where
Some(res != 0)
}

/// Read n bits.
///
/// `u32` is the return type as we expect all possible codes to be within that type's range.
#[inline]
pub fn read_bits(&mut self, n: usize) -> Option<u32> {
pub fn read_bits(&mut self, n: u8) -> Option<u32> {
let mut res = 0;
let max_power = 2_u32.pow(n as u32);
let mut power = 1;
let max_power: u32 = 1 << n;
let mut power: u32 = 1;
while power != max_power {
res |= u32::from(self.read_bit()?) * power;
power <<= 1;
Expand Down Expand Up @@ -162,69 +174,67 @@ pub fn decompress_from_uint8_array(compressed: &[u8]) -> Option<Vec<u16>> {
/// # Panics
/// Panics if `bits_per_char` is greater than the number of bits in a `u16`.
#[inline]
pub fn decompress_internal<I>(compressed: I, bits_per_char: usize) -> Option<Vec<u16>>
pub fn decompress_internal<I>(compressed: I, bits_per_char: u8) -> Option<Vec<u16>>
where
I: Iterator<Item = u16>,
{
assert!(bits_per_char <= std::mem::size_of::<u16>() * 8);

// u16::MAX < u32::MAX
let reset_val_pow = u32::try_from(bits_per_char).unwrap() - 1;
let reset_val = 2_usize.pow(reset_val_pow);
let mut ctx = match DecompressContext::new(compressed, reset_val) {
let mut ctx = match DecompressContext::new(compressed, bits_per_char) {
Some(ctx) => ctx,
None => return Some(Vec::new()),
};

let mut dictionary: Vec<Vec<u16>> = Vec::with_capacity(3);
let mut dictionary: Vec<Vec<u16>> = Vec::with_capacity(16);
for i in 0_u16..3_u16 {
dictionary.push(vec![i]);
}

let next = ctx.read_bits(2)?;
let first_entry: u16 = match next as u16 {
0 | 1 => {
let bits_to_read = (next * 8) + 8;
ctx.read_bits(bits_to_read as usize)? as u16
// u8::MAX > u2::MAX
let code = u8::try_from(ctx.read_bits(START_CODE_BITS)?).unwrap();
let first_entry = match code {
U8_CODE | U16_CODE => {
let bits_to_read = (code * 8) + 8;
// bits_to_read == 8 or 16 <= 16
u16::try_from(ctx.read_bits(bits_to_read)?).unwrap()
}
CLOSE_CODE => return Some(Vec::new()),
_ => return None,
};
dictionary.insert(3, vec![first_entry]);
dictionary.push(vec![first_entry]);

let mut w = vec![first_entry];
let mut result = vec![first_entry];
let mut num_bits = 3;
let mut enlarge_in = 4;
let mut dict_size = 4;
let mut num_bits: u8 = 3;
let mut enlarge_in: u64 = 4;
let mut entry;
loop {
let mut cc = ctx.read_bits(num_bits)? as usize;
match cc as u16 {
0 | 1 => {
let bits_to_read = (cc * 8) + 8;
let mut code = ctx.read_bits(num_bits)?;
match u8::try_from(code) {
Ok(code_u8 @ (U8_CODE | U16_CODE)) => {
let bits_to_read = (code_u8 * 8) + 8;
// if cc == 0 {
// if (errorCount++ > 10000) return "Error"; // TODO: Error logic
// }

let bits = ctx.read_bits(bits_to_read)? as u16;
// bits_to_read == 8 or 16 <= 16
let bits = u16::try_from(ctx.read_bits(bits_to_read)?).unwrap();
dictionary.push(vec![bits]);
dict_size += 1;
cc = dict_size - 1;
code = u32::try_from(dictionary.len() - 1).ok()?;
enlarge_in -= 1;
}
CLOSE_CODE => return Some(result),
Ok(CLOSE_CODE) => return Some(result),
_ => {}
}

if enlarge_in == 0 {
enlarge_in = 2_u32.pow(num_bits as u32);
enlarge_in = 1 << num_bits;
num_bits += 1;
}

if let Some(entry_value) = dictionary.get(cc) {
// Return error if code cannot be converted to dictionary index
let code_usize = usize::try_from(code).ok()?;
if let Some(entry_value) = dictionary.get(code_usize) {
entry = entry_value.clone();
} else if cc == dict_size {
} else if code_usize == dictionary.len() {
entry = w.clone();
entry.push(*w.first()?);
} else {
Expand All @@ -237,13 +247,12 @@ where
let mut to_be_inserted = w.clone();
to_be_inserted.push(*entry.first()?);
dictionary.push(to_be_inserted);
dict_size += 1;
enlarge_in -= 1;

w = entry;

if enlarge_in == 0 {
enlarge_in = 2_u32.pow(num_bits as u32);
enlarge_in = 1 << num_bits;
num_bits += 1;
}
}
Expand Down
4 changes: 4 additions & 0 deletions tests/valid_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,8 @@ fn valid_long_input_round() {
assert_eq!(a, b, "[index={}] {} != {}", i, a, b);
}
}
assert_eq!(compressed, js_compressed);

let decompressed = lz_str::decompress(&compressed).expect("decompression failed");
assert_eq!(decompressed, data);
}