Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix Display for invalid UTF-8 in OsStr/Path #136677

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions library/core/src/fmt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1516,8 +1516,11 @@ unsafe fn getcount(args: &[rt::Argument<'_>], cnt: &rt::Count) -> Option<u16> {
}

/// Padding after the end of something. Returned by `Formatter::padding`.
#[doc(hidden)]
#[must_use = "don't forget to write the post padding"]
pub(crate) struct PostPadding {
#[unstable(feature = "fmt_internals", reason = "internal to standard library", issue = "none")]
#[derive(Debug)]
pub struct PostPadding {
fill: char,
padding: u16,
}
Expand All @@ -1528,7 +1531,9 @@ impl PostPadding {
}

/// Writes this post padding.
pub(crate) fn write(self, f: &mut Formatter<'_>) -> Result {
#[doc(hidden)]
#[unstable(feature = "fmt_internals", reason = "internal to standard library", issue = "none")]
pub fn write(self, f: &mut Formatter<'_>) -> Result {
for _ in 0..self.padding {
f.buf.write_char(self.fill)?;
}
Expand Down Expand Up @@ -1738,7 +1743,9 @@ impl<'a> Formatter<'a> {
///
/// Callers are responsible for ensuring post-padding is written after the
/// thing that is being padded.
pub(crate) fn padding(
#[doc(hidden)]
#[unstable(feature = "fmt_internals", reason = "internal to standard library", issue = "none")]
pub fn padding(
&mut self,
padding: u16,
default: Alignment,
Expand Down
161 changes: 84 additions & 77 deletions library/core/src/str/lossy.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use super::from_utf8_unchecked;
use super::validations::utf8_char_width;
use crate::fmt;
use crate::fmt::{Formatter, Write};
use crate::iter::FusedIterator;
use crate::{fmt, slice};

impl [u8] {
/// Creates an iterator over the contiguous valid UTF-8 ranges of this
Expand Down Expand Up @@ -152,7 +152,7 @@ impl fmt::Debug for Debug<'_> {
///
/// See the [`Utf8Chunk`] type for documentation of the items yielded by this iterator.
///
/// [byteslice]: slice
/// [byteslice]: prim@slice
/// [`from_utf8`]: super::from_utf8
///
/// # Examples
Expand Down Expand Up @@ -197,86 +197,29 @@ impl<'a> Iterator for Utf8Chunks<'a> {
return None;
}

const TAG_CONT_U8: u8 = 128;
fn safe_get(xs: &[u8], i: usize) -> u8 {
*xs.get(i).unwrap_or(&0)
}

let mut i = 0;
let mut valid_up_to = 0;
while i < self.source.len() {
// SAFETY: `i < self.source.len()` per previous line.
// For some reason the following are both significantly slower:
// while let Some(&byte) = self.source.get(i) {
// while let Some(byte) = self.source.get(i).copied() {
let byte = unsafe { *self.source.get_unchecked(i) };
i += 1;

if byte < 128 {
// This could be a `1 => ...` case in the match below, but for
// the common case of all-ASCII inputs, we bypass loading the
// sizeable UTF8_CHAR_WIDTH table into cache.
} else {
let w = utf8_char_width(byte);

match w {
2 => {
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
3 => {
match (byte, safe_get(self.source, i)) {
(0xE0, 0xA0..=0xBF) => (),
(0xE1..=0xEC, 0x80..=0xBF) => (),
(0xED, 0x80..=0x9F) => (),
(0xEE..=0xEF, 0x80..=0xBF) => (),
_ => break,
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
4 => {
match (byte, safe_get(self.source, i)) {
(0xF0, 0x90..=0xBF) => (),
(0xF1..=0xF3, 0x80..=0xBF) => (),
(0xF4, 0x80..=0x8F) => (),
_ => break,
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
if safe_get(self.source, i) & 192 != TAG_CONT_U8 {
break;
}
i += 1;
}
_ => break,
}
let mut iter = self.source.iter();
let mut len_after_valid = iter.len();
while !iter.is_empty() {
if !advance_utf8(&mut iter) {
// Stop at the first invalid sequence.
break;
}

valid_up_to = i;
len_after_valid = iter.len();
}
let valid_up_to = self.source.len() - len_after_valid;
let inspected_len = self.source.len() - iter.len();

// SAFETY: `i <= self.source.len()` because it is only ever incremented
// via `i += 1` and in between every single one of those increments, `i`
// is compared against `self.source.len()`. That happens either
// literally by `i < self.source.len()` in the while-loop's condition,
// or indirectly by `safe_get(self.source, i) & 192 != TAG_CONT_U8`. The
// loop is terminated as soon as the latest `i += 1` has made `i` no
// longer less than `self.source.len()`, which means it'll be at most
// equal to `self.source.len()`.
let (inspected, remaining) = unsafe { self.source.split_at_unchecked(i) };
// SAFETY: The length of the remaining bytes in `iter` only decreases,
// so `iter.len() <= self.source.len()`. The length of inspected bytes,
// `self.source.len() - iter.len()`, then only increases and can be at
// most `self.source.len()`.
let (inspected, remaining) = unsafe { self.source.split_at_unchecked(inspected_len) };
self.source = remaining;

// SAFETY: `valid_up_to <= i` because it is only ever assigned via
// `valid_up_to = i` and `i` only increases.
// SAFETY: Since `iter.len()` only decreases and `len_after_valid` is
// the value of `iter.len()` from the previous iteration, it follows
// that `len_after_valid <= iter.len()`, which is equivalent to
// `valid_up_to <= inspected_len` by simple substitution.
let (valid, invalid) = unsafe { inspected.split_at_unchecked(valid_up_to) };

Some(Utf8Chunk {
Expand All @@ -296,3 +239,67 @@ impl fmt::Debug for Utf8Chunks<'_> {
f.debug_struct("Utf8Chunks").field("source", &self.debug()).finish()
}
}

/// Advances the byte iterator by one UTF-8 scalar value, allowing invalid UTF-8
/// sequences. When the current sequence is invalid, the maximal prefix of a
/// valid UTF-8 code unit sequence is consumed. Returns whether the sequence is
/// a valid Unicode scalar value.
#[doc(hidden)]
#[unstable(feature = "str_internals", issue = "none")]
#[inline]
pub fn advance_utf8(bytes: &mut slice::Iter<'_, u8>) -> bool {
const TAG_CONT_U8: u8 = 128;
#[inline]
fn peek(bytes: &slice::Iter<'_, u8>) -> u8 {
*bytes.clone().next().unwrap_or(&0)
}

let Some(&byte) = bytes.next() else { return false };
if byte < 128 {
// This could be a `1 => ...` case in the match below, but for the
// common case of all-ASCII inputs, we bypass loading the sizeable
// UTF8_CHAR_WIDTH table into cache.
} else {
match utf8_char_width(byte) {
2 => {
if peek(bytes) & 192 != TAG_CONT_U8 {
return false;
}
bytes.next();
}
3 => {
match (byte, peek(bytes)) {
(0xE0, 0xA0..=0xBF) => {}
(0xE1..=0xEC, 0x80..=0xBF) => {}
(0xED, 0x80..=0x9F) => {}
(0xEE..=0xEF, 0x80..=0xBF) => {}
_ => return false,
}
bytes.next();
if peek(bytes) & 192 != TAG_CONT_U8 {
return false;
}
bytes.next();
}
4 => {
match (byte, peek(bytes)) {
(0xF0, 0x90..=0xBF) => {}
(0xF1..=0xF3, 0x80..=0xBF) => {}
(0xF4, 0x80..=0x8F) => {}
_ => return false,
}
bytes.next();
if peek(bytes) & 192 != TAG_CONT_U8 {
return false;
}
bytes.next();
if peek(bytes) & 192 != TAG_CONT_U8 {
return false;
}
bytes.next();
}
_ => return false,
}
}
true
}
4 changes: 3 additions & 1 deletion library/core/src/str/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ mod converts;
mod count;
mod error;
mod iter;
mod lossy;
mod traits;
mod validations;

Expand All @@ -21,7 +22,6 @@ use crate::{ascii, mem};

pub mod pattern;

mod lossy;
#[unstable(feature = "str_from_raw_parts", issue = "119206")]
pub use converts::{from_raw_parts, from_raw_parts_mut};
#[stable(feature = "rust1", since = "1.0.0")]
Expand Down Expand Up @@ -52,6 +52,8 @@ pub use iter::{Matches, RMatches};
pub use iter::{RSplit, RSplitTerminator, Split, SplitTerminator};
#[stable(feature = "rust1", since = "1.0.0")]
pub use iter::{RSplitN, SplitN};
#[unstable(feature = "str_internals", issue = "none")]
pub use lossy::advance_utf8;
#[stable(feature = "utf8_chunks", since = "1.79.0")]
pub use lossy::{Utf8Chunk, Utf8Chunks};
#[stable(feature = "rust1", since = "1.0.0")]
Expand Down
29 changes: 29 additions & 0 deletions library/std/src/ffi/os_str/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,35 @@ fn test_os_string_join() {
assert_eq!("a b c", strings_abc.join(OsStr::new(" ")));
}

#[test]
fn display() {
let os_string = OsString::from("bcd");
assert_eq!(format!("a{:^10}e", os_string.display()), "a bcd e");
}

#[cfg(unix)]
#[test]
fn display_invalid_utf8_unix() {
use crate::os::unix::ffi::OsStringExt;

let os_string = OsString::from_vec(b"b\xFFd".to_vec());
assert_eq!(format!("a{:^10}e", os_string.display()), "a b�d e");
assert_eq!(format!("a{:^10}e", os_string.as_os_str().display()), "a b�d e");
let os_string = OsString::from_vec(b"b\xE1\xBAd".to_vec());
assert_eq!(format!("a{:^10}e", os_string.display()), "a b�d e");
assert_eq!(format!("a{:^10}e", os_string.as_os_str().display()), "a b�d e");
}

#[cfg(windows)]
#[test]
fn display_invalid_wtf8_windows() {
use crate::os::windows::ffi::OsStringExt;

let os_string = OsString::from_wide(&[b'b' as _, 0xD800, b'd' as _]);
assert_eq!(format!("a{:^10}e", os_string.display()), "a b�d e");
assert_eq!(format!("a{:^10}e", os_string.as_os_str().display()), "a b�d e");
}

#[test]
fn test_os_string_default() {
let os_string: OsString = Default::default();
Expand Down
1 change: 1 addition & 0 deletions library/std/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@
#![feature(formatting_options)]
#![feature(if_let_guard)]
#![feature(intra_doc_pointers)]
#![feature(iter_advance_by)]
#![feature(lang_items)]
#![feature(let_chains)]
#![feature(link_cfg)]
Expand Down
71 changes: 56 additions & 15 deletions library/std/src/sys/os_str/bytes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
//! systems: just a `Vec<u8>`/`[u8]`.

use core::clone::CloneToUninit;
use core::str::advance_utf8;

use crate::borrow::Cow;
use crate::collections::TryReserveError;
Expand Down Expand Up @@ -64,25 +65,37 @@ impl fmt::Debug for Slice {

impl fmt::Display for Slice {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
// If we're the empty string then our iterator won't actually yield
// anything, so perform the formatting manually
if self.inner.is_empty() {
return "".fmt(f);
// Corresponds to `Formatter::pad`, but for `OsStr` instead of `str`.

// Make sure there's a fast path up front.
if f.options().get_width().is_none() && f.options().get_precision().is_none() {
return self.write_lossy(f);
}

for chunk in self.inner.utf8_chunks() {
let valid = chunk.valid();
// If we successfully decoded the whole chunk as a valid string then
// we can return a direct formatting of the string which will also
// respect various formatting flags if possible.
if chunk.invalid().is_empty() {
return valid.fmt(f);
}
// The `precision` field can be interpreted as a maximum width for the
// string being formatted.
let max_char_count = f.options().get_precision().unwrap_or(u16::MAX);
let (truncated, char_count) = truncate_chars(&self.inner, max_char_count as usize);

// If our string is longer than the maximum width, truncate it and
// handle other flags in terms of the truncated string.
// SAFETY: The truncation splits at Unicode scalar value boundaries.
let s = unsafe { Slice::from_encoded_bytes_unchecked(truncated) };

f.write_str(valid)?;
f.write_char(char::REPLACEMENT_CHARACTER)?;
// The `width` field is more of a minimum width parameter at this point.
if let Some(width) = f.options().get_width()
&& char_count < width as usize
{
// If we're under the minimum width, then fill up the minimum width
// with the specified string + some alignment.
let post_padding = f.padding(width - char_count as u16, fmt::Alignment::Left)?;
s.write_lossy(f)?;
post_padding.write(f)
} else {
// If we're over the minimum width or there is no minimum width, we
// can just emit the string.
s.write_lossy(f)
}
Ok(())
}
}

Expand Down Expand Up @@ -302,6 +315,18 @@ impl Slice {
String::from_utf8_lossy(&self.inner)
}

/// Writes the string as lossy UTF-8 like [`String::from_utf8_lossy`].
/// It ignores formatter flags.
fn write_lossy(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
for chunk in self.inner.utf8_chunks() {
f.write_str(chunk.valid())?;
if !chunk.invalid().is_empty() {
f.write_char(char::REPLACEMENT_CHARACTER)?;
}
}
Ok(())
}

#[inline]
pub fn to_owned(&self) -> Buf {
Buf { inner: self.inner.to_vec() }
Expand Down Expand Up @@ -376,3 +401,19 @@ unsafe impl CloneToUninit for Slice {
unsafe { self.inner.clone_to_uninit(dst) }
}
}

/// Counts the number of Unicode scalar values in the byte string, allowing
/// invalid UTF-8 sequences. For invalid sequences, the maximal prefix of a
/// valid UTF-8 code unit counts as one. Only up to `max_chars` scalar values
/// are scanned. Returns the character count and the byte length.
fn truncate_chars(bytes: &[u8], max_chars: usize) -> (&[u8], usize) {
let mut iter = bytes.iter();
let mut char_count = 0;
while !iter.is_empty() && char_count < max_chars {
advance_utf8(&mut iter);
char_count += 1;
}
let byte_len = bytes.len() - iter.len();
let truncated = unsafe { bytes.get_unchecked(..byte_len) };
(truncated, char_count)
}
Loading
Loading