Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Resolve review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
zasdfgbnm authored and alliepiper committed Jan 19, 2022
1 parent 70d2fbb commit 8986eef
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 306 deletions.
28 changes: 14 additions & 14 deletions cub/agent/agent_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,14 @@ struct AgentUniqueByKey
};

// Cache-modified Input iterator wrapper type (for applying cache modifier) for keys
using WrappedKeyInputIteratorT = typename If<IsPointer<KeyInputIteratorT>::VALUE,
using WrappedKeyInputIteratorT = typename std::conditional<IsPointer<KeyInputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentUniqueByKeyPolicyT::LOAD_MODIFIER, KeyT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
KeyInputIteratorT>::Type; // Directly use the supplied input iterator type
KeyInputIteratorT>::type; // Directly use the supplied input iterator type

// Cache-modified Input iterator wrapper type (for applying cache modifier) for values
using WrappedValueInputIteratorT = typename If<IsPointer<ValueInputIteratorT>::VALUE,
using WrappedValueInputIteratorT = typename std::conditional<IsPointer<ValueInputIteratorT>::VALUE,
CacheModifiedInputIterator<AgentUniqueByKeyPolicyT::LOAD_MODIFIER, ValueT, OffsetT>, // Wrap the native input pointer with CacheModifiedValuesInputIterator
ValueInputIteratorT>::Type; // Directly use the supplied input iterator type
ValueInputIteratorT>::type; // Directly use the supplied input iterator type

// Parameterized BlockLoad type for input data
using BlockLoadKeys = BlockLoad<
Expand Down Expand Up @@ -214,16 +214,16 @@ struct AgentUniqueByKey
// Utility functions
//---------------------------------------------------------------------

struct key_tag {};
struct value_tag {};
struct KeyTagT {};
struct ValueTagT {};

__device__ __forceinline__
KeyExchangeT &get_shared(key_tag)
KeyExchangeT &GetShared(KeyTagT)
{
return temp_storage.shared_keys.Alias();
}
__device__ __forceinline__
ValueExchangeT &get_shared(value_tag)
ValueExchangeT &GetShared(ValueTagT)
{
return temp_storage.shared_values.Alias();
}
Expand Down Expand Up @@ -253,7 +253,7 @@ struct AgentUniqueByKey
num_selections_prefix;
if (selection_flags[ITEM])
{
get_shared(tag)[local_scatter_offset] = items[ITEM];
GetShared(tag)[local_scatter_offset] = items[ITEM];
}
}

Expand All @@ -263,7 +263,7 @@ struct AgentUniqueByKey
item < num_tile_selections;
item += BLOCK_THREADS)
{
items_out[num_selections_prefix + item] = get_shared(tag)[item];
items_out[num_selections_prefix + item] = GetShared(tag)[item];
}

CTA_SYNC();
Expand Down Expand Up @@ -364,7 +364,7 @@ struct AgentUniqueByKey

CTA_SYNC();

Scatter(key_tag(),
Scatter(KeyTagT(),
d_keys_out,
keys,
selection_flags,
Expand All @@ -376,7 +376,7 @@ struct AgentUniqueByKey

CTA_SYNC();

Scatter(value_tag(),
Scatter(ValueTagT(),
d_values_out,
values,
selection_flags,
Expand Down Expand Up @@ -481,7 +481,7 @@ struct AgentUniqueByKey

CTA_SYNC();

Scatter(key_tag(),
Scatter(KeyTagT(),
d_keys_out,
keys,
selection_flags,
Expand All @@ -493,7 +493,7 @@ struct AgentUniqueByKey

CTA_SYNC();

Scatter(value_tag(),
Scatter(ValueTagT(),
d_values_out,
values,
selection_flags,
Expand Down
2 changes: 1 addition & 1 deletion cub/device/dispatch/dispatch_merge_sort.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -622,7 +622,7 @@ struct DispatchMergeSort : SelectedPolicy
static_cast<std::size_t>(max_shmem);

virtual_shared_memory_size =
VshmemSize(static_cast<std::size_t>(max_shmem),
detail::VshmemSize(static_cast<std::size_t>(max_shmem),
(cub::max)(block_sort_shmem_size, merge_shmem_size),
static_cast<std::size_t>(num_tiles));
}
Expand Down
107 changes: 7 additions & 100 deletions cub/device/dispatch/dispatch_unique_by_key.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,6 @@
* cub::DeviceSelect::UniqueByKey provides device-wide, parallel operations for selecting unique items by key from sequences of data items residing within device-accessible memory.
*/

// TODO: remove this note:
// copy-pasted from https://github.com/NVIDIA/thrust/blob/main/thrust/system/cuda/detail/unique_by_key.h

#include "../../agent/agent_unique_by_key.cuh"
#include "../../util_math.cuh"
#include "../../util_macro.cuh"
Expand Down Expand Up @@ -126,7 +123,7 @@ struct DeviceUniqueByKeyPolicy
NOMINAL_4B_ITEMS_PER_THREAD = 11,
ITEMS_PER_THREAD = Nominal4BItemsToItems<KeyT>(NOMINAL_4B_ITEMS_PER_THREAD),
};

using UniqueByKeyPolicyT = AgentUniqueByKeyPolicy<64,
ITEMS_PER_THREAD,
cub::BLOCK_LOAD_WARP_TRANSPOSE,
Expand Down Expand Up @@ -223,13 +220,13 @@ struct DispatchUniqueByKey: SelectedPolicy
cudaError_t Invoke(InitKernel init_kernel, ScanKernel scan_kernel)
{
#ifndef CUB_RUNTIME_ENABLED

(void)init_kernel;
(void)scan_kernel;

// Kernel launch not supported from this device
return CubDebug(cudaErrorNotSupported);

#else

using Policy = typename ActivePolicyT::UniqueByKeyPolicyT;
Expand Down Expand Up @@ -261,7 +258,7 @@ struct DispatchUniqueByKey: SelectedPolicy
{
break;
}
std::size_t vshmem_size = VshmemSize(max_shmem, sizeof(typename UniqueByKeyAgentT::TempStorage), num_tiles);
std::size_t vshmem_size = detail::VshmemSize(max_shmem, sizeof(typename UniqueByKeyAgentT::TempStorage), num_tiles);

// Specify temporary storage allocation requirements
size_t allocation_sizes[2] = {0, vshmem_size};
Expand Down Expand Up @@ -355,10 +352,10 @@ struct DispatchUniqueByKey: SelectedPolicy
while(0);

return error;

#endif // CUB_RUNTIME_ENABLED
}

template <typename ActivePolicyT>
CUB_RUNTIME_FUNCTION __host__ __forceinline__
cudaError_t Invoke()
Expand All @@ -379,7 +376,7 @@ struct DispatchUniqueByKey: SelectedPolicy
);
}


/**
* Internal dispatch routine
*/
Expand Down Expand Up @@ -430,94 +427,4 @@ struct DispatchUniqueByKey: SelectedPolicy
}
};


// template <typename Derived,
// typename KeyInputIteratorT,
// typename ValueInputIteratorT,
// typename KeyOutputIteratorT,
// typename ValueOutputIteratorT,
// typename EqualityOpT>
// THRUST_RUNTIME_FUNCTION
// pair<KeyOutputIteratorT, ValueOutputIteratorT>
// unique_by_key(execution_policy<Derived>& policy,
// KeyInputIteratorT keys_first,
// KeyInputIteratorT keys_last,
// ValueInputIteratorT values_first,
// KeyOutputIteratorT keys_result,
// ValueOutputIteratorT values_result,
// EqualityOpT binary_pred)
// {

// typedef int size_type;

// size_type num_items
// = static_cast<size_type>(thrust::distance(keys_first, keys_last));

// size_t temp_storage_bytes = 0;
// cudaStream_t stream = cuda_cub::stream(policy);
// bool debug_sync = THRUST_DEBUG_SYNC_FLAG;

// cudaError_t status;
// status = __unique_by_key::doit_step(NULL,
// temp_storage_bytes,
// keys_first,
// values_first,
// keys_result,
// values_result,
// binary_pred,
// reinterpret_cast<size_type*>(NULL),
// num_items,
// stream,
// debug_sync);
// cuda_cub::throw_on_error(status, "unique_by_key: failed on 1st step");

// size_t allocation_sizes[2] = {sizeof(size_type), temp_storage_bytes};
// void * allocations[2] = {NULL, NULL};

// size_t storage_size = 0;
// status = core::alias_storage(NULL,
// storage_size,
// allocations,
// allocation_sizes);
// cuda_cub::throw_on_error(status, "unique_by_key failed on 1st alias_storage");

// // Allocate temporary storage.
// thrust::detail::temporary_array<thrust::detail::uint8_t, Derived>
// tmp(policy, storage_size);
// void *ptr = static_cast<void*>(tmp.data().get());

// status = core::alias_storage(ptr,
// storage_size,
// allocations,
// allocation_sizes);
// cuda_cub::throw_on_error(status, "unique_by_key failed on 2nd alias_storage");

// size_type* d_num_selected_out
// = thrust::detail::aligned_reinterpret_cast<size_type*>(allocations[0]);

// status = __unique_by_key::doit_step(allocations[1],
// temp_storage_bytes,
// keys_first,
// values_first,
// keys_result,
// values_result,
// binary_pred,
// d_num_selected_out,
// num_items,
// stream,
// debug_sync);
// cuda_cub::throw_on_error(status, "unique_by_key: failed on 2nd step");

// status = cuda_cub::synchronize(policy);
// cuda_cub::throw_on_error(status, "unique_by_key: failed to synchronize");

// size_type num_selected = get_value(policy, d_num_selected_out);

// return thrust::make_pair(
// keys_result + num_selected,
// values_result + num_selected
// );
// }


CUB_NAMESPACE_END
31 changes: 15 additions & 16 deletions cub/util_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,21 @@ using is_integral_or_enum =
std::integral_constant<bool,
std::is_integral<T>::value || std::is_enum<T>::value>;

__host__ __device__ __forceinline__ constexpr std::size_t
VshmemSize(std::size_t max_shmem,
std::size_t shmem_per_block,
std::size_t num_blocks)
{
if (shmem_per_block > max_shmem)
{
return shmem_per_block * num_blocks;
}
else
{
return 0;
}
}

}

/**
Expand All @@ -67,22 +82,6 @@ DivideAndRoundUp(NumeratorT n, DenominatorT d)
return static_cast<NumeratorT>(n / d + (n % d != 0 ? 1 : 0));
}


__host__ __device__ __forceinline__ constexpr std::size_t
VshmemSize(std::size_t max_shmem,
std::size_t shmem_per_block,
std::size_t num_blocks)
{
if (shmem_per_block > max_shmem)
{
return shmem_per_block * num_blocks;
}
else
{
return 0;
}
}

constexpr __device__ __host__ int
Nominal4BItemsToItemsCombined(int nominal_4b_items_per_thread, int combined_bytes)
{
Expand Down
Loading

0 comments on commit 8986eef

Please sign in to comment.