Skip to content

Commit

Permalink
array: improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
davidhewitt committed Apr 27, 2021
1 parent 7ead166 commit 4fac205
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 79 deletions.
163 changes: 86 additions & 77 deletions src/conversions/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,70 +3,6 @@ use crate::{
ToPyObject,
};

#[cfg(not(min_const_generics))]
macro_rules! array_impls {
($($N:expr),+) => {
$(
impl<T> IntoPy<PyObject> for [T; $N]
where
T: ToPyObject
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}

impl<'a, T> FromPyObject<'a> for [T; $N]
where
T: Copy + Default + FromPyObject<'a>,
{
#[cfg(not(feature = "nightly"))]
fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
_extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}

#[cfg(feature = "nightly")]
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
_extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}

#[cfg(feature = "nightly")]
impl<'source, T> FromPyObject<'source> for [T; $N]
where
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
// first try buffer protocol
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
return Ok(array);
}
buf.release(obj.py());
}
}
// fall back to sequence protocol
_extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}
)+
}
}

#[cfg(not(min_const_generics))]
array_impls!(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32
);

#[cfg(min_const_generics)]
impl<T, const N: usize> IntoPy<PyObject> for [T; N]
where
Expand Down Expand Up @@ -100,20 +36,18 @@ where
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
use crate::{AsPyPointer, PyNativeType};
let mut array = [T::default(); N];
// first try buffer protocol
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
let mut array = [T::default(); N];
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
return Ok(array);
}
buf.release(obj.py());
}
}
// fall back to sequence protocol
_extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
create_array_from_obj(obj)
}
}

Expand All @@ -123,12 +57,11 @@ where
T: FromPyObject<'s>,
{
let seq = <crate::types::PySequence as PyTryFrom>::try_from(obj)?;
let expected_len = seq.len()? as usize;
array_try_from_fn(|idx| {
seq.get_item(idx as isize)
.map_err(|_| invalid_sequence_length(expected_len, idx + 1))?
.extract::<T>()
})
let seq_len = seq.len()? as usize;
if seq_len != N {
return Err(invalid_sequence_length(N, seq_len));
}
array_try_from_fn(|idx| seq.get_item(idx as isize).and_then(PyAny::extract))
}

// TODO use std::array::try_from_fn, if that stabilises:
Expand Down Expand Up @@ -174,7 +107,72 @@ where
}
}

fn _extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
#[cfg(not(min_const_generics))]
macro_rules! array_impls {
($($N:expr),+) => {
$(
impl<T> IntoPy<PyObject> for [T; $N]
where
T: ToPyObject
{
fn into_py(self, py: Python) -> PyObject {
self.as_ref().to_object(py)
}
}

impl<'a, T> FromPyObject<'a> for [T; $N]
where
T: Copy + Default + FromPyObject<'a>,
{
#[cfg(not(feature = "nightly"))]
fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}

#[cfg(feature = "nightly")]
default fn extract(obj: &'a PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}

#[cfg(feature = "nightly")]
impl<'source, T> FromPyObject<'source> for [T; $N]
where
for<'a> T: Default + FromPyObject<'a> + crate::buffer::Element,
{
fn extract(obj: &'source PyAny) -> PyResult<Self> {
let mut array = [T::default(); $N];
// first try buffer protocol
if unsafe { crate::ffi::PyObject_CheckBuffer(obj.as_ptr()) } == 1 {
if let Ok(buf) = crate::buffer::PyBuffer::get(obj) {
if buf.dimensions() == 1 && buf.copy_to_slice(obj.py(), &mut array).is_ok() {
buf.release(obj.py());
return Ok(array);
}
buf.release(obj.py());
}
}
// fall back to sequence protocol
extract_sequence_into_slice(obj, &mut array)?;
Ok(array)
}
}
)+
}
}

#[cfg(not(min_const_generics))]
array_impls!(
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25,
26, 27, 28, 29, 30, 31, 32
);

#[cfg(not(min_const_generics))]
fn extract_sequence_into_slice<'s, T>(obj: &'s PyAny, slice: &mut [T]) -> PyResult<()>
where
T: FromPyObject<'s>,
{
Expand All @@ -189,7 +187,7 @@ where
Ok(())
}

pub fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {
exceptions::PyValueError::new_err(format!(
"expected a sequence of length {} (got {})",
expected, actual
Expand All @@ -198,7 +196,7 @@ pub fn invalid_sequence_length(expected: usize, actual: usize) -> PyErr {

#[cfg(test)]
mod test {
use crate::Python;
use crate::{Python, PyResult};
#[cfg(min_const_generics)]
use std::{
panic,
Expand Down Expand Up @@ -238,6 +236,17 @@ mod test {
assert!(&v == b"abc");
}

#[test]
fn test_extract_invalid_sequence_length() {
let gil = Python::acquire_gil();
let py = gil.python();
let v: PyResult<[u8; 3]> = py
.eval("bytearray(b'abcdefg')", None, None)
.unwrap()
.extract();
assert_eq!(v.unwrap_err().to_string(), "ValueError: expected a sequence of length 3 (got 7)");
}

#[cfg(min_const_generics)]
#[test]
fn test_extract_bytearray_to_array() {
Expand Down
3 changes: 1 addition & 2 deletions src/conversions/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
//! This module contains conversions between non-String Rust object and their string representation
//! in Python
//! This module contains conversions between various Rust object and their representation in Python.
mod array;
mod osstr;
Expand Down

0 comments on commit 4fac205

Please sign in to comment.