Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge branch 'main' into sort
Browse files Browse the repository at this point in the history
  • Loading branch information
canonizer committed Nov 2, 2020
2 parents 56151c8 + af39ee2 commit 349ad8f
Show file tree
Hide file tree
Showing 17 changed files with 65 additions and 716 deletions.
2 changes: 1 addition & 1 deletion cub/agent/agent_radix_sort_upsweep.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ struct AgentRadixSortUpsweep
PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter),
LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE,

LOG_COUNTER_LANES = CUB_MAX(0, RADIX_BITS - LOG_PACKING_RATIO),
LOG_COUNTER_LANES = CUB_MAX(0, int(RADIX_BITS) - int(LOG_PACKING_RATIO)),
COUNTER_LANES = 1 << LOG_COUNTER_LANES,

// To prevent counter overflow, we must periodically unpack and aggregate the
Expand Down
2 changes: 1 addition & 1 deletion cub/agent/agent_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ struct AgentReduce
{
BLOCK_THREADS = AgentReducePolicy::BLOCK_THREADS,
ITEMS_PER_THREAD = AgentReducePolicy::ITEMS_PER_THREAD,
VECTOR_LOAD_LENGTH = CUB_MIN(ITEMS_PER_THREAD, AgentReducePolicy::VECTOR_LOAD_LENGTH),
VECTOR_LOAD_LENGTH = CUB_MIN(int(ITEMS_PER_THREAD), int(AgentReducePolicy::VECTOR_LOAD_LENGTH)),
TILE_ITEMS = BLOCK_THREADS * ITEMS_PER_THREAD,

// Can vectorize according to the policy if the input iterator is a native pointer to a primitive type
Expand Down
26 changes: 23 additions & 3 deletions cub/agent/agent_scan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,19 @@ struct AgentScan
OutputT items[ITEMS_PER_THREAD];

if (IS_LAST_TILE)
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, num_remaining);
{
// Fill last element with the first element because collectives are
// not suffix guarded.
BlockLoadT(temp_storage.load)
.Load(d_in + tile_offset,
items,
num_remaining,
*(d_in + tile_offset));
}
else
{
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);
}

CTA_SYNC();

Expand Down Expand Up @@ -330,7 +340,7 @@ struct AgentScan
* Scan tiles of items as part of a dynamic chained scan
*/
__device__ __forceinline__ void ConsumeRange(
int num_items, ///< Total number of input items
OffsetT num_items, ///< Total number of input items
ScanTileStateT& tile_state, ///< Global tile state descriptor
int start_tile) ///< The starting tile for the current grid
{
Expand Down Expand Up @@ -371,9 +381,19 @@ struct AgentScan
OutputT items[ITEMS_PER_THREAD];

if (IS_LAST_TILE)
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items, valid_items);
{
// Fill last element with the first element because collectives are
// not suffix guarded.
BlockLoadT(temp_storage.load)
.Load(d_in + tile_offset,
items,
valid_items,
*(d_in + tile_offset));
}
else
{
BlockLoadT(temp_storage.load).Load(d_in + tile_offset, items);
}

CTA_SYNC();

Expand Down
4 changes: 2 additions & 2 deletions cub/block/block_load.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ private:
};

// Assert BLOCK_THREADS must be a multiple of WARP_THREADS
CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");
CUB_STATIC_ASSERT((int(BLOCK_THREADS) % int(WARP_THREADS) == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");

// BlockExchange utility type for keys
typedef BlockExchange<InputT, BLOCK_DIM_X, ITEMS_PER_THREAD, false, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> BlockExchange;
Expand Down Expand Up @@ -940,7 +940,7 @@ private:
};

// Assert BLOCK_THREADS must be a multiple of WARP_THREADS
CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");
CUB_STATIC_ASSERT((int(BLOCK_THREADS) % int(WARP_THREADS) == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");

// BlockExchange utility type for keys
typedef BlockExchange<InputT, BLOCK_DIM_X, ITEMS_PER_THREAD, true, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> BlockExchange;
Expand Down
2 changes: 1 addition & 1 deletion cub/block/block_radix_rank.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ private:
PACKING_RATIO = sizeof(PackedCounter) / sizeof(DigitCounter),
LOG_PACKING_RATIO = Log2<PACKING_RATIO>::VALUE,

LOG_COUNTER_LANES = CUB_MAX((RADIX_BITS - LOG_PACKING_RATIO), 0), // Always at least one lane
LOG_COUNTER_LANES = CUB_MAX((int(RADIX_BITS) - int(LOG_PACKING_RATIO)), 0), // Always at least one lane
COUNTER_LANES = 1 << LOG_COUNTER_LANES,

// The number of packed counters per thread (plus one for padding)
Expand Down
4 changes: 2 additions & 2 deletions cub/block/block_store.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,7 @@ private:
};

// Assert BLOCK_THREADS must be a multiple of WARP_THREADS
CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");
CUB_STATIC_ASSERT((int(BLOCK_THREADS) % int(WARP_THREADS) == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");

// BlockExchange utility type for keys
typedef BlockExchange<T, BLOCK_DIM_X, ITEMS_PER_THREAD, false, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> BlockExchange;
Expand Down Expand Up @@ -765,7 +765,7 @@ private:
};

// Assert BLOCK_THREADS must be a multiple of WARP_THREADS
CUB_STATIC_ASSERT((BLOCK_THREADS % WARP_THREADS == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");
CUB_STATIC_ASSERT((int(BLOCK_THREADS) % int(WARP_THREADS) == 0), "BLOCK_THREADS must be a multiple of WARP_THREADS");

// BlockExchange utility type for keys
typedef BlockExchange<T, BLOCK_DIM_X, ITEMS_PER_THREAD, true, BLOCK_DIM_Y, BLOCK_DIM_Z, PTX_ARCH> BlockExchange;
Expand Down
2 changes: 1 addition & 1 deletion cub/block/specializations/block_reduce_raking.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct BlockReduceRaking
SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH,

/// Cooperative work can be entirely warp synchronous
WARP_SYNCHRONOUS = (RAKING_THREADS == BLOCK_THREADS),
WARP_SYNCHRONOUS = (int(RAKING_THREADS) == int(BLOCK_THREADS)),

/// Whether or not warp-synchronous reduction should be unguarded (i.e., the warp-reduction elements is a power of two
WARP_SYNCHRONOUS_UNGUARDED = PowerOfTwo<RAKING_THREADS>::VALUE,
Expand Down
2 changes: 1 addition & 1 deletion cub/block/specializations/block_scan_raking.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ struct BlockScanRaking
SEGMENT_LENGTH = BlockRakingLayout::SEGMENT_LENGTH,

/// Cooperative work can be entirely warp synchronous
WARP_SYNCHRONOUS = (BLOCK_THREADS == RAKING_THREADS),
WARP_SYNCHRONOUS = (int(BLOCK_THREADS) == int(RAKING_THREADS)),
};

/// WarpScan utility type
Expand Down
71 changes: 2 additions & 69 deletions cub/device/dispatch/dispatch_histogram.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -341,52 +341,6 @@ struct DipatchHistogram
};
};


/// SM11
struct Policy110
{
// HistogramSweepPolicy
typedef AgentHistogramPolicy<
512,
(NUM_CHANNELS == 1) ? 8 : 2,
BLOCK_LOAD_DIRECT,
LOAD_DEFAULT,
true,
GMEM,
false>
HistogramSweepPolicy;
};

/// SM20
struct Policy200
{
// HistogramSweepPolicy
typedef AgentHistogramPolicy<
(NUM_CHANNELS == 1) ? 256 : 128,
(NUM_CHANNELS == 1) ? 8 : 3,
(NUM_CHANNELS == 1) ? BLOCK_LOAD_DIRECT : BLOCK_LOAD_WARP_TRANSPOSE,
LOAD_DEFAULT,
true,
SMEM,
false>
HistogramSweepPolicy;
};

/// SM30
struct Policy300
{
// HistogramSweepPolicy
typedef AgentHistogramPolicy<
512,
(NUM_CHANNELS == 1) ? 8 : 2,
BLOCK_LOAD_DIRECT,
LOAD_DEFAULT,
true,
GMEM,
false>
HistogramSweepPolicy;
};

/// SM35
struct Policy350
{
Expand Down Expand Up @@ -426,17 +380,8 @@ struct DipatchHistogram
#if (CUB_PTX_ARCH >= 500)
typedef Policy500 PtxPolicy;

#elif (CUB_PTX_ARCH >= 350)
typedef Policy350 PtxPolicy;

#elif (CUB_PTX_ARCH >= 300)
typedef Policy300 PtxPolicy;

#elif (CUB_PTX_ARCH >= 200)
typedef Policy200 PtxPolicy;

#else
typedef Policy110 PtxPolicy;
typedef Policy350 PtxPolicy;

#endif

Expand Down Expand Up @@ -473,21 +418,9 @@ struct DipatchHistogram
{
result = histogram_sweep_config.template Init<typename Policy500::HistogramSweepPolicy>();
}
else if (ptx_version >= 350)
{
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
else if (ptx_version >= 300)
{
result = histogram_sweep_config.template Init<typename Policy300::HistogramSweepPolicy>();
}
else if (ptx_version >= 200)
{
result = histogram_sweep_config.template Init<typename Policy200::HistogramSweepPolicy>();
}
else
{
result = histogram_sweep_config.template Init<typename Policy110::HistogramSweepPolicy>();
result = histogram_sweep_config.template Init<typename Policy350::HistogramSweepPolicy>();
}
#endif
}
Expand Down
122 changes: 3 additions & 119 deletions cub/device/dispatch/dispatch_radix_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,8 @@ template <
typename ValueT, ///< Value type
typename OffsetT> ///< Signed integer type for global offsets
__launch_bounds__ (int((ALT_DIGIT_BITS) ?
ChainedPolicyT::ActivePolicy::AltDownsweepPolicy::BLOCK_THREADS :
ChainedPolicyT::ActivePolicy::DownsweepPolicy::BLOCK_THREADS))
int(ChainedPolicyT::ActivePolicy::AltDownsweepPolicy::BLOCK_THREADS) :
int(ChainedPolicyT::ActivePolicy::DownsweepPolicy::BLOCK_THREADS)))
__global__ void DeviceRadixSortDownsweepKernel(
const KeyT *d_keys_in, ///< [in] Input keys buffer
KeyT *d_keys_out, ///< [in] Output keys buffer
Expand Down Expand Up @@ -622,124 +622,8 @@ struct DeviceRadixSortPolicy
// Architecture-specific tuning policies
//------------------------------------------------------------------------------

/// SM20
struct Policy200 : ChainedPolicy<200, Policy200, Policy200>
{
enum {
PRIMARY_RADIX_BITS = 5,
ALT_RADIX_BITS = PRIMARY_RADIX_BITS - 1,

// Relative size of KeyT type to a 4-byte word
SCALE_FACTOR_4B = (CUB_MAX(sizeof(KeyT), sizeof(ValueT)) + 3) / 4,
ONESWEEP = false,
ONESWEEP_RADIX_BITS = 8,
};

// Histogram policy
typedef AgentRadixSortHistogramPolicy <256, 8, 1, KeyT, ONESWEEP_RADIX_BITS> HistogramPolicy;

// Exclusive sum policy
typedef AgentRadixSortExclusiveSumPolicy <256, ONESWEEP_RADIX_BITS> ExclusiveSumPolicy;

// Onesweep policy
typedef AgentRadixSortOnesweepPolicy <256, 21, DominantT, 1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY, BLOCK_SCAN_WARP_SCANS, RADIX_SORT_STORE_DIRECT,
ONESWEEP_RADIX_BITS> OnesweepPolicy;

// Keys-only upsweep policies
typedef AgentRadixSortUpsweepPolicy <64, 18, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyKeys;
typedef AgentRadixSortUpsweepPolicy <64, 18, DominantT, LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyKeys;

// Key-value pairs upsweep policies
typedef AgentRadixSortUpsweepPolicy <128, 13, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyPairs;
typedef AgentRadixSortUpsweepPolicy <128, 13, DominantT, LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyPairs;

// Upsweep policies
typedef typename If<KEYS_ONLY, UpsweepPolicyKeys, UpsweepPolicyPairs>::Type UpsweepPolicy;
typedef typename If<KEYS_ONLY, AltUpsweepPolicyKeys, AltUpsweepPolicyPairs>::Type AltUpsweepPolicy;

// Scan policy
typedef AgentScanPolicy <512, 4, OffsetT, BLOCK_LOAD_VECTORIZE, LOAD_DEFAULT, BLOCK_STORE_VECTORIZE, BLOCK_SCAN_RAKING_MEMOIZE> ScanPolicy;

// Keys-only downsweep policies
typedef AgentRadixSortDownsweepPolicy <64, 18, DominantT, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyKeys;
typedef AgentRadixSortDownsweepPolicy <64, 18, DominantT, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyKeys;

// Key-value pairs downsweep policies
typedef AgentRadixSortDownsweepPolicy <128, 13, DominantT, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyPairs;
typedef AgentRadixSortDownsweepPolicy <128, 13, DominantT, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyPairs;

// Downsweep policies
typedef typename If<KEYS_ONLY, DownsweepPolicyKeys, DownsweepPolicyPairs>::Type DownsweepPolicy;
typedef typename If<KEYS_ONLY, AltDownsweepPolicyKeys, AltDownsweepPolicyPairs>::Type AltDownsweepPolicy;

// Single-tile policy
typedef DownsweepPolicy SingleTilePolicy;

// Segmented policies
typedef DownsweepPolicy SegmentedPolicy;
typedef AltDownsweepPolicy AltSegmentedPolicy;
};

/// SM30
struct Policy300 : ChainedPolicy<300, Policy300, Policy200>
{
enum {
PRIMARY_RADIX_BITS = 5,
ALT_RADIX_BITS = PRIMARY_RADIX_BITS - 1,
ONESWEEP = false,
ONESWEEP_RADIX_BITS = 8,
};

// Histogram policy
typedef AgentRadixSortHistogramPolicy <256, 8, 1, KeyT, ONESWEEP_RADIX_BITS> HistogramPolicy;

// Exclusive sum policy
typedef AgentRadixSortExclusiveSumPolicy <256, ONESWEEP_RADIX_BITS> ExclusiveSumPolicy;

// Onesweep policy
typedef AgentRadixSortOnesweepPolicy <256, 21, DominantT, 1,
RADIX_RANK_MATCH_EARLY_COUNTS_ANY, BLOCK_SCAN_WARP_SCANS, RADIX_SORT_STORE_DIRECT,
ONESWEEP_RADIX_BITS> OnesweepPolicy;

// Keys-only upsweep policies
typedef AgentRadixSortUpsweepPolicy <256, 7, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyKeys;
typedef AgentRadixSortUpsweepPolicy <256, 7, DominantT, LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyKeys;

// Key-value pairs upsweep policies
typedef AgentRadixSortUpsweepPolicy <256, 5, DominantT, LOAD_DEFAULT, PRIMARY_RADIX_BITS> UpsweepPolicyPairs;
typedef AgentRadixSortUpsweepPolicy <256, 5, DominantT, LOAD_DEFAULT, ALT_RADIX_BITS> AltUpsweepPolicyPairs;

// Upsweep policies
typedef typename If<KEYS_ONLY, UpsweepPolicyKeys, UpsweepPolicyPairs>::Type UpsweepPolicy;
typedef typename If<KEYS_ONLY, AltUpsweepPolicyKeys, AltUpsweepPolicyPairs>::Type AltUpsweepPolicy;

// Scan policy
typedef AgentScanPolicy <1024, 4, OffsetT, BLOCK_LOAD_VECTORIZE, LOAD_DEFAULT, BLOCK_STORE_VECTORIZE, BLOCK_SCAN_WARP_SCANS> ScanPolicy;

// Keys-only downsweep policies
typedef AgentRadixSortDownsweepPolicy <128, 14, DominantT, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyKeys;
typedef AgentRadixSortDownsweepPolicy <128, 14, DominantT, BLOCK_LOAD_WARP_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyKeys;

// Key-value pairs downsweep policies
typedef AgentRadixSortDownsweepPolicy <128, 10, DominantT, BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, PRIMARY_RADIX_BITS> DownsweepPolicyPairs;
typedef AgentRadixSortDownsweepPolicy <128, 10, DominantT, BLOCK_LOAD_TRANSPOSE, LOAD_DEFAULT, RADIX_RANK_BASIC, BLOCK_SCAN_WARP_SCANS, ALT_RADIX_BITS> AltDownsweepPolicyPairs;

// Downsweep policies
typedef typename If<KEYS_ONLY, DownsweepPolicyKeys, DownsweepPolicyPairs>::Type DownsweepPolicy;
typedef typename If<KEYS_ONLY, AltDownsweepPolicyKeys, AltDownsweepPolicyPairs>::Type AltDownsweepPolicy;

// Single-tile policy
typedef DownsweepPolicy SingleTilePolicy;

// Segmented policies
typedef DownsweepPolicy SegmentedPolicy;
typedef AltDownsweepPolicy AltSegmentedPolicy;
};


/// SM35
struct Policy350 : ChainedPolicy<350, Policy350, Policy300>
struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
{
enum {
PRIMARY_RADIX_BITS = (sizeof(KeyT) > 1) ? 6 : 5, // 1.72B 32b keys/s, 1.17B 32b pairs/s, 1.55B 32b segmented keys/s (K40m)
Expand Down
Loading

0 comments on commit 349ad8f

Please sign in to comment.