Skip to content

Commit

Permalink
Add support for arbitrary arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
c410-f3r committed Mar 31, 2021
1 parent bc826d1 commit 926f70a
Show file tree
Hide file tree
Showing 9 changed files with 306 additions and 108 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
- Extend `hashbrown` optional dependency supported versions to include 0.11. [#1496](https://github.com/PyO3/pyo3/pull/1496)

### Added
- Add conversion for `[T; N]` for all `N` on Rust 1.51 and up. [#1128](https://github.com/PyO3/pyo3/pull/1128)
- Add conversions between `OsStr`/`OsString`/`Path`/`PathBuf` and Python strings. [#1379](https://github.com/PyO3/pyo3/pull/1379)
- Add `#[pyo3(from_py_with = "...")]` attribute for function arguments and struct fields to override the default from-Python conversion. [#1411](https://github.com/PyO3/pyo3/pull/1411)
- Add FFI definition `PyCFunction_CheckExact` for Python 3.9 and later. [#1425](https://github.com/PyO3/pyo3/pull/1425)
Expand Down
23 changes: 23 additions & 0 deletions build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,27 @@ fn abi3_without_interpreter() -> Result<()> {
Ok(())
}

fn rustc_minor_version() -> Option<u32> {
let rustc = env::var_os("RUSTC")?;
let output = Command::new(rustc).arg("--version").output().ok()?;
let version = core::str::from_utf8(&output.stdout).ok()?;
let mut pieces = version.split('.');
if pieces.next() != Some("rustc 1") {
return None;
}
pieces.next()?.parse().ok()
}

fn manage_min_const_generics() {
let rustc_minor_version = match rustc_minor_version() {
Some(inner) => inner,
None => return,
};
if rustc_minor_version >= 51 {
println!("cargo:rustc-cfg=min_const_generics");
}
}

fn main_impl() -> Result<()> {
// If PYO3_NO_PYTHON is set with abi3, we can build PyO3 without calling Python.
// We only check for the abi3-py3{ABI3_MAX_MINOR} because lower versions depend on it.
Expand Down Expand Up @@ -916,6 +937,8 @@ fn main_impl() -> Result<()> {
println!("cargo:rustc-cfg=__pyo3_ci");
}

manage_min_const_generics();

Ok(())
}

Expand Down
9 changes: 3 additions & 6 deletions src/buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

//! `PyBuffer` implementation
use crate::err::{self, PyResult};
use crate::utils::invalid_sequence_length;
use crate::{exceptions, ffi, AsPyPointer, FromPyObject, PyAny, PyNativeType, Python};
use std::ffi::CStr;
use std::marker::PhantomData;
Expand Down Expand Up @@ -441,9 +442,7 @@ impl<T: Element> PyBuffer<T> {

fn copy_to_slice_impl(&self, py: Python, target: &mut [T], fort: u8) -> PyResult<()> {
if mem::size_of_val(target) != self.len_bytes() {
return Err(exceptions::PyBufferError::new_err(
"Slice length does not match buffer length.",
));
return Err(invalid_sequence_length(self.item_count(), target.len()));
}
unsafe {
err::error_on_minusone(
Expand Down Expand Up @@ -528,9 +527,7 @@ impl<T: Element> PyBuffer<T> {
return buffer_readonly_error();
}
if mem::size_of_val(source) != self.len_bytes() {
return Err(exceptions::PyBufferError::new_err(
"Slice length does not match buffer length.",
));
return Err(invalid_sequence_length(source.len(), self.item_count()));
}
unsafe {
err::error_on_minusone(
Expand Down
269 changes: 269 additions & 0 deletions src/conversions/array.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,269 @@
use crate::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, PyTryFrom, Python, ToPyObject};

#[cfg(not(min_const_generics))]
macro_rules! array_impls {
($($N:expr),+) => {
$(
impl<'a, T> FromPyObject<'a> for [T; $N]
where
T: Copy + Default + FromPyObject<'a>,
{
#[cfg(not(feature = "nightly"))]
fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}

#[cfg(feature = "nightly")]
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}

#[cfg(feature = "nightly")]
impl<'source, T> FromPyObject<'source> for [T; $N]
where
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
// first try buffer protocol
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
return Ok(array);
}
buf.release(obj.py());
}
}
// fall back to sequence protocol
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}
)+
}
}

#[cfg(not(min_const_generics))]
array_impls!(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32
);

#[cfg(all(min_const_generics, not(feature = "nightly")))]
impl<'a, T, const N: usize> FromPyObject<'a> for [T; N]
where
T: FromPyObject<'a>,
{
#[cfg(not(feature = "nightly"))]
fn extract(obj: &'a PyAny) -> PyResult<Self> {
create_array_from_obj(obj)
}

#[cfg(feature = "nightly")]
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
create_array_from_obj(obj)
}
}

#[cfg(not(min_const_generics))]
macro_rules! array_impls {
($($N:expr),+) => {
$(
impl<T> IntoPy<PyObject> for [T; $N]
where
T: ToPyObject
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}
)+
}
}

#[cfg(not(min_const_generics))]
array_impls!(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32
);

#[cfg(min_const_generics)]
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
where
T: ToPyObject,
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}

#[cfg(all(min_const_generics, feature = "nightly"))]
impl<'source, T, const N: usize> FromPyObject<'source> for [T; N]
where
for<'a> T: FromPyObject<'a> + crate::buffer::Element,
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
let mut array: core::mem::MaybeUninit<[T; N]> = core::mem::MaybeUninit::uninit();
// first try buffer protocol
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
// SAFETY: The array should be fully filled by `copy_to_slice`
return Ok(unsafe { array.assume_init() });
}
buf.release(obj.py());
}
}
// fall back to sequence protocol
extract_sequence_into_slice(obj, &mut array)?;
// SAFETY: The array should be fully filled by `extract_sequence_into_slice`
Ok(unsafe { array.assume_init() })
}
}

// Helper to safely create arrays since the standard library doesn't
// provide one yet. Shouldn't be necessary in the future.
#[cfg(min_const_generics)]
struct ArrayGuard<'a, T, const N: usize> {
dst: *mut T,
initialized: &'a mut usize,
}

#[cfg(min_const_generics)]
impl<T, const N: usize> Drop for ArrayGuard<'_, T, N> {
fn drop(&mut self) {
debug_assert!(*self.initialized <= N);
let initialized_part = core::ptr::slice_from_raw_parts_mut(self.dst, *self.initialized);
unsafe {
core::ptr::drop_in_place(initialized_part);
}
}
}

#[cfg(min_const_generics)]
fn create_array_from_obj<'s, T, const N: usize>(obj: &'s PyAny) -> PyResult<[T; N]>
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let expected_len = seq.len()? as usize;
let mut counter = 0;
try_create_array(&mut counter, |idx| {
seq.get_item(idx as isize)
.map_err(|_| crate::utils::invalid_sequence_length(expected_len, idx + 1))?
.extract::<T>()
})
}

#[cfg(any(all(min_const_generics, feature = "nightly"), not(min_const_generics)))]
fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let expected_len = seq.len()? as usize;
if expected_len != slice.len() {
return Err(crate::utils::invalid_sequence_length(expected_len, slice.len()));
}
for (value, item) in slice.iter_mut().zip(seq.iter()?) {
*value = item?.extract::<T>()?;
}
Ok(())
}

#[cfg(min_const_generics)]
fn try_create_array<E, F, T, const N: usize>(counter: &mut usize, mut cb: F) -> Result<[T; N], E>
where
F: FnMut(usize) -> Result<T, E>,
{
let mut array: MaybeUninit<[T; N]> = MaybeUninit::uninit();
let guard: ArrayGuard<T, N> = ArrayGuard {
dst: array.as_mut_ptr() as _,
initialized: counter,
};
unsafe {
for (idx, value_ptr) in (&mut *array.as_mut_ptr()).iter_mut().enumerate() {
core::ptr::write(value_ptr, cb(idx)?);
*guard.initialized += 1;
}
core::mem::forget(guard);
Ok(array.assume_init())
}
}

#[cfg(test)]
mod test {
use crate::Python;
use std::panic;
use std::sync::{Arc, Mutex};
use std::thread::sleep;
use std::time;

#[cfg(min_const_generics)]
#[test]
fn try_create_array() {
#[allow(clippy::mutex_atomic)]
let counter = Arc::new(Mutex::new(0));
let counter_unwind = Arc::clone(&counter);
let _ = catch_unwind_silent(move || {
let mut locked = counter_unwind.lock().unwrap();
let _: Result<[i32; 4], _> = super::try_create_array(&mut *locked, |idx| {
if idx == 2 {
panic!("peek a boo");
}
Ok::<_, ()>(1)
});
});
sleep(time::Duration::from_secs(2));
assert_eq!(*counter.lock().unwrap_err().into_inner(), 2);
}

#[cfg(not(min_const_generics))]
#[test]
fn test_extract_bytearray_to_array() {
let gil = Python::acquire_gil();
let py = gil.python();
let v: [u8; 3] = py
.eval("bytearray(b'abc')", None, None)
.unwrap()
.extract()
.unwrap();
assert!(&v == b"abc");
}

#[cfg(min_const_generics)]
#[test]
fn test_extract_bytearray_to_array() {
let gil = Python::acquire_gil();
let py = gil.python();
let v: [u8; 33] = py
.eval(
"bytearray(b'abcabcabcabcabcabcabcabcabcabcabc')",
None,
None,
)
.unwrap()
.extract()
.unwrap();
assert!(&v == b"abcabcabcabcabcabcabcabcabcabcabc");
}

// https://stackoverflow.com/a/59211505
fn catch_unwind_silent<F, R>(f: F) -> std::thread::Result<R>
where
F: FnOnce() -> R + panic::UnwindSafe,
{
let prev_hook = panic::take_hook();
panic::set_hook(Box::new(|_| {}));
let result = panic::catch_unwind(f);
panic::set_hook(prev_hook);
result
}
}
1 change: 1 addition & 0 deletions src/conversions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! This module contains conversions between non-String Rust object and their string representation
//! in Python
mod array;
mod osstr;
mod path;
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ pub mod pyclass_slots;
mod python;
pub mod type_object;
pub mod types;
mod utils;

#[cfg(feature = "serde")]
pub mod serde;
Expand Down
20 changes: 0 additions & 20 deletions src/types/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,26 +178,6 @@ where
}
}

macro_rules! array_impls {
($($N:expr),+) => {
$(
impl<T> IntoPy<PyObject> for [T; $N]
where
T: ToPyObject
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}
)+
}
}

array_impls!(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32
);

impl<T> ToPyObject for Vec<T>
where
T: ToPyObject,
Expand Down
Loading

0 comments on commit 926f70a

Please sign in to comment.