Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

64-bit Offsets in DeviceRadixSort #340

Merged
merged 1 commit into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 33 additions & 12 deletions cub/agent/agent_radix_sort_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
#include "../block/radix_rank_sort_operations.cuh"
#include "../config.cuh"
#include "../thread/thread_reduce.cuh"
#include "../util_math.cuh"
#include "../util_type.cuh"


Expand Down Expand Up @@ -97,12 +98,13 @@ struct AgentRadixSortHistogram
};

typedef RadixSortTwiddle<IS_DESCENDING, KeyT> Twiddle;
typedef OffsetT ShmemAtomicOffsetT;
typedef std::uint32_t ShmemCounterT;
typedef ShmemCounterT ShmemAtomicCounterT;
typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;

struct _TempStorage
{
ShmemAtomicOffsetT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
ShmemAtomicCounterT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
};

struct TempStorage : Uninitialized<_TempStorage> {};
Expand Down Expand Up @@ -133,8 +135,11 @@ struct AgentRadixSortHistogram
d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
num_items(num_items), begin_bit(begin_bit), end_bit(end_bit),
num_passes((end_bit - begin_bit + RADIX_BITS - 1) / RADIX_BITS)
{}

__device__ __forceinline__ void Init()
{
// init bins
// Initialize bins to 0.
#pragma unroll
for (int bin = threadIdx.x; bin < RADIX_DIGITS; bin += BLOCK_THREADS)
{
Expand Down Expand Up @@ -219,17 +224,33 @@ struct AgentRadixSortHistogram

__device__ __forceinline__ void Process()
{
for (OffsetT tile_offset = blockIdx.x * TILE_ITEMS; tile_offset < num_items;
tile_offset += TILE_ITEMS * gridDim.x)
// Within a portion, avoid overflowing (u)int32 counters.
// Between portions, accumulate results in global memory.
const OffsetT MAX_PORTION_SIZE = 1 << 30;
alliepiper marked this conversation as resolved.
Show resolved Hide resolved
OffsetT num_portions = cub::DivideAndRoundUp(num_items, MAX_PORTION_SIZE);
for (OffsetT portion = 0; portion < num_portions; ++portion)
{
UnsignedBits keys[ITEMS_PER_THREAD];
LoadTileKeys(tile_offset, keys);
AccumulateSharedHistograms(tile_offset, keys);
}
CTA_SYNC();
// Reset the counters.
Init();
CTA_SYNC();

// accumulate in global memory
AccumulateGlobalHistograms();
// Process the tiles.
OffsetT portion_offset = portion * MAX_PORTION_SIZE;
OffsetT portion_end =
portion_offset + CUB_MIN(MAX_PORTION_SIZE, num_items - portion_offset);
for (OffsetT tile_offset = portion_offset + blockIdx.x * TILE_ITEMS;
tile_offset < portion_end; tile_offset += TILE_ITEMS * gridDim.x)
{
UnsignedBits keys[ITEMS_PER_THREAD];
LoadTileKeys(tile_offset, keys);
AccumulateSharedHistograms(tile_offset, keys);
}
CTA_SYNC();

// Accumulate the result in global memory.
AccumulateGlobalHistograms();
CTA_SYNC();
}
}
};

Expand Down
29 changes: 15 additions & 14 deletions cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ template <
bool IS_DESCENDING,
typename KeyT,
typename ValueT,
typename OffsetT>
typename OffsetT,
typename PortionOffsetT>
struct AgentRadixSortOnesweep
{
// constants
Expand All @@ -110,14 +111,14 @@ struct AgentRadixSortOnesweep
WARP_THREADS = CUB_PTX_WARP_THREADS,
BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS,
WARP_MASK = ~0,
LOOKBACK_PARTIAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 2),
LOOKBACK_GLOBAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 1),
LOOKBACK_PARTIAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 2),
LOOKBACK_GLOBAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 1),
LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK,
LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK,
};

typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;
typedef OffsetT AtomicOffsetT;
typedef PortionOffsetT AtomicOffsetT;

static const RadixRankAlgorithm RANK_ALGORITHM =
AgentRadixSortOnesweepPolicy::RANK_ALGORITHM;
Expand Down Expand Up @@ -165,7 +166,7 @@ struct AgentRadixSortOnesweep
union
{
OffsetT global_offsets[RADIX_DIGITS];
OffsetT block_idx;
PortionOffsetT block_idx;
};
};

Expand All @@ -183,13 +184,13 @@ struct AgentRadixSortOnesweep
const UnsignedBits* d_keys_in;
ValueT* d_values_out;
const ValueT* d_values_in;
OffsetT num_items;
PortionOffsetT num_items;
ShiftDigitExtractor<KeyT> digit_extractor;

// other thread variables
int warp;
int lane;
OffsetT block_idx;
PortionOffsetT block_idx;
bool full_block;

// helper methods
Expand All @@ -213,7 +214,7 @@ struct AgentRadixSortOnesweep
{
// write the local sum into the bin
AtomicOffsetT& loc = d_lookback[block_idx * RADIX_DIGITS + bin];
OffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
PortionOffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
ThreadStore<STORE_VOLATILE>(&loc, value);
}
}
Expand All @@ -222,7 +223,7 @@ struct AgentRadixSortOnesweep
struct CountsCallback
{
typedef AgentRadixSortOnesweep<AgentRadixSortOnesweepPolicy, IS_DESCENDING, KeyT,
ValueT, OffsetT> AgentT;
ValueT, OffsetT, PortionOffsetT> AgentT;
AgentT& agent;
int (&bins)[BINS_PER_THREAD];
UnsignedBits (&keys)[ITEMS_PER_THREAD];
Expand Down Expand Up @@ -251,13 +252,13 @@ struct AgentRadixSortOnesweep
int bin = ThreadBin(u);
if (FULL_BINS || bin < RADIX_DIGITS)
{
OffsetT inc_sum = bins[u];
PortionOffsetT inc_sum = bins[u];
int want_mask = ~0;
// backtrack as long as necessary
for (OffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
for (PortionOffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
{
// wait for some value to appear
OffsetT value_j = 0;
PortionOffsetT value_j = 0;
AtomicOffsetT& loc_j = d_lookback[block_jdx * RADIX_DIGITS + bin];
do {
__threadfence_block(); // prevent hoisting loads from loop
Expand All @@ -269,7 +270,7 @@ struct AgentRadixSortOnesweep
if (value_j & LOOKBACK_GLOBAL_MASK) break;
}
AtomicOffsetT& loc_i = d_lookback[block_idx * RADIX_DIGITS + bin];
OffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
PortionOffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
ThreadStore<STORE_VOLATILE>(&loc_i, value_i);
s.global_offsets[bin] += inc_sum - bins[u];
}
Expand Down Expand Up @@ -638,7 +639,7 @@ struct AgentRadixSortOnesweep
const KeyT *d_keys_in,
ValueT *d_values_out,
const ValueT *d_values_in,
OffsetT num_items,
PortionOffsetT num_items,
int current_bit,
int num_bits)
: s(temp_storage.Alias())
Expand Down
Loading