diff --git a/src/compress.rs b/src/compress.rs index 24dac27..ed73273 100644 --- a/src/compress.rs +++ b/src/compress.rs @@ -1,25 +1,14 @@ use crate::constants::BASE64_KEY; use crate::constants::CLOSE_CODE; +use crate::constants::START_CODE_BITS; +use crate::constants::U16_CODE; +use crate::constants::U8_CODE; use crate::constants::URI_KEY; use crate::IntoWideIter; use std::collections::HashMap; use std::collections::HashSet; use std::convert::TryInto; -/// The starting size of a codepoint. -/// -/// Compression starts with the following codes: -/// 0: u8 -/// 1: u16 -/// 2: close stream -const START_NUM_BITS: u8 = 2; - -/// The stream code for a `u8`. -const U8_CODE: u32 = 0; - -/// The stream code for a `u16`. -const U16_CODE: u32 = 1; - /// The number of "base codes", /// the default codes of all streams. /// @@ -96,7 +85,7 @@ where bit_buffer: 0, - num_bits: START_NUM_BITS, + num_bits: START_CODE_BITS, bit_position: 0, bits_per_char, @@ -114,10 +103,10 @@ where { Some(Some(first_w_char)) => { if first_w_char < 256 { - self.write_bits(self.num_bits, U8_CODE); + self.write_bits(self.num_bits, U8_CODE.into()); self.write_bits(8, first_w_char.into()); } else { - self.write_bits(self.num_bits, U16_CODE); + self.write_bits(self.num_bits, U16_CODE.into()); self.write_bits(16, first_w_char.into()); } self.decrement_enlarge_in(); diff --git a/src/constants.rs b/src/constants.rs index 7380c7b..b27d490 100644 --- a/src/constants.rs +++ b/src/constants.rs @@ -1,5 +1,19 @@ pub const URI_KEY: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+-$"; pub const BASE64_KEY: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/="; +/// The stream code for a `u8`. +pub const U8_CODE: u8 = 0; + +/// The stream code for a `u16`. +pub const U16_CODE: u8 = 1; + /// End of stream signal -pub const CLOSE_CODE: u16 = 2; +pub const CLOSE_CODE: u8 = 2; + +/// The starting size of a code. +/// +/// Compression starts with the following codes: +/// 0: u8 +/// 1: u16 +/// 2: close stream +pub const START_CODE_BITS: u8 = 2; diff --git a/src/decompress.rs b/src/decompress.rs index f86efca..b6a39f8 100644 --- a/src/decompress.rs +++ b/src/decompress.rs @@ -1,8 +1,8 @@ -// TODO: Disable this -#![allow(clippy::cast_possible_truncation)] - use crate::constants::BASE64_KEY; use crate::constants::CLOSE_CODE; +use crate::constants::START_CODE_BITS; +use crate::constants::U16_CODE; +use crate::constants::U8_CODE; use crate::constants::URI_KEY; use crate::IntoWideIter; use std::convert::TryFrom; @@ -12,8 +12,8 @@ use std::convert::TryInto; pub struct DecompressContext { val: u16, compressed_data: I, - position: usize, - reset_val: usize, + position: u16, + reset_val: u16, } impl DecompressContext @@ -24,8 +24,17 @@ where /// /// # Errors /// Returns `None` if the iterator is empty. + /// + /// # Panics + /// Panics if `bits_per_char` is greater than the number of bits in a `u16`. #[inline] - pub fn new(mut compressed_data: I, reset_val: usize) -> Option { + pub fn new(mut compressed_data: I, bits_per_char: u8) -> Option { + assert!(usize::from(bits_per_char) <= std::mem::size_of::() * 8); + + let reset_val_pow = bits_per_char - 1; + // (1 << 15) <= u16::MAX + let reset_val: u16 = 1 << reset_val_pow; + Some(DecompressContext { val: compressed_data.next()?, compressed_data, @@ -36,7 +45,7 @@ where #[inline] pub fn read_bit(&mut self) -> Option { - let res = self.val & (self.position as u16); + let res = self.val & self.position; self.position >>= 1; if self.position == 0 { @@ -47,11 +56,14 @@ where Some(res != 0) } + /// Read n bits. + /// + /// `u32` is the return type as we expect all possible codes to be within that type's range. #[inline] - pub fn read_bits(&mut self, n: usize) -> Option { + pub fn read_bits(&mut self, n: u8) -> Option { let mut res = 0; - let max_power = 2_u32.pow(n as u32); - let mut power = 1; + let max_power: u32 = 1 << n; + let mut power: u32 = 1; while power != max_power { res |= u32::from(self.read_bit()?) * power; power <<= 1; @@ -162,69 +174,67 @@ pub fn decompress_from_uint8_array(compressed: &[u8]) -> Option> { /// # Panics /// Panics if `bits_per_char` is greater than the number of bits in a `u16`. #[inline] -pub fn decompress_internal(compressed: I, bits_per_char: usize) -> Option> +pub fn decompress_internal(compressed: I, bits_per_char: u8) -> Option> where I: Iterator, { - assert!(bits_per_char <= std::mem::size_of::() * 8); - - // u16::MAX < u32::MAX - let reset_val_pow = u32::try_from(bits_per_char).unwrap() - 1; - let reset_val = 2_usize.pow(reset_val_pow); - let mut ctx = match DecompressContext::new(compressed, reset_val) { + let mut ctx = match DecompressContext::new(compressed, bits_per_char) { Some(ctx) => ctx, None => return Some(Vec::new()), }; - let mut dictionary: Vec> = Vec::with_capacity(3); + let mut dictionary: Vec> = Vec::with_capacity(16); for i in 0_u16..3_u16 { dictionary.push(vec![i]); } - let next = ctx.read_bits(2)?; - let first_entry: u16 = match next as u16 { - 0 | 1 => { - let bits_to_read = (next * 8) + 8; - ctx.read_bits(bits_to_read as usize)? as u16 + // u8::MAX > u2::MAX + let code = u8::try_from(ctx.read_bits(START_CODE_BITS)?).unwrap(); + let first_entry = match code { + U8_CODE | U16_CODE => { + let bits_to_read = (code * 8) + 8; + // bits_to_read == 8 or 16 <= 16 + u16::try_from(ctx.read_bits(bits_to_read)?).unwrap() } CLOSE_CODE => return Some(Vec::new()), _ => return None, }; - dictionary.insert(3, vec![first_entry]); + dictionary.push(vec![first_entry]); let mut w = vec![first_entry]; let mut result = vec![first_entry]; - let mut num_bits = 3; - let mut enlarge_in = 4; - let mut dict_size = 4; + let mut num_bits: u8 = 3; + let mut enlarge_in: u64 = 4; let mut entry; loop { - let mut cc = ctx.read_bits(num_bits)? as usize; - match cc as u16 { - 0 | 1 => { - let bits_to_read = (cc * 8) + 8; + let mut code = ctx.read_bits(num_bits)?; + match u8::try_from(code) { + Ok(code_u8 @ (U8_CODE | U16_CODE)) => { + let bits_to_read = (code_u8 * 8) + 8; // if cc == 0 { // if (errorCount++ > 10000) return "Error"; // TODO: Error logic // } - let bits = ctx.read_bits(bits_to_read)? as u16; + // bits_to_read == 8 or 16 <= 16 + let bits = u16::try_from(ctx.read_bits(bits_to_read)?).unwrap(); dictionary.push(vec![bits]); - dict_size += 1; - cc = dict_size - 1; + code = u32::try_from(dictionary.len() - 1).ok()?; enlarge_in -= 1; } - CLOSE_CODE => return Some(result), + Ok(CLOSE_CODE) => return Some(result), _ => {} } if enlarge_in == 0 { - enlarge_in = 2_u32.pow(num_bits as u32); + enlarge_in = 1 << num_bits; num_bits += 1; } - if let Some(entry_value) = dictionary.get(cc) { + // Return error if code cannot be converted to dictionary index + let code_usize = usize::try_from(code).ok()?; + if let Some(entry_value) = dictionary.get(code_usize) { entry = entry_value.clone(); - } else if cc == dict_size { + } else if code_usize == dictionary.len() { entry = w.clone(); entry.push(*w.first()?); } else { @@ -237,13 +247,12 @@ where let mut to_be_inserted = w.clone(); to_be_inserted.push(*entry.first()?); dictionary.push(to_be_inserted); - dict_size += 1; enlarge_in -= 1; w = entry; if enlarge_in == 0 { - enlarge_in = 2_u32.pow(num_bits as u32); + enlarge_in = 1 << num_bits; num_bits += 1; } } diff --git a/tests/valid_inputs.rs b/tests/valid_inputs.rs index 38e9f14..4aee9df 100644 --- a/tests/valid_inputs.rs +++ b/tests/valid_inputs.rs @@ -35,4 +35,8 @@ fn valid_long_input_round() { assert_eq!(a, b, "[index={}] {} != {}", i, a, b); } } + assert_eq!(compressed, js_compressed); + + let decompressed = lz_str::decompress(&compressed).expect("decompression failed"); + assert_eq!(decompressed, data); }