diff --git a/docs/api_extra.rst b/docs/api_extra.rst index 5332f326..c197dde2 100644 --- a/docs/api_extra.rst +++ b/docs/api_extra.rst @@ -664,6 +664,12 @@ section `. .. cpp:class:: template ndarray + .. cpp:var:: ReadOnly + + A constant static boolean that is true if the array's data is read-only. + This is determined by the class template arguments, not by any dynamic + properties of the referenced array. + .. cpp:function:: ndarray() = default Create an invalid array. @@ -793,14 +799,19 @@ section `. In a multi-device/GPU setup, this function returns the ID of the device storing the array. - .. cpp:function:: const Scalar * data() const + .. cpp:function:: Scalar * data() const + + Return a pointer to the array data. + If :cpp:var:`ReadOnly` is true, a pointer-to-const is returned. - Return a const pointer to the array data. + .. cpp:function:: template auto& operator()(Ts... indices) - .. cpp:function:: Scalar * data() + Return a reference to the element stored at the provided index/indices. + If :cpp:var:`ReadOnly` is true, a reference-to-const is returned. + Note that ``sizeof(Ts)`` must match :cpp:func:`ndim()`. - Return a mutable pointer to the array data. Only enabled when `Scalar` is - not itself ``const``. + This accessor is only available when the scalar type and array dimension + were specified as template parameters. .. cpp:function:: template auto view() @@ -813,14 +824,6 @@ section `. ``shape()``, ``stride()``, and ``operator()`` following the conventions of the `ndarray` type. - .. cpp:function:: template auto& operator()(Ts... indices) - - Return a mutable reference to the element at stored at the provided - index/indices. ``sizeof(Ts)`` must match :cpp:func:`ndim()`. - - This accessor is only available when the scalar type and array dimension - were specified as template parameters. - Data types ^^^^^^^^^^ diff --git a/docs/ndarray.rst b/docs/ndarray.rst index 0ccbbb11..affc9ba2 100644 --- a/docs/ndarray.rst +++ b/docs/ndarray.rst @@ -33,14 +33,14 @@ Binding functions that take arrays as input ------------------------------------------- A function that accepts a :cpp:class:`nb::ndarray\<\> `-typed parameter -(i.e., *without* template parameters) can be called with *any* array +(i.e., *without* template parameters) can be called with *any* writable array from any framework regardless of the device on which it is stored. The following example binding declaration uses this functionality to inspect the properties of an arbitrary input array: .. code-block:: cpp - m.def("inspect", [](nb::ndarray<> a) { + m.def("inspect", [](const nb::ndarray<>& a) { printf("Array data pointer : %p\n", a.data()); printf("Array dimension : %zu\n", a.ndim()); for (size_t i = 0; i < a.ndim(); ++i) { diff --git a/include/nanobind/ndarray.h b/include/nanobind/ndarray.h index 0aef12e2..bde13ada 100644 --- a/include/nanobind/ndarray.h +++ b/include/nanobind/ndarray.h @@ -110,7 +110,7 @@ template using ndim = typename detail::ndim_shape constexpr dlpack::dtype dtype() { static_assert( detail::is_ndarray_scalar_v, - "nanobind::dtype: T must be a floating point or integer variable!" + "nanobind::dtype: T must be a floating point or integer type!" ); dlpack::dtype result; @@ -266,6 +266,7 @@ template struct ndarray_arg> { template struct ndarray_info { using scalar_type = void; using shape_type = void; + constexpr static bool ReadOnly = false; constexpr static auto name = const_name("ndarray"); constexpr static ndarray_framework framework = ndarray_framework::none; constexpr static char order = '\0'; @@ -273,39 +274,68 @@ template struct ndarray_info { template struct ndarray_info : ndarray_info { using scalar_type = - std::conditional_t::is_float || ndarray_traits::is_int || - ndarray_traits::is_bool || ndarray_traits::is_complex, - T, typename ndarray_info::scalar_type>; + std::conditional_t< + detail::is_ndarray_scalar_v && + std::is_void_v::scalar_type>, + T, typename ndarray_info::scalar_type>; + + constexpr static bool ReadOnly = ndarray_info::ReadOnly || + (detail::is_ndarray_scalar_v && std::is_const_v); +}; + +template struct ndarray_info : ndarray_info { + constexpr static bool ReadOnly = true; }; template struct ndarray_info, Ts...> : ndarray_info { - using shape_type = shape; + using shape_type = + std::conditional_t< + std::is_void_v::shape_type>, + shape, typename ndarray_info::shape_type>; }; template struct ndarray_info : ndarray_info { + static_assert(ndarray_info::order == '\0' + || ndarray_info::order == 'C', + "The order can only be set once."); constexpr static char order = 'C'; }; template struct ndarray_info : ndarray_info { + static_assert(ndarray_info::order == '\0' + || ndarray_info::order == 'F', + "The order can only be set once."); constexpr static char order = 'F'; }; template struct ndarray_info : ndarray_info { + static_assert(ndarray_info::framework == ndarray_framework::none + || ndarray_info::framework == ndarray_framework::numpy, + "The framework can only be set once."); constexpr static auto name = const_name("numpy.ndarray"); constexpr static ndarray_framework framework = ndarray_framework::numpy; }; template struct ndarray_info : ndarray_info { + static_assert(ndarray_info::framework == ndarray_framework::none + || ndarray_info::framework == ndarray_framework::pytorch, + "The framework can only be set once."); constexpr static auto name = const_name("torch.Tensor"); constexpr static ndarray_framework framework = ndarray_framework::pytorch; }; template struct ndarray_info : ndarray_info { + static_assert(ndarray_info::framework == ndarray_framework::none + || ndarray_info::framework == ndarray_framework::tensorflow, + "The framework can only be set once."); constexpr static auto name = const_name("tensorflow.python.framework.ops.EagerTensor"); constexpr static ndarray_framework framework = ndarray_framework::tensorflow; }; template struct ndarray_info : ndarray_info { + static_assert(ndarray_info::framework == ndarray_framework::none + || ndarray_info::framework == ndarray_framework::jax, + "The framework can only be set once."); constexpr static auto name = const_name("jaxlib.xla_extension.DeviceArray"); constexpr static ndarray_framework framework = ndarray_framework::jax; }; @@ -314,9 +344,9 @@ template struct ndarray_info : ndarray_info constexpr static ndarray_framework framework = ndarray_framework::cupy; }; - NAMESPACE_END(detail) + template struct ndarray_view { static constexpr size_t Dim = Shape::size; @@ -380,7 +410,10 @@ template class ndarray { template friend class ndarray; using Info = detail::ndarray_info; - using Scalar = typename Info::scalar_type; + static constexpr bool ReadOnly = Info::ReadOnly; + using Scalar = std::conditional_t, + typename Info::scalar_type>; ndarray() = default; @@ -392,7 +425,7 @@ template class ndarray { template explicit ndarray(const ndarray &other) : ndarray(other.m_handle) { } - ndarray(std::conditional_t, const void *, void *> data, + ndarray(std::conditional_t data, size_t ndim, const size_t *shape, handle owner, @@ -402,11 +435,11 @@ template class ndarray { int32_t device_id = 0) { m_handle = detail::ndarray_create( (void *) data, ndim, shape, owner.ptr(), strides, &dtype, - std::is_const_v, device_type, device_id); + ReadOnly, device_type, device_id); m_dltensor = *detail::ndarray_inc_ref(m_handle); } - ndarray(std::conditional_t, const void *, void *> data, + ndarray(std::conditional_t data, std::initializer_list shape, handle owner, std::initializer_list strides = { }, @@ -420,7 +453,7 @@ template class ndarray { m_handle = detail::ndarray_create( (void *) data, shape.size(), shape.begin(), owner.ptr(), (strides.size() == 0) ? nullptr : strides.begin(), &dtype, - std::is_const_v, device_type, device_id); + ReadOnly, device_type, device_id); m_dltensor = *detail::ndarray_inc_ref(m_handle); } @@ -476,35 +509,26 @@ template class ndarray { size_t itemsize() const { return ((size_t) dtype().bits + 7) / 8; } size_t nbytes() const { return ((size_t) dtype().bits * size() + 7) / 8; } - const Scalar *data() const { - return (const Scalar *)((const uint8_t *) m_dltensor.data + m_dltensor.byte_offset); - } - - template , int> = 1> - Scalar *data() { + Scalar *data() const { return (Scalar *) ((uint8_t *) m_dltensor.data + m_dltensor.byte_offset); } - template , int> = 1, typename... Ts> - NB_INLINE auto &operator()(Ts... indices) { + template + NB_INLINE auto& operator()(Ts... indices) const { return *(Scalar *) ((uint8_t *) m_dltensor.data + byte_offset(indices...)); } - template NB_INLINE const auto & operator()(Ts... indices) const { - return *(const Scalar *) ((const uint8_t *) m_dltensor.data + - byte_offset(indices...)); - } - template NB_INLINE auto view() const { using Info2 = typename ndarray::Info; - using Scalar2 = typename Info2::scalar_type; + using Scalar2 = std::conditional_t, + typename Info2::scalar_type>; using Shape2 = typename Info2::shape_type; - constexpr bool has_scalar = !std::is_same_v, - has_shape = !std::is_same_v; + constexpr bool has_scalar = !std::is_void_v, + has_shape = !std::is_void_v; static_assert(has_scalar, "To use the ndarray::view<..>() method, you must add a scalar type " @@ -528,8 +552,8 @@ template class ndarray { private: template NB_INLINE int64_t byte_offset(Ts... indices) const { - constexpr bool has_scalar = !std::is_same_v, - has_shape = !std::is_same_v; + constexpr bool has_scalar = !std::is_void_v, + has_shape = !std::is_void_v; static_assert(has_scalar, "To use ndarray::operator(), you must add a scalar type " @@ -547,7 +571,7 @@ template class ndarray { int64_t index = 0; ((index += int64_t(indices) * m_dltensor.strides[counter++]), ...); - return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type); + return (int64_t) m_dltensor.byte_offset + index * sizeof(Scalar); } else { return 0; } diff --git a/tests/test_ndarray.cpp b/tests/test_ndarray.cpp index a8e66b8b..fe66b8fe 100644 --- a/tests/test_ndarray.cpp +++ b/tests/test_ndarray.cpp @@ -24,6 +24,43 @@ namespace nanobind { } #endif +template +bool check_ro(const nb::ndarray& a) { // Pytest passes five doubles + static_assert(std::remove_reference_t::ReadOnly == expect_ro); + static_assert(std::is_const_v> + == expect_ro); + auto vd = a.template view>(); + static_assert(std::is_const_v> + == expect_ro); + static_assert(std::is_const_v> + == expect_ro); + auto vcd = a.template view>(); + static_assert(std::is_const_v>); + static_assert(std::is_const_v>); + + bool pass = vd.data() == a.data() && vcd.data() == a.data(); + if constexpr (!expect_ro) { + vd(1) = 1.414214; + pass &= vcd(1) == 1.414214; + } + if constexpr (is_shaped) { + static_assert(std::is_const_v> + == expect_ro); + auto v = a.view(); + static_assert(std::is_const_v> + == expect_ro); + static_assert(std::is_const_v> + == expect_ro); + pass &= v.data() == a.data(); + if constexpr (!expect_ro) { + a(2) = 2.718282; + v(4) = 16.0; + } + } + pass &= vcd(3) == 3.14159; + return pass; +} + NB_MODULE(test_ndarray_ext, m) { m.def("get_is_valid", [](const nb::ndarray &t) { return t.is_valid(); @@ -90,6 +127,57 @@ NB_MODULE(test_ndarray_ext, m) { [](const nb::ndarray> &) {}, "array"_a.noconvert()); + m.def("check_rw_by_value", + [](nb::ndarray<> a) { + return check_ro(a); + }); + m.def("check_ro_by_value_ro", + [](nb::ndarray a) { + return check_ro(a); + }); + m.def("check_rw_by_value_float64", + [](nb::ndarray> a) { + return check_ro(a); + }); + m.def("check_ro_by_value_const_float64", + [](nb::ndarray> a) { + return check_ro(a); + }); + + m.def("check_rw_by_const_ref", + [](const nb::ndarray<>& a) { + return check_ro(a); + }); + m.def("check_ro_by_const_ref_ro", + [](const nb::ndarray& a) { + return check_ro(a); + }); + m.def("check_rw_by_const_ref_float64", + [](nb::ndarray> a) { + return check_ro(a); + }); + m.def("check_ro_by_const_ref_const_float64", + [](const nb::ndarray>& a) { + return check_ro(a); + }); + + m.def("check_rw_by_rvalue_ref", + [](nb::ndarray<>&& a) { + return check_ro(a); + }); + m.def("check_ro_by_rvalue_ref_ro", + [](nb::ndarray&& a) { + return check_ro(a); + }); + m.def("check_rw_by_rvalue_ref_float64", + [](nb::ndarray>&& a) { + return check_ro(a); + }); + m.def("check_ro_by_rvalue_ref_const_float64", + [](nb::ndarray>&& a) { + return check_ro(a); + }); + m.def("check_order", [](nb::ndarray) -> char { return 'C'; }); m.def("check_order", [](nb::ndarray) -> char { return 'F'; }); m.def("check_order", [](nb::ndarray<>) -> char { return '?'; }); @@ -123,7 +211,7 @@ NB_MODULE(test_ndarray_ext, m) { [](nb::ndarray>) { return 0; }, "array"_a); - m.def("inspect_ndarray", [](nb::ndarray<> ndarray) { + m.def("inspect_ndarray", [](const nb::ndarray<>& ndarray) { printf("Tensor data pointer : %p\n", ndarray.data()); printf("Tensor dimension : %zu\n", ndarray.ndim()); for (size_t i = 0; i < ndarray.ndim(); ++i) { @@ -285,6 +373,12 @@ NB_MODULE(test_ndarray_ext, m) { v(i, j) *= std::complex(-1.0f, 2.0f); }, "x"_a.noconvert()); + m.def("fill_view_6", [](nb::ndarray, nb::shape<2, 2>, nb::c_contig, nb::device::cpu> x) { + auto v = x.view>(); + for (size_t i = 0; i < v.shape(0); ++i) + v(i) = -v(i); + }, "x"_a.noconvert()); + #if defined(__aarch64__) m.def("ret_numpy_half", []() { __fp16 *f = new __fp16[8] { 1, 2, 3, 4, 5, 6, 7, 8 }; diff --git a/tests/test_ndarray.py b/tests/test_ndarray.py index dbf5bd72..baa56b2f 100644 --- a/tests/test_ndarray.py +++ b/tests/test_ndarray.py @@ -628,6 +628,9 @@ def test31_view(): x2 = x1 * (-1+2j) t.fill_view_5(x1) assert np.allclose(x1, x2) + x2 = -x2; + t.fill_view_6(x1) + assert np.allclose(x1, x2) @needs_numpy def test32_half(): @@ -701,9 +704,88 @@ def test37_noninteger_stride(): t.get_stride(v, 0); assert 'incompatible function arguments' in str(excinfo.value) +@needs_numpy +def test38_const_qualifiers_numpy(): + a = np.array([0, 0, 0, 3.14159, 0], dtype=np.float64) + assert t.check_rw_by_value(a); + assert a[1] == 1.414214; + assert t.check_rw_by_value_float64(a); + assert a[2] == 2.718282; + assert a[4] == 16.0; + assert t.check_ro_by_value_ro(a); + assert t.check_ro_by_value_const_float64(a); + a.setflags(write=False) + assert t.check_ro_by_value_ro(a); + assert t.check_ro_by_value_const_float64(a); + assert a[0] == 0.0; + assert a[3] == 3.14159; + + a = np.array([0, 0, 0, 3.14159, 0], dtype=np.float64) + assert t.check_rw_by_const_ref(a); + assert a[1] == 1.414214; + assert t.check_rw_by_const_ref_float64(a); + assert a[2] == 2.718282; + assert a[4] == 16.0; + assert t.check_ro_by_const_ref_ro(a); + assert t.check_ro_by_const_ref_const_float64(a); + a.setflags(write=False) + assert t.check_ro_by_const_ref_ro(a); + assert t.check_ro_by_const_ref_const_float64(a); + assert a[0] == 0.0; + assert a[3] == 3.14159; + + a = np.array([0, 0, 0, 3.14159, 0], dtype=np.float64) + assert t.check_rw_by_rvalue_ref(a); + assert a[1] == 1.414214; + assert t.check_rw_by_rvalue_ref_float64(a); + assert a[2] == 2.718282; + assert a[4] == 16.0; + assert t.check_ro_by_rvalue_ref_ro(a); + assert t.check_ro_by_rvalue_ref_const_float64(a); + a.setflags(write=False) + assert t.check_ro_by_rvalue_ref_ro(a); + assert t.check_ro_by_rvalue_ref_const_float64(a); + assert a[0] == 0.0; + assert a[3] == 3.14159; + +@needs_torch +def test39_const_qualifiers_pytorch(): + a = torch.tensor([0, 0, 0, 3.14159, 0], dtype=torch.float64) + assert t.check_rw_by_value(a); + assert a[1] == 1.414214; + assert t.check_rw_by_value_float64(a); + assert a[2] == 2.718282; + assert a[4] == 16.0; + assert t.check_ro_by_value_ro(a); + assert t.check_ro_by_value_const_float64(a); + assert a[0] == 0.0; + assert a[3] == 3.14159; + + a = torch.tensor([0, 0, 0, 3.14159, 0], dtype=torch.float64) + assert t.check_rw_by_const_ref(a); + assert a[1] == 1.414214; + assert t.check_rw_by_const_ref_float64(a); + assert a[2] == 2.718282; + assert a[4] == 16.0; + assert t.check_ro_by_const_ref_ro(a); + assert t.check_ro_by_const_ref_const_float64(a); + assert a[0] == 0.0; + assert a[3] == 3.14159; + + a = torch.tensor([0, 0, 0, 3.14159, 0], dtype=torch.float64) + assert t.check_rw_by_rvalue_ref(a); + assert a[1] == 1.414214; + assert t.check_rw_by_rvalue_ref_float64(a); + assert a[2] == 2.718282; + assert a[4] == 16.0; + assert t.check_ro_by_rvalue_ref_ro(a); + assert t.check_ro_by_rvalue_ref_const_float64(a); + assert a[0] == 0.0; + assert a[3] == 3.14159; + @needs_cupy @pytest.mark.filterwarnings -def test38_constrain_order_cupy(): +def test40_constrain_order_cupy(): try: c = cp.zeros((3, 5)) c.__dlpack__() @@ -719,7 +801,7 @@ def test38_constrain_order_cupy(): @needs_cupy -def test39_implicit_conversion_cupy(): +def test41_implicit_conversion_cupy(): with warnings.catch_warnings(): warnings.simplefilter("ignore") try: diff --git a/tests/test_ndarray_ext.pyi.ref b/tests/test_ndarray_ext.pyi.ref index 29a9518c..aecbe0cf 100644 --- a/tests/test_ndarray_ext.pyi.ref +++ b/tests/test_ndarray_ext.pyi.ref @@ -43,6 +43,30 @@ def check_order(arg: Annotated[ArrayLike, dict(order='F')], /) -> str: ... @overload def check_order(arg: ArrayLike, /) -> str: ... +def check_ro_by_const_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... + +def check_ro_by_const_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... + +def check_ro_by_rvalue_ref_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... + +def check_ro_by_rvalue_ref_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... + +def check_ro_by_value_const_float64(arg: Annotated[ArrayLike, dict(dtype='float64', writable=False, shape=(None))], /) -> bool: ... + +def check_ro_by_value_ro(arg: Annotated[ArrayLike, dict(writable=False)], /) -> bool: ... + +def check_rw_by_const_ref(arg: ArrayLike, /) -> bool: ... + +def check_rw_by_const_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... + +def check_rw_by_rvalue_ref(arg: ArrayLike, /) -> bool: ... + +def check_rw_by_rvalue_ref_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... + +def check_rw_by_value(arg: ArrayLike, /) -> bool: ... + +def check_rw_by_value_float64(arg: Annotated[ArrayLike, dict(dtype='float64', shape=(None))], /) -> bool: ... + def check_shape_ptr(arg: ArrayLike, /) -> bool: ... def check_stride_ptr(arg: ArrayLike, /) -> bool: ... @@ -59,6 +83,8 @@ def fill_view_4(x: Annotated[ArrayLike, dict(dtype='float32', shape=(3, 4), orde def fill_view_5(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... +def fill_view_6(x: Annotated[ArrayLike, dict(dtype='complex64', shape=(2, 2), order='C', device='cpu')]) -> None: ... + def get_is_valid(array: Annotated[ArrayLike, dict(writable=False)] | None) -> bool: ... def get_itemsize(array: ArrayLike | None) -> int: ...