diff --git a/cub/agent/agent_batch_memcpy.cuh b/cub/agent/agent_batch_memcpy.cuh index dcabae69fa..1e30adb5be 100644 --- a/cub/agent/agent_batch_memcpy.cuh +++ b/cub/agent/agent_batch_memcpy.cuh @@ -231,8 +231,10 @@ private: static constexpr uint32_t TLEV_BUFFERS_PER_THREAD = BUFFERS_PER_THREAD; static constexpr uint32_t BLEV_BUFFERS_PER_THREAD = BUFFERS_PER_THREAD; - static constexpr uint32_t WARP_LEVEL_THRESHOLD = 64; - static constexpr uint32_t BLOCK_LEVEL_THRESHOLD = 1024; + static constexpr uint32_t WARP_LEVEL_THRESHOLD = 32; + static constexpr uint32_t BLOCK_LEVEL_THRESHOLD = 8 * 1024; + + static constexpr uint32_t BUFFER_STABLE_PARTITION = false; // Constants enum : uint32_t @@ -246,6 +248,9 @@ private: //--------------------------------------------------------------------- // TYPE DECLARATIONS //--------------------------------------------------------------------- + /// Internal load/store type. For byte-wise memcpy, a single-byte type + using AliasT = uint8_t; + /// Type that has to be sufficiently large to hold any of the buffers' sizes. /// The BufferSizeIteratorT's value type must be convertible to this type. using BufferSizeT = typename std::iterator_traits::value_type; @@ -280,6 +285,13 @@ private: BlockBufferOffsetT buffer_id; }; + // Load buffers in a striped arrangement if we do not want to performa a stable partitioning into small, medium, and + // large buffers, otherwise load them in a blocked arrangement + using BufferLoadT = BlockLoad; + // A vectorized counter that will count the number of buffers that fall into each of the size-classes. Where the size // class representes the collaboration level that is required to process a buffer. The collaboration level being // either: @@ -296,11 +308,8 @@ private: // Block-level run-length decode algorithm to evenly distribute work of all buffers requiring thread-level // collaboration - using BlockRunLengthDecodeT = cub::BlockRunLengthDecode; + using BlockRunLengthDecodeT = + cub::BlockRunLengthDecode; using BlockExchangeTLevT = cub::BlockExchange; @@ -314,6 +323,8 @@ private: { union { + typename BufferLoadT::TempStorage load_storage; + // Stage 1: histogram over the size classes in preparation for partitioning buffers by size typename BlockSizeClassScanT::TempStorage size_scan_storage; @@ -369,7 +380,7 @@ private: __device__ __forceinline__ void LoadBufferSizesFullTile(BufferSizeIteratorT tile_buffer_sizes_it, BufferSizeT (&buffer_sizes)[BUFFERS_PER_THREAD]) { - LoadDirectStriped(threadIdx.x, tile_buffer_sizes_it, buffer_sizes); + BufferLoadT(temp_storage.load_storage).Load(tile_buffer_sizes_it, buffer_sizes); } /** @@ -383,11 +394,7 @@ private: // Out-of-bounds buffer items are initialized to '0', so those buffers will simply be ignored later on constexpr BufferSizeT OOB_DEFAULT_BUFFER_SIZE = 0U; - LoadDirectStriped(threadIdx.x, - tile_buffer_sizes_it, - buffer_sizes, - num_valid, - OOB_DEFAULT_BUFFER_SIZE); + BufferLoadT(temp_storage.load_storage).Load(tile_buffer_sizes_it, buffer_sizes, num_valid, OOB_DEFAULT_BUFFER_SIZE); } /** @@ -421,7 +428,13 @@ private: VectorizedSizeClassCounterT &vectorized_offsets, BufferTuple (&buffers_by_size_class)[BUFFERS_PER_BLOCK]) { - BlockBufferOffsetT buffer_id = threadIdx.x; + // If we intend to perform a stable partitioning, the thread's buffer are in a blocked arrangement, + // otherwise they are in a striped arrangement + BlockBufferOffsetT buffer_id = BUFFER_STABLE_PARTITION ? (BUFFERS_PER_THREAD * threadIdx.x) : (threadIdx.x); + constexpr BlockBufferOffsetT BUFFER_STRIDE = BUFFER_STABLE_PARTITION + ? static_cast(1) + : static_cast(BLOCK_THREADS); + #pragma unroll for (uint32_t i = 0; i < BUFFERS_PER_THREAD; i++) { @@ -434,7 +447,7 @@ private: buffers_by_size_class[write_offset] = {static_cast(buffer_sizes[i]), buffer_id}; vectorized_offsets.Add(buffer_size_class, 1U); } - buffer_id += BLOCK_THREADS; + buffer_id += BUFFER_STRIDE; } } @@ -619,6 +632,16 @@ private: // Read in the TLEV buffer partition (i.e., the buffers that require thread-level collaboration) uint32_t tlev_buffer_offset = threadIdx.x * TLEV_BUFFERS_PER_THREAD; + + // Pre-populate the buffer sizes to 0 (i.e. zero-padding towards the end) to ensure out-of-bounds TLEV buffers will + // not be considered +#pragma unroll + for (uint32_t i = 0; i < TLEV_BUFFERS_PER_THREAD; i++) + { + tlev_buffer_sizes[i] = 0; + } + + // Assign TLEV buffers in a blocked arrangement (each thread is assigned consecutive TLEV buffers) #pragma unroll for (uint32_t i = 0; i < TLEV_BUFFERS_PER_THREAD; i++) { @@ -627,19 +650,16 @@ private: tlev_buffer_ids[i] = buffers_by_size_class[tlev_buffer_offset].buffer_id; tlev_buffer_sizes[i] = buffers_by_size_class[tlev_buffer_offset].size; } - else - { - // Out-of-bounds buffers are assigned a size of '0' - tlev_buffer_sizes[i] = 0; - } tlev_buffer_offset++; } // Evenly distribute all the bytes that have to be copied from all the buffers that require thread-level // collaboration using BlockRunLengthDecode uint32_t num_total_tlev_bytes = 0U; - BlockRunLengthDecodeT block_run_length_decode(temp_storage.run_length_decode); - block_run_length_decode.Init(tlev_buffer_ids, tlev_buffer_sizes, num_total_tlev_bytes); + BlockRunLengthDecodeT block_run_length_decode(temp_storage.run_length_decode, + tlev_buffer_ids, + tlev_buffer_sizes, + num_total_tlev_bytes); // Run-length decode the buffers' sizes into a window buffer of limited size. This is repeated until we were able to // cover all the bytes of TLEV buffers @@ -650,7 +670,7 @@ private: TLevBufferSizeT buffer_byte_offset[TLEV_BYTES_PER_THREAD]; // Now we have a balanced assignment: buffer_id[i] will hold the tile's buffer id and buffer_byte_offset[i] that - // buffer's byte that we're supposed to copy + // buffer's byte that this thread supposed to copy block_run_length_decode.RunLengthDecode(buffer_id, buffer_byte_offset, decoded_window_offset); // Zip from SoA to AoS @@ -671,16 +691,21 @@ private: if (is_full_window) { uint32_t absolute_tlev_byte_offset = decoded_window_offset + threadIdx.x; + AliasT src_byte[TLEV_BYTES_PER_THREAD]; #pragma unroll for (int32_t i = 0; i < TLEV_BYTES_PER_THREAD; i++) { - uint8_t src_byte = reinterpret_cast( + src_byte[i] = reinterpret_cast( tile_buffer_srcs[zipped_byte_assignment[i].tile_buffer_id])[zipped_byte_assignment[i].buffer_byte_offset]; - reinterpret_cast( - tile_buffer_dsts[zipped_byte_assignment[i].tile_buffer_id])[zipped_byte_assignment[i].buffer_byte_offset] = - src_byte; absolute_tlev_byte_offset += BLOCK_THREADS; } +#pragma unroll + for (int32_t i = 0; i < TLEV_BYTES_PER_THREAD; i++) + { + reinterpret_cast( + tile_buffer_dsts[zipped_byte_assignment[i].tile_buffer_id])[zipped_byte_assignment[i].buffer_byte_offset] = + src_byte[i]; + } } else { @@ -690,9 +715,9 @@ private: { if (absolute_tlev_byte_offset < num_total_tlev_bytes) { - uint8_t src_byte = reinterpret_cast( + AliasT src_byte = reinterpret_cast( tile_buffer_srcs[zipped_byte_assignment[i].tile_buffer_id])[zipped_byte_assignment[i].buffer_byte_offset]; - reinterpret_cast( + reinterpret_cast( tile_buffer_dsts[zipped_byte_assignment[i].tile_buffer_id])[zipped_byte_assignment[i].buffer_byte_offset] = src_byte; } @@ -728,6 +753,9 @@ public: LoadBufferSizesPartialTile(tile_buffer_sizes_it, buffer_sizes, num_buffers - buffer_offset); } + // Ensure we can repurpose the BlockLoad's temporary storage + CTA_SYNC(); + // Count how many buffers fall into each size-class VectorizedSizeClassCounterT size_class_histogram = {}; GetBufferSizeClassHistogram(buffer_sizes, size_class_histogram);