Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow std::complex field with PYBIND11_NUMPY_DTYPE #831

Merged
merged 9 commits into from
May 10, 2017
19 changes: 15 additions & 4 deletions include/pybind11/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -608,13 +608,24 @@ template <typename T> struct is_fmt_numeric<T, enable_if_t<std::is_arithmetic<T>
NAMESPACE_END(detail)

template <typename T> struct format_descriptor<T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>> {
static constexpr const char c = "?bBhHiIqQfdgFDG"[detail::is_fmt_numeric<T>::index];
static constexpr const char value[2] = { c, '\0' };
static std::string format() { return std::string(1, c); }
static constexpr const char c1 = "?bBhHiIqQfdgZZZ"[detail::is_fmt_numeric<T>::index];
static constexpr const char c2 = "\0\0\0\0\0\0\0\0\0\0\0\0fdg"[detail::is_fmt_numeric<T>::index];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have to say, I'm not a big fan of this part. I think it would be cleaner to just remove FDG from here and specialize the whole format_descriptor struct for std::complex inside complex.h.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Especially in light of the extra hack required below for MSVC).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Now for some reason the VS2017 x86 Python 3.6 build (but none of the others) is failing, and there is no indication of what the error is. Also, Travis doesn't seem to have been run at all. Any ideas?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The AppVeyor issue is most likely random: #792. Not sure what happened to Travis.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, looks like it was just random failures. Appveyor is still running the VS2015 builds, but it should all be working now.

static constexpr const char value[3] = { c1, c2, '\0' };
static std::string format() {
#if !defined(_MSC_VER) || _MSC_VER >= 1910
return std::string(value);
#else
// MSVC 2015 has trouble constructing std::string from value
std::string out(1, c1);
if (c2)
out += c2;
return out;
#endif
}
};

template <typename T> constexpr const char format_descriptor<
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[2];
T, detail::enable_if_t<detail::is_fmt_numeric<T>::value>>::value[3];

/// RAII wrapper that temporarily clears any Python error state
struct error_scope {
Expand Down
21 changes: 14 additions & 7 deletions include/pybind11/numpy.h
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,14 @@ template <typename T> struct is_complex : std::false_type { };
template <typename T> struct is_complex<std::complex<T>> : std::true_type { };

template <typename T> using is_pod_struct = all_of<
std::is_pod<T>, // since we're accessing directly in memory we need a POD type
std::is_standard_layout<T>, // since we're accessing directly in memory we need a standard layout type
#if !defined(__GNUG__) || defined(__clang__) || __GNUC__ >= 5
std::is_trivially_copyable<T>,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For libstcd++ < 5 (which doesn't have is_trivially_copyable) this workaround might work:

    std::is_standard_layout<T>,
#if !defined(__GNUG__) || defined(__clang__) || __GNUC__ >= 5
    std::is_trivially_copyable<T>,
#else
    std::is_trivially_destructible<T>, std::has_trivial_copy_constructor<T>,
#endif
    satisfies_none_of<...

It's not quite the same as trivially copyable, but it ought to be close enough in practice.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Closer still:

std::is_trivially_destructible<T>, satisifies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>

libstdc++ doesn't appear to have an equivalent has_trivial_move_constructor, but the above should handle most cases.

#else
// GCC 4 doesn't implement is_trivially_copyable, so approximate it
std::is_trivially_destructible<T>,
satisfies_any_of<T, std::has_trivial_copy_constructor, std::has_trivial_copy_assign>,
#endif
satisfies_none_of<T, std::is_reference, std::is_array, is_std_array, std::is_arithmetic, is_complex, std::is_enum>
>;

Expand Down Expand Up @@ -948,7 +955,6 @@ struct field_descriptor {
const char *name;
size_t offset;
size_t size;
size_t alignment;
std::string format;
dtype descr;
};
Expand Down Expand Up @@ -985,13 +991,15 @@ inline PYBIND11_NOINLINE void register_structured_dtype(
[](const field_descriptor &a, const field_descriptor &b) { return a.offset < b.offset; });
size_t offset = 0;
std::ostringstream oss;
oss << "T{";
// mark the structure as unaligned with '^', because numpy and C++ don't
// always agree about alignment (particularly for complex), and we're
// explicitly listing all our padding. This depends on none of the fields
// overriding the endianness. Putting the ^ in front of individual fields
// isn't guaranteed to work due to https://github.com/numpy/numpy/issues/9049
oss << "^T{";
for (auto& field : ordered_fields) {
if (field.offset > offset)
oss << (field.offset - offset) << 'x';
// mark unaligned fields with '^' (unaligned native type)
if (field.offset % field.alignment)
oss << '^';
oss << field.format << ':' << field.name << ':';
offset = field.offset + field.size;
}
Expand Down Expand Up @@ -1053,7 +1061,6 @@ template <typename T, typename SFINAE> struct npy_format_descriptor {
#define PYBIND11_FIELD_DESCRIPTOR_EX(T, Field, Name) \
::pybind11::detail::field_descriptor { \
Name, offsetof(T, Field), sizeof(decltype(std::declval<T>().Field)), \
alignof(decltype(std::declval<T>().Field)), \
::pybind11::format_descriptor<decltype(std::declval<T>().Field)>::format(), \
::pybind11::detail::npy_format_descriptor<decltype(std::declval<T>().Field)>::dtype() \
}
Expand Down
30 changes: 28 additions & 2 deletions tests/test_numpy_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,15 @@ struct StringStruct {
std::array<char, 3> b;
};

struct ComplexStruct {
std::complex<float> cflt;
std::complex<double> cdbl;
};

std::ostream& operator<<(std::ostream& os, const ComplexStruct& v) {
return os << "c:" << v.cflt << "," << v.cdbl;
}

PYBIND11_PACKED(struct StructWithUglyNames {
int8_t __x__;
uint64_t __y__;
Expand Down Expand Up @@ -173,6 +182,18 @@ py::array_t<EnumStruct, 0> create_enum_array(size_t n) {
return arr;
}

py::array_t<ComplexStruct, 0> create_complex_array(size_t n) {
auto arr = mkarray_via_buffer<ComplexStruct>(n);
auto ptr = (ComplexStruct *) arr.mutable_data();
for (size_t i = 0; i < n; i++) {
ptr[i].cflt.real(float(i));
ptr[i].cflt.imag(float(i) + 0.25f);
ptr[i].cdbl.real(double(i) + 0.5);
ptr[i].cdbl.imag(double(i) + 0.75);
}
return arr;
}

template <typename S>
py::list print_recarray(py::array_t<S, 0> arr) {
const auto req = arr.request();
Expand All @@ -194,7 +215,8 @@ py::list print_format_descriptors() {
py::format_descriptor<PartialStruct>::format(),
py::format_descriptor<PartialNestedStruct>::format(),
py::format_descriptor<StringStruct>::format(),
py::format_descriptor<EnumStruct>::format()
py::format_descriptor<EnumStruct>::format(),
py::format_descriptor<ComplexStruct>::format()
};
auto l = py::list();
for (const auto &fmt : fmts) {
Expand All @@ -212,7 +234,8 @@ py::list print_dtypes() {
py::str(py::dtype::of<PartialNestedStruct>()),
py::str(py::dtype::of<StringStruct>()),
py::str(py::dtype::of<EnumStruct>()),
py::str(py::dtype::of<StructWithUglyNames>())
py::str(py::dtype::of<StructWithUglyNames>()),
py::str(py::dtype::of<ComplexStruct>())
};
auto l = py::list();
for (const auto &s : dtypes) {
Expand Down Expand Up @@ -352,6 +375,7 @@ test_initializer numpy_dtypes([](py::module &m) {
PYBIND11_NUMPY_DTYPE(PartialNestedStruct, a);
PYBIND11_NUMPY_DTYPE(StringStruct, a, b);
PYBIND11_NUMPY_DTYPE(EnumStruct, e1, e2);
PYBIND11_NUMPY_DTYPE(ComplexStruct, cflt, cdbl);
PYBIND11_NUMPY_DTYPE(TrailingPaddingStruct, a, b);
PYBIND11_NUMPY_DTYPE(CompareStruct, x, y, z);

Expand Down Expand Up @@ -380,6 +404,8 @@ test_initializer numpy_dtypes([](py::module &m) {
m.def("print_string_array", &print_recarray<StringStruct>);
m.def("create_enum_array", &create_enum_array);
m.def("print_enum_array", &print_recarray<EnumStruct>);
m.def("create_complex_array", &create_complex_array);
m.def("print_complex_array", &print_recarray<ComplexStruct>);
m.def("test_array_ctors", &test_array_ctors);
m.def("test_dtype_ctors", &test_dtype_ctors);
m.def("test_dtype_methods", &test_dtype_methods);
Expand Down
36 changes: 28 additions & 8 deletions tests/test_numpy_dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,20 +73,21 @@ def test_format_descriptors():

ld = np.dtype('longdouble')
ldbl_fmt = ('4x' if ld.alignment > 4 else '') + ld.char
ss_fmt = "T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
ss_fmt = "^T{?:bool_:3xI:uint_:f:float_:" + ldbl_fmt + ":ldbl_:}"
dbl = np.dtype('double')
partial_fmt = ("T{?:bool_:3xI:uint_:f:float_:" +
partial_fmt = ("^T{?:bool_:3xI:uint_:f:float_:" +
str(4 * (dbl.alignment > 4) + dbl.itemsize + 8 * (ld.alignment > 8)) +
"xg:ldbl_:}")
nested_extra = str(max(8, ld.alignment))
assert print_format_descriptors() == [
ss_fmt,
"T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}",
"T{" + ss_fmt + ":a:T{?:bool_:^I:uint_:^f:float_:^g:ldbl_:}:b:}",
"^T{?:bool_:I:uint_:f:float_:g:ldbl_:}",
"^T{" + ss_fmt + ":a:^T{?:bool_:I:uint_:f:float_:g:ldbl_:}:b:}",
partial_fmt,
"T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
"T{3s:a:3s:b:}",
'T{q:e1:B:e2:}'
"^T{" + nested_extra + "x" + partial_fmt + ":a:" + nested_extra + "x}",
"^T{3s:a:3s:b:}",
'^T{q:e1:B:e2:}',
'^T{Zf:cflt:Zd:cdbl:}'
]


Expand All @@ -104,7 +105,8 @@ def test_dtype(simple_dtype):
partial_nested_fmt(),
"[('a', 'S3'), ('b', 'S3')]",
"[('e1', '" + e + "i8'), ('e2', 'u1')]",
"[('x', 'i1'), ('y', '" + e + "u8')]"
"[('x', 'i1'), ('y', '" + e + "u8')]",
"[('cflt', '" + e + "c8'), ('cdbl', '" + e + "c16')]"
]

d1 = np.dtype({'names': ['a', 'b'], 'formats': ['int32', 'float64'],
Expand Down Expand Up @@ -231,6 +233,24 @@ def test_enum_array():
assert create_enum_array(0).dtype == dtype


def test_complex_array():
from pybind11_tests import create_complex_array, print_complex_array
from sys import byteorder
e = '<' if byteorder == 'little' else '>'

arr = create_complex_array(3)
dtype = arr.dtype
assert dtype == np.dtype([('cflt', e + 'c8'), ('cdbl', e + 'c16')])
assert print_complex_array(arr) == [
"c:(0,0.25),(0.5,0.75)",
"c:(1,1.25),(1.5,1.75)",
"c:(2,2.25),(2.5,2.75)"
]
assert arr['cflt'].tolist() == [0.0 + 0.25j, 1.0 + 1.25j, 2.0 + 2.25j]
assert arr['cdbl'].tolist() == [0.5 + 0.75j, 1.5 + 1.75j, 2.5 + 2.75j]
assert create_complex_array(0).dtype == dtype


def test_signature(doc):
from pybind11_tests import create_rec_nested

Expand Down