Skip to content

Commit

Permalink
Provide pushforward methods for Kokkos::View indexing
Browse files Browse the repository at this point in the history
Previously, we relied on automatically generated pushforwards
for these operator calls, but this solution is way safer and
should work for more machines and Kokkos versions.
  • Loading branch information
gojakuch authored and vgvassilev committed Aug 26, 2024
1 parent 6f4b081 commit c3b76c0
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 2 deletions.
71 changes: 69 additions & 2 deletions include/clad/Differentiator/KokkosBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,75 @@ constructor_pushforward(
Kokkos::View<DataType, ViewParams...>(
"_diff_" + name, idx0, idx1, idx2, idx3, idx4, idx5, idx6, idx7)};
}

/// View indexing
template <typename View, typename Idx>
inline clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, const View* d_v,
Idx /*d_i0*/) {
return {(*v)(i0), (*d_v)(i0)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, const View* d_v,
Idx /*d_i0*/, Idx /*d_i1*/) {
return {(*v)(i0, i1), (*d_v)(i0, i1)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2,
const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/) {
return {(*v)(i0, i1, i2), (*d_v)(i0, i1, i2)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3,
const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/, Idx /*d_i3*/) {
return {(*v)(i0, i1, i2, i3), (*d_v)(i0, i1, i2, i3)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/, Idx /*d_i3*/, Idx /*d_i4*/) {
return {(*v)(i0, i1, i2, i3, i4), (*d_v)(i0, i1, i2, i3, i4)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, const View* d_v, Idx /*d_i0*/, Idx /*d_i1*/,
Idx /*d_i2*/, Idx /*d_i3*/, Idx /*d_i4*/,
Idx /*d_i5*/) {
return {(*v)(i0, i1, i2, i3, i4, i5), (*d_v)(i0, i1, i2, i3, i4, i5)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, Idx i6, const View* d_v, Idx /*d_i0*/,
Idx /*d_i1*/, Idx /*d_i2*/, Idx /*d_i3*/,
Idx /*d_i4*/, Idx /*d_i5*/, Idx /*d_i6*/) {
return {(*v)(i0, i1, i2, i3, i4, i5, i6), (*d_v)(i0, i1, i2, i3, i4, i5, i6)};
}
template <typename View, typename Idx>
clad::ValueAndPushforward<typename View::reference_type,
typename View::reference_type>
operator_call_pushforward(const View* v, Idx i0, Idx i1, Idx i2, Idx i3, Idx i4,
Idx i5, Idx i6, Idx i7, const View* d_v, Idx /*d_i0*/,
Idx /*d_i1*/, Idx /*d_i2*/, Idx /*d_i3*/,
Idx /*d_i4*/, Idx /*d_i5*/, Idx /*d_i6*/,
Idx /*d_i7*/) {
return {(*v)(i0, i1, i2, i3, i4, i5, i6, i7),
(*d_v)(i0, i1, i2, i3, i4, i5, i6, i7)};
}
} // namespace class_functions

/// Kokkos functions (view utils)
Expand All @@ -39,7 +108,6 @@ inline void deep_copy_pushforward(const View1& dst, const View2& src, T param,
deep_copy(dst, src);
deep_copy(d_dst, d_src);
}

template <class View>
inline void resize_pushforward(View& v, const size_t n0, const size_t n1,
const size_t n2, const size_t n3,
Expand All @@ -52,7 +120,6 @@ inline void resize_pushforward(View& v, const size_t n0, const size_t n1,
::Kokkos::resize(v, n0, n1, n2, n3, n4, n5, n6, n7);
::Kokkos::resize(d_v, n0, n1, n2, n3, n4, n5, n6, n7);
}

template <class I, class dI, class View>
inline void resize_pushforward(const I& arg, View& v, const size_t n0,
const size_t n1, const size_t n2,
Expand Down
31 changes: 31 additions & 0 deletions unittests/Kokkos/ViewBasics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,4 +248,35 @@ TEST(ViewBasics, TestResize4) {
for (double x = 3; x <= 5; x += 1)
for (double y = 3; y <= 5; y += 1)
EXPECT_NEAR(df.execute(x, y), df_true(x, y), eps);
}

template <typename View> struct FooModifier {
double x;

FooModifier(View& v, double x) : x(x) {}

void operator()(View& v) { v(1, 0, 1, 0, 1, 0, 1) += x; }
};

double f_basics_call(double x) {
Kokkos::View<double[2][2][2][2][2][2][2], Kokkos::LayoutLeft,
Kokkos::HostSpace>
a("a");
Kokkos::deep_copy(a, 3 * x);

FooModifier<Kokkos::View<double[2][2][2][2][2][2][2], Kokkos::LayoutLeft,
Kokkos::HostSpace>>
f(a, x);

f(a);

return a(1, 0, 1, 0, 1, 0, 1);
}

TEST(ViewBasics, FunctorCall4) {
const double eps = 1e-8;

auto df = clad::differentiate(f_basics_call, 0);
for (double x = 3; x <= 5; x += 1)
EXPECT_NEAR(df.execute(x), 4, eps);
}

0 comments on commit c3b76c0

Please sign in to comment.