Skip to content

Commit

Permalink
Using SIDs as neighbour tables in functional API (GridTools#1730)
Browse files Browse the repository at this point in the history
Adds a simple class that wraps an SID and implements the neighbour table concept. This makes it possible to use Python buffers as neighbour tables by first wrapping them into an SID, but any SID is suitable as a neighbour table.
  • Loading branch information
petiaccja authored and havogt committed Dec 12, 2022
1 parent adb9971 commit b474f72
Show file tree
Hide file tree
Showing 5 changed files with 192 additions and 0 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ Willem Deconinck (wdeconinck), ECMWF
Auriane Reverdell (aurianer), ETH Zurich (CSCS)
Mikael Simberg (msimberg), ETH Zurich (CSCS)
Till Ehrengruber (tehrengruber), ETH Zurich (CSCS)
Péter Kardos (petiaccja), ETH Zurich (EXCLAIM)
76 changes: 76 additions & 0 deletions include/gridtools/fn/sid_neighbor_table.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* GridTools
*
* Copyright (c) 2014-2022, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once

#include <cstddef>
#include <type_traits>

#include "../common/array.hpp"
#include "../fn/unstructured.hpp"
#include "../sid/concept.hpp"

namespace gridtools::fn::sid_neighbor_table {
namespace sid_neighbor_table_impl_ {
template <class IndexDimension,
class NeighborDimension,
std::size_t MaxNumNeighbors,
class PtrHolder,
class Strides>
struct sid_neighbor_table {
PtrHolder origin;
Strides strides;
};

template <class IndexDimension,
class NeighborDimension,
std::size_t MaxNumNeighbors,
class PtrHolder,
class Strides>
GT_FUNCTION auto neighbor_table_neighbors(
sid_neighbor_table<IndexDimension, NeighborDimension, MaxNumNeighbors, PtrHolder, Strides> const &table,
int index) {

using namespace gridtools::literals;

auto ptr = table.origin();
using element_type = std::decay_t<decltype(*ptr)>;

gridtools::array<element_type, MaxNumNeighbors> neighbors;

sid::shift(ptr, sid::get_stride<IndexDimension>(table.strides), index);
for (std::size_t element_idx = 0; element_idx < MaxNumNeighbors; ++element_idx) {
neighbors[element_idx] = *ptr;
sid::shift(ptr, sid::get_stride<NeighborDimension>(table.strides), 1_c);
}
return neighbors;
}

template <class IndexDimension, class NeighborDimension, std::size_t MaxNumNeighbors, class Sid>
auto as_neighbor_table(Sid &&sid) -> sid_neighbor_table<IndexDimension,
NeighborDimension,
MaxNumNeighbors,
sid::ptr_holder_type<Sid>,
sid::strides_type<Sid>> {

static_assert(gridtools::tuple_util::size<decltype(sid::get_strides(std::declval<Sid>()))>::value == 2,
"Neighbor tables must have exactly two dimensions: the index dimension and the neighbor dimension");
static_assert(!std::is_same_v<IndexDimension, NeighborDimension>,
"The index dimension and the neighbor dimension must be different.");

const auto origin = sid::get_origin(sid);
const auto strides = sid::get_strides(sid);

return {origin, strides};
}
} // namespace sid_neighbor_table_impl_

using sid_neighbor_table_impl_::as_neighbor_table;

} // namespace gridtools::fn::sid_neighbor_table
5 changes: 5 additions & 0 deletions tests/unit_tests/fn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gridtools_add_unit_test(test_fn_run SOURCES test_fn_run.cpp)
gridtools_add_unit_test(test_fn_column_stage SOURCES test_fn_column_stage.cpp)
gridtools_add_unit_test(test_fn_stencil_stage SOURCES test_fn_stencil_stage.cpp LABELS fn)
gridtools_add_unit_test(test_fn_unstructured SOURCES test_fn_unstructured.cpp LABELS fn)
gridtools_add_unit_test(test_fn_sid_neighbor_table SOURCES test_fn_sid_neighbor_table.cpp LABELS fn)

if(TARGET _gridtools_cuda)
gridtools_add_unit_test(test_fn_backend_gpu_cuda
Expand All @@ -29,4 +30,8 @@ if(TARGET _gridtools_cuda)
SOURCES test_extents.cu
LIBRARIES _gridtools_cuda
LABELS cuda fn)
gridtools_add_unit_test(test_fn_sid_neighbor_table_cuda
SOURCES test_fn_sid_neighbor_table.cu
LIBRARIES _gridtools_cuda
LABELS cuda fn)
endif()
45 changes: 45 additions & 0 deletions tests/unit_tests/fn/test_fn_sid_neighbor_table.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* GridTools
*
* Copyright (c) 2014-2022, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <gridtools/fn/sid_neighbor_table.hpp>

#include <array>
#include <cstddef>
#include <cstdint>

#include <gtest/gtest.h>

namespace gridtools::fn {
namespace {

using sid_neighbor_table::as_neighbor_table;

using edge_dim_t = integral_constant<int_t, 0>;
using edge_to_cell_dim_t = integral_constant<int_t, 1>;

TEST(sid_neighbor_table, correctness) {
constexpr std::size_t num_elements = 3;
constexpr std::size_t num_neighbors = 2;
const int contents[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}};
const auto table = as_neighbor_table<edge_dim_t, edge_to_cell_dim_t, num_neighbors>(contents);

auto [n00, n01] = neighbor_table::neighbors(table, 0);
auto [n10, n11] = neighbor_table::neighbors(table, 1);
auto [n20, n21] = neighbor_table::neighbors(table, 2);
EXPECT_EQ(n00, 0);
EXPECT_EQ(n01, 1);
EXPECT_EQ(n10, 10);
EXPECT_EQ(n11, 11);
EXPECT_EQ(n20, 20);
EXPECT_EQ(n21, 21);
}

} // namespace
} // namespace gridtools::fn
65 changes: 65 additions & 0 deletions tests/unit_tests/fn/test_fn_sid_neighbor_table.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* GridTools
*
* Copyright (c) 2014-2021, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <gridtools/fn/sid_neighbor_table.hpp>

#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>

#include <gtest/gtest.h>

#include <cuda_test_helper.hpp>
#include <gridtools/sid/synthetic.hpp>

namespace gridtools::fn {
namespace {
using sid_neighbor_table::as_neighbor_table;

using edge_dim_t = integral_constant<int_t, 0>;
using edge_to_cell_dim_t = integral_constant<int_t, 1>;

template <class Table>
__device__ auto neighbor_table_neighbors_device(Table const &table, int index)
-> array<int, 2> {
return neighbor_table::neighbors(table, index);
}

TEST(sid_neighbor_table, correctness_cuda) {
constexpr std::size_t num_elements = 3;
constexpr std::size_t num_neighbors = 2;

const int data[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}};
const auto device_data = cuda_util::cuda_malloc<int>(num_elements * num_neighbors);
GT_CUDA_CHECK(cudaMemcpy(device_data.get(), &data, sizeof data, cudaMemcpyHostToDevice));
using dim_hymap_t = hymap::keys<edge_dim_t, edge_to_cell_dim_t>;
auto contents = sid::synthetic()
.set<sid::property::origin>(sid::host_device::simple_ptr_holder(device_data.get()))
.set<sid::property::strides>(dim_hymap_t::make_values(num_neighbors, 1));

const auto table = as_neighbor_table<edge_dim_t, edge_to_cell_dim_t, num_neighbors>(contents);
using table_t = std::decay_t<decltype(table)>;

auto [n00, n01] = on_device::exec(
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 0);
auto [n10, n11] = on_device::exec(
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 1);
auto [n20, n21] = on_device::exec(
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 2);
EXPECT_EQ(n00, 0);
EXPECT_EQ(n01, 1);
EXPECT_EQ(n10, 10);
EXPECT_EQ(n11, 11);
EXPECT_EQ(n20, 20);
EXPECT_EQ(n21, 21);
}
} // namespace
} // namespace gridtools::fn

0 comments on commit b474f72

Please sign in to comment.