From d06273c81a90fb7d5cea2d00f4f94e453b7eb68c Mon Sep 17 00:00:00 2001 From: Jed Denlea Date: Fri, 20 Sep 2024 03:00:46 -0700 Subject: [PATCH] Tighten JsonPointer and methods (#613) --- crates/core/Cargo.toml | 5 +- crates/core/src/json_pointer.rs | 234 ++++++++++++++++---------------- 2 files changed, 124 insertions(+), 115 deletions(-) diff --git a/crates/core/Cargo.toml b/crates/core/Cargo.toml index 5c726c89d..c32e962d6 100644 --- a/crates/core/Cargo.toml +++ b/crates/core/Cargo.toml @@ -12,4 +12,7 @@ documentation = "https://docs.rs/ssi-core/" thiserror.workspace = true async-trait.workspace = true serde = { workspace = true, features = ["derive"] } -pin-project.workspace = true \ No newline at end of file +pin-project.workspace = true + +[dev-dependencies] +serde_json.workspace = true diff --git a/crates/core/src/json_pointer.rs b/crates/core/src/json_pointer.rs index 6812022a8..32c7b2454 100644 --- a/crates/core/src/json_pointer.rs +++ b/crates/core/src/json_pointer.rs @@ -1,9 +1,5 @@ -use core::fmt; -use std::{ - borrow::{Borrow, Cow}, - ops::Deref, - str::FromStr, -}; +use core::{fmt, ops::Deref, str::FromStr}; +use std::borrow::{Borrow, Cow}; use serde::{Deserialize, Serialize}; @@ -12,9 +8,11 @@ use crate::BytesBuf; #[macro_export] macro_rules! json_pointer { ($value:literal) => { - match $crate::JsonPointer::from_str_const($value) { - Ok(p) => p, - Err(_) => panic!("invalid JSON pointer"), + const { + match $crate::JsonPointer::from_str_const($value) { + Ok(p) => p, + Err(_) => panic!("invalid JSON pointer"), + } } }; } @@ -27,7 +25,8 @@ pub struct InvalidJsonPointer(pub T); /// /// See: #[derive(Debug, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct JsonPointer([u8]); +#[repr(transparent)] +pub struct JsonPointer(str); impl<'a> Default for &'a JsonPointer { fn default() -> Self { @@ -36,25 +35,22 @@ impl<'a> Default for &'a JsonPointer { } impl JsonPointer { - pub const ROOT: &'static Self = unsafe { - // SAFETY: the empty string is a valid JSON pointer. - JsonPointer::new_unchecked(&[]) - }; + pub const ROOT: &'static Self = json_pointer!(""); /// Converts the given string into a JSON pointer. - pub fn new>(s: &S) -> Result<&Self, InvalidJsonPointer<&S>> { - let bytes = s.as_ref(); - if Self::validate(bytes) { - Ok(unsafe { Self::new_unchecked(bytes) }) - } else { - Err(InvalidJsonPointer(s)) - } + pub fn new(s: &S) -> Result<&Self, InvalidJsonPointer<&S>> + where + S: AsRef<[u8]> + ?Sized, + { + core::str::from_utf8(s.as_ref()) + .ok() + .and_then(|s| Self::from_str_const(s).ok()) + .ok_or(InvalidJsonPointer(s)) } pub const fn from_str_const(s: &str) -> Result<&Self, InvalidJsonPointer<&str>> { - let bytes = s.as_bytes(); - if Self::validate(bytes) { - Ok(unsafe { Self::new_unchecked(bytes) }) + if Self::validate_str(s) { + Ok(unsafe { Self::new_unchecked_str(s) }) } else { Err(InvalidJsonPointer(s)) } @@ -65,14 +61,36 @@ impl JsonPointer { /// # Safety /// /// The input string *must* be a valid JSON pointer. - pub const unsafe fn new_unchecked(s: &[u8]) -> &Self { + pub const unsafe fn new_unchecked_str(s: &str) -> &Self { std::mem::transmute(s) } - pub const fn validate(bytes: &[u8]) -> bool { - if std::str::from_utf8(bytes).is_err() { + /// Converts the given string into a JSON pointer without validation. + /// + /// # Safety + /// + /// The input string *must* be a valid JSON pointer. + pub const unsafe fn new_unchecked(s: &[u8]) -> &Self { + Self::new_unchecked_str(core::str::from_utf8_unchecked(s)) + } + + /// Confirms the validity of a string such that it may be safely used for + /// [`Self::new_unchecked`]. + pub const fn validate_bytes(s: &[u8]) -> bool { + match core::str::from_utf8(s) { + Ok(s) => Self::validate_str(s), + Err(_) => false, + } + } + + /// Confirms the validity of a string such that it may be safely used for + /// [`Self::new_unchecked_str`]. + pub const fn validate_str(s: &str) -> bool { + let bytes = s.as_bytes(); + + if !matches!(bytes, [] | [b'/', ..]) { return false; - }; + } let mut i = 0; while i < bytes.len() { @@ -91,49 +109,30 @@ impl JsonPointer { } pub fn as_bytes(&self) -> &[u8] { - &self.0 + self.0.as_bytes() } pub fn as_str(&self) -> &str { - unsafe { - // SAFETY: a JSON pointer is an UTF-8 encoded string by definition. - std::str::from_utf8_unchecked(&self.0) - } + &self.0 } pub fn is_empty(&self) -> bool { self.0.is_empty() } - fn token_end(&self) -> Option { - if self.is_empty() { - None - } else { - let mut i = 1; - - while i < self.0.len() { - if self.0[i] == b'/' { - break; - } - - i += 1 - } - - Some(i) - } - } - pub fn split_first(&self) -> Option<(&ReferenceToken, &Self)> { - self.token_end().map(|i| unsafe { - ( - ReferenceToken::new_unchecked(&self.0[1..i]), - Self::new_unchecked(&self.0[i..]), - ) + self.0.strip_prefix("/").map(|s| { + let (left, right) = s.find("/").map(|idx| s.split_at(idx)).unwrap_or((s, "")); + // Safety: the token is guaranteed not to include a '/', and remaining shall be either + // empty or a valid pointer starting with '/'. + let token = unsafe { ReferenceToken::new_unchecked(left) }; + let remaining = unsafe { Self::new_unchecked_str(right) }; + (token, remaining) }) } pub fn iter(&self) -> JsonPointerIter { - let mut tokens = self.as_str().split('/'); + let mut tokens = self.0.split('/'); tokens.next(); JsonPointerIter(tokens) } @@ -178,11 +177,21 @@ impl<'a> Iterator for JsonPointerIter<'a> { } } +impl<'de> Deserialize<'de> for &'de JsonPointer { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s: &str = <&str as Deserialize>::deserialize(deserializer)?; + JsonPointer::new(s).map_err(serde::de::Error::custom) + } +} + /// JSON Pointer buffer. /// /// See: #[derive(Debug, Clone, Serialize, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct JsonPointerBuf(Vec); +pub struct JsonPointerBuf(String); impl Default for JsonPointerBuf { fn default() -> Self { @@ -193,44 +202,38 @@ impl Default for JsonPointerBuf { impl JsonPointerBuf { /// Converts the given byte string into an owned JSON pointer. pub fn new(value: B) -> Result> { - if JsonPointer::validate(value.as_ref()) { - Ok(Self(value.into())) + if JsonPointer::validate_bytes(value.as_ref()) { + let v: Vec = value.into(); + // SAFETY: we've just ensured the contents of the BytesBuf is a valid UTF-8 string and + // JsonPointer. + Ok(Self(unsafe { String::from_utf8_unchecked(v) })) } else { Err(InvalidJsonPointer(value)) } } pub fn push(&mut self, token: &str) { - self.0.push(b'/'); + self.0.reserve(1 + token.len()); + self.0.push('/'); for c in token.chars() { match c { - '~' => { - self.0.push(b'~'); - self.0.push(b'0'); - } - '/' => { - self.0.push(b'~'); - self.0.push(b'1'); - } - _ => { - let i = self.0.len(); - let len = c.len_utf8(); - self.0.resize(i + len, 0); - c.encode_utf8(&mut self.0[i..]); - } + '~' => self.0.push_str("~0"), + '/' => self.0.push_str("~1"), + _ => self.0.push(c), } } } pub fn push_index(&mut self, i: usize) { - self.push(&i.to_string()) + use core::fmt::Write; + write!(self.0, "/{i}").unwrap() } pub fn as_json_pointer(&self) -> &JsonPointer { unsafe { // SAFETY: the inner bytes are representing a JSON pointer by // construction. - JsonPointer::new_unchecked(&self.0) + JsonPointer::new_unchecked_str(&self.0) } } } @@ -239,7 +242,7 @@ impl Deref for JsonPointerBuf { type Target = JsonPointer; fn deref(&self) -> &Self::Target { - unsafe { JsonPointer::new_unchecked(&self.0) } + self.as_json_pointer() } } @@ -259,7 +262,7 @@ impl FromStr for JsonPointerBuf { type Err = InvalidJsonPointer; fn from_str(s: &str) -> Result { - Self::new(s.to_owned()) + s.to_owned().try_into() } } @@ -267,7 +270,11 @@ impl TryFrom for JsonPointerBuf { type Error = InvalidJsonPointer; fn try_from(value: String) -> Result { - Self::new(value) + if JsonPointer::validate_str(&value) { + Ok(Self(value)) + } else { + Err(InvalidJsonPointer(value)) + } } } @@ -283,14 +290,14 @@ impl<'de> Deserialize<'de> for JsonPointerBuf { D: serde::Deserializer<'de>, { String::deserialize(deserializer)? - .parse() + .try_into() .map_err(serde::de::Error::custom) } } #[derive(Debug)] #[repr(transparent)] -pub struct ReferenceToken([u8]); +pub struct ReferenceToken(str); impl ReferenceToken { /// Converts the given string into a JSON pointer reference token without @@ -299,24 +306,20 @@ impl ReferenceToken { /// # Safety /// /// The input string *must* be a valid JSON pointer reference token. - pub const unsafe fn new_unchecked(s: &[u8]) -> &Self { + pub const unsafe fn new_unchecked(s: &str) -> &Self { std::mem::transmute(s) } pub fn is_escaped(&self) -> bool { - self.0.contains(&b'~') + self.0.contains("~") } pub fn as_bytes(&self) -> &[u8] { - &self.0 + self.0.as_bytes() } pub fn as_str(&self) -> &str { - unsafe { - // SAFETY: a reference token is an UTF-8 encoded string by - // definition. - std::str::from_utf8_unchecked(&self.0) - } + &self.0 } pub fn to_decoded(&self) -> Cow { @@ -328,39 +331,27 @@ impl ReferenceToken { } pub fn decode(&self) -> String { - let mut result = String::new(); - let mut chars = self.as_str().chars(); - while let Some(c) = chars.next() { - let decoded_c = match c { + let mut buf = String::with_capacity(self.0.len()); + let mut chars = self.0.chars(); + buf.extend(core::iter::from_fn(|| { + Some(match chars.next()? { '~' => match chars.next() { Some('0') => '~', Some('1') => '/', _ => unreachable!(), }, c => c, - }; - - result.push(decoded_c); - } - - result + }) + })); + buf } pub fn as_array_index(&self) -> Option { - let mut chars = self.as_str().chars(); - let mut i = chars.next()?.to_digit(10)? as usize; - if i == 0 { - match chars.next() { - Some(_) => None, - None => Some(0), - } - } else { - for c in chars { - let d = c.to_digit(10)? as usize; - i = i * 10 + d; - } - - Some(i) + // Like usize::from_str, but don't allow leading '+' or '0'. + match self.0.as_bytes() { + [c @ b'0'..=b'9'] => Some((c - b'0') as usize), + [b'1'..=b'9', ..] => self.0.parse().ok(), + _ => None, } } } @@ -370,3 +361,18 @@ impl fmt::Display for ReferenceToken { self.as_str().fmt(f) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_serde_borrow() { + let s = String::from("\"/foo/b~1ar\""); + let p: JsonPointerBuf = serde_json::from_str(&s).unwrap(); + let jp: &JsonPointer = serde_json::from_str(&s).unwrap(); + assert_eq!(p.0, jp.0); + + serde_json::from_str::<&JsonPointer>("\"invalid\"").unwrap_err(); + } +}