Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python bindings: Nanobind SID adapter #1762

Merged
merged 15 commits into from
Aug 9, 2023
90 changes: 90 additions & 0 deletions include/gridtools/storage/adapter/nanobind_adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#pragma once

#include <algorithm>
#include <stdexcept>

#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>

#include "../../common/array.hpp"
#include "../../common/integral_constant.hpp"
#include "../../common/tuple.hpp"
#include "../../sid/simple_ptr_holder.hpp"
#include "../../sid/synthetic.hpp"
#include "../../sid/unknown_kind.hpp"

namespace gridtools {
namespace nanobind_sid_adapter_impl_ {

// Use nanobind::any for dynamic stride, use an integral value for static stride.
template <std::size_t... Values>
using stride_spec = std::index_sequence<Values...>;

template <class IndexSequence>
struct dynamic_strides_helper;

template <std::size_t... Indices>
struct dynamic_strides_helper<std::index_sequence<Indices...>> {
using type = stride_spec<(void(Indices), nanobind::any)...>;
};

template <std::size_t N>
using fully_dynamic_strides = typename dynamic_strides_helper<std::make_index_sequence<N>>::type;

template <std::size_t SpecValue>
auto select_static_stride_value(std::size_t dyn_value) {
if constexpr (SpecValue == nanobind::any) {
return dyn_value;
} else {
if (SpecValue != dyn_value) {
throw std::invalid_argument("static stride in stride specification doesn't match dynamic stride");
}
return gridtools::integral_constant<std::size_t, SpecValue>{};
}
}

template <std::size_t... SpecValues, std::size_t... IndexValues>
auto select_static_strides_helper(
stride_spec<SpecValues...>, const std::size_t *dyn_values, std::index_sequence<IndexValues...>) {

return gridtools::tuple{select_static_stride_value<SpecValues>(dyn_values[IndexValues])...};
}

template <std::size_t... SpecValues>
auto select_static_strides(stride_spec<SpecValues...> spec, const std::size_t *dyn_values) {
return select_static_strides_helper(spec, dyn_values, std::make_index_sequence<sizeof...(SpecValues)>{});
}

template <class T,
std::size_t... Sizes,
class... Args,
class Strides = fully_dynamic_strides<sizeof...(Sizes)>,
class StridesKind = sid::unknown_kind>
auto as_sid(nanobind::ndarray<T, nanobind::shape<Sizes...>, Args...> ndarray,
Strides stride_spec_ = {},
StridesKind = {}) {
using sid::property;
petiaccja marked this conversation as resolved.
Show resolved Hide resolved
const auto ptr = ndarray.data();
constexpr auto ndim = sizeof...(Sizes);
assert(ndim == ndarray.ndim());
gridtools::array<std::size_t, ndim> shape;
std::copy_n(ndarray.shape_ptr(), ndim, shape.begin());
gridtools::array<std::size_t, ndim> strides_;
std::copy_n(ndarray.stride_ptr(), ndim, strides_.begin());
const auto strides = select_static_strides(stride_spec_, strides_.data());

return sid::synthetic()
.template set<property::origin>(sid::host_device::simple_ptr_holder<T *>{ptr})
.template set<property::strides>(strides)
.template set<property::strides_kind, StridesKind>()
.template set<property::lower_bounds>(gridtools::array<integral_constant<std::size_t, 0>, ndim>())
.template set<property::upper_bounds>(shape);
}
} // namespace nanobind_sid_adapter_impl_

namespace nanobind {
using nanobind_sid_adapter_impl_::as_sid;
using nanobind_sid_adapter_impl_::fully_dynamic_strides;
using nanobind_sid_adapter_impl_::stride_spec;
} // namespace nanobind
} // namespace gridtools
1 change: 1 addition & 0 deletions include/gridtools/storage/data_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
*/
#pragma once

#include <cstdint>
#include <memory>
#include <string>
#include <type_traits>
Expand Down
15 changes: 15 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ function(gridtools_add_mpi_test arch tgt)
endfunction()


# This option is only useful if we have a broken setup,
# where a Python environment is found but doesn't work.
option(GT_TESTS_ENABLE_PYTHON_TESTS "Enable Python tests" ON)
option(GT_TESTS_REQUIRE_Python "Enable Python tests" OFF)

# Find Python libraries
if (${GT_TESTS_ENABLE_PYTHON_TESTS})
if(GT_TESTS_REQUIRE_Python)
set(_GT_TESTS_Python_REQUIRED "REQUIRED")
endif()

find_package(Python "3.8" ${_GT_TESTS_Python_REQUIRED} COMPONENTS Interpreter Development NumPy)
endif()


add_subdirectory(regression)
add_subdirectory(unit_tests)

Expand Down
21 changes: 4 additions & 17 deletions tests/regression/py_bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
option(GT_TESTS_REQUIRE_Python "CMake will abort if no Python environment can be found" OFF)

if(GT_TESTS_REQUIRE_Python)
set(_GT_TESTS_Python_REQUIRED "REQUIRED")
endif()

find_package(Python3 ${_GT_TESTS_Python_REQUIRED} COMPONENTS Interpreter Development NumPy)

if (NOT Python3_FOUND OR NOT Python3_Development_FOUND OR NOT Python3_NumPy_FOUND)
if (NOT ${GT_TESTS_ENABLE_PYTHON_TESTS})
return()
endif()

# This option is only useful if we have a broken setup,
# where a Python environment is found but doesn't work.
option(GT_TESTS_ENABLE_PYTHON_TESTS "Enable Python tests" ON)

if(NOT GT_TESTS_ENABLE_PYTHON_TESTS)
if (NOT ${Python_Development_FOUND} OR NOT ${Python_NumPy_FOUND})
return()
endif()

Expand All @@ -28,12 +15,12 @@ FetchContent_Declare(
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
FetchContent_Populate(pybind11)
set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE})
set(PYTHON_EXECUTABLE ${Python_EXECUTABLE})
add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR})
endif()

pybind11_add_module(py_implementation implementation.cpp)

target_link_libraries(py_implementation PRIVATE gridtools)

add_test(NAME py_bindings COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/driver.py ${GT_CUDA_TYPE})
add_test(NAME py_bindings COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/driver.py ${GT_CUDA_TYPE})
20 changes: 20 additions & 0 deletions tests/unit_tests/storage/adapter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,23 @@ gridtools_add_unit_test(test_fortran_array_adapter
SOURCES test_fortran_array_adapter.cpp
LIBRARIES cpp_bindgen_interface
NO_NVCC)


if (${GT_TESTS_ENABLE_PYTHON_TESTS})
if (${Python_Development_FOUND})
FetchContent_Declare(
nanobind
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
GIT_TAG v1.4.0
)
FetchContent_MakeAvailable(nanobind)
nanobind_build_library(nanobind-static)

gridtools_add_unit_test(test_nanobind_adapter
SOURCES test_nanobind_adapter.cpp
LIBRARIES nanobind-static Python::Python
NO_NVCC)
nanobind_compile_options(test_nanobind_adapter)
nanobind_link_options(test_nanobind_adapter)
endif()
endif()
61 changes: 61 additions & 0 deletions tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* GridTools
*
* Copyright (c) 2014-2023, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <array>

#include <gtest/gtest.h>

#include <gridtools/common/integral_constant.hpp>
#include <gridtools/sid/concept.hpp>
#include <gridtools/storage/adapter/nanobind_adapter.hpp>

namespace nb = nanobind;

TEST(NanobindAdapter, DataDynStrides) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
const auto s_origin = sid_get_origin(sid);
const auto s_strides = sid_get_strides(sid);
const auto s_ptr = s_origin();

EXPECT_EQ(s_ptr, data);
EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST(NanobindAdapter, StaticStridesMatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, nanobind::any>{});
const auto s_strides = sid_get_strides(sid);

EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value);
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST(NanobindAdapter, StaticStridesMismatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

EXPECT_THROW(
gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, nanobind::any>{}), std::invalid_argument);
}
Loading