Skip to content

Commit

Permalink
minor ndarray tweak to improve an error message
Browse files Browse the repository at this point in the history
  • Loading branch information
wjakob committed Sep 10, 2023
1 parent 36b9778 commit e0bcd51
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions include/nanobind/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,21 +410,26 @@ template <typename... Args> class ndarray {
private:
template <typename... Ts>
NB_INLINE int64_t byte_offset(Ts... indices) const {
static_assert(
!std::is_same_v<Scalar, void>,
constexpr bool has_scalar = !std::is_same_v<Scalar, void>,
has_shape = !std::is_same_v<typename Info::shape_type, void>;

static_assert(has_scalar,
"To use nb::ndarray::operator(), you must add a scalar type "
"annotation (e.g. 'float') to the ndarray template parameters.");
static_assert(
!std::is_same_v<Scalar, void>,
static_assert(has_shape,
"To use nb::ndarray::operator(), you must add a nb::shape<> "
"annotation to the ndarray template parameters.");
static_assert(sizeof...(Ts) == Info::shape_type::size,
"nb::ndarray::operator(): invalid number of arguments");
size_t counter = 0;
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
if constexpr (has_scalar && has_shape) {
static_assert(sizeof...(Ts) == Info::shape_type::size,
"nb::ndarray::operator(): invalid number of arguments");
size_t counter = 0;
int64_t index = 0;
((index += int64_t(indices) * m_dltensor.strides[counter++]), ...);

return (int64_t) m_dltensor.byte_offset + index * sizeof(typename Info::scalar_type);
} else {
return 0;
}
}

detail::ndarray_handle *m_handle = nullptr;
Expand Down

0 comments on commit e0bcd51

Please sign in to comment.