From 3d290cbb9c269304ebd7bdbd52399ccc091b8ba8 Mon Sep 17 00:00:00 2001 From: ijl Date: Sun, 3 Jul 2022 15:17:13 +0000 Subject: [PATCH] python3.12 unicode compatibility --- script/pysort | 2 +- src/deserialize/utf8.rs | 59 ++++++++++++++++++-------------------- src/serialize/dataclass.rs | 45 ++++++++++++----------------- src/serialize/dict.rs | 55 ++++++++++++++--------------------- src/serialize/numpy.rs | 8 ++---- src/serialize/str.rs | 14 ++++----- src/unicode.rs | 59 +++++++++++++++++++++----------------- 7 files changed, 111 insertions(+), 131 deletions(-) diff --git a/script/pysort b/script/pysort index df40a89b..b1a47204 100755 --- a/script/pysort +++ b/script/pysort @@ -18,7 +18,7 @@ import orjson os.sched_setaffinity(os.getpid(), {0, 1}) -dirname = os.path.join(os.path.dirname(__file__), "data") +dirname = os.path.join(os.path.dirname(__file__), "..", "data") def read_fixture_obj(filename): diff --git a/src/deserialize/utf8.rs b/src/deserialize/utf8.rs index c878822e..d08ffd6f 100644 --- a/src/deserialize/utf8.rs +++ b/src/deserialize/utf8.rs @@ -35,46 +35,43 @@ pub fn read_input_to_buf( ptr: *mut pyo3_ffi::PyObject, ) -> Result<&'static [u8], DeserializeError<'static>> { let obj_type_ptr = ob_type!(ptr); - let contents: &[u8]; - if is_type!(obj_type_ptr, STR_TYPE) { - let mut str_size: pyo3_ffi::Py_ssize_t = 0; - let uni = read_utf8_from_str(ptr, &mut str_size); - if unlikely!(uni.is_null()) { + let buffer: *const u8; + let length: usize; + if is_type!(obj_type_ptr, BYTES_TYPE) { + buffer = unsafe { PyBytes_AS_STRING(ptr) as *const u8 }; + length = unsafe { PyBytes_GET_SIZE(ptr) as usize }; + } else if is_type!(obj_type_ptr, STR_TYPE) { + let uni = unicode_to_str(ptr); + if unlikely!(uni.is_none()) { return Err(DeserializeError::new(Cow::Borrowed(INVALID_STR), 0, 0, "")); } - contents = unsafe { std::slice::from_raw_parts(uni, str_size as usize) }; - } else { - let buffer: *const u8; - let length: usize; - if is_type!(obj_type_ptr, BYTES_TYPE) { - buffer = unsafe { PyBytes_AS_STRING(ptr) as *const u8 }; - length = unsafe { PyBytes_GET_SIZE(ptr) as usize }; - } else if is_type!(obj_type_ptr, MEMORYVIEW_TYPE) { - let membuf = unsafe { PyMemoryView_GET_BUFFER(ptr) }; - if unsafe { pyo3_ffi::PyBuffer_IsContiguous(membuf, b'C' as c_char) == 0 } { - return Err(DeserializeError::new( - Cow::Borrowed("Input type memoryview must be a C contiguous buffer"), - 0, - 0, - "", - )); - } - buffer = unsafe { (*membuf).buf as *const u8 }; - length = unsafe { (*membuf).len as usize }; - } else if is_type!(obj_type_ptr, BYTEARRAY_TYPE) { - buffer = ffi!(PyByteArray_AsString(ptr)) as *const u8; - length = ffi!(PyByteArray_Size(ptr)) as usize; - } else { + let as_str = uni.unwrap(); + buffer = as_str.as_ptr(); + length = as_str.len(); + } else if is_type!(obj_type_ptr, MEMORYVIEW_TYPE) { + let membuf = unsafe { PyMemoryView_GET_BUFFER(ptr) }; + if unsafe { pyo3_ffi::PyBuffer_IsContiguous(membuf, b'C' as c_char) == 0 } { return Err(DeserializeError::new( - Cow::Borrowed("Input must be bytes, bytearray, memoryview, or str"), + Cow::Borrowed("Input type memoryview must be a C contiguous buffer"), 0, 0, "", )); } - contents = unsafe { std::slice::from_raw_parts(buffer, length) }; + buffer = unsafe { (*membuf).buf as *const u8 }; + length = unsafe { (*membuf).len as usize }; + } else if is_type!(obj_type_ptr, BYTEARRAY_TYPE) { + buffer = ffi!(PyByteArray_AsString(ptr)) as *const u8; + length = ffi!(PyByteArray_Size(ptr)) as usize; + } else { + return Err(DeserializeError::new( + Cow::Borrowed("Input must be bytes, bytearray, memoryview, or str"), + 0, + 0, + "", + )); } - Ok(contents) + Ok(unsafe { std::slice::from_raw_parts(buffer, length) }) } pub fn read_buf_to_str(contents: &[u8]) -> Result<&str, DeserializeError> { diff --git a/src/serialize/dataclass.rs b/src/serialize/dataclass.rs index e30fd74f..2d57e15c 100644 --- a/src/serialize/dataclass.rs +++ b/src/serialize/dataclass.rs @@ -50,10 +50,9 @@ impl Serialize for DataclassFastSerializer { } let mut map = serializer.serialize_map(None).unwrap(); let mut pos = 0isize; - let mut str_size: pyo3_ffi::Py_ssize_t = 0; let mut key: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); let mut value: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); - for _ in 0..=len - 1 { + for _ in 0..=len.saturating_sub(1) { unsafe { pyo3_ffi::_PyDict_Next( self.dict, @@ -63,7 +62,7 @@ impl Serialize for DataclassFastSerializer { std::ptr::null_mut(), ) }; - let value = PyObjectSerializer::new( + let pyvalue = PyObjectSerializer::new( value, self.opts, self.default_calls, @@ -73,19 +72,16 @@ impl Serialize for DataclassFastSerializer { if unlikely!(unsafe { ob_type!(key) != STR_TYPE }) { err!(SerializeError::KeyMustBeStr) } - { - let data = read_utf8_from_str(key, &mut str_size); - if unlikely!(data.is_null()) { - err!(SerializeError::InvalidStr) - } - let key_as_str = str_from_slice!(data, str_size); - if unlikely!(key_as_str.as_bytes()[0] == b'_') { - continue; - } - map.serialize_key(key_as_str).unwrap(); + let data = unicode_to_str(key); + if unlikely!(data.is_none()) { + err!(SerializeError::InvalidStr) } - - map.serialize_value(&value)?; + let key_as_str = data.unwrap(); + if unlikely!(key_as_str.as_bytes()[0] == b'_') { + continue; + } + map.serialize_key(key_as_str).unwrap(); + map.serialize_value(&pyvalue)?; } map.end() } @@ -131,7 +127,6 @@ impl Serialize for DataclassFallbackSerializer { } let mut map = serializer.serialize_map(None).unwrap(); let mut pos = 0isize; - let mut str_size: pyo3_ffi::Py_ssize_t = 0; let mut attr: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); let mut field: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); for _ in 0..=len - 1 { @@ -149,21 +144,19 @@ impl Serialize for DataclassFallbackSerializer { if unsafe { field_type != FIELD_TYPE.as_ptr() } { continue; } - { - let data = read_utf8_from_str(attr, &mut str_size); - if unlikely!(data.is_null()) { - err!(SerializeError::InvalidStr); - } - let key_as_str = str_from_slice!(data, str_size); - if key_as_str.as_bytes()[0] == b'_' { - continue; - } - map.serialize_key(key_as_str).unwrap(); + let data = unicode_to_str(attr); + if unlikely!(data.is_none()) { + err!(SerializeError::InvalidStr); + } + let key_as_str = data.unwrap(); + if key_as_str.as_bytes()[0] == b'_' { + continue; } let value = ffi!(PyObject_GetAttr(self.ptr, attr)); ffi!(Py_DECREF(value)); + map.serialize_key(key_as_str).unwrap(); map.serialize_value(&PyObjectSerializer::new( value, self.opts, diff --git a/src/serialize/dict.rs b/src/serialize/dict.rs index 63b5d609..00b1d285 100644 --- a/src/serialize/dict.rs +++ b/src/serialize/dict.rs @@ -50,7 +50,6 @@ impl Serialize for Dict { { let mut map = serializer.serialize_map(None).unwrap(); let mut pos = 0isize; - let mut str_size: pyo3_ffi::Py_ssize_t = 0; let mut key: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); let mut value: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); for _ in 0..=unsafe { PyDict_GET_SIZE(self.ptr) as usize } - 1 { @@ -63,25 +62,22 @@ impl Serialize for Dict { std::ptr::null_mut(), ) }; - let value = PyObjectSerializer::new( + if unlikely!(unsafe { ob_type!(key) != STR_TYPE }) { + err!(SerializeError::KeyMustBeStr) + } + let key_as_str = unicode_to_str(key); + if unlikely!(key_as_str.is_none()) { + err!(SerializeError::InvalidStr) + } + let pyvalue = PyObjectSerializer::new( value, self.opts, self.default_calls, self.recursion + 1, self.default, ); - if unlikely!(unsafe { ob_type!(key) != STR_TYPE }) { - err!(SerializeError::KeyMustBeStr) - } - { - let data = read_utf8_from_str(key, &mut str_size); - if unlikely!(data.is_null()) { - err!(SerializeError::InvalidStr) - } - map.serialize_key(str_from_slice!(data, str_size)).unwrap(); - } - - map.serialize_value(&value)?; + map.serialize_key(key_as_str.unwrap()).unwrap(); + map.serialize_value(&pyvalue)?; } map.end() } @@ -123,7 +119,6 @@ impl Serialize for DictSortedKey { let mut items: SmallVec<[(&str, *mut pyo3_ffi::PyObject); 8]> = SmallVec::with_capacity(len); let mut pos = 0isize; - let mut str_size: pyo3_ffi::Py_ssize_t = 0; let mut key: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); let mut value: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); for _ in 0..=len - 1 { @@ -139,11 +134,11 @@ impl Serialize for DictSortedKey { if unlikely!(unsafe { ob_type!(key) != STR_TYPE }) { err!(SerializeError::KeyMustBeStr) } - let data = read_utf8_from_str(key, &mut str_size); - if unlikely!(data.is_null()) { + let data = unicode_to_str(key); + if unlikely!(data.is_none()) { err!(SerializeError::InvalidStr) } - items.push((str_from_slice!(data, str_size), value)); + items.push((data.unwrap(), value)); } items.sort_unstable_by(|a, b| a.0.cmp(b.0)); @@ -263,21 +258,19 @@ impl DictNonStrKey { } ObType::Str => { // because of ObType::Enum - let mut str_size: pyo3_ffi::Py_ssize_t = 0; - let uni = read_utf8_from_str(key, &mut str_size); - if unlikely!(uni.is_null()) { + let uni = unicode_to_str(key); + if unlikely!(uni.is_none()) { Err(SerializeError::InvalidStr) } else { - Ok(InlinableString::from(str_from_slice!(uni, str_size))) + Ok(InlinableString::from(uni.unwrap())) } } ObType::StrSubclass => { - let mut str_size: pyo3_ffi::Py_ssize_t = 0; - let uni = ffi!(PyUnicode_AsUTF8AndSize(key, &mut str_size)) as *const u8; - if unlikely!(uni.is_null()) { + let uni = unicode_to_str_via_ffi(key); + if unlikely!(uni.is_none()) { Err(SerializeError::InvalidStr) } else { - Ok(InlinableString::from(str_from_slice!(uni, str_size))) + Ok(InlinableString::from(uni.unwrap())) } } ObType::Tuple @@ -301,7 +294,6 @@ impl Serialize for DictNonStrKey { let mut items: SmallVec<[(InlinableString, *mut pyo3_ffi::PyObject); 8]> = SmallVec::with_capacity(len); let mut pos = 0isize; - let mut str_size: pyo3_ffi::Py_ssize_t = 0; let mut key: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); let mut value: *mut pyo3_ffi::PyObject = std::ptr::null_mut(); let opts = self.opts & NOT_PASSTHROUGH; @@ -316,14 +308,11 @@ impl Serialize for DictNonStrKey { ) }; if is_type!(ob_type!(key), STR_TYPE) { - let data = read_utf8_from_str(key, &mut str_size); - if unlikely!(data.is_null()) { + let data = unicode_to_str(key); + if unlikely!(data.is_none()) { err!(SerializeError::InvalidStr) } - items.push(( - InlinableString::from(str_from_slice!(data, str_size)), - value, - )); + items.push((InlinableString::from(data.unwrap()), value)); } else { match self.pyobject_to_string(key, opts) { Ok(key_as_str) => items.push((key_as_str, value)), diff --git a/src/serialize/numpy.rs b/src/serialize/numpy.rs index f41e832f..1968cb78 100644 --- a/src/serialize/numpy.rs +++ b/src/serialize/numpy.rs @@ -722,15 +722,13 @@ impl NumpyDatetimeUnit { let el0 = ffi!(PyList_GET_ITEM(descr, 0)); ffi!(Py_DECREF(descr)); let descr_str = ffi!(PyTuple_GET_ITEM(el0, 1)); - let mut str_size: pyo3_ffi::Py_ssize_t = 0; - let uni = crate::unicode::read_utf8_from_str(descr_str, &mut str_size); - if str_size < 5 { + let uni = crate::unicode::unicode_to_str(descr_str).unwrap(); + if uni.len() < 5 { return Self::NaT; } - let fmt = str_from_slice!(uni, str_size); // unit descriptions are found at // https://github.com/numpy/numpy/blob/b235f9e701e14ed6f6f6dcba885f7986a833743f/numpy/core/src/multiarray/datetime.c#L79-L96. - match &fmt[4..fmt.len() - 1] { + match &uni[4..uni.len() - 1] { "Y" => Self::Years, "M" => Self::Months, "W" => Self::Weeks, diff --git a/src/serialize/str.rs b/src/serialize/str.rs index c5b534de..ff0c9a6a 100644 --- a/src/serialize/str.rs +++ b/src/serialize/str.rs @@ -21,12 +21,11 @@ impl Serialize for StrSerializer { where S: Serializer, { - let mut str_size: pyo3_ffi::Py_ssize_t = 0; - let uni = read_utf8_from_str(self.ptr, &mut str_size); - if unlikely!(uni.is_null()) { + let uni = unicode_to_str(self.ptr); + if unlikely!(uni.is_none()) { err!(SerializeError::InvalidStr) } - serializer.serialize_str(str_from_slice!(uni, str_size)) + serializer.serialize_str(uni.unwrap()) } } @@ -47,11 +46,10 @@ impl Serialize for StrSubclassSerializer { where S: Serializer, { - let mut str_size: pyo3_ffi::Py_ssize_t = 0; - let uni = ffi!(PyUnicode_AsUTF8AndSize(self.ptr, &mut str_size)) as *const u8; - if unlikely!(uni.is_null()) { + let uni = unicode_to_str_via_ffi(self.ptr); + if unlikely!(uni.is_none()) { err!(SerializeError::InvalidStr) } - serializer.serialize_str(str_from_slice!(uni, str_size)) + serializer.serialize_str(uni.unwrap()) } } diff --git a/src/unicode.rs b/src/unicode.rs index a88b0ed7..49630015 100644 --- a/src/unicode.rs +++ b/src/unicode.rs @@ -27,10 +27,17 @@ pub struct PyCompactUnicodeObject { pub wstr_length: Py_ssize_t, } +#[cfg(not(Py_3_12))] const STATE_ASCII: u32 = 0b00000000000000000000000001000000; #[cfg(not(Py_3_12))] const STATE_COMPACT: u32 = 0b00000000000000000000000000100000; -#[cfg(not(Py_3_12))] + +#[cfg(Py_3_12)] +const STATE_ASCII: u32 = 0b00000000000000000000000000100000; + +#[cfg(Py_3_12)] +const STATE_COMPACT: u32 = 0b00000000000000000000000000010000; + const STATE_COMPACT_ASCII: u32 = STATE_COMPACT | STATE_ASCII; fn is_four_byte(buf: &str) -> bool { @@ -108,42 +115,40 @@ pub fn unicode_from_str(buf: &str) -> *mut pyo3_ffi::PyObject { } } -#[cfg(Py_3_12)] -pub fn read_utf8_from_str(op: *mut PyObject, str_size: &mut Py_ssize_t) -> *const u8 { +#[inline] +pub fn hash_str(op: *mut PyObject) -> Py_hash_t { unsafe { - if (*op.cast::()).state & STATE_ASCII != 0 { - *str_size = (*op.cast::()).length; - op.cast::().offset(1) as *const u8 - } else if !(*op.cast::()).utf8.is_null() { - *str_size = (*op.cast::()).utf8_length; - (*op.cast::()).utf8 as *const u8 - } else { - PyUnicode_AsUTF8AndSize(op, str_size) as *const u8 - } + (*op.cast::()).hash = STR_HASH_FUNCTION.unwrap()(op); + (*op.cast::()).hash } } -#[cfg(not(Py_3_12))] -pub fn read_utf8_from_str(op: *mut PyObject, str_size: &mut Py_ssize_t) -> *const u8 { +#[inline(never)] +pub fn unicode_to_str_via_ffi(op: *mut PyObject) -> Option<&'static str> { + let mut str_size: pyo3_ffi::Py_ssize_t = 0; + let ptr = ffi!(PyUnicode_AsUTF8AndSize(op, &mut str_size)) as *const u8; + if unlikely!(ptr.is_null()) { + None + } else { + Some(str_from_slice!(ptr, str_size as usize)) + } +} + +#[inline(always)] +pub fn unicode_to_str(op: *mut PyObject) -> Option<&'static str> { unsafe { if (*op.cast::()).state & STATE_COMPACT_ASCII == STATE_COMPACT_ASCII { - *str_size = (*op.cast::()).length; - op.cast::().offset(1) as *const u8 + let ptr = op.cast::().offset(1) as *const u8; + let len = (*op.cast::()).length as usize; + Some(str_from_slice!(ptr, len)) } else if (*op.cast::()).state & STATE_COMPACT == STATE_COMPACT && !(*op.cast::()).utf8.is_null() { - *str_size = (*op.cast::()).utf8_length; - (*op.cast::()).utf8 as *const u8 + let ptr = (*op.cast::()).utf8 as *const u8; + let len = (*op.cast::()).utf8_length as usize; + Some(str_from_slice!(ptr, len)) } else { - PyUnicode_AsUTF8AndSize(op, str_size) as *const u8 + unicode_to_str_via_ffi(op) } } } - -#[inline] -pub fn hash_str(op: *mut PyObject) -> Py_hash_t { - unsafe { - (*op.cast::()).hash = STR_HASH_FUNCTION.unwrap()(op); - (*op.cast::()).hash - } -}