Skip to content

Commit

Permalink
Add support for shifted elements_with_stride loop in range for
Browse files Browse the repository at this point in the history
  • Loading branch information
ericcano authored and fwyzard committed Dec 21, 2023
1 parent 3dfb1a9 commit ce64e0a
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions HeterogeneousCore/AlpakaInterface/interface/workdivision.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,31 @@ namespace cms::alpakatools {
};

/* elements_with_stride
*
* `elements_with_stride(acc, [first, ]extent)` returns an iteratable range that spans the element indices required to
* cover the given problem size:
* - `first` (optional) is index to the first element; if not specified, the loop starts from 0;
* - `extent` is the total size of the problem, including any elements that may come before `first`.
*/

template <typename TAcc, typename = std::enable_if_t<alpaka::isAccelerator<TAcc> and alpaka::Dim<TAcc>::value == 1>>
class elements_with_stride {
public:
ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc)
: elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
extent_{stride_} {}

ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc, Idx extent)
: elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
thread_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
extent_{extent} {}

ALPAKA_FN_ACC inline elements_with_stride(TAcc const& acc, Idx first, Idx extent)
: elements_{alpaka::getWorkDiv<alpaka::Thread, alpaka::Elems>(acc)[0u]},
first_{alpaka::getIdx<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_ + first},
stride_{alpaka::getWorkDiv<alpaka::Grid, alpaka::Threads>(acc)[0u] * elements_},
extent_{extent} {}

Expand Down Expand Up @@ -164,13 +175,13 @@ namespace cms::alpakatools {
Idx range_;
};

ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, thread_); }
ALPAKA_FN_ACC inline iterator begin() const { return iterator(elements_, stride_, extent_, first_); }

ALPAKA_FN_ACC inline iterator end() const { return iterator(elements_, stride_, extent_, extent_); }

private:
const Idx elements_;
const Idx thread_;
const Idx first_;
const Idx stride_;
const Idx extent_;
};
Expand Down

0 comments on commit ce64e0a

Please sign in to comment.