diff --git a/AUTHORS b/AUTHORS index 20ca08997..45790d056 100644 --- a/AUTHORS +++ b/AUTHORS @@ -15,3 +15,4 @@ Willem Deconinck (wdeconinck), ECMWF Auriane Reverdell (aurianer), ETH Zurich (CSCS) Mikael Simberg (msimberg), ETH Zurich (CSCS) Till Ehrengruber (tehrengruber), ETH Zurich (CSCS) +Péter Kardos (petiaccja), ETH Zurich (EXCLAIM) diff --git a/include/gridtools/fn/sid_neighbor_table.hpp b/include/gridtools/fn/sid_neighbor_table.hpp new file mode 100644 index 000000000..f61733bf3 --- /dev/null +++ b/include/gridtools/fn/sid_neighbor_table.hpp @@ -0,0 +1,76 @@ +/* + * GridTools + * + * Copyright (c) 2014-2022, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ +#pragma once + +#include +#include + +#include "../common/array.hpp" +#include "../fn/unstructured.hpp" +#include "../sid/concept.hpp" + +namespace gridtools::fn::sid_neighbor_table { + namespace sid_neighbor_table_impl_ { + template + struct sid_neighbor_table { + PtrHolder origin; + Strides strides; + }; + + template + GT_FUNCTION auto neighbor_table_neighbors( + sid_neighbor_table const &table, + int index) { + + using namespace gridtools::literals; + + auto ptr = table.origin(); + using element_type = std::decay_t; + + gridtools::array neighbors; + + sid::shift(ptr, sid::get_stride(table.strides), index); + for (std::size_t element_idx = 0; element_idx < MaxNumNeighbors; ++element_idx) { + neighbors[element_idx] = *ptr; + sid::shift(ptr, sid::get_stride(table.strides), 1_c); + } + return neighbors; + } + + template + auto as_neighbor_table(Sid &&sid) -> sid_neighbor_table, + sid::strides_type> { + + static_assert(gridtools::tuple_util::size()))>::value == 2, + "Neighbor tables must have exactly two dimensions: the index dimension and the neighbor dimension"); + static_assert(!std::is_same_v, + "The index dimension and the neighbor dimension must be different."); + + const auto origin = sid::get_origin(sid); + const auto strides = sid::get_strides(sid); + + return {origin, strides}; + } + } // namespace sid_neighbor_table_impl_ + + using sid_neighbor_table_impl_::as_neighbor_table; + +} // namespace gridtools::fn::sid_neighbor_table \ No newline at end of file diff --git a/tests/unit_tests/fn/CMakeLists.txt b/tests/unit_tests/fn/CMakeLists.txt index 9d920a78f..bd2d59c84 100644 --- a/tests/unit_tests/fn/CMakeLists.txt +++ b/tests/unit_tests/fn/CMakeLists.txt @@ -7,6 +7,7 @@ gridtools_add_unit_test(test_fn_run SOURCES test_fn_run.cpp) gridtools_add_unit_test(test_fn_column_stage SOURCES test_fn_column_stage.cpp) gridtools_add_unit_test(test_fn_stencil_stage SOURCES test_fn_stencil_stage.cpp LABELS fn) gridtools_add_unit_test(test_fn_unstructured SOURCES test_fn_unstructured.cpp LABELS fn) +gridtools_add_unit_test(test_fn_sid_neighbor_table SOURCES test_fn_sid_neighbor_table.cpp LABELS fn) if(TARGET _gridtools_cuda) gridtools_add_unit_test(test_fn_backend_gpu_cuda @@ -29,4 +30,8 @@ if(TARGET _gridtools_cuda) SOURCES test_extents.cu LIBRARIES _gridtools_cuda LABELS cuda fn) + gridtools_add_unit_test(test_fn_sid_neighbor_table_cuda + SOURCES test_fn_sid_neighbor_table.cu + LIBRARIES _gridtools_cuda + LABELS cuda fn) endif() diff --git a/tests/unit_tests/fn/test_fn_sid_neighbor_table.cpp b/tests/unit_tests/fn/test_fn_sid_neighbor_table.cpp new file mode 100644 index 000000000..2f43ab7ee --- /dev/null +++ b/tests/unit_tests/fn/test_fn_sid_neighbor_table.cpp @@ -0,0 +1,45 @@ +/* + * GridTools + * + * Copyright (c) 2014-2022, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include + +#include +#include +#include + +#include + +namespace gridtools::fn { + namespace { + + using sid_neighbor_table::as_neighbor_table; + + using edge_dim_t = integral_constant; + using edge_to_cell_dim_t = integral_constant; + + TEST(sid_neighbor_table, correctness) { + constexpr std::size_t num_elements = 3; + constexpr std::size_t num_neighbors = 2; + const int contents[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}}; + const auto table = as_neighbor_table(contents); + + auto [n00, n01] = neighbor_table::neighbors(table, 0); + auto [n10, n11] = neighbor_table::neighbors(table, 1); + auto [n20, n21] = neighbor_table::neighbors(table, 2); + EXPECT_EQ(n00, 0); + EXPECT_EQ(n01, 1); + EXPECT_EQ(n10, 10); + EXPECT_EQ(n11, 11); + EXPECT_EQ(n20, 20); + EXPECT_EQ(n21, 21); + } + + } // namespace +} // namespace gridtools::fn \ No newline at end of file diff --git a/tests/unit_tests/fn/test_fn_sid_neighbor_table.cu b/tests/unit_tests/fn/test_fn_sid_neighbor_table.cu new file mode 100644 index 000000000..14c393496 --- /dev/null +++ b/tests/unit_tests/fn/test_fn_sid_neighbor_table.cu @@ -0,0 +1,65 @@ +/* + * GridTools + * + * Copyright (c) 2014-2021, ETH Zurich + * All rights reserved. + * + * Please, refer to the LICENSE file in the root directory. + * SPDX-License-Identifier: BSD-3-Clause + */ + +#include + +#include +#include +#include +#include + +#include + +#include +#include + +namespace gridtools::fn { + namespace { + using sid_neighbor_table::as_neighbor_table; + + using edge_dim_t = integral_constant; + using edge_to_cell_dim_t = integral_constant; + + template + __device__ auto neighbor_table_neighbors_device(Table const &table, int index) + -> array { + return neighbor_table::neighbors(table, index); + } + + TEST(sid_neighbor_table, correctness_cuda) { + constexpr std::size_t num_elements = 3; + constexpr std::size_t num_neighbors = 2; + + const int data[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}}; + const auto device_data = cuda_util::cuda_malloc(num_elements * num_neighbors); + GT_CUDA_CHECK(cudaMemcpy(device_data.get(), &data, sizeof data, cudaMemcpyHostToDevice)); + using dim_hymap_t = hymap::keys; + auto contents = sid::synthetic() + .set(sid::host_device::simple_ptr_holder(device_data.get())) + .set(dim_hymap_t::make_values(num_neighbors, 1)); + + const auto table = as_neighbor_table(contents); + using table_t = std::decay_t; + + auto [n00, n01] = on_device::exec( + GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device), table, 0); + auto [n10, n11] = on_device::exec( + GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device), table, 1); + auto [n20, n21] = on_device::exec( + GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device), table, 2); + EXPECT_EQ(n00, 0); + EXPECT_EQ(n01, 1); + EXPECT_EQ(n10, 10); + EXPECT_EQ(n11, 11); + EXPECT_EQ(n20, 20); + EXPECT_EQ(n21, 21); + } + } // namespace +} // namespace gridtools::fn \ No newline at end of file