Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using SIDs as neighbour tables in functional API #1730

Merged
merged 27 commits into from
Oct 11, 2022
Merged
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a68f595
added neighbor table and a test
petiaccja Aug 17, 2022
9c5a26b
proper reading of elements
petiaccja Aug 18, 2022
9f27022
includes and size_t
petiaccja Aug 24, 2022
d93bf1f
more includes and size_t
petiaccja Aug 24, 2022
813a41f
includes
petiaccja Aug 24, 2022
ae473c2
includes
petiaccja Aug 24, 2022
47c8393
test namespaces
petiaccja Aug 24, 2022
d1f99ef
use sid::shift
petiaccja Aug 24, 2022
dd74069
use C arrays instead of std::array
petiaccja Aug 24, 2022
fa40976
code style
petiaccja Aug 24, 2022
6fb6a86
changed implementation to use ptr holder and strides instead of whole…
petiaccja Aug 24, 2022
0bbbdc7
make ADL function a friend
petiaccja Aug 25, 2022
ab55a32
test on cuda
petiaccja Aug 25, 2022
fa522e3
remove unused includes
petiaccja Aug 25, 2022
d9569af
cuda playing funny games
petiaccja Aug 25, 2022
b67bdf8
use dynamic device memory
petiaccja Aug 25, 2022
9afdaae
to trailing return type
petiaccja Aug 25, 2022
f47a4ec
using neighbor_table::neighbors wrapper fun
petiaccja Aug 25, 2022
af63168
removed superfluous template params
petiaccja Aug 26, 2022
fbde942
consistent const placement
petiaccja Aug 26, 2022
63eb832
i don't know, just trying to get it compile...
petiaccja Aug 26, 2022
37742bd
ints
petiaccja Aug 26, 2022
33bbd76
removed friend function to fix old gcc builds
petiaccja Aug 26, 2022
9e58c84
use integral constant expressions
petiaccja Aug 29, 2022
95f088a
use decay_t to support const ptr holders
petiaccja Aug 30, 2022
4497a48
use const array in tests
petiaccja Aug 30, 2022
907d43c
Update AUTHORS
havogt Oct 4, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions include/gridtools/fn/sid_neighbor_table.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* GridTools
*
* Copyright (c) 2014-2022, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/
#pragma once

#include <cstddef>
#include <type_traits>
petiaccja marked this conversation as resolved.
Show resolved Hide resolved

#include "../common/array.hpp"
#include "../fn/unstructured.hpp"
#include "../sid/concept.hpp"

namespace gridtools::fn::sid_neighbor_table {
namespace sid_neighbor_table_impl_ {
template <class IndexDimension,
class NeighborDimension,
std::size_t MaxNumNeighbors,
class PtrHolder,
class Strides>
struct sid_neighbor_table {
PtrHolder origin;
Strides strides;
};

template <class IndexDimension,
class NeighborDimension,
std::size_t MaxNumNeighbors,
class PtrHolder,
class Strides>
GT_FUNCTION auto neighbor_table_neighbors(
sid_neighbor_table<IndexDimension, NeighborDimension, MaxNumNeighbors, PtrHolder, Strides> const &table,
int index) {

using namespace gridtools::literals;

auto ptr = table.origin();
using element_type = std::remove_reference_t<decltype(*ptr)>;
petiaccja marked this conversation as resolved.
Show resolved Hide resolved

petiaccja marked this conversation as resolved.
Show resolved Hide resolved
gridtools::array<element_type, MaxNumNeighbors> neighbors;

sid::shift(ptr, sid::get_stride<IndexDimension>(table.strides), index);
for (std::size_t element_idx = 0; element_idx < MaxNumNeighbors; ++element_idx) {
neighbors[element_idx] = *ptr;
sid::shift(ptr, sid::get_stride<NeighborDimension>(table.strides), 1_c);
}
return neighbors;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With explicit return type, friend works for the case that failed in CI, but we can keep the current version.

Suggested change
class Strides>
struct sid_neighbor_table {
PtrHolder origin;
Strides strides;
};
template <class IndexDimension,
class NeighborDimension,
std::size_t MaxNumNeighbors,
class PtrHolder,
class Strides>
GT_FUNCTION auto neighbor_table_neighbors(
sid_neighbor_table<IndexDimension, NeighborDimension, MaxNumNeighbors, PtrHolder, Strides> const &table,
int index) {
using namespace gridtools::literals;
auto ptr = table.origin();
using element_type = std::remove_reference_t<decltype(*ptr)>;
gridtools::array<element_type, MaxNumNeighbors> neighbors;
sid::shift(ptr, sid::get_stride<IndexDimension>(table.strides), index);
for (std::size_t element_idx = 0; element_idx < MaxNumNeighbors; ++element_idx) {
neighbors[element_idx] = *ptr;
sid::shift(ptr, sid::get_stride<NeighborDimension>(table.strides), 1_c);
}
return neighbors;
}
class Strides,
class ElementType = std::remove_reference_t<decltype(*std::declval<PtrHolder>()())>>
struct sid_neighbor_table {
PtrHolder origin;
Strides strides;
friend GT_FUNCTION gridtools::array<ElementType, MaxNumNeighbors> neighbor_table_neighbors(sid_neighbor_table const &table, int index) {
auto ptr = table.origin();
gridtools::array<ElementType, MaxNumNeighbors> neighbors;
sid::shift(ptr, sid::get_stride<IndexDimension>(table.strides), index);
for (std::size_t element_idx = 0; element_idx < MaxNumNeighbors; ++element_idx) {
neighbors[element_idx] = *ptr;
sid::shift(ptr, sid::get_stride<NeighborDimension>(table.strides), 1);
}
return neighbors;
}
};

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda got used to the friend method so it would be nice, but I think I'll just leave it and we can move it once we drop support for older GCC.


template <class IndexDimension, class NeighborDimension, std::size_t MaxNumNeighbors, class Sid>
auto as_neighbor_table(Sid &&sid) -> sid_neighbor_table<IndexDimension,
NeighborDimension,
MaxNumNeighbors,
sid::ptr_holder_type<Sid>,
sid::strides_type<Sid>> {

static_assert(gridtools::tuple_util::size<decltype(sid::get_strides(std::declval<Sid>()))>::value == 2,
"Neighbor tables must have exactly two dimensions: the index dimension and the neighbor dimension");
static_assert(!std::is_same_v<IndexDimension, NeighborDimension>,
"The index dimension and the neighbor dimension must be different.");

const auto origin = sid::get_origin(sid);
const auto strides = sid::get_strides(sid);

return {origin, strides};
havogt marked this conversation as resolved.
Show resolved Hide resolved
}
} // namespace sid_neighbor_table_impl_

using sid_neighbor_table_impl_::as_neighbor_table;

} // namespace gridtools::fn::sid_neighbor_table
havogt marked this conversation as resolved.
Show resolved Hide resolved
5 changes: 5 additions & 0 deletions tests/unit_tests/fn/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ gridtools_add_unit_test(test_fn_run SOURCES test_fn_run.cpp)
gridtools_add_unit_test(test_fn_column_stage SOURCES test_fn_column_stage.cpp)
gridtools_add_unit_test(test_fn_stencil_stage SOURCES test_fn_stencil_stage.cpp LABELS fn)
gridtools_add_unit_test(test_fn_unstructured SOURCES test_fn_unstructured.cpp LABELS fn)
gridtools_add_unit_test(test_fn_sid_neighbor_table SOURCES test_fn_sid_neighbor_table.cpp LABELS fn)

if(TARGET _gridtools_cuda)
gridtools_add_unit_test(test_fn_backend_gpu_cuda
Expand All @@ -29,4 +30,8 @@ if(TARGET _gridtools_cuda)
SOURCES test_extents.cu
LIBRARIES _gridtools_cuda
LABELS cuda fn)
gridtools_add_unit_test(test_fn_sid_neighbor_table_cuda
SOURCES test_fn_sid_neighbor_table.cu
LIBRARIES _gridtools_cuda
LABELS cuda fn)
endif()
45 changes: 45 additions & 0 deletions tests/unit_tests/fn/test_fn_sid_neighbor_table.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* GridTools
*
* Copyright (c) 2014-2022, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <gridtools/fn/sid_neighbor_table.hpp>

#include <array>
#include <cstddef>
#include <cstdint>

#include <gtest/gtest.h>

namespace gridtools::fn {
namespace {

using sid_neighbor_table::as_neighbor_table;

using edge_dim_t = integral_constant<int_t, 0>;
using edge_to_cell_dim_t = integral_constant<int_t, 1>;

TEST(sid_neighbor_table, correctness) {
constexpr std::size_t num_elements = 3;
constexpr std::size_t num_neighbors = 2;
int contents[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}};
const auto table = as_neighbor_table<edge_dim_t, edge_to_cell_dim_t, num_neighbors>(contents);

auto [n00, n01] = neighbor_table::neighbors(table, 0);
auto [n10, n11] = neighbor_table::neighbors(table, 1);
auto [n20, n21] = neighbor_table::neighbors(table, 2);
EXPECT_EQ(n00, 0);
EXPECT_EQ(n01, 1);
EXPECT_EQ(n10, 10);
EXPECT_EQ(n11, 11);
EXPECT_EQ(n20, 20);
EXPECT_EQ(n21, 21);
}

} // namespace
} // namespace gridtools::fn
65 changes: 65 additions & 0 deletions tests/unit_tests/fn/test_fn_sid_neighbor_table.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* GridTools
*
* Copyright (c) 2014-2021, ETH Zurich
* All rights reserved.
*
* Please, refer to the LICENSE file in the root directory.
* SPDX-License-Identifier: BSD-3-Clause
*/

#include <gridtools/fn/sid_neighbor_table.hpp>

#include <array>
#include <cstddef>
#include <cstdint>
#include <type_traits>

#include <gtest/gtest.h>

#include <cuda_test_helper.hpp>
#include <gridtools/sid/synthetic.hpp>

namespace gridtools::fn {
namespace {
using sid_neighbor_table::as_neighbor_table;

using edge_dim_t = integral_constant<int_t, 0>;
using edge_to_cell_dim_t = integral_constant<int_t, 1>;

template <class Table>
__device__ auto neighbor_table_neighbors_device(Table const &table, int index)
-> array<int, 2> {
Comment on lines +31 to +32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
__device__ auto neighbor_table_neighbors_device(Table const &table, int index)
-> array<int, 2> {
constexpr __device__ auto neighbor_table_neighbors_device(Table const &table, int index) {

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this work?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, surprisingly, yes, but I don't think it should though, because it's calling non-constexpr functions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some special handling of constexpr functions in nvcc, even if not called in constexpr context it affects how nvcc handles the functions in good and in bad ways...

return neighbor_table::neighbors(table, index);
}

TEST(sid_neighbor_table, correctness_cuda) {
constexpr std::size_t num_elements = 3;
constexpr std::size_t num_neighbors = 2;

const int data[num_elements][num_neighbors] = {{0, 1}, {10, 11}, {20, 21}};
const auto device_data = cuda_util::cuda_malloc<int>(num_elements * num_neighbors);
GT_CUDA_CHECK(cudaMemcpy(device_data.get(), &data, sizeof data, cudaMemcpyHostToDevice));
using dim_hymap_t = hymap::keys<edge_dim_t, edge_to_cell_dim_t>;
auto contents = sid::synthetic()
.set<sid::property::origin>(sid::host_device::simple_ptr_holder(device_data.get()))
.set<sid::property::strides>(dim_hymap_t::make_values(num_neighbors, 1));

const auto table = as_neighbor_table<edge_dim_t, edge_to_cell_dim_t, num_neighbors>(contents);
using table_t = std::decay_t<decltype(table)>;

auto [n00, n01] = on_device::exec(
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 0);
auto [n10, n11] = on_device::exec(
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 1);
auto [n20, n21] = on_device::exec(
GT_MAKE_INTEGRAL_CONSTANT_FROM_VALUE(&neighbor_table_neighbors_device<table_t>), table, 2);
EXPECT_EQ(n00, 0);
EXPECT_EQ(n01, 1);
EXPECT_EQ(n10, 10);
EXPECT_EQ(n11, 11);
EXPECT_EQ(n20, 20);
EXPECT_EQ(n21, 21);
}
} // namespace
} // namespace gridtools::fn