From b3a37d01f8d15abb48d0e9a0a46f807107b1ef78 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Kardos?= Date: Wed, 9 Aug 2023 10:30:31 +0200 Subject: [PATCH] Python bindings: Nanobind SID adapter (#1762) Add a SID adapter for dlpack via nanobind. --- .../storage/adapter/nanobind_adapter.hpp | 90 +++++++++++++++++++ tests/CMakeLists.txt | 15 ++++ tests/regression/py_bindings/CMakeLists.txt | 21 +---- .../unit_tests/storage/adapter/CMakeLists.txt | 20 +++++ .../storage/adapter/test_nanobind_adapter.cpp | 61 +++++++++++++ 5 files changed, 190 insertions(+), 17 deletions(-) create mode 100644 include/gridtools/storage/adapter/nanobind_adapter.hpp create mode 100644 tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp diff --git a/include/gridtools/storage/adapter/nanobind_adapter.hpp b/include/gridtools/storage/adapter/nanobind_adapter.hpp new file mode 100644 index 0000000000..4fb12adc46 --- /dev/null +++ b/include/gridtools/storage/adapter/nanobind_adapter.hpp @@ -0,0 +1,90 @@ +#pragma once + +#include +#include + +#include +#include + +#include "../../common/array.hpp" +#include "../../common/integral_constant.hpp" +#include "../../common/tuple.hpp" +#include "../../sid/simple_ptr_holder.hpp" +#include "../../sid/synthetic.hpp" +#include "../../sid/unknown_kind.hpp" + +namespace gridtools { + namespace nanobind_sid_adapter_impl_ { + + // Use nanobind::any for dynamic stride, use an integral value for static stride. + template + using stride_spec = std::index_sequence; + + template + struct dynamic_strides_helper; + + template + struct dynamic_strides_helper> { + using type = stride_spec<(void(Indices), nanobind::any)...>; + }; + + template + using fully_dynamic_strides = typename dynamic_strides_helper>::type; + + template + auto select_static_stride_value(std::size_t dyn_value) { + if constexpr (SpecValue == nanobind::any) { + return dyn_value; + } else { + if (SpecValue != dyn_value) { + throw std::invalid_argument("static stride in stride specification doesn't match dynamic stride"); + } + return gridtools::integral_constant{}; + } + } + + template + auto select_static_strides_helper( + stride_spec, const std::size_t *dyn_values, std::index_sequence) { + + return gridtools::tuple{select_static_stride_value(dyn_values[IndexValues])...}; + } + + template + auto select_static_strides(stride_spec spec, const std::size_t *dyn_values) { + return select_static_strides_helper(spec, dyn_values, std::make_index_sequence{}); + } + + template , + class StridesKind = sid::unknown_kind> + auto as_sid(nanobind::ndarray, Args...> ndarray, + Strides stride_spec_ = {}, + StridesKind = {}) { + using sid::property; + const auto ptr = ndarray.data(); + constexpr auto ndim = sizeof...(Sizes); + assert(ndim == ndarray.ndim()); + gridtools::array shape; + std::copy_n(ndarray.shape_ptr(), ndim, shape.begin()); + gridtools::array strides_; + std::copy_n(ndarray.stride_ptr(), ndim, strides_.begin()); + const auto strides = select_static_strides(stride_spec_, strides_.data()); + + return sid::synthetic() + .template set(sid::host_device::simple_ptr_holder{ptr}) + .template set(strides) + .template set() + .template set(gridtools::array, ndim>()) + .template set(shape); + } + } // namespace nanobind_sid_adapter_impl_ + + namespace nanobind { + using nanobind_sid_adapter_impl_::as_sid; + using nanobind_sid_adapter_impl_::fully_dynamic_strides; + using nanobind_sid_adapter_impl_::stride_spec; + } // namespace nanobind +} // namespace gridtools \ No newline at end of file diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c149097595..143943db0c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -101,6 +101,21 @@ function(gridtools_add_mpi_test arch tgt) endfunction() +# This option is only useful if we have a broken setup, +# where a Python environment is found but doesn't work. +option(GT_TESTS_ENABLE_PYTHON_TESTS "Enable Python tests" ON) +option(GT_TESTS_REQUIRE_Python "Enable Python tests" OFF) + +# Find Python libraries +if (${GT_TESTS_ENABLE_PYTHON_TESTS}) + if(GT_TESTS_REQUIRE_Python) + set(_GT_TESTS_Python_REQUIRED "REQUIRED") + endif() + + find_package(Python "3.8" ${_GT_TESTS_Python_REQUIRED} COMPONENTS Interpreter Development NumPy) +endif() + + add_subdirectory(regression) add_subdirectory(unit_tests) diff --git a/tests/regression/py_bindings/CMakeLists.txt b/tests/regression/py_bindings/CMakeLists.txt index 967f00d8f7..c2187ff2dc 100644 --- a/tests/regression/py_bindings/CMakeLists.txt +++ b/tests/regression/py_bindings/CMakeLists.txt @@ -1,20 +1,7 @@ -option(GT_TESTS_REQUIRE_Python "CMake will abort if no Python environment can be found" OFF) - -if(GT_TESTS_REQUIRE_Python) - set(_GT_TESTS_Python_REQUIRED "REQUIRED") -endif() - -find_package(Python3 ${_GT_TESTS_Python_REQUIRED} COMPONENTS Interpreter Development NumPy) - -if (NOT Python3_FOUND OR NOT Python3_Development_FOUND OR NOT Python3_NumPy_FOUND) +if (NOT ${GT_TESTS_ENABLE_PYTHON_TESTS}) return() endif() - -# This option is only useful if we have a broken setup, -# where a Python environment is found but doesn't work. -option(GT_TESTS_ENABLE_PYTHON_TESTS "Enable Python tests" ON) - -if(NOT GT_TESTS_ENABLE_PYTHON_TESTS) +if (NOT ${Python_Development_FOUND} OR NOT ${Python_NumPy_FOUND}) return() endif() @@ -28,7 +15,7 @@ FetchContent_Declare( FetchContent_GetProperties(pybind11) if(NOT pybind11_POPULATED) FetchContent_Populate(pybind11) - set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE}) + set(PYTHON_EXECUTABLE ${Python_EXECUTABLE}) add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR}) endif() @@ -36,4 +23,4 @@ pybind11_add_module(py_implementation implementation.cpp) target_link_libraries(py_implementation PRIVATE gridtools) -add_test(NAME py_bindings COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/driver.py ${GT_CUDA_TYPE}) +add_test(NAME py_bindings COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/driver.py ${GT_CUDA_TYPE}) diff --git a/tests/unit_tests/storage/adapter/CMakeLists.txt b/tests/unit_tests/storage/adapter/CMakeLists.txt index d5be60f65c..e825bb4375 100644 --- a/tests/unit_tests/storage/adapter/CMakeLists.txt +++ b/tests/unit_tests/storage/adapter/CMakeLists.txt @@ -14,3 +14,23 @@ gridtools_add_unit_test(test_fortran_array_adapter SOURCES test_fortran_array_adapter.cpp LIBRARIES cpp_bindgen_interface NO_NVCC) + + +if (${GT_TESTS_ENABLE_PYTHON_TESTS}) + if (${Python_Development_FOUND}) + FetchContent_Declare( + nanobind + GIT_REPOSITORY https://github.com/wjakob/nanobind.git + GIT_TAG v1.4.0 + ) + FetchContent_MakeAvailable(nanobind) + nanobind_build_library(nanobind-static) + + gridtools_add_unit_test(test_nanobind_adapter + SOURCES test_nanobind_adapter.cpp + LIBRARIES nanobind-static Python::Python + NO_NVCC) + nanobind_compile_options(test_nanobind_adapter) + nanobind_link_options(test_nanobind_adapter) + endif() +endif() diff --git a/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp new file mode 100644 index 0000000000..ef97a1fe61 --- /dev/null +++ b/tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp @@ -0,0 +1,61 @@ +/* + * GridTools + * + * Copyright (c) 2014-2023, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include + +#include + +#include +#include +#include + +namespace nb = nanobind; + +TEST(NanobindAdapter, DataDynStrides) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + const auto sid = gridtools::nanobind::as_sid(ndarray); + const auto s_origin = sid_get_origin(sid); + const auto s_strides = sid_get_strides(sid); + const auto s_ptr = s_origin(); + + EXPECT_EQ(s_ptr, data); + EXPECT_EQ(strides[0], gridtools::get<0>(s_strides)); + EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); +} + +TEST(NanobindAdapter, StaticStridesMatch) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, nanobind::any>{}); + const auto s_strides = sid_get_strides(sid); + + EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value); + EXPECT_EQ(strides[1], gridtools::get<1>(s_strides)); +} + +TEST(NanobindAdapter, StaticStridesMismatch) { + const auto data = reinterpret_cast(0xDEADBEEF); + constexpr int ndim = 2; + constexpr std::array shape = {3, 4}; + constexpr std::array strides = {1, 3}; + nb::ndarray> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()}; + + EXPECT_THROW( + gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, nanobind::any>{}), std::invalid_argument); +} \ No newline at end of file