Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow std::complex field with PYBIND11_NUMPY_DTYPE #831

Merged
merged 9 commits into from
May 10, 2017
9 changes: 6 additions & 3 deletions docs/advanced/pycpp/numpy.rst
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,12 @@ expects the type followed by field names:
/* now both A and B can be used as template arguments to py::array_t */
}

The structure should consist of fundamental arithmetic types, previously
registered substructures, and arrays of any of the above. Both C++ arrays and
``std::array`` are supported.
The structure should consist of fundamental arithmetic types, ``std::complex``,
previously registered substructures, and arrays of any of the above. Both C++
arrays and ``std::array`` are supported. While there is a static assertion to
prevent many types of unsupported structures, it is still the user's
responsibility to use only "plain" structures that can be safely manipulated as
raw memory without violating invariants.

Vectorizing functions
=====================
Expand Down
6 changes: 3 additions & 3 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,14 +608,14 @@ template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>
};
NAMESPACE_END(detail)

template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>> {
static constexpr const char c = "?bBhHiIqQfdgFDG"[detail::is_fmt_numeric<T>::index];
template <typename T> struct format_descriptor<T, detail::enable_if_t<std::is_arithmetic<T>::value>> {
static constexpr const char c = "?bBhHiIqQfdg"[detail::is_fmt_numeric<T>::index];
static constexpr const char value[2] = { c, '\0' };
static std::string format() { return std::string(1, c); }
};

template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
T, detail::enable_if_t<std::is_arithmetic<T>::value>>::value[2];

/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
Expand Down
13 changes: 11 additions & 2 deletions include/pybind11/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,19 @@
#endif

NAMESPACE_BEGIN(pybind11)

template <typename T> struct format_descriptor<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr const char c = format_descriptor<T>::c;
static constexpr const char value[3] = { 'Z', c, '\0' };
static std::string format() { return std::string(value); }
};

template <typename T> constexpr const char format_descriptor<
std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>>::value[3];

NAMESPACE_BEGIN(detail)

// The format codes are already in the string in common.h, we just need to provide a specialization
template <typename T> struct is_fmt_numeric<std::complex<T>> {
template <typename T> struct is_fmt_numeric<std::complex<T>, detail::enable_if_t<std::is_floating_point<T>::value>> {
static constexpr bool value = true;
static constexpr int index = is_fmt_numeric<T>::index + 3;
};
Expand Down
21 changes: 14 additions & 7 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,14 @@ template <typename T, size_t N> struct array_info<T[N]> : array_info<std::array<
template <typename T> using remove_all_extents_t = typename array_info<T>::type;

template <typename T> using is_pod_struct = all_of<
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
std::is_standard_layout<T>, // since we're accessing directly in memory we need a standard layout type
#if !defined(__GNUG__) || defined(__clang__) || __GNUC__ >= 5
std::is_trivially_copyable<T>,
#else
// GCC 4 doesn't implement is_trivially_copyable, so approximate it
std::is_trivially_destructible<T>,
satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
#endif
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

Expand Down Expand Up @@ -1016,7 +1023,6 @@ struct field_descriptor {
const char *name;
ssize_t offset;
ssize_t size;
ssize_t alignment;
std::string format;
dtype descr;
};
Expand Down Expand Up @@ -1053,13 +1059,15 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
ssize_t offset = 0;
std::ostringstream oss;
oss << "T{";
// mark the structure as unaligned with '^', because numpy and C++ don't
// always agree about alignment (particularly for complex), and we're
// explicitly listing all our padding. This depends on none of the fields
// overriding the endianness. Putting the ^ in front of individual fields
// isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
oss << "^T{";
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
// mark unaligned fields with '^' (unaligned native type)
if (field.offset % field.alignment)
oss << '^';
oss << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
Expand Down Expand Up @@ -1121,7 +1129,6 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
::pybind11::detail::field_descriptor { \
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
alignof(decltype(std::declval<T>().Field)), \
::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
}
Expand Down
30 changes: 28 additions & 2 deletions tests/test_numpy_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ struct StringStruct {
std::array<char, 3> b;
};

struct ComplexStruct {
std::complex<float> cflt;
std::complex<double> cdbl;
};

std::ostream& operator<<(std::ostream& os, const ComplexStruct& v) {
return os << "c:" << v.cflt << "," << v.cdbl;
}

struct ArrayStruct {
char a[3][4];
int32_t b[2];
Expand Down Expand Up @@ -219,6 +228,18 @@ py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
return arr;
}

py::array_t<ComplexStruct, 0> create_complex_array(size_t n) {
auto arr = mkarray_via_buffer<ComplexStruct>(n);
auto ptr = (ComplexStruct *) arr.mutable_data();
for (size_t i = 0; i < n; i++) {
ptr[i].cflt.real(float(i));
ptr[i].cflt.imag(float(i) + 0.25f);
ptr[i].cdbl.real(double(i) + 0.5);
ptr[i].cdbl.imag(double(i) + 0.75);
}
return arr;
}

template <typename S>
py::list print_recarray(py::array_t<S, 0> arr) {
const auto req = arr.request();
Expand All @@ -241,7 +262,8 @@ py::list print_format_descriptors() {
py::format_descriptor<PartialNestedStruct>::format(),
py::format_descriptor<StringStruct>::format(),
py::format_descriptor<ArrayStruct>::format(),
py::format_descriptor<EnumStruct>::format()
py::format_descriptor<EnumStruct>::format(),
py::format_descriptor<ComplexStruct>::format()
};
auto l = py::list();
for (const auto &fmt : fmts) {
Expand All @@ -260,7 +282,8 @@ py::list print_dtypes() {
py::str(py::dtype::of<StringStruct>()),
py::str(py::dtype::of<ArrayStruct>()),
py::str(py::dtype::of<EnumStruct>()),
py::str(py::dtype::of<StructWithUglyNames>())
py::str(py::dtype::of<StructWithUglyNames>()),
py::str(py::dtype::of<ComplexStruct>())
};
auto l = py::list();
for (const auto &s : dtypes) {
Expand Down Expand Up @@ -401,6 +424,7 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(ArrayStruct, a, b, c, d);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
PYBIND11_NUMPY_DTYPE(ComplexStruct, cflt, cdbl);
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);

Expand Down Expand Up @@ -431,6 +455,8 @@ test_initializer numpy_dtypes([](py::module &m) {
m.def("print_array_array", &print_recarray<ArrayStruct>);
m.def("create_enum_array", &create_enum_array);
m.def("print_enum_array", &print_recarray<EnumStruct>);
m.def("create_complex_array", &create_complex_array);
m.def("print_complex_array", &print_recarray<ComplexStruct>);
m.def("test_array_ctors", &test_array_ctors);
m.def("test_dtype_ctors", &test_dtype_ctors);
m.def("test_dtype_methods", &test_dtype_methods);
Expand Down
38 changes: 29 additions & 9 deletions tests/test_numpy_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,21 +73,22 @@ def test_format_descriptors():

ld = np.dtype('longdouble')
ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char
ss_fmt = "T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
dbl = np.dtype('double')
partial_fmt = ("T{?:bool_:3xI:uint_:f:float_:" +
partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" +
str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
"xg:ldbl_:}")
nested_extra = str(max(8, ld.alignment))
assert print_format_descriptors() == [
ss_fmt,
"T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}",
"T{" + ss_fmt + ":a:T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}",
"^T{?:bool_:I:uint_:f:float_:g:ldbl_:}",
"^T{" + ss_fmt + ":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}",
partial_fmt,
"T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
"T{3s:a:3s:b:}",
"T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}",
'T{q:e1:B:e2:}'
"^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
"^T{3s:a:3s:b:}",
"^T{(3)4s:a:(2)i:b:(3)B:c:1x(4, 2)f:d:}",
'^T{q:e1:B:e2:}',
'^T{Zf:cflt:Zd:cdbl:}'
]


Expand All @@ -108,7 +109,8 @@ def test_dtype(simple_dtype):
"'formats':[('S4', (3,)),('<i4', (2,)),('u1', (3,)),('<f4', (4, 2))], " +
"'offsets':[0,12,20,24], 'itemsize':56}}").format(e=e),
"[('e1', '" + e + "i8'), ('e2', 'u1')]",
"[('x', 'i1'), ('y', '" + e + "u8')]"
"[('x', 'i1'), ('y', '" + e + "u8')]",
"[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]"
]

d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
Expand Down Expand Up @@ -260,6 +262,24 @@ def test_enum_array():
assert create_enum_array(0).dtype == dtype


def test_complex_array():
from pybind11_tests import create_complex_array, print_complex_array
from sys import byteorder
e = '<' if byteorder == 'little' else '>'

arr = create_complex_array(3)
dtype = arr.dtype
assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')])
assert print_complex_array(arr) == [
"c:(0,0.25),(0.5,0.75)",
"c:(1,1.25),(1.5,1.75)",
"c:(2,2.25),(2.5,2.75)"
]
assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j]
assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j]
assert create_complex_array(0).dtype == dtype


def test_signature(doc):
from pybind11_tests import create_rec_nested

Expand Down