Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Optimize compilation time for the common case #400

Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 123 additions & 55 deletions cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@

#pragma once

#include "../../util_math.cuh"
#include "../../util_device.cuh"
#include "../../util_namespace.cuh"
#include "../../agent/agent_merge_sort.cuh"
#include <cub/util_math.cuh>
#include <cub/util_device.cuh>
#include <cub/util_namespace.cuh>
#include <cub/agent/agent_merge_sort.cuh>

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>
#include <thrust/detail/integer_math.h>
Expand Down Expand Up @@ -183,7 +183,7 @@ DeviceMergeSortMergeKernel(bool ping,
agent.Process();
}

/******************************************************************************
/*******************************************************************************
* Policy
******************************************************************************/

Expand All @@ -192,9 +192,9 @@ struct DeviceMergeSortPolicy
{
using KeyT = typename std::iterator_traits<KeyIteratorT>::value_type;

//------------------------------------------------------------------------------
//----------------------------------------------------------------------------
// Architecture-specific tuning policies
//------------------------------------------------------------------------------
//----------------------------------------------------------------------------

struct Policy350 : ChainedPolicy<350, Policy350, Policy350>
{
Expand Down Expand Up @@ -245,7 +245,8 @@ template <typename KeyInputIteratorT,
typename ActivePolicyT,
typename CompareOpT,
typename KeyT,
typename ValueT>
typename ValueT,
bool AgentFitsIntoDefaultShmemSize>
struct BlockSortLauncher
{
int num_tiles;
Expand Down Expand Up @@ -309,12 +310,15 @@ struct BlockSortLauncher
template <bool UseVShmem>
CUB_RUNTIME_FUNCTION __forceinline__ void launch_impl() const
{
constexpr bool use_vshmem = (AgentFitsIntoDefaultShmemSize == false) &&
UseVShmem;

THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_tiles,
ActivePolicyT::MergeSortPolicy::BLOCK_THREADS,
block_sort_shmem_size,
use_vshmem ? 0 : block_sort_shmem_size,
stream)
.doit(DeviceMergeSortBlockSortKernel<UseVShmem,
.doit(DeviceMergeSortBlockSortKernel<use_vshmem,
ChainedPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
Expand Down Expand Up @@ -344,7 +348,8 @@ template <typename KeyIteratorT,
typename ActivePolicyT,
typename CompareOpT,
typename KeyT,
typename ValueT>
typename ValueT,
bool AgentFitsIntoDefaultShmemSize>
struct MergeLauncher
{
int num_tiles;
Expand Down Expand Up @@ -402,12 +407,15 @@ struct MergeLauncher
CUB_RUNTIME_FUNCTION __forceinline__ void
launch_impl(bool ping, OffsetT target_merged_tiles_number) const
{
constexpr bool use_vshmem = (AgentFitsIntoDefaultShmemSize == false) &&
UseVShmem;

THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_tiles,
ActivePolicyT::MergeSortPolicy::BLOCK_THREADS,
merge_shmem_size,
use_vshmem ? 0 : merge_shmem_size,
stream)
.doit(DeviceMergeSortMergeKernel<UseVShmem,
.doit(DeviceMergeSortMergeKernel<use_vshmem,
ChainedPolicyT,
KeyIteratorT,
ValueIteratorT,
Expand Down Expand Up @@ -440,28 +448,45 @@ struct DispatchMergeSort : SelectedPolicy
using KeyT = typename std::iterator_traits<KeyIteratorT>::value_type;
using ValueT = typename std::iterator_traits<ValueIteratorT>::value_type;

// Whether or not there are values to be trucked along with keys
/// Whether or not there are values to be trucked along with keys
static constexpr bool KEYS_ONLY = Equals<ValueT, NullType>::VALUE;

//------------------------------------------------------------------------------
// Problem state
//------------------------------------------------------------------------------

void *d_temp_storage; ///< [in] Device-accessible allocation of temporary storage. When NULL, the required allocation size is written to \p temp_storage_bytes and no work is done.
std::size_t &temp_storage_bytes; ///< [in,out] Reference to size in bytes of \p d_temp_storage allocation
KeyInputIteratorT d_input_keys; ///< [in] Pointer to the input sequence of unsorted input keys
ValueInputIteratorT d_input_items;///< [in] Pointer to the input sequence of unsorted input values
KeyIteratorT d_output_keys; ///< [out] Pointer to the output sequence of sorted input keys
ValueIteratorT d_output_items; ///< [out] Pointer to the output sequence of sorted input values
OffsetT num_items; ///< [in] Number of items to sort
CompareOpT compare_op; ///< [in] Comparison function object which returns true if the first argument is ordered before the second
cudaStream_t stream; ///< [in] <b>[optional]</b> CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
bool debug_synchronous; ///< [in] Whether or not to synchronize the stream after every kernel launch to check for errors. Also causes launch configurations to be printed to the console. Default is \p false.
int ptx_version;

//------------------------------------------------------------------------------
// Constructor
//------------------------------------------------------------------------------
/// Device-accessible allocation of temporary storage. When NULL, the required
/// allocation size is written to \p temp_storage_bytes and no work is done.
void *d_temp_storage;

/// Reference to size in bytes of \p d_temp_storage allocation
std::size_t &temp_storage_bytes;

/// Pointer to the input sequence of unsorted input keys
KeyInputIteratorT d_input_keys;

/// Pointer to the input sequence of unsorted input values
ValueInputIteratorT d_input_items;

/// Pointer to the output sequence of sorted input keys
KeyIteratorT d_output_keys;

/// Pointer to the output sequence of sorted input values
ValueIteratorT d_output_items;

/// Number of items to sort
OffsetT num_items;

/// Comparison function object which returns true if the first argument is
/// ordered before the second
CompareOpT compare_op;

/// CUDA stream to launch kernels within. Default is stream<sub>0</sub>.
cudaStream_t stream;

/// Whether or not to synchronize the stream after every kernel launch to
/// check for errors. Also causes launch configurations to be printed to the
/// console. Default is \p false.
bool debug_synchronous;
int ptx_version;

CUB_RUNTIME_FUNCTION __forceinline__ std::size_t
vshmem_size(std::size_t max_shmem,
Expand All @@ -478,7 +503,7 @@ struct DispatchMergeSort : SelectedPolicy
}
}

/// Constructor
// Constructor
CUB_RUNTIME_FUNCTION __forceinline__
DispatchMergeSort(void *d_temp_storage,
std::size_t &temp_storage_bytes,
Expand All @@ -504,7 +529,7 @@ struct DispatchMergeSort : SelectedPolicy
, ptx_version(ptx_version)
{}

/// Invocation
// Invocation
template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION __forceinline__ cudaError_t Invoke()
{
Expand Down Expand Up @@ -544,22 +569,43 @@ struct DispatchMergeSort : SelectedPolicy
}

// Get shared memory size
int max_shmem = 0;
if (CubDebug(error = cudaDeviceGetAttribute(&max_shmem,
cudaDevAttrMaxSharedMemoryPerBlock,
device_ordinal)))
{
break;
}

const auto tile_size = MergePolicyT::ITEMS_PER_TILE;
const auto num_tiles = cub::DivideAndRoundUp(num_items, tile_size);

const auto block_sort_shmem_size =
/**
* Merge sort supports large types, which can lead to excessive shared
* memory size requirements. In these cases, merge sort allocates virtual
* shared memory that resides in global memory:
* ```
* extern __shared__ char shmem[];
* typename AgentT::TempStorage &storage =
* *reinterpret_cast<typename AgentT::TempStorage *>(
* UseVShmem ? vshmem + vshmem_offset : shmem);
* ```
* Having `UseVShmem` as a runtime variable leads to the generation of
* generic loads and stores, which causes a slowdown. Therefore,
* `UseVShmem` has to be known at compilation time.
* In the generic case, available shared memory size is queried at runtime
* to check if kernels requirements are satisfied. Since the query result
* is not known at compile-time, merge sort kernels are specialized for
* both cases.
* To address increased compilation time, the dispatch layer checks
* whether kernels requirements fit into default shared memory
* size (48KB). In this case, there's no need for virtual shared
* memory specialization.
*/
constexpr std::size_t default_shared_memory_size = 48 * 1024;
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
constexpr auto block_sort_shmem_size =
static_cast<std::size_t>(BlockSortAgentT::SHARED_MEMORY_SIZE);
constexpr bool block_sort_fits_into_default_shmem =
block_sort_shmem_size < default_shared_memory_size;

const auto merge_shmem_size =
constexpr auto merge_shmem_size =
static_cast<std::size_t>(MergeAgentT::SHARED_MEMORY_SIZE);
constexpr bool merge_fits_into_default_shmem = merge_shmem_size <
default_shared_memory_size;
constexpr bool runtime_shmem_size_check_is_required =
!(merge_fits_into_default_shmem && block_sort_fits_into_default_shmem);

const auto merge_partitions_size =
static_cast<std::size_t>(1 + num_tiles) * sizeof(OffsetT);
Expand All @@ -570,10 +616,32 @@ struct DispatchMergeSort : SelectedPolicy
const auto temporary_values_storage_size =
static_cast<std::size_t>(num_items * sizeof(ValueT)) * !KEYS_ONLY;

const auto virtual_shared_memory_size =
vshmem_size(static_cast<std::size_t>(max_shmem),
(cub::max)(block_sort_shmem_size, merge_shmem_size),
static_cast<std::size_t>(num_tiles));
std::size_t virtual_shared_memory_size = 0;
bool block_sort_requires_vshmem = false;
bool merge_requires_vshmem = false;

if (runtime_shmem_size_check_is_required)
{
int max_shmem = 0;
if (CubDebug(
error = cudaDeviceGetAttribute(&max_shmem,
cudaDevAttrMaxSharedMemoryPerBlock,
device_ordinal)))
{
break;
}

block_sort_requires_vshmem = block_sort_shmem_size >
static_cast<std::size_t>(max_shmem);
merge_requires_vshmem = merge_shmem_size >
static_cast<std::size_t>(max_shmem);

virtual_shared_memory_size =
vshmem_size(static_cast<std::size_t>(max_shmem),
(cub::max)(block_sort_shmem_size, merge_shmem_size),
static_cast<std::size_t>(num_tiles));
}


void *allocations[4] = {nullptr, nullptr, nullptr, nullptr};
std::size_t allocation_sizes[4] = {merge_partitions_size,
Expand Down Expand Up @@ -633,11 +701,10 @@ struct DispatchMergeSort : SelectedPolicy
ActivePolicyT,
CompareOpT,
KeyT,
ValueT>
ValueT,
block_sort_fits_into_default_shmem>
block_sort_launcher(static_cast<int>(num_tiles),
virtual_shared_memory_size > 0
? 0
: block_sort_shmem_size,
block_sort_shmem_size,
ping,
d_input_keys,
d_input_items,
Expand All @@ -648,7 +715,7 @@ struct DispatchMergeSort : SelectedPolicy
stream,
keys_buffer,
items_buffer,
vshmem_ptr);
block_sort_requires_vshmem ? vshmem_ptr : nullptr);

block_sort_launcher.launch();

Expand Down Expand Up @@ -679,9 +746,10 @@ struct DispatchMergeSort : SelectedPolicy
ActivePolicyT,
CompareOpT,
KeyT,
ValueT>
ValueT,
merge_fits_into_default_shmem>
merge_launcher(static_cast<int>(num_tiles),
virtual_shared_memory_size > 0 ? 0 : merge_shmem_size,
merge_shmem_size,
d_output_keys,
d_output_items,
num_items,
Expand All @@ -690,7 +758,7 @@ struct DispatchMergeSort : SelectedPolicy
stream,
keys_buffer,
items_buffer,
vshmem_ptr);
merge_requires_vshmem ? vshmem_ptr : nullptr);

for (int pass = 0; pass < num_passes; ++pass, ping = !ping)
{
Expand Down