forked from GridTools/gridtools
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Using SIDs as neighbour tables in functional API (GridTools#1730)
Adds a simple class that wraps an SID and implements the neighbour table concept. This makes it possible to use Python buffers as neighbour tables by first wrapping them into an SID, but any SID is suitable as a neighbour table.
- Loading branch information
Showing
5 changed files
with
192 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2022, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
#pragma once | ||
|
||
#include <cstddef> | ||
#include <type_traits> | ||
|
||
#include "../common/array.hpp" | ||
#include "../fn/unstructured.hpp" | ||
#include "../sid/concept.hpp" | ||
|
||
namespace gridtools::fn::sid_neighbor_table { | ||
namespace sid_neighbor_table_impl_ { | ||
template <class IndexDimension, | ||
class NeighborDimension, | ||
std::size_t MaxNumNeighbors, | ||
class PtrHolder, | ||
class Strides> | ||
struct sid_neighbor_table { | ||
PtrHolder origin; | ||
Strides strides; | ||
}; | ||
|
||
template <class IndexDimension, | ||
class NeighborDimension, | ||
std::size_t MaxNumNeighbors, | ||
class PtrHolder, | ||
class Strides> | ||
GT_FUNCTION auto neighbor_table_neighbors( | ||
sid_neighbor_table<IndexDimension, NeighborDimension, MaxNumNeighbors, PtrHolder, Strides> const &table, | ||
int index) { | ||
|
||
using namespace gridtools::literals; | ||
|
||
auto ptr = table.origin(); | ||
using element_type = std::decay_t<decltype(*ptr)>; | ||
|
||
gridtools::array<element_type, MaxNumNeighbors> neighbors; | ||
|
||
sid::shift(ptr, sid::get_stride<IndexDimension>(table.strides), index); | ||
for (std::size_t element_idx = 0; element_idx < MaxNumNeighbors; ++element_idx) { | ||
neighbors[element_idx] = *ptr; | ||
sid::shift(ptr, sid::get_stride<NeighborDimension>(table.strides), 1_c); | ||
} | ||
return neighbors; | ||
} | ||
|
||
template <class IndexDimension, class NeighborDimension, std::size_t MaxNumNeighbors, class Sid> | ||
auto as_neighbor_table(Sid &&sid) -> sid_neighbor_table<IndexDimension, | ||
NeighborDimension, | ||
MaxNumNeighbors, | ||
sid::ptr_holder_type<Sid>, | ||
sid::strides_type<Sid>> { | ||
|
||
static_assert(gridtools::tuple_util::size<decltype(sid::get_strides(std::declval<Sid>()))>::value == 2, | ||
"Neighbor tables must have exactly two dimensions: the index dimension and the neighbor dimension"); | ||
static_assert(!std::is_same_v<IndexDimension, NeighborDimension>, | ||
"The index dimension and the neighbor dimension must be different."); | ||
|
||
const auto origin = sid::get_origin(sid); | ||
const auto strides = sid::get_strides(sid); | ||
|
||
return {origin, strides}; | ||
} | ||
} // namespace sid_neighbor_table_impl_ | ||
|
||
using sid_neighbor_table_impl_::as_neighbor_table; | ||
|
||
} // namespace gridtools::fn::sid_neighbor_table |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2022, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#include <gridtools/fn/sid_neighbor_table.hpp> | ||
|
||
#include <array> | ||
#include <cstddef> | ||
#include <cstdint> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
namespace gridtools::fn { | ||
namespace { | ||
|
||
using sid_neighbor_table::as_neighbor_table; | ||
|
||
using edge_dim_t = integral_constant<int_t, 0>; | ||
using edge_to_cell_dim_t = integral_constant<int_t, 1>; | ||
|
||
TEST(sid_neighbor_table, correctness) { | ||
constexpr std::size_t num_elements = 3; | ||
constexpr std::size_t num_neighbors = 2; | ||
const int contents[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}}; | ||
const auto table = as_neighbor_table<edge_dim_t, edge_to_cell_dim_t, num_neighbors>(contents); | ||
|
||
auto [n00, n01] = neighbor_table::neighbors(table, 0); | ||
auto [n10, n11] = neighbor_table::neighbors(table, 1); | ||
auto [n20, n21] = neighbor_table::neighbors(table, 2); | ||
EXPECT_EQ(n00, 0); | ||
EXPECT_EQ(n01, 1); | ||
EXPECT_EQ(n10, 10); | ||
EXPECT_EQ(n11, 11); | ||
EXPECT_EQ(n20, 20); | ||
EXPECT_EQ(n21, 21); | ||
} | ||
|
||
} // namespace | ||
} // namespace gridtools::fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
/* | ||
* GridTools | ||
* | ||
* Copyright (c) 2014-2021, ETH Zurich | ||
* All rights reserved. | ||
* | ||
* Please, refer to the LICENSE file in the root directory. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
|
||
#include <gridtools/fn/sid_neighbor_table.hpp> | ||
|
||
#include <array> | ||
#include <cstddef> | ||
#include <cstdint> | ||
#include <type_traits> | ||
|
||
#include <gtest/gtest.h> | ||
|
||
#include <cuda_test_helper.hpp> | ||
#include <gridtools/sid/synthetic.hpp> | ||
|
||
namespace gridtools::fn { | ||
namespace { | ||
using sid_neighbor_table::as_neighbor_table; | ||
|
||
using edge_dim_t = integral_constant<int_t, 0>; | ||
using edge_to_cell_dim_t = integral_constant<int_t, 1>; | ||
|
||
template <class Table> | ||
__device__ auto neighbor_table_neighbors_device(Table const &table, int index) | ||
-> array<int, 2> { | ||
return neighbor_table::neighbors(table, index); | ||
} | ||
|
||
TEST(sid_neighbor_table, correctness_cuda) { | ||
constexpr std::size_t num_elements = 3; | ||
constexpr std::size_t num_neighbors = 2; | ||
|
||
const int data[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}}; | ||
const auto device_data = cuda_util::cuda_malloc<int>(num_elements * num_neighbors); | ||
GT_CUDA_CHECK(cudaMemcpy(device_data.get(), &data, sizeof data, cudaMemcpyHostToDevice)); | ||
using dim_hymap_t = hymap::keys<edge_dim_t, edge_to_cell_dim_t>; | ||
auto contents = sid::synthetic() | ||
.set<sid::property::origin>(sid::host_device::simple_ptr_holder(device_data.get())) | ||
.set<sid::property::strides>(dim_hymap_t::make_values(num_neighbors, 1)); | ||
|
||
const auto table = as_neighbor_table<edge_dim_t, edge_to_cell_dim_t, num_neighbors>(contents); | ||
using table_t = std::decay_t<decltype(table)>; | ||
|
||
auto [n00, n01] = on_device::exec( | ||
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 0); | ||
auto [n10, n11] = on_device::exec( | ||
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 1); | ||
auto [n20, n21] = on_device::exec( | ||
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 2); | ||
EXPECT_EQ(n00, 0); | ||
EXPECT_EQ(n01, 1); | ||
EXPECT_EQ(n10, 10); | ||
EXPECT_EQ(n11, 11); | ||
EXPECT_EQ(n20, 20); | ||
EXPECT_EQ(n21, 21); | ||
} | ||
} // namespace | ||
} // namespace gridtools::fn |