Skip to content

Commit

Permalink
Accept abitrary containers and iterators for shape/strides
Browse files Browse the repository at this point in the history
This adds support for constructing `buffer_info` and `array`s using
arbitrary containers or iterators instead of requiring a vector.

This is primarily needed by PR pybind#782 (which makes strides signed to
properly support negative strides), but also needs to preserve backwards
compatibility with 2.1 and earlier which accepts the strides parameter
as a vector of size_t's.

Rather than adding nearly duplicate constructors for each stride-taking
constructor, it seems nicer to simply allow any type of container (or
iterator pairs).  This adds iterator pair constructors, and also adds
a `detail::any_container` class that handles implicit conversion of
arbitrary containers into a vector of the desired type.
  • Loading branch information
jagerman committed Apr 8, 2017
1 parent 3d95ad8 commit 544ed2f
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 60 deletions.
34 changes: 19 additions & 15 deletions include/pybind11/buffer_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
BSD-style license that can be found in the LICENSE file.
*/

#pragma once
#pragma once

#include "common.h"

Expand All @@ -26,25 +26,29 @@ struct buffer_info {
buffer_info() { }

buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
const std::vector<size_t> &shape, const std::vector<size_t> &strides)
: ptr(ptr), itemsize(itemsize), size(1), format(format),
ndim(ndim), shape(shape), strides(strides) {
detail::any_container<size_t> shape_in, detail::any_container<size_t> strides_in)
: ptr(ptr), itemsize(itemsize), size(1), format(format), ndim(ndim),
shape(std::move(shape_in)), strides(std::move(strides_in)) {
if (ndim != shape.size() || ndim != strides.size())
pybind11_fail("buffer_info: ndim doesn't match shape and/or strides length");
for (size_t i = 0; i < ndim; ++i)
size *= shape[i];
}


template <typename ShapeIt, typename StridesIt,
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t ndim,
ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last)
: buffer_info(ptr, itemsize, format, ndim, {shape_first, shape_last}, {strides_first, strides_last}) { }

buffer_info(void *ptr, size_t itemsize, const std::string &format, size_t size)
: buffer_info(ptr, itemsize, format, 1, std::vector<size_t> { size },
std::vector<size_t> { itemsize }) { }

explicit buffer_info(Py_buffer *view, bool ownview = true)
: ptr(view->buf), itemsize((size_t) view->itemsize), size(1), format(view->format),
ndim((size_t) view->ndim), shape((size_t) view->ndim), strides((size_t) view->ndim), view(view), ownview(ownview) {
for (size_t i = 0; i < (size_t) view->ndim; ++i) {
shape[i] = (size_t) view->shape[i];
strides[i] = (size_t) view->strides[i];
size *= shape[i];
}
: buffer_info(ptr, itemsize, format, 1, { size }, { itemsize }) { }

explicit buffer_info(Py_buffer *view, bool ownview_in = true)
: buffer_info(view->buf, (size_t) view->itemsize, view->format, (size_t) view->ndim,
view->shape, view->shape + view->ndim, view->strides, view->strides + view->ndim) {
ownview = ownview_in;
}

buffer_info(const buffer_info &) = delete;
Expand Down
48 changes: 48 additions & 0 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,12 @@ struct is_instantiation<Class, Class<Us...>> : std::true_type { };
/// Check if T is std::shared_ptr<U> where U can be anything
template <typename T> using is_shared_ptr = is_instantiation<std::shared_ptr, T>;

/// Check if T looks like an input iterator
template <typename T, typename = void> struct is_input_iterator : std::false_type {};
template <typename T>
struct is_input_iterator<T, void_t<decltype(*std::declval<T &>()), decltype(++std::declval<T &>())>>
: std::true_type {};

/// Ignore that a variable is unused in compiler warnings
inline void ignore_unused(const int *) { }

Expand Down Expand Up @@ -651,4 +657,46 @@ static constexpr auto const_ = std::true_type{};

#endif // overload_cast

NAMESPACE_BEGIN(detail)

// Adaptor for converting arbitrary container arguments into a vector; implicitly convertible from
// any standard container (or C-style array) supporting std::begin/std::end.
template <typename T>
class any_container {
std::vector<T> v;
public:
any_container() = default;

// Can construct from a pair of iterators
template <typename It, typename = enable_if_t<is_input_iterator<It>::value>>
any_container(It first, It last) : v(first, last) { }

// Implicit conversion constructor from any arbitrary container type with values convertible to T
template <typename Container, typename = enable_if_t<std::is_convertible<decltype(*std::begin(std::declval<const Container &>())), T>::value>>
any_container(const Container &c) : any_container(std::begin(c), std::end(c)) { }

// initializer_list's aren't deducible, so don't get matched by the above template; we need this
// to explicitly allow implicit conversion from one:
template <typename TIn, typename = enable_if_t<std::is_convertible<TIn, T>::value>>
any_container(const std::initializer_list<TIn> &c) : any_container(c.begin(), c.end()) { }

// Avoid copying if given an rvalue vector of the correct type.
any_container(std::vector<T> &&v) : v(std::move(v)) { }

// Moves the vector out of an rvalue any_container
operator std::vector<T> &&() && { return std::move(v); }

// Dereferencing obtains a reference to the underlying vector
std::vector<T> &operator*() { return v; }
const std::vector<T> &operator*() const { return v; }

// -> lets you call methods on the underlying vector
std::vector<T> *operator->() { return &v; }
const std::vector<T> *operator->() const { return &v; }
};

NAMESPACE_END(detail)



NAMESPACE_END(pybind11)
19 changes: 7 additions & 12 deletions include/pybind11/eigen.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,18 +200,13 @@ template <typename Type_> struct EigenProps {
// otherwise it'll make a copy. writeable lets you turn off the writeable flag for the array.
template <typename props> handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) {
constexpr size_t elem_size = sizeof(typename props::Scalar);
std::vector<size_t> shape, strides;
if (props::vector) {
shape.push_back(src.size());
strides.push_back(elem_size * src.innerStride());
}
else {
shape.push_back(src.rows());
shape.push_back(src.cols());
strides.push_back(elem_size * src.rowStride());
strides.push_back(elem_size * src.colStride());
}
array a(std::move(shape), std::move(strides), src.data(), base);
array a;
if (props::vector)
a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base);
else
a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() },
src.data(), base);

if (!writeable)
array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_;

Expand Down
79 changes: 49 additions & 30 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,12 +455,18 @@ class array : public buffer {

array() : array(0, static_cast<const double *>(nullptr)) {}

array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
const std::vector<size_t> &strides, const void *ptr = nullptr,
handle base = handle()) {
auto& api = detail::npy_api::get();
auto ndim = shape.size();
if (shape.size() != strides.size())
using ShapeContainer = detail::any_container<Py_intptr_t>;
using StridesContainer = detail::any_container<Py_intptr_t>;

// Constructs an array taking shape/strides from arbitrary container types
array(const pybind11::dtype &dt, ShapeContainer shape, StridesContainer strides,
const void *ptr = nullptr, handle base = handle()) {

if (strides->empty())
strides = default_strides(*shape, dt.itemsize());

auto ndim = shape->size();
if (ndim != strides->size())
pybind11_fail("NumPy: shape ndim doesn't match strides ndim");
auto descr = dt;

Expand All @@ -474,10 +480,9 @@ class array : public buffer {
flags = detail::npy_api::NPY_ARRAY_WRITEABLE_;
}

auto &api = detail::npy_api::get();
auto tmp = reinterpret_steal<object>(api.PyArray_NewFromDescr_(
api.PyArray_Type_, descr.release().ptr(), (int) ndim,
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(shape.data())),
reinterpret_cast<Py_intptr_t *>(const_cast<size_t*>(strides.data())),
api.PyArray_Type_, descr.release().ptr(), (int) ndim, shape->data(), strides->data(),
const_cast<void *>(ptr), flags, nullptr));
if (!tmp)
pybind11_fail("NumPy: unable to create array!");
Expand All @@ -491,27 +496,37 @@ class array : public buffer {
m_ptr = tmp.release().ptr();
}

array(const pybind11::dtype &dt, const std::vector<size_t> &shape,
const void *ptr = nullptr, handle base = handle())
: array(dt, shape, default_strides(shape, dt.itemsize()), ptr, base) { }
template <typename ShapeIt, typename StridesIt,
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
array(const pybind11::dtype &dt, ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
const void *ptr = nullptr, handle base = handle())
: array(dt, {shape_first, shape_last}, {strides_first, strides_last}, ptr, base) { }

array(const pybind11::dtype &dt, ShapeContainer shape, const void *ptr = nullptr, handle base = handle())
: array(dt, std::move(shape), {}, ptr, base) { }

array(const pybind11::dtype &dt, size_t count, const void *ptr = nullptr,
handle base = handle())
: array(dt, std::vector<size_t>{ count }, ptr, base) { }
: array(dt, ShapeContainer{{ count }}, ptr, base) { }

template<typename T> array(const std::vector<size_t>& shape,
const std::vector<size_t>& strides,
const T* ptr, handle base = handle())
: array(pybind11::dtype::of<T>(), shape, strides, (const void *) ptr, base) { }
template <typename T, typename ShapeIt, typename StridesIt,
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
array(ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
const T *ptr = nullptr, handle base = handle())
: array(pybind11::dtype::of<T>(), ShapeContainer(std::move(shape_first), std::move(shape_last)),
StrideContainer(std::move(strides_first), std::move(strides_last)), ptr, base) { }

template <typename T>
array(const std::vector<size_t> &shape, const T *ptr,
handle base = handle())
: array(shape, default_strides(shape, sizeof(T)), ptr, base) { }
array(ShapeContainer shape, StridesContainer strides, const T *ptr, handle base = handle())
: array(pybind11::dtype::of<T>(), std::move(shape), std::move(strides), ptr, base) { }

template <typename T>
array(ShapeContainer shape, const T *ptr, handle base = handle())
: array(std::move(shape), {}, ptr, base) { }

template <typename T>
array(size_t count, const T *ptr, handle base = handle())
: array(std::vector<size_t>{ count }, ptr, base) { }
: array({{ count }}, ptr, base) { }

explicit array(const buffer_info &info)
: array(pybind11::dtype(info), info.shape, info.strides, info.ptr) { }
Expand Down Expand Up @@ -673,9 +688,9 @@ class array : public buffer {
throw std::domain_error("array is not writeable");
}

static std::vector<size_t> default_strides(const std::vector<size_t>& shape, size_t itemsize) {
static std::vector<Py_intptr_t> default_strides(const std::vector<Py_intptr_t>& shape, size_t itemsize) {
auto ndim = shape.size();
std::vector<size_t> strides(ndim);
std::vector<Py_intptr_t> strides(ndim);
if (ndim) {
std::fill(strides.begin(), strides.end(), itemsize);
for (size_t i = 0; i < ndim - 1; i++)
Expand Down Expand Up @@ -729,14 +744,18 @@ template <typename T, int ExtraFlags = array::forcecast> class array_t : public

explicit array_t(const buffer_info& info) : array(info) { }

array_t(const std::vector<size_t> &shape,
const std::vector<size_t> &strides, const T *ptr = nullptr,
handle base = handle())
: array(shape, strides, ptr, base) { }
array_t(ShapeContainer shape, StridesContainer strides, const T *ptr = nullptr, handle base = handle())
: array(std::move(shape), std::move(strides), ptr, base) { }

template <typename ShapeIt, typename StridesIt,
typename = detail::enable_if_t<detail::is_input_iterator<ShapeIt>::value && detail::is_input_iterator<StridesIt>::value>>
array_t(ShapeIt shape_first, ShapeIt shape_last, StridesIt strides_first, StridesIt strides_last,
const T *ptr = nullptr, handle base = handle())
: array(ShapeContainer(std::move(shape_first), std::move(shape_last)),
StridesContainer(std::move(strides_first), std::move(strides_last)), ptr, base) { }

explicit array_t(const std::vector<size_t> &shape, const T *ptr = nullptr,
handle base = handle())
: array(shape, ptr, base) { }
explicit array_t(ShapeContainer shape, const T *ptr = nullptr, handle base = handle())
: array(std::move(shape), ptr, base) { }

explicit array_t(size_t count, const T *ptr = nullptr, handle base = handle())
: array(count, ptr, base) { }
Expand Down
5 changes: 2 additions & 3 deletions tests/test_numpy_array.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <pybind11/stl.h>

#include <cstdint>
#include <vector>

using arr = py::array;
using arr_t = py::array_t<uint16_t, 0>;
Expand Down Expand Up @@ -119,8 +118,8 @@ test_initializer numpy_array([](py::module &m) {
sm.def("wrap", [](py::array a) {
return py::array(
a.dtype(),
std::vector<size_t>(a.shape(), a.shape() + a.ndim()),
std::vector<size_t>(a.strides(), a.strides() + a.ndim()),
a.shape(), a.shape() + a.ndim(),
a.strides(), a.strides() + a.ndim(),
a.data(),
a
);
Expand Down

0 comments on commit 544ed2f

Please sign in to comment.