Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate const ndarray from const data; remove view and ro. #498

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 6 additions & 20 deletions docs/api_extra.rst
Original file line number Diff line number Diff line change
Expand Up @@ -678,30 +678,16 @@ section <ndarrays>`.
In a multi-device/GPU setup, this function returns the ID of the device
storing the array.

.. cpp:function:: const Scalar * data() const
.. cpp:function:: Scalar * data() const

Return a const pointer to the array data.

.. cpp:function:: Scalar * data()

Return a mutable pointer to the array data. Only enabled when `Scalar` is
not itself ``const``.

.. cpp:function:: template <typename... Extra> auto view()

Returns an nd-array view that is optimized for fast array access on the
CPU. You may optionally specify additional ndarray constraints via the
`Extra` parameter (though a runtime check should first be performed to
ensure that the array possesses these properties).

The returned view provides the operations ``data()``, ``ndim()``,
``shape()``, ``stride()``, and ``operator()`` following the conventions
of the `ndarray` type.
Return a pointer to the array data.
If the scalar type is ``const``-qualified, a pointer-to-const is returned.

.. cpp:function:: template <typename... Ts> auto& operator()(Ts... indices)

Return a mutable reference to the element at stored at the provided
index/indices. ``sizeof(Ts)`` must match :cpp:func:`ndim()`.
Return a reference to the element stored at the provided index/indices.
If the scalar type is ``const``-qualified, a reference-to-const is
returned. Note that ``sizeof(Ts)`` must match :cpp:func:`ndim()`.

This accessor is only available when the scalar type and array dimension
were specified as template parameters.
Expand Down
6 changes: 3 additions & 3 deletions docs/ndarray.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,14 @@ Binding functions that take arrays as input
-------------------------------------------

A function that accepts a :cpp:class:`nb::ndarray\<\> <ndarray>`-typed parameter
(i.e., *without* template parameters) can be called with *any* array
(i.e., *without* template parameters) can be called with *any* writable array
from any framework regardless of the device on which it is stored. The
following example binding declaration uses this functionality to inspect the
properties of an arbitrary input array:

.. code-block:: cpp

m.def("inspect", [](nb::ndarray<> a) {
m.def("inspect", [](const nb::ndarray<>& a) {
printf("Array data pointer : %p\n", a.data());
printf("Array dimension : %zu\n", a.ndim());
for (size_t i = 0; i < a.ndim(); ++i) {
Expand Down Expand Up @@ -536,7 +536,7 @@ cannot be called using NumPy arrays that are marked as constant.

If you wish your function to be callable with constant input, either change the
parameter to ``nb::ndarray<const T, ...>`` (if the array is parameterized by
type), or write ``nb::ndarray<nb::ro>`` to accept a read-only array of any
type), or write ``nb::ndarray<const void>`` to accept a read-only array of any
type.

Limitations related to ``dtypes``
Expand Down
121 changes: 12 additions & 109 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,6 @@ struct numpy { };
struct tensorflow { };
struct pytorch { };
struct jax { };
struct ro { };

template <typename T> struct ndarray_traits {
static constexpr bool is_complex = detail::is_complex_v<T>;
Expand Down Expand Up @@ -109,7 +108,7 @@ template <size_t N> using ndim = typename detail::ndim_shape<std::make_index_seq
template <typename T> constexpr dlpack::dtype dtype() {
static_assert(
detail::is_ndarray_scalar_v<T>,
"nanobind::dtype<T>: T must be a floating point or integer variable!"
"nanobind::dtype<T>: T must be a floating point or integer type!"
);

dlpack::dtype result;
Expand Down Expand Up @@ -213,13 +212,14 @@ template <typename T> struct ndarray_arg<T, enable_if_t<ndarray_traits<T>::is_bo
}
};

template<> struct ndarray_arg<ro> {
template <typename T> struct ndarray_arg<T, enable_if_t<std::is_void_v<T>>> {
static constexpr size_t size = 0;

static constexpr auto name = const_name("writable=False");
static constexpr auto name =
const_name<std::is_const_v<T>>("writable=False", "");

static void apply(ndarray_req &tr) {
tr.req_ro = true;
tr.req_ro = std::is_const_v<T>;
}
};

Expand Down Expand Up @@ -272,8 +272,7 @@ template <typename... Ts> struct ndarray_info {

template <typename T, typename... Ts> struct ndarray_info<T, Ts...> : ndarray_info<Ts...> {
using scalar_type =
std::conditional_t<ndarray_traits<T>::is_float || ndarray_traits<T>::is_int ||
ndarray_traits<T>::is_bool || ndarray_traits<T>::is_complex,
std::conditional_t<detail::is_ndarray_scalar_v<T> || std::is_void_v<T>,
T, typename ndarray_info<Ts...>::scalar_type>;
};

Expand Down Expand Up @@ -309,66 +308,8 @@ template <typename... Ts> struct ndarray_info<jax, Ts...> : ndarray_info<Ts...>
constexpr static ndarray_framework framework = ndarray_framework::jax;
};


NAMESPACE_END(detail)

template <typename Scalar, typename Shape, char Order> struct ndarray_view {
static constexpr size_t Dim = Shape::size;

ndarray_view() = default;
ndarray_view(const ndarray_view &) = default;
ndarray_view(ndarray_view &&) = default;
ndarray_view &operator=(const ndarray_view &) = default;
ndarray_view &operator=(ndarray_view &&) noexcept = default;
~ndarray_view() noexcept = default;

template <typename... Ts> NB_INLINE Scalar &operator()(Ts... indices) const {
static_assert(
sizeof...(Ts) == Dim,
"ndarray_view::operator(): invalid number of arguments");

const int64_t indices_i64[] { (int64_t) indices... };
int64_t offset = 0;
for (size_t i = 0; i < Dim; ++i)
offset += indices_i64[i] * m_strides[i];

return *(m_data + offset);
}

size_t ndim() const { return Dim; }
size_t shape(size_t i) const { return m_shape[i]; }
int64_t stride(size_t i) const { return m_strides[i]; }
Scalar *data() const { return m_data; }

private:
template <typename...> friend class ndarray;

template <size_t... I1, ssize_t... I2>
ndarray_view(Scalar *data, const int64_t *shape, const int64_t *strides,
std::index_sequence<I1...>, nanobind::shape<I2...>)
: m_data(data) {

/* Initialize shape/strides with compile-time knowledge if
available (to permit vectorization, loop unrolling, etc.) */
((m_shape[I1] = (I2 == -1) ? shape[I1] : (int64_t) I2), ...);
((m_strides[I1] = strides[I1]), ...);

if constexpr (Order == 'F') {
m_strides[0] = 1;
for (size_t i = 1; i < Dim; ++i)
m_strides[i] = m_strides[i - 1] * m_shape[i - 1];
} else if constexpr (Order == 'C') {
m_strides[Dim - 1] = 1;
for (Py_ssize_t i = (Py_ssize_t) Dim - 2; i >= 0; --i)
m_strides[i] = m_strides[i + 1] * m_shape[i + 1];
}
}

Scalar *m_data = nullptr;
int64_t m_shape[Dim] { };
int64_t m_strides[Dim] { };
};


template <typename... Args> class ndarray {
public:
Expand Down Expand Up @@ -471,60 +412,22 @@ template <typename... Args> class ndarray {
size_t itemsize() const { return ((size_t) dtype().bits + 7) / 8; }
size_t nbytes() const { return ((size_t) dtype().bits * size() + 7) / 8; }

const Scalar *data() const {
return (const Scalar *)((const uint8_t *) m_dltensor.data + m_dltensor.byte_offset);
}

template <typename T = Scalar, std::enable_if_t<!std::is_const_v<T>, int> = 1>
Scalar *data() {
Scalar *data() const {
return (Scalar *) ((uint8_t *) m_dltensor.data +
m_dltensor.byte_offset);
}

template <typename T = Scalar,
std::enable_if_t<!std::is_const_v<T>, int> = 1, typename... Ts>
NB_INLINE auto &operator()(Ts... indices) {
template <typename... Ts>
NB_INLINE auto& operator()(Ts... indices) const {
return *(Scalar *) ((uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

template <typename... Ts> NB_INLINE const auto & operator()(Ts... indices) const {
return *(const Scalar *) ((const uint8_t *) m_dltensor.data +
byte_offset(indices...));
}

template <typename... Extra> NB_INLINE auto view() const {
using Info2 = typename ndarray<Args..., Extra...>::Info;
using Scalar2 = typename Info2::scalar_type;
using Shape2 = typename Info2::shape_type;

constexpr bool has_scalar = !std::is_same_v<Scalar2, void>,
has_shape = !std::is_same_v<Shape2, void>;

static_assert(has_scalar,
"To use the ndarray::view<..>() method, you must add a scalar type "
"annotation (e.g. 'float') to the template parameters of the parent "
"ndarray, or to the call to .view<..>()");

static_assert(has_shape,
"To use the ndarray::view<..>() method, you must add a shape<..> "
"or ndim<..> annotation to the template parameters of the parent "
"ndarray, or to the call to .view<..>()");

if constexpr (has_scalar && has_shape) {
return ndarray_view<Scalar2, Shape2, Info2::order>(
(Scalar2 *) data(), shape_ptr(), stride_ptr(),
std::make_index_sequence<Shape2::size>(), Shape2());
} else {
return nullptr;
}
}

private:
template <typename... Ts>
NB_INLINE int64_t byte_offset(Ts... indices) const {
constexpr bool has_scalar = !std::is_same_v<Scalar, void>,
has_shape = !std::is_same_v<typename Info::shape_type, void>;
constexpr bool has_scalar = !std::is_void_v<Scalar>,
has_shape = !std::is_void_v<typename Info::shape_type>;

static_assert(has_scalar,
"To use ndarray::operator(), you must add a scalar type "
Expand All @@ -542,7 +445,7 @@ template <typename... Args> class ndarray {
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
return (int64_t) m_dltensor.byte_offset + index * sizeof(Scalar);
} else {
return 0;
}
Expand Down
107 changes: 68 additions & 39 deletions tests/test_ndarray.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,23 @@ namespace nanobind {
}
#endif

template<bool writable, bool is_shaped, typename... Ts>
bool check_const(const nb::ndarray<Ts...>& a) { // Pytest passes 3 doubles
static_assert(std::is_const_v<std::remove_pointer_t<decltype(a.data())>>
== !writable);
if constexpr (is_shaped) {
static_assert(std::is_const_v<std::remove_reference_t<decltype(a(0))>>
== !writable);
if constexpr (writable) {
a(1) *= 2.0;
return a(0) == 0.0 && a(2) == 2.718282;
}
}
return true;
}

NB_MODULE(test_ndarray_ext, m) {
m.def("get_shape", [](const nb::ndarray<nb::ro> &t) {
m.def("get_shape", [](const nb::ndarray<const void> &t) {
nb::list l;
for (size_t i = 0; i < t.ndim(); ++i)
l.append(t.shape(i));
Expand Down Expand Up @@ -86,6 +101,57 @@ NB_MODULE(test_ndarray_ext, m) {
[](const nb::ndarray<float, nb::c_contig,
nb::shape<-1, -1, 4>> &) {}, "array"_a.noconvert());

m.def("check_rw_by_value",
[](nb::ndarray<> a) {
return check_const</*writable=*/true, /*is_shaped=*/false>(a);
});
m.def("check_ro_by_value_const_void",
[](nb::ndarray<const void> a) {
return check_const</*writable=*/false, /*is_shaped=*/false>(a);
});
m.def("check_rw_by_value_float64",
[](nb::ndarray<double, nb::ndim<1>> a) {
return check_const</*writable=*/true, /*is_shaped=*/true>(a);
});
m.def("check_ro_by_value_const_float64",
[](nb::ndarray<const double, nb::ndim<1>> a) {
return check_const</*writable=*/false, /*is_shaped=*/true>(a);
});

m.def("check_rw_by_const_ref",
[](const nb::ndarray<>& a) {
return check_const</*writable=*/true, /*is_shaped=*/false>(a);
});
m.def("check_ro_by_const_ref_const_void",
[](const nb::ndarray<const void>& a) {
return check_const</*writable=*/false, /*is_shaped=*/false>(a);
});
m.def("check_rw_by_const_ref_float64",
[](nb::ndarray<double, nb::ndim<1>> a) {
return check_const</*writable=*/true, /*is_shaped=*/true>(a);
});
m.def("check_ro_by_const_ref_const_float64",
[](const nb::ndarray<const double, nb::ndim<1>>& a) {
return check_const</*writable=*/false, /*is_shaped=*/true>(a);
});

m.def("check_rw_by_rvalue_ref",
[](nb::ndarray<>&& a) {
return check_const</*writable=*/true, /*is_shaped=*/false>(a);
});
m.def("check_ro_by_rvalue_ref_const_void",
[](nb::ndarray<const void>&& a) {
return check_const</*writable=*/false, /*is_shaped=*/false>(a);
});
m.def("check_rw_by_rvalue_ref_float64",
[](nb::ndarray<double, nb::ndim<1>>&& a) {
return check_const</*writable=*/true, /*is_shaped=*/true>(a);
});
m.def("check_ro_by_rvalue_ref_const_float64",
[](nb::ndarray<const double, nb::ndim<1>>&& a) {
return check_const</*writable=*/false, /*is_shaped=*/true>(a);
});

m.def("check_order", [](nb::ndarray<nb::c_contig>) -> char { return 'C'; });
m.def("check_order", [](nb::ndarray<nb::f_contig>) -> char { return 'F'; });
m.def("check_order", [](nb::ndarray<>) -> char { return '?'; });
Expand Down Expand Up @@ -119,7 +185,7 @@ NB_MODULE(test_ndarray_ext, m) {
[](nb::ndarray<float, nb::c_contig, nb::shape<2, 2>>) { return 0; },
"array"_a);

m.def("inspect_ndarray", [](nb::ndarray<> ndarray) {
m.def("inspect_ndarray", [](const nb::ndarray<>& ndarray) {
printf("Tensor data pointer : %p\n", ndarray.data());
printf("Tensor dimension : %zu\n", ndarray.ndim());
for (size_t i = 0; i < ndarray.ndim(); ++i) {
Expand Down Expand Up @@ -241,43 +307,6 @@ NB_MODULE(test_ndarray_ext, m) {
.def("f2_ri", &Cls::f2, nb::rv_policy::reference_internal)
.def("f3_ri", &Cls::f3, nb::rv_policy::reference_internal);

m.def("fill_view_1", [](nb::ndarray<> x) {
if (x.ndim() == 2 && x.dtype() == nb::dtype<float>()) {
auto v = x.view<float, nb::ndim<2>>();
for (size_t i = 0; i < v.shape(0); i++)
for (size_t j = 0; j < v.shape(1); j++)
v(i, j) *= 2;
}
}, "x"_a.noconvert());

m.def("fill_view_2", [](nb::ndarray<float, nb::ndim<2>, nb::device::cpu> x) {
auto v = x.view();
for (size_t i = 0; i < v.shape(0); ++i)
for (size_t j = 0; j < v.shape(1); ++j)
v(i, j) = (float) (i * 10 + j);
}, "x"_a.noconvert());

m.def("fill_view_3", [](nb::ndarray<float, nb::shape<3, 4>, nb::c_contig, nb::device::cpu> x) {
auto v = x.view();
for (size_t i = 0; i < v.shape(0); ++i)
for (size_t j = 0; j < v.shape(1); ++j)
v(i, j) = (float) (i * 10 + j);
}, "x"_a.noconvert());

m.def("fill_view_4", [](nb::ndarray<float, nb::shape<3, 4>, nb::f_contig, nb::device::cpu> x) {
auto v = x.view();
for (size_t i = 0; i < v.shape(0); ++i)
for (size_t j = 0; j < v.shape(1); ++j)
v(i, j) = (float) (i * 10 + j);
}, "x"_a.noconvert());

m.def("fill_view_5", [](nb::ndarray<std::complex<float>, nb::shape<2, 2>, nb::c_contig, nb::device::cpu> x) {
auto v = x.view();
for (size_t i = 0; i < v.shape(0); ++i)
for (size_t j = 0; j < v.shape(1); ++j)
v(i, j) *= std::complex<float>(-1.0f, 2.0f);
}, "x"_a.noconvert());

#if defined(__aarch64__)
m.def("ret_numpy_half", []() {
__fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 };
Expand Down
Loading
Loading