Skip to content

Commit

Permalink
Refactor and add SAFETY comments to PyArrayUnicode
Browse files Browse the repository at this point in the history
Replace deprecated `PyUnicode_FromUnicode` with `PyUnicode_FromKindAndData`
  • Loading branch information
messense authored and h-vetinari committed Apr 12, 2022
1 parent bc9e5fd commit 87f157c
Showing 1 changed file with 19 additions and 17 deletions.
36 changes: 19 additions & 17 deletions bindings/python/src/tokenizer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,48 +259,50 @@ impl<'s> From<TextInputSequence<'s>> for tk::InputSequence<'s> {
struct PyArrayUnicode(Vec<String>);
impl FromPyObject<'_> for PyArrayUnicode {
fn extract(ob: &PyAny) -> PyResult<Self> {
// SAFETY Making sure the pointer is a valid numpy array requires calling numpy C code
if unsafe { npyffi::PyArray_Check(ob.py(), ob.as_ptr()) } == 0 {
return Err(exceptions::PyTypeError::new_err("Expected an np.array"));
}
let arr = ob.as_ptr() as *mut npyffi::PyArrayObject;
if unsafe { (*arr).nd } != 1 {
return Err(exceptions::PyTypeError::new_err(
"Expected a 1 dimensional np.array",
));
}
if unsafe { (*arr).flags }
& (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS)
== 0
{
return Err(exceptions::PyTypeError::new_err(
"Expected a contiguous np.array",
));
}
let n_elem = unsafe { *(*arr).dimensions } as usize;
let (type_num, elsize, alignment, data) = unsafe {
// SAFETY Getting all the metadata about the numpy array to check its sanity
let (type_num, elsize, alignment, data, nd, flags) = unsafe {
let desc = (*arr).descr;
(
(*desc).type_num,
(*desc).elsize as usize,
(*desc).alignment as usize,
(*arr).data,
(*arr).nd,
(*arr).flags,
)
};

if nd != 1 {
return Err(exceptions::PyTypeError::new_err(
"Expected a 1 dimensional np.array",
));
}
if flags & (npyffi::NPY_ARRAY_C_CONTIGUOUS | npyffi::NPY_ARRAY_F_CONTIGUOUS) == 0 {
return Err(exceptions::PyTypeError::new_err(
"Expected a contiguous np.array",
));
}
if type_num != npyffi::types::NPY_TYPES::NPY_UNICODE as i32 {
return Err(exceptions::PyTypeError::new_err(
"Expected a np.array[dtype='U']",
));
}

// SAFETY Looking at the raw numpy data to create new owned Rust strings via copies (so it's safe afterwards).
unsafe {
let n_elem = *(*arr).dimensions as usize;
let all_bytes = std::slice::from_raw_parts(data as *const u8, elsize * n_elem);

let seq = (0..n_elem)
.map(|i| {
let bytes = &all_bytes[i * elsize..(i + 1) * elsize];
#[allow(deprecated)]
let unicode = pyo3::ffi::PyUnicode_FromUnicode(
let unicode = pyo3::ffi::PyUnicode_FromKindAndData(
pyo3::ffi::PyUnicode_4BYTE_KIND as _,
bytes.as_ptr() as *const _,
elsize as isize / alignment as isize,
);
Expand Down

0 comments on commit 87f157c

Please sign in to comment.