From 92a88a8502d03052b4a0f161d6ee0f2ca7525f39 Mon Sep 17 00:00:00 2001 From: uentity Date: Thu, 13 Apr 2017 21:41:55 +0500 Subject: [PATCH] array: implement array resize --- include/pybind11/numpy.h | 23 +++++++++++++++++++++++ tests/test_numpy_array.cpp | 21 +++++++++++++++++++++ tests/test_numpy_array.py | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index b32c3887472..f28160b5c6e 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -129,6 +129,11 @@ struct npy_api { NPY_STRING_, NPY_UNICODE_, NPY_VOID_ }; + typedef struct { + Py_intptr_t *ptr; + int len; + } PyArray_Dims; + static npy_api& get() { static npy_api api = lookup(); return api; @@ -158,6 +163,7 @@ struct npy_api { Py_ssize_t *, PyObject **, PyObject *); PyObject *(*PyArray_Squeeze_)(PyObject *); int (*PyArray_SetBaseObject_)(PyObject *, PyObject *); + PyObject* (*PyArray_Resize_)(PyObject*, PyArray_Dims*, int, int); private: enum functions { API_PyArray_Type = 2, @@ -166,6 +172,7 @@ struct npy_api { API_PyArray_DescrFromType = 45, API_PyArray_DescrFromScalar = 57, API_PyArray_FromAny = 69, + API_PyArray_Resize = 80, API_PyArray_NewCopy = 85, API_PyArray_NewFromDescr = 94, API_PyArray_DescrNewFromType = 9, @@ -192,6 +199,7 @@ struct npy_api { DECL_NPY_API(PyArray_DescrFromType); DECL_NPY_API(PyArray_DescrFromScalar); DECL_NPY_API(PyArray_FromAny); + DECL_NPY_API(PyArray_Resize); DECL_NPY_API(PyArray_NewCopy); DECL_NPY_API(PyArray_NewFromDescr); DECL_NPY_API(PyArray_DescrNewFromType); @@ -647,6 +655,21 @@ class array : public buffer { return reinterpret_steal(api.PyArray_Squeeze_(m_ptr)); } + /// Resize array to given shape + /// If refcheck is true and more that one reference exist to this array + /// then resize will succeed only if it makes a reshape, i.e. original size doesn't change + void resize(ShapeContainer new_shape, bool refcheck = true) { + detail::npy_api::PyArray_Dims d = { + new_shape->data(), int(new_shape->size()) + }; + // try to resize, set ordering param to -1 cause it's not used anyway + object new_array = reinterpret_steal( + detail::npy_api::get().PyArray_Resize_(m_ptr, &d, int(refcheck), -1) + ); + if (!new_array) throw error_already_set(); + if (isinstance(new_array)) { *this = std::move(new_array); } + } + /// Ensure that the argument is a NumPy array /// In case of an error, nullptr is returned and the Python error is cleared. static array ensure(handle h, int ExtraFlags = 0) { diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index 7252547503e..ed8df3c53b8 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -267,4 +267,25 @@ test_initializer numpy_array([](py::module &m) { // Issue #785: Uninformative "Unknown internal error" exception when constructing array from empty object: sm.def("array_fail_test", []() { return py::array(py::object()); }); sm.def("array_t_fail_test", []() { return py::array_t(py::object()); }); + + // reshape array to 2D without changing size + sm.def("array_reshape2", [](py::array_t a) { + const size_t dim_sz = (size_t)std::sqrt(a.size()); + if (dim_sz * dim_sz != a.size()) + throw std::domain_error("array_reshape2: input array total size is not a squared integer"); + a.resize({dim_sz, dim_sz}); + }); + + // resize to 3D array with each dimension = N + sm.def("array_resize3", [](py::array_t a, size_t N, bool refcheck) { + a.resize({N, N, N}, refcheck); + }); + + // return 2D array with Nrows = Ncols = N + sm.def("create_and_resize", [](size_t N) { + py::array_t a; + a.resize({N, N}); + std::fill(a.mutable_data(), a.mutable_data() + a.size(), 42.); + return a; + }); }); diff --git a/tests/test_numpy_array.py b/tests/test_numpy_array.py index 6281fa4789d..10af7486a16 100644 --- a/tests/test_numpy_array.py +++ b/tests/test_numpy_array.py @@ -389,3 +389,38 @@ def test_array_failure(): with pytest.raises(ValueError) as excinfo: array_t_fail_test() assert str(excinfo.value) == 'cannot create a pybind11::array_t from a nullptr' + + +def test_array_resize(msg): + from pybind11_tests.array import (array_reshape2, array_resize3) + + a = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9], dtype='float64') + array_reshape2(a) + assert(a.size == 9) + assert(np.all(a == [[1, 2, 3], [4, 5, 6], [7, 8, 9]])) + + # total size change should succced with refcheck off + array_resize3(a, 4, False) + assert(a.size == 64) + # ... and fail with refcheck on + try: + array_resize3(a, 3, True) + except ValueError as e: + assert(str(e).startswith("cannot resize an array")) + # transposed array doesn't own data + b = a.transpose() + try: + array_resize3(b, 3, False) + except ValueError as e: + assert(str(e).startswith("cannot resize this array: it does not own its data")) + # ... but reshape should be fine + array_reshape2(b) + assert(b.shape == (8, 8)) + + +@pytest.unsupported_on_pypy +def test_array_create_and_resize(msg): + from pybind11_tests.array import create_and_resize + a = create_and_resize(2) + assert(a.size == 4) + assert(np.all(a == 42.))