From 0f96c717924178c47491e94e802cca8a4a3a7175 Mon Sep 17 00:00:00 2001 From: Tobias Decking Date: Tue, 17 Jan 2023 21:56:34 +0100 Subject: [PATCH] Improve the floating point parser in `dec2flt`. * Remove all remaining traces of unsafe. * Put `parse_8digits` inside a loop. * Rework parsing of inf/NaN values. --- library/core/src/num/dec2flt/common.rs | 178 ++++--------------- library/core/src/num/dec2flt/decimal.rs | 65 ++++--- library/core/src/num/dec2flt/mod.rs | 7 +- library/core/src/num/dec2flt/parse.rs | 224 +++++++++++++----------- library/core/tests/num/dec2flt/parse.rs | 2 +- 5 files changed, 188 insertions(+), 288 deletions(-) diff --git a/library/core/src/num/dec2flt/common.rs b/library/core/src/num/dec2flt/common.rs index 3e8d75df37868..11a626485191c 100644 --- a/library/core/src/num/dec2flt/common.rs +++ b/library/core/src/num/dec2flt/common.rs @@ -1,165 +1,60 @@ //! Common utilities, for internal use only. -use crate::ptr; - /// Helper methods to process immutable bytes. -pub(crate) trait ByteSlice: AsRef<[u8]> { - unsafe fn first_unchecked(&self) -> u8 { - debug_assert!(!self.is_empty()); - // SAFETY: safe as long as self is not empty - unsafe { *self.as_ref().get_unchecked(0) } - } - - /// Get if the slice contains no elements. - fn is_empty(&self) -> bool { - self.as_ref().is_empty() - } - - /// Check if the slice at least `n` length. - fn check_len(&self, n: usize) -> bool { - n <= self.as_ref().len() - } - - /// Check if the first character in the slice is equal to c. - fn first_is(&self, c: u8) -> bool { - self.as_ref().first() == Some(&c) - } - - /// Check if the first character in the slice is equal to c1 or c2. - fn first_is2(&self, c1: u8, c2: u8) -> bool { - if let Some(&c) = self.as_ref().first() { c == c1 || c == c2 } else { false } - } - - /// Bounds-checked test if the first character in the slice is a digit. - fn first_isdigit(&self) -> bool { - if let Some(&c) = self.as_ref().first() { c.is_ascii_digit() } else { false } - } - - /// Check if self starts with u with a case-insensitive comparison. - fn starts_with_ignore_case(&self, u: &[u8]) -> bool { - debug_assert!(self.as_ref().len() >= u.len()); - let iter = self.as_ref().iter().zip(u.iter()); - let d = iter.fold(0, |i, (&x, &y)| i | (x ^ y)); - d == 0 || d == 32 - } - - /// Get the remaining slice after the first N elements. - fn advance(&self, n: usize) -> &[u8] { - &self.as_ref()[n..] - } - - /// Get the slice after skipping all leading characters equal c. - fn skip_chars(&self, c: u8) -> &[u8] { - let mut s = self.as_ref(); - while s.first_is(c) { - s = s.advance(1); - } - s - } - - /// Get the slice after skipping all leading characters equal c1 or c2. - fn skip_chars2(&self, c1: u8, c2: u8) -> &[u8] { - let mut s = self.as_ref(); - while s.first_is2(c1, c2) { - s = s.advance(1); - } - s - } - +pub(crate) trait ByteSlice { /// Read 8 bytes as a 64-bit integer in little-endian order. - unsafe fn read_u64_unchecked(&self) -> u64 { - debug_assert!(self.check_len(8)); - let src = self.as_ref().as_ptr() as *const u64; - // SAFETY: safe as long as self is at least 8 bytes - u64::from_le(unsafe { ptr::read_unaligned(src) }) - } + fn read_u64(&self) -> u64; - /// Try to read the next 8 bytes from the slice. - fn read_u64(&self) -> Option { - if self.check_len(8) { - // SAFETY: self must be at least 8 bytes. - Some(unsafe { self.read_u64_unchecked() }) - } else { - None - } - } - - /// Calculate the offset of slice from another. - fn offset_from(&self, other: &Self) -> isize { - other.as_ref().len() as isize - self.as_ref().len() as isize - } -} - -impl ByteSlice for [u8] {} - -/// Helper methods to process mutable bytes. -pub(crate) trait ByteSliceMut: AsMut<[u8]> { /// Write a 64-bit integer as 8 bytes in little-endian order. - unsafe fn write_u64_unchecked(&mut self, value: u64) { - debug_assert!(self.as_mut().len() >= 8); - let dst = self.as_mut().as_mut_ptr() as *mut u64; - // NOTE: we must use `write_unaligned`, since dst is not - // guaranteed to be properly aligned. Miri will warn us - // if we use `write` instead of `write_unaligned`, as expected. - // SAFETY: safe as long as self is at least 8 bytes - unsafe { - ptr::write_unaligned(dst, u64::to_le(value)); - } - } -} + fn write_u64(&mut self, value: u64); -impl ByteSliceMut for [u8] {} + /// Calculate the offset of a slice from another. + fn offset_from(&self, other: &Self) -> isize; -/// Bytes wrapper with specialized methods for ASCII characters. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub(crate) struct AsciiStr<'a> { - slc: &'a [u8], + /// Iteratively parse and consume digits from bytes. + /// Returns the same bytes with consumed digits being + /// elided. + fn parse_digits(&self, func: impl FnMut(u8)) -> &Self; } -impl<'a> AsciiStr<'a> { - pub fn new(slc: &'a [u8]) -> Self { - Self { slc } +impl ByteSlice for [u8] { + #[inline(always)] // inlining this is crucial to remove bound checks + fn read_u64(&self) -> u64 { + let mut tmp = [0; 8]; + tmp.copy_from_slice(&self[..8]); + u64::from_le_bytes(tmp) } - /// Advance the view by n, advancing it in-place to (n..). - pub unsafe fn step_by(&mut self, n: usize) -> &mut Self { - // SAFETY: safe as long n is less than the buffer length - self.slc = unsafe { self.slc.get_unchecked(n..) }; - self + #[inline(always)] // inlining this is crucial to remove bound checks + fn write_u64(&mut self, value: u64) { + self[..8].copy_from_slice(&value.to_le_bytes()) } - /// Advance the view by n, advancing it in-place to (1..). - pub unsafe fn step(&mut self) -> &mut Self { - // SAFETY: safe as long as self is not empty - unsafe { self.step_by(1) } + #[inline] + fn offset_from(&self, other: &Self) -> isize { + other.len() as isize - self.len() as isize } - /// Iteratively parse and consume digits from bytes. - pub fn parse_digits(&mut self, mut func: impl FnMut(u8)) { - while let Some(&c) = self.as_ref().first() { + #[inline] + fn parse_digits(&self, mut func: impl FnMut(u8)) -> &Self { + let mut s = self; + + // FIXME: Can't use s.split_first() here yet, + // see https://github.com/rust-lang/rust/issues/109328 + while let [c, s_next @ ..] = s { let c = c.wrapping_sub(b'0'); if c < 10 { func(c); - // SAFETY: self cannot be empty - unsafe { - self.step(); - } + s = s_next; } else { break; } } - } -} -impl<'a> AsRef<[u8]> for AsciiStr<'a> { - #[inline] - fn as_ref(&self) -> &[u8] { - self.slc + s } } -impl<'a> ByteSlice for AsciiStr<'a> {} - /// Determine if 8 bytes are all decimal digits. /// This does not care about the order in which the bytes were loaded. pub(crate) fn is_8digits(v: u64) -> bool { @@ -168,19 +63,6 @@ pub(crate) fn is_8digits(v: u64) -> bool { (a | b) & 0x8080_8080_8080_8080 == 0 } -/// Iteratively parse and consume digits from bytes. -pub(crate) fn parse_digits(s: &mut &[u8], mut f: impl FnMut(u8)) { - while let Some(&c) = s.get(0) { - let c = c.wrapping_sub(b'0'); - if c < 10 { - f(c); - *s = s.advance(1); - } else { - break; - } - } -} - /// A custom 64-bit floating point type, representing `f * 2^e`. /// e is biased, so it be directly shifted into the exponent bits. #[derive(Debug, Copy, Clone, PartialEq, Eq, Default)] diff --git a/library/core/src/num/dec2flt/decimal.rs b/library/core/src/num/dec2flt/decimal.rs index 2019f71e69b8c..350f64bb4f7a3 100644 --- a/library/core/src/num/dec2flt/decimal.rs +++ b/library/core/src/num/dec2flt/decimal.rs @@ -9,7 +9,7 @@ //! algorithm can be found in "ParseNumberF64 by Simple Decimal Conversion", //! available online: . -use crate::num::dec2flt::common::{is_8digits, parse_digits, ByteSlice, ByteSliceMut}; +use crate::num::dec2flt::common::{is_8digits, ByteSlice}; #[derive(Clone)] pub struct Decimal { @@ -205,29 +205,32 @@ impl Decimal { pub fn parse_decimal(mut s: &[u8]) -> Decimal { let mut d = Decimal::default(); let start = s; - s = s.skip_chars(b'0'); - parse_digits(&mut s, |digit| d.try_add_digit(digit)); - if s.first_is(b'.') { - s = s.advance(1); + + while let Some((&b'0', s_next)) = s.split_first() { + s = s_next; + } + + s = s.parse_digits(|digit| d.try_add_digit(digit)); + + if let Some((b'.', s_next)) = s.split_first() { + s = s_next; let first = s; // Skip leading zeros. if d.num_digits == 0 { - s = s.skip_chars(b'0'); + while let Some((&b'0', s_next)) = s.split_first() { + s = s_next; + } } while s.len() >= 8 && d.num_digits + 8 < Decimal::MAX_DIGITS { - // SAFETY: s is at least 8 bytes. - let v = unsafe { s.read_u64_unchecked() }; + let v = s.read_u64(); if !is_8digits(v) { break; } - // SAFETY: d.num_digits + 8 is less than d.digits.len() - unsafe { - d.digits[d.num_digits..].write_u64_unchecked(v - 0x3030_3030_3030_3030); - } + d.digits[d.num_digits..].write_u64(v - 0x3030_3030_3030_3030); d.num_digits += 8; - s = s.advance(8); + s = &s[8..]; } - parse_digits(&mut s, |digit| d.try_add_digit(digit)); + s = s.parse_digits(|digit| d.try_add_digit(digit)); d.decimal_point = s.len() as i32 - first.len() as i32; } if d.num_digits != 0 { @@ -248,22 +251,26 @@ pub fn parse_decimal(mut s: &[u8]) -> Decimal { d.num_digits = Decimal::MAX_DIGITS; } } - if s.first_is2(b'e', b'E') { - s = s.advance(1); - let mut neg_exp = false; - if s.first_is(b'-') { - neg_exp = true; - s = s.advance(1); - } else if s.first_is(b'+') { - s = s.advance(1); - } - let mut exp_num = 0_i32; - parse_digits(&mut s, |digit| { - if exp_num < 0x10000 { - exp_num = 10 * exp_num + digit as i32; + if let Some((&ch, s_next)) = s.split_first() { + if ch == b'e' || ch == b'E' { + s = s_next; + let mut neg_exp = false; + if let Some((&ch, s_next)) = s.split_first() { + neg_exp = ch == b'-'; + if ch == b'-' || ch == b'+' { + s = s_next; + } } - }); - d.decimal_point += if neg_exp { -exp_num } else { exp_num }; + let mut exp_num = 0_i32; + + s.parse_digits(|digit| { + if exp_num < 0x10000 { + exp_num = 10 * exp_num + digit as i32; + } + }); + + d.decimal_point += if neg_exp { -exp_num } else { exp_num }; + } } for i in d.num_digits..Decimal::MAX_DIGITS_WITHOUT_OVERFLOW { d.digits[i] = 0; diff --git a/library/core/src/num/dec2flt/mod.rs b/library/core/src/num/dec2flt/mod.rs index 58ffb950ad86b..a4bc8b1c9b0c3 100644 --- a/library/core/src/num/dec2flt/mod.rs +++ b/library/core/src/num/dec2flt/mod.rs @@ -79,7 +79,7 @@ use crate::error::Error; use crate::fmt; use crate::str::FromStr; -use self::common::{BiasedFp, ByteSlice}; +use self::common::BiasedFp; use self::float::RawFloat; use self::lemire::compute_float; use self::parse::{parse_inf_nan, parse_number}; @@ -238,17 +238,18 @@ pub fn dec2flt(s: &str) -> Result { }; let negative = c == b'-'; if c == b'-' || c == b'+' { - s = s.advance(1); + s = &s[1..]; } if s.is_empty() { return Err(pfe_invalid()); } - let num = match parse_number(s, negative) { + let mut num = match parse_number(s) { Some(r) => r, None if let Some(value) = parse_inf_nan(s, negative) => return Ok(value), None => return Err(pfe_invalid()), }; + num.negative = negative; if let Some(value) = num.try_fast_path::() { return Ok(value); } diff --git a/library/core/src/num/dec2flt/parse.rs b/library/core/src/num/dec2flt/parse.rs index 1a90e0d206fd7..b0a23835c5bd4 100644 --- a/library/core/src/num/dec2flt/parse.rs +++ b/library/core/src/num/dec2flt/parse.rs @@ -1,6 +1,6 @@ //! Functions to parse floating-point numbers. -use crate::num::dec2flt::common::{is_8digits, AsciiStr, ByteSlice}; +use crate::num::dec2flt::common::{is_8digits, ByteSlice}; use crate::num::dec2flt::float::RawFloat; use crate::num::dec2flt::number::Number; @@ -26,24 +26,39 @@ fn parse_8digits(mut v: u64) -> u64 { } /// Parse digits until a non-digit character is found. -fn try_parse_digits(s: &mut AsciiStr<'_>, x: &mut u64) { +fn try_parse_digits(mut s: &[u8], mut x: u64) -> (&[u8], u64) { // may cause overflows, to be handled later - s.parse_digits(|digit| { - *x = x.wrapping_mul(10).wrapping_add(digit as _); + + while s.len() >= 8 { + let num = s.read_u64(); + if is_8digits(num) { + x = x.wrapping_mul(1_0000_0000).wrapping_add(parse_8digits(num)); + s = &s[8..]; + } else { + break; + } + } + + s = s.parse_digits(|digit| { + x = x.wrapping_mul(10).wrapping_add(digit as _); }); + + (s, x) } /// Parse up to 19 digits (the max that can be stored in a 64-bit integer). -fn try_parse_19digits(s: &mut AsciiStr<'_>, x: &mut u64) { +fn try_parse_19digits(s_ref: &mut &[u8], x: &mut u64) { + let mut s = *s_ref; + while *x < MIN_19DIGIT_INT { - if let Some(&c) = s.as_ref().first() { + // FIXME: Can't use s.split_first() here yet, + // see https://github.com/rust-lang/rust/issues/109328 + if let [c, s_next @ ..] = s { let digit = c.wrapping_sub(b'0'); + if digit < 10 { *x = (*x * 10) + digit as u64; // no overflows here - // SAFETY: cannot be empty - unsafe { - s.step(); - } + s = s_next; } else { break; } @@ -51,46 +66,26 @@ fn try_parse_19digits(s: &mut AsciiStr<'_>, x: &mut u64) { break; } } -} -/// Try to parse 8 digits at a time, using an optimized algorithm. -fn try_parse_8digits(s: &mut AsciiStr<'_>, x: &mut u64) { - // may cause overflows, to be handled later - if let Some(v) = s.read_u64() { - if is_8digits(v) { - *x = x.wrapping_mul(1_0000_0000).wrapping_add(parse_8digits(v)); - // SAFETY: already ensured the buffer was >= 8 bytes in read_u64. - unsafe { - s.step_by(8); - } - if let Some(v) = s.read_u64() { - if is_8digits(v) { - *x = x.wrapping_mul(1_0000_0000).wrapping_add(parse_8digits(v)); - // SAFETY: already ensured the buffer was >= 8 bytes in try_read_u64. - unsafe { - s.step_by(8); - } - } - } - } - } + *s_ref = s; } /// Parse the scientific notation component of a float. -fn parse_scientific(s: &mut AsciiStr<'_>) -> Option { - let mut exponent = 0_i64; +fn parse_scientific(s_ref: &mut &[u8]) -> Option { + let mut exponent = 0i64; let mut negative = false; - if let Some(&c) = s.as_ref().get(0) { + + let mut s = *s_ref; + + if let Some((&c, s_next)) = s.split_first() { negative = c == b'-'; if c == b'-' || c == b'+' { - // SAFETY: s cannot be empty - unsafe { - s.step(); - } + s = s_next; } } - if s.first_isdigit() { - s.parse_digits(|digit| { + + if matches!(s.first(), Some(&x) if x.is_ascii_digit()) { + *s_ref = s.parse_digits(|digit| { // no overflows here, saturate well before overflow if exponent < 0x10000 { exponent = 10 * exponent + digit as i64; @@ -98,6 +93,7 @@ fn parse_scientific(s: &mut AsciiStr<'_>) -> Option { }); if negative { Some(-exponent) } else { Some(exponent) } } else { + *s_ref = s; None } } @@ -106,28 +102,29 @@ fn parse_scientific(s: &mut AsciiStr<'_>) -> Option { /// /// This creates a representation of the float as the /// significant digits and the decimal exponent. -fn parse_partial_number(s: &[u8], negative: bool) -> Option<(Number, usize)> { - let mut s = AsciiStr::new(s); - let start = s; +fn parse_partial_number(mut s: &[u8]) -> Option<(Number, usize)> { debug_assert!(!s.is_empty()); // parse initial digits before dot let mut mantissa = 0_u64; - let digits_start = s; - try_parse_digits(&mut s, &mut mantissa); - let mut n_digits = s.offset_from(&digits_start); + let start = s; + let tmp = try_parse_digits(s, mantissa); + s = tmp.0; + mantissa = tmp.1; + let mut n_digits = s.offset_from(start); // handle dot with the following digits let mut n_after_dot = 0; let mut exponent = 0_i64; let int_end = s; - if s.first_is(b'.') { - // SAFETY: s cannot be empty due to first_is - unsafe { s.step() }; + + if let Some((&b'.', s_next)) = s.split_first() { + s = s_next; let before = s; - try_parse_8digits(&mut s, &mut mantissa); - try_parse_digits(&mut s, &mut mantissa); - n_after_dot = s.offset_from(&before); + let tmp = try_parse_digits(s, mantissa); + s = tmp.0; + mantissa = tmp.1; + n_after_dot = s.offset_from(before); exponent = -n_after_dot as i64; } @@ -138,65 +135,60 @@ fn parse_partial_number(s: &[u8], negative: bool) -> Option<(Number, usize)> { // handle scientific format let mut exp_number = 0_i64; - if s.first_is2(b'e', b'E') { - // SAFETY: s cannot be empty - unsafe { - s.step(); + if let Some((&c, s_next)) = s.split_first() { + if c == b'e' || c == b'E' { + s = s_next; + // If None, we have no trailing digits after exponent, or an invalid float. + exp_number = parse_scientific(&mut s)?; + exponent += exp_number; } - // If None, we have no trailing digits after exponent, or an invalid float. - exp_number = parse_scientific(&mut s)?; - exponent += exp_number; } - let len = s.offset_from(&start) as _; + let len = s.offset_from(start) as _; // handle uncommon case with many digits if n_digits <= 19 { - return Some((Number { exponent, mantissa, negative, many_digits: false }, len)); + return Some((Number { exponent, mantissa, negative: false, many_digits: false }, len)); } n_digits -= 19; let mut many_digits = false; - let mut p = digits_start; - while p.first_is2(b'0', b'.') { - // SAFETY: p cannot be empty due to first_is2 - unsafe { - // '0' = b'.' + 2 - n_digits -= p.first_unchecked().saturating_sub(b'0' - 1) as isize; - p.step(); + let mut p = start; + while let Some((&c, p_next)) = p.split_first() { + if c == b'.' || c == b'0' { + n_digits -= c.saturating_sub(b'0' - 1) as isize; + p = p_next; + } else { + break; } } if n_digits > 0 { // at this point we have more than 19 significant digits, let's try again many_digits = true; mantissa = 0; - let mut s = digits_start; + let mut s = start; try_parse_19digits(&mut s, &mut mantissa); exponent = if mantissa >= MIN_19DIGIT_INT { // big int - int_end.offset_from(&s) + int_end.offset_from(s) } else { - // SAFETY: the next byte must be present and be '.' - // We know this is true because we had more than 19 - // digits previously, so we overflowed a 64-bit integer, - // but parsing only the integral digits produced less - // than 19 digits. That means we must have a decimal - // point, and at least 1 fractional digit. - unsafe { s.step() }; + s = &s[1..]; let before = s; try_parse_19digits(&mut s, &mut mantissa); - -s.offset_from(&before) + -s.offset_from(before) } as i64; // add back the explicit part exponent += exp_number; } - Some((Number { exponent, mantissa, negative, many_digits }, len)) + Some((Number { exponent, mantissa, negative: false, many_digits }, len)) } -/// Try to parse a non-special floating point number. -pub fn parse_number(s: &[u8], negative: bool) -> Option { - if let Some((float, rest)) = parse_partial_number(s, negative) { +/// Try to parse a non-special floating point number, +/// as well as two slices with integer and fractional parts +/// and the parsed exponent. +pub fn parse_number(s: &[u8]) -> Option { + if let Some((float, rest)) = parse_partial_number(s) { if rest == s.len() { return Some(float); } @@ -204,30 +196,48 @@ pub fn parse_number(s: &[u8], negative: bool) -> Option { None } -/// Parse a partial representation of a special, non-finite float. -fn parse_partial_inf_nan(s: &[u8]) -> Option<(F, usize)> { - fn parse_inf_rest(s: &[u8]) -> usize { - if s.len() >= 8 && s[3..].as_ref().starts_with_ignore_case(b"inity") { 8 } else { 3 } - } - if s.len() >= 3 { - if s.starts_with_ignore_case(b"nan") { - return Some((F::NAN, 3)); - } else if s.starts_with_ignore_case(b"inf") { - return Some((F::INFINITY, parse_inf_rest(s))); - } - } - None -} - /// Try to parse a special, non-finite float. -pub fn parse_inf_nan(s: &[u8], negative: bool) -> Option { - if let Some((mut float, rest)) = parse_partial_inf_nan::(s) { - if rest == s.len() { - if negative { - float = -float; - } - return Some(float); - } +pub(crate) fn parse_inf_nan(s: &[u8], negative: bool) -> Option { + // Since a valid string has at most the length 8, we can load + // all relevant characters into a u64 and work from there. + // This also generates much better code. + + let mut register; + let len: usize; + + // All valid strings are either of length 8 or 3. + if s.len() == 8 { + register = s.read_u64(); + len = 8; + } else if s.len() == 3 { + let a = s[0] as u64; + let b = s[1] as u64; + let c = s[2] as u64; + register = (c << 16) | (b << 8) | a; + len = 3; + } else { + return None; } - None + + // Clear out the bits which turn ASCII uppercase characters into + // lowercase characters. The resulting string is all uppercase. + // What happens to other characters is irrelevant. + register &= 0xDFDFDFDFDFDFDFDF; + + // u64 values corresponding to relevant cases + const INF_3: u64 = 0x464E49; // "INF" + const INF_8: u64 = 0x5954494E49464E49; // "INFINITY" + const NAN: u64 = 0x4E414E; // "NAN" + + // Match register value to constant to parse string. + // Also match on the string length to catch edge cases + // like "inf\0\0\0\0\0". + let float = match (register, len) { + (INF_3, 3) => F::INFINITY, + (INF_8, 8) => F::INFINITY, + (NAN, 3) => F::NAN, + _ => return None, + }; + + if negative { Some(-float) } else { Some(float) } } diff --git a/library/core/tests/num/dec2flt/parse.rs b/library/core/tests/num/dec2flt/parse.rs index edc77377d5820..4a5d24ba7d5fa 100644 --- a/library/core/tests/num/dec2flt/parse.rs +++ b/library/core/tests/num/dec2flt/parse.rs @@ -32,7 +32,7 @@ fn invalid_chars() { } fn parse_positive(s: &[u8]) -> Option { - parse_number(s, false) + parse_number(s) } #[test]