diff --git a/include/gridtools/storage/adapter/nanobind_adapter.hpp b/include/gridtools/storage/adapter/nanobind_adapter.hpp index a2962e9ba..df9f2135b 100644 --- a/include/gridtools/storage/adapter/nanobind_adapter.hpp +++ b/include/gridtools/storage/adapter/nanobind_adapter.hpp @@ -25,24 +25,24 @@ namespace gridtools { namespace nanobind_sid_adapter_impl_ { - // Use nanobind::any for dynamic stride, use an integral value for static stride. - template - using stride_spec = std::index_sequence; + // Use `-1` for dynamic stride, use an integral value for static stride. + template + using stride_spec = std::integer_sequence; template struct dynamic_strides_helper; template struct dynamic_strides_helper> { - using type = stride_spec<(void(Indices), nanobind::any)...>; + using type = stride_spec<(void(Indices), -1)...>; }; template using fully_dynamic_strides = typename dynamic_strides_helper>::type; - template + template auto select_static_stride_value(std::size_t dyn_value) { - if constexpr (SpecValue == nanobind::any) { + if constexpr (SpecValue == -1) { return dyn_value; } else { if (SpecValue != dyn_value) { @@ -52,20 +52,20 @@ namespace gridtools { } } - template + template auto select_static_strides_helper( stride_spec, const std::size_t *dyn_values, std::index_sequence) { return gridtools::tuple{select_static_stride_value(dyn_values[IndexValues])...}; } - template + template auto select_static_strides(stride_spec spec, const std::size_t *dyn_values) { return select_static_strides_helper(spec, dyn_values, std::make_index_sequence{}); } template , class StridesKind = sid::unknown_kind> diff --git a/tests/unit_tests/storage/adapter/CMakeLists.txt b/tests/unit_tests/storage/adapter/CMakeLists.txt index e825bb437..ac7c8e3a2 100644 --- a/tests/unit_tests/storage/adapter/CMakeLists.txt +++ b/tests/unit_tests/storage/adapter/CMakeLists.txt @@ -21,7 +21,7 @@ if (${GT_TESTS_ENABLE_PYTHON_TESTS}) FetchContent_Declare( nanobind GIT_REPOSITORY https://github.com/wjakob/nanobind.git - GIT_TAG v1.4.0 + GIT_TAG v2.0.0 ) FetchContent_MakeAvailable(nanobind) nanobind_build_library(nanobind-static) diff --git a/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp index 429190b8b..552b0b0e5 100644 --- a/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp +++ b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp @@ -8,22 +8,29 @@ * SPDX-License-Identifier: BSD-3-Clause */ -#include - -#include +#include +#include +#include #include #include -#include + +#include namespace nb = nanobind; -TEST(NanobindAdapter, DataDynStrides) { +class python_init_fixture : public ::testing::Test { + protected: + void SetUp() override { Py_Initialize(); } + void TearDown() override { Py_FinalizeEx(); } +}; + +TEST_F(python_init_fixture, NanobindAdapterDataDynStrides) { const auto data = reinterpret_cast(0xDEADBEEF); constexpr int ndim = 2; constexpr std::array shape = {3, 4}; constexpr std::array strides = {1, 3}; - nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; const auto sid = gridtools::nanobind::as_sid(ndarray); const auto s_origin = sid_get_origin(sid); @@ -35,27 +42,27 @@ TEST(NanobindAdapter, DataDynStrides) { EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); } -TEST(NanobindAdapter, StaticStridesMatch) { +TEST_F(python_init_fixture, NanobindAdapterStaticStridesMatch) { const auto data = reinterpret_cast(0xDEADBEEF); constexpr int ndim = 2; constexpr std::array shape = {3, 4}; constexpr std::array strides = {1, 3}; - nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; - const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, nanobind::any>{}); + const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, -1>{}); const auto s_strides = sid_get_strides(sid); EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value); EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); } -TEST(NanobindAdapter, StaticStridesMismatch) { +TEST_F(python_init_fixture, NanobindAdapterStaticStridesMismatch) { const auto data = reinterpret_cast(0xDEADBEEF); constexpr int ndim = 2; constexpr std::array shape = {3, 4}; constexpr std::array strides = {1, 3}; - nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; EXPECT_THROW( - gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, nanobind::any>{}), std::invalid_argument); + gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, -1>{}), std::invalid_argument); }