Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Templated type of num_items in DeviceRadixSort.
Browse files Browse the repository at this point in the history
List of individual changes:

- OffsetT == unsigned long long for the 64-bit case
- using std::{is_same,conditional}
- using "portion" consistently for 2^28-2^30-sized chunks of the input array
- HasEnoughMemory() takes overwrite into account.
- moved checking for enough memory earlier.
- added a CTA_SYNC() to the histogram kernel
- disabled tests with NumItemsT != int for segmented sort
- testing with 4.5 bln. items
- tests for different NumItemsT
- NumItemsT for all device sorting functions
- wrapped ChooseOffsetT into namespace detail
- fixed typos
- templatized the type of num_items in 2 methods of DeviceRadixSort
- tuned radix sort with 64-bit OffsetT for V100
- tuned for 64-bit OffsetT for A100
- separate tuning parameters for 64-bit OffsetT
- improved downsweep policy for GP100
- option for 64-bit num_items with 32-bit shared memory histogram counters.
- introduced PartOffsetT into Onesweep kernel.
  - OffsetT is now only used for offsets into the whole array
    (e.g. bin counts or global read/write offsets)
  - PartOffsetT is used for offsets that do not exceed a single part
    (e.g. decoupled look-back, block index, number of items inside a part)
  - this fixes problems when OffsetT is unsigned, and also contributes
    towards supporting 64-bit num_items
  • Loading branch information
canonizer authored and alliepiper committed Jan 19, 2022
1 parent 5d31d2d commit 5912195
Show file tree
Hide file tree
Showing 5 changed files with 342 additions and 227 deletions.
44 changes: 32 additions & 12 deletions cub/agent/agent_radix_sort_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,13 @@ struct AgentRadixSortHistogram
};

typedef RadixSortTwiddle<IS_DESCENDING, KeyT> Twiddle;
typedef OffsetT ShmemAtomicOffsetT;
typedef std::uint32_t ShmemCounterT;
typedef ShmemCounterT ShmemAtomicCounterT;
typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;

struct _TempStorage
{
ShmemAtomicOffsetT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
ShmemAtomicCounterT bins[MAX_NUM_PASSES][RADIX_DIGITS][NUM_PARTS];
};

struct TempStorage : Uninitialized<_TempStorage> {};
Expand Down Expand Up @@ -133,8 +134,11 @@ struct AgentRadixSortHistogram
d_keys_in(reinterpret_cast<const UnsignedBits*>(d_keys_in)),
num_items(num_items), begin_bit(begin_bit), end_bit(end_bit),
num_passes((end_bit - begin_bit + RADIX_BITS - 1) / RADIX_BITS)
{}

__device__ __forceinline__ void Init()
{
// init bins
// Initialize bins to 0.
#pragma unroll
for (int bin = threadIdx.x; bin < RADIX_DIGITS; bin += BLOCK_THREADS)
{
Expand Down Expand Up @@ -219,17 +223,33 @@ struct AgentRadixSortHistogram

__device__ __forceinline__ void Process()
{
for (OffsetT tile_offset = blockIdx.x * TILE_ITEMS; tile_offset < num_items;
tile_offset += TILE_ITEMS * gridDim.x)
// Within a portion, avoid overflowing (u)int32 counters.
// Between portions, accumulate results in global memory.
const OffsetT MAX_PORTION_SIZE = 1 << 30;
OffsetT num_portions = cub::DivideAndRoundUp(num_items, MAX_PORTION_SIZE);
for (OffsetT portion = 0; portion < num_portions; ++portion)
{
UnsignedBits keys[ITEMS_PER_THREAD];
LoadTileKeys(tile_offset, keys);
AccumulateSharedHistograms(tile_offset, keys);
}
CTA_SYNC();
// Reset the counters.
Init();
CTA_SYNC();

// accumulate in global memory
AccumulateGlobalHistograms();
// Process the tiles.
OffsetT portion_offset = portion * MAX_PORTION_SIZE;
OffsetT portion_end =
portion_offset + CUB_MIN(MAX_PORTION_SIZE, num_items - portion_offset);
for (OffsetT tile_offset = portion_offset + blockIdx.x * TILE_ITEMS;
tile_offset < portion_end; tile_offset += TILE_ITEMS * gridDim.x)
{
UnsignedBits keys[ITEMS_PER_THREAD];
LoadTileKeys(tile_offset, keys);
AccumulateSharedHistograms(tile_offset, keys);
}
CTA_SYNC();

// Accumulate the result in global memory.
AccumulateGlobalHistograms();
CTA_SYNC();
}
}
};

Expand Down
29 changes: 15 additions & 14 deletions cub/agent/agent_radix_sort_onesweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ template <
bool IS_DESCENDING,
typename KeyT,
typename ValueT,
typename OffsetT>
typename OffsetT,
typename PortionOffsetT>
struct AgentRadixSortOnesweep
{
// constants
Expand All @@ -110,14 +111,14 @@ struct AgentRadixSortOnesweep
WARP_THREADS = CUB_PTX_WARP_THREADS,
BLOCK_WARPS = BLOCK_THREADS / WARP_THREADS,
WARP_MASK = ~0,
LOOKBACK_PARTIAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 2),
LOOKBACK_GLOBAL_MASK = 1 << (OffsetT(sizeof(OffsetT)) * 8 - 1),
LOOKBACK_PARTIAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 2),
LOOKBACK_GLOBAL_MASK = 1 << (PortionOffsetT(sizeof(PortionOffsetT)) * 8 - 1),
LOOKBACK_KIND_MASK = LOOKBACK_PARTIAL_MASK | LOOKBACK_GLOBAL_MASK,
LOOKBACK_VALUE_MASK = ~LOOKBACK_KIND_MASK,
};

typedef typename Traits<KeyT>::UnsignedBits UnsignedBits;
typedef OffsetT AtomicOffsetT;
typedef PortionOffsetT AtomicOffsetT;

static const RadixRankAlgorithm RANK_ALGORITHM =
AgentRadixSortOnesweepPolicy::RANK_ALGORITHM;
Expand Down Expand Up @@ -165,7 +166,7 @@ struct AgentRadixSortOnesweep
union
{
OffsetT global_offsets[RADIX_DIGITS];
OffsetT block_idx;
PortionOffsetT block_idx;
};
};

Expand All @@ -183,13 +184,13 @@ struct AgentRadixSortOnesweep
const UnsignedBits* d_keys_in;
ValueT* d_values_out;
const ValueT* d_values_in;
OffsetT num_items;
PortionOffsetT num_items;
ShiftDigitExtractor<KeyT> digit_extractor;

// other thread variables
int warp;
int lane;
OffsetT block_idx;
PortionOffsetT block_idx;
bool full_block;

// helper methods
Expand All @@ -213,7 +214,7 @@ struct AgentRadixSortOnesweep
{
// write the local sum into the bin
AtomicOffsetT& loc = d_lookback[block_idx * RADIX_DIGITS + bin];
OffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
PortionOffsetT value = bins[u] | LOOKBACK_PARTIAL_MASK;
ThreadStore<STORE_VOLATILE>(&loc, value);
}
}
Expand All @@ -222,7 +223,7 @@ struct AgentRadixSortOnesweep
struct CountsCallback
{
typedef AgentRadixSortOnesweep<AgentRadixSortOnesweepPolicy, IS_DESCENDING, KeyT,
ValueT, OffsetT> AgentT;
ValueT, OffsetT, PortionOffsetT> AgentT;
AgentT& agent;
int (&bins)[BINS_PER_THREAD];
UnsignedBits (&keys)[ITEMS_PER_THREAD];
Expand Down Expand Up @@ -251,13 +252,13 @@ struct AgentRadixSortOnesweep
int bin = ThreadBin(u);
if (FULL_BINS || bin < RADIX_DIGITS)
{
OffsetT inc_sum = bins[u];
PortionOffsetT inc_sum = bins[u];
int want_mask = ~0;
// backtrack as long as necessary
for (OffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
for (PortionOffsetT block_jdx = block_idx - 1; block_jdx >= 0; --block_jdx)
{
// wait for some value to appear
OffsetT value_j = 0;
PortionOffsetT value_j = 0;
AtomicOffsetT& loc_j = d_lookback[block_jdx * RADIX_DIGITS + bin];
do {
__threadfence_block(); // prevent hoisting loads from loop
Expand All @@ -269,7 +270,7 @@ struct AgentRadixSortOnesweep
if (value_j & LOOKBACK_GLOBAL_MASK) break;
}
AtomicOffsetT& loc_i = d_lookback[block_idx * RADIX_DIGITS + bin];
OffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
PortionOffsetT value_i = inc_sum | LOOKBACK_GLOBAL_MASK;
ThreadStore<STORE_VOLATILE>(&loc_i, value_i);
s.global_offsets[bin] += inc_sum - bins[u];
}
Expand Down Expand Up @@ -638,7 +639,7 @@ struct AgentRadixSortOnesweep
const KeyT *d_keys_in,
ValueT *d_values_out,
const ValueT *d_values_in,
OffsetT num_items,
PortionOffsetT num_items,
int current_bit,
int num_bits)
: s(temp_storage.Alias())
Expand Down
Loading

0 comments on commit 5912195

Please sign in to comment.