Skip to content

Commit

Permalink
add wip improved merging + tests
Browse files Browse the repository at this point in the history
  • Loading branch information
upsj committed Jan 13, 2020
1 parent d44f3e3 commit b6a4c75
Show file tree
Hide file tree
Showing 2 changed files with 175 additions and 83 deletions.
156 changes: 82 additions & 74 deletions common/components/merging.hpp.inc
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,41 @@ namespace detail {

/**
* @internal
* The result from the @ref warp_merge_step function.
* The result from the @ref parallel_merge_step function.
*/
template <typename ValueType>
struct warp_merge_result {
/** true iff the element at this thread originates from sequence `a`. */
bool merged_a;
/** the warp lane index from which the element at this thread originates. */
int source_index;
/** the value of the element at this thread. */
ValueType value;
/** how many elements of `a` did we merge? */
struct merge_result {
/** The element of a being merged in the current thread. */
ValueType a_val;
/** The element of b being merged in the current thread. */
ValueType b_val;
/** The index from a that is being merged in the current thread. */
int a_idx;
/** The index from b that is being merged in the current thread. */
int b_idx;
/** The number of elements from a that have been merged in total. */
int a_advance;
/** The number of elements from b that have been merged in total. */
int b_advance;
};


template <typename ValueType, typename Group>
ValueType shfl_read(Group group, ValueType v, int idx)
{
// avoid lane ID overflows (especially in the negative direction)
return group.shfl(v, (idx + group.size()) % group.size());
}


template <typename ValueType>
ValueType checked_load(const ValueType *p, int i, int size)
{
constexpr auto sentinel = device_numeric_limits<ValueType>::max();
return i < size ? p[i] : sentinel;
}


} // namespace detail


Expand All @@ -63,51 +83,43 @@ struct warp_merge_result {
*
* @param a the element from the first range
* @param b the element from the second range
* @param size the maximum number of elements from both ranges
* @param group the cooperative group that executes the merge
* @return a structure containing the merge result distributed over the group.
*/
template <typename ValueType>
__device__ detail::warp_merge_result<ValueType> warp_merge_step(ValueType a,
ValueType b)
template <typename ValueType, typename Group>
__device__ detail::merge_result<ValueType> group_merge_step(ValueType a,
ValueType b,
int size,
Group group)
{
auto warp = group::thread_block_tile<config::warp_size>();
// thread i takes care of the diagonal (0, i) -> (i, 0)
auto diag = threadIdx.x % config::warp_size;
auto a_pos = [&](int i) {
// avoid out-of-bounds lane accesses
return (config::warp_size + diag - i) % config::warp_size;
};
auto b_pos = [&](int i) { return i; };

// find the intersection of the diagonal with the merge path
// we need to "extend" the diagonals such that they all have the same size
// otherwise not all threads would participate in the shuffle.
auto intersection =
synchronous_binary_search<config::warp_size>([&](int d) {
auto a_remote = warp.shfl(a, a_pos(d));
auto b_remote = warp.shfl(b, b_pos(d));
// outside the diagonal, the predicate must be true (sentinel)
return a_remote < b_remote || d > diag;
});
// determine if we merged a or b:
auto intersection_prev = warp.shfl_up(intersection, 1);
intersection_prev = diag == 0 ? 0 : intersection_prev;
auto a_pos_int = a_pos(intersection_prev);
auto b_pos_int = b_pos(intersection_prev);
// the intersection index is equal to b_pos
auto merged_a = intersection_prev == intersection;

// fetch the corresponding values of a and b
auto a_remote = warp.shfl(a, a_pos_int);
auto b_remote = warp.shfl(b, b_pos_int);
// assert a_remove < b_remote

// merge them at the current position
detail::warp_merge_result<ValueType> result{};
result.merged_a = merged_a;
result.source_index = merged_a ? a_pos_int : b_pos_int;
result.value = merged_a ? a_remote : b_remote;
result.a_advance =
__popc(warp.ballot(merged_a)); // TODO replace by shuffle
return result;
// round up to the next power of two
auto size_pow2 = 1 << (32 - clz(int32(size - 1)));
// thread i takes care of ith element of the merged sequence
auto i = group.thread_rank();

// we want to find the smallest index `x` such that a[x] > b[i - x]
// this means that `x` gives us the first element of `a` that is no longer
// part of the output in the range [0...i]. That especially means that
// a[0...x - 1] and b[0...i - x] form the output range c[0...i]
// and a[x - 1] and b[i - x] were the last elements to be compared.
auto max_x = synchronous_binary_search(size_pow2, [&](int x) {
auto a_remote = shfl_read(group, a, x);
auto b_remote = shfl_read(group, b, (i - x));

// `true` sentinel for i - x < 0
return a_remote > b_remote || x > i;
});

auto a_idx = max_x - 1;
auto b_idx = i - x;
auto a_val = shfl_read(group, a, a_idx);
auto b_val = shfl_read(group, b, b_idx);
auto cmp = a_val < b_val;
auto a_advance = popcnt(group.ballot(cmp));
auto b_advance = group.size() - a_advance;

return {a_val, b_val, a_idx, b_idx, a_advance, b_advance};
}


Expand All @@ -127,45 +139,41 @@ __device__ detail::warp_merge_result<ValueType> warp_merge_step(ValueType a,
* element is from `a` or `b`, and `source_index` is the index
* of `val` in `a` or `b`.
*/
template <typename IndexType, typename ValueType, typename Callback>
__device__ void warp_merge(const ValueType *a, IndexType a_size,
const ValueType *b, IndexType b_size,
Callback merge_fn)
template <typename IndexType, typename ValueType, typename Group,
typename Callback>
__device__ void group_merge(const ValueType *a, IndexType a_size,
const ValueType *b, IndexType b_size, Group group,
Callback merge_fn)
{
constexpr auto sentinel = device_numeric_limits<ValueType>::max();
auto warp = group::thread_block_tile<config::warp_size>();
auto c_size = a_size + b_size;
auto a_begin = 0;
auto b_begin = 0;
auto c_begin = 0;
auto lane = threadIdx.x % config::warp_size;
auto checked_load = [](const int *p, int i, int size) {
return i < size ? p[i] : sentinel;
};
auto a_local = checked_load(a, lane, a_size);
auto b_local = checked_load(b, lane, b_size);
auto lane = group.thread_rank();
auto a_cur = checked_load(a, lane, a_size);
auto b_cur = checked_load(b, lane, b_size);
while (c_begin < c_size) {
auto merge_result = warp_merge_step(a_local, b_local);
auto merge_size = min(group.size(), c_size - c_begin);
auto merge_result = group_merge_step(a_cur, b_cur, merge_size, group);
if (c_begin + lane < c_size) {
auto source_idx = merge_result.source_index + merge_result.merged_a
? a_begin
: b_begin;
merge_fn(merge_result.value, merge_result.merged_a, source_idx);
merge_fn(merge_result.a_idx + a_begin, merge_result.a_val,
merge_result.b_idx + b_begin, merge_result.b_val,
c_begin + lane);
}
auto a_advance = merge_result.a_advance;
auto b_advance = config::warp_size - a_advance;
auto b_advance = merge_result.b_advance;
a_begin += a_advance;
b_begin += b_advance;
c_begin += a_advance + b_advance;
c_begin += group.size();

// shuffle the unmerged elements to the front
a_local = warp.shfl_down(a_local, a_advance);
b_local = warp.shfl_down(b_local, b_advance);
/*
* To optimize memory access, we load the new elements for `a` and `b`
* with a single load instruction:
* the lower part of the warp loads new elements for `a`
* the upper part of the warp loads new elements for `b`
* the lower part of the group loads new elements for `a`
* the upper part of the group loads new elements for `b`
* `load_lane` is the part-local lane idx
* The elements for `a` have to be shuffled up afterwards.
*/
Expand All @@ -178,7 +186,7 @@ __device__ void warp_merge(const ValueType *a, IndexType a_size,
auto load_idx = load_begin + load_lane;
auto loaded = checked_load(load_source, load_idx, load_size);
// shuffle the `a` values to the end of the warp
auto lower_loaded = warp.shfl_up(loaded, b_advance);
auto lower_loaded = group.shfl_up(loaded, b_advance);
a_local = lane < b_advance ? a_local : lower_loaded;
b_local = lane < a_advance ? b_local : loaded;
}
Expand Down
102 changes: 93 additions & 9 deletions hip/test/components/merging.hip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "hip/components/searching.hip.hpp"


#include <algorithm>
#include <memory>
#include <random>

Expand All @@ -60,23 +61,106 @@ class Merging : public ::testing::Test {
: ref(gko::ReferenceExecutor::create()),
hip(gko::HipExecutor::create(0, ref)),
rng(123456),
ddata(hip)
max_size{1637},
sizes{},
data1(ref, max_size),
data2(ref, max_size),
outdata(ref, 2 * max_size),
refdata(ref, 2 * max_size),
ddata1(hip),
ddata2(hip),
doutdata(hip, 2 * max_size)
{}

void init_data()
{
std::uniform_int_distribution<gko::int32> dist(0, max_size);
for (auto i = 0; i < max_size; ++i) {
data1.get_data()[i] = dist(rng);
data2.get_data()[i] = dist(rng);
}
std::sort(data1.get_data(), data1.get_data() + max_size);
std::sort(data2.get_data(), data2.get_data() + max_size);

ddata1 = data1;
ddata2 = data2;
}

void assert_eq_ref(int size, int eq_size)
{
outdata = doutdata;
auto out_ptr = outdata.get_const_data();
auto out_end = out_ptr + eq_size;
auto ref_ptr = refdata.get_data();
std::copy_n(data1.get_const_data(), size, ref_ptr);
std::copy_n(data2.get_const_data(), size, ref_ptr + size);
std::sort(ref_ptr, ref_ptr + 2 * size);

ASSERT_TRUE(std::equal(out_ptr, out_end, ref_ptr));
}

std::shared_ptr<gko::ReferenceExecutor> ref;
std::shared_ptr<gko::HipExecutor> hip;
std::default_random_engine rng;
gko::Array<gko::int32> ddata;
};


TEST_F(Merging, MergeStep) {}


TEST_F(Merging, FullMerge) {}
int max_size;
std::vector<int> sizes;
gko::Array<gko::int32> data1;
gko::Array<gko::int32> data2;
gko::Array<gko::int32> outdata;
gko::Array<gko::int32> refdata;
gko::Array<gko::int32> ddata1;
gko::Array<gko::int32> ddata2;
gko::Array<gko::int32> doutdata;
};


TEST_F(Merging, EqualFullMerge) {}
__global__ test_merge_step(const gko::int32 *a, const gko::int32 *b,
gko::int32 *c)
{
auto warp = tiled_partition<config::warp_size>(this_thread_block());
auto i = warp.thread_rank();
auto result = kernel::group_merge_step(a[i], b[i], config::warp_size, warp);
c[i] = min(result.a_val, result.b_val);
}

TEST_F(Merging, MergeStep)
{
for (auto i = 0; i < rng_runs; ++i) {
init_data();
test_merge_step<<<1, config::warp_size>>>(ddata1.get_const_data(),
ddata2.get_const_data(),
doutdata.get_data());

assert_eq_ref(config::warp_size, config::warp_size);
}
}


__global__ test_merge(const gko::int32 *a, const gko::int32 *b, int size,
gko::int32 *c)
{
auto warp = tiled_partition<config::warp_size>(this_thread_block());
kernel::group_merge(
a, size, b, size, warp,
[&](int a_idx, gko::int32 a_val, int b_idx, gko::int32 b_val, int i) {
c[i] = min(a_val, b_val);
});
}

TEST_F(Merging, FullMerge)
{
for (auto i = 0; i < rng_runs; ++i) {
init_data();
for (auto size : sizes) {
test_merge_step<<<1, config::warp_size>>>(
ddata1.get_const_data(), ddata2.get_const_data(), size,
doutdata.get_data());

assert_eq_ref(size, 2 * size);
}
}
}


} // namespace

0 comments on commit b6a4c75

Please sign in to comment.