Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Commit

Permalink
Fix block shuffle
Browse files Browse the repository at this point in the history
  • Loading branch information
gevtushenko committed May 26, 2021
1 parent ad5299d commit dcf0671
Show file tree
Hide file tree
Showing 3 changed files with 374 additions and 21 deletions.
2 changes: 1 addition & 1 deletion cmake/CubCudaConfig.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ enable_language(CUDA)
# Architecture options:
#

set(all_archs 35 37 50 52 53 60 61 62 70 72 75 80)
set(all_archs 35 37 50 52 53 60 61 62 70 72 75 80 86)
set(arch_message "CUB: Explicitly enabled compute architectures:")

# Thrust sets up the architecture flags in CMAKE_CUDA_FLAGS already. Just
Expand Down
35 changes: 15 additions & 20 deletions cub/block/block_shuffle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,7 @@ private:
******************************************************************************/

/// Shared memory storage layout type (last element from each thread's input)
struct _TempStorage
{
T prev[BLOCK_THREADS];
T next[BLOCK_THREADS];
};
typedef T _TempStorage[BLOCK_THREADS];


public:
Expand Down Expand Up @@ -171,14 +167,14 @@ public:
T& output, ///< [out] The \p input item from the successor (or predecessor) thread <em>thread</em><sub><em>i</em>+<tt>distance</tt></sub> (may be aliased to \p input). This value is only updated for for <em>thread<sub>i</sub></em> when 0 <= (<em>i</em> + \p distance) < <tt>BLOCK_THREADS-1</tt>
int distance = 1) ///< [in] Offset distance (may be negative)
{
temp_storage[linear_tid].prev = input;
temp_storage[linear_tid] = input;

CTA_SYNC();

const int offset_tid = static_cast<int>(linear_tid) + distance;
if ((offset_tid >= 0) && (offset_tid < BLOCK_THREADS))
{
output = temp_storage[static_cast<size_t>(offset_tid)].prev;
output = temp_storage[static_cast<size_t>(offset_tid)];
}
}

Expand All @@ -194,15 +190,15 @@ public:
T& output, ///< [out] The \p input item from thread <em>thread</em><sub>(<em>i</em>+<tt>distance></tt>)%<tt><BLOCK_THREADS></tt></sub> (may be aliased to \p input). This value is not updated for <em>thread</em><sub>BLOCK_THREADS-1</sub>
unsigned int distance = 1) ///< [in] Offset distance (0 < \p distance < <tt>BLOCK_THREADS</tt>)
{
temp_storage[linear_tid].prev = input;
temp_storage[linear_tid] = input;

CTA_SYNC();

unsigned int offset = threadIdx.x + distance;
if (offset >= BLOCK_THREADS)
offset -= BLOCK_THREADS;

output = temp_storage[offset].prev;
output = temp_storage[offset];
}


Expand All @@ -219,17 +215,16 @@ public:
T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items
T (&prev)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The item \p prev[0] is not updated for <em>thread</em><sub>0</sub>.
{
temp_storage[linear_tid].prev = input[ITEMS_PER_THREAD - 1];
temp_storage[linear_tid] = input[ITEMS_PER_THREAD - 1];

CTA_SYNC();

#pragma unroll
for (int ITEM = ITEMS_PER_THREAD - 1; ITEM > 0; --ITEM)
prev[ITEM] = input[ITEM - 1];


if (linear_tid > 0)
prev[0] = temp_storage[linear_tid - 1].prev;
prev[0] = temp_storage[linear_tid - 1];
}


Expand All @@ -248,7 +243,7 @@ public:
T &block_suffix) ///< [out] The item \p input[ITEMS_PER_THREAD-1] from <em>thread</em><sub><tt>BLOCK_THREADS-1</tt></sub>, provided to all threads
{
Up(input, prev);
block_suffix = temp_storage[BLOCK_THREADS - 1].prev;
block_suffix = temp_storage[BLOCK_THREADS - 1];
}


Expand All @@ -265,16 +260,16 @@ public:
T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items
T (&prev)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p prev[0] is not updated for <em>thread</em><sub>BLOCK_THREADS-1</sub>.
{
temp_storage[linear_tid].prev = input[ITEMS_PER_THREAD - 1];
temp_storage[linear_tid] = input[0];

CTA_SYNC();

#pragma unroll
for (int ITEM = ITEMS_PER_THREAD - 1; ITEM > 0; --ITEM)
prev[ITEM] = input[ITEM - 1];
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD - 1; ITEM++)
prev[ITEM] = input[ITEM + 1];

if (linear_tid > 0)
prev[0] = temp_storage[linear_tid - 1].prev;
if (linear_tid < BLOCK_THREADS - 1)
prev[ITEMS_PER_THREAD - 1] = temp_storage[linear_tid + 1];
}


Expand All @@ -292,8 +287,8 @@ public:
T (&prev)[ITEMS_PER_THREAD], ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p prev[0] is not updated for <em>thread</em><sub>BLOCK_THREADS-1</sub>.
T &block_prefix) ///< [out] The item \p input[0] from <em>thread</em><sub><tt>0</tt></sub>, provided to all threads
{
Up(input, prev);
block_prefix = temp_storage[BLOCK_THREADS - 1].prev;
Down(input, prev);
block_prefix = temp_storage[0];
}

//@} end member group
Expand Down
Loading

0 comments on commit dcf0671

Please sign in to comment.