Skip to content

Commit

Permalink
Python bindings: Nanobind SID adapter (#1762)
Browse files Browse the repository at this point in the history
Add a SID adapter for dlpack via nanobind.
  • Loading branch information
petiaccja authored and havogt committed Aug 16, 2023
1 parent 4f36f88 commit b3a37d0
Show file tree
Hide file tree
Showing 5 changed files with 190 additions and 17 deletions.
90 changes: 90 additions & 0 deletions include/gridtools/storage/adapter/nanobind_adapter.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#pragma once

#include <algorithm>
#include <stdexcept>

#include <nanobind/nanobind.h>
#include <nanobind/ndarray.h>

#include "../../common/array.hpp"
#include "../../common/integral_constant.hpp"
#include "../../common/tuple.hpp"
#include "../../sid/simple_ptr_holder.hpp"
#include "../../sid/synthetic.hpp"
#include "../../sid/unknown_kind.hpp"

namespace gridtools {
namespace nanobind_sid_adapter_impl_ {

// Use nanobind::any for dynamic stride, use an integral value for static stride.
template <std::size_t... Values>
using stride_spec = std::index_sequence<Values...>;

template <class IndexSequence>
struct dynamic_strides_helper;

template <std::size_t... Indices>
struct dynamic_strides_helper<std::index_sequence<Indices...>> {
using type = stride_spec<(void(Indices), nanobind::any)...>;
};

template <std::size_t N>
using fully_dynamic_strides = typename dynamic_strides_helper<std::make_index_sequence<N>>::type;

template <std::size_t SpecValue>
auto select_static_stride_value(std::size_t dyn_value) {
if constexpr (SpecValue == nanobind::any) {
return dyn_value;
} else {
if (SpecValue != dyn_value) {
throw std::invalid_argument("static stride in stride specification doesn't match dynamic stride");
}
return gridtools::integral_constant<std::size_t, SpecValue>{};
}
}

template <std::size_t... SpecValues, std::size_t... IndexValues>
auto select_static_strides_helper(
stride_spec<SpecValues...>, const std::size_t *dyn_values, std::index_sequence<IndexValues...>) {

return gridtools::tuple{select_static_stride_value<SpecValues>(dyn_values[IndexValues])...};
}

template <std::size_t... SpecValues>
auto select_static_strides(stride_spec<SpecValues...> spec, const std::size_t *dyn_values) {
return select_static_strides_helper(spec, dyn_values, std::make_index_sequence<sizeof...(SpecValues)>{});
}

template <class T,
std::size_t... Sizes,
class... Args,
class Strides = fully_dynamic_strides<sizeof...(Sizes)>,
class StridesKind = sid::unknown_kind>
auto as_sid(nanobind::ndarray<T, nanobind::shape<Sizes...>, Args...> ndarray,
Strides stride_spec_ = {},
StridesKind = {}) {
using sid::property;
const auto ptr = ndarray.data();
constexpr auto ndim = sizeof...(Sizes);
assert(ndim == ndarray.ndim());
gridtools::array<std::size_t, ndim> shape;
std::copy_n(ndarray.shape_ptr(), ndim, shape.begin());
gridtools::array<std::size_t, ndim> strides_;
std::copy_n(ndarray.stride_ptr(), ndim, strides_.begin());
const auto strides = select_static_strides(stride_spec_, strides_.data());

return sid::synthetic()
.template set<property::origin>(sid::host_device::simple_ptr_holder<T *>{ptr})
.template set<property::strides>(strides)
.template set<property::strides_kind, StridesKind>()
.template set<property::lower_bounds>(gridtools::array<integral_constant<std::size_t, 0>, ndim>())
.template set<property::upper_bounds>(shape);
}
} // namespace nanobind_sid_adapter_impl_

namespace nanobind {
using nanobind_sid_adapter_impl_::as_sid;
using nanobind_sid_adapter_impl_::fully_dynamic_strides;
using nanobind_sid_adapter_impl_::stride_spec;
} // namespace nanobind
} // namespace gridtools
15 changes: 15 additions & 0 deletions tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,21 @@ function(gridtools_add_mpi_test arch tgt)
endfunction()


# This option is only useful if we have a broken setup,
# where a Python environment is found but doesn't work.
option(GT_TESTS_ENABLE_PYTHON_TESTS "Enable Python tests" ON)
option(GT_TESTS_REQUIRE_Python "Enable Python tests" OFF)

# Find Python libraries
if (${GT_TESTS_ENABLE_PYTHON_TESTS})
if(GT_TESTS_REQUIRE_Python)
set(_GT_TESTS_Python_REQUIRED "REQUIRED")
endif()

find_package(Python "3.8" ${_GT_TESTS_Python_REQUIRED} COMPONENTS Interpreter Development NumPy)
endif()


add_subdirectory(regression)
add_subdirectory(unit_tests)

Expand Down
21 changes: 4 additions & 17 deletions tests/regression/py_bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,20 +1,7 @@
option(GT_TESTS_REQUIRE_Python "CMake will abort if no Python environment can be found" OFF)

if(GT_TESTS_REQUIRE_Python)
set(_GT_TESTS_Python_REQUIRED "REQUIRED")
endif()

find_package(Python3 ${_GT_TESTS_Python_REQUIRED} COMPONENTS Interpreter Development NumPy)

if (NOT Python3_FOUND OR NOT Python3_Development_FOUND OR NOT Python3_NumPy_FOUND)
if (NOT ${GT_TESTS_ENABLE_PYTHON_TESTS})
return()
endif()

# This option is only useful if we have a broken setup,
# where a Python environment is found but doesn't work.
option(GT_TESTS_ENABLE_PYTHON_TESTS "Enable Python tests" ON)

if(NOT GT_TESTS_ENABLE_PYTHON_TESTS)
if (NOT ${Python_Development_FOUND} OR NOT ${Python_NumPy_FOUND})
return()
endif()

Expand All @@ -28,12 +15,12 @@ FetchContent_Declare(
FetchContent_GetProperties(pybind11)
if(NOT pybind11_POPULATED)
FetchContent_Populate(pybind11)
set(PYTHON_EXECUTABLE ${Python3_EXECUTABLE})
set(PYTHON_EXECUTABLE ${Python_EXECUTABLE})
add_subdirectory(${pybind11_SOURCE_DIR} ${pybind11_BINARY_DIR})
endif()

pybind11_add_module(py_implementation implementation.cpp)

target_link_libraries(py_implementation PRIVATE gridtools)

add_test(NAME py_bindings COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/driver.py ${GT_CUDA_TYPE})
add_test(NAME py_bindings COMMAND ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/driver.py ${GT_CUDA_TYPE})
20 changes: 20 additions & 0 deletions tests/unit_tests/storage/adapter/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,23 @@ gridtools_add_unit_test(test_fortran_array_adapter
SOURCES test_fortran_array_adapter.cpp
LIBRARIES cpp_bindgen_interface
NO_NVCC)


if (${GT_TESTS_ENABLE_PYTHON_TESTS})
if (${Python_Development_FOUND})
FetchContent_Declare(
nanobind
GIT_REPOSITORY https://github.com/wjakob/nanobind.git
GIT_TAG v1.4.0
)
FetchContent_MakeAvailable(nanobind)
nanobind_build_library(nanobind-static)

gridtools_add_unit_test(test_nanobind_adapter
SOURCES test_nanobind_adapter.cpp
LIBRARIES nanobind-static Python::Python
NO_NVCC)
nanobind_compile_options(test_nanobind_adapter)
nanobind_link_options(test_nanobind_adapter)
endif()
endif()
61 changes: 61 additions & 0 deletions tests/unit_tests/storage/adapter/test_nanobind_adapter.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* GridTools
*
* Copyright (c) 2014-2023, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <array>

#include <gtest/gtest.h>

#include <gridtools/common/integral_constant.hpp>
#include <gridtools/sid/concept.hpp>
#include <gridtools/storage/adapter/nanobind_adapter.hpp>

namespace nb = nanobind;

TEST(NanobindAdapter, DataDynStrides) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray);
const auto s_origin = sid_get_origin(sid);
const auto s_strides = sid_get_strides(sid);
const auto s_ptr = s_origin();

EXPECT_EQ(s_ptr, data);
EXPECT_EQ(strides[0], gridtools::get<0>(s_strides));
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST(NanobindAdapter, StaticStridesMatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

const auto sid = gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<1, nanobind::any>{});
const auto s_strides = sid_get_strides(sid);

EXPECT_EQ(strides[0], gridtools::get<0>(s_strides).value);
EXPECT_EQ(strides[1], gridtools::get<1>(s_strides));
}

TEST(NanobindAdapter, StaticStridesMismatch) {
const auto data = reinterpret_cast<void *>(0xDEADBEEF);
constexpr int ndim = 2;
constexpr std::array<std::size_t, ndim> shape = {3, 4};
constexpr std::array<std::intptr_t, ndim> strides = {1, 3};
nb::ndarray<int, nb::shape<nb::any, nb::any>> ndarray{data, ndim, shape.data(), nb::handle{}, strides.data()};

EXPECT_THROW(
gridtools::nanobind::as_sid(ndarray, gridtools::nanobind::stride_spec<2, nanobind::any>{}), std::invalid_argument);
}

0 comments on commit b3a37d0

Please sign in to comment.