From ebbcce876e48a127a3ed7358e1656e2068339dfa Mon Sep 17 00:00:00 2001 From: Mmanu Chaturvedi Date: Thu, 26 Oct 2017 14:26:56 -0400 Subject: [PATCH] Add ability to create object matrices --- include/pybind11/eigen.h | 164 ++++++++++++++++++++++++++++++++++----- include/pybind11/numpy.h | 15 ++++ 2 files changed, 160 insertions(+), 19 deletions(-) diff --git a/include/pybind11/eigen.h b/include/pybind11/eigen.h index 2a234cef7f5..ac923889437 100644 --- a/include/pybind11/eigen.h +++ b/include/pybind11/eigen.h @@ -10,6 +10,7 @@ #pragma once #include "numpy.h" +#include "numpy/ndarraytypes.h" #if defined(__INTEL_COMPILER) # pragma warning(disable: 1682) // implicit conversion of a 64-bit integral type to a smaller integral type (potential portability problem) @@ -139,14 +140,19 @@ template struct EigenProps { const auto dims = a.ndim(); if (dims < 1 || dims > 2) return false; - + bool is_pyobject = false; + if (npy_format_descriptor::value == npy_api::NPY_OBJECT_) + is_pyobject = true; + ssize_t scalar_size = (is_pyobject ? static_cast(sizeof(PyObject*)) : + static_cast(sizeof(Scalar))); if (dims == 2) { // Matrix type: require exact match (or dynamic) EigenIndex np_rows = a.shape(0), np_cols = a.shape(1), - np_rstride = a.strides(0) / static_cast(sizeof(Scalar)), - np_cstride = a.strides(1) / static_cast(sizeof(Scalar)); + np_rstride = a.strides(0) / scalar_size, + np_cstride = a.strides(1) / scalar_size; + if ((fixed_rows && np_rows != rows) || (fixed_cols && np_cols != cols)) return false; @@ -156,7 +162,7 @@ template struct EigenProps { // Otherwise we're storing an n-vector. Only one of the strides will be used, but whichever // is used, we want the (single) numpy stride value. const EigenIndex n = a.shape(0), - stride = a.strides(0) / static_cast(sizeof(Scalar)); + stride = a.strides(0) / scalar_size; if (vector) { // Eigen type is a compile-time vector if (fixed && size != n) @@ -207,11 +213,52 @@ template struct EigenProps { template handle eigen_array_cast(typename props::Type const &src, handle base = handle(), bool writeable = true) { constexpr ssize_t elem_size = sizeof(typename props::Scalar); array a; - if (props::vector) - a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base); - else - a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() }, - src.data(), base); + using Scalar = typename props::Type::Scalar; + bool is_pyoject = npy_format_descriptor::value == npy_api::NPY_OBJECT_; + + if (!is_pyoject) { + if (props::vector) + a = array({ src.size() }, { elem_size * src.innerStride() }, src.data(), base); + else + a = array({ src.rows(), src.cols() }, { elem_size * src.rowStride(), elem_size * src.colStride() }, + src.data(), base); + } + else { + if (props::vector) { + a = array( + npy_format_descriptor::dtype(), + { (size_t) src.size() }, + nullptr, + base + ); + auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy; + for (ssize_t i = 0; i < src.size(); ++i) { + auto value_ = reinterpret_steal(make_caster::cast(src(i, 0), policy, base)); + if (!value_) + return handle(); + auto p = a.mutable_data(i); + PyArray_SETITEM(a.ptr(), p, value_.release().ptr()); + } + } + else { + a = array( + npy_format_descriptor::dtype(), + {(size_t) src.rows(), (size_t) src.cols()}, + nullptr, + base + ); + auto policy = base ? return_value_policy::automatic_reference : return_value_policy::copy; + for (ssize_t i = 0; i < src.rows(); ++i) { + for (ssize_t j = 0; j < src.cols(); ++j) { + auto value_ = reinterpret_steal(make_caster::cast(src(i, j), policy, base)); + if (!value_) + return handle(); + auto p = a.mutable_data(i, j); + PyArray_SETITEM(a.ptr(), p, value_.release().ptr()); + } + } + } + } if (!writeable) array_proxy(a.ptr())->flags &= ~detail::npy_api::NPY_ARRAY_WRITEABLE_; @@ -265,14 +312,47 @@ struct type_caster::value>> { auto fits = props::conformable(buf); if (!fits) return false; - + int result = 0; // Allocate the new type, then build a numpy reference into it value = Type(fits.rows, fits.cols); - auto ref = reinterpret_steal(eigen_ref_array(value)); - if (dims == 1) ref = ref.squeeze(); - else if (ref.ndim() == 1) buf = buf.squeeze(); - - int result = detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr()); + bool is_pyobject = npy_format_descriptor::value == npy_api::NPY_OBJECT_; + + if (!is_pyobject) { + auto ref = reinterpret_steal(eigen_ref_array(value)); + if (dims == 1) ref = ref.squeeze(); + else if (ref.ndim() == 1) buf = buf.squeeze(); + result = + detail::npy_api::get().PyArray_CopyInto_(ref.ptr(), buf.ptr()); + } + else { + if (dims == 1){ + if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) { + value.resize(buf.shape(0), 1); + } + for (ssize_t i = 0; i < buf.shape(0); ++i) { + auto p = buf.mutable_data(i); + make_caster conv_val; + if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p), convert)) + return false; + value(i) = cast_op(conv_val); + } + } else { + if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) { + value.resize(buf.shape(0), buf.shape(1)); + } + for (ssize_t i = 0; i < buf.shape(0); ++i) { + for (ssize_t j = 0; j < buf.shape(1); ++j) { + // p is the const void pointer to the item + auto p = buf.mutable_data(i, j); + make_caster conv_val; + if (!conv_val.load(PyArray_GETITEM(buf.ptr(), p), + convert)) + return false; + value(i,j) = cast_op(conv_val); + } + } + } + } if (result < 0) { // Copy failed! PyErr_Clear(); @@ -424,6 +504,7 @@ struct type_caster< // storage order conversion. (Note that we refuse to use this temporary copy when loading an // argument for a Ref with M non-const, i.e. a read-write reference). Array copy_or_ref; + typename std::remove_cv::type val; public: bool load(handle src, bool convert) { // First check whether what we have is already an array of the right type. If not, we can't @@ -431,6 +512,11 @@ struct type_caster< bool need_copy = !isinstance(src); EigenConformable fits; + bool is_pyobject = false; + if (npy_format_descriptor::value == npy_api::NPY_OBJECT_) { + is_pyobject = true; + need_copy = true; + } if (!need_copy) { // We don't need a converting copy, but we also need to check whether the strides are // compatible with the Ref's stride requirements @@ -453,15 +539,55 @@ struct type_caster< // We need to copy: If we need a mutable reference, or we're not supposed to convert // (either because we're in the no-convert overload pass, or because we're explicitly // instructed not to copy (via `py::arg().noconvert()`) we have to fail loading. - if (!convert || need_writeable) return false; + if (!is_pyobject && (!convert || need_writeable)) { + return false; + } Array copy = Array::ensure(src); if (!copy) return false; fits = props::conformable(copy); - if (!fits || !fits.template stride_compatible()) + if (!fits || !fits.template stride_compatible()) { return false; - copy_or_ref = std::move(copy); - loader_life_support::add_patient(copy_or_ref); + } + + if (!is_pyobject) { + copy_or_ref = std::move(copy); + loader_life_support::add_patient(copy_or_ref); + } + else { + auto dims = copy.ndim(); + if (dims == 1){ + if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) { + val.resize(copy.shape(0), 1); + } + for (ssize_t i = 0; i < copy.shape(0); ++i) { + auto p = copy.mutable_data(i); + make_caster conv_val; + if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p), + convert)) + return false; + val(i) = cast_op(conv_val); + + } + } else { + if (Type::RowsAtCompileTime == Eigen::Dynamic || Type::ColsAtCompileTime == Eigen::Dynamic) { + val.resize(copy.shape(0), copy.shape(1)); + } + for (ssize_t i = 0; i < copy.shape(0); ++i) { + for (ssize_t j = 0; j < copy.shape(1); ++j) { + // p is the const void pointer to the item + auto p = copy.mutable_data(i, j); + make_caster conv_val; + if (!conv_val.load(PyArray_GETITEM(copy.ptr(), p), + convert)) + return false; + val(i, j) = cast_op(conv_val); + } + } + } + ref.reset(new Type(val)); + return true; + } } ref.reset(); diff --git a/include/pybind11/numpy.h b/include/pybind11/numpy.h index 6fd8fdf3779..b00c4dace27 100644 --- a/include/pybind11/numpy.h +++ b/include/pybind11/numpy.h @@ -1227,6 +1227,21 @@ template struct npy_format_descriptor { ::pybind11::detail::npy_format_descriptor::register_dtype \ ({PYBIND11_MAP2_LIST (PYBIND11_FIELD_DESCRIPTOR_EX, Type, __VA_ARGS__)}) +#define PYBIND11_NUMPY_OBJECT_DTYPE(Type) \ + namespace pybind11 { namespace detail { \ + template <> struct npy_format_descriptor { \ + public: \ + enum { value = npy_api::NPY_OBJECT_ }; \ + static pybind11::dtype dtype() { \ + if (auto ptr = npy_api::get().PyArray_DescrFromType_(value)) { \ + return reinterpret_borrow(ptr); \ + } \ + pybind11_fail("Unsupported buffer format!"); \ + } \ + static constexpr auto name = _("object"); \ + }; \ + }} + #endif // __CLION_IDE__ template