diff --git a/include/pybind11/buffer_info.h b/include/pybind11/buffer_info.h index 8bdfc07374..ead972fb70 100644 --- a/include/pybind11/buffer_info.h +++ b/include/pybind11/buffer_info.h @@ -7,7 +7,7 @@ BSD-style license that can be found in the LICENSE file. */ -#pragma once +#pragma once #include "common.h" @@ -26,25 +26,22 @@ struct buffer_info { buffer_info() { } buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim, - const std::vector &shape, const std::vector &strides) - : ptr(ptr), itemsize(itemsize), size(1), format(format), - ndim(ndim), shape(shape), strides(strides) { + detail::any_container shape_in, detail::any_container strides_in) + : ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim), + shape(std::move(shape_in)), strides(std::move(strides_in)) { + if (ndim != shape.size() || ndim != strides.size()) + pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length"); for (size_t i = 0; i < ndim; ++i) size *= shape[i]; } buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t size) - : buffer_info(ptr, itemsize, format, 1, std::vector { size }, - std::vector { itemsize }) { } - - explicit buffer_info(Py_buffer *view, bool ownview = true) - : ptr(view->buf), itemsize((size_t) view->itemsize), size(1), format(view->format), - ndim((size_t) view->ndim), shape((size_t) view->ndim), strides((size_t) view->ndim), view(view), ownview(ownview) { - for (size_t i = 0; i < (size_t) view->ndim; ++i) { - shape[i] = (size_t) view->shape[i]; - strides[i] = (size_t) view->strides[i]; - size *= shape[i]; - } + : buffer_info(ptr, itemsize, format, 1, { size }, { itemsize }) { } + + explicit buffer_info(Py_buffer *view, bool ownview_in = true) + : buffer_info(view->buf, (size_t) view->itemsize, view->format, (size_t) view->ndim, + {view->shape, view->shape + view->ndim}, {view->strides, view->strides + view->ndim}) { + ownview = ownview_in; } buffer_info(const buffer_info &) = delete; diff --git a/include/pybind11/common.h b/include/pybind11/common.h index 941842b21f..bf032bbabd 100644 --- a/include/pybind11/common.h +++ b/include/pybind11/common.h @@ -490,6 +490,12 @@ struct is_instantiation> : std::true_type { }; /// Check if T is std::shared_ptr where U can be anything template using is_shared_ptr = is_instantiation; +/// Check if T looks like an input iterator +template struct is_input_iterator : std::false_type {}; +template +struct is_input_iterator()), decltype(++std::declval())>> + : std::true_type {}; + /// Ignore that a variable is unused in compiler warnings inline void ignore_unused(const int *) { } @@ -651,4 +657,46 @@ static constexpr auto const_ = std::true_type{}; #endif // overload_cast +NAMESPACE_BEGIN(detail) + +// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from +// any standard container (or C-style array) supporting std::begin/std::end. +template +class any_container { + std::vector v; +public: + any_container() = default; + + // Can construct from a pair of iterators + template ::value>> + any_container(It first, It last) : v(first, last) { } + + // Implicit conversion constructor from any arbitrary container type with values convertible to T + template ())), T>::value>> + any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { } + + // initializer_list's aren't deducible, so don't get matched by the above template; we need this + // to explicitly allow implicit conversion from one: + template ::value>> + any_container(const std::initializer_list &c) : any_container(c.begin(), c.end()) { } + + // Avoid copying if given an rvalue vector of the correct type. + any_container(std::vector &&v) : v(std::move(v)) { } + + // Moves the vector out of an rvalue any_container + operator std::vector &&() && { return std::move(v); } + + // Dereferencing obtains a reference to the underlying vector + std::vector &operator*() { return v; } + const std::vector &operator*() const { return v; } + + // -> lets you call methods on the underlying vector + std::vector *operator->() { return &v; } + const std::vector *operator->() const { return &v; } +}; + +NAMESPACE_END(detail) + + + NAMESPACE_END(pybind11) diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index 69194a2c42..2b465622c8 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -201,18 +201,13 @@ template struct EigenProps { // otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array. template handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) { constexpr size_t elem_size = sizeof(typename props::Scalar); - std::vector shape, strides; - if (props::vector) { - shape.push_back(src.size()); - strides.push_back(elem_size * src.innerStride()); - } - else { - shape.push_back(src.rows()); - shape.push_back(src.cols()); - strides.push_back(elem_size * src.rowStride()); - strides.push_back(elem_size * src.colStride()); - } - array a(std::move(shape), std::move(strides), src.data(), base); + array a; + if (props::vector) + a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base); + else + a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() }, + src.data(), base); + if (!writeable) array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_; diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index f8ce14ebf1..27ef096c03 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -455,12 +455,18 @@ class array : public buffer { array() : array(0, static_cast(nullptr)) {} - array(const pybind11::dtype &dt, const std::vector &shape, - const std::vector &strides, const void *ptr = nullptr, - handle base = handle()) { - auto& api = detail::npy_api::get(); - auto ndim = shape.size(); - if (shape.size() != strides.size()) + using ShapeContainer = detail::any_container; + using StridesContainer = detail::any_container; + + // Constructs an array taking shape/strides from arbitrary container types + array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides, + const void *ptr = nullptr, handle base = handle()) { + + if (strides->empty()) + strides = default_strides(*shape, dt.itemsize()); + + auto ndim = shape->size(); + if (ndim != strides->size()) pybind11_fail("NumPy: shape ndim doesn't match strides ndim"); auto descr = dt; @@ -474,10 +480,9 @@ class array : public buffer { flags = detail::npy_api::NPY_ARRAY_WRITEABLE_; } + auto &api = detail::npy_api::get(); auto tmp = reinterpret_steal(api.PyArray_NewFromDescr_( - api.PyArray_Type_, descr.release().ptr(), (int) ndim, - reinterpret_cast(const_cast(shape.data())), - reinterpret_cast(const_cast(strides.data())), + api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(), const_cast(ptr), flags, nullptr)); if (!tmp) pybind11_fail("NumPy: unable to create array!"); @@ -491,27 +496,24 @@ class array : public buffer { m_ptr = tmp.release().ptr(); } - array(const pybind11::dtype &dt, const std::vector &shape, - const void *ptr = nullptr, handle base = handle()) - : array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { } + array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle()) + : array(dt, std::move(shape), {}, ptr, base) { } array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr, handle base = handle()) - : array(dt, std::vector{ count }, ptr, base) { } + : array(dt, ShapeContainer{{ count }}, ptr, base) { } - template array(const std::vector& shape, - const std::vector& strides, - const T* ptr, handle base = handle()) - : array(pybind11::dtype::of(), shape, strides, (const void *) ptr, base) { } + template + array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle()) + : array(pybind11::dtype::of(), std::move(shape), std::move(strides), ptr, base) { } template - array(const std::vector &shape, const T *ptr, - handle base = handle()) - : array(shape, default_strides(shape, sizeof(T)), ptr, base) { } + array(ShapeContainer shape, const T *ptr, handle base = handle()) + : array(std::move(shape), {}, ptr, base) { } template array(size_t count, const T *ptr, handle base = handle()) - : array(std::vector{ count }, ptr, base) { } + : array({{ count }}, ptr, base) { } explicit array(const buffer_info &info) : array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { } @@ -673,9 +675,9 @@ class array : public buffer { throw std::domain_error("array is not writeable"); } - static std::vector default_strides(const std::vector& shape, size_t itemsize) { + static std::vector default_strides(const std::vector& shape, size_t itemsize) { auto ndim = shape.size(); - std::vector strides(ndim); + std::vector strides(ndim); if (ndim) { std::fill(strides.begin(), strides.end(), itemsize); for (size_t i = 0; i < ndim - 1; i++) @@ -729,14 +731,11 @@ template class array_t : public explicit array_t(const buffer_info& info) : array(info) { } - array_t(const std::vector &shape, - const std::vector &strides, const T *ptr = nullptr, - handle base = handle()) - : array(shape, strides, ptr, base) { } + array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle()) + : array(std::move(shape), std::move(strides), ptr, base) { } - explicit array_t(const std::vector &shape, const T *ptr = nullptr, - handle base = handle()) - : array(shape, ptr, base) { } + explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle()) + : array(std::move(shape), ptr, base) { } explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle()) : array(count, ptr, base) { } diff --git a/tests/test_numpy_array.cpp b/tests/test_numpy_array.cpp index cd6487249b..08ade64a0b 100644 --- a/tests/test_numpy_array.cpp +++ b/tests/test_numpy_array.cpp @@ -13,7 +13,6 @@ #include #include -#include using arr = py::array; using arr_t = py::array_t; @@ -119,8 +118,8 @@ test_initializer numpy_array([](py::module &m) { sm.def("wrap", [](py::array a) { return py::array( a.dtype(), - std::vector(a.shape(), a.shape() + a.ndim()), - std::vector(a.strides(), a.strides() + a.ndim()), + {a.shape(), a.shape() + a.ndim()}, + {a.strides(), a.strides() + a.ndim()}, a.data(), a );