Skip to content

Commit 378861a

Browse files
authored
feat: implement check for arm and portable-simd (#10)
1 parent a3944b6 commit 378861a

File tree

8 files changed

+248
-144
lines changed

8 files changed

+248
-144
lines changed

.github/workflows/ci.yml

+4-6
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ jobs:
4040
- uses: Swatinem/rust-cache@v2
4141

4242
- run: cargo build
43-
- run: cargo test
4443
- run: cargo build --no-default-features
45-
- run: cargo test --tests --no-default-features
46-
- run: cargo test --tests --no-default-features --features force-generic
47-
- run: cargo test --tests --no-default-features --features nightly,portable-simd
44+
- run: cargo test
45+
- run: cargo test --no-default-features
46+
- run: cargo test --no-default-features --features force-generic
47+
- run: cargo test --no-default-features --features nightly,portable-simd
4848
if: matrix.rust == 'nightly'
4949
- run: cargo bench --no-run
5050
if: matrix.rust == 'nightly'
@@ -65,8 +65,6 @@ jobs:
6565
- uses: dtolnay/rust-toolchain@miri
6666
with:
6767
target: ${{ matrix.target }}
68-
- uses: Swatinem/rust-cache@v2
69-
- run: cargo miri setup --target ${{ matrix.target }} ${{ matrix.flags }}
7068
- run: cargo miri test --target ${{ matrix.target }} ${{ matrix.flags }}
7169

7270
fuzz:

README.md

+108-89
Large diffs are not rendered by default.

benches/bench/main.rs

+44
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,50 @@ impl<const N: usize> fmt::Display for StdFormat<N> {
2929

3030
macro_rules! benches {
3131
($($name:ident($enc:expr, $dec:expr))*) => {
32+
mod check {
33+
use super::*;
34+
35+
mod const_hex {
36+
use super::*;
37+
38+
$(
39+
#[bench]
40+
fn $name(b: &mut Bencher) {
41+
b.iter(|| {
42+
::const_hex::check(black_box($dec))
43+
});
44+
}
45+
)*
46+
}
47+
48+
mod faster_hex {
49+
use super::*;
50+
51+
$(
52+
#[bench]
53+
fn $name(b: &mut Bencher) {
54+
b.iter(|| {
55+
::faster_hex::hex_check(black_box($dec.as_bytes()))
56+
});
57+
}
58+
)*
59+
}
60+
61+
mod naive {
62+
use super::*;
63+
64+
$(
65+
#[bench]
66+
fn $name(b: &mut Bencher) {
67+
b.iter(|| {
68+
let dec = black_box($dec.as_bytes());
69+
dec.iter().all(u8::is_ascii_hexdigit)
70+
});
71+
}
72+
)*
73+
}
74+
}
75+
3276
#[cfg(feature = "alloc")]
3377
mod decode {
3478
use super::*;

src/arch/aarch64.rs

+40-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use super::generic;
44
use crate::get_chars_table;
55
use core::arch::aarch64::*;
66

7-
pub(crate) const USE_CHECK_FN: bool = false;
7+
pub(crate) const USE_CHECK_FN: bool = true;
88
const CHUNK_SIZE: usize = core::mem::size_of::<uint8x16_t>();
99

1010
cfg_if::cfg_if! {
@@ -63,6 +63,44 @@ pub(crate) unsafe fn encode_neon<const UPPER: bool>(input: &[u8], output: *mut u
6363
}
6464
}
6565

66-
pub(crate) use generic::check;
66+
#[inline]
67+
pub(crate) fn check(input: &[u8]) -> bool {
68+
if cfg!(miri) || !has_neon() || input.len() < CHUNK_SIZE {
69+
return generic::check(input);
70+
}
71+
unsafe { check_neon(input) }
72+
}
73+
74+
#[target_feature(enable = "neon")]
75+
pub(crate) unsafe fn check_neon(input: &[u8]) -> bool {
76+
let ascii_zero = vdupq_n_u8(b'0' - 1);
77+
let ascii_nine = vdupq_n_u8(b'9' + 1);
78+
let ascii_ua = vdupq_n_u8(b'A' - 1);
79+
let ascii_uf = vdupq_n_u8(b'F' + 1);
80+
let ascii_la = vdupq_n_u8(b'a' - 1);
81+
let ascii_lf = vdupq_n_u8(b'f' + 1);
82+
83+
let (prefix, chunks, suffix) = input.align_to::<uint8x16_t>();
84+
generic::check(prefix)
85+
&& chunks.iter().all(|&chunk| {
86+
let ge0 = vcgtq_u8(chunk, ascii_zero);
87+
let le9 = vcltq_u8(chunk, ascii_nine);
88+
let valid_digit = vandq_u8(ge0, le9);
89+
90+
let geua = vcgtq_u8(chunk, ascii_ua);
91+
let leuf = vcltq_u8(chunk, ascii_uf);
92+
let valid_upper = vandq_u8(geua, leuf);
93+
94+
let gela = vcgtq_u8(chunk, ascii_la);
95+
let lelf = vcltq_u8(chunk, ascii_lf);
96+
let valid_lower = vandq_u8(gela, lelf);
97+
98+
let valid_letter = vorrq_u8(valid_lower, valid_upper);
99+
let valid_mask = vorrq_u8(valid_digit, valid_letter);
100+
vminvq_u8(valid_mask) == 0xFF
101+
})
102+
&& generic::check(suffix)
103+
}
104+
67105
pub(crate) use generic::decode_checked;
68106
pub(crate) use generic::decode_unchecked;

src/arch/generic.rs

+5-6
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
2424
/// Default check function.
2525
#[inline]
2626
pub(crate) const fn check(mut input: &[u8]) -> bool {
27-
while let [byte, rest @ ..] = input {
28-
if HEX_DECODE_LUT[*byte as usize] == NIL {
27+
while let &[byte, ref rest @ ..] = input {
28+
if HEX_DECODE_LUT[byte as usize] == NIL {
2929
return false;
3030
}
3131
input = rest;
@@ -48,8 +48,9 @@ pub(crate) unsafe fn decode_checked(input: &[u8], output: &mut [u8]) -> bool {
4848
///
4949
/// Assumes `output.len() == input.len() / 2` and that the input is valid hex.
5050
pub(crate) unsafe fn decode_unchecked(input: &[u8], output: &mut [u8]) {
51-
let r = unsafe { decode_maybe_check::<false>(input, output) };
52-
debug_assert!(r);
51+
#[allow(unused_braces)] // False positive on older rust versions.
52+
let success = unsafe { decode_maybe_check::<{ cfg!(debug_assertions) }>(input, output) };
53+
debug_assert!(success);
5354
}
5455

5556
/// Default decoding function. Checks input validity if `CHECK` is `true`, otherwise assumes it.
@@ -67,8 +68,6 @@ unsafe fn decode_maybe_check<const CHECK: bool>(input: &[u8], output: &mut [u8])
6768
if $var == NIL {
6869
return false;
6970
}
70-
} else {
71-
debug_assert_ne!($var, NIL, "invalid hex input");
7271
}
7372
};
7473
}

src/arch/portable_simd.rs

+22-8
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
use super::generic;
22
use crate::get_chars_table;
3-
use core::simd::u8x16;
3+
use core::simd::prelude::*;
44
use core::slice;
55

6-
pub(crate) const USE_CHECK_FN: bool = false;
7-
const CHUNK_SIZE: usize = core::mem::size_of::<u8x16>();
6+
type Simd = u8x16;
7+
8+
pub(crate) const USE_CHECK_FN: bool = true;
9+
const CHUNK_SIZE: usize = core::mem::size_of::<Simd>();
810

911
pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
1012
let mut i = 0;
@@ -14,18 +16,18 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
1416
unsafe { generic::encode::<UPPER>(prefix, output) };
1517
i += prefix.len() * 2;
1618

17-
let hex_table = u8x16::from_array(*get_chars_table::<UPPER>());
19+
let hex_table = Simd::from_array(*get_chars_table::<UPPER>());
1820
for &chunk in chunks {
1921
// Load input bytes and mask to nibbles.
20-
let mut lo = chunk & u8x16::splat(15);
21-
let mut hi = chunk >> u8x16::splat(4);
22+
let mut lo = chunk & Simd::splat(15);
23+
let mut hi = chunk >> Simd::splat(4);
2224

2325
// Lookup the corresponding ASCII hex digit for each nibble.
2426
lo = hex_table.swizzle_dyn(lo);
2527
hi = hex_table.swizzle_dyn(hi);
2628

2729
// Interleave the nibbles ([hi[0], lo[0], hi[1], lo[1], ...]).
28-
let (hex_lo, hex_hi) = u8x16::interleave(hi, lo);
30+
let (hex_lo, hex_hi) = Simd::interleave(hi, lo);
2931

3032
// Store result into the output buffer.
3133
// SAFETY: ensured by caller.
@@ -41,6 +43,18 @@ pub(crate) unsafe fn encode<const UPPER: bool>(input: &[u8], output: *mut u8) {
4143
unsafe { generic::encode::<UPPER>(suffix, output.add(i)) };
4244
}
4345

44-
pub(crate) use generic::check;
46+
pub(crate) fn check(input: &[u8]) -> bool {
47+
let (prefix, chunks, suffix) = input.as_simd::<CHUNK_SIZE>();
48+
generic::check(prefix)
49+
&& chunks.iter().all(|&chunk| {
50+
let valid_digit = chunk.simd_ge(Simd::splat(b'0')) & chunk.simd_le(Simd::splat(b'9'));
51+
let valid_upper = chunk.simd_ge(Simd::splat(b'A')) & chunk.simd_le(Simd::splat(b'F'));
52+
let valid_lower = chunk.simd_ge(Simd::splat(b'a')) & chunk.simd_le(Simd::splat(b'f'));
53+
let valid = valid_digit | valid_upper | valid_lower;
54+
valid.all()
55+
})
56+
&& generic::check(suffix)
57+
}
58+
4559
pub(crate) use generic::decode_checked;
4660
pub(crate) use generic::decode_unchecked;

src/arch/x86.rs

+24-32
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ pub(crate) const USE_CHECK_FN: bool = true;
1212
const CHUNK_SIZE_SSE: usize = core::mem::size_of::<__m128i>();
1313
const CHUNK_SIZE_AVX: usize = core::mem::size_of::<__m256i>();
1414

15-
const T_MASK: i32 = 65535;
16-
1715
cfg_if::cfg_if! {
1816
if #[cfg(feature = "std")] {
1917
#[inline(always)]
@@ -58,11 +56,11 @@ unsafe fn encode_ssse3<const UPPER: bool>(input: &[u8], output: *mut u8) {
5856
let input_remainder = input_chunks.remainder();
5957

6058
let mut i = 0;
61-
for input_chunk in input_chunks {
59+
for chunk in input_chunks {
6260
// Load input bytes and mask to nibbles.
63-
let input_bytes = _mm_loadu_si128(input_chunk.as_ptr().cast());
64-
let mut lo = _mm_and_si128(input_bytes, mask_lo);
65-
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(input_bytes, mask_hi));
61+
let chunk = _mm_loadu_si128(chunk.as_ptr().cast());
62+
let mut lo = _mm_and_si128(chunk, mask_lo);
63+
let mut hi = _mm_srli_epi32::<4>(_mm_and_si128(chunk, mask_hi));
6664

6765
// Lookup the corresponding ASCII hex digit for each nibble.
6866
lo = _mm_shuffle_epi8(hex_table, lo);
@@ -101,32 +99,26 @@ unsafe fn check_sse2(input: &[u8]) -> bool {
10199
let ascii_la = _mm_set1_epi8((b'a' - 1) as i8);
102100
let ascii_lf = _mm_set1_epi8((b'f' + 1) as i8);
103101

104-
let input_chunks = input.chunks_exact(CHUNK_SIZE_SSE);
105-
let input_remainder = input_chunks.remainder();
106-
for input_chunk in input_chunks {
107-
let unchecked = _mm_loadu_si128(input_chunk.as_ptr().cast());
108-
109-
let gt0 = _mm_cmpgt_epi8(unchecked, ascii_zero);
110-
let lt9 = _mm_cmplt_epi8(unchecked, ascii_nine);
111-
let valid_digit = _mm_and_si128(gt0, lt9);
112-
113-
let gtua = _mm_cmpgt_epi8(unchecked, ascii_ua);
114-
let ltuf = _mm_cmplt_epi8(unchecked, ascii_uf);
115-
116-
let gtla = _mm_cmpgt_epi8(unchecked, ascii_la);
117-
let ltlf = _mm_cmplt_epi8(unchecked, ascii_lf);
118-
119-
let valid_lower = _mm_and_si128(gtla, ltlf);
120-
let valid_upper = _mm_and_si128(gtua, ltuf);
121-
let valid_letter = _mm_or_si128(valid_lower, valid_upper);
122-
123-
let ret = _mm_movemask_epi8(_mm_or_si128(valid_digit, valid_letter));
124-
if ret != T_MASK {
125-
return false;
126-
}
127-
}
128-
129-
generic::check(input_remainder)
102+
let (prefix, chunks, suffix) = input.align_to::<__m128i>();
103+
generic::check(prefix)
104+
&& chunks.iter().all(|&chunk| {
105+
let ge0 = _mm_cmpgt_epi8(chunk, ascii_zero);
106+
let le9 = _mm_cmplt_epi8(chunk, ascii_nine);
107+
let valid_digit = _mm_and_si128(ge0, le9);
108+
109+
let geua = _mm_cmpgt_epi8(chunk, ascii_ua);
110+
let leuf = _mm_cmplt_epi8(chunk, ascii_uf);
111+
let valid_upper = _mm_and_si128(geua, leuf);
112+
113+
let gela = _mm_cmpgt_epi8(chunk, ascii_la);
114+
let lelf = _mm_cmplt_epi8(chunk, ascii_lf);
115+
let valid_lower = _mm_and_si128(gela, lelf);
116+
117+
let valid_letter = _mm_or_si128(valid_lower, valid_upper);
118+
let valid_mask = _mm_movemask_epi8(_mm_or_si128(valid_digit, valid_letter));
119+
valid_mask == 0xffff
120+
})
121+
&& generic::check(suffix)
130122
}
131123

132124
#[inline]

src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
#![cfg_attr(
2222
feature = "nightly",
2323
feature(core_intrinsics, inline_const),
24-
allow(internal_features)
24+
allow(internal_features, stable_features)
2525
)]
2626
#![cfg_attr(feature = "portable-simd", feature(portable_simd))]
2727
#![warn(

0 commit comments

Comments
 (0)