Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow std::complex field with PYBIND11_NUMPY_DTYPE #831

Merged
merged 9 commits into from
May 10, 2017
9 changes: 5 additions & 4 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,14 @@ template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>
NAMESPACE_END(detail)

template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>> {
static constexpr const char c = "?bBhHiIqQfdgFDG"[detail::is_fmt_numeric<T>::index];
static constexpr const char value[2] = { c, '\0' };
static std::string format() { return std::string(1, c); }
static constexpr const char c1 = "?bBhHiIqQfdgZZZ"[detail::is_fmt_numeric<T>::index];
static constexpr const char c2 = "\0\0\0\0\0\0\0\0\0\0\0\0fdg"[detail::is_fmt_numeric<T>::index];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to say, I'm not a big fan of this part. I think it would be cleaner to just remove FDG from here and specialize the whole format_descriptor struct for std::complex inside complex.h.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Especially in light of the extra hack required below for MSVC).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Now for some reason the VS2017 x86 Python 3.6 build (but none of the others) is failing, and there is no indication of what the error is. Also, Travis doesn't seem to have been run at all. Any ideas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AppVeyor issue is most likely random: #792. Not sure what happened to Travis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, looks like it was just random failures. Appveyor is still running the VS2015 builds, but it should all be working now.

static constexpr const char value[3] = { c1, c2, '\0' };
static std::string format() { return std::string(value); }
};

template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[3];

/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
Expand Down
12 changes: 6 additions & 6 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,8 @@ template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

template <typename T> using is_pod_struct = all_of<
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
std::is_standard_layout<T>, // since we're accessing directly in memory we need a standard layout type
std::is_trivially_copyable<T>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For libstcd++ < 5 (which doesn't have is_trivially_copyable) this workaround might work:

    std::is_standard_layout<T>,
#if !defined(__GNUG__) || defined(__clang__) || __GNUC__ >= 5
    std::is_trivially_copyable<T>,
#else
    std::is_trivially_destructible<T>, std::has_trivial_copy_constructor<T>,
#endif
    satisfies_none_of<...

It's not quite the same as trivially copyable, but it ought to be close enough in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closer still:

std::is_trivially_destructible<T>, satisifies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>

libstdc++ doesn't appear to have an equivalent has_trivial_move_constructor, but the above should handle most cases.

satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

Expand Down Expand Up @@ -948,7 +949,6 @@ struct field_descriptor {
const char *name;
size_t offset;
size_t size;
size_t alignment;
std::string format;
dtype descr;
};
Expand Down Expand Up @@ -989,9 +989,10 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
// mark unaligned fields with '^' (unaligned native type)
if (field.offset % field.alignment)
oss << '^';
// mark all fields with '^' (unaligned native type), because numpy
// and C++ don't always agree about alignment (particularly for
// complex, and we're explicitly listing all padding.
oss << '^';
oss << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
Expand Down Expand Up @@ -1053,7 +1054,6 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
::pybind11::detail::field_descriptor { \
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
alignof(decltype(std::declval<T>().Field)), \
::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
}
Expand Down
30 changes: 28 additions & 2 deletions tests/test_numpy_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ struct StringStruct {
std::array<char, 3> b;
};

struct ComplexStruct {
std::complex<float> cflt;
std::complex<double> cdbl;
};

std::ostream& operator<<(std::ostream& os, const ComplexStruct& v) {
return os << "c:" << v.cflt << "," << v.cdbl;
}

PYBIND11_PACKED(struct StructWithUglyNames {
int8_t __x__;
uint64_t __y__;
Expand Down Expand Up @@ -173,6 +182,18 @@ py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
return arr;
}

py::array_t<ComplexStruct, 0> create_complex_array(size_t n) {
auto arr = mkarray_via_buffer<ComplexStruct>(n);
auto ptr = (ComplexStruct *) arr.mutable_data();
for (size_t i = 0; i < n; i++) {
ptr[i].cflt.real(float(i));
ptr[i].cflt.imag(float(i) + 0.25f);
ptr[i].cdbl.real(double(i) + 0.5);
ptr[i].cdbl.imag(double(i) + 0.75);
}
return arr;
}

template <typename S>
py::list print_recarray(py::array_t<S, 0> arr) {
const auto req = arr.request();
Expand All @@ -194,7 +215,8 @@ py::list print_format_descriptors() {
py::format_descriptor<PartialStruct>::format(),
py::format_descriptor<PartialNestedStruct>::format(),
py::format_descriptor<StringStruct>::format(),
py::format_descriptor<EnumStruct>::format()
py::format_descriptor<EnumStruct>::format(),
py::format_descriptor<ComplexStruct>::format()
};
auto l = py::list();
for (const auto &fmt : fmts) {
Expand All @@ -212,7 +234,8 @@ py::list print_dtypes() {
py::str(py::dtype::of<PartialNestedStruct>()),
py::str(py::dtype::of<StringStruct>()),
py::str(py::dtype::of<EnumStruct>()),
py::str(py::dtype::of<StructWithUglyNames>())
py::str(py::dtype::of<StructWithUglyNames>()),
py::str(py::dtype::of<ComplexStruct>())
};
auto l = py::list();
for (const auto &s : dtypes) {
Expand Down Expand Up @@ -352,6 +375,7 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
PYBIND11_NUMPY_DTYPE(ComplexStruct, cflt, cdbl);
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);

Expand Down Expand Up @@ -380,6 +404,8 @@ test_initializer numpy_dtypes([](py::module &m) {
m.def("print_string_array", &print_recarray<StringStruct>);
m.def("create_enum_array", &create_enum_array);
m.def("print_enum_array", &print_recarray<EnumStruct>);
m.def("create_complex_array", &create_complex_array);
m.def("print_complex_array", &print_recarray<ComplexStruct>);
m.def("test_array_ctors", &test_array_ctors);
m.def("test_dtype_ctors", &test_dtype_ctors);
m.def("test_dtype_methods", &test_dtype_methods);
Expand Down
40 changes: 30 additions & 10 deletions tests/test_numpy_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,21 +72,22 @@ def test_format_descriptors():
assert re.match('^NumPy type info missing for .*UnboundStruct.*$', str(excinfo.value))

ld = np.dtype('longdouble')
ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char
ss_fmt = "T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
ldbl_fmt = ('4x' if ld.alignment > 4 else '') + '^' + ld.char
ss_fmt = "T{^?:bool_:3x^I:uint_:^f:float_:" + ldbl_fmt + ":ldbl_:}"
dbl = np.dtype('double')
partial_fmt = ("T{?:bool_:3xI:uint_:f:float_:" +
partial_fmt = ("T{^?:bool_:3x^I:uint_:^f:float_:" +
str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
"xg:ldbl_:}")
"x^g:ldbl_:}")
nested_extra = str(max(8, ld.alignment))
assert print_format_descriptors() == [
ss_fmt,
"T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}",
"T{" + ss_fmt + ":a:T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}",
"T{^?:bool_:^I:uint_:^f:float_:^g:ldbl_:}",
"T{^" + ss_fmt + ":a:^T{^?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}",
partial_fmt,
"T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
"T{3s:a:3s:b:}",
'T{q:e1:B:e2:}'
"T{" + nested_extra + "x^" + partial_fmt + ":a:" + nested_extra + "x}",
"T{^3s:a:^3s:b:}",
'T{^q:e1:^B:e2:}',
'T{^Zf:cflt:^Zd:cdbl:}'
]


Expand All @@ -104,7 +105,8 @@ def test_dtype(simple_dtype):
partial_nested_fmt(),
"[('a', 'S3'), ('b', 'S3')]",
"[('e1', '" + e + "i8'), ('e2', 'u1')]",
"[('x', 'i1'), ('y', '" + e + "u8')]"
"[('x', 'i1'), ('y', '" + e + "u8')]",
"[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]"
]

d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
Expand Down Expand Up @@ -231,6 +233,24 @@ def test_enum_array():
assert create_enum_array(0).dtype == dtype


def test_complex_array():
from pybind11_tests import create_complex_array, print_complex_array
from sys import byteorder
e = '<' if byteorder == 'little' else '>'

arr = create_complex_array(3)
dtype = arr.dtype
assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')])
assert print_complex_array(arr) == [
"c:(0,0.25),(0.5,0.75)",
"c:(1,1.25),(1.5,1.75)",
"c:(2,2.25),(2.5,2.75)"
]
assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j]
assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j]
assert create_complex_array(0).dtype == dtype


def test_signature(doc):
from pybind11_tests import create_rec_nested

Expand Down