Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Optimize compilation time for the common case #400

Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 61 additions & 32 deletions cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@

#pragma once

#include "../../util_math.cuh"
#include "../../util_device.cuh"
#include "../../util_namespace.cuh"
#include "../../agent/agent_merge_sort.cuh"
#include <cub/util_math.cuh>
#include <cub/util_device.cuh>
#include <cub/util_namespace.cuh>
#include <cub/agent/agent_merge_sort.cuh>

#include <thrust/system/cuda/detail/core/triple_chevron_launch.h>
#include <thrust/detail/integer_math.h>
Expand Down Expand Up @@ -245,7 +245,8 @@ template <typename KeyInputIteratorT,
typename ActivePolicyT,
typename CompareOpT,
typename KeyT,
typename ValueT>
typename ValueT,
bool AgentFitsIntoDefaultShmemSize>
struct BlockSortLauncher
{
int num_tiles;
Expand Down Expand Up @@ -309,12 +310,15 @@ struct BlockSortLauncher
template <bool UseVShmem>
CUB_RUNTIME_FUNCTION __forceinline__ void launch_impl() const
{
constexpr bool use_vshmem = (AgentFitsIntoDefaultShmemSize == false) &&
UseVShmem;

THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_tiles,
ActivePolicyT::MergeSortPolicy::BLOCK_THREADS,
block_sort_shmem_size,
use_vshmem ? 0 : block_sort_shmem_size,
stream)
.doit(DeviceMergeSortBlockSortKernel<UseVShmem,
.doit(DeviceMergeSortBlockSortKernel<use_vshmem,
ChainedPolicyT,
KeyInputIteratorT,
ValueInputIteratorT,
Expand Down Expand Up @@ -344,7 +348,8 @@ template <typename KeyIteratorT,
typename ActivePolicyT,
typename CompareOpT,
typename KeyT,
typename ValueT>
typename ValueT,
bool AgentFitsIntoDefaultShmemSize>
struct MergeLauncher
{
int num_tiles;
Expand Down Expand Up @@ -402,12 +407,15 @@ struct MergeLauncher
CUB_RUNTIME_FUNCTION __forceinline__ void
launch_impl(bool ping, OffsetT target_merged_tiles_number) const
{
constexpr bool use_vshmem = (AgentFitsIntoDefaultShmemSize == false) &&
UseVShmem;

THRUST_NS_QUALIFIER::cuda_cub::launcher::triple_chevron(
num_tiles,
ActivePolicyT::MergeSortPolicy::BLOCK_THREADS,
merge_shmem_size,
use_vshmem ? 0 : merge_shmem_size,
stream)
.doit(DeviceMergeSortMergeKernel<UseVShmem,
.doit(DeviceMergeSortMergeKernel<use_vshmem,
ChainedPolicyT,
KeyIteratorT,
ValueIteratorT,
Expand Down Expand Up @@ -544,22 +552,21 @@ struct DispatchMergeSort : SelectedPolicy
}

// Get shared memory size
int max_shmem = 0;
if (CubDebug(error = cudaDeviceGetAttribute(&max_shmem,
cudaDevAttrMaxSharedMemoryPerBlock,
device_ordinal)))
{
break;
}

const auto tile_size = MergePolicyT::ITEMS_PER_TILE;
const auto num_tiles = cub::DivideAndRoundUp(num_items, tile_size);

const auto block_sort_shmem_size =
constexpr std::size_t default_shared_memory_size = 48 * 1024;
gevtushenko marked this conversation as resolved.
Show resolved Hide resolved
constexpr auto block_sort_shmem_size =
static_cast<std::size_t>(BlockSortAgentT::SHARED_MEMORY_SIZE);
constexpr bool block_sort_fits_into_default_shmem =
block_sort_shmem_size < default_shared_memory_size;

const auto merge_shmem_size =
constexpr auto merge_shmem_size =
static_cast<std::size_t>(MergeAgentT::SHARED_MEMORY_SIZE);
constexpr bool merge_fits_into_default_shmem = merge_shmem_size <
default_shared_memory_size;
constexpr bool runtime_shmem_size_check_is_required =
!(merge_fits_into_default_shmem && block_sort_fits_into_default_shmem);

const auto merge_partitions_size =
static_cast<std::size_t>(1 + num_tiles) * sizeof(OffsetT);
Expand All @@ -570,10 +577,32 @@ struct DispatchMergeSort : SelectedPolicy
const auto temporary_values_storage_size =
static_cast<std::size_t>(num_items * sizeof(ValueT)) * !KEYS_ONLY;

const auto virtual_shared_memory_size =
vshmem_size(static_cast<std::size_t>(max_shmem),
(cub::max)(block_sort_shmem_size, merge_shmem_size),
static_cast<std::size_t>(num_tiles));
std::size_t virtual_shared_memory_size = 0;
bool block_sort_requires_vshmem = false;
bool merge_requires_vshmem = false;

if (runtime_shmem_size_check_is_required)
{
int max_shmem = 0;
if (CubDebug(
error = cudaDeviceGetAttribute(&max_shmem,
cudaDevAttrMaxSharedMemoryPerBlock,
device_ordinal)))
{
break;
}

block_sort_requires_vshmem = block_sort_shmem_size >
static_cast<std::size_t>(max_shmem);
merge_requires_vshmem = merge_shmem_size >
static_cast<std::size_t>(max_shmem);

virtual_shared_memory_size =
vshmem_size(static_cast<std::size_t>(max_shmem),
(cub::max)(block_sort_shmem_size, merge_shmem_size),
static_cast<std::size_t>(num_tiles));
}


void *allocations[4] = {nullptr, nullptr, nullptr, nullptr};
std::size_t allocation_sizes[4] = {merge_partitions_size,
Expand Down Expand Up @@ -633,11 +662,10 @@ struct DispatchMergeSort : SelectedPolicy
ActivePolicyT,
CompareOpT,
KeyT,
ValueT>
ValueT,
block_sort_fits_into_default_shmem>
block_sort_launcher(static_cast<int>(num_tiles),
virtual_shared_memory_size > 0
? 0
: block_sort_shmem_size,
block_sort_shmem_size,
ping,
d_input_keys,
d_input_items,
Expand All @@ -648,7 +676,7 @@ struct DispatchMergeSort : SelectedPolicy
stream,
keys_buffer,
items_buffer,
vshmem_ptr);
block_sort_requires_vshmem ? vshmem_ptr : nullptr);

block_sort_launcher.launch();

Expand Down Expand Up @@ -679,9 +707,10 @@ struct DispatchMergeSort : SelectedPolicy
ActivePolicyT,
CompareOpT,
KeyT,
ValueT>
ValueT,
merge_fits_into_default_shmem>
merge_launcher(static_cast<int>(num_tiles),
virtual_shared_memory_size > 0 ? 0 : merge_shmem_size,
merge_shmem_size,
d_output_keys,
d_output_items,
num_items,
Expand All @@ -690,7 +719,7 @@ struct DispatchMergeSort : SelectedPolicy
stream,
keys_buffer,
items_buffer,
vshmem_ptr);
merge_requires_vshmem ? vshmem_ptr : nullptr);

for (int pass = 0; pass < num_passes; ++pass, ping = !ping)
{
Expand Down