Skip to content

Commit

Permalink
Update nanobind to v2 (#1777)
Browse files Browse the repository at this point in the history
- nb::any is replaced by `nb::ssize_t(-1)`
- Because of a change in `~ndarray`, we need to make sure that Python is initialized in our tests (see wjakob/nanobind#377)
  • Loading branch information
havogt authored Jun 13, 2024
1 parent d704213 commit 88e7d91
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 22 deletions.
18 changes: 9 additions & 9 deletions include/gridtools/storage/adapter/nanobind_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,24 +25,24 @@
namespace gridtools {
namespace nanobind_sid_adapter_impl_ {

// Use nanobind::any for dynamic stride, use an integral value for static stride.
template <std::size_t... Values>
using stride_spec = std::index_sequence<Values...>;
// Use `-1` for dynamic stride, use an integral value for static stride.
template <nanobind::ssize_t... Values>
using stride_spec = std::integer_sequence<nanobind::ssize_t, Values...>;

template <class IndexSequence>
struct dynamic_strides_helper;

template <std::size_t... Indices>
struct dynamic_strides_helper<std::index_sequence<Indices...>> {
using type = stride_spec<(void(Indices), nanobind::any)...>;
using type = stride_spec<(void(Indices), -1)...>;
};

template <std::size_t N>
using fully_dynamic_strides = typename dynamic_strides_helper<std::make_index_sequence<N>>::type;

template <std::size_t SpecValue>
template <nanobind::ssize_t SpecValue>
auto select_static_stride_value(std::size_t dyn_value) {
if constexpr (SpecValue == nanobind::any) {
if constexpr (SpecValue == -1) {
return dyn_value;
} else {
if (SpecValue != dyn_value) {
Expand All @@ -52,20 +52,20 @@ namespace gridtools {
}
}

template <std::size_t... SpecValues, std::size_t... IndexValues>
template <nanobind::ssize_t... SpecValues, std::size_t... IndexValues>
auto select_static_strides_helper(
stride_spec<SpecValues...>, const std::size_t *dyn_values, std::index_sequence<IndexValues...>) {

return gridtools::tuple{select_static_stride_value<SpecValues>(dyn_values[IndexValues])...};
}

template <std::size_t... SpecValues>
template <nanobind::ssize_t... SpecValues>
auto select_static_strides(stride_spec<SpecValues...> spec, const std::size_t *dyn_values) {
return select_static_strides_helper(spec, dyn_values, std::make_index_sequence<sizeof...(SpecValues)>{});
}

template <class T,
std::size_t... Sizes,
nanobind::ssize_t... Sizes,
class... Args,
class Strides = fully_dynamic_strides<sizeof...(Sizes)>,
class StridesKind = sid::unknown_kind>
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/storage/adapter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ if (${GT_TESTS_ENABLE_PYTHON_TESTS})
FetchContent_Declare(
nanobind
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
GIT_TAG v1.4.0
GIT_TAG v2.0.0
)
FetchContent_MakeAvailable(nanobind)
nanobind_build_library(nanobind-static)
Expand Down
31 changes: 19 additions & 12 deletions tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,29 @@
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <array>

#include <gtest/gtest.h>
#include <gridtools/storage/adapter/nanobind_adapter.hpp>

#include <Python.h>
#include <array>
#include <gridtools/common/integral_constant.hpp>
#include <gridtools/sid/concept.hpp>
#include <gridtools/storage/adapter/nanobind_adapter.hpp>

#include <gtest/gtest.h>

namespace nb = nanobind;

TEST(NanobindAdapter, DataDynStrides) {
class python_init_fixture : public ::testing::Test {
protected:
void SetUp() override { Py_Initialize(); }
void TearDown() override { Py_FinalizeEx(); }
};

TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
const auto s_origin = sid_get_origin(sid);
Expand All @@ -35,27 +42,27 @@ TEST(NanobindAdapter, DataDynStrides) {
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST(NanobindAdapter, StaticStridesMatch) {
TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, nanobind::any>{});
const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, -1>{});
const auto s_strides = sid_get_strides(sid);

EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value);
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST(NanobindAdapter, StaticStridesMismatch) {
TEST_F(python_init_fixture, NanobindAdapterStaticStridesMismatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};
nb::ndarray<int, nb::shape<-1, -1>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

EXPECT_THROW(
gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, nanobind::any>{}), std::invalid_argument);
gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, -1>{}), std::invalid_argument);
}

0 comments on commit 88e7d91

Please sign in to comment.