Skip to content
This repository has been archived by the owner on Mar 21, 2024. It is now read-only.

Fix block shuffle #311

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/CubCudaConfig.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ enable_language(CUDA)
# Architecture options:
#

set(all_archs 35 37 50 52 53 60 61 62 70 72 75 80)
set(all_archs 35 37 50 52 53 60 61 62 70 72 75 80 86)
alliepiper marked this conversation as resolved.
Show resolved Hide resolved
set(arch_message "CUB: Explicitly enabled compute architectures:")

# Thrust sets up the architecture flags in CMAKE_CUDA_FLAGS already. Just
Expand Down
35 changes: 15 additions & 20 deletions cub/block/block_shuffle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,7 @@ private:
******************************************************************************/

/// Shared memory storage layout type (last element from each thread's input)
struct _TempStorage
{
T prev[BLOCK_THREADS];
T next[BLOCK_THREADS];
};
typedef T _TempStorage[BLOCK_THREADS];


public:
Expand Down Expand Up @@ -171,14 +167,14 @@ public:
T& output, ///< [out] The \p input item from the successor (or predecessor) thread <em>thread</em><sub><em>i</em>+<tt>distance</tt></sub> (may be aliased to \p input). This value is only updated for for <em>thread<sub>i</sub></em> when 0 <= (<em>i</em> + \p distance) < <tt>BLOCK_THREADS-1</tt>
int distance = 1) ///< [in] Offset distance (may be negative)
{
temp_storage[linear_tid].prev = input;
temp_storage[linear_tid] = input;

CTA_SYNC();

const int offset_tid = static_cast<int>(linear_tid) + distance;
if ((offset_tid >= 0) && (offset_tid < BLOCK_THREADS))
{
output = temp_storage[static_cast<size_t>(offset_tid)].prev;
output = temp_storage[static_cast<size_t>(offset_tid)];
}
}

Expand All @@ -194,15 +190,15 @@ public:
T& output, ///< [out] The \p input item from thread <em>thread</em><sub>(<em>i</em>+<tt>distance></tt>)%<tt><BLOCK_THREADS></tt></sub> (may be aliased to \p input). This value is not updated for <em>thread</em><sub>BLOCK_THREADS-1</sub>
unsigned int distance = 1) ///< [in] Offset distance (0 < \p distance < <tt>BLOCK_THREADS</tt>)
{
temp_storage[linear_tid].prev = input;
temp_storage[linear_tid] = input;

CTA_SYNC();

unsigned int offset = threadIdx.x + distance;
if (offset >= BLOCK_THREADS)
offset -= BLOCK_THREADS;

output = temp_storage[offset].prev;
output = temp_storage[offset];
}


Expand All @@ -219,17 +215,16 @@ public:
T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items
T (&prev)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The item \p prev[0] is not updated for <em>thread</em><sub>0</sub>.
{
temp_storage[linear_tid].prev = input[ITEMS_PER_THREAD - 1];
temp_storage[linear_tid] = input[ITEMS_PER_THREAD - 1];

CTA_SYNC();

#pragma unroll
for (int ITEM = ITEMS_PER_THREAD - 1; ITEM > 0; --ITEM)
prev[ITEM] = input[ITEM - 1];


if (linear_tid > 0)
prev[0] = temp_storage[linear_tid - 1].prev;
prev[0] = temp_storage[linear_tid - 1];
}


Expand All @@ -248,7 +243,7 @@ public:
T &block_suffix) ///< [out] The item \p input[ITEMS_PER_THREAD-1] from <em>thread</em><sub><tt>BLOCK_THREADS-1</tt></sub>, provided to all threads
{
Up(input, prev);
block_suffix = temp_storage[BLOCK_THREADS - 1].prev;
block_suffix = temp_storage[BLOCK_THREADS - 1];
}


Expand All @@ -265,16 +260,16 @@ public:
T (&input)[ITEMS_PER_THREAD], ///< [in] The calling thread's input items
T (&prev)[ITEMS_PER_THREAD]) ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p prev[0] is not updated for <em>thread</em><sub>BLOCK_THREADS-1</sub>.
{
temp_storage[linear_tid].prev = input[ITEMS_PER_THREAD - 1];
temp_storage[linear_tid] = input[0];

CTA_SYNC();

#pragma unroll
for (int ITEM = ITEMS_PER_THREAD - 1; ITEM > 0; --ITEM)
prev[ITEM] = input[ITEM - 1];
for (int ITEM = 0; ITEM < ITEMS_PER_THREAD - 1; ITEM++)
prev[ITEM] = input[ITEM + 1];

if (linear_tid > 0)
prev[0] = temp_storage[linear_tid - 1].prev;
if (linear_tid < BLOCK_THREADS - 1)
prev[ITEMS_PER_THREAD - 1] = temp_storage[linear_tid + 1];
}


Expand All @@ -292,8 +287,8 @@ public:
T (&prev)[ITEMS_PER_THREAD], ///< [out] The corresponding predecessor items (may be aliased to \p input). The value \p prev[0] is not updated for <em>thread</em><sub>BLOCK_THREADS-1</sub>.
T &block_prefix) ///< [out] The item \p input[0] from <em>thread</em><sub><tt>0</tt></sub>, provided to all threads
{
Up(input, prev);
block_prefix = temp_storage[BLOCK_THREADS - 1].prev;
Down(input, prev);
block_prefix = temp_storage[0];
}

//@} end member group
Expand Down
Loading