Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Merge pull request #516 from senior-zero/fix-main/github/warp_reduce
Browse files Browse the repository at this point in the history
Fix warp reduce
  • Loading branch information
gevtushenko authored Jul 26, 2022
2 parents 443e020 + 91963c8 commit 728a2a2
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 20 deletions.
120 changes: 105 additions & 15 deletions cub/warp/specializations/warp_reduce_shfl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@
#include "../../util_type.cuh"

#include <stdint.h>
#include <type_traits>

#include <nv/target>

CUB_NAMESPACE_BEGIN

Expand Down Expand Up @@ -455,35 +458,122 @@ struct WarpReduceShfl
//---------------------------------------------------------------------
// Reduction operations
//---------------------------------------------------------------------

/// Reduction
template <
bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items
typename ReductionOp>
__device__ __forceinline__ T Reduce(
template <typename ReductionOp>
__device__ __forceinline__ T ReduceImpl(
Int2Type<0> /* all_lanes_valid */,
T input, ///< [in] Calling thread's input
int valid_items, ///< [in] Total number of valid items across the logical warp
ReductionOp reduction_op) ///< [in] Binary reduction operator
{
int last_lane = (ALL_LANES_VALID) ?
LOGICAL_WARP_THREADS - 1 :
valid_items - 1;
int last_lane = valid_items - 1;

T output = input;

// // Iterate reduction steps
// #pragma unroll
// for (int STEP = 0; STEP < STEPS; STEP++)
// {
// output = ReduceStep(output, reduction_op, last_lane, 1 << STEP, Int2Type<IsInteger<T>::IS_SMALL_UNSIGNED>());
// }
// Template-iterate reduction steps
ReduceStep(output, reduction_op, last_lane, Int2Type<0>());

return output;
}

template <typename ReductionOp>
__device__ __forceinline__ T ReduceImpl(
Int2Type<1> /* all_lanes_valid */,
T input, ///< [in] Calling thread's input
int /* valid_items */, ///< [in] Total number of valid items across the logical warp
ReductionOp reduction_op) ///< [in] Binary reduction operator
{
int last_lane = LOGICAL_WARP_THREADS - 1;

T output = input;

// Template-iterate reduction steps
ReduceStep(output, reduction_op, last_lane, Int2Type<0>());

return output;
}

// Warp reduce functions are not supported by nvc++ (NVBug 3694682)
#ifndef _NVHPC_CUDA
template <class U = T>
__device__ __forceinline__
typename std::enable_if<
std::is_same<int, U>::value
|| std::is_same<unsigned int, U>::value, T>::type
ReduceImpl(Int2Type<1> /* all_lanes_valid */,
T input,
int /* valid_items */,
cub::Sum /* reduction_op */)
{
T output = input;

NV_IF_TARGET(NV_PROVIDES_SM_80,
(output = __reduce_add_sync(member_mask, input);),
(output = ReduceImpl<cub::Sum>(Int2Type<1>{},
input,
LOGICAL_WARP_THREADS,
cub::Sum{});));

return output;
}

template <class U = T>
__device__ __forceinline__
typename std::enable_if<
std::is_same<int, U>::value
|| std::is_same<unsigned int, U>::value, T>::type
ReduceImpl(Int2Type<1> /* all_lanes_valid */,
T input,
int /* valid_items */,
cub::Min /* reduction_op */)
{
T output = input;

NV_IF_TARGET(NV_PROVIDES_SM_80,
(output = __reduce_min_sync(member_mask, input);),
(output = ReduceImpl<cub::Min>(Int2Type<1>{},
input,
LOGICAL_WARP_THREADS,
cub::Min{});));

return output;
}

template <class U = T>
__device__ __forceinline__
typename std::enable_if<
std::is_same<int, U>::value
|| std::is_same<unsigned int, U>::value, T>::type
ReduceImpl(Int2Type<1> /* all_lanes_valid */,
T input,
int /* valid_items */,
cub::Max /* reduction_op */)
{
T output = input;

NV_IF_TARGET(NV_PROVIDES_SM_80,
(output = __reduce_max_sync(member_mask, input);),
(output = ReduceImpl<cub::Max>(Int2Type<1>{},
input,
LOGICAL_WARP_THREADS,
cub::Max{});));

return output;
}
#endif // _NVHPC_CUDA

/// Reduction
template <
bool ALL_LANES_VALID, ///< Whether all lanes in each warp are contributing a valid fold of items
typename ReductionOp>
__device__ __forceinline__ T Reduce(
T input, ///< [in] Calling thread's input
int valid_items, ///< [in] Total number of valid items across the logical warp
ReductionOp reduction_op) ///< [in] Binary reduction operator
{
return ReduceImpl(
Int2Type<ALL_LANES_VALID>{}, input, valid_items, reduction_op);
}


/// Segmented reduction
template <
Expand Down
65 changes: 65 additions & 0 deletions cub/warp/warp_reduce.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,71 @@ public:
//@} end member group
};

template <typename T, int LEGACY_PTX_ARCH>
class WarpReduce<T, 1, LEGACY_PTX_ARCH>
{
private:
using _TempStorage = cub::NullType;

public:
struct TempStorage : Uninitialized<_TempStorage>
{};

__device__ __forceinline__ WarpReduce(TempStorage & /*temp_storage */)
{}

__device__ __forceinline__ T Sum(T input) { return input; }

__device__ __forceinline__ T Sum(T input, int /* valid_items */)
{
return input;
}

template <typename FlagT>
__device__ __forceinline__ T HeadSegmentedSum(T input, FlagT /* head_flag */)
{
return input;
}

template <typename FlagT>
__device__ __forceinline__ T TailSegmentedSum(T input, FlagT /* tail_flag */)
{
return input;
}

template <typename ReductionOp>
__device__ __forceinline__ T Reduce(T input, ReductionOp /* reduction_op */)
{
return input;
}

template <typename ReductionOp>
__device__ __forceinline__ T Reduce(T input,
ReductionOp /* reduction_op */,
int /* valid_items */)
{
return input;
}

template <typename ReductionOp, typename FlagT>
__device__ __forceinline__ T
HeadSegmentedReduce(T input,
FlagT /* head_flag */,
ReductionOp /* reduction_op */)
{
return input;
}

template <typename ReductionOp, typename FlagT>
__device__ __forceinline__ T
TailSegmentedReduce(T input,
FlagT /* tail_flag */,
ReductionOp /* reduction_op */)
{
return input;
}
};

/** @} */ // end group WarpModule

CUB_NAMESPACE_END
13 changes: 8 additions & 5 deletions test/test_warp_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,8 @@ void Initialize(
RandomBits(bits, flag_entropy);
h_flags[i] = bits & 0x1;
}
h_flags[warps * warp_threads] = {};
h_tail_out[warps * warp_threads] = {};

// Accumulate segments (lane 0 of each warp is implicitly a segment head)
for (int warp = 0; warp < warps; ++warp)
Expand Down Expand Up @@ -483,9 +485,9 @@ void TestReduce(

// Allocate host arrays
T *h_in = new T[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS + 1];
T *h_out = new T[BLOCK_THREADS];
T *h_tail_out = new T[BLOCK_THREADS];
T *h_tail_out = new T[BLOCK_THREADS + 1];

// Initialize problem
Initialize(gen_mode, -1, h_in, h_flags, WARPS, LOGICAL_WARP_THREADS, valid_warp_threads, reduction_op, h_out, h_tail_out);
Expand Down Expand Up @@ -578,9 +580,9 @@ void TestSegmentedReduce(
// Allocate host arrays
int compare;
T *h_in = new T[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS];
T *h_head_out = new T[BLOCK_THREADS];
T *h_tail_out = new T[BLOCK_THREADS];
int *h_flags = new int[BLOCK_THREADS + 1];
T *h_head_out = new T[BLOCK_THREADS + 1];
T *h_tail_out = new T[BLOCK_THREADS + 1];

// Initialize problem
Initialize(gen_mode, flag_entropy, h_in, h_flags, WARPS, LOGICAL_WARP_THREADS, LOGICAL_WARP_THREADS, reduction_op, h_head_out, h_tail_out);
Expand Down Expand Up @@ -817,6 +819,7 @@ int main(int argc, char** argv)
Test<16>();
Test<9>();
Test<7>();
Test<1>();

return 0;
}
Expand Down

0 comments on commit 728a2a2

Please sign in to comment.