Skip to content

Commit

Permalink
Add better handling for long inputs (#27)
Browse files Browse the repository at this point in the history
* Start adding new clippy lints

* Improve compressor bitwriter slightly

* Make num_bits a u32

* Add test for compressing long inputs

* Make `num_bits` a `u8`

* Remove unreachable `expect`s

* Add cache to Bindings action
  • Loading branch information
adumbidiot authored Oct 16, 2022
1 parent 4b0911b commit 37cc0e5
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 39 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/Bindings.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,17 @@ jobs:
- name: Checkout
uses: actions/checkout@v3

- name: Cache
uses: actions/cache@v3
with:
path: |
~/.cargo/bin/
~/.cargo/registry/index/
~/.cargo/registry/cache/
~/.cargo/git/db/
target/
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}

- name: Install Wasm-Pack
run: curl https://rustwasm.github.io/wasm-pack/installer/init.sh -sSf | sh

Expand Down
90 changes: 52 additions & 38 deletions src/compress.rs
Original file line number Diff line number Diff line change
@@ -1,39 +1,61 @@
use crate::{
constants::{BASE64_KEY, CLOSE_CODE, URI_KEY},
IntoWideIter,
};
use std::collections::{HashMap, HashSet};
use crate::constants::BASE64_KEY;
use crate::constants::CLOSE_CODE;
use crate::constants::URI_KEY;
use crate::IntoWideIter;
use std::collections::HashMap;
use std::collections::HashSet;

#[derive(Debug)]
pub(crate) struct CompressContext<F> {
dictionary: HashMap<Vec<u16>, u16>,
dictionary: HashMap<Vec<u16>, u32>,
dictionary_to_create: HashSet<Vec<u16>>,
wc: Vec<u16>,
w: Vec<u16>,
enlarge_in: usize,
dict_size: usize,
num_bits: usize,

/// The current number of bits in a code.
///
/// This is a u8,
/// because we currently assume the max code size is 32 bits.
/// 32 < u8::MAX
num_bits: u8,

// result: Vec<u16>,

// Data
output: Vec<u16>,
val: u16,
position: usize,
// Limits
bits_per_char: usize,

/// The current bit position.
bit_position: u8,

/// The maximum # of bits per char.
///
/// This value may not exceed 16,
/// as the reference implementation will also not handle values over 16.
bits_per_char: u8,

/// A transformation function to map a u16 to another u16,
/// before appending it to the output buffer.
to_char: F,
}

impl<F> CompressContext<F>
where
F: Fn(u16) -> u16,
{
/// Make a new [`CompressContext`].
///
/// # Panics
/// Panics if `bits_per_char` exceeds 16.
#[inline]
pub fn new(bits_per_char: usize, to_char: F) -> Self {
pub fn new(bits_per_char: u8, to_char: F) -> Self {
assert!(bits_per_char <= 16);

CompressContext {
dictionary: Default::default(),
dictionary_to_create: HashSet::new(),
dictionary: HashMap::with_capacity(16),
dictionary_to_create: HashSet::with_capacity(16),
wc: Vec::new(),
w: Vec::new(),
enlarge_in: 2,
Expand All @@ -43,7 +65,8 @@ where
// result: Vec::new(),
output: Vec::new(),
val: 0,
position: 0,

bit_position: 0,
bits_per_char,
to_char,
}
Expand All @@ -55,10 +78,10 @@ where
let first_w_char = self.w[0];
if first_w_char < 256 {
self.write_bits(self.num_bits, 0);
self.write_bits(8, first_w_char);
self.write_bits(8, first_w_char.into());
} else {
self.write_bits(self.num_bits, 1);
self.write_bits(16, first_w_char);
self.write_bits(16, first_w_char.into());
}
self.decrement_enlarge_in();
self.dictionary_to_create.remove(&self.w);
Expand All @@ -69,20 +92,19 @@ where
}

#[inline]
pub fn write_bit(&mut self, value: u16) {
self.val = (self.val << 1) | value;
if self.position == self.bits_per_char - 1 {
self.position = 0;
pub fn write_bit(&mut self, value: u32) {
self.val = (self.val << 1) | (value as u16);
self.bit_position += 1;
if self.bit_position == self.bits_per_char {
self.bit_position = 0;
let char_data = (self.to_char)(self.val);
self.output.push(char_data);
self.val = 0;
} else {
self.position += 1;
}
}

#[inline]
pub fn write_bits(&mut self, n: usize, mut value: u16) {
pub fn write_bits(&mut self, n: u8, mut value: u32) {
for _ in 0..n {
self.write_bit(value & 1);
value >>= 1;
Expand All @@ -93,7 +115,7 @@ where
pub fn decrement_enlarge_in(&mut self) {
self.enlarge_in -= 1;
if self.enlarge_in == 0 {
self.enlarge_in = 2_usize.pow(self.num_bits as u32);
self.enlarge_in = 2_usize.pow(self.num_bits.into());
self.num_bits += 1;
}
}
Expand All @@ -103,7 +125,7 @@ where
pub fn write_u16(&mut self, c: u16) {
let c = vec![c];
if !self.dictionary.contains_key(&c) {
self.dictionary.insert(c.clone(), self.dict_size as u16);
self.dictionary.insert(c.clone(), self.dict_size as u32);
self.dict_size += 1;
self.dictionary_to_create.insert(c.clone());
}
Expand All @@ -116,7 +138,7 @@ where
self.produce_w();
// Add wc to the dictionary.
self.dictionary
.insert(self.wc.clone(), self.dict_size as u16);
.insert(self.wc.clone(), self.dict_size as u32);
self.dict_size += 1;
self.w = c;
}
Expand All @@ -131,7 +153,7 @@ where
}

// Mark the end of the stream
self.write_bits(self.num_bits, CLOSE_CODE);
self.write_bits(self.num_bits, CLOSE_CODE.into());

let str_len = self.output.len();
// Flush the last char
Expand Down Expand Up @@ -170,11 +192,7 @@ pub fn compress_to_utf16(input: impl IntoWideIter) -> String {
#[inline]
pub fn compress_to_encoded_uri_component(data: impl IntoWideIter) -> String {
let compressed = compress_internal(data.into_wide_iter(), 6, |n| {
u16::from(
*URI_KEY
.get(usize::from(n))
.expect("Invalid index into `URI_KEY` in `compress_to_encoded_uri_component`"),
)
u16::from(URI_KEY[usize::from(n)])
});

String::from_utf16(&compressed)
Expand All @@ -186,11 +204,7 @@ pub fn compress_to_encoded_uri_component(data: impl IntoWideIter) -> String {
/// This function converts the result back into a Rust [`String`] since it is guaranteed to be valid unicode.
pub fn compress_to_base64(data: impl IntoWideIter) -> String {
let mut compressed = compress_internal(data.into_wide_iter(), 6, |n| {
u16::from(
*BASE64_KEY
.get(usize::from(n))
.expect("Invalid index into `BASE64_KEY` in `compress_to_base64`"),
)
u16::from(BASE64_KEY[usize::from(n)])
});

let mod_4 = compressed.len() % 4;
Expand All @@ -217,7 +231,7 @@ pub fn compress_to_uint8_array(data: impl IntoWideIter) -> Vec<u8> {
/// All other compression functions are built on top of this.
/// It generally should not be used directly.
#[inline]
pub fn compress_internal<I, F>(uncompressed: I, bits_per_char: usize, to_char: F) -> Vec<u16>
pub fn compress_internal<I, F>(uncompressed: I, bits_per_char: u8, to_char: F) -> Vec<u16>
where
I: Iterator<Item = u16>,
F: Fn(u16) -> u16,
Expand Down
4 changes: 4 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
#![forbid(unsafe_code)]
#![deny(missing_docs)]
#![warn(clippy::cast_lossless)]
#![warn(clippy::cast_possible_wrap)]
// TODO: Enable this
// #![warn(clippy::cast_possible_truncation)]

//! A port of [lz-string](https://github.com/pieroxy/lz-string) to Rust.
//!
Expand Down
1 change: 1 addition & 0 deletions test_data/long_compressed_js.txt

Large diffs are not rendered by default.

28 changes: 27 additions & 1 deletion tests/valid_inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,33 @@ fn valid_decompress() {
&[("red123", vec![0x80, 0x80]), ("腆퍂蚂荂", vec![0xD8A0])];
for (data, expected) in valid_data {
let arr: Vec<u16> = data.encode_utf16().collect();
let decompressed = decompress(&arr).expect("Valid Decompress");
let decompressed = decompress(&arr).expect("decompression failed");
assert_eq!(&decompressed, expected);
}
}

#[test]
fn valid_long_input_round() {
// let buffer = [];
// for(let i = 0; i < 100000; i++){
// buffer.push(i % 65_535);
// }
// result = LZString144.compress(String.fromCharCode(...buffer));
// Array.from(result).map((v) => v.charCodeAt(0));
let data: Vec<u16> = (0u64..100_000u64)
.map(|val| (val % u64::from(std::u16::MAX)) as u16)
.collect();

let compressed = lz_str::compress(&data);

let js_compressed = include_str!("../test_data/long_compressed_js.txt")
.split(',')
.map(|s| s.trim().parse::<u16>().unwrap())
.collect::<Vec<u16>>();

for (i, (a, b)) in compressed.iter().zip(js_compressed.iter()).enumerate() {
if a != b {
assert_eq!(a, b, "[index={}] {} != {}", i, a, b);
}
}
}

0 comments on commit 37cc0e5

Please sign in to comment.