Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Faster Least Significant Digit Radix Sort Implementation #204

Merged
merged 1 commit into from
Nov 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
.p4config
*~
\#*
67 changes: 30 additions & 37 deletions cub/agent/agent_radix_sort_downsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -56,16 +56,6 @@ namespace cub {
* Tuning policy types
******************************************************************************/

/**
* Radix ranking algorithm
*/
enum RadixRankAlgorithm
{
RADIX_RANK_BASIC,
RADIX_RANK_MEMOIZE,
RADIX_RANK_MATCH
};

/**
* Parameterizable tuning policy type for AgentRadixSortDownsweep
*/
Expand Down Expand Up @@ -137,6 +127,9 @@ struct AgentRadixSortDownsweep

RADIX_DIGITS = 1 << RADIX_BITS,
KEYS_ONLY = Equals<ValueT, NullType>::VALUE,
LOAD_WARP_STRIPED = RANK_ALGORITHM == RADIX_RANK_MATCH ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY ||
RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ATOMIC_OR,
};

// Input iterator wrapper type (for applying cache modifier)s
Expand All @@ -148,7 +141,15 @@ struct AgentRadixSortDownsweep
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, false, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MEMOIZE),
BlockRadixRank<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, true, SCAN_ALGORITHM>,
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH),
BlockRadixRankMatch<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING, SCAN_ALGORITHM>,
typename If<(RANK_ALGORITHM == RADIX_RANK_MATCH_EARLY_COUNTS_ANY),
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ANY>,
BlockRadixRankMatchEarlyCounts<BLOCK_THREADS, RADIX_BITS, IS_DESCENDING,
SCAN_ALGORITHM, WARP_MATCH_ATOMIC_OR>
>::Type
>::Type
>::Type
>::Type BlockRadixRankT;

Expand Down Expand Up @@ -303,16 +304,15 @@ struct AgentRadixSortDownsweep
}

/**
* Load a tile of keys (specialized for full tile, any ranking algorithm)
* Load a tile of keys (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadKeysT(temp_storage.load_keys).Load(
d_keys_in + block_offset, keys);
Expand All @@ -322,16 +322,15 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for partial tile, any ranking algorithm)
* Load a tile of keys (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -345,30 +344,29 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of keys (specialized for full tile, match ranking algorithm)
* Load a tile of keys (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys);
}


/**
* Load a tile of keys (specialized for partial tile, match ranking algorithm)
* Load a tile of keys (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadKeys(
UnsignedBits (&keys)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
UnsignedBits oob_item,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -377,17 +375,15 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_keys_in + block_offset, keys, valid_items, oob_item);
}


/**
* Load a tile of values (specialized for full tile, any ranking algorithm)
* Load a tile of values (specialized for full tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
BlockLoadValuesT(temp_storage.load_values).Load(
d_values_in + block_offset, values);
Expand All @@ -397,15 +393,14 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of values (specialized for partial tile, any ranking algorithm)
* Load a tile of values (specialized for partial tile, block load)
*/
template <int _RANK_ALGORITHM>
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<_RANK_ALGORITHM> rank_algorithm)
Int2Type<false> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -419,28 +414,27 @@ struct AgentRadixSortDownsweep


/**
* Load a tile of items (specialized for full tile, match ranking algorithm)
* Load a tile of items (specialized for full tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<true> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values);
}


/**
* Load a tile of items (specialized for partial tile, match ranking algorithm)
* Load a tile of items (specialized for partial tile, warp-striped load)
*/
__device__ __forceinline__ void LoadValues(
ValueT (&values)[ITEMS_PER_THREAD],
OffsetT block_offset,
OffsetT valid_items,
Int2Type<false> is_full_tile,
Int2Type<RADIX_RANK_MATCH> rank_algorithm)
Int2Type<true> warp_striped)
{
// Register pressure work-around: moving valid_items through shfl prevents compiler
// from reusing guards/addressing from prior guarded loads
Expand All @@ -449,7 +443,6 @@ struct AgentRadixSortDownsweep
LoadDirectWarpStriped(threadIdx.x, d_values_in + block_offset, values, valid_items);
}


/**
* Truck along associated values
*/
Expand All @@ -470,7 +463,7 @@ struct AgentRadixSortDownsweep
block_offset,
valid_items,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

ScatterValues<FULL_TILE>(
values,
Expand Down Expand Up @@ -515,7 +508,7 @@ struct AgentRadixSortDownsweep
valid_items,
default_key,
Int2Type<FULL_TILE>(),
Int2Type<RANK_ALGORITHM>());
Int2Type<LOAD_WARP_STRIPED>());

// Twiddle key bits if necessary
#pragma unroll
Expand Down
Loading